728x90
반응형
pre-trained model의 일부분만 바꿔서 모델을 학습시키려고 할 때
pre-trained model의 파라미터 일부를 가져와 현재 모델의 state dict을 업데이트 하여 학습해야한다.
즉, pre-trained model의 파라미터를 로드하고, 현재 모델에 덮어 씌워주는 과정이다.
먼저, 아래와 같이 pre-trained model과 현재 model을 각각 로드해준다.
pretrained_dict = torch.load(PATH) # pretrained 모델 static dict 로드
model_dict = model.state_dict() # 현재 모델 state dict 로드
그 다음 아래와 같이 pre-trained model의 값들 중에서 현재 model의 state dict과 일치하는 값들만 넣어준다.
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
그 다음 현재 모델의 state dict을 pre-trained model의 state dict으로 업데이트 해주고 로드해준다.
model_dict.update(pretrained_dict)
model.load_state_dict(pretrained_dict)
728x90
반응형
'AI Development > PyTorch' 카테고리의 다른 글
[PyTorch] contiguous 연산의 필요성, Grad strides do not match bucket view strides (1) | 2023.12.04 |
---|---|
[PyTorch] PyTorch 모델을 저장하는 방법 및 고려해야할 점 (2) | 2022.03.23 |
[PyTorch] 파이토치에서 TensorBoard 사용하기 (3) | 2022.02.04 |
[Pytorch] 이미지 데이터세트에 대한 평균(mean)과 표준편차(std) 구하기 (0) | 2021.05.11 |
[Pytorch] 파이토치 시간 측정, How to measure time in PyTorch (5) | 2020.07.27 |