forked from facebookresearch/aepsych
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for fixed parameters (facebookresearch#457)
Summary: Pull Request resolved: facebookresearch#457 Add support for fixed parameters. These parameters are just set in the config and will be removed before any model or generator is aware of them (then added back in whenever a model or generator is asked for an output. Differential Revision: D66012863
- Loading branch information
1 parent
cb9f00c
commit e0b6c77
Showing
6 changed files
with
288 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta, Inc. and its affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import torch | ||
from aepsych.config import Config | ||
from aepsych.transforms.ops.base import Transform | ||
|
||
|
||
class Fixed(Transform, torch.nn.Module): | ||
def __init__( | ||
self, | ||
indices: List[int], | ||
values: List[Union[float, int]], | ||
transform_on_train: bool = True, | ||
transform_on_eval: bool = True, | ||
transform_on_fantasize: bool = True, | ||
reverse: bool = False, | ||
**kwargs, | ||
) -> None: | ||
"""Initialize a fixed transform. It will add and remove fixed values from | ||
tensors. | ||
Args: | ||
indices (List[int]): The indices of the parameters to be fixed. | ||
values (List[Union[float, int]]): The values of the fixed parameters. | ||
transform_on_train (bool): A boolean indicating whether to apply the | ||
transforms in train() mode. Default: True. | ||
transform_on_eval (bool): A boolean indicating whether to apply the | ||
transform in eval() mode. Default: True. | ||
transform_on_fantasize (bool): A boolean indicating whether to apply the | ||
transform when called from within a `fantasize` call. Default: True. | ||
reverse (bool): A boolean indicating whether the forward pass should | ||
untransform the inputs. Default: False. | ||
**kwargs: Accepted to conform to API. | ||
""" | ||
# Turn indices and values into tensors and sort | ||
indices_ = torch.tensor(indices, dtype=torch.long) | ||
values_ = torch.tensor(values, dtype=torch.float64) | ||
|
||
# Sort indices and values | ||
sort_idx = torch.argsort(indices_) | ||
indices_ = indices_[sort_idx] | ||
values_ = values_[sort_idx] | ||
|
||
super().__init__() | ||
self.register_buffer("indices", indices_) | ||
self.register_buffer("values", values_) | ||
self.transform_on_train = transform_on_train | ||
self.transform_on_eval = transform_on_eval | ||
self.transform_on_fantasize = transform_on_fantasize | ||
self.reverse = reverse | ||
|
||
def _transform(self, X: torch.Tensor) -> torch.Tensor: | ||
r"""Transform the input Tensor by popping out the fixed parameters at the | ||
specified indices. | ||
Args: | ||
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. | ||
Returns: | ||
torch.Tensor: The input tensor with fixed parameters removed. | ||
""" | ||
X = X.clone() | ||
|
||
mask = ~torch.isin(torch.arange(X.shape[1]), self.indices) | ||
|
||
X = X[:, mask] | ||
|
||
return X | ||
|
||
def _untransform(self, X: torch.Tensor) -> torch.Tensor: | ||
r"""Transform the input tensor by adding back in the fixed parameters at the | ||
specified indices. | ||
Args: | ||
X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. | ||
Returns: | ||
torch.Tensor: The same tensor as the input with the fixed parameters added | ||
back in. | ||
""" | ||
X = X.clone() | ||
|
||
for i, idx in enumerate(self.indices): | ||
pre_fixed = X[:, :idx] | ||
post_fixed = X[:, idx:] | ||
fixed = torch.tile(self.values[i], (X.shape[0], 1)) | ||
X = torch.cat((pre_fixed, fixed, post_fixed), dim=1) | ||
|
||
return X | ||
|
||
@classmethod | ||
def get_config_options( | ||
cls, | ||
config: Config, | ||
name: Optional[str] = None, | ||
options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
"""Return a dictionary of the relevant options to initialize a Fixed parameter | ||
transform for the named parameter within the config. | ||
Args: | ||
config (Config): Config to look for options in. | ||
name (str, optional): Parameter to find options for. | ||
options (Dict[str, Any], optional): Options to override from the config. | ||
Returns: | ||
Dict[str, Any]: A dictionary of options to initialize this class with, | ||
including the transformed bounds. | ||
""" | ||
options = super().get_config_options(config=config, name=name, options=options) | ||
|
||
if "values" not in options: | ||
options["values"] = [config.getfloat(name, "value")] | ||
|
||
return options |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters