Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import random | |
| from tempfile import TemporaryDirectory | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from scipy import stats | |
| from torch import nn | |
| from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit, | |
| PretrainedInit, TruncNormalInit, UniformInit, XavierInit, | |
| bias_init_with_prob, caffe2_xavier_init, constant_init, | |
| initialize, kaiming_init, normal_init, trunc_normal_init, | |
| uniform_init, xavier_init) | |
| if torch.__version__ == 'parrots': | |
| pytest.skip('not supported in parrots now', allow_module_level=True) | |
| def test_constant_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| constant_init(conv_module, 0.1) | |
| assert conv_module.weight.allclose( | |
| torch.full_like(conv_module.weight, 0.1)) | |
| assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias)) | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| constant_init(conv_module_no_bias, 0.1) | |
| assert conv_module.weight.allclose( | |
| torch.full_like(conv_module.weight, 0.1)) | |
| def test_xavier_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| xavier_init(conv_module, bias=0.1) | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) | |
| xavier_init(conv_module, distribution='uniform') | |
| # TODO: sanity check of weight distribution, e.g. mean, std | |
| with pytest.raises(AssertionError): | |
| xavier_init(conv_module, distribution='student-t') | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| xavier_init(conv_module_no_bias) | |
| def test_normal_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| normal_init(conv_module, bias=0.1) | |
| # TODO: sanity check of weight distribution, e.g. mean, std | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| normal_init(conv_module_no_bias) | |
| # TODO: sanity check distribution, e.g. mean, std | |
| def test_trunc_normal_init(): | |
| def _random_float(a, b): | |
| return (b - a) * random.random() + a | |
| def _is_trunc_normal(tensor, mean, std, a, b): | |
| # scipy's trunc norm is suited for data drawn from N(0, 1), | |
| # so we need to transform our data to test it using scipy. | |
| z_samples = (tensor.view(-1) - mean) / std | |
| z_samples = z_samples.tolist() | |
| a0 = (a - mean) / std | |
| b0 = (b - mean) / std | |
| p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1] | |
| return p_value > 0.0001 | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| mean = _random_float(-3, 3) | |
| std = _random_float(.01, 1) | |
| a = _random_float(mean - 2 * std, mean) | |
| b = _random_float(mean, mean + 2 * std) | |
| trunc_normal_init(conv_module, mean, std, a, b, bias=0.1) | |
| assert _is_trunc_normal(conv_module.weight, mean, std, a, b) | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| trunc_normal_init(conv_module_no_bias) | |
| # TODO: sanity check distribution, e.g. mean, std | |
| def test_uniform_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| uniform_init(conv_module, bias=0.1) | |
| # TODO: sanity check of weight distribution, e.g. mean, std | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| uniform_init(conv_module_no_bias) | |
| def test_kaiming_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| kaiming_init(conv_module, bias=0.1) | |
| # TODO: sanity check of weight distribution, e.g. mean, std | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1)) | |
| kaiming_init(conv_module, distribution='uniform') | |
| with pytest.raises(AssertionError): | |
| kaiming_init(conv_module, distribution='student-t') | |
| conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False) | |
| kaiming_init(conv_module_no_bias) | |
| def test_caffe_xavier_init(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| caffe2_xavier_init(conv_module) | |
| def test_bias_init_with_prob(): | |
| conv_module = nn.Conv2d(3, 16, 3) | |
| prior_prob = 0.1 | |
| normal_init(conv_module, bias=bias_init_with_prob(0.1)) | |
| # TODO: sanity check of weight distribution, e.g. mean, std | |
| bias = float(-np.log((1 - prior_prob) / prior_prob)) | |
| assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias)) | |
| def test_constaninit(): | |
| """test ConstantInit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = ConstantInit(val=1, bias=2, layer='Conv2d') | |
| func(model) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) | |
| assert not torch.equal(model[2].weight, | |
| torch.full(model[2].weight.shape, 1.)) | |
| assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) | |
| func = ConstantInit(val=3, bias_prob=0.01, layer='Linear') | |
| func(model) | |
| res = bias_init_with_prob(0.01) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = ConstantInit(val=4., bias=5., layer='_ConvNd') | |
| func(model) | |
| assert torch.all(model[0].weight == 4.) | |
| assert torch.all(model[2].weight == 4.) | |
| assert torch.all(model[0].bias == 5.) | |
| assert torch.all(model[2].bias == 5.) | |
| # test bias input type | |
| with pytest.raises(TypeError): | |
| func = ConstantInit(val=1, bias='1') | |
| # test bias_prob type | |
| with pytest.raises(TypeError): | |
| func = ConstantInit(val=1, bias_prob='1') | |
| # test layer input type | |
| with pytest.raises(TypeError): | |
| func = ConstantInit(val=1, layer=1) | |
| def test_xavierinit(): | |
| """test XavierInit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = XavierInit(bias=0.1, layer='Conv2d') | |
| func(model) | |
| assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1)) | |
| assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1)) | |
| constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear']) | |
| func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear']) | |
| model.apply(constant_func) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) | |
| res = bias_init_with_prob(0.01) | |
| func(model) | |
| assert not torch.equal(model[0].weight, | |
| torch.full(model[0].weight.shape, 0.)) | |
| assert not torch.equal(model[2].weight, | |
| torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = ConstantInit(val=4., bias=5., layer='_ConvNd') | |
| func(model) | |
| assert torch.all(model[0].weight == 4.) | |
| assert torch.all(model[2].weight == 4.) | |
| assert torch.all(model[0].bias == 5.) | |
| assert torch.all(model[2].bias == 5.) | |
| func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd') | |
| func(model) | |
| assert not torch.all(model[0].weight == 4.) | |
| assert not torch.all(model[2].weight == 4.) | |
| assert torch.all(model[0].bias == res) | |
| assert torch.all(model[2].bias == res) | |
| # test bias input type | |
| with pytest.raises(TypeError): | |
| func = XavierInit(bias='0.1', layer='Conv2d') | |
| # test layer inpur type | |
| with pytest.raises(TypeError): | |
| func = XavierInit(bias=0.1, layer=1) | |
| def test_normalinit(): | |
| """test Normalinit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear']) | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(100.)) | |
| assert model[2].weight.allclose(torch.tensor(100.)) | |
| assert model[0].bias.allclose(torch.tensor(200.)) | |
| assert model[2].bias.allclose(torch.tensor(200.)) | |
| func = NormalInit( | |
| mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear']) | |
| res = bias_init_with_prob(0.01) | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(300.)) | |
| assert model[2].weight.allclose(torch.tensor(300.)) | |
| assert model[0].bias.allclose(torch.tensor(res)) | |
| assert model[2].bias.allclose(torch.tensor(res)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd') | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(300.)) | |
| assert model[2].weight.allclose(torch.tensor(300.)) | |
| assert torch.all(model[0].bias == res) | |
| assert torch.all(model[2].bias == res) | |
| def test_truncnormalinit(): | |
| """test TruncNormalInit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = TruncNormalInit( | |
| mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear']) | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(100.)) | |
| assert model[2].weight.allclose(torch.tensor(100.)) | |
| assert model[0].bias.allclose(torch.tensor(200.)) | |
| assert model[2].bias.allclose(torch.tensor(200.)) | |
| func = TruncNormalInit( | |
| mean=300, | |
| std=1e-5, | |
| a=100, | |
| b=400, | |
| bias_prob=0.01, | |
| layer=['Conv2d', 'Linear']) | |
| res = bias_init_with_prob(0.01) | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(300.)) | |
| assert model[2].weight.allclose(torch.tensor(300.)) | |
| assert model[0].bias.allclose(torch.tensor(res)) | |
| assert model[2].bias.allclose(torch.tensor(res)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = TruncNormalInit( | |
| mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd') | |
| func(model) | |
| assert model[0].weight.allclose(torch.tensor(300.)) | |
| assert model[2].weight.allclose(torch.tensor(300.)) | |
| assert torch.all(model[0].bias == res) | |
| assert torch.all(model[2].bias == res) | |
| def test_uniforminit(): | |
| """"test UniformInit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear']) | |
| func(model) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) | |
| func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10) | |
| func(model) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, | |
| 100.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, | |
| 100.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd') | |
| res = bias_init_with_prob(0.01) | |
| func(model) | |
| assert torch.all(model[0].weight == 100.) | |
| assert torch.all(model[2].weight == 100.) | |
| assert torch.all(model[0].bias == res) | |
| assert torch.all(model[2].bias == res) | |
| def test_kaiminginit(): | |
| """test KaimingInit class.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = KaimingInit(bias=0.1, layer='Conv2d') | |
| func(model) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) | |
| assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) | |
| func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear']) | |
| constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear']) | |
| model.apply(constant_func) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) | |
| func(model) | |
| assert not torch.equal(model[0].weight, | |
| torch.full(model[0].weight.shape, 0.)) | |
| assert not torch.equal(model[2].weight, | |
| torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) | |
| # test layer key with base class name | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1)) | |
| func = KaimingInit(bias=0.1, layer='_ConvNd') | |
| func(model) | |
| assert torch.all(model[0].bias == 0.1) | |
| assert torch.all(model[2].bias == 0.1) | |
| func = KaimingInit(a=100, bias=10, layer='_ConvNd') | |
| constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd') | |
| model.apply(constant_func) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.)) | |
| func(model) | |
| assert not torch.equal(model[0].weight, | |
| torch.full(model[0].weight.shape, 0.)) | |
| assert not torch.equal(model[2].weight, | |
| torch.full(model[2].weight.shape, 0.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) | |
| def test_caffe2xavierinit(): | |
| """test Caffe2XavierInit.""" | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| func = Caffe2XavierInit(bias=0.1, layer='Conv2d') | |
| func(model) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1)) | |
| assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1)) | |
| class FooModule(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = nn.Linear(1, 2) | |
| self.conv2d = nn.Conv2d(3, 1, 3) | |
| self.conv2d_2 = nn.Conv2d(3, 2, 3) | |
| def test_pretrainedinit(): | |
| """test PretrainedInit class.""" | |
| modelA = FooModule() | |
| constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear']) | |
| modelA.apply(constant_func) | |
| modelB = FooModule() | |
| funcB = PretrainedInit(checkpoint='modelA.pth') | |
| modelC = nn.Linear(1, 2) | |
| funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.') | |
| with TemporaryDirectory(): | |
| torch.save(modelA.state_dict(), 'modelA.pth') | |
| funcB(modelB) | |
| assert torch.equal(modelB.linear.weight, | |
| torch.full(modelB.linear.weight.shape, 1.)) | |
| assert torch.equal(modelB.linear.bias, | |
| torch.full(modelB.linear.bias.shape, 2.)) | |
| assert torch.equal(modelB.conv2d.weight, | |
| torch.full(modelB.conv2d.weight.shape, 1.)) | |
| assert torch.equal(modelB.conv2d.bias, | |
| torch.full(modelB.conv2d.bias.shape, 2.)) | |
| assert torch.equal(modelB.conv2d_2.weight, | |
| torch.full(modelB.conv2d_2.weight.shape, 1.)) | |
| assert torch.equal(modelB.conv2d_2.bias, | |
| torch.full(modelB.conv2d_2.bias.shape, 2.)) | |
| funcC(modelC) | |
| assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.)) | |
| assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.)) | |
| def test_initialize(): | |
| model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2)) | |
| foonet = FooModule() | |
| # test layer key | |
| init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) | |
| initialize(model, init_cfg) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.)) | |
| assert init_cfg == dict( | |
| type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2) | |
| # test init_cfg with list type | |
| init_cfg = [ | |
| dict(type='Constant', layer='Conv2d', val=1, bias=2), | |
| dict(type='Constant', layer='Linear', val=3, bias=4) | |
| ] | |
| initialize(model, init_cfg) | |
| assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.)) | |
| assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.)) | |
| assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.)) | |
| assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.)) | |
| assert init_cfg == [ | |
| dict(type='Constant', layer='Conv2d', val=1, bias=2), | |
| dict(type='Constant', layer='Linear', val=3, bias=4) | |
| ] | |
| # test layer key and override key | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| layer=['Conv2d', 'Linear'], | |
| override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) | |
| initialize(foonet, init_cfg) | |
| assert torch.equal(foonet.linear.weight, | |
| torch.full(foonet.linear.weight.shape, 1.)) | |
| assert torch.equal(foonet.linear.bias, | |
| torch.full(foonet.linear.bias.shape, 2.)) | |
| assert torch.equal(foonet.conv2d.weight, | |
| torch.full(foonet.conv2d.weight.shape, 1.)) | |
| assert torch.equal(foonet.conv2d.bias, | |
| torch.full(foonet.conv2d.bias.shape, 2.)) | |
| assert torch.equal(foonet.conv2d_2.weight, | |
| torch.full(foonet.conv2d_2.weight.shape, 3.)) | |
| assert torch.equal(foonet.conv2d_2.bias, | |
| torch.full(foonet.conv2d_2.bias.shape, 4.)) | |
| assert init_cfg == dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| layer=['Conv2d', 'Linear'], | |
| override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) | |
| # test override key | |
| init_cfg = dict( | |
| type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) | |
| initialize(foonet, init_cfg) | |
| assert not torch.equal(foonet.linear.weight, | |
| torch.full(foonet.linear.weight.shape, 5.)) | |
| assert not torch.equal(foonet.linear.bias, | |
| torch.full(foonet.linear.bias.shape, 6.)) | |
| assert not torch.equal(foonet.conv2d.weight, | |
| torch.full(foonet.conv2d.weight.shape, 5.)) | |
| assert not torch.equal(foonet.conv2d.bias, | |
| torch.full(foonet.conv2d.bias.shape, 6.)) | |
| assert torch.equal(foonet.conv2d_2.weight, | |
| torch.full(foonet.conv2d_2.weight.shape, 5.)) | |
| assert torch.equal(foonet.conv2d_2.bias, | |
| torch.full(foonet.conv2d_2.bias.shape, 6.)) | |
| assert init_cfg == dict( | |
| type='Constant', val=5, bias=6, override=dict(name='conv2d_2')) | |
| init_cfg = dict( | |
| type='Pretrained', | |
| checkpoint='modelA.pth', | |
| override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) | |
| modelA = FooModule() | |
| constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear']) | |
| modelA.apply(constant_func) | |
| with TemporaryDirectory(): | |
| torch.save(modelA.state_dict(), 'modelA.pth') | |
| initialize(foonet, init_cfg) | |
| assert torch.equal(foonet.linear.weight, | |
| torch.full(foonet.linear.weight.shape, 1.)) | |
| assert torch.equal(foonet.linear.bias, | |
| torch.full(foonet.linear.bias.shape, 2.)) | |
| assert torch.equal(foonet.conv2d.weight, | |
| torch.full(foonet.conv2d.weight.shape, 1.)) | |
| assert torch.equal(foonet.conv2d.bias, | |
| torch.full(foonet.conv2d.bias.shape, 2.)) | |
| assert torch.equal(foonet.conv2d_2.weight, | |
| torch.full(foonet.conv2d_2.weight.shape, 3.)) | |
| assert torch.equal(foonet.conv2d_2.bias, | |
| torch.full(foonet.conv2d_2.bias.shape, 4.)) | |
| assert init_cfg == dict( | |
| type='Pretrained', | |
| checkpoint='modelA.pth', | |
| override=dict(type='Constant', name='conv2d_2', val=3, bias=4)) | |
| # test init_cfg type | |
| with pytest.raises(TypeError): | |
| init_cfg = 'init_cfg' | |
| initialize(foonet, init_cfg) | |
| # test override value type | |
| with pytest.raises(TypeError): | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| layer=['Conv2d', 'Linear'], | |
| override='conv') | |
| initialize(foonet, init_cfg) | |
| # test override name | |
| with pytest.raises(RuntimeError): | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| layer=['Conv2d', 'Linear'], | |
| override=dict(type='Constant', name='conv2d_3', val=3, bias=4)) | |
| initialize(foonet, init_cfg) | |
| # test list override name | |
| with pytest.raises(RuntimeError): | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| layer=['Conv2d', 'Linear'], | |
| override=[ | |
| dict(type='Constant', name='conv2d', val=3, bias=4), | |
| dict(type='Constant', name='conv2d_3', val=5, bias=6) | |
| ]) | |
| initialize(foonet, init_cfg) | |
| # test override with args except type key | |
| with pytest.raises(ValueError): | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| override=dict(name='conv2d_2', val=3, bias=4)) | |
| initialize(foonet, init_cfg) | |
| # test override without name | |
| with pytest.raises(ValueError): | |
| init_cfg = dict( | |
| type='Constant', | |
| val=1, | |
| bias=2, | |
| override=dict(type='Constant', val=3, bias=4)) | |
| initialize(foonet, init_cfg) | |