Skip to content

Commit

Permalink
Improving typing of additional_forward_args
Browse files Browse the repository at this point in the history
Summary: Change `additional_forward_args` to resolve pyre errors and address the feedback on D64998803.

Differential Revision: D65178564
  • Loading branch information
Zach Carmichael authored and facebook-github-bot committed Oct 30, 2024
1 parent 07470af commit 1a4e012
Show file tree
Hide file tree
Showing 48 changed files with 172 additions and 210 deletions.
3 changes: 1 addition & 2 deletions captum/_utils/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ def _compute_and_save_activations(
inputs: Union[Tensor, Tuple[Tensor, ...]],
identifier: str,
num_id: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
load_from_disk: bool = True,
) -> None:
r"""
Expand Down
25 changes: 20 additions & 5 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,25 @@ def _format_float_or_tensor_into_tuples(
return inputs


@overload
def _format_additional_forward_args(
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[Tensor, Tuple]
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Tuple: ...


@overload
def _format_additional_forward_args( # type: ignore
additional_forward_args: Optional[object],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]: ...


def _format_additional_forward_args(
additional_forward_args: Optional[object],
) -> Union[None, Tuple[object, ...]]:
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
) -> Union[None, Tuple]:
if additional_forward_args is not None and not isinstance(
additional_forward_args, tuple
):
Expand All @@ -284,8 +300,8 @@ def _format_additional_forward_args(


def _expand_additional_forward_args(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Union[None, Tuple],
n_steps: int,
expansion_type: ExpansionTypes = ExpansionTypes.repeat,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
Expand Down Expand Up @@ -557,8 +573,7 @@ def _run_forward(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
inputs: Any,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
) -> Union[Tensor, Future[Tensor]]:
forward_func_args = signature(forward_func).parameters
if len(forward_func_args) == 0:
Expand Down
39 changes: 16 additions & 23 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def compute_gradients(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
) -> Tuple[Tensor, ...]:
r"""
Computes gradients of the output with respect to inputs for an
Expand Down Expand Up @@ -175,8 +174,7 @@ def _forward_layer_eval(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: List[Module],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
grad_enabled: bool = False,
Expand All @@ -191,8 +189,7 @@ def _forward_layer_eval(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
grad_enabled: bool = False,
Expand All @@ -204,7 +201,7 @@ def _forward_layer_eval(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: ModuleOrModuleList,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
grad_enabled: bool = False,
Expand Down Expand Up @@ -233,8 +230,7 @@ def _forward_layer_distributed_eval(
inputs: Any,
layer: ModuleOrModuleList,
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return: Literal[False] = False,
require_layer_grads: bool = False,
Expand All @@ -250,7 +246,7 @@ def _forward_layer_distributed_eval(
inputs: Any,
layer: ModuleOrModuleList,
target_ind: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
attribute_to_layer_input: bool = False,
*,
forward_hook_with_return: Literal[True],
Expand All @@ -264,7 +260,7 @@ def _forward_layer_distributed_eval(
inputs: Any,
layer: ModuleOrModuleList,
target_ind: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
attribute_to_layer_input: bool = False,
forward_hook_with_return: bool = False,
require_layer_grads: bool = False,
Expand Down Expand Up @@ -427,8 +423,7 @@ def _forward_layer_eval_with_neuron_grads(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
*,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
Expand All @@ -446,7 +441,7 @@ def _forward_layer_eval_with_neuron_grads(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: List[Module],
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
Expand All @@ -462,7 +457,7 @@ def _forward_layer_eval_with_neuron_grads(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
Expand All @@ -475,7 +470,7 @@ def _forward_layer_eval_with_neuron_grads(
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: ModuleOrModuleList,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
gradient_neuron_selector: Union[
None, int, Tuple[Union[int, slice], ...], Callable
Expand Down Expand Up @@ -549,8 +544,7 @@ def compute_layer_gradients_and_eval(
layer: Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
*,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
Expand All @@ -571,7 +565,7 @@ def compute_layer_gradients_and_eval(
layer: List[Module],
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
gradient_neuron_selector: None = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -590,7 +584,7 @@ def compute_layer_gradients_and_eval(
layer: Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
gradient_neuron_selector: None = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -606,7 +600,7 @@ def compute_layer_gradients_and_eval(
layer: ModuleOrModuleList,
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
gradient_neuron_selector: Union[
None, int, Tuple[Union[int, slice], ...], Callable
Expand Down Expand Up @@ -792,8 +786,7 @@ def grad_fn(
forward_fn: Callable,
inputs: TensorOrTupleOfTensorsGeneric,
target_ind: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
additional_forward_args: Optional[object] = None,
) -> Tuple[Tensor, ...]:
_, grads = _forward_layer_eval_with_neuron_grads(
forward_fn,
Expand Down
18 changes: 9 additions & 9 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyre-strict
import typing
import warnings
from typing import Callable, cast, Dict, List, Literal, Tuple, Type, Union
from typing import Callable, cast, Dict, List, Literal, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -117,7 +117,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
*,
return_convergence_delta: Literal[True],
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
Expand All @@ -129,7 +129,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
return_convergence_delta: Literal[False] = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> TensorOrTupleOfTensorsGeneric: ...
Expand All @@ -140,7 +140,7 @@ def attribute( # type: ignore
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
return_convergence_delta: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Union[
Expand Down Expand Up @@ -370,7 +370,7 @@ def _construct_forward_func(
forward_func: Callable[..., Tensor],
inputs: Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
) -> Callable[[], Tensor]:
def forward_fn() -> Tensor:
model_out = cast(
Expand Down Expand Up @@ -604,7 +604,7 @@ def attribute(
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
*,
return_convergence_delta: Literal[True],
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
Expand All @@ -618,7 +618,7 @@ def attribute(
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
return_convergence_delta: Literal[False] = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> TensorOrTupleOfTensorsGeneric: ...
Expand All @@ -631,7 +631,7 @@ def attribute( # type: ignore
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[Tuple[object, ...]] = None,
return_convergence_delta: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Union[
Expand Down Expand Up @@ -840,7 +840,7 @@ def _expand_inputs_baselines_targets(
baselines: Tuple[Tensor, ...],
inputs: Tuple[Tensor, ...],
target: TargetType,
additional_forward_args: object,
additional_forward_args: Optional[Tuple[object, ...]],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, object]:
inp_bsz = inputs[0].shape[0]
base_bsz = baselines[0].shape[0]
Expand Down
6 changes: 3 additions & 3 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def attribute(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down Expand Up @@ -408,7 +408,7 @@ def attribute_future(
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down Expand Up @@ -655,7 +655,7 @@ def _ith_input_ablation_generator(
self,
i: int,
inputs: TensorOrTupleOfTensorsGeneric,
additional_args: object,
additional_args: Optional[Tuple[object, ...]],
target: TargetType,
baselines: BaselineType,
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
Expand Down
6 changes: 3 additions & 3 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import torch
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
Expand Down Expand Up @@ -99,7 +99,7 @@ def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down Expand Up @@ -280,7 +280,7 @@ def attribute_future(
self,
inputs: TensorOrTupleOfTensorsGeneric,
target: TargetType = None,
additional_forward_args: object = None,
additional_forward_args: Optional[object] = None,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
Expand Down
Loading

0 comments on commit 1a4e012

Please sign in to comment.