PyTorch에서 모델 요약 출력 방법
torchsummary 사용
torchsummary
는 모델 요약을 출력하는 간단한 라이브러리입니다.
설치:
pip install torchsummary
사용:
from torchsummary import summary
model = ... # 모델 정의
summary(model)
summary
함수는 모델 구조, 각 레이어의 입력/출력 크기, 매개변수 수 등을 출력합니다.
Model.summary() 사용
PyTorch 1.8 이상 버전을 사용하면 Model.summary()
메서드를 사용하여 모델 요약을 출력할 수 있습니다.
model = ... # 모델 정의
model.summary()
summary()
메서드는 torchsummary
라이브러리와 비슷한 정보를 출력합니다.
예시
다음 예시는 torchsummary
라이브러리를 사용하여 모델 요약을 출력하는 방법을 보여줍니다.
import torch
from torchsummary import summary
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, 1)
self.fc1 = torch.nn.Linear(64 * 10 * 10, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = x.view(-1, 64 * 10 * 10)
x = self.fc1(x)
return x
model = MyModel()
summary(model)
출력:
----------------------------------------------------------------
Layer (type) Input Output
----------------------------------------------------------------
Conv2d-1 (1, 28, 28) (32, 28, 28)
ReLU-2 (32, 28, 28) (32, 28, 28)
Conv2d-3 (32, 28, 28) (64, 26, 26)
ReLU-4 (64, 26, 26) (64, 26, 26)
Flatten-5 (64, 26, 26) (6400,)
Linear-6 (6400,) (10,)
================================================================
Total params: 1,089,610
Trainable params: 1,089,610
Non-trainable params: 0
----------------------------------------------------------------
예제 코드
import torch
from torchsummary import summary
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, 1)
self.fc1 = torch.nn.Linear(64 * 10 * 10, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = x.view(-1, 64 * 10 * 10)
x = self.fc1(x)
return x
model = MyModel()
summary(model)
----------------------------------------------------------------
Layer (type) Input Output
----------------------------------------------------------------
Conv2d-1 (1, 28, 28) (32, 28, 28)
ReLU-2 (32, 28, 28) (32, 28, 28)
Conv2d-3 (32, 28, 28) (64, 26, 26)
ReLU-4 (64, 26, 26) (64, 26, 26)
Flatten-5 (64, 26, 26) (6400,)
Linear-6 (6400,) (10,)
================================================================
Total params: 1,089,610
Trainable params: 1,089,610
Non-trainable params: 0
----------------------------------------------------------------
실행 방법
- Python 3.6 이상 버전을 설치합니다.
torch
와torchsummary
라이브러리를 설치합니다.
pip install torch
pip install torchsummary
- 예제 코드를 저장합니다.
- 코드를 실행합니다.
python example.py
출력
코드를 실행하면 다음과 같은 출력이 나타납니다.
----------------------------------------------------------------
Layer (type) Input Output
----------------------------------------------------------------
Conv2d-1 (1, 28, 28) (32, 28, 28)
ReLU-2 (32, 28, 28) (32, 28, 28)
Conv2d-3 (32, 28, 28) (64, 26, 26)
ReLU-4 (64, 26, 26) (64, 26, 26)
Flatten-5 (64, 26, 26) (6400,)
Linear-6 (6400,) (10,)
================================================================
Total params: 1,089,610
Trainable params: 1,089,610
Non-trainable params: 0
----------------------------------------------------------------
설명
출력은 다음과 같이 구성됩니다.
- Layer (type): 레이어 이름과 유형
- Input: 레이어 입력 크기
- Total params: 모델의 총 매개변수 수
- Trainable params: 학습 가능한 매개변수 수
PyTorch 모델 요약 출력 대체 방법
print() 함수 사용
가장 간단한 방법은 print()
함수를 사용하는 것입니다.
model = ... # 모델 정의
print(model)
print()
함수는 모델 구조와 각 레이어의 매개변수 수를 출력합니다. 하지만 torchsummary
라이브러리만큼 정보가 상세하지는 않습니다.
Model.summary() 메서드 사용
model = ... # 모델 정의
model.summary()
그래프 시각화
torchviz
라이브러리를 사용하여 모델 구조를 그래프로 시각화할 수 있습니다.
from torchviz import make_dot
model = ... # 모델 정의
dot = make_dot(model)
dot.render('model.png', format='png')
make_dot()
함수는 모델 구조를 그래프 객체로 생성하고, render()
함수는 그래프 객체를 PNG 파일로 저장합니다.
비교
방법 | 장점 | 단점 |
---|---|---|
torchsummary | 정보가 상세하고 보기 편함 | 별도의 라이브러리 설치 필요 |
print() | 간단하고 빠름 | 정보가 상세하지 않음 |
Model.summary() | torchsummary 와 비슷하지만 PyTorch 1.8 이상 버전만 지원 | 별도의 라이브러리 설치 필요 없음 |
torchviz | 모델 구조를 시각적으로 확인할 수 있음 | 코드 작성량이 더 많음 |
선택 가이드
- 정보가 상세하고 보기 편한 요약 출력을 원한다면
torchsummary
라이브러리를 사용하는 것이 좋습니다. - 간단하고 빠르게 요약 출력을 원한다면
print()
함수를 사용하십시오. - PyTorch 1.8 이상 버전을 사용하고 있으며
torchsummary
라이브러리를 설치하지 않고 싶다면Model.summary()
메서드를 사용하십시오. - 모델 구조를 시각적으로 확인하고 싶다면
torchviz
라이브러리를 사용하십시오.
python machine-learning deep-learning