Valentin Boussot commited on
Commit
2d34814
·
1 Parent(s): 6ec9956

Add fast inference

Browse files
Build.py CHANGED
@@ -34,8 +34,6 @@ def convert_torchScript_full(model_name: str, model: torch.nn.Module, type: int,
34
  model.load_state_dict(tmp)
35
  torch.save({"Model" : {"Unet_TS" : tmp}}, f"{model_name}.pt")
36
 
37
-
38
-
39
  def download(url: str) -> dict[str, torch.Tensor]:
40
  with open(url.split("/")[-1], 'wb') as f:
41
  with requests.get(url, stream=True) as r:
 
34
  model.load_state_dict(tmp)
35
  torch.save({"Model" : {"Unet_TS" : tmp}}, f"{model_name}.pt")
36
 
 
 
37
  def download(url: str) -> dict[str, torch.Tensor]:
38
  with open(url.split("/")[-1], 'wb') as f:
39
  with requests.get(url, stream=True) as r:
Model.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from konfai.network import network, blocks
 
3
 
4
  class ConvBlock(network.ModuleArgsDict):
5
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
@@ -15,15 +16,16 @@ class UNetHead(network.ModuleArgsDict):
15
  def __init__(self, in_channels: int, nb_class: int) -> None:
16
  super().__init__()
17
  self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
 
18
 
19
  class UNetBlock(network.ModuleArgsDict):
20
 
21
- def __init__(self, channels, nb_class: int, mri: bool, i : int = 0) -> None:
22
  super().__init__()
23
  self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
24
 
25
  if len(channels) > 2:
26
- self.add_module("UNetBlock", UNetBlock(channels[1:], nb_class, mri, i+1))
27
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
28
 
29
  if i > 0:
@@ -39,7 +41,6 @@ class Unet_TS(network.Network):
39
  },
40
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
41
  channels = [1, 32, 64, 128, 320, 320],
42
- nb_class: int = 41,
43
  mri: bool = False) -> None:
44
  super().__init__(
45
  in_channels=channels[0],
@@ -49,5 +50,29 @@ class Unet_TS(network.Network):
49
  patch=None,
50
  dim=3,
51
  )
52
- self.add_module("UNetBlock", UNetBlock(channels, nb_class, mri))
53
- self.add_module("Head", UNetHead(channels[1], nb_class))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from konfai.network import network, blocks
3
+ from konfai.predictor import Reduction
4
 
5
  class ConvBlock(network.ModuleArgsDict):
6
  def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
 
16
  def __init__(self, in_channels: int, nb_class: int) -> None:
17
  super().__init__()
18
  self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
19
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
20
 
21
  class UNetBlock(network.ModuleArgsDict):
22
 
23
+ def __init__(self, channels, mri: bool, i : int = 0) -> None:
24
  super().__init__()
25
  self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
26
 
27
  if len(channels) > 2:
28
+ self.add_module("UNetBlock", UNetBlock(channels[1:], mri, i+1))
29
  self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
30
 
31
  if i > 0:
 
41
  },
42
  outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
43
  channels = [1, 32, 64, 128, 320, 320],
 
44
  mri: bool = False) -> None:
45
  super().__init__(
46
  in_channels=channels[0],
 
50
  patch=None,
51
  dim=3,
52
  )
53
+ self.add_module("UNetBlock", UNetBlock(channels, mri))
54
+ self.add_module("Head", UNetHead(channels[1], 42))
55
+
56
+ def load(
57
+ self,
58
+ state_dict: dict[str, dict[str, torch.Tensor] | int],
59
+ init: bool = True,
60
+ ema: bool = False,
61
+ ):
62
+ nb_class, in_channels = state_dict["Model"]["Unet_TS"]["Head.Conv.weight"].shape[:2]
63
+ self["Head"].add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
64
+ super().load(state_dict, init, ema)
65
+
66
+ class Combine(Reduction):
67
+
68
+ def __init__(self):
69
+ pass
70
+
71
+ def __call__(self, tensors: list[torch.Tensor]) -> torch.Tensor:
72
+ fg_all = torch.cat([p[:, 1:, ...] for p in tensors], dim=1)
73
+ sum_fg = fg_all.sum(dim=1, keepdim=True)
74
+ bg = (1.0 - sum_fg).clamp(min=1e-6)
75
+
76
+ probs = torch.cat([bg, fg_all], dim=1)
77
+
78
+ return probs / probs.sum(dim=1, keepdim=True)
Prediction_CT.yml CHANGED
@@ -11,7 +11,6 @@ Predictor:
11
  - 256
12
  - 320
13
  - 320
14
- nb_class: 25
15
  mri: false
16
  Dataset:
17
  groups_src:
@@ -85,15 +84,15 @@ Predictor:
85
  dtype: uint8
86
  inverse: true
87
  dataset_filename: Dataset:mha
88
- group: MASK
89
  same_as_group: Volume:Volume
90
  patch_combine: Cosinus
91
  inverse_transform: true
92
  reduction: Mean
93
- train_name: Curvas
94
  manual_seed: 32
95
  gpu_checkpoints: None
96
  images_log: None
97
- combine: Mean
98
  autocast: false
99
  data_log: None
 
11
  - 256
12
  - 320
13
  - 320
 
14
  mri: false
15
  Dataset:
16
  groups_src:
 
84
  dtype: uint8
85
  inverse: true
86
  dataset_filename: Dataset:mha
87
+ group: Seg
88
  same_as_group: Volume:Volume
89
  patch_combine: Cosinus
90
  inverse_transform: true
91
  reduction: Mean
92
+ train_name: TotalSegmentator
93
  manual_seed: 32
94
  gpu_checkpoints: None
95
  images_log: None
96
+ combine: Model:Combine
97
  autocast: false
98
  data_log: None
Prediction_CT_Fast.yml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Predictor:
2
+ Model:
3
+ classpath: Model:Unet_TS
4
+ Unet_TS:
5
+ outputs_criterions: None
6
+ channels:
7
+ - 1
8
+ - 32
9
+ - 64
10
+ - 128
11
+ - 256
12
+ - 320
13
+ mri: false
14
+ Dataset:
15
+ groups_src:
16
+ Volume:
17
+ groups_dest:
18
+ Volume:
19
+ transforms:
20
+ TensorCast:
21
+ dtype: float32
22
+ inverse: false
23
+ Canonical:
24
+ inverse: true
25
+ Clip:
26
+ min_value: -1024
27
+ max_value: 276
28
+ save_clip_min: false
29
+ save_clip_max: false
30
+ mask: None
31
+ Standardize:
32
+ lazy: false
33
+ mean: -370.00039267657144
34
+ std: 436.5998675471528
35
+ mask: None
36
+ inverse: true
37
+ ResampleToResolution:
38
+ spacing:
39
+ - 3
40
+ - 3
41
+ - 3
42
+ inverse: true
43
+ Padding:
44
+ padding:
45
+ - 32
46
+ - 32
47
+ - 32
48
+ - 32
49
+ - 32
50
+ - 32
51
+ mode: constant
52
+ inverse: true
53
+ patch_transforms: None
54
+ is_input: true
55
+ augmentations: None
56
+ Patch:
57
+ patch_size:
58
+ - 96
59
+ - 128
60
+ - 160
61
+ overlap: 32
62
+ mask: None
63
+ pad_value: 0
64
+ extend_slice: 0
65
+ subset: None
66
+ filter: None
67
+ dataset_filenames:
68
+ - ./Dataset/:nii.gz
69
+ use_cache: false
70
+ batch_size: 1
71
+ outputs_dataset:
72
+ Head:Conv:
73
+ OutputDataset:
74
+ name_class: OutSameAsGroupDataset
75
+ before_reduction_transforms: None
76
+ after_reduction_transforms: None
77
+ final_transforms:
78
+ Softmax:
79
+ dim: 0
80
+ Argmax:
81
+ dim: 0
82
+ TensorCast:
83
+ dtype: uint8
84
+ inverse: true
85
+ dataset_filename: Dataset:mha
86
+ group: Seg
87
+ same_as_group: Volume:Volume
88
+ patch_combine: Cosinus
89
+ inverse_transform: true
90
+ reduction: Mean
91
+ train_name: TotalSegmentator
92
+ manual_seed: 32
93
+ gpu_checkpoints: None
94
+ images_log: None
95
+ combine: Model:Combine
96
+ autocast: false
97
+ data_log: None
Prediction_MR.yml CHANGED
@@ -11,7 +11,7 @@ Predictor:
11
  - 256
12
  - 320
13
  - 320
14
- nb_class: 30
15
  Dataset:
16
  groups_src:
17
  Volume:
@@ -78,15 +78,15 @@ Predictor:
78
  dtype: uint8
79
  inverse: true
80
  dataset_filename: Dataset:mha
81
- group: MASK
82
  same_as_group: Volume:Volume
83
  patch_combine: Cosinus
84
  inverse_transform: true
85
  reduction: Mean
86
- train_name: Curvas
87
  manual_seed: 32
88
  gpu_checkpoints: None
89
  images_log: None
90
- combine: Mean
91
  autocast: false
92
  data_log: None
 
11
  - 256
12
  - 320
13
  - 320
14
+ mri: true
15
  Dataset:
16
  groups_src:
17
  Volume:
 
78
  dtype: uint8
79
  inverse: true
80
  dataset_filename: Dataset:mha
81
+ group: Seg
82
  same_as_group: Volume:Volume
83
  patch_combine: Cosinus
84
  inverse_transform: true
85
  reduction: Mean
86
+ train_name: TotalSegmentator
87
  manual_seed: 32
88
  gpu_checkpoints: None
89
  images_log: None
90
+ combine: Model:Combine
91
  autocast: false
92
  data_log: None
Prediction_MR_Fast.yml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Predictor:
2
+ Model:
3
+ classpath: Model:Unet_TS
4
+ Unet_TS:
5
+ outputs_criterions: None
6
+ channels:
7
+ - 1
8
+ - 32
9
+ - 64
10
+ - 128
11
+ - 256
12
+ - 320
13
+ mri: true
14
+ Dataset:
15
+ groups_src:
16
+ Volume:
17
+ groups_dest:
18
+ Volume:
19
+ transforms:
20
+ TensorCast:
21
+ dtype: float32
22
+ inverse: false
23
+ Canonical:
24
+ inverse: true
25
+ Standardize:
26
+ lazy: false
27
+ mean: None
28
+ std: None
29
+ mask: None
30
+ inverse: false
31
+ ResampleToResolution:
32
+ spacing:
33
+ - 3
34
+ - 3
35
+ - 3
36
+ inverse: true
37
+ Padding:
38
+ padding:
39
+ - 32
40
+ - 32
41
+ - 32
42
+ - 32
43
+ - 32
44
+ - 32
45
+ mode: constant
46
+ inverse: true
47
+ patch_transforms: None
48
+ is_input: true
49
+ augmentations: None
50
+ Patch:
51
+ patch_size:
52
+ - 96
53
+ - 128
54
+ - 160
55
+ overlap: 32
56
+ mask: None
57
+ pad_value: 0
58
+ extend_slice: 0
59
+ subset: None
60
+ filter: None
61
+ dataset_filenames:
62
+ - ./Dataset/:nii.gz
63
+ use_cache: false
64
+ batch_size: 1
65
+ outputs_dataset:
66
+ Head:Conv:
67
+ OutputDataset:
68
+ name_class: OutSameAsGroupDataset
69
+ before_reduction_transforms: None
70
+ after_reduction_transforms: None
71
+ final_transforms:
72
+ Softmax:
73
+ dim: 0
74
+ Argmax:
75
+ dim: 0
76
+ TensorCast:
77
+ dtype: uint8
78
+ inverse: true
79
+ dataset_filename: Dataset:mha
80
+ group: Seg
81
+ same_as_group: Volume:Volume
82
+ patch_combine: Cosinus
83
+ inverse_transform: true
84
+ reduction: Mean
85
+ train_name: TotalSegmentator
86
+ manual_seed: 32
87
+ gpu_checkpoints: None
88
+ images_log: None
89
+ combine: Model:Combine
90
+ autocast: false
91
+ data_log: None