Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| from functools import partial | |
| from typing import Callable | |
| import numpy as np | |
| import onnx | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| from mmcv.tensorrt import (TRTWrapper, is_tensorrt_plugin_loaded, onnx2trt, | |
| save_trt_engine) | |
| except ImportError: | |
| pytest.skip( | |
| 'TensorRT should be installed from source.', allow_module_level=True) | |
| if not torch.cuda.is_available(): | |
| pytest.skip( | |
| 'CUDA is required for this test module', allow_module_level=True) | |
| if not is_tensorrt_plugin_loaded(): | |
| pytest.skip( | |
| 'Test requires to complie TensorRT plugins in mmcv', | |
| allow_module_level=True) | |
| class WrapFunction(nn.Module): | |
| def __init__(self, wrapped_function): | |
| super().__init__() | |
| self.wrapped_function = wrapped_function | |
| def forward(self, *args, **kwargs): | |
| return self.wrapped_function(*args, **kwargs) | |
| onnx_file = 'tmp.onnx' | |
| trt_file = 'tmp.engine' | |
| def test_roialign(): | |
| try: | |
| from mmcv.ops import RoIAlign | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| # roi align config | |
| pool_h = 2 | |
| pool_w = 2 | |
| spatial_scale = 1.0 | |
| sampling_ratio = 2 | |
| inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), | |
| ([[[[1., 2.], [3., 4.]], [[4., 3.], | |
| [2., 1.]]]], [[0., 0., 0., 1., 1.]]), | |
| ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], | |
| [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])] | |
| wrapped_model = RoIAlign((pool_w, pool_h), spatial_scale, sampling_ratio, | |
| 'avg', True).cuda() | |
| for case in inputs: | |
| np_input = np.array(case[0], dtype=np.float32) | |
| np_rois = np.array(case[1], dtype=np.float32) | |
| input = torch.from_numpy(np_input).cuda() | |
| rois = torch.from_numpy(np_rois).cuda() | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (input, rois), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=['input', 'rois'], | |
| output_names=['roi_feat'], | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(input.shape), | |
| list(input.shape), | |
| list(input.shape)], | |
| 'rois': [list(rois.shape), | |
| list(rois.shape), | |
| list(rois.shape)] | |
| } | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, ['input', 'rois'], ['roi_feat']) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': input, 'rois': rois}) | |
| trt_roi_feat = trt_outputs['roi_feat'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_roi_feat = wrapped_model(input, rois) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_roi_feat, trt_roi_feat) | |
| def test_nms(): | |
| try: | |
| import mmcv | |
| from mmcv.ops import nms | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| os.environ['ONNX_BACKEND'] = 'MMCVTensorRT' | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| data = mmcv.load('./tests/data/batched_nms_data.pkl') | |
| boxes = torch.from_numpy(data['boxes']).cuda() | |
| scores = torch.from_numpy(data['scores']).cuda() | |
| nms = partial( | |
| nms, iou_threshold=0.7, offset=0, score_threshold=0.1, max_num=100) | |
| wrapped_model = WrapFunction(nms) | |
| wrapped_model.cpu().eval() | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (boxes.detach().cpu(), scores.detach().cpu()), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=['boxes', 'scores'], | |
| output_names=['dets', 'inds'], | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'boxes': [list(boxes.shape), | |
| list(boxes.shape), | |
| list(boxes.shape)], | |
| 'scores': [list(scores.shape), | |
| list(scores.shape), | |
| list(scores.shape)] | |
| } | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, ['boxes', 'scores'], ['dets', 'inds']) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'boxes': boxes, 'scores': scores}) | |
| trt_dets = trt_outputs['dets'] | |
| trt_inds = trt_outputs['inds'] | |
| trt_inds = trt_inds.long() | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_outputs = wrapped_model(boxes, scores) | |
| pytorch_dets, pytorch_inds = pytorch_outputs | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| num_boxes = pytorch_dets.shape[0] | |
| trt_dets = trt_dets[:num_boxes, ...] | |
| trt_inds = trt_inds[:num_boxes] | |
| trt_scores = trt_dets[:, 4] | |
| pytorch_scores = pytorch_dets[:, 4] | |
| os.environ.pop('ONNX_BACKEND') | |
| assert torch.allclose(pytorch_scores, trt_scores, atol=1e-3) | |
| assert torch.equal(pytorch_inds, trt_inds) | |
| def test_batched_nms(): | |
| try: | |
| import mmcv | |
| from mmcv.ops import batched_nms | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| # trt config | |
| os.environ['ONNX_BACKEND'] = 'MMCVTensorRT' | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| data = mmcv.load('./tests/data/batched_nms_data.pkl') | |
| nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1) | |
| boxes = torch.from_numpy(data['boxes']).cuda() | |
| scores = torch.from_numpy(data['scores']).cuda() | |
| idxs = torch.from_numpy(data['idxs']).cuda() | |
| class_agnostic = False | |
| nms = partial(batched_nms, nms_cfg=nms_cfg, class_agnostic=class_agnostic) | |
| wrapped_model = WrapFunction(nms) | |
| wrapped_model.cpu().eval() | |
| input_data = (boxes.detach().cpu(), scores.detach().cpu(), | |
| idxs.detach().cpu()) | |
| input_names = ['boxes', 'scores', 'idxs'] | |
| output_names = ['dets', 'inds'] | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, | |
| input_data, | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'boxes': [list(boxes.shape), | |
| list(boxes.shape), | |
| list(boxes.shape)], | |
| 'scores': [list(scores.shape), | |
| list(scores.shape), | |
| list(scores.shape)], | |
| 'idxs': [list(idxs.shape), | |
| list(idxs.shape), | |
| list(idxs.shape)] | |
| } | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({ | |
| 'boxes': boxes, | |
| 'scores': scores, | |
| 'idxs': idxs | |
| }) | |
| trt_dets = trt_outputs['dets'] | |
| trt_inds = trt_outputs['inds'] | |
| trt_inds = trt_inds.long() | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_outputs = wrapped_model(boxes, scores, idxs) | |
| pytorch_dets, pytorch_inds = pytorch_outputs | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| num_boxes = pytorch_dets.shape[0] | |
| trt_dets = trt_dets[:num_boxes, ...] | |
| trt_inds = trt_inds[:num_boxes] | |
| trt_scores = trt_dets[:, 4] | |
| pytorch_scores = pytorch_dets[:, 4] | |
| os.environ.pop('ONNX_BACKEND') | |
| assert torch.allclose(pytorch_scores, trt_scores) | |
| assert torch.equal(pytorch_inds, trt_inds) | |
| def test_scatternd(): | |
| def func(data): | |
| data[:, :-2] += 1 | |
| data[:2, :] -= 1 | |
| return data | |
| data = torch.zeros(4, 4).cuda() | |
| wrapped_model = WrapFunction(func).eval().cuda() | |
| input_names = ['input'] | |
| output_names = ['output'] | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (data.clone(), ), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(data.shape), | |
| list(data.shape), | |
| list(data.shape)], | |
| } | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': data.clone()}) | |
| trt_results = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_results = wrapped_model(data.clone()) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_results, trt_results) | |
| def test_deform_conv(): | |
| try: | |
| from mmcv.ops import DeformConv2dPack | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] | |
| offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], | |
| [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]], | |
| [[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]], | |
| [[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]] | |
| offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7] | |
| deform_weight = [[[0.4, 0.2, 0.1, 0.9]]] | |
| c_in = 1 | |
| c_out = 1 | |
| x = torch.Tensor(input).cuda() | |
| x.requires_grad = True | |
| model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) | |
| model.conv_offset.weight.data = torch.nn.Parameter( | |
| torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) | |
| model.conv_offset.bias.data = torch.nn.Parameter( | |
| torch.Tensor(offset_bias).reshape(8)) | |
| model.weight.data = torch.nn.Parameter( | |
| torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) | |
| model.cuda().eval() | |
| input_names = ['input'] | |
| output_names = ['output'] | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| model, (x.clone(), ), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(x.shape), list(x.shape), | |
| list(x.shape)], | |
| } | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': x.clone()}) | |
| trt_results = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_results = model(x.clone()) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_results, trt_results) | |
| def test_modulated_deform_conv(with_bias): | |
| try: | |
| from mmcv.ops import ModulatedDeformConv2dPack | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] | |
| x = torch.Tensor(input).cuda() | |
| model = ModulatedDeformConv2dPack( | |
| 1, | |
| 1, | |
| kernel_size=(2, 2), | |
| stride=1, | |
| padding=1, | |
| deform_groups=1, | |
| bias=with_bias) | |
| model.weight.data.fill_(1.) | |
| model.type(torch.float32) | |
| model = model.cuda().eval() | |
| input_names = ['input'] | |
| output_names = ['output'] | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| model, (x.clone(), ), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(x.shape), list(x.shape), | |
| list(x.shape)], | |
| } | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': x.clone()}) | |
| trt_results = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_results = model(x.clone()) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| torch.testing.assert_allclose(pytorch_results, trt_results) | |
| def test_grid_sample(mode, padding_mode, align_corners): | |
| from mmcv.onnx.symbolic import register_extra_symbolics | |
| register_extra_symbolics(11) | |
| input = torch.rand(1, 1, 10, 10).cuda() | |
| grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]]) | |
| grid = F.affine_grid(grid, (1, 1, 15, 15)).type_as(input).cuda() | |
| def func(input, grid): | |
| return F.grid_sample( | |
| input, | |
| grid, | |
| mode=mode, | |
| padding_mode=padding_mode, | |
| align_corners=align_corners) | |
| wrapped_model = WrapFunction(func).eval().cuda() | |
| input_names = ['input', 'grid'] | |
| output_names = ['output'] | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (input.clone(), grid.clone()), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(input.shape), | |
| list(input.shape), | |
| list(input.shape)], | |
| 'grid': [list(grid.shape), | |
| list(grid.shape), | |
| list(grid.shape)], | |
| } | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': input.clone(), 'grid': grid.clone()}) | |
| trt_results = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_results = wrapped_model(input.clone(), grid.clone()) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_results, trt_results) | |
| def test_cummin_cummax(func: Callable): | |
| # Note generally `cummax` or `cummin` is exportable to ONNX | |
| # as long as the pytorch version >= 1.5.0, since `torch.cummax` | |
| # is only supported with torch >= 1.5.0. | |
| # But when `cummax` or `cummin` serves as an intermediate component | |
| # whose outputs is used as inputs for another modules, it's expected | |
| # that pytorch version must be >= 1.7.0. Otherwise error appears like: | |
| # `RuntimeError: tuple appears in op that does not forward tuples, | |
| # unsupported 'kind: prim::PythonOp`. | |
| from packaging import version | |
| if version.parse(torch.__version__) < version.parse('1.7.0'): | |
| pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0') | |
| opset = 11 | |
| # register custom op `mmcv::cummax` and `mmcv::cummin` | |
| from mmcv.onnx.symbolic import register_extra_symbolics | |
| register_extra_symbolics(opset) | |
| input_list = [ | |
| # arbitrary shape, e.g. 1-D, 2-D, 3-D, ... | |
| torch.rand((2, 3, 4, 1, 5)).cuda(), | |
| torch.rand(1).cuda() | |
| ] | |
| input_names = ['input'] | |
| output_names = ['output', 'indices'] | |
| for input in input_list: | |
| ndims = input.dim() | |
| # valid dim range is [-ndims, ndims-1] | |
| # test for all `dim` value which is valid | |
| for dim in range(-ndims, ndims): | |
| cummax_func = partial(func, dim=dim) | |
| wrapped_model = WrapFunction(cummax_func).eval().cuda() | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, | |
| input, | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=False, | |
| input_names=input_names, | |
| output_names=output_names, | |
| opset_version=opset) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': | |
| [list(input.shape), | |
| list(input.shape), | |
| list(input.shape)] | |
| } | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| # remove ONNX model after conversion | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| # save TensorRT model | |
| save_trt_engine(trt_engine, trt_file) | |
| # load and wrap TensorRT model | |
| trt_model = TRTWrapper(trt_file) | |
| # remove trt model after loading | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| # compute trt output | |
| with torch.no_grad(): | |
| trt_results = trt_model({'input': input.contiguous().clone()}) | |
| trt_output = trt_results['output'] | |
| trt_indices = trt_results['indices'] | |
| # compute pytorch output | |
| with torch.no_grad(): | |
| pytorch_results = wrapped_model(input.clone()) | |
| pytorch_output = pytorch_results[0] | |
| pytorch_indices = pytorch_results[1] | |
| torch.testing.assert_allclose(trt_output, pytorch_output) | |
| torch.testing.assert_allclose(trt_indices, pytorch_indices) | |
| def test_instance_norm(dynamic_export, fp16_mode): | |
| n, c, h, w = 2, 3, 10, 10 | |
| data = torch.randn(n, c, h, w).cuda() | |
| norm = nn.InstanceNorm2d(c, affine=True) | |
| wrapped_model = WrapFunction(norm).eval().cuda() | |
| input_names = ['input'] | |
| output_names = ['output'] | |
| dynamic_axes = None | |
| if dynamic_export: | |
| dynamic_axes = { | |
| 'input': { | |
| 0: 'n', | |
| 2: 'h', | |
| 3: 'w', | |
| }, | |
| 'output': { | |
| 0: 'n', | |
| 2: 'h', | |
| 3: 'w', | |
| }, | |
| } | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (data.clone(), ), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=input_names, | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| opset_version=11) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| if dynamic_export: | |
| opt_shape_dict = { | |
| 'input': | |
| [list(data.shape), | |
| list(data.shape), [2 * n, c, 2 * h, 2 * w]], | |
| } | |
| else: | |
| opt_shape_dict = { | |
| 'input': [list(data.shape), | |
| list(data.shape), | |
| list(data.shape)], | |
| } | |
| # trt config | |
| max_workspace_size = 1 << 30 | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, input_names, output_names) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': data.clone()}) | |
| trt_results = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_results = wrapped_model(data.clone()) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_results, trt_results) | |
| def test_corner_pool(mode): | |
| try: | |
| from mmcv.ops import CornerPool | |
| except (ImportError, ModuleNotFoundError): | |
| pytest.skip('test requires compilation') | |
| opset = 11 | |
| # register custom op `mmcv::MMCVCornerPool` | |
| from mmcv.onnx.symbolic import register_extra_symbolics | |
| register_extra_symbolics(opset) | |
| # trt config | |
| fp16_mode = False | |
| max_workspace_size = 1 << 30 | |
| inputs = [ | |
| # (n, c, h, w) | |
| torch.rand((2, 3, 5, 5)), | |
| torch.rand((1, 2, 4, 6)), | |
| torch.rand((2, 1, 3, 2)), | |
| ] | |
| class CornerPoolWrapper(CornerPool): | |
| def __init__(self, mode): | |
| super().__init__(mode) | |
| def forward(self, x): | |
| # no use `torch.cummax`, instead `corner_pool` is used | |
| # for various torch version | |
| return self.corner_pool.apply(x) | |
| wrapped_model = CornerPoolWrapper(mode).cuda() | |
| for input in inputs: | |
| input = input.cuda() | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| wrapped_model, (input, ), | |
| onnx_file, | |
| export_params=True, | |
| keep_initializers_as_inputs=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| opset_version=opset) | |
| onnx_model = onnx.load(onnx_file) | |
| # create trt engine and wrapper | |
| opt_shape_dict = { | |
| 'input': [list(input.shape), | |
| list(input.shape), | |
| list(input.shape)], | |
| } | |
| trt_engine = onnx2trt( | |
| onnx_model, | |
| opt_shape_dict, | |
| fp16_mode=fp16_mode, | |
| max_workspace_size=max_workspace_size) | |
| save_trt_engine(trt_engine, trt_file) | |
| trt_model = TRTWrapper(trt_file, ['input'], ['output']) | |
| with torch.no_grad(): | |
| trt_outputs = trt_model({'input': input}) | |
| trt_pool_feat = trt_outputs['output'] | |
| # compute pytorch_output | |
| with torch.no_grad(): | |
| pytorch_pool_feat = wrapped_model(input) | |
| # allclose | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| if os.path.exists(trt_file): | |
| os.remove(trt_file) | |
| assert torch.allclose(pytorch_pool_feat, trt_pool_feat, atol=1e-5) | |