torch.squeeze()
t = torch.randn(1, 2, 1, 2, 3, 1) # t의 모양: [1, 2, 1, 2, 3, 1]
# dim=1 차원의 크기는 2이므로, 이 차원은 제거되지않는다.
s_t = torch.squeeze(t, dim=1)
print("s_t.shape: {}".format(s_t.shape)) # t의 모양: [1, 2, 1, 2, 3, 1]
# dim=0 차원의 크기는 1이므로, 이 차원은 제거된다.
s_t.squeeze_(0)
print("s_t.shape: {}".format(s_t.shape)) # t의 모양: [2, 1, 2, 3, 1]
# dim이 1인 모든 차원을 축소한다.
s_t.squeeze_()
print("s_t.shape: {}".format(s_t.shape)) # t의 모양: [2, 2, 3]