Skip to content

Commit

Permalink
Pyre fixes for common.py [3/n] (pytorch#1424)
Browse files Browse the repository at this point in the history
Summary:

Rewriting from D64259572 after BE week

Reviewed By: cyrjano, vivekmig

Differential Revision: D65011997
  • Loading branch information
csauper authored and facebook-github-bot committed Oct 30, 2024
1 parent 492ae0e commit bbeab93
Showing 1 changed file with 13 additions and 32 deletions.
45 changes: 13 additions & 32 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
List,
Literal,
Optional,
overload,
Sequence,
Tuple,
Expand Down Expand Up @@ -272,28 +273,9 @@ def _format_float_or_tensor_into_tuples(
return inputs


@overload
def _format_additional_forward_args(additional_forward_args: None) -> None: ...


@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple]
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...


@overload
def _format_additional_forward_args( # type: ignore
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]: ...


# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]:
additional_forward_args: Optional[object],
) -> Union[None, Tuple[object, ...]]:
if additional_forward_args is not None and not isinstance(
additional_forward_args, tuple
):
Expand Down Expand Up @@ -853,8 +835,7 @@ def _register_backward_hook(
module: Module,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
hook: Callable,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
attr_obj: Any,
attr_obj: Union[object, None],
) -> List[torch.utils.hooks.RemovableHandle]:
grad_out: Dict[device, Tensor] = {}

Expand All @@ -864,10 +845,9 @@ def forward_hook(
out: Union[Tensor, Tuple[Tensor, ...]],
) -> None:
nonlocal grad_out
grad_out = {}

# pyre-fixme[53]: Captured variable `grad_out` is not annotated.
def output_tensor_hook(output_grad: Tensor) -> None:
nonlocal grad_out
grad_out[output_grad.device] = output_grad

if isinstance(out, tuple):
Expand All @@ -878,18 +858,19 @@ def output_tensor_hook(output_grad: Tensor) -> None:
else:
out.register_hook(output_tensor_hook)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def pre_hook(module, inp):
# pyre-fixme[53]: Captured variable `module` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
def input_tensor_hook(input_grad: Tensor):
def pre_hook(module: Module, inp: Union[Tensor, Tuple[Tensor, ...]]) -> Tensor:
def input_tensor_hook(
input_grad: Tensor,
) -> Union[None, Tensor, Tuple[Tensor, ...]]:
nonlocal grad_out

if len(grad_out) == 0:
return
return None
hook_out = hook(module, input_grad, grad_out[input_grad.device])

if hook_out is not None:
return hook_out[0] if isinstance(hook_out, tuple) else hook_out
return None

if isinstance(inp, tuple):
assert (
Expand Down

0 comments on commit bbeab93

Please sign in to comment.