Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help using adahessian in TensorFlow #16

Open
Cyberface opened this issue Feb 8, 2021 · 5 comments
Open

Help using adahessian in TensorFlow #16

Cyberface opened this issue Feb 8, 2021 · 5 comments
Assignees

Comments

@Cyberface
Copy link

Hi, I'm trying to use adahessian in TensorFlow for a simple regression experiment but having trouble.

I have a simple example in this google colab notebook: https://colab.research.google.com/drive/1EbKZ0YHhyu6g8chFlJD74dzWrbo82mbV?usp=sharing

I am getting the following error

ValueError: Variable <tf.Variable 'dense_12/kernel:0' shape=(1, 100) dtype=float32> has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

In the notebook I first write a little training loop that works with standard optimisers such as Adam. See "example training with Adam"

Then in the next section "example training with Adahessian" I basically copy the previous code and make a few modifications to try and get Adahessian to work.

Specifically, I only changed

from

optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

to

optimizer = AdaHessian(learning_rate=0.01)

and from

grads = tape.gradient(current_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

to

grads, Hessian = optimizer.get_gradients_hessian(current_loss, model.trainable_weights)
optimizer.apply_gradients_hessian(zip(grads, Hessian, model.trainable_weights))

Can anyone see what I'm doing wrong? Thanks!

@KimiHsieh
Copy link

I have the same issue

@KimiHsieh
Copy link

Environment: adahessian_tf/environment.yml

I think the issue is caused by grads = gradients.gradients(loss, params) in get_gradients_hessian(self, loss, params)
if you check the return of grads = gradients.gradients(loss, params), it will be None.
But I don't know how to fix this issue.

@tf.function
def step(x, y, training):
with tf.GradientTape() as tape:
r_loss = tf.add_n(model.losses)
outs = model(x, training)
c_loss = loss_fn(y, outs)
loss = c_loss + r_loss
if training:
if optim_method != 'adahessian':
gradients = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
else:
gradients, Hessian = optimizer.get_gradients_hessian(loss, model.trainable_weights)
optimizer.apply_gradients_hessian(zip(gradients, Hessian, model.trainable_weights))

@KimiHsieh
Copy link

Environment: adahessian_tf/environment.yml

I think the issue is caused by grads = gradients.gradients(loss, params) in get_gradients_hessian(self, loss, params)
if you check the return of grads = gradients.gradients(loss, params), it will be None.
But I don't know how to fix this issue.

@tf.function
def step(x, y, training):
with tf.GradientTape() as tape:
r_loss = tf.add_n(model.losses)
outs = model(x, training)
c_loss = loss_fn(y, outs)
loss = c_loss + r_loss
if training:
if optim_method != 'adahessian':
gradients = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
else:
gradients, Hessian = optimizer.get_gradients_hessian(loss, model.trainable_weights)
optimizer.apply_gradients_hessian(zip(gradients, Hessian, model.trainable_weights))

def get_gradients_hessian(self, loss, params):
"""Returns gradients and Hessian of `loss` with respect to `params`.
Arguments:
loss: Loss tensor.
params: List of variables.
Returns:
List of gradient and Hessian tensors.
Raises:
ValueError: In case any gradient cannot be computed (e.g. if gradient
function not implemented).
"""
params = nest.flatten(params)
with backend.get_graph().as_default(), backend.name_scope(self._name +
"/gradients"):
grads = gradients.gradients(loss, params)
for grad, param in zip(grads, params):
if grad is None:
raise ValueError("Variable {} has `None` for gradient. "
"Please make sure that all of your ops have a "
"gradient defined (i.e. are differentiable). "
"Common ops without gradient: "
"K.argmax, K.round, K.eval.".format(param))
# WARNING: for now we do not support gradient clip
# grads = self._clip_gradients(grads)
v = [np.random.uniform(0, 1, size = p.shape) for p in params]
for vi in v:
vi[ vi < 0.5] = -1
vi[ vi >= 0.5] = 1
v = [tf.convert_to_tensor(vi, dtype = tf.dtypes.float32) for vi in v]
vprod = tf.reduce_sum([ tf.reduce_sum(vi * grad) for vi, grad in zip(v, grads)])
Hv = gradients.gradients(vprod, params)
Hd = [ tf.abs(Hvi * vi) for Hvi, vi in zip(Hv, v)]
return grads, Hd

@lpupp
Copy link

lpupp commented Aug 12, 2024

i have the same issue. has this been solved?

@lpupp
Copy link

lpupp commented Aug 19, 2024

In the original post:

I have a simple example in this google colab notebook: https://colab.research.google.com/drive/1EbKZ0YHhyu6g8chFlJD74dzWrbo82mbV?usp=sharing

I am getting the following error ...

wrapping the train function in a @tf.function decorator solves it for me.

tf.gradients is only valid in a graph context (see official docs), which is I guess what was missing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants