import torch from konfai.network import network, blocks from konfai.predictor import Reduction class ConvBlock(network.ModuleArgsDict): def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None: super().__init__() self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True)) self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)) self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)) self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)) class UNetHead(network.ModuleArgsDict): def __init__(self, in_channels: int, nb_class: int) -> None: super().__init__() self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) self.add_module("Softmax", torch.nn.Softmax(dim=1)) class UNetBlock(network.ModuleArgsDict): def __init__(self, channels, mri: bool, i : int = 0) -> None: super().__init__() self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1)) if len(channels) > 2: self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1)) self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1])) if i > 0: self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0)) self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1]) class Unet_TS(network.Network): def __init__(self, optimizer: network.OptimizerLoader = network.OptimizerLoader(), schedulers: dict[str, network.LRSchedulersLoader] = { "default:ReduceLROnPlateau": network.LRSchedulersLoader(0) }, outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()}, channels = [1, 32, 64, 128, 320, 320], mri: bool = False) -> None: super().__init__( in_channels=channels[0], optimizer=optimizer, schedulers=schedulers, outputs_criterions=outputs_criterions, patch=None, dim=3, ) self.add_module("UNetBlock", UNetBlock(channels, mri)) self.add_module("Head", UNetHead(channels[1], 42)) def load( self, state_dict: dict[str, dict[str, torch.Tensor] | int], init: bool = True, ema: bool = False, ): nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2] self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0)) super().load(state_dict, init, ema)