PyTorch에서 reshape와 view의 차이점
PyTorch에서 reshape와 view의 차이점
기본적인 동작
- reshape: 텐서의 크기와 모양을 새로운 크기로 변경합니다.
- view: 텐서의 크기와 모양을 새로운 크기로 변경하는 듯 보이지만, 실제로는 기존 데이터와 같은 메모리 공간을 공유하며 stride 크기만 변경합니다.
주요 차이점
- contiguous 속성:
- view: 텐서가 contiguous해야만 작동합니다.
- reshape: 텐서가 contiguous 하지 않더라도 작동합니다.
- 데이터 복사:
- view: 텐서가 contiguous 하면 데이터를 복사하지 않습니다.
예시
import torch
# 3x3 텐서 생성
x = torch.arange(9).view(3, 3)
# reshape: 텐서 크기 변경 (3x3 -> 1x9)
y = x.reshape(1, 9)
# view: 텐서 크기 변경 (3x3 -> 9x1)
z = x.view(9, 1)
# 확인
print(x)
print(y)
print(z)
# 텐서 메모리 주소 확인
print(x.data_ptr() == y.data_ptr()) # False
print(x.data_ptr() == z.data_ptr()) # True
결론
- view: 텐서 크기 변경 + 메모리 공간 공유 (contiguous 조건 충족)
- reshape: 텐서 크기 변경 + 데이터 복사 (contiguous 조건 상관 없음)
예제 코드
import torch
# 3x3 텐서 생성
x = torch.arange(9).view(3, 3)
# reshape: 텐서 크기 변경 (3x3 -> 1x9)
y = x.reshape(1, 9)
# view: 텐서 크기 변경 (3x3 -> 9x1)
z = x.view(9, 1)
# 확인
print(x)
print(y)
print(z)
# 텐서 메모리 주소 확인
print(x.data_ptr() == y.data_ptr()) # False
print(x.data_ptr() == z.data_ptr()) # True
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8]])
False
True
설명
x
는 3x3 크기의 텐서입니다.y
는x
를 reshape하여 1x9 크기의 텐서로 변환합니다.data_ptr()
비교 결과를 통해x
와y
는 서로 다른 메모리 공간을 사용한다는 것을 확인할 수 있습니다.
추가 예시
- 텐서의 크기 변경과 동시에 차원 순서 변경:
x = torch.arange(9).view(3, 3)
y = x.view(3, 1, 3) # 3x3 텐서를 3x1x3 텐서로 변환 (차원 순서 변경)
print(x)
print(y)
print(x.data_ptr() == y.data_ptr()) # True
- 텐서의 크기 변경 없이 stride만 변경:
x = torch.arange(9).view(3, 3)
y = x.view(-1) # 텐서 크기 변경 없이 stride만 변경 (1차원 텐서로 변환)
print(x)
print(y)
print(x.data_ptr() == y.data_ptr()) # True
reshape와 view의 대체 방법
슬라이싱
- 특정 차원의 일부 요소만 선택하여 텐서 크기를 줄일 때 유용합니다.
x = torch.arange(9).view(3, 3)
y = x[:, 0] # 첫 번째 열만 선택 (3x1 텐서)
print(x)
print(y)
인덱싱
x = torch.arange(9).view(3, 3)
y = x[1, :] # 두 번째 행만 선택 (1x3 텐서)
print(x)
print(y)
concatenate
- 여러 텐서를 연결하여 텐서 크기를 늘릴 때 유용합니다.
x = torch.arange(3)
y = torch.arange(3, 6)
z = torch.cat((x, y), dim=0) # x와 y를 0차원 기준으로 연결 (6x1 텐서)
print(x)
print(y)
print(z)
stack
x = torch.arange(3)
y = torch.arange(3, 6)
z = torch.stack((x, y), dim=0) # x와 y를 0차원 기준으로 차원 추가 (2x3 텐서)
print(x)
print(y)
print(z)
transpose
- 텐서의 차원 순서를 변경할 때 유용합니다.
x = torch.arange(9).view(3, 3)
y = x.transpose(0, 1) # 0차원과 1차원 순서 변경 (3x3 텐서)
print(x)
print(y)
permute
x = torch.arange(9).view(3, 3)
y = x.permute(1, 0, 2) # 1차원, 0차원, 2차원 순서 변경 (3x3 텐서)
print(x)
print(y)
unfold
- 텐서를 윈도우 형태로 분할하여 새로운 텐서를 생성할 때 유용합니다.
x = torch.arange(9).view(3, 3)
y = x.unfold(0, 2, 1) # 0차원 기준으로 윈도우 크기 2x1로 분할 (2x3 텐서)
print(x)
print(y)
fold
- 윈도우 형태로 분할된 텐서를 다시 하나의 텐서로 합칠 때 유용합니다.
x = torch.arange(9).view(3, 3)
y = x.unfold(0, 2, 1)
z = y.fold(0, 2, 1) # 0차원 기준으로 윈도우 크기 2x1로 합
print(x)
print(y)
print(z)
squeeze
- 텐서의 차원 중 크기가 1인 차원을 제거할 때 유용합니다.
x = torch.arange(9).view(3, 1, 3)
y = x.squeeze(1) # 크기가 1인 1차원 제거 (3x3 텐서)
print(x)
print(y)
unsqueeze
x = torch.arange(9).view(3, 3)
y = x.unsqueeze(
python pytorch