Skip to content

Commit

Permalink
grad_wei can't be NoneType when running with DeepSpeed, for zero3 wil…
Browse files Browse the repository at this point in the history
…l divided the gradient
  • Loading branch information
ys950902 committed Sep 20, 2024
1 parent fc989b8 commit 05779cf
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
args = get_args()
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

Expand Down Expand Up @@ -368,7 +369,12 @@ def backward(ctx, grad_output):
# grad_weight = grad_output.t().matmul(total_input)
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None

if args.enable_zbh1_pipeline:
grad_weight = None
else:
grad_weight = weight.grad

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.sequence_parallel:
Expand Down

0 comments on commit 05779cf

Please sign in to comment.