Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from unittest.mock import patch | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, | |
| Linear, MaxPool2d, MaxPool3d) | |
| if torch.__version__ != 'parrots': | |
| torch_version = '1.1' | |
| else: | |
| torch_version = 'parrots' | |
| def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, | |
| padding, dilation): | |
| """ | |
| CommandLine: | |
| xdoctest -m tests/test_wrappers.py test_conv2d | |
| """ | |
| # train mode | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_h, in_w) | |
| torch.manual_seed(0) | |
| wrapper = Conv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True) | |
| torch.manual_seed(0) | |
| ref = nn.Conv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| wrapper_out.sum().backward() | |
| assert wrapper.weight.grad is not None | |
| assert wrapper.weight.grad.shape == wrapper.weight.shape | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| # eval mode | |
| x_empty = torch.randn(0, in_channel, in_h, in_w) | |
| wrapper = Conv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| wrapper.eval() | |
| wrapper(x_empty) | |
| def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride, | |
| padding, dilation): | |
| """ | |
| CommandLine: | |
| xdoctest -m tests/test_wrappers.py test_conv3d | |
| """ | |
| # train mode | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
| torch.manual_seed(0) | |
| wrapper = Conv3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_t, in_h, | |
| in_w).requires_grad_(True) | |
| torch.manual_seed(0) | |
| ref = nn.Conv3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| wrapper_out.sum().backward() | |
| assert wrapper.weight.grad is not None | |
| assert wrapper.weight.grad.shape == wrapper.weight.shape | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| # eval mode | |
| x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
| wrapper = Conv3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation) | |
| wrapper.eval() | |
| wrapper(x_empty) | |
| def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size, | |
| stride, padding, dilation): | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True) | |
| # out padding must be smaller than either stride or dilation | |
| op = min(stride, dilation) - 1 | |
| if torch.__version__ == 'parrots': | |
| op = 0 | |
| torch.manual_seed(0) | |
| wrapper = ConvTranspose2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_h, in_w) | |
| torch.manual_seed(0) | |
| ref = nn.ConvTranspose2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| wrapper_out.sum().backward() | |
| assert wrapper.weight.grad is not None | |
| assert wrapper.weight.grad.shape == wrapper.weight.shape | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| # eval mode | |
| x_empty = torch.randn(0, in_channel, in_h, in_w) | |
| wrapper = ConvTranspose2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| wrapper.eval() | |
| wrapper(x_empty) | |
| def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel, | |
| kernel_size, stride, padding, dilation): | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True) | |
| # out padding must be smaller than either stride or dilation | |
| op = min(stride, dilation) - 1 | |
| torch.manual_seed(0) | |
| wrapper = ConvTranspose3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_t, in_h, in_w) | |
| torch.manual_seed(0) | |
| ref = nn.ConvTranspose3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| wrapper_out.sum().backward() | |
| assert wrapper.weight.grad is not None | |
| assert wrapper.weight.grad.shape == wrapper.weight.shape | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| # eval mode | |
| x_empty = torch.randn(0, in_channel, in_t, in_h, in_w) | |
| wrapper = ConvTranspose3d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| output_padding=op) | |
| wrapper.eval() | |
| wrapper(x_empty) | |
| def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride, | |
| padding, dilation): | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True) | |
| wrapper = MaxPool2d( | |
| kernel_size, stride=stride, padding=padding, dilation=dilation) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_h, in_w) | |
| ref = nn.MaxPool2d( | |
| kernel_size, stride=stride, padding=padding, dilation=dilation) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, | |
| stride, padding, dilation): | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True) | |
| wrapper = MaxPool3d( | |
| kernel_size, stride=stride, padding=padding, dilation=dilation) | |
| if torch.__version__ == 'parrots': | |
| x_empty = x_empty.cuda() | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_channel, in_t, in_h, in_w) | |
| ref = nn.MaxPool3d( | |
| kernel_size, stride=stride, padding=padding, dilation=dilation) | |
| if torch.__version__ == 'parrots': | |
| x_normal = x_normal.cuda() | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| def test_linear(in_w, in_h, in_feature, out_feature): | |
| # wrapper op with 0-dim input | |
| x_empty = torch.randn(0, in_feature, requires_grad=True) | |
| torch.manual_seed(0) | |
| wrapper = Linear(in_feature, out_feature) | |
| wrapper_out = wrapper(x_empty) | |
| # torch op with 3-dim input as shape reference | |
| x_normal = torch.randn(3, in_feature) | |
| torch.manual_seed(0) | |
| ref = nn.Linear(in_feature, out_feature) | |
| ref_out = ref(x_normal) | |
| assert wrapper_out.shape[0] == 0 | |
| assert wrapper_out.shape[1:] == ref_out.shape[1:] | |
| wrapper_out.sum().backward() | |
| assert wrapper.weight.grad is not None | |
| assert wrapper.weight.grad.shape == wrapper.weight.shape | |
| assert torch.equal(wrapper(x_normal), ref_out) | |
| # eval mode | |
| x_empty = torch.randn(0, in_feature) | |
| wrapper = Linear(in_feature, out_feature) | |
| wrapper.eval() | |
| wrapper(x_empty) | |
| def test_nn_op_forward_called(): | |
| for m in ['Conv2d', 'ConvTranspose2d', 'MaxPool2d']: | |
| with patch(f'torch.nn.{m}.forward') as nn_module_forward: | |
| # randn input | |
| x_empty = torch.randn(0, 3, 10, 10) | |
| wrapper = eval(m)(3, 2, 1) | |
| wrapper(x_empty) | |
| nn_module_forward.assert_called_with(x_empty) | |
| # non-randn input | |
| x_normal = torch.randn(1, 3, 10, 10) | |
| wrapper = eval(m)(3, 2, 1) | |
| wrapper(x_normal) | |
| nn_module_forward.assert_called_with(x_normal) | |
| for m in ['Conv3d', 'ConvTranspose3d', 'MaxPool3d']: | |
| with patch(f'torch.nn.{m}.forward') as nn_module_forward: | |
| # randn input | |
| x_empty = torch.randn(0, 3, 10, 10, 10) | |
| wrapper = eval(m)(3, 2, 1) | |
| wrapper(x_empty) | |
| nn_module_forward.assert_called_with(x_empty) | |
| # non-randn input | |
| x_normal = torch.randn(1, 3, 10, 10, 10) | |
| wrapper = eval(m)(3, 2, 1) | |
| wrapper(x_normal) | |
| nn_module_forward.assert_called_with(x_normal) | |
| with patch('torch.nn.Linear.forward') as nn_module_forward: | |
| # randn input | |
| x_empty = torch.randn(0, 3) | |
| wrapper = Linear(3, 3) | |
| wrapper(x_empty) | |
| nn_module_forward.assert_called_with(x_empty) | |
| # non-randn input | |
| x_normal = torch.randn(1, 3) | |
| wrapper = Linear(3, 3) | |
| wrapper(x_normal) | |
| nn_module_forward.assert_called_with(x_normal) | |