Skip to content

Commit

Permalink
return total grad norm in torchrec grad clipping (pytorch#2507)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2507

this is to keep consistent with torch.nn.utils.clip_grad_norm_

Reviewed By: awgu

Differential Revision: D64712277

fbshipit-source-id: 689e02bd21dc37568c3347d5b0833c573f042c15
  • Loading branch information
weifengpy authored and facebook-github-bot committed Oct 23, 2024
1 parent 1a57ce1 commit b34da0d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def step(self, closure: Any = None) -> None:
self._step_num += 1

@torch.no_grad()
def clip_grad_norm_(self) -> None:
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
"""Clip the gradient norm of all parameters."""
max_norm = self._max_gradient
norm_type = float(self._norm_type)
Expand Down Expand Up @@ -224,6 +224,7 @@ def clip_grad_norm_(self) -> None:
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(all_grads, clip_coef_clamped)
return total_grad_norm


def _batch_cal_norm(
Expand Down

0 comments on commit b34da0d

Please sign in to comment.