diff --git a/megatron/training.py b/megatron/training.py index ef32cd3856..7f160c775c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -747,7 +747,10 @@ def train_step(forward_step_func, data_iterator, # Update learning rate. if args.deepspeed: skipped_iter = 0 - grad_norm = None + if hasattr(model[0], 'get_global_grad_norm'): + grad_norm = model[0].get_global_grad_norm() + else: + grad_norm = None num_zeros_in_grad = None loss_reduced = {}