AI Development/ONNX

[ONNX] Pytorch 에서 Onnx 로 변환 시 Gather op 때문에 export 안되는 문제

꾸준희 2020. 4. 28. 11:00
728x90
반응형

Pytorch 에서 Onnx 모델로 변환시 Gather 와 같은 옵션 때문에 변환이 안되는 문제가 발생한다. 

이유는 Onnx 에서 Gather 라는 옵션을 지원하지 않기 때문이다. (2020.04.28 기준)

 

아래와 같이 interpolate 를 scale factor 로 구현하였을 때 

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        return F.interpolate(x, scale_factor=(2, 2), mode='nearest')

생성되는 그래프는 아래와 같다

 

 

 

위와 같이 구현할 경우 "Gather" 라는 op 이 붙어서 그래프가 생성되는데, 

Onnx 에서는 Gather 라는 옵션을 지원하지 않기 때문에 (2020.04.28 기준)

아래와 같이 interpolate의 scale factor 를 구현해야한다.

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        sh = torch.tensor(x.shape)
        return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')

그러면 Gather 옵션을 피해서 아래와 같이 그래프가 생성된다.

 

 

 

 

 

 

 

 

참고자료 : https://github.com/onnx/onnx-tensorrt/issues/192

 

Gather in Upsample problem · Issue #192 · onnx/onnx-tensorrt

Hi! Cant export model from onnx to tensorrt. `---------------------------------------------------------------- Input filename: model.onnx ONNX IR version: 0.0.4 Opset version: 9 Producer name: pyto...

github.com

 

 

728x90
반응형
1 ··· 4 5 6 7 8