Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pyre errors in tracin #1427

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 27 additions & 45 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(
self,
model: Module,
train_dataset: Union[Dataset, DataLoader],
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
checkpoints: Union[str, List[str], Iterator],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints: Union[str, List[str], Iterator[str]],
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
Expand Down Expand Up @@ -146,9 +146,8 @@ def __init__(
"""

self.model: Module = model

self.checkpoints = checkpoints # type: ignore

self._checkpoints: List[str] = self.checkpoints
self.checkpoints_load_func = checkpoints_load_func
self.loss_fn = loss_fn
# If test_loss_fn not provided, it's assumed to be same as loss_fn
Expand Down Expand Up @@ -184,12 +183,10 @@ def __init__(

@property
def checkpoints(self) -> List[str]:
# pyre-fixme[16]: `TracInCPBase` has no attribute `_checkpoints`.
return self._checkpoints

@checkpoints.setter
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None:
def checkpoints(self, checkpoints: Union[str, List[str], Iterator[str]]) -> None:
if isinstance(checkpoints, str):
self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*")))
elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str):
Expand Down Expand Up @@ -450,10 +447,10 @@ def __init__(
self,
model: Module,
train_dataset: Union[Dataset, DataLoader],
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
checkpoints: Union[str, List[str], Iterator],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints: Union[str, List[str], Iterator[str]],
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
layers: Optional[List[str]] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
Expand Down Expand Up @@ -584,13 +581,11 @@ def __init__(
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch

# check `loss_fn`
# pyre-fixme[4]: Attribute must be annotated.
self.reduction_type = _check_loss_fn(
self.reduction_type: str = _check_loss_fn(
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
)
# check `test_loss_fn` if it was provided
# pyre-fixme[4]: Attribute must be annotated.
self.test_reduction_type = (
self.test_reduction_type: str = (
self.reduction_type
if test_loss_fn is None
else _check_loss_fn(
Expand All @@ -603,8 +598,7 @@ def __init__(
within influence to restore after every influence call)? or make a copy so that
changes to grad_requires aren't persistent after using TracIn.
"""
# pyre-fixme[4]: Attribute must be annotated.
self.layer_modules = None
self.layer_modules: Optional[List[Module]] = None
if layers is not None:
self.layer_modules = _set_active_parameters(model, layers)

Expand Down Expand Up @@ -760,9 +754,8 @@ def _sum_jacobians(

inputs_batch = next(inputs_iter)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_batch_contribution(inputs_batch):
# pyre-fixme[2]: Parameter `inputs_batch` must have a type that does not contain `Any`. # noqa: E501
def get_batch_contribution(inputs_batch: Tuple[Any, ...]) -> Tuple[Tensor, ...]:
_input_jacobians = self._basic_computation_tracincp(
inputs_batch[0:-1],
inputs_batch[-1],
Expand Down Expand Up @@ -871,12 +864,10 @@ def compute_intermediate_quantities(
the variable d in the top of page 15 of the TracIn paper:
https://arxiv.org/pdf/2002.08484.pdf.
"""
# If `inputs` is not a `DataLoader`, turn it into one.
inputs = _format_inputs_dataset(inputs)
f_inputs: DataLoader = _format_inputs_dataset(inputs)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(checkpoint):
def get_checkpoint_contribution(checkpoint: str) -> Tensor:
nonlocal f_inputs
assert (
checkpoint is not None
), "None returned from `checkpoints`, cannot load."
Expand All @@ -885,19 +876,13 @@ def get_checkpoint_contribution(checkpoint):
# get jacobians as tuple of tensors
if aggregate:
inputs_jacobians = self._sum_jacobians(
# pyre-fixme[6]: For 1st argument expected
# `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any],
# typing.Tuple[typing.Any, ...]]`.
inputs,
f_inputs,
self.loss_fn,
self.reduction_type,
)
else:
inputs_jacobians = self._concat_jacobians(
# pyre-fixme[6]: For 1st argument expected
# `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any],
# typing.Tuple[typing.Any, ...]]`.
inputs,
f_inputs,
self.loss_fn,
self.reduction_type,
)
Expand Down Expand Up @@ -932,9 +917,9 @@ def _influence_batch_tracincp(
computed by `_get_checkpoint_jacobians`.
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(input_jacobians, checkpoint):
def get_checkpoint_contribution(
input_jacobians: Tuple[Tensor, ...], checkpoint: str
) -> Tensor:

assert (
checkpoint is not None
Expand Down Expand Up @@ -1224,7 +1209,7 @@ def _self_influence_by_checkpoints(
if show_progress:
# Try to determine length of inner progress bar if possible, with a default
# of `None`.
inputs_len = None
inputs_len: Optional[int] = None
try:
inputs_len = len(inputs)
except TypeError:
Expand All @@ -1237,9 +1222,8 @@ def _self_influence_by_checkpoints(
stacklevel=1,
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def calculate_via_vector_norm(layer_jacobian):
def calculate_via_vector_norm(layer_jacobian) -> Tensor:
# Helper to efficiently calculate vector norm if pytorch version permits.
return (
torch.linalg.vector_norm(
Expand All @@ -1249,10 +1233,8 @@ def calculate_via_vector_norm(layer_jacobian):
** 2
)

# pyre-fixme[53]: Captured variable `inputs_len` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(checkpoint):
def get_checkpoint_contribution(checkpoint: str) -> Tensor:
nonlocal inputs_len
# This function returns a 1D tensor representing the contribution to the
# self influence score for the given checkpoint, for all batches in
# `inputs`. The length of the 1D tensor is the total number of
Expand Down
9 changes: 5 additions & 4 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor:


def _gradient_dot_product(
input_grads: Tuple[Tensor], src_grads: Tuple[Tensor]
input_grads: Tuple[Tensor, ...], src_grads: Tuple[Tensor, ...]
) -> Tensor:
r"""
Computes the dot product between the gradient vector for a model on an input batch
Expand Down Expand Up @@ -334,9 +334,10 @@ def __len__(self) -> int:
return len(self._l)


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def _format_inputs_dataset(inputs_dataset: Union[Tuple[Any, ...], DataLoader]):
def _format_inputs_dataset(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
inputs_dataset: Union[Tuple[Any, ...], DataLoader]
) -> DataLoader:
# if `inputs_dataset` is not a `DataLoader`, turn it into one.
# `_DatasetFromList` turns a list into a `Dataset` where `__getitem__`
# returns an element in the list, and using it to construct a `DataLoader`
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/influence/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_random_model_and_data(
in_features, out_features, num_samples, use_gpu, unpack_inputs
)

net: Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear]
net: Module # Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear]
if model_type == "random":
net = (
BasicLinearNet(in_features, hidden_nodes, out_features)
Expand Down
Loading