Skip to content

Commit

Permalink
syntax fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmarcinkiewicz committed Dec 12, 2023
1 parent 5969941 commit b5cc52e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
1 change: 0 additions & 1 deletion image_segmentation/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


def main():
mllog.config(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'unet3d.log'))
mllog.config(filename=os.path.join("/results", 'unet3d.log'))
mllogger = mllog.get_mllogger()
mllogger.logger.propagate = False
Expand Down
7 changes: 3 additions & 4 deletions image_segmentation/pytorch/runtime/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def lr_warmup(optimizer, init_lr, lr, current_samples, warmup_samples):


def lr_decay(optimizer, lr_decay_samples, lr_decay_factor, total_samples):
if total_samples > lr_decay_samples[0]:
if len(lr_decay_samples) > 0 and total_samples > lr_decay_samples[0]:
lr_decay_samples = lr_decay_samples[1:]
for param_group in optimizer.param_groups:
param_group['lr'] *= lr_decay_factor
Expand Down Expand Up @@ -73,7 +73,7 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal
while not diverged and not is_successful:
mllog_start(key=CONSTANTS.BLOCK_START, sync=False,
metadata={CONSTANTS.FIRST_EPOCH_NUM: total_samples,
CONSTANTS.EPOCH_COUNT: next_eval_at})
CONSTANTS.EPOCH_COUNT: EVALUATE_EVERY})

t0 = time()
while total_samples < next_eval_at:
Expand Down Expand Up @@ -112,7 +112,6 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal
optimizer.zero_grad()
iteration += 1

print(f"Throughput: {round(EVALUATE_EVERY / (time() - t0), 2)} samples/s. Time {time() - t0}")
# Evaluation
del output
if total_samples >= START_EVAL_AT:
Expand All @@ -134,7 +133,7 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal

mllog_end(key=CONSTANTS.BLOCK_STOP, sync=False,
metadata={CONSTANTS.FIRST_EPOCH_NUM: total_samples,
CONSTANTS.EPOCH_COUNT: next_eval_at})
CONSTANTS.EPOCH_COUNT: EVALUATE_EVERY})
next_eval_at += EVALUATE_EVERY

mllog_end(key=CONSTANTS.RUN_STOP, sync=True,
Expand Down

0 comments on commit b5cc52e

Please sign in to comment.