背景
为了方便RealBasicVSR部署到其他平台,尝试将RealBasicVSR的模型转换成onnx格式,记录一下过程中遇到的问题和解决方案
导出onnx模型
我直接在inference_realbasicvsr.py
中的outputs = model(inputs, test_mode=True)['output'].cpu()
后面加入下列代码导出onnx模型
torch.onnx.export(model, inputs, 'model.onnx', input_names= ['input'], output_names=['output'], opset_version=11, dynamic_axes={'input' : {0 : 'batch_size', 3 : 'w', 4 : 'h'}, 'output' : {0 : 'batch_size', 3 : 'dstw', 4 : 'dsth'}})
通过dynamic_axes
设置输入输出为动态大小
使用ONNX模型
上一个最小可用的例子,就是简单的加载图片和模型,然后处理,显示和保存,使用了numpy的一些方法让数据的shape和模型输入保持一致
import impimport onnxruntime as ortimport numpy as npimport onnximport cv2img = cv2.imread("./data/demos/anim2.jpg")img = np.expand_dims((img/255.0).astype(np.float32).transpose(2,0,1), axis=0)imgs = np.array([img])print(imgs.shape)onnx_model = onnx.load_model("model.onnx")sess = ort.InferenceSession(onnx_model.SerializeToString())sess.set_providers(['CPUExecutionProvider'])input_name = sess.get_inputs()[0].nameoutput_name = sess.get_outputs()[1].nameprint(sess.get_outputs()[0].shape)output = sess.run([output_name], {input_name : imgs})print(output[0].shape)res = output[0][0][0].transpose(1, 2, 0)cv2.imshow("res", res)cv2.imwrite("results/anims/anim03.jpg", (res * 255).astype(np.uint8))cv2.waitKey(0)
遇到的问题
SRGAN model does not support forward_train
function.
这个按照我的理解是test_mode字段没有传下去,于是直接把下面的代码改了下,原本test_mode默认为False,我改成了True,不确定是否有其他影响,只能说目前可用
@auto_fp16(apply_to=('lq', ))def forward(self, lq, gt=None, test_mode=True, **kwargs): """Forward function. Args: lq (Tensor): Input lq images. gt (Tensor): Ground-truth image. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if test_mode: return self.forward_test(lq, gt, **kwargs) raise ValueError( 'SRGAN model does not support `forward_train` function.')
output出来的数据shape和输入的一样
这个找了很久,最后发现模型其实有2个输出,应该导出模型的时候都命名,然后选择第二个就好了。上面的使用代码中直接选择了第二个output
output_name = sess.get_outputs()[1].name原文:https://juejin.cn/post/7101652566005514248