Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| from functools import wraps | |
| import onnx | |
| import pytest | |
| import torch | |
| from mmcv.ops import nms | |
| from mmcv.tensorrt.preprocess import preprocess_onnx | |
| if torch.__version__ == 'parrots': | |
| pytest.skip('not supported in parrots now', allow_module_level=True) | |
| def remove_tmp_file(func): | |
| def wrapper(*args, **kwargs): | |
| onnx_file = 'tmp.onnx' | |
| kwargs['onnx_file'] = onnx_file | |
| try: | |
| result = func(*args, **kwargs) | |
| finally: | |
| if os.path.exists(onnx_file): | |
| os.remove(onnx_file) | |
| return result | |
| return wrapper | |
| def export_nms_module_to_onnx(module, onnx_file): | |
| torch_model = module() | |
| torch_model.eval() | |
| input = (torch.rand([100, 4], dtype=torch.float32), | |
| torch.rand([100], dtype=torch.float32)) | |
| torch.onnx.export( | |
| torch_model, | |
| input, | |
| onnx_file, | |
| opset_version=11, | |
| input_names=['boxes', 'scores'], | |
| output_names=['output']) | |
| onnx_model = onnx.load(onnx_file) | |
| return onnx_model | |
| def test_can_handle_nms_with_constant_maxnum(): | |
| class ModuleNMS(torch.nn.Module): | |
| def forward(self, boxes, scores): | |
| return nms(boxes, scores, iou_threshold=0.4, max_num=10) | |
| onnx_model = export_nms_module_to_onnx(ModuleNMS) | |
| preprocess_onnx_model = preprocess_onnx(onnx_model) | |
| for node in preprocess_onnx_model.graph.node: | |
| if 'NonMaxSuppression' in node.name: | |
| assert len(node.attribute) == 5, 'The NMS must have 5 attributes.' | |
| def test_can_handle_nms_with_undefined_maxnum(): | |
| class ModuleNMS(torch.nn.Module): | |
| def forward(self, boxes, scores): | |
| return nms(boxes, scores, iou_threshold=0.4) | |
| onnx_model = export_nms_module_to_onnx(ModuleNMS) | |
| preprocess_onnx_model = preprocess_onnx(onnx_model) | |
| for node in preprocess_onnx_model.graph.node: | |
| if 'NonMaxSuppression' in node.name: | |
| assert len(node.attribute) == 5, \ | |
| 'The NMS must have 5 attributes.' | |
| assert node.attribute[2].i > 0, \ | |
| 'The max_output_boxes_per_class is not defined correctly.' | |