AI Development/ONNX

[ONNX] ONNX Simplifier 사용하여 모델 간소화 하기

꾸준희 2021. 7. 26. 18:13



ONNX Simplifier 는 복잡한 ONNX node 들 즉 ONNX 모델을 단순하게 만들어주는 툴이다. 

전체 계산 그래프(the whole computation graph)를 추론한 다음 중복 연산자(the redundant operators)를 상수 출력(their constant outputs)으로 대체한다. 



아래 그림의 왼쪽 그림은 어떤 모델의 ONNX 원본 모델이고, 오른쪽 그림은 onnx simplifier를 거친 ONNX 모델이다. 

(잘 보이진 않지만... 자세히 들여다보면 간소화 된 모습을 볼 수 있었다... 모델 크기도 줄어든다.)






아래 그림을 보면 더 명확히 와닿는다. 이런 느낌이다.

불필요한 Gather 및 Unsqueeze 연산자 등의 조합을 Reshape 으로 대체한다. 

원래 이런 툴이 나오기 전에는 일일히 노드들을 Reshape 으로 대체하곤 했다.. 








ONNX Simplifier 설치

$ pip3 install -U pip && pip3 install onnx-simplifier


ONNX Simplifier 사용

$ python3 -m onnxsim input_onnx_model output_onnx_model


ONNX Simplifier 옵션 보기

$ python3 -m onnxsim -h




API 는 아래와 같이 구성되어있다. 

import argparse
import sys

import onnx     # type: ignore
import onnxsim
import numpy as np

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input_model', help='Input ONNX model')
    parser.add_argument('output_model', help='Output ONNX model')
    parser.add_argument('check_n', help='Check whether the output is correct with n random inputs',
                        nargs='?', type=int, default=3)
    parser.add_argument('--enable-fuse-bn', help='This option is deprecated. Fusing bn into conv is enabled by default.',
    parser.add_argument('--skip-fuse-bn', help='Skip fusing batchnorm into conv.',
    parser.add_argument('--skip-optimization', help='Skip optimization of ONNX optimizers.',
        '--input-shape', help='The manually-set static input shape, useful when the input shape is dynamic. The value should be "input_name:dim0,dim1,...,dimN" or simply "dim0,dim1,...,dimN" when there is only one input, for example, "data:1,3,224,224" or "1,3,224,224". Note: you might want to use some visualization tools like netron to make sure what the input name and dimension ordering (NCHW or NHWC) is.', type=str, nargs='+')
        '--skip-optimizer', help='Skip a certain ONNX optimizer', type=str, nargs='+')
                        help='Skip shape inference. Shape inference causes segfault on some large models', action='store_true')
    parser.add_argument('--dynamic-input-shape', help='This option enables dynamic input shape support. "Shape" ops will not be eliminated in this case. Note that "--input-shape" is also needed for generating random inputs and checking equality. If "dynamic_input_shape" is False, the input shape in simplified model will be overwritten by the value of "input_shapes" param.', action='store_true')
        '--input-data-path', help='input data, The value should be "input_name1:xxx1.bin"  "input_name2:xxx2.bin ...", input data should be a binary data file.', type=str, nargs='+')
        '--custom-lib', help="custom lib path which should be absolute path, if you have custom onnxruntime backend you should use this to register you custom op", type=str)

    args = parser.parse_args()


    if args.dynamic_input_shape and args.input_shape is None:
        raise RuntimeError(
            'Please pass "--input-shape" argument for generating random input and checking equality. Run "python3 -m onnxsim -h" for details.')
    if args.input_shape is not None and not args.dynamic_input_shape:
        print("Note: The input shape of the simplified model will be overwritten by the value of '--input-shape' argument. Pass '--dynamic-input-shape' if it is not what you want. Run 'python3 -m onnxsim -h' for details.")
    input_shapes = dict()
    if args.input_shape is not None:
        for x in args.input_shape:
            if ':' not in x:
                input_shapes[None] = list(map(int, x.split(',')))
                pieces = x.split(':')
                # for the input name like input:0
                name, shape = ':'.join(
                    pieces[:-1]), list(map(int, pieces[-1].split(',')))
                input_shapes.update({name: shape})

    input_data_paths = dict()
    if args.input_data_path is not None:
        for x in args.input_data_path:
            pieces = x.split(':')
            name, data = ':'.join(pieces[:-1]), pieces[-1]
            input_data_paths.update({name: data})

    input_tensors = dict()
    if len(input_data_paths) > 0 and args.input_shape is not None:
        for name in input_shapes.keys():
            input_data = np.fromfile(input_data_paths[name], dtype=np.float32)
            input_data = input_data.reshape(input_shapes[name])
            input_tensors.update({name: input_data})

    model_opt, check_ok = onnxsim.simplify(
        perform_optimization=not args.skip_optimization,
        custom_lib=args.custom_lib), args.output_model)

    if check_ok:
        print("Check failed, please be careful to use the simplified model, or try specifying \"--skip-fuse-bn\" or \"--skip-optimization\" (run \"python3 -m onnxsim -h\" for details)")

if __name__ == '__main__':






참고자료 :


GitHub - daquexian/onnx-simplifier: Simplify your onnx model

Simplify your onnx model. Contribute to daquexian/onnx-simplifier development by creating an account on GitHub.