ONNX 모델은 여러 다양한 플랫폼과 하드웨어에서 효율적인 추론을 가능하게 한다. 여기서 하드웨어는 리눅스, 윈도우, 맥 뿐만 아니라 여러 CPU, GPU 등의 하드웨어를 뜻한다.
ONNX 모델 변환을 위해 필요한 import 문은 다음과 같다.
# 필요한 import문
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
예제 모델은 아래에서 소개된 모델을 기반으로 한다.
# PyTorch에서 구현된 초해상도 모델
import torch.nn as nn
import torch.nn.init as init
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
# 위에서 정의된 모델을 사용하여 초해상도 모델 생성
torch_model = SuperResolutionNet(upscale_factor=3)
변환하고자 하는 pytorch 의 모델(모델 구조, 가중치)을 준비하고, 아래와 같이 모델을 변환하기 전에 모델을 추론 모드로 바꾸기 위해서 torch_model.eval() 또는 torch_model.train(False) 를 호출하는 것이 필요하다.
# 미리 학습된 가중치를 읽어옵니다
model_url = 'model.pth'
batch_size = 1 # 임의의 수
# 모델을 미리 학습된 가중치로 초기화합니다
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# 모델을 추론 모드로 전환합니다
torch_model.eval()
위와 같이 모델을 변환할 준비가 되었다면 아래와 같이 torch.onnx.export 를 이용하여 변환을 수행한다.
특히 변환 시 random torch 값을 입력 값으로 넣어주는데, 이 텐서의 값은 알맞은 자료형과 모양이라면 랜덤하게 결정되어도 무방하다. 특정 차원을 동적인 차원으로 지정하지 않는 이상, ONNX로 변환된 그래프의 경우 입력값의 사이즈는 모든 차원에 대해 고정된다.
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)
# 모델 변환
torch.onnx.export(torch_model, # 실행될 모델
x, # 모델 입력값 (튜플 또는 여러 입력값들도 가능)
"super_resolution.onnx", # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능)
export_params=True, # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
opset_version=10, # 모델을 변환할 때 사용할 ONNX 버전
do_constant_folding=True, # 최적하시 상수폴딩을 사용할지의 여부
input_names = ['input'], # 모델의 입력값을 가리키는 이름
output_names = ['output'], # 모델의 출력값을 가리키는 이름
dynamic_axes={'input' : {0 : 'batch_size'}, # 가변적인 길이를 가진 차원
'output' : {0 : 'batch_size'}})
아래와 같이 onnx api 를 이용하여 onnx 모델을 확인 할 수 있다.
import onnx
onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)
또한, 아래 사이트에서 ONNX Visualization 도 가능하다.
https://lutzroeder.github.io/netron/
참고자료 1 : https://tutorials.pytorch.kr/advanced/super_resolution_with_onnxruntime.html
참고자료 2 : https://github.com/onnx/onnx
'AI Development > ONNX' 카테고리의 다른 글
[ONNX] Pytorch 모델을 ONNX 모델로 변환 할 때 dynamic_axes 지정하는 방법 (0) | 2021.06.29 |
---|---|
[ONNX] ONNX 배치 사이즈 변경하는 방법 + 삽질 (3) | 2021.02.02 |
[ONNX] onnx-graphsurgeon 이용하여 plugin 사용하기 - Group Normalization (3) | 2020.07.21 |
[ONNX] Netron : ONNX model Visualization (0) | 2020.07.16 |
[Onnx] Onnx Tutorials (정리중) (0) | 2020.05.14 |