PyTorch에서 .flatten()과 .view(-1)의 차이점
PyTorch에서 .flatten()과 .view(-1)의 차이점
작동 방식
.flatten(start_dim=d, end_dim=-1)
: 주어진 차원(d)부터 마지막 차원까지 텐서를 단일 차원으로 펼칩니다..view(-1)
: 텐서를 단일 차원으로 펼칩니다.-1
은 텐서의 모든 요소를 하나의 차원으로 결합하도록 PyTorch에 지시합니다.
예시
import torch
# 3차원 텐서 생성
x = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
# .flatten() 사용
y = x.flatten()
print(y)
# 출력:
# tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
# .view(-1) 사용
z = x.view(-1)
print(z)
# 출력:
# tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
주요 차이점
.flatten()
:- 시작 차원을 지정할 수 있습니다.
- 메모리 레이아웃을 변경합니다.
- 원래 텐서의 크기 정보를 유지하지 않습니다.
어떤 것을 사용해야 할까요?
- 텐서를 단순히 펼치고 싶을 때는
.view(-1)
을 사용하는 것이 좋습니다. - 텐서를 펼치고 특정 차원부터 작업하고 싶을 때는
.flatten()
을 사용하는 것이 좋습니다. - 원래 텐서의 크기 정보를 유지해야 하는 경우
.view(-1)
을 사용해야 합니다.
예제 코드
import torch
# 3차원 텐서 생성
x = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
# .flatten() 사용
y = x.flatten()
print(y)
# 출력:
# tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
# .view(-1) 사용
z = x.view(-1)
print(z)
# 출력:
# tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
# 2차원으로 펼치기
y = x.flatten(start_dim=1)
print(y)
# 출력:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9],
# [10, 11, 12]])
# 원래 텐서 크기 정보 확인
print(x.size())
# 출력:
# torch.Size([2, 2, 3])
# .flatten() 사용 후 크기 정보 확인
print(y.size())
# 출력:
# torch.Size([8])
# .view(-1) 사용 후 크기 정보 확인
print(z.size())
# 출력:
# torch.Size([12])
설명
- 첫 번째 코드 블록은
.flatten()
과.view(-1)
을 사용하여 3차원 텐서를 1차원으로 펼치는 방법을 보여줍니다. - 세 번째 코드 블록은
.flatten()
과.view(-1)
사용 후 원래 텐서의 크기 정보가 어떻게 변하는지 보여줍니다.
대체 방법
.reshape()
.reshape()
메서드는 텐서의 크기를 원하는 형태로 변경하는 데 사용할 수 있습니다.
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 3행 2열로 변환
y = x.reshape(3, 2)
print(y)
# 출력:
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
for 루프
간단한 경우 for 루프를 사용하여 텐서를 펼칠 수 있습니다.
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 1차원 리스트로 변환
y = []
for i in range(x.size(0)):
y.append(x[i])
print(y)
# 출력:
# [1, 2, 3, 4, 5, 6]
.tolist()
.tolist()
메서드는 텐서를 Python 리스트로 변환합니다.
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 리스트로 변환
y = x.tolist()
print(y)
# 출력:
# [1, 2, 3, 4, 5, 6]
.squeeze()
.squeeze()
메서드는 텐서에서 차원을 제거하는 데 사용할 수 있습니다.
x = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
# 3차원 텐서를 2차원 텐서로 변환
y = x.squeeze()
print(y)
# 출력:
# tensor([[1, 2, 3],
# [4, 5, 6]])
x = torch.tensor([1, 2, 3, 4, 5, 6])
# 1차원 텐서를 2차원 텐서로 변환
y = x.unsqueeze(dim=1)
print(y)
# 출력:
# tensor([[1],
# [2],
# [3],
# [4],
# [5],
# [6]])
선택 가이드
어떤 방법을 사용할지는 상황에 따라 다릅니다.
- for 루프는 간단한 경우에 유용하지만, 속도가 느릴 수 있습니다.
.squeeze()
는 불필요한 차원을 제거할 때 유용합니다.
python machine-learning deep-learning