Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmcv.cnn.bricks import HSigmoid | |
| def test_hsigmoid(): | |
| # test assertion divisor can not be zero | |
| with pytest.raises(AssertionError): | |
| HSigmoid(divisor=0) | |
| # test with default parameters | |
| act = HSigmoid() | |
| input_shape = torch.Size([1, 3, 64, 64]) | |
| input = torch.randn(input_shape) | |
| output = act(input) | |
| expected_output = torch.min( | |
| torch.max((input + 3) / 6, torch.zeros(input_shape)), | |
| torch.ones(input_shape)) | |
| # test output shape | |
| assert output.shape == expected_output.shape | |
| # test output value | |
| assert torch.equal(output, expected_output) | |
| # test with designated parameters | |
| act = HSigmoid(1, 2, 0, 1) | |
| input_shape = torch.Size([1, 3, 64, 64]) | |
| input = torch.randn(input_shape) | |
| output = act(input) | |
| expected_output = torch.min( | |
| torch.max((input + 1) / 2, torch.zeros(input_shape)), | |
| torch.ones(input_shape)) | |
| # test output shape | |
| assert output.shape == expected_output.shape | |
| # test output value | |
| assert torch.equal(output, expected_output) | |