Valentin Boussot
commited on
Commit
·
6ec9956
1
Parent(s):
a1cf612
Add TotalSegmentator KonfAI weights
Browse files- Build.py +82 -0
- Destination_Unet_1.txt +100 -0
- Destination_Unet_2.txt +82 -0
- Destination_Unet_3.txt +64 -0
- M291.pt +3 -0
- M292.pt +3 -0
- M293.pt +3 -0
- M294.pt +3 -0
- M295.pt +3 -0
- M297.pt +3 -0
- M298.pt +3 -0
- M730.pt +3 -0
- M731.pt +3 -0
- M732.pt +3 -0
- M733.pt +3 -0
- M850.pt +3 -0
- M851.pt +3 -0
- M852.pt +3 -0
- M853.pt +3 -0
- Model.py +53 -0
- Prediction_CT.yml +99 -0
- Prediction_MR.yml +92 -0
Build.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import zipfile
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
from functools import partial
|
| 9 |
+
from Model import Unet_TS
|
| 10 |
+
|
| 11 |
+
def convert_torchScript_full(model_name: str, model: torch.nn.Module, type: int, url: str) -> None:
|
| 12 |
+
state_dict = download(url)
|
| 13 |
+
tmp = {}
|
| 14 |
+
with open("Destination_Unet_{}.txt".format(type)) as f2:
|
| 15 |
+
it = iter(state_dict.keys())
|
| 16 |
+
for l1 in f2:
|
| 17 |
+
key = next(it)
|
| 18 |
+
while "decoder.seg_layers" in key:
|
| 19 |
+
if type == 1:
|
| 20 |
+
if "decoder.seg_layers.4" in key :
|
| 21 |
+
break
|
| 22 |
+
if type == 2:
|
| 23 |
+
if "decoder.seg_layers.3" in key:
|
| 24 |
+
break
|
| 25 |
+
if type == 3:
|
| 26 |
+
if "decoder.seg_layers.2" in key:
|
| 27 |
+
break
|
| 28 |
+
key = next(it)
|
| 29 |
+
|
| 30 |
+
while "all_modules" in key or "decoder.encoder" in key:
|
| 31 |
+
key = next(it)
|
| 32 |
+
tmp[l1.replace("\n", "")] = state_dict[key]
|
| 33 |
+
|
| 34 |
+
model.load_state_dict(tmp)
|
| 35 |
+
torch.save({"Model" : {"Unet_TS" : tmp}}, f"{model_name}.pt")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def download(url: str) -> dict[str, torch.Tensor]:
|
| 40 |
+
with open(url.split("/")[-1], 'wb') as f:
|
| 41 |
+
with requests.get(url, stream=True) as r:
|
| 42 |
+
r.raise_for_status()
|
| 43 |
+
|
| 44 |
+
total_size = int(r.headers.get('content-length', 0))
|
| 45 |
+
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading")
|
| 46 |
+
for chunk in r.iter_content(chunk_size=8192 * 16):
|
| 47 |
+
progress_bar.update(len(chunk))
|
| 48 |
+
f.write(chunk)
|
| 49 |
+
progress_bar.close()
|
| 50 |
+
with zipfile.ZipFile(url.split("/")[-1], 'r') as zip_f:
|
| 51 |
+
zip_f.extractall(url.split("/")[-1].replace(".zip", ""))
|
| 52 |
+
os.remove(url.split("/")[-1])
|
| 53 |
+
state_dict = torch.load(next(Path(url.split("/")[-1].replace(".zip", "")).rglob("checkpoint_final.pth"), None), weights_only=False)["network_weights"]
|
| 54 |
+
shutil.rmtree(url.split("/")[-1].replace(".zip", ""))
|
| 55 |
+
return state_dict
|
| 56 |
+
|
| 57 |
+
url = "https://github.com/wasserth/TotalSegmentator/releases/download/"
|
| 58 |
+
|
| 59 |
+
UnetCPP_1 = partial(Unet_TS, channels = [1,32,64,128,256,320,320])
|
| 60 |
+
UnetCPP_2 = partial(Unet_TS, channels = [1,32,64,128,256,320])
|
| 61 |
+
UnetCPP_3 = partial(Unet_TS, channels = [1,32,64,128,256])
|
| 62 |
+
|
| 63 |
+
models = {
|
| 64 |
+
"M291" : (UnetCPP_1(nb_class=25), 1, url+"v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip"),
|
| 65 |
+
"M292" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip"),
|
| 66 |
+
"M293" : (UnetCPP_1(nb_class=19), 1, url+"v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip"),
|
| 67 |
+
"M294" : (UnetCPP_1(nb_class=24), 1, url+"v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip"),
|
| 68 |
+
"M295" : (UnetCPP_1(nb_class=27), 1, url+"v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"),
|
| 69 |
+
"M297" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.4-weights/Dataset297_TotalSegmentator_total_3mm_1559subj_v204.zip"),
|
| 70 |
+
"M298" : (UnetCPP_2(nb_class=118), 2, url+"v2.0.0-weights/Dataset298_TotalSegmentator_total_6mm_1559subj.zip"),
|
| 71 |
+
"M730" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.2.0-weights/Dataset730_TotalSegmentatorMRI_part1_organs_495subj.zip"),
|
| 72 |
+
"M731" : (UnetCPP_1(nb_class=28, mri = True), 1, url+"v2.2.0-weights/Dataset731_TotalSegmentatorMRI_part2_muscles_495subj.zip"),
|
| 73 |
+
"M732" : (UnetCPP_2(nb_class=57), 2, url+"v2.2.0-weights/Dataset732_TotalSegmentatorMRI_total_3mm_495subj.zip"),
|
| 74 |
+
"M733" : (UnetCPP_3(nb_class=57), 3, url+"v2.2.0-weights/Dataset733_TotalSegmentatorMRI_total_6mm_495subj.zip"),
|
| 75 |
+
"M850" : (UnetCPP_1(nb_class=30, mri = True), 1, url+"v2.5.0-weights/Dataset850_TotalSegMRI_part1_organs_1088subj.zip"),
|
| 76 |
+
"M851" : (UnetCPP_1(nb_class=22, mri = True), 1, url+"v2.5.0-weights/Dataset851_TotalSegMRI_part2_muscles_1088subj.zip"),
|
| 77 |
+
"M852" : (UnetCPP_2(nb_class=51), 2, url+"v2.5.0-weights/Dataset852_TotalSegMRI_total_3mm_1088subj.zip"),
|
| 78 |
+
"M853" : (UnetCPP_3(nb_class=51), 3, url+"v2.5.0-weights/Dataset853_TotalSegMRI_total_6mm_1088subj.zip")}
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
for name, model in models.items():
|
| 82 |
+
convert_torchScript_full(name, model[0], model[1], model[2])
|
Destination_Unet_1.txt
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
UNetBlock.DownConvBlock.Conv_0.weight
|
| 2 |
+
UNetBlock.DownConvBlock.Conv_0.bias
|
| 3 |
+
UNetBlock.DownConvBlock.Norm_0.weight
|
| 4 |
+
UNetBlock.DownConvBlock.Norm_0.bias
|
| 5 |
+
UNetBlock.DownConvBlock.Conv_1.weight
|
| 6 |
+
UNetBlock.DownConvBlock.Conv_1.bias
|
| 7 |
+
UNetBlock.DownConvBlock.Norm_1.weight
|
| 8 |
+
UNetBlock.DownConvBlock.Norm_1.bias
|
| 9 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 10 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 11 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 12 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 13 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 14 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 15 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 16 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 17 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 18 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 19 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 20 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 21 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 22 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 23 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 24 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 25 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 26 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 27 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 28 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 29 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 30 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 31 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 32 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 33 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 34 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 35 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 36 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 37 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 38 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 39 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 40 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 41 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 42 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 43 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 44 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 45 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 46 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 47 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 48 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 49 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 50 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 51 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 52 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 53 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 54 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 55 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 56 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 57 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 58 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 59 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 60 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 61 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 62 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 63 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 64 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 65 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 66 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 67 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 68 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 69 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 70 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 71 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 72 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 73 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 74 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 75 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 76 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 77 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 78 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 79 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 80 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 81 |
+
UNetBlock.UpConvBlock.Conv_0.weight
|
| 82 |
+
UNetBlock.UpConvBlock.Conv_0.bias
|
| 83 |
+
UNetBlock.UpConvBlock.Norm_0.weight
|
| 84 |
+
UNetBlock.UpConvBlock.Norm_0.bias
|
| 85 |
+
UNetBlock.UpConvBlock.Conv_1.weight
|
| 86 |
+
UNetBlock.UpConvBlock.Conv_1.bias
|
| 87 |
+
UNetBlock.UpConvBlock.Norm_1.weight
|
| 88 |
+
UNetBlock.UpConvBlock.Norm_1.bias
|
| 89 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 90 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 91 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 92 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 93 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 94 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 95 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 96 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 97 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 98 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 99 |
+
Head.Conv.weight
|
| 100 |
+
Head.Conv.bias
|
Destination_Unet_2.txt
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
UNetBlock.DownConvBlock.Conv_0.weight
|
| 2 |
+
UNetBlock.DownConvBlock.Conv_0.bias
|
| 3 |
+
UNetBlock.DownConvBlock.Norm_0.weight
|
| 4 |
+
UNetBlock.DownConvBlock.Norm_0.bias
|
| 5 |
+
UNetBlock.DownConvBlock.Conv_1.weight
|
| 6 |
+
UNetBlock.DownConvBlock.Conv_1.bias
|
| 7 |
+
UNetBlock.DownConvBlock.Norm_1.weight
|
| 8 |
+
UNetBlock.DownConvBlock.Norm_1.bias
|
| 9 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 10 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 11 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 12 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 13 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 14 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 15 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 16 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 17 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 18 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 19 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 20 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 21 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 22 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 23 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 24 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 25 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 26 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 27 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 28 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 29 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 30 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 31 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 32 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 33 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 34 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 35 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 36 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 37 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 38 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 39 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 40 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 41 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 42 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 43 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 44 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 45 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 46 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 47 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 48 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 49 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 50 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 51 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 52 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 53 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 54 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 55 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 56 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 57 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 58 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 59 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 60 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 61 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 62 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 63 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 64 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 65 |
+
UNetBlock.UpConvBlock.Conv_0.weight
|
| 66 |
+
UNetBlock.UpConvBlock.Conv_0.bias
|
| 67 |
+
UNetBlock.UpConvBlock.Norm_0.weight
|
| 68 |
+
UNetBlock.UpConvBlock.Norm_0.bias
|
| 69 |
+
UNetBlock.UpConvBlock.Conv_1.weight
|
| 70 |
+
UNetBlock.UpConvBlock.Conv_1.bias
|
| 71 |
+
UNetBlock.UpConvBlock.Norm_1.weight
|
| 72 |
+
UNetBlock.UpConvBlock.Norm_1.bias
|
| 73 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 74 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 75 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 76 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 77 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 78 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 79 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 80 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 81 |
+
Head.Conv.weight
|
| 82 |
+
Head.Conv.bias
|
Destination_Unet_3.txt
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
UNetBlock.DownConvBlock.Conv_0.weight
|
| 2 |
+
UNetBlock.DownConvBlock.Conv_0.bias
|
| 3 |
+
UNetBlock.DownConvBlock.Norm_0.weight
|
| 4 |
+
UNetBlock.DownConvBlock.Norm_0.bias
|
| 5 |
+
UNetBlock.DownConvBlock.Conv_1.weight
|
| 6 |
+
UNetBlock.DownConvBlock.Conv_1.bias
|
| 7 |
+
UNetBlock.DownConvBlock.Norm_1.weight
|
| 8 |
+
UNetBlock.DownConvBlock.Norm_1.bias
|
| 9 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 10 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 11 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 12 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 13 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 14 |
+
UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 15 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 16 |
+
UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 17 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 18 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 19 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 20 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 21 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 22 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 23 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 24 |
+
UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 25 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.weight
|
| 26 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_0.bias
|
| 27 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.weight
|
| 28 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_0.bias
|
| 29 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.weight
|
| 30 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Conv_1.bias
|
| 31 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.weight
|
| 32 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.DownConvBlock.Norm_1.bias
|
| 33 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 34 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 35 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 36 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 37 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 38 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 39 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 40 |
+
UNetBlock.UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 41 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.weight
|
| 42 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_0.bias
|
| 43 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.weight
|
| 44 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_0.bias
|
| 45 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.weight
|
| 46 |
+
UNetBlock.UNetBlock.UpConvBlock.Conv_1.bias
|
| 47 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.weight
|
| 48 |
+
UNetBlock.UNetBlock.UpConvBlock.Norm_1.bias
|
| 49 |
+
UNetBlock.UpConvBlock.Conv_0.weight
|
| 50 |
+
UNetBlock.UpConvBlock.Conv_0.bias
|
| 51 |
+
UNetBlock.UpConvBlock.Norm_0.weight
|
| 52 |
+
UNetBlock.UpConvBlock.Norm_0.bias
|
| 53 |
+
UNetBlock.UpConvBlock.Conv_1.weight
|
| 54 |
+
UNetBlock.UpConvBlock.Conv_1.bias
|
| 55 |
+
UNetBlock.UpConvBlock.Norm_1.weight
|
| 56 |
+
UNetBlock.UpConvBlock.Norm_1.bias
|
| 57 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 58 |
+
UNetBlock.UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 59 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 60 |
+
UNetBlock.UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 61 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.weight
|
| 62 |
+
UNetBlock.UNetBlock.CONV_TRANSPOSE.bias
|
| 63 |
+
Head.Conv.weight
|
| 64 |
+
Head.Conv.bias
|
M291.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78065b11af6e339feaa49c72c3aa45d78ae272be1432b27bc17285428a851e6a
|
| 3 |
+
size 124807929
|
M292.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e0a8d8c17572b392bfdbc480edd169ac46fd8962fdf30b81de2406a6f32f275
|
| 3 |
+
size 124808185
|
M293.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9638deb528c1dafc8edfe5287404eefc764770c913810628cd62b23b5950a4c0
|
| 3 |
+
size 124807161
|
M294.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11879030dad493ff88d49cbcbbbbee1f5671c79df743ae931b1bd7bbf7302b5e
|
| 3 |
+
size 124807801
|
M295.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e586a1082dcdb3a054e2ccd670d2506c7436e1f0f6f616cdc771bfe5c41d948
|
| 3 |
+
size 124808185
|
M297.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6490a06ac9d242757af99fc674ba8f74ecb5c62009ecda6c62cda217b9cbbcb7
|
| 3 |
+
size 66225317
|
M298.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f04169298561a6e1fcad186ef1a24623b48edf5f776710793de13a5a9c6d40e
|
| 3 |
+
size 66225317
|
M730.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:462c3c7ca2eb1a40a0984d5b1f004cfbf617c3e9218af65f442006e132fd594e
|
| 3 |
+
size 123170169
|
M731.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9998319e380221bc0e69dae2d0f120f7e9d7c38254dc982afd05aa9f44877b1
|
| 3 |
+
size 123169913
|
M732.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce4d777f8da2ebdeb2a4b19f976f8eac36bba252c5f63fccf76f9e3dd8bc296c
|
| 3 |
+
size 66217253
|
M733.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3340e70c92bb9c095f1db8e139319efeb7fb8e444f04ffbd8421c70e657fc444
|
| 3 |
+
size 22435409
|
M850.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbcf0d153ff695b696748e6192a912b305f0b8deba3d04fec1d685ea39df4c37
|
| 3 |
+
size 123170169
|
M851.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:570f5d944fd07c26c89dd41cd105dbe470286bc91d01d92e3ba63d9904f03903
|
| 3 |
+
size 123169145
|
M852.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22f810c3c77079ca3a196f577d9ab21733aa1e94cb006896c9d35d7d7f726157
|
| 3 |
+
size 66216485
|
M853.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:024cadec3839c0b552fb07b4ba866073430c3856d06ef5a2e0bb2896d20e7596
|
| 3 |
+
size 22434641
|
Model.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from konfai.network import network, blocks
|
| 3 |
+
|
| 4 |
+
class ConvBlock(network.ModuleArgsDict):
|
| 5 |
+
def __init__(self, in_channels : int, out_channels : int, stride: int = 1 ) -> None:
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.add_module("Conv_0", torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=True))
|
| 8 |
+
self.add_module("Norm_0", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 9 |
+
self.add_module("Activation_0", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 10 |
+
self.add_module("Conv_1", torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True))
|
| 11 |
+
self.add_module("Norm_1", torch.nn.InstanceNorm3d(num_features=out_channels, affine=True))
|
| 12 |
+
self.add_module("Activation_1", torch.nn.LeakyReLU(negative_slope=0.01, inplace=True))
|
| 13 |
+
|
| 14 |
+
class UNetHead(network.ModuleArgsDict):
|
| 15 |
+
def __init__(self, in_channels: int, nb_class: int) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.add_module("Conv", torch.nn.Conv3d(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
| 18 |
+
|
| 19 |
+
class UNetBlock(network.ModuleArgsDict):
|
| 20 |
+
|
| 21 |
+
def __init__(self, channels, nb_class: int, mri: bool, i : int = 0) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.add_module("DownConvBlock", ConvBlock(in_channels=channels[0], out_channels=channels[1], stride= ((1,2,2) if mri and i > 4 else 2) if i>0 else 1))
|
| 24 |
+
|
| 25 |
+
if len(channels) > 2:
|
| 26 |
+
self.add_module("UNetBlock", UNetBlock(channels[1:], nb_class, mri, i+1))
|
| 27 |
+
self.add_module("UpConvBlock", ConvBlock(in_channels=channels[1]*2, out_channels=channels[1]))
|
| 28 |
+
|
| 29 |
+
if i > 0:
|
| 30 |
+
self.add_module("CONV_TRANSPOSE", torch.nn.ConvTranspose3d(in_channels = channels[1], out_channels = channels[0], kernel_size = (1,2,2) if mri and i > 4 else 2, stride = (1,2,2) if mri and i > 4 else 2, padding = 0))
|
| 31 |
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
| 32 |
+
|
| 33 |
+
class Unet_TS(network.Network):
|
| 34 |
+
|
| 35 |
+
def __init__(self,
|
| 36 |
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
| 37 |
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
| 38 |
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
| 39 |
+
},
|
| 40 |
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
| 41 |
+
channels = [1, 32, 64, 128, 320, 320],
|
| 42 |
+
nb_class: int = 41,
|
| 43 |
+
mri: bool = False) -> None:
|
| 44 |
+
super().__init__(
|
| 45 |
+
in_channels=channels[0],
|
| 46 |
+
optimizer=optimizer,
|
| 47 |
+
schedulers=schedulers,
|
| 48 |
+
outputs_criterions=outputs_criterions,
|
| 49 |
+
patch=None,
|
| 50 |
+
dim=3,
|
| 51 |
+
)
|
| 52 |
+
self.add_module("UNetBlock", UNetBlock(channels, nb_class, mri))
|
| 53 |
+
self.add_module("Head", UNetHead(channels[1], nb_class))
|
Prediction_CT.yml
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Predictor:
|
| 2 |
+
Model:
|
| 3 |
+
classpath: Model:Unet_TS
|
| 4 |
+
Unet_TS:
|
| 5 |
+
outputs_criterions: None
|
| 6 |
+
channels:
|
| 7 |
+
- 1
|
| 8 |
+
- 32
|
| 9 |
+
- 64
|
| 10 |
+
- 128
|
| 11 |
+
- 256
|
| 12 |
+
- 320
|
| 13 |
+
- 320
|
| 14 |
+
nb_class: 25
|
| 15 |
+
mri: false
|
| 16 |
+
Dataset:
|
| 17 |
+
groups_src:
|
| 18 |
+
Volume:
|
| 19 |
+
groups_dest:
|
| 20 |
+
Volume:
|
| 21 |
+
transforms:
|
| 22 |
+
TensorCast:
|
| 23 |
+
dtype: float32
|
| 24 |
+
inverse: false
|
| 25 |
+
Canonical:
|
| 26 |
+
inverse: true
|
| 27 |
+
Clip:
|
| 28 |
+
min_value: -1024
|
| 29 |
+
max_value: 276
|
| 30 |
+
save_clip_min: false
|
| 31 |
+
save_clip_max: false
|
| 32 |
+
mask: None
|
| 33 |
+
Standardize:
|
| 34 |
+
lazy: false
|
| 35 |
+
mean: -370.00039267657144
|
| 36 |
+
std: 436.5998675471528
|
| 37 |
+
mask: None
|
| 38 |
+
inverse: true
|
| 39 |
+
ResampleToResolution:
|
| 40 |
+
spacing:
|
| 41 |
+
- 1.5
|
| 42 |
+
- 1.5
|
| 43 |
+
- 1.5
|
| 44 |
+
inverse: true
|
| 45 |
+
Padding:
|
| 46 |
+
padding:
|
| 47 |
+
- 32
|
| 48 |
+
- 32
|
| 49 |
+
- 32
|
| 50 |
+
- 32
|
| 51 |
+
- 32
|
| 52 |
+
- 32
|
| 53 |
+
mode: constant
|
| 54 |
+
inverse: true
|
| 55 |
+
patch_transforms: None
|
| 56 |
+
is_input: true
|
| 57 |
+
augmentations: None
|
| 58 |
+
Patch:
|
| 59 |
+
patch_size:
|
| 60 |
+
- 96
|
| 61 |
+
- 128
|
| 62 |
+
- 160
|
| 63 |
+
overlap: 32
|
| 64 |
+
mask: None
|
| 65 |
+
pad_value: 0
|
| 66 |
+
extend_slice: 0
|
| 67 |
+
subset: None
|
| 68 |
+
filter: None
|
| 69 |
+
dataset_filenames:
|
| 70 |
+
- ./Dataset/:nii.gz
|
| 71 |
+
use_cache: false
|
| 72 |
+
batch_size: 1
|
| 73 |
+
outputs_dataset:
|
| 74 |
+
Head:Conv:
|
| 75 |
+
OutputDataset:
|
| 76 |
+
name_class: OutSameAsGroupDataset
|
| 77 |
+
before_reduction_transforms: None
|
| 78 |
+
after_reduction_transforms: None
|
| 79 |
+
final_transforms:
|
| 80 |
+
Softmax:
|
| 81 |
+
dim: 0
|
| 82 |
+
Argmax:
|
| 83 |
+
dim: 0
|
| 84 |
+
TensorCast:
|
| 85 |
+
dtype: uint8
|
| 86 |
+
inverse: true
|
| 87 |
+
dataset_filename: Dataset:mha
|
| 88 |
+
group: MASK
|
| 89 |
+
same_as_group: Volume:Volume
|
| 90 |
+
patch_combine: Cosinus
|
| 91 |
+
inverse_transform: true
|
| 92 |
+
reduction: Mean
|
| 93 |
+
train_name: Curvas
|
| 94 |
+
manual_seed: 32
|
| 95 |
+
gpu_checkpoints: None
|
| 96 |
+
images_log: None
|
| 97 |
+
combine: Mean
|
| 98 |
+
autocast: false
|
| 99 |
+
data_log: None
|
Prediction_MR.yml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Predictor:
|
| 2 |
+
Model:
|
| 3 |
+
classpath: Model:Unet_TS
|
| 4 |
+
Unet_TS:
|
| 5 |
+
outputs_criterions: None
|
| 6 |
+
channels:
|
| 7 |
+
- 1
|
| 8 |
+
- 32
|
| 9 |
+
- 64
|
| 10 |
+
- 128
|
| 11 |
+
- 256
|
| 12 |
+
- 320
|
| 13 |
+
- 320
|
| 14 |
+
nb_class: 30
|
| 15 |
+
Dataset:
|
| 16 |
+
groups_src:
|
| 17 |
+
Volume:
|
| 18 |
+
groups_dest:
|
| 19 |
+
Volume:
|
| 20 |
+
transforms:
|
| 21 |
+
TensorCast:
|
| 22 |
+
dtype: float32
|
| 23 |
+
inverse: false
|
| 24 |
+
Canonical:
|
| 25 |
+
inverse: true
|
| 26 |
+
Standardize:
|
| 27 |
+
lazy: false
|
| 28 |
+
mean: None
|
| 29 |
+
std: None
|
| 30 |
+
mask: None
|
| 31 |
+
inverse: false
|
| 32 |
+
ResampleToResolution:
|
| 33 |
+
spacing:
|
| 34 |
+
- 1.5
|
| 35 |
+
- 1.5
|
| 36 |
+
- 1.5
|
| 37 |
+
inverse: true
|
| 38 |
+
Padding:
|
| 39 |
+
padding:
|
| 40 |
+
- 32
|
| 41 |
+
- 32
|
| 42 |
+
- 32
|
| 43 |
+
- 32
|
| 44 |
+
- 32
|
| 45 |
+
- 32
|
| 46 |
+
mode: constant
|
| 47 |
+
inverse: true
|
| 48 |
+
patch_transforms: None
|
| 49 |
+
is_input: true
|
| 50 |
+
augmentations: None
|
| 51 |
+
Patch:
|
| 52 |
+
patch_size:
|
| 53 |
+
- 96
|
| 54 |
+
- 128
|
| 55 |
+
- 160
|
| 56 |
+
overlap: 32
|
| 57 |
+
mask: None
|
| 58 |
+
pad_value: 0
|
| 59 |
+
extend_slice: 0
|
| 60 |
+
subset: None
|
| 61 |
+
filter: None
|
| 62 |
+
dataset_filenames:
|
| 63 |
+
- ./Dataset/:nii.gz
|
| 64 |
+
use_cache: false
|
| 65 |
+
batch_size: 1
|
| 66 |
+
outputs_dataset:
|
| 67 |
+
Head:Conv:
|
| 68 |
+
OutputDataset:
|
| 69 |
+
name_class: OutSameAsGroupDataset
|
| 70 |
+
before_reduction_transforms: None
|
| 71 |
+
after_reduction_transforms: None
|
| 72 |
+
final_transforms:
|
| 73 |
+
Softmax:
|
| 74 |
+
dim: 0
|
| 75 |
+
Argmax:
|
| 76 |
+
dim: 0
|
| 77 |
+
TensorCast:
|
| 78 |
+
dtype: uint8
|
| 79 |
+
inverse: true
|
| 80 |
+
dataset_filename: Dataset:mha
|
| 81 |
+
group: MASK
|
| 82 |
+
same_as_group: Volume:Volume
|
| 83 |
+
patch_combine: Cosinus
|
| 84 |
+
inverse_transform: true
|
| 85 |
+
reduction: Mean
|
| 86 |
+
train_name: Curvas
|
| 87 |
+
manual_seed: 32
|
| 88 |
+
gpu_checkpoints: None
|
| 89 |
+
images_log: None
|
| 90 |
+
combine: Mean
|
| 91 |
+
autocast: false
|
| 92 |
+
data_log: None
|