PyTorch에서 학습된 모델을 저장하는 방법
PyTorch에서 학습된 모델을 저장하는 방법
모델 state_dict 저장
가장 간단한 방법은 모델의 state_dict를 저장하는 것입니다. State_dict는 모델의 모든 학습된 매개변수를 포함하는 Python 딕셔너리입니다.
# 모델 state_dict 저장
torch.save(model.state_dict(), "model_state_dict.pth")
장점:
- 간단하고 빠르게 저장할 수 있습니다.
- 모델의 매개변수만 저장하기 때문에 용량이 적게 듭니다.
단점:
- 모델의 아키텍처 정보는 저장되지 않습니다.
- 모델을 불러올 때 동일한 아키텍처의 모델을 만들어야 합니다.
모델 전체 저장
모델 전체를 저장하면 모델의 아키텍처 정보와 학습된 매개변수를 모두 포함하는 하나의 파일로 저장됩니다.
# 모델 전체 저장
torch.save(model, "model.pth")
- 모델 아키텍처 정보와 매개변수를 모두 저장하기 때문에 모델을 불러올 때 별도의 작업이 필요하지 않습니다.
- 모델을 다른 장치나 환경에서 쉽게 배포할 수 있습니다.
- 모델 state_dict를 저장하는 것보다 용량이 더 많이 듭니다.
TorchScript로 저장
TorchScript는 PyTorch 모델을 추론에 최적화된 형식으로 변환하는 도구입니다. TorchScript로 변환된 모델은 C++과 같은 다른 프로그래밍 언어에서도 사용할 수 있습니다.
# 모델을 TorchScript로 변환
model_scripted = torch.jit.trace(model, example_inputs)
# TorchScript 모델 저장
torch.jit.save(model_scripted, "model_scripted.pt")
- 추론 속도를 크게 향상시킬 수 있습니다.
- 다른 프로그래밍 언어에서도 모델을 사용할 수 있습니다.
- 모델 변환 과정이 다소 복잡합니다.
- 모든 모델이 TorchScript로 변환되는 것은 아닙니다.
예제 코드
# 모델 정의
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 1)
def forward(self, x):
x = x.view(-1)
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return x
# 모델 학습
model = MyModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(10):
# ...
# 모델 저장
# 모델 state_dict 저장
torch.save(model.state_dict(), "model_state_dict.pth")
# 모델 전체 저장
torch.save(model, "model.pth")
# TorchScript로 저장
model_scripted = torch.jit.trace(model, example_inputs)
torch.jit.save(model_scripted, "model_scripted.pt")
- 위 코드는 예시이며, 실제 상황에 맞게 수정해야 합니다.
- 모델 학습 코드는 생략되었습니다.
대체 방법
PyTorch 모델을 pickle로 저장하는 방법도 있습니다. Pickle은 Python 객체를 직렬화하는 표준 라이브러리입니다.
# 모델 pickle 저장
import pickle
with open("model.pkl", "wb") as f:
pickle.dump(model, f)
- 매우 간단하게 사용할 수 있습니다.
- 모델 state_dict뿐만 아니라 모델 전체를 저장할 수 있습니다.
- pickle은 안전하지 않을 수 있습니다. 악의적인 코드를 포함하는 pickle 파일을 로드하면 시스템이 손상될 수 있습니다.
- pickle은 모든 Python 객체를 직렬화할 수 있는 것은 아닙니다.
ONNX
ONNX는 Open Neural Network Exchange의 약자로, 다양한 프레임워크에서 사용할 수 있는 모델 형식입니다. PyTorch 모델을 ONNX로 변환하면 다른 프레임워크에서 모델을 추론하는 데 사용할 수 있습니다.
# 모델을 ONNX로 변환
import onnx
torch.onnx.export(model, example_inputs, "model.onnx")
- 다른 프레임워크에서 모델을 추론하는 데 사용할 수 있습니다.
- 모델 크기를 줄일 수 있습니다.
- 모든 모델이 ONNX로 변환되는 것은 아닙니다.
TensorFlow SavedModel
TensorFlow SavedModel은 TensorFlow 모델을 저장하는 표준 형식입니다. PyTorch 모델을 TensorFlow SavedModel로 변환하면 TensorFlow에서 모델을 추론하는 데 사용할 수 있습니다.
# 모델을 TensorFlow SavedModel로 변환
import tensorflow as tf
tf.saved_model.save(model, "model")
- 모든 모델이 TensorFlow SavedModel로 변환되는 것은 아닙니다.
python serialization deep-learning