Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.runner.fp16_utils import auto_fp16, cast_tensor_type, force_fp32 | |
| def test_cast_tensor_type(): | |
| inputs = torch.FloatTensor([5.]) | |
| src_type = torch.float32 | |
| dst_type = torch.int32 | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, torch.Tensor) | |
| assert outputs.dtype == dst_type | |
| # convert torch.float to torch.half | |
| inputs = torch.FloatTensor([5.]) | |
| src_type = torch.float | |
| dst_type = torch.half | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, torch.Tensor) | |
| assert outputs.dtype == dst_type | |
| # skip the conversion when the type of input is not the same as src_type | |
| inputs = torch.IntTensor([5]) | |
| src_type = torch.float | |
| dst_type = torch.half | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, torch.Tensor) | |
| assert outputs.dtype == inputs.dtype | |
| inputs = 'tensor' | |
| src_type = str | |
| dst_type = str | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, str) | |
| inputs = np.array([5.]) | |
| src_type = np.ndarray | |
| dst_type = np.ndarray | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, np.ndarray) | |
| inputs = dict( | |
| tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.])) | |
| src_type = torch.float32 | |
| dst_type = torch.int32 | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, dict) | |
| assert outputs['tensor_a'].dtype == dst_type | |
| assert outputs['tensor_b'].dtype == dst_type | |
| inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])] | |
| src_type = torch.float32 | |
| dst_type = torch.int32 | |
| outputs = cast_tensor_type(inputs, src_type, dst_type) | |
| assert isinstance(outputs, list) | |
| assert outputs[0].dtype == dst_type | |
| assert outputs[1].dtype == dst_type | |
| inputs = 5 | |
| outputs = cast_tensor_type(inputs, None, None) | |
| assert isinstance(outputs, int) | |
| def test_auto_fp16(): | |
| with pytest.raises(TypeError): | |
| # ExampleObject is not a subclass of nn.Module | |
| class ExampleObject: | |
| def __call__(self, x): | |
| return x | |
| model = ExampleObject() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| model(input_x) | |
| # apply to all input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y): | |
| return x, y | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| input_y = torch.ones(1, dtype=torch.float32) | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| model.fp16_enabled = True | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y = model(input_x.cuda(), input_y.cuda()) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| # apply to specified input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y): | |
| return x, y | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| input_y = torch.ones(1, dtype=torch.float32) | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| model.fp16_enabled = True | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.float32 | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y = model(input_x.cuda(), input_y.cuda()) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.float32 | |
| # apply to optional input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y=None, z=None): | |
| return x, y, z | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| input_y = torch.ones(1, dtype=torch.float32) | |
| input_z = torch.ones(1, dtype=torch.float32) | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.float32 | |
| model.fp16_enabled = True | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.float32 | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y, output_z = model( | |
| input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.float32 | |
| # out_fp32=True | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y=None, z=None): | |
| return x, y, z | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.half) | |
| input_y = torch.ones(1, dtype=torch.float32) | |
| input_z = torch.ones(1, dtype=torch.float32) | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.float32 | |
| model.fp16_enabled = True | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.float32 | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y, output_z = model( | |
| input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.float32 | |
| def test_force_fp32(): | |
| with pytest.raises(TypeError): | |
| # ExampleObject is not a subclass of nn.Module | |
| class ExampleObject: | |
| def __call__(self, x): | |
| return x | |
| model = ExampleObject() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| model(input_x) | |
| # apply to all input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y): | |
| return x, y | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.half) | |
| input_y = torch.ones(1, dtype=torch.half) | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| model.fp16_enabled = True | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y = model(input_x.cuda(), input_y.cuda()) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| # apply to specified input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y): | |
| return x, y | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.half) | |
| input_y = torch.ones(1, dtype=torch.half) | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| model.fp16_enabled = True | |
| output_x, output_y = model(input_x, input_y) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.half | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y = model(input_x.cuda(), input_y.cuda()) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.half | |
| # apply to optional input args | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y=None, z=None): | |
| return x, y, z | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.half) | |
| input_y = torch.ones(1, dtype=torch.half) | |
| input_z = torch.ones(1, dtype=torch.half) | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.half | |
| model.fp16_enabled = True | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.half | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y, output_z = model( | |
| input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.float32 | |
| assert output_z.dtype == torch.half | |
| # out_fp16=True | |
| class ExampleModule(nn.Module): | |
| def forward(self, x, y=None, z=None): | |
| return x, y, z | |
| model = ExampleModule() | |
| input_x = torch.ones(1, dtype=torch.float32) | |
| input_y = torch.ones(1, dtype=torch.half) | |
| input_z = torch.ones(1, dtype=torch.half) | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.float32 | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.half | |
| model.fp16_enabled = True | |
| output_x, output_y, output_z = model(input_x, y=input_y, z=input_z) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.half | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| output_x, output_y, output_z = model( | |
| input_x.cuda(), y=input_y.cuda(), z=input_z.cuda()) | |
| assert output_x.dtype == torch.half | |
| assert output_y.dtype == torch.half | |
| assert output_z.dtype == torch.half | |