| import torch | |
| from konfai.network import network, blocks | |
| from konfai.predictor import Reduction | |
| class ConvBlock(network.ModuleArgsDict): | |
| def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None: | |
| super().__init__() | |
| self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True)) | |
| self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) | |
| self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) | |
| self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)) | |
| self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) | |
| self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) | |
| class UNetHead(network.ModuleArgsDict): | |
| def __init__(self, in_channels: int, nb_class: int) -> None: | |
| super().__init__() | |
| self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) | |
| self.add_module("Softmax", torch.nn.Softmax(dim=1)) | |
| class UNetBlock(network.ModuleArgsDict): | |
| def __init__(self, channels, mri: bool, i : int = 0) -> None: | |
| super().__init__() | |
| self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1)) | |
| if len(channels) > 2: | |
| self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1)) | |
| self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1])) | |
| if i > 0: | |
| self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0)) | |
| self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1]) | |
| class Unet_TS(network.Network): | |
| def __init__(self, | |
| optimizer: network.OptimizerLoader = network.OptimizerLoader(), | |
| schedulers: dict[str, network.LRSchedulersLoader] = { | |
| "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) | |
| }, | |
| outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()}, | |
| channels = [1, 32, 64, 128, 320, 320], | |
| mri: bool = False) -> None: | |
| super().__init__( | |
| in_channels=channels[0], | |
| optimizer=optimizer, | |
| schedulers=schedulers, | |
| outputs_criterions=outputs_criterions, | |
| patch=None, | |
| dim=3, | |
| ) | |
| self.add_module("UNetBlock", UNetBlock(channels, mri)) | |
| self.add_module("Head", UNetHead(channels[1], 42)) | |
| def load( | |
| self, | |
| state_dict: dict[str, dict[str, torch.Tensor] | int], | |
| init: bool = True, | |
| ema: bool = False, | |
| ): | |
| nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2] | |
| self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) | |
| super().load(state_dict, init, ema) | |