AI Development/ONNX

[ONNX] Pytorch 모델을 ONNX 모델로 변환하기

꾸준희 2020. 8. 23. 20:53
728x90
반응형

 

 

ONNX 모델은 여러 다양한 플랫폼과 하드웨어에서 효율적인 추론을 가능하게 한다. 여기서 하드웨어는 리눅스, 윈도우, 맥 뿐만 아니라 여러 CPU, GPU 등의 하드웨어를 뜻한다. 

 

 

 

 

 

ONNX 모델 변환을 위해 필요한 import 문은 다음과 같다. 

# 필요한 import문
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

 

 

 

 

예제 모델은 아래에서 소개된 모델을 기반으로 한다. 

“Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network” - Shi et a

 

Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network

Recently, several models based on deep neural networks have achieved great success in terms of both reconstruction accuracy and computational performance for single image super-resolution. In these methods, the low resolution (LR) input image is upscaled t

arxiv.org

# PyTorch에서 구현된 초해상도 모델
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# 위에서 정의된 모델을 사용하여 초해상도 모델 생성
torch_model = SuperResolutionNet(upscale_factor=3)

 

 

 

 

 

 

변환하고자 하는 pytorch 의 모델(모델 구조, 가중치)을 준비하고, 아래와 같이 모델을 변환하기 전에 모델을 추론 모드로 바꾸기 위해서 torch_model.eval() 또는 torch_model.train(False) 를 호출하는 것이 필요하다.

# 미리 학습된 가중치를 읽어옵니다
model_url = 'model.pth'
batch_size = 1    # 임의의 수

# 모델을 미리 학습된 가중치로 초기화합니다
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# 모델을 추론 모드로 전환합니다
torch_model.eval()

 

 

 

 

 

위와 같이 모델을 변환할 준비가 되었다면 아래와 같이 torch.onnx.export 를 이용하여 변환을 수행한다.

특히 변환 시 random torch 값을 입력 값으로 넣어주는데, 이 텐서의 값은 알맞은 자료형과 모양이라면 랜덤하게 결정되어도 무방하다. 특정 차원을 동적인 차원으로 지정하지 않는 이상, ONNX로 변환된 그래프의 경우 입력값의 사이즈는 모든 차원에 대해 고정된다. 

x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)

# 모델 변환
torch.onnx.export(torch_model,               # 실행될 모델
                  x,                         # 모델 입력값 (튜플 또는 여러 입력값들도 가능)
                  "super_resolution.onnx",   # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능)
                  export_params=True,        # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
                  opset_version=10,          # 모델을 변환할 때 사용할 ONNX 버전
                  do_constant_folding=True,  # 최적하시 상수폴딩을 사용할지의 여부
                  input_names = ['input'],   # 모델의 입력값을 가리키는 이름
                  output_names = ['output'], # 모델의 출력값을 가리키는 이름
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 가변적인 길이를 가진 차원
                                'output' : {0 : 'batch_size'}})

 

 

 

 

 

 

아래와 같이 onnx api 를 이용하여 onnx 모델을 확인 할 수 있다.

import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

 

 

 

 

또한, 아래 사이트에서 ONNX Visualization 도 가능하다.

https://lutzroeder.github.io/netron/

 

Netron

Version {version} This app uses cookies to report errors and anonymous usage information. Accept Open Model… Download App . . .

lutzroeder.github.io

 

 

 

 

참고자료 1 : https://tutorials.pytorch.kr/advanced/super_resolution_with_onnxruntime.html

 

(선택) PyTorch 모델을 ONNX으로 변환하고 ONNX 런타임에서 실행하기 — PyTorch Tutorials 1.6.0 documentation

Note Click here to download the full example code (선택) PyTorch 모델을 ONNX으로 변환하고 ONNX 런타임에서 실행하기 이 튜토리얼에서는 어떻게 PyTorch에서 정의된 모델을 ONNX 형식으로 변환하고 또 어떻게 그 ��

tutorials.pytorch.kr

참고자료 2 : https://github.com/onnx/onnx

 

onnx/onnx

Open standard for machine learning interoperability - onnx/onnx

github.com

 

728x90
반응형
1 2 3 4 5 6 7 8