Skip to content

Commit

Permalink
Flag to disable tracking for training task (#1104)
Browse files Browse the repository at this point in the history
* Flag to disable tracking for training task

Signed-off-by: Sachidanand Alle <[email protected]>

* Flag to disable tracking for training task

Signed-off-by: Sachidanand Alle <[email protected]>

Signed-off-by: Sachidanand Alle <[email protected]>
  • Loading branch information
SachidanandAlle authored Oct 28, 2022
1 parent a1a2b84 commit 166ff7c
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion monailabel/tasks/train/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
PersistentDataset,
SmartCacheDataset,
ThreadDataLoader,
get_track_meta,
partition_dataset,
set_track_meta,
)
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
Expand Down Expand Up @@ -167,6 +169,7 @@ def __init__(
self._find_unused_parameters = find_unused_parameters
self._load_strict = load_strict
self._labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
self._disable_tracking = kwargs.get("disable_tracking", True)

@abstractmethod
def network(self, context: Context):
Expand Down Expand Up @@ -455,7 +458,16 @@ def train(self, rank, world_size, request, datalist):

# Finalize and Run Training
self.finalize(context)
context.trainer.run()

# Disable Tracking
meta_tracking = get_track_meta()
if self._disable_tracking:
set_track_meta(False)

try:
context.trainer.run()
finally:
set_track_meta(meta_tracking) # In case of same process (restore)

if context.multi_gpu:
torch.distributed.destroy_process_group()
Expand Down

0 comments on commit 166ff7c

Please sign in to comment.