forked from M3DV/FracNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
98 lines (78 loc) · 2.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from functools import partial
import torch
from torch import save, nn
from torchinfo import summary
from fastai.basic_train import Learner
from fastai.train import ShowGraph
from fastai.data_block import DataBunch
from torch import optim
from dataset.fracnet_dataset import FracNetTrainDataset
from dataset import transforms as tsfm
from utils.metrics import dice, recall, precision, fbeta_score
from model.unet import UNet
from model.losses import MixLoss, DiceLoss
def main(args):
train_image_dir = args.train_image_dir
train_label_dir = args.train_label_dir
val_image_dir = args.val_image_dir
val_label_dir = args.val_label_dir
batch_size = 4
num_workers = 4
optimizer = optim.SGD
criterion = MixLoss(nn.BCEWithLogitsLoss(), 0.5, DiceLoss(), 1)
thresh = 0.1
recall_partial = partial(recall, thresh=thresh)
precision_partial = partial(precision, thresh=thresh)
fbeta_score_partial = partial(fbeta_score, thresh=thresh)
model = UNet(1, 1, first_out_channels=16)
model = nn.DataParallel(model.cuda())
# Print summary of model
x = torch.rand(16, 1, 64, 64, 64)
print(str(summary(model=model, input_data=x, verbose=0)))
transforms = [
tsfm.Window(-200, 1000),
tsfm.MinMaxNorm(-200, 1000)
]
ds_train = FracNetTrainDataset(train_image_dir, train_label_dir,
transforms=transforms)
dl_train = FracNetTrainDataset.get_dataloader(ds_train, batch_size, False,
num_workers)
ds_val = FracNetTrainDataset(val_image_dir, val_label_dir,
transforms=transforms)
dl_val = FracNetTrainDataset.get_dataloader(ds_val, batch_size, False,
num_workers)
databunch = DataBunch(dl_train, dl_val,
collate_fn=FracNetTrainDataset.collate_fn)
learn = Learner(
databunch,
model,
opt_func=optimizer,
loss_func=criterion,
metrics=[dice, recall_partial, precision_partial, fbeta_score_partial]
)
learn.fit_one_cycle(
200,
1e-1,
pct_start=0,
div_factor=1000,
callbacks=[
ShowGraph(learn),
]
)
if args.save_model:
save(model.module.state_dict(), "./output_train/model_weights.pth")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--train_image_dir", required=True,
help="The training image nii directory.")
parser.add_argument("--train_label_dir", required=True,
help="The training label nii directory.")
parser.add_argument("--val_image_dir", required=True,
help="The validation image nii directory.")
parser.add_argument("--val_label_dir", required=True,
help="The validation label nii directory.")
parser.add_argument("--save_model", default=False,
help="Whether to save the trained model.")
args = parser.parse_args()
main(args)