首页>>人工智能->RealBasicVSR如何将模型转换成onnx

RealBasicVSR如何将模型转换成onnx

时间:2023-11-30 本站 点击:1

背景

为了方便RealBasicVSR部署到其他平台,尝试将RealBasicVSR的模型转换成onnx格式,记录一下过程中遇到的问题和解决方案

导出onnx模型

我直接在inference_realbasicvsr.py中的outputs = model(inputs, test_mode=True)['output'].cpu()后面加入下列代码导出onnx模型

torch.onnx.export(model,                inputs,                'model.onnx',                input_names= ['input'],                output_names=['output'],                opset_version=11,                dynamic_axes={'input' : {0 : 'batch_size', 3 : 'w', 4 : 'h'}, 'output' : {0 : 'batch_size', 3 : 'dstw', 4 : 'dsth'}})

通过dynamic_axes设置输入输出为动态大小

使用ONNX模型

上一个最小可用的例子,就是简单的加载图片和模型,然后处理,显示和保存,使用了numpy的一些方法让数据的shape和模型输入保持一致

import impimport onnxruntime as ortimport numpy as npimport onnximport cv2img = cv2.imread("./data/demos/anim2.jpg")img = np.expand_dims((img/255.0).astype(np.float32).transpose(2,0,1), axis=0)imgs = np.array([img])print(imgs.shape)onnx_model = onnx.load_model("model.onnx")sess = ort.InferenceSession(onnx_model.SerializeToString())sess.set_providers(['CPUExecutionProvider'])input_name = sess.get_inputs()[0].nameoutput_name = sess.get_outputs()[1].nameprint(sess.get_outputs()[0].shape)output = sess.run([output_name], {input_name : imgs})print(output[0].shape)res = output[0][0][0].transpose(1, 2, 0)cv2.imshow("res", res)cv2.imwrite("results/anims/anim03.jpg", (res * 255).astype(np.uint8))cv2.waitKey(0)

遇到的问题

SRGAN model does not support forward_train function.

这个按照我的理解是test_mode字段没有传下去,于是直接把下面的代码改了下,原本test_mode默认为False,我改成了True,不确定是否有其他影响,只能说目前可用

@auto_fp16(apply_to=('lq', ))def forward(self, lq, gt=None, test_mode=True, **kwargs):    """Forward function.    Args:        lq (Tensor): Input lq images.        gt (Tensor): Ground-truth image. Default: None.        test_mode (bool): Whether in test mode or not. Default: False.        kwargs (dict): Other arguments.    """    if test_mode:        return self.forward_test(lq, gt, **kwargs)    raise ValueError(        'SRGAN model does not support `forward_train` function.')

output出来的数据shape和输入的一样

这个找了很久,最后发现模型其实有2个输出,应该导出模型的时候都命名,然后选择第二个就好了。上面的使用代码中直接选择了第二个output

output_name = sess.get_outputs()[1].name
原文:https://juejin.cn/post/7101652566005514248


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:/AI/2545.html