[Deep Learning] Pre-trained model로 gray image를 학습하는 방법
보통 pre-trained model(ResNet 50, ...)의 경우 RGB 3채널을 가지는 color image 를 기반으로 학습을 진행하기 때문에 모델의 입력 정보는 (height, weight, channel=3) 으로 이루어지게 된다. 따라서 이러한 모델에 임의로 1채널을 가지는 gray image를 입력할 경우 shape error 가 발생하게 된다. 이는 단순히 모델의 첫 레이어의 채널을 변경한다고해서 해결되지 않는다. 정확히 말하면 에러는 해결되지만 학습 진행이 안된다.
따라서 pre-trained model로 gray image를 학습시키는 법은 다음과 같다.
1. 첫번째 conv layer 의 채널을 1로 변경하기
2. pretrained weight load 시 1채널로 변경된 첫번째 conv layer의 weight를 하나로 sum 시켜준다.
단, 이렇게 변경하여 학습 진행 후, 저장된 모델을 다시 로드하여 fine-tuning 진행 시
모델의 채널을 기존 3채널에서 1채널로 바꾸어주어야한다.
아래 코드를 참고하면 내용 이해가 잘 될 것이다. 아래 코드는 fastai의 코드이다.
def _load_pretrained_weights(new_layer, previous_layer):
"Load pretrained weights based on number of input channels"
n_in = getattr(new_layer, 'in_channels')
if n_in==1:
# we take the sum
new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)
elif n_in==2:
# we take first 2 channels + 50%
new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5
else:
# keep 3 channels weights and set others to null
new_layer.weight.data[:,:3] = previous_layer.weight.data
new_layer.weight.data[:,3:].zero_()
def _update_first_layer(model, n_in, pretrained):
"Change first layer based on number of input channels"
if n_in == 3: return
first_layer, parent, name = _get_first_layer(model)
assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'
assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, "in_channels")} while expecting 3'
params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}
params['bias'] = getattr(first_layer, 'bias') is not None
params['in_channels'] = n_in
new_layer = nn.Conv2d(**params)
if pretrained:
_load_pretrained_weights(new_layer, first_layer)
setattr(parent, name, new_layer)
참고로 stackoverflow 에 있는 질문이 가장 많은 도움이 되었다.
성능은 RGB image > Gray image(3ch) > Gray image(1ch) 이라고 한다.
Unmodified ResNet50 w/ RGB Images : Prec @1: 75.6, Prec @5: 92.8
Unmodified ResNet50 w/ 3-chan Grayscale Images: Prec @1: 64.6, Prec @5: 86.4
Modified 1-chan ResNet50 w/ 1-chan Grayscale Images: Prec @1: 63.8, Prec @5: 86.1
이 외에도 gray image를 RGB 처럼 3채널로 만드는 방법이 있다. 이는 단순 1채널 이미지를 3채널 형식에 복사하는 형식이다. 또한 누군가가 gray image로 학습시켜놓은 pre-trained weight를 사용하는 방법도 있다. Github에 검색해보면 몇 개 나오는 것 같다.
참고자료 1 : https://monghead.blogspot.com/2020/12/error-deep-learning-pretrained-model.html
참고자료 2 : https://github.com/zhaoyuzhi/PyTorch-Special-Pre-trained-Models
참고자료 3 : https://github.com/fastai/fastai/
'AI Research Topic > Deep Learning' 카테고리의 다른 글
[Paper Reveiw] NMS Strikes Back (0) | 2023.02.22 |
---|---|
[Paper Review] Attention Mechanisms in Computer Vision, A Survey (0) | 2022.07.05 |
[Paper Review] ResNet strikes back: An improved training procedure in timm (0) | 2022.03.30 |
[Deep Learning] Weight Standardization (+ 2D, 3D 구현 방법) (0) | 2020.12.20 |
[Paper Review] DCNv2 : Deformable Convolutional Networks v2 (3) | 2020.11.01 |