Skip to content

Code example for stop_gradient() #580

Answered by awni
RahulBhalley asked this question in Q&A
Discussion options

You must be logged in to vote
def fun(x):
  return mx.exp(x) + mx.stop_gradient(mx.exp(x))

print(mx.grad(fun)(mx.array(1.0)))

Gives array(2.71828, dtype=float32).

So there you would only get the gradient through the first mx.exp.

Compare to:

def fun(x):
  return mx.exp(x) + mx.exp(x)

print(mx.grad(fun)(mx.array(1.0)))

Which gives you twice the result of the first (grad through both paths). Gives array(5.43656, dtype=float32).

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by RahulBhalley
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants