From 3e6776c8b1b1c44c90e840c7a5d8204497a8d2dc Mon Sep 17 00:00:00 2001 From: Christy Sauper Date: Wed, 30 Oct 2024 14:29:27 -0700 Subject: [PATCH] Fix pyre errors in tracin Summary: more pyre error fixing in tracin Differential Revision: D65232602 --- captum/influence/_core/tracincp.py | 72 +++++++++++------------------- captum/influence/_utils/common.py | 9 ++-- tests/helpers/influence/common.py | 2 +- 3 files changed, 33 insertions(+), 50 deletions(-) diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index a603634d5..a546ae6bc 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -72,10 +72,10 @@ def __init__( self, model: Module, train_dataset: Union[Dataset, DataLoader], - # pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter. - checkpoints: Union[str, List[str], Iterator], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - checkpoints_load_func: Callable = _load_flexible_state_dict, + checkpoints: Union[str, List[str], Iterator[str]], + checkpoints_load_func: Callable[ + [Module, str], float + ] = _load_flexible_state_dict, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, batch_size: Union[int, None] = 1, @@ -146,9 +146,8 @@ def __init__( """ self.model: Module = model - self.checkpoints = checkpoints # type: ignore - + self._checkpoints: List[str] = self.checkpoints self.checkpoints_load_func = checkpoints_load_func self.loss_fn = loss_fn # If test_loss_fn not provided, it's assumed to be same as loss_fn @@ -184,12 +183,10 @@ def __init__( @property def checkpoints(self) -> List[str]: - # pyre-fixme[16]: `TracInCPBase` has no attribute `_checkpoints`. return self._checkpoints @checkpoints.setter - # pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter. - def checkpoints(self, checkpoints: Union[str, List[str], Iterator]) -> None: + def checkpoints(self, checkpoints: Union[str, List[str], Iterator[str]]) -> None: if isinstance(checkpoints, str): self._checkpoints = AV.sort_files(glob.glob(join(checkpoints, "*"))) elif isinstance(checkpoints, List) and isinstance(checkpoints[0], str): @@ -450,10 +447,10 @@ def __init__( self, model: Module, train_dataset: Union[Dataset, DataLoader], - # pyre-fixme[24]: Generic type `Iterator` expects 1 type parameter. - checkpoints: Union[str, List[str], Iterator], - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - checkpoints_load_func: Callable = _load_flexible_state_dict, + checkpoints: Union[str, List[str], Iterator[str]], + checkpoints_load_func: Callable[ + [Module, str], float + ] = _load_flexible_state_dict, layers: Optional[List[str]] = None, # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. loss_fn: Optional[Union[Module, Callable]] = None, @@ -584,13 +581,11 @@ def __init__( self.sample_wise_grads_per_batch = sample_wise_grads_per_batch # check `loss_fn` - # pyre-fixme[4]: Attribute must be annotated. - self.reduction_type = _check_loss_fn( + self.reduction_type: str = _check_loss_fn( self, loss_fn, "loss_fn", sample_wise_grads_per_batch ) # check `test_loss_fn` if it was provided - # pyre-fixme[4]: Attribute must be annotated. - self.test_reduction_type = ( + self.test_reduction_type: str = ( self.reduction_type if test_loss_fn is None else _check_loss_fn( @@ -603,8 +598,7 @@ def __init__( within influence to restore after every influence call)? or make a copy so that changes to grad_requires aren't persistent after using TracIn. """ - # pyre-fixme[4]: Attribute must be annotated. - self.layer_modules = None + self.layer_modules: Optional[List[Module]] = None if layers is not None: self.layer_modules = _set_active_parameters(model, layers) @@ -760,9 +754,8 @@ def _sum_jacobians( inputs_batch = next(inputs_iter) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_batch_contribution(inputs_batch): + # pyre-fixme[2]: Parameter `inputs_batch` must have a type that does not contain `Any`. + def get_batch_contribution(inputs_batch: Tuple[Any, ...]) -> Tuple[Tensor, ...]: _input_jacobians = self._basic_computation_tracincp( inputs_batch[0:-1], inputs_batch[-1], @@ -871,12 +864,10 @@ def compute_intermediate_quantities( the variable d in the top of page 15 of the TracIn paper: https://arxiv.org/pdf/2002.08484.pdf. """ - # If `inputs` is not a `DataLoader`, turn it into one. - inputs = _format_inputs_dataset(inputs) + f_inputs: DataLoader = _format_inputs_dataset(inputs) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_checkpoint_contribution(checkpoint): + def get_checkpoint_contribution(checkpoint: str) -> Tensor: + nonlocal f_inputs assert ( checkpoint is not None ), "None returned from `checkpoints`, cannot load." @@ -885,19 +876,13 @@ def get_checkpoint_contribution(checkpoint): # get jacobians as tuple of tensors if aggregate: inputs_jacobians = self._sum_jacobians( - # pyre-fixme[6]: For 1st argument expected - # `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any], - # typing.Tuple[typing.Any, ...]]`. - inputs, + f_inputs, self.loss_fn, self.reduction_type, ) else: inputs_jacobians = self._concat_jacobians( - # pyre-fixme[6]: For 1st argument expected - # `DataLoader[typing.Any]` but got `Union[DataLoader[typing.Any], - # typing.Tuple[typing.Any, ...]]`. - inputs, + f_inputs, self.loss_fn, self.reduction_type, ) @@ -932,9 +917,9 @@ def _influence_batch_tracincp( computed by `_get_checkpoint_jacobians`. """ - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_checkpoint_contribution(input_jacobians, checkpoint): + def get_checkpoint_contribution( + input_jacobians: Tuple[Tensor, ...], checkpoint: str + ) -> Tensor: assert ( checkpoint is not None @@ -1224,7 +1209,7 @@ def _self_influence_by_checkpoints( if show_progress: # Try to determine length of inner progress bar if possible, with a default # of `None`. - inputs_len = None + inputs_len: Optional[int] = None try: inputs_len = len(inputs) except TypeError: @@ -1237,9 +1222,8 @@ def _self_influence_by_checkpoints( stacklevel=1, ) - # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. - def calculate_via_vector_norm(layer_jacobian): + def calculate_via_vector_norm(layer_jacobian) -> Tensor: # Helper to efficiently calculate vector norm if pytorch version permits. return ( torch.linalg.vector_norm( @@ -1249,10 +1233,8 @@ def calculate_via_vector_norm(layer_jacobian): ** 2 ) - # pyre-fixme[53]: Captured variable `inputs_len` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def get_checkpoint_contribution(checkpoint): + def get_checkpoint_contribution(checkpoint: str) -> Tensor: + nonlocal inputs_len # This function returns a 1D tensor representing the contribution to the # self influence score for the given checkpoint, for all batches in # `inputs`. The length of the 1D tensor is the total number of diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 4a25ccb72..ba3ba0f85 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -65,7 +65,7 @@ def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: def _gradient_dot_product( - input_grads: Tuple[Tensor], src_grads: Tuple[Tensor] + input_grads: Tuple[Tensor, ...], src_grads: Tuple[Tensor, ...] ) -> Tensor: r""" Computes the dot product between the gradient vector for a model on an input batch @@ -334,9 +334,10 @@ def __len__(self) -> int: return len(self._l) -# pyre-fixme[3]: Return type must be annotated. -# pyre-fixme[2]: Parameter annotation cannot contain `Any`. -def _format_inputs_dataset(inputs_dataset: Union[Tuple[Any, ...], DataLoader]): +def _format_inputs_dataset( + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + inputs_dataset: Union[Tuple[Any, ...], DataLoader] +) -> DataLoader: # if `inputs_dataset` is not a `DataLoader`, turn it into one. # `_DatasetFromList` turns a list into a `Dataset` where `__getitem__` # returns an element in the list, and using it to construct a `DataLoader` diff --git a/tests/helpers/influence/common.py b/tests/helpers/influence/common.py index bedba7693..1369d96d4 100644 --- a/tests/helpers/influence/common.py +++ b/tests/helpers/influence/common.py @@ -409,7 +409,7 @@ def get_random_model_and_data( in_features, out_features, num_samples, use_gpu, unpack_inputs ) - net: Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear] + net: Module # Union[BasicLinearNet, MultLinearNet, Linear, UnpackLinear] if model_type == "random": net = ( BasicLinearNet(in_features, hidden_nodes, out_features)