Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| from unittest.mock import patch | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish | |
| from mmcv.utils import TORCH_VERSION, digit_version | |
| class ExampleConv(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| norm_cfg=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.groups = groups | |
| self.bias = bias | |
| self.norm_cfg = norm_cfg | |
| self.output_padding = (0, 0, 0) | |
| self.transposed = False | |
| self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size) | |
| self.init_weights() | |
| def forward(self, x): | |
| x = self.conv0(x) | |
| return x | |
| def init_weights(self): | |
| nn.init.constant_(self.conv0.weight, 0) | |
| def test_conv_module(): | |
| with pytest.raises(AssertionError): | |
| # conv_cfg must be a dict or None | |
| conv_cfg = 'conv' | |
| ConvModule(3, 8, 2, conv_cfg=conv_cfg) | |
| with pytest.raises(AssertionError): | |
| # norm_cfg must be a dict or None | |
| norm_cfg = 'norm' | |
| ConvModule(3, 8, 2, norm_cfg=norm_cfg) | |
| with pytest.raises(KeyError): | |
| # softmax is not supported | |
| act_cfg = dict(type='softmax') | |
| ConvModule(3, 8, 2, act_cfg=act_cfg) | |
| # conv + norm + act | |
| conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
| assert conv.with_activation | |
| assert hasattr(conv, 'activate') | |
| assert conv.with_norm | |
| assert hasattr(conv, 'norm') | |
| x = torch.rand(1, 3, 256, 256) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 255, 255) | |
| # conv + act | |
| conv = ConvModule(3, 8, 2) | |
| assert conv.with_activation | |
| assert hasattr(conv, 'activate') | |
| assert not conv.with_norm | |
| assert conv.norm is None | |
| x = torch.rand(1, 3, 256, 256) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 255, 255) | |
| # conv | |
| conv = ConvModule(3, 8, 2, act_cfg=None) | |
| assert not conv.with_norm | |
| assert conv.norm is None | |
| assert not conv.with_activation | |
| assert not hasattr(conv, 'activate') | |
| x = torch.rand(1, 3, 256, 256) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 255, 255) | |
| # conv with its own `init_weights` method | |
| conv_module = ConvModule( | |
| 3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None) | |
| assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2)) | |
| # with_spectral_norm=True | |
| conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True) | |
| assert hasattr(conv.conv, 'weight_orig') | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # padding_mode='reflect' | |
| conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect') | |
| assert isinstance(conv.padding_layer, nn.ReflectionPad2d) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # non-existing padding mode | |
| with pytest.raises(KeyError): | |
| conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists') | |
| # leaky relu | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU')) | |
| assert isinstance(conv.activate, nn.LeakyReLU) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # tanh | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Tanh')) | |
| assert isinstance(conv.activate, nn.Tanh) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # Sigmoid | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Sigmoid')) | |
| assert isinstance(conv.activate, nn.Sigmoid) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # PReLU | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='PReLU')) | |
| assert isinstance(conv.activate, nn.PReLU) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # HSwish | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish')) | |
| if (TORCH_VERSION == 'parrots' | |
| or digit_version(TORCH_VERSION) < digit_version('1.7')): | |
| assert isinstance(conv.activate, HSwish) | |
| else: | |
| assert isinstance(conv.activate, nn.Hardswish) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| # HSigmoid | |
| conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid')) | |
| assert isinstance(conv.activate, HSigmoid) | |
| output = conv(x) | |
| assert output.shape == (1, 8, 256, 256) | |
| def test_bias(): | |
| # bias: auto, without norm | |
| conv = ConvModule(3, 8, 2) | |
| assert conv.conv.bias is not None | |
| # bias: auto, with norm | |
| conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
| assert conv.conv.bias is None | |
| # bias: False, without norm | |
| conv = ConvModule(3, 8, 2, bias=False) | |
| assert conv.conv.bias is None | |
| # bias: True, with batch norm | |
| with pytest.warns(UserWarning) as record: | |
| ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN')) | |
| assert len(record) == 1 | |
| assert record[0].message.args[ | |
| 0] == 'Unnecessary conv bias before batch/instance norm' | |
| # bias: True, with instance norm | |
| with pytest.warns(UserWarning) as record: | |
| ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN')) | |
| assert len(record) == 1 | |
| assert record[0].message.args[ | |
| 0] == 'Unnecessary conv bias before batch/instance norm' | |
| # bias: True, with other norm | |
| with pytest.warns(UserWarning) as record: | |
| norm_cfg = dict(type='GN', num_groups=1) | |
| ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg) | |
| warnings.warn('No warnings') | |
| assert len(record) == 1 | |
| assert record[0].message.args[0] == 'No warnings' | |
| def conv_forward(self, x): | |
| return x + '_conv' | |
| def bn_forward(self, x): | |
| return x + '_bn' | |
| def relu_forward(self, x): | |
| return x + '_relu' | |
| def test_order(): | |
| with pytest.raises(AssertionError): | |
| # order must be a tuple | |
| order = ['conv', 'norm', 'act'] | |
| ConvModule(3, 8, 2, order=order) | |
| with pytest.raises(AssertionError): | |
| # length of order must be 3 | |
| order = ('conv', 'norm') | |
| ConvModule(3, 8, 2, order=order) | |
| with pytest.raises(AssertionError): | |
| # order must be an order of 'conv', 'norm', 'act' | |
| order = ('conv', 'norm', 'norm') | |
| ConvModule(3, 8, 2, order=order) | |
| with pytest.raises(AssertionError): | |
| # order must be an order of 'conv', 'norm', 'act' | |
| order = ('conv', 'norm', 'something') | |
| ConvModule(3, 8, 2, order=order) | |
| # ('conv', 'norm', 'act') | |
| conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
| out = conv('input') | |
| assert out == 'input_conv_bn_relu' | |
| # ('norm', 'conv', 'act') | |
| conv = ConvModule( | |
| 3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act')) | |
| out = conv('input') | |
| assert out == 'input_bn_conv_relu' | |
| # ('conv', 'norm', 'act'), activate=False | |
| conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
| out = conv('input', activate=False) | |
| assert out == 'input_conv_bn' | |
| # ('conv', 'norm', 'act'), activate=False | |
| conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN')) | |
| out = conv('input', norm=False) | |
| assert out == 'input_conv_relu' | |