PyTorch에서 발생하는 KeyError: "unexpected key "module.encoder.embedding.weight" in state_dict" 오류 해결
PyTorch에서 발생하는 KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict' 오류 해결
PyTorch 모델을 학습 후 저장하고 다시 불러올 때 다음과 같은 오류가 발생할 수 있습니다.
KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'
원인:
이 오류는 모델 저장 시 nn.DataParallel
을 사용했지만, 불러올 때는 사용하지 않아 발생합니다. nn.DataParallel
은 모델을 여러 GPU에 분산하여 학습시키는 데 사용되는 모듈입니다. 모델을 저장할 때 nn.DataParallel
을 사용하면 모델 이름 앞에 module.
이라는 접두사가 추가됩니다. 하지만 모델을 불러올 때 nn.DataParallel
을 사용하지 않으면 모델 이름에 module.
접두사가 없어 오류가 발생합니다.
해결 방법:
nn.DataParallel
사용:
모델을 불러올 때도 nn.DataParallel
을 사용하면 오류 없이 모델을 불러올 수 있습니다. 다음 코드와 같이 모델을 감싸줍니다.
model = nn.DataParallel(model)
model.load_state_dict(torch.load(path))
module.
접두사 제거:
모델을 불러올 때 nn.DataParallel
을 사용하지 않을 경우, 저장된 모델 state_dict에서 module.
접두사를 제거해야 합니다. 다음 코드와 같이 collections.OrderedDict
를 사용하여 접두사를 제거할 수 있습니다.
state_dict = torch.load(path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
strict=False
사용:
model.load_state_dict
함수의 strict
매개변수를 False
로 설정하면 모델 이름에 module.
접두사가 없더라도 오류 없이 모델을 불러올 수 있습니다. 하지만, 모델 구조가 변경되었거나 일부 키가 누락된 경우 예상치 못한 결과가 발생할 수 있습니다.
model.load_state_dict(torch.load(path), strict=False)
주의 사항:
nn.DataParallel
을 사용하지 않고 모델을 불러올 경우, 모델 이름에module.
접두사가 있는지 확인해야 합니다.collections.OrderedDict
를 사용하여module.
접두사를 제거하는 방법은 모델 구조가 변경되지 않았을 때만 사용할 수 있습니다.strict=False
를 사용하면 모델 구조가 변경되었거나 일부 키가 누락된 경우 예상치 못한 결과가 발생할 수 있습니다.
예제 코드
import torch
from collections import OrderedDict
# 모델 정의
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 10)
# 모델 학습
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
# ...
# 모델 저장
torch.save(model.state_dict(), "model.pth")
# 모델 불러오기 (nn.DataParallel 사용)
model = MyModel()
model = nn.DataParallel(model)
model.load_state_dict(torch.load("model.pth"))
# 모델 불러오기 (module. 접두사 제거)
model = MyModel()
state_dict = torch.load("model.pth")
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# 모델 불러오기 (strict=False 사용)
model = MyModel()
model.load_state_dict(torch.load("model.pth"), strict=False)
- 이 코드는 예시이며, 실제 코드는 상황에 따라 변경될 수 있습니다.
- 모델 학습 및 저장 코드는 생략되었습니다.
KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict' 오류 해결 방법
모델 이름 변경:
모델을 저장하기 전에 모델 이름을 module
로 변경합니다. 다음 코드와 같이 setattr
함수를 사용하여 모델 이름을 변경할 수 있습니다.
setattr(model, "module", model)
torch.save(model.state_dict(), "model.pth")
모델을 불러올 때는 getattr
함수를 사용하여 모델 이름을 원래 이름으로 되돌립니다.
model = torch.load("model.pth")
model = getattr(model, "module")
torch.jit.load 사용:
torch.jit.load
함수를 사용하여 모델을 불러오면 module.
접두사가 없더라도 오류 없이 모델을 불러올 수 있습니다. 하지만, 모델을 불러온 후에는 모델을 수정할 수 없습니다.
model = torch.jit.load("model.pth")
pickle 사용:
pickle
모듈을 사용하여 모델을 저장하고 불러올 수 있습니다. pickle
은 모델을 직렬화하여 저장하기 때문에 모델 이름에 대한 정보가 유지되지 않습니다.
import pickle
with open("model.pkl", "wb") as f:
pickle.dump(model, f)
with open("model.pkl", "rb") as f:
model = pickle.load(f)
- 모델 이름을 변경하는 방법은 모델 구조가 변경되지 않았을 때만 사용할 수 있습니다.
torch.jit.load
를 사용하면 모델을 수정할 수 없습니다.pickle
은 모델을 직렬화하여 저장하기 때문에 모델 구조가 변경되거나 일부 키가 누락된 경우 예상치 못한 결과가 발생할 수 있습니다.
pytorch