Skip to content

Commit

Permalink
remove Tensor._replace_by
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 11, 2023
1 parent 665224d commit d5861cb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
9 changes: 0 additions & 9 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,6 @@ def mark_as_default_output(self) -> Tensor:
res.mark_as_output()
return res

def _replace_by(self, tensor: nn.Tensor):
"""
Replace this tensor by the given tensor.
This is a workaround in case other refs point to this tensor object.
"""
assert isinstance(tensor, nn.Tensor)
self.raw_tensor = tensor.raw_tensor # type: nn.NameCtx
self.data = tensor.data

def _sis_hash(self):
# noinspection PyProtectedMember
return self.raw_tensor._sis_hash()
Expand Down
10 changes: 6 additions & 4 deletions nn/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,12 @@ def prepare_for_config_serialization(self, root_module: nn.Module):
assert sub_out.layer_dict["class"] == "copy"
sub_real_out = sub_out.layer_dict["from"]
assert isinstance(sub_real_out, nn.Tensor)
# noinspection PyProtectedMember
sub_out.tensor._replace_by(sub_real_out)
# noinspection PyProtectedMember
root_mod_call.tensor._replace_by(sub_real_out)
# Replace this tensor by the given tensor.
# This is a workaround in case other refs point to this tensor object.
sub_out.tensor.raw_tensor = sub_real_out.raw_tensor
sub_out.tensor.data = sub_real_out.data
root_mod_call.tensor.raw_tensor = sub_real_out.raw_tensor
root_mod_call.tensor.data = sub_real_out.data

# Do not use self.move_tensor_here(root_mod_call.tensor) because we don't want the extra logic.
self.module = root_module
Expand Down

0 comments on commit d5861cb

Please sign in to comment.