AI Development/PyTorch
[PyTorch] 모델 파라미터 일부만 업데이트 하기
꾸준희
2023. 4. 6. 13:53
728x90
반응형
pre-trained model의 일부분만 바꿔서 모델을 학습시키려고 할 때
pre-trained model의 파라미터 일부를 가져와 현재 모델의 state dict을 업데이트 하여 학습해야한다.
즉, pre-trained model의 파라미터를 로드하고, 현재 모델에 덮어 씌워주는 과정이다.
먼저, 아래와 같이 pre-trained model과 현재 model을 각각 로드해준다.
pretrained_dict = torch.load(PATH) # pretrained 모델 static dict 로드
model_dict = model.state_dict() # 현재 모델 state dict 로드
그 다음 아래와 같이 pre-trained model의 값들 중에서 현재 model의 state dict과 일치하는 값들만 넣어준다.
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
그 다음 현재 모델의 state dict을 pre-trained model의 state dict으로 업데이트 해주고 로드해준다.
model_dict.update(pretrained_dict)
model.load_state_dict(pretrained_dict)
728x90
반응형