AI Development/PyTorch

[PyTorch] PyTorch 모델을 저장하는 방법 및 고려해야할 점

꾸준희 2022. 3. 23. 19:07
728x90
반응형

 

PyTorch 모델을 학습 한 뒤, 모델을 저장하고 불러오는 방법은 다음과 같다. 

 

모델 저장하고 불러오기 

다음과 같이 PyTorch 모델은 학습한 매개변수를 state_dict 이라고 불리는 internal state dictionary에 저장한다. 이 state 값들은 torch.save 함수를 이용하여 저장 할 수 있다고 한다. 모델의 가중치를 불러와서 저장하려면 저장하려는 모델의 인스턴스를 생성한 다음 load_state_dict() 함수를 사용하여 매개변수를 불러온다. 참고로 state_dict은 dictionary 이며 이 형태에 맞게 데이터를 저장하거나 불러오는 것이 가능하다. 이는 각 계층을 매개변수 Tensor로 매핑하며 학습 가능한 매개변수를 갖는 계층(conv layer, linear layer, ...) 등이 모델의 state_dict에 항목을 가지게 된다.

 

여기서 중요한 점은 모델을 추론하기 전에 model.eval() 함수를 호출하여 dropout과 batch normalization을 evaluation mode로 설정해야한다. 그렇지 않을 경우 일관성 없는 추론 결과가 생성된다. 

 

모델을 저장하고 불러오는 방법에는 2가지가 방법이 있다. 저장한 방식에 따라 불러오는 방식도 달라진다. 참고로 PyTorch 에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용한다. 그리고 tar 를 통한 압축 형태로 *.pth.tar 와 같이 많이 사용한다. 

 

 

1. 모델의 형태를 포함하여 저장하는 방법

torch.save(model, 'model.pth')
torch.load('model.pth')

 

 

2. 학습된 모델의 매개변수(state_dict)만 저장하는 방법

torch.save(model.state_dict(), 'model.pth')
model.load_state_dict(torch.load('model.pth'))

 

 

 

다음은 모델을 저장하는 예시 코드이다. 이는 모델의 매개변수만 저장하는 후자의 방식을 사용하여 저장한다. 보통 모델 저장 시 모델의 형태를 포함하여 저장하는 전자의 경우 보다는 state_dict()을 저장하는 방식을 권장한다고 한다. 

 

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

model = models.vgg16() # 기본 가중치를 불러오지 않으므로 pretrained=True를 지정하지 않음
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

 

참고로 PyTorch 1.6 버전에서는 new zipfile-based file format을 사용하도록 torch.save 함수가 변경되었다고 한다. torch.load는 예전 방식의 파일들을 읽어올 수 있다고 한다. 만약 torch.save 함수를 예전 방식대로 사용하고 싶다면 _use_new_zipfile_serialization=False 이라는 옵션을 주면 된다고 한다. 

 

 

 


 

 

모델을 저장할 때 static_dict()을 이용하여 모델을 저장하는 방식이 권장되는 이유는 전체 모델 형태를 포함하여 저장하게 되면 모델의 파라미터 뿐만 아니라 Optimizer, Epoch 등 모든 상태를 저장하는 것이다. 이는 당연히 저장되는 모델의 크기가 커질 수 있다. 경험 상 2배~3배 정도 차이나는 듯 하다. 

 

이러한 방법으로 저장되는 모델은 pickle 형태로 저장되는데, 이는 모델 클래스 자체를 저장하는 것이 아니라 클래스를 포함하는 파일에 대한 경로를 저장한다. 이 경로는 로드 시점에 사용되며, 이로 인해 이 모델을 다른 프로젝트에서 사용하게 될 경우 정상적으로 모델이 로드가 안될 수도 있다고 한다. 간단히 말하면 serialization을 제대로 하지 못할 수도 있게 된다. 

 

또한 어떤 블로그에 의하면 모델 파라미터를 serialization 하여 저장 할 때 PyTorch 버전에 따라 구조가 바뀔 수도 있다고 한다. 생각해보면 PyTorch 버전에 따라 모델에 포함되는 모듈들이 조금씩 다르게 구현되는 경우가 있는데 모델 구조를 그대로 포함하여 Serialization 할 경우 다른 버전에서 사용하게 된다면 문제가 생길 것 같다. 

 

반면 state_dict 을 사용하는 방식은 매개변수가 담겨있는 딕셔너리이며, weight 와 bias 가 포함되어 있다. 이는 모델의 형태가 저장되어 있지 않기 때문에 모델 구현 즉 모델 클래스가 코드 상에 존재하고, 이를 로드해야 사용할 수 있다. state_dict 만 저장하면 모델의 용량이 가벼워진다는 장점이 있다. 

 

 

 

 

 

 

 

PyTorch 모델 저장 관련 참고자료 

 

참고자료 1 : 모델 저장하기 & 불러오기

참고자료 2 : 모델 저장하고 불러오기

참고자료 3 : PYTORCH에서 일반적인 체크포인트(CHECKPOINT) 저장하기 & 불러오기

참고자료 4 : PYTORCH에서 여러 모델을 하나의 파일에 저장하기 & 불러오기

참고자료 5 : https://jdjin3000.tistory.com/17 

참고자료 6 : https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

참고자료 7 : https://gaussian37.github.io/dl-pytorch-snippets/

 

 

 

 

 

 

 

728x90
반응형