Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from mmcv.cnn.bricks import GeneralizedAttention | |
| def test_context_block(): | |
| # test attention_type='1000' | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, attention_type='1000') | |
| assert gen_attention_block.query_conv.in_channels == 16 | |
| assert gen_attention_block.key_conv.in_channels == 16 | |
| assert gen_attention_block.key_conv.in_channels == 16 | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test attention_type='0100' | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, attention_type='0100') | |
| assert gen_attention_block.query_conv.in_channels == 16 | |
| assert gen_attention_block.appr_geom_fc_x.in_features == 8 | |
| assert gen_attention_block.appr_geom_fc_y.in_features == 8 | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test attention_type='0010' | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, attention_type='0010') | |
| assert gen_attention_block.key_conv.in_channels == 16 | |
| assert hasattr(gen_attention_block, 'appr_bias') | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test attention_type='0001' | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, attention_type='0001') | |
| assert gen_attention_block.appr_geom_fc_x.in_features == 8 | |
| assert gen_attention_block.appr_geom_fc_y.in_features == 8 | |
| assert hasattr(gen_attention_block, 'geom_bias') | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test spatial_range >= 0 | |
| imgs = torch.randn(2, 256, 20, 20) | |
| gen_attention_block = GeneralizedAttention(256, spatial_range=10) | |
| assert hasattr(gen_attention_block, 'local_constraint_map') | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test q_stride > 1 | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, q_stride=2) | |
| assert gen_attention_block.q_downsample is not None | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test kv_stride > 1 | |
| imgs = torch.randn(2, 16, 20, 20) | |
| gen_attention_block = GeneralizedAttention(16, kv_stride=2) | |
| assert gen_attention_block.kv_downsample is not None | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |
| # test fp16 with attention_type='1111' | |
| if torch.cuda.is_available(): | |
| imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half) | |
| gen_attention_block = GeneralizedAttention( | |
| 16, | |
| spatial_range=-1, | |
| num_heads=8, | |
| attention_type='1111', | |
| kv_stride=2) | |
| gen_attention_block.cuda().type(torch.half) | |
| out = gen_attention_block(imgs) | |
| assert out.shape == imgs.shape | |