AI Development/PyTorch

[PyTorch] 모델 파라미터 일부만 업데이트 하기

꾸준희 2023. 4. 6. 13:53
728x90
반응형

 

 

 

pre-trained model의 일부분만 바꿔서 모델을 학습시키려고 할 때 

pre-trained model의 파라미터 일부를 가져와 현재 모델의 state dict을 업데이트 하여 학습해야한다. 

즉, pre-trained model의 파라미터를 로드하고, 현재 모델에 덮어 씌워주는 과정이다. 

 

 

먼저, 아래와 같이 pre-trained model과 현재 model을 각각 로드해준다. 

pretrained_dict = torch.load(PATH)  # pretrained 모델 static dict 로드 
model_dict = model.state_dict() # 현재 모델 state dict 로드

 

그 다음 아래와 같이 pre-trained model의 값들 중에서 현재 model의 state dict과 일치하는 값들만 넣어준다. 

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 

그 다음 현재 모델의 state dict을 pre-trained model의 state dict으로 업데이트 해주고 로드해준다. 

model_dict.update(pretrained_dict) 
model.load_state_dict(pretrained_dict)
728x90
반응형