Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Make autograd ignore computing gradients for certain functions #152

Open
MLpatzer opened this issue Sep 3, 2016 · 2 comments
Open

Make autograd ignore computing gradients for certain functions #152

MLpatzer opened this issue Sep 3, 2016 · 2 comments

Comments

@MLpatzer
Copy link

MLpatzer commented Sep 3, 2016

Hi, I am a bit new to torch and autograd. Sorry if this is a bit of an obvious question.

I am trying to use a function not supported by autograd namely "torch.sort()" in order to then only backprop for certain components in my loss function. It seems since the sort would just give indexes to be used later on it should be some way to make autograd ignore it from its gradients but still execute. I've tried a few variants of doing this but can't seem to get it to work.

pseudcode would look something like this

function features(x)
...
end

function Loss(x,y)

feat=features(x)

inds,_=torch.sort(feat,true) -- ignore this line

x=x:index(1,inds[{{1,20}}])
y=y:index(1,inds[{{1,20}}])
...
end

@allanzelener
Copy link
Contributor

I think the easiest way to do this is to define an nn module that returns inds on the forward pass and just returns gradOutput on the backward pass, and then functionalize that module.

@alexbw
Copy link
Collaborator

alexbw commented Sep 7, 2016

I think you'd want something like this, outside of your loss function:

-- Build your own custom module
mymodule = {}
-- Make a sort function in your module, which just returns the sorted array
mymodule.sort = function(x)
   local sorted, _ = torch.sort(x)
   return sorted
end

-- Define the gradient for the module
grad.overload.module("mymodule", mymodule, function(module)
   module.gradient("sort", {
      function(g, ans, x)
         -- You need to define the gradient of sort here
         -- Would involve "unsorting" g, and returning it
      end
   })
end)

You can't ignore the gradient of sort, because it changes the indexing for subsequent use of your data, which needs to be undone in the backwards pass.

Also, in autograd, you don't want to use function calls of the form x:blah(), because they are often in-place, and we don't support that in autograd. You'll want to rewrite it as torch.blah(x).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants