-
-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The merge semantics for Arc<Mutex<OwnedTape<_, _>>>
seem a bit unintuitive
#841
Comments
FYI I hacked something together using The basic idea is to provide another layer of indirection so we can mutate the pointer to the actual impl<E, D: Storage<E>> Merge<NoneTape> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
fn merge(self, _: NoneTape) -> Self {
self
}
}
impl<E, D: Storage<E>> Merge<Self> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
fn merge(self, other: Self) -> Self {
if !Arc::ptr_eq(&self, &other) {
let pointer_lhs = self.lock().unwrap();
let mut pointer_rhs = other.lock().unwrap();
if !Arc::ptr_eq(&pointer_lhs, &pointer_rhs) {
let mut lhs = pointer_lhs.lock().unwrap();
let mut rhs = pointer_rhs.lock().unwrap();
lhs.gradients
.gradient_by_id
.append(&mut rhs.gradients.gradient_by_id);
if let Some(leafs) = &mut rhs.gradients.leaf_ids {
lhs.gradients
.leaf_ids
.get_or_insert_with(Default::default)
.append(leafs);
}
lhs.operations.append(&mut rhs.operations);
}
// Update the RHS so it points to the same underlying OwnedTape.
*pointer_rhs = pointer_lhs.clone();
}
self
}
}
impl<E, D: Storage<E>> Tape<E, D> for Arc<Mutex<Arc<Mutex<OwnedTape<E, D>>>>> {
const OWNS_TAPE: bool = true;
fn add_backward_op<F>(&mut self, operation: F)
where
F: 'static + FnOnce(&mut Gradients<E, D>) -> Result<(), D::Err>,
{
let mut tape = self.lock().unwrap();
tape.add_backward_op(operation);
}
} TBH anything with |
So you would want both tapes to be merged together and have the same gradients? |
The merge semantics for
Arc<Mutex<OwnedTape<_, _>>>
seem a bit unintuitive.In particular, one would hope that when these tapes are merged, they would essentially be replaced with one tape which is the union of the two input tapes.
But what actually happens is the left tape becomes the union and the right tape becomes the empty tape.
This can lead to a similar problem as with the plain
OwnedTape<_, _>
, where the gradient tape is partitioned across several objects, and it's up to the programmer to figure out which object has which part of the gradients.For reference, here's the merge code for
Arc<Mutex<OwnedTape<_, _>>>
:If we invoke this code by writing
let z = x + y.clone()
and then writelet zz = y * 2.0
, we'll end up with gradients partitioned betweenz
andzz
, each of which haveArc<Mutex<_>>
pointers to distinctOwnedTape<_, _>
objects.I'd personally find it more natural if the merge operation mutated the tapes of both arguments to make them both point to the same underlying data, so that we only have one
OwnedTape<_, _>
at the end of the day.Thoughts?
The text was updated successfully, but these errors were encountered: