Pytorch Tensor shaping (reshape, squeeze, unsqueeze)
Pytorch shaping
-
텐서의 shape를 변경할 수 있다.
-
reshape
,squeeze
,unsqueeze
함수를 사용할 수 있다.
[코드]
import torch
x = torch.FloatTensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]]])
print(x.size())
[결과]
torch.Size([3, 2, 2])
1. reshape 함수
- -1은 “나머지 숫자는 알아서 계산해줘”라는 의미로 사용 가능하다.
[코드]
#----- 두 연산의 결과 값은 같음
print(x.reshape(12)) # 3*2*2 = 12
print(x.reshape(-1))
[결과]
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
[코드]
#----- 두 연산의 결과 값은 같음
print(x.reshape(3, 4)) # 3*4 = 12
print(x.reshape(3, -1))
[결과]
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
[코드]
print(x.reshape(3, 1, 4))
print(x.reshape(-1, 1, 4))
[결과]
tensor([[[ 1., 2., 3., 4.]],
[[ 5., 6., 7., 8.]],
[[ 9., 10., 11., 12.]]])
tensor([[[ 1., 2., 3., 4.]],
[[ 5., 6., 7., 8.]],
[[ 9., 10., 11., 12.]]])
2. squeeze 함수
- 오직 1개의 element를 가지고 있는 dimension을 제거할 수 있다.
- 또한 제거할 dimensiond을 선택할 수 있다.
[코드]
x = torch.FloatTensor([[[1, 2],
[3, 4]]])
print(x.size())
[결과]
torch.Size([1, 2, 2])
[코드]
print(x.squeeze())
print(x.squeeze().size())
[결과]
tensor([[1., 2.],
[3., 4.]])
torch.Size([2, 2])
- 제거할 dim을 선택할 수 있으며, 만약 선택한 dim이 element를 한 개 가지고 있는 것이 아니면, 아무 일도 일어나지 않는다.
[코드]
print(x.squeeze(0), x.squeeze(0).size())
print()
print(x.squeeze(1).size()) # 아무 변화도 일어나지 않음
[결과]
tensor([[1., 2.],
[3., 4.]]) torch.Size([2, 2])
torch.Size([1, 2, 2])
3. unsqueeze 함수
- 원하는 dim 자리에 dim을 추가할 수 있다.
[코드]
x = torch.FloatTensor([[1, 2],
[3, 4]])
print(x.size())
[결과]
torch.Size([2, 2])
[코드]
z = x.unsqueeze(0)
print(z, z.size())
print("-------------------------------------------\n")
z = x.unsqueeze(-1)
print(z, z.size())
[결과]
tensor([[[1., 2.],
[3., 4.]]]) torch.Size([1, 2, 2])
-------------------------------------------
tensor([[[1.],
[2.]],
[[3.],
[4.]]]) torch.Size([2, 2, 1])
이 포스팅은 패스트캠퍼스 김기현의 딥러닝 유치원 강의 기반으로 작성되었다.