728x90
반응형
pre-trained model의 일부분만 바꿔서 모델을 학습시키려고 할 때
pre-trained model의 파라미터 일부를 가져와 현재 모델의 state dict을 업데이트 하여 학습해야한다.
즉, pre-trained model의 파라미터를 로드하고, 현재 모델에 덮어 씌워주는 과정이다.
먼저, 아래와 같이 pre-trained model과 현재 model을 각각 로드해준다.
pretrained_dict = torch.load(PATH) # pretrained 모델 static dict 로드
model_dict = model.state_dict() # 현재 모델 state dict 로드
그 다음 아래와 같이 pre-trained model의 값들 중에서 현재 model의 state dict과 일치하는 값들만 넣어준다.
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
그 다음 현재 모델의 state dict을 pre-trained model의 state dict으로 업데이트 해주고 로드해준다.
model_dict.update(pretrained_dict)
model.load_state_dict(pretrained_dict)
728x90
반응형
'Development & Tools > Frameworks & Libraries' 카테고리의 다른 글
[PyTorch] contiguous 연산의 필요성, Grad strides do not match bucket view strides (1) | 2023.12.04 |
---|---|
[TFLite] TensorFlow Lite 개념 (0) | 2023.01.18 |
[ONNX] Brevitas, QAT 모델을 Standard ONNX 모델로 생성하는 라이브러리 (0) | 2022.07.04 |
[NVIDIA TAO Toolkit] TAO Toolkit 개요 (1) | 2022.05.03 |
[PyTorch] PyTorch 모델을 저장하는 방법 및 고려해야할 점 (2) | 2022.03.23 |