Skip to content

Commit

Permalink
Fix pyre errors in tracin (#1427)
Browse files Browse the repository at this point in the history
Summary:

more pyre error fixing in tracin

Reviewed By: cyrjano

Differential Revision: D65232602
  • Loading branch information
csauper authored and facebook-github-bot committed Oct 30, 2024
1 parent 87a0a9e commit eb90717
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 50 deletions.
72 changes: 27 additions & 45 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(
self,
model: Module,
train_dataset: Union[Dataset, DataLoader],
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
checkpoints: Union[str, List[str], Iterator],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints: Union[str, List[str], Iterator[str]],
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
Expand Down Expand Up @@ -146,9 +146,8 @@ def __init__(
"""

self.model: Module = model

self.checkpoints = checkpoints # type: ignore

self._checkpoints: List[str] = self.checkpoints
self.checkpoints_load_func = checkpoints_load_func
self.loss_fn = loss_fn
# If test_loss_fn not provided, it's assumed to be same as loss_fn
Expand Down Expand Up @@ -184,12 +183,10 @@ def __init__(

@property
def checkpoints(self) -> List[str]:
# pyre-fixme[16]: `TracInCPBase` has no attribute `_checkpoints`.
return self._checkpoints

@checkpoints.setter
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None:
def checkpoints(self, checkpoints: Union[str, List[str], Iterator[str]]) -> None:
if isinstance(checkpoints, str):
self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*")))
elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str):
Expand Down Expand Up @@ -450,10 +447,10 @@ def __init__(
self,
model: Module,
train_dataset: Union[Dataset, DataLoader],
# pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter.
checkpoints: Union[str, List[str], Iterator],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
checkpoints_load_func: Callable = _load_flexible_state_dict,
checkpoints: Union[str, List[str], Iterator[str]],
checkpoints_load_func: Callable[
[Module, str], float
] = _load_flexible_state_dict,
layers: Optional[List[str]] = None,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
loss_fn: Optional[Union[Module, Callable]] = None,
Expand Down Expand Up @@ -584,13 +581,11 @@ def __init__(
self.sample_wise_grads_per_batch = sample_wise_grads_per_batch

# check `loss_fn`
# pyre-fixme[4]: Attribute must be annotated.
self.reduction_type = _check_loss_fn(
self.reduction_type: str = _check_loss_fn(
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
)
# check `test_loss_fn` if it was provided
# pyre-fixme[4]: Attribute must be annotated.
self.test_reduction_type = (
self.test_reduction_type: str = (
self.reduction_type
if test_loss_fn is None
else _check_loss_fn(
Expand All @@ -603,8 +598,7 @@ def __init__(
within influence to restore after every influence call)? or make a copy so that
changes to grad_requires aren't persistent after using TracIn.
"""
# pyre-fixme[4]: Attribute must be annotated.
self.layer_modules = None
self.layer_modules: Optional[List[Module]] = None
if layers is not None:
self.layer_modules = _set_active_parameters(model, layers)

Expand Down Expand Up @@ -760,9 +754,8 @@ def _sum_jacobians(

inputs_batch = next(inputs_iter)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_batch_contribution(inputs_batch):
# pyre-fixme[2]: Parameter `inputs_batch` must have a type that does not contain `Any`. # noqa: E501
def get_batch_contribution(inputs_batch: Tuple[Any, ...]) -> Tuple[Tensor, ...]:
_input_jacobians = self._basic_computation_tracincp(
inputs_batch[0:-1],
inputs_batch[-1],
Expand Down Expand Up @@ -871,12 +864,10 @@ def compute_intermediate_quantities(
the variable d in the top of page 15 of the TracIn paper:
https://arxiv.org/pdf/2002.08484.pdf.
"""
# If `inputs` is not a `DataLoader`, turn it into one.
inputs = _format_inputs_dataset(inputs)
f_inputs: DataLoader = _format_inputs_dataset(inputs)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(checkpoint):
def get_checkpoint_contribution(checkpoint: str) -> Tensor:
nonlocal f_inputs
assert (
checkpoint is not None
), "None returned from `checkpoints`, cannot load."
Expand All @@ -885,19 +876,13 @@ def get_checkpoint_contribution(checkpoint):
# get jacobians as tuple of tensors
if aggregate:
inputs_jacobians = self._sum_jacobians(
# pyre-fixme[6]: For 1st argument expected
# `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any],
# typing.Tuple[typing.Any, ...]]`.
inputs,
f_inputs,
self.loss_fn,
self.reduction_type,
)
else:
inputs_jacobians = self._concat_jacobians(
# pyre-fixme[6]: For 1st argument expected
# `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any],
# typing.Tuple[typing.Any, ...]]`.
inputs,
f_inputs,
self.loss_fn,
self.reduction_type,
)
Expand Down Expand Up @@ -932,9 +917,9 @@ def _influence_batch_tracincp(
computed by `_get_checkpoint_jacobians`.
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(input_jacobians, checkpoint):
def get_checkpoint_contribution(
input_jacobians: Tuple[Tensor, ...], checkpoint: str
) -> Tensor:

assert (
checkpoint is not None
Expand Down Expand Up @@ -1224,7 +1209,7 @@ def _self_influence_by_checkpoints(
if show_progress:
# Try to determine length of inner progress bar if possible, with a default
# of `None`.
inputs_len = None
inputs_len: Optional[int] = None
try:
inputs_len = len(inputs)
except TypeError:
Expand All @@ -1237,9 +1222,8 @@ def _self_influence_by_checkpoints(
stacklevel=1,
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def calculate_via_vector_norm(layer_jacobian):
def calculate_via_vector_norm(layer_jacobian) -> Tensor:
# Helper to efficiently calculate vector norm if pytorch version permits.
return (
torch.linalg.vector_norm(
Expand All @@ -1249,10 +1233,8 @@ def calculate_via_vector_norm(layer_jacobian):
** 2
)

# pyre-fixme[53]: Captured variable `inputs_len` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_checkpoint_contribution(checkpoint):
def get_checkpoint_contribution(checkpoint: str) -> Tensor:
nonlocal inputs_len
# This function returns a 1D tensor representing the contribution to the
# self influence score for the given checkpoint, for all batches in
# `inputs`. The length of the 1D tensor is the total number of
Expand Down
9 changes: 5 additions & 4 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor:


def _gradient_dot_product(
input_grads: Tuple[Tensor], src_grads: Tuple[Tensor]
input_grads: Tuple[Tensor, ...], src_grads: Tuple[Tensor, ...]
) -> Tensor:
r"""
Computes the dot product between the gradient vector for a model on an input batch
Expand Down Expand Up @@ -334,9 +334,10 @@ def __len__(self) -> int:
return len(self._l)


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def _format_inputs_dataset(inputs_dataset: Union[Tuple[Any, ...], DataLoader]):
def _format_inputs_dataset(
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
inputs_dataset: Union[Tuple[Any, ...], DataLoader]
) -> DataLoader:
# if `inputs_dataset` is not a `DataLoader`, turn it into one.
# `_DatasetFromList` turns a list into a `Dataset` where `__getitem__`
# returns an element in the list, and using it to construct a `DataLoader`
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/influence/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def get_random_model_and_data(
in_features, out_features, num_samples, use_gpu, unpack_inputs
)

net: Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear]
net: Module # Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear]
if model_type == "random":
net = (
BasicLinearNet(in_features, hidden_nodes, out_features)
Expand Down

0 comments on commit eb90717

Please sign in to comment.