[ONNX] Pytorch 에서 Onnx 로 변환 시 Gather op 때문에 export 안되는 문제
|2020. 4. 28. 11:00
728x90
반응형
Pytorch 에서 Onnx 모델로 변환시 Gather 와 같은 옵션 때문에 변환이 안되는 문제가 발생한다.
이유는 Onnx 에서 Gather 라는 옵션을 지원하지 않기 때문이다. (2020.04.28 기준)
아래와 같이 interpolate 를 scale factor 로 구현하였을 때
class ResizeModel(nn.Module):
def __init__(self):
super(ResizeModel, self).__init__()
def forward(self, x):
return F.interpolate(x, scale_factor=(2, 2), mode='nearest')
생성되는 그래프는 아래와 같다
위와 같이 구현할 경우 "Gather" 라는 op 이 붙어서 그래프가 생성되는데,
Onnx 에서는 Gather 라는 옵션을 지원하지 않기 때문에 (2020.04.28 기준)
아래와 같이 interpolate의 scale factor 를 구현해야한다.
class ResizeModel(nn.Module):
def __init__(self):
super(ResizeModel, self).__init__()
def forward(self, x):
sh = torch.tensor(x.shape)
return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')
그러면 Gather 옵션을 피해서 아래와 같이 그래프가 생성된다.
참고자료 : https://github.com/onnx/onnx-tensorrt/issues/192
Gather in Upsample problem · Issue #192 · onnx/onnx-tensorrt
Hi! Cant export model from onnx to tensorrt. `---------------------------------------------------------------- Input filename: model.onnx ONNX IR version: 0.0.4 Opset version: 9 Producer name: pyto...
github.com
728x90
반응형
'Development & Tools > Frameworks & Libraries' 카테고리의 다른 글
[TensorRT] TensorRT 및 Tensor Core에서 NCHW vs NHWC 형식의 성능 차이 (1) | 2020.04.29 |
---|---|
[TensorRT] 지원되는 연산자 목록 (2020.04.29 기준) (0) | 2020.04.29 |
[TensorRT] NVIDIA TensorRT 개념, 설치방법, 사용하기 (17) | 2020.04.21 |
[TensorRT] AttributeError: 'NoneType' object has no attribute 'serialize' (2) | 2020.03.12 |
[Pytorch] 장치간 모델 불러오기 (GPU / CPU) (1) | 2020.01.09 |