diff --git a/aepsych/config.py b/aepsych/config.py index 4753d60ba..1464311b7 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -176,14 +176,19 @@ def update( par_names = self.getlist( "common", "parnames", element_type=str, fallback=[] ) - lb = [None] * len(par_names) - ub = [None] * len(par_names) + lb = [] + ub = [] for i, par_name in enumerate(par_names): # Validate the parameter-specific block self._check_param_settings(par_name) - lb[i] = self[par_name]["lower_bound"] - ub[i] = self[par_name]["upper_bound"] + if self[par_name]["par_type"] == "categorical": + choices = self.getlist(par_name, "choices", element_type=str) + lb.append("0") + ub.append(str(len(choices) - 1)) + else: + lb.append(self[par_name]["lower_bound"]) + ub.append(self[par_name]["upper_bound"]) self["common"]["lb"] = f"[{', '.join(lb)}]" self["common"]["ub"] = f"[{', '.join(ub)}]" @@ -260,6 +265,28 @@ def _check_param_settings(self, param_name: str) -> None: raise ValueError( f"Parameter {param_name} is missing the upper_bound setting." ) + elif param_block["par_type"] == "integer": + # Check if bounds exist and actaully integers + if "lower_bound" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the lower_bound setting." + ) + if "upper_bound" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the upper_bound setting." + ) + + if not ( + self.getint(param_name, "lower_bound") % 1 == 0 + and self.getint(param_name, "upper_bound") % 1 == 0 + ): + raise ValueError(f"Parameter {param_name} has non-integer bounds.") + elif param_block["par_type"] == "categorical": + # Need a choices array + if "choices" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the choices setting." + ) else: raise ValueError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/server/server.py b/aepsych/server/server.py index fd2d33e3e..df1dfdfce 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -276,22 +276,31 @@ def can_pregen_ask(self): return self.strat is not None and self.enable_pregen def _tensor_to_config(self, next_x): + next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0] config = {} for name, val in zip(self.parnames, next_x): - if val.dim() == 0: + if isinstance(val, str): + config[name] = [val] + elif isinstance(val, (int, float)): config[name] = [float(val)] + elif isinstance(val[0], str): + config[name] = val else: - config[name] = np.array(val) + config[name] = np.array(val, dtype="float64") return config def _config_to_tensor(self, config): unpacked = [config[name] for name in self.parnames] - - # handle config elements being either scalars or length-1 lists if isinstance(unpacked[0], list): - x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1) + x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1) else: - x = torch.tensor(np.stack(unpacked)) + x = np.stack(unpacked, dtype="O") + + # Unsqueeze batch dimension + x = np.expand_dims(x, 0) + + x = self.strat.transforms.str_to_indices(x)[0] + return x def __getstate__(self): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 066e90e3c..66798b617 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -9,7 +9,18 @@ from abc import ABC from configparser import NoOptionError from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) import numpy as np import torch @@ -17,7 +28,13 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin, ModelProtocol from botorch.acquisition import AcquisitionFunction -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize +from botorch.models.transforms.input import ( + ChainedInputTransform, + Log10, + Normalize, + ReversibleInputTransform, + InputTransform, +) from botorch.models.transforms.utils import subset_transform from botorch.posteriors import Posterior from torch import Tensor @@ -39,10 +56,34 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin): space back into raw space. """ + def __init__( + self, + **transforms: InputTransform, + ) -> None: + self.cat_map_raw = {} + transform_keys = list(transforms.keys()) + for key in transform_keys: + if isinstance(transforms[key], Categorical): + categorical = transforms.pop(key) + self.cat_map_raw.update(categorical.cat_map_raw) + + if len(self.cat_map_raw) > 0: + # Remake the categorical and put it at the end + transforms["_CombinedCategorical"] = Categorical( + indices=list(self.cat_map_raw.keys()), categorical_map=self.cat_map_raw + ) + self.cat_map_transformed = transforms[ + "_CombinedCategorical" + ].cat_map_transformed + else: + self.cat_map_transformed = {} + + super().__init__(**transforms) + def _temporary_reshape(func: Callable) -> Callable: # Decorator to reshape tensors to the expected 2D shape, even if the input was # 1D or 3D and after the transform reshape it back to the original. - def wrapper(self, X: Tensor) -> Tensor: + def wrapper(self, X: Tensor, **kwargs) -> Tensor: squeeze = False if len(X.shape) == 1: # For 1D inputs, primarily for transforming arguments X = X.unsqueeze(0) @@ -54,7 +95,7 @@ def wrapper(self, X: Tensor) -> Tensor: X = X.swapaxes(-2, -1).reshape(-1, dim) reshape = True - X = func(self, X) + X = func(self, X, **kwargs) if reshape: X = X.reshape(batch, stim, -1).swapaxes(-1, -2) @@ -80,6 +121,27 @@ def transform(self, X: Tensor) -> Tensor: """ return super().transform(X) + @_temporary_reshape + def transform_bounds( + self, X: Tensor, bound: Optional[Literal["lb", "ub"]] = None + ) -> Tensor: + r"""Transform bounds of a parameter. + + Individual transforms are applied in sequence. Looks for a specific + transform_bounds method in each transform to apply that, otherwise uses the + normal transform. + + Args: + X: A tensor of inputs. Either `[dim]`, `[batch, dim]`, or `[batch, dim, stimuli]`. + + Returns: + A tensor of transformed inputs with the same shape as the input. + """ + for tf in self.values(): + X = tf.transform_bounds(X, bound=bound) + + return X + @_temporary_reshape def untransform(self, X: Tensor) -> Tensor: r"""Un-transform the inputs to a model. @@ -94,6 +156,45 @@ def untransform(self, X: Tensor) -> Tensor: """ return super().untransform(X) + @_temporary_reshape + def indices_to_str(self, X: Tensor) -> np.ndarray: + r"""Return a NumPy array of objects where the categorical parameters will be + strings. + + Args: + X (Tensor): A tensor shaped `[batch, dim]` to turn into a mixed type NumPy + array. + + Returns: + np.ndarray: An array with the objet type where the categorical parameters + are strings. + """ + obj_arr = X.cpu().numpy().astype("O") + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats[int(i)] for i in obj_arr[:, idx]] + + return obj_arr + + @_temporary_reshape + def str_to_indices(self, obj_arr: np.ndarray) -> Tensor: + r"""Return a Tensor where the categorical parameters are converted from strings + to indices. + + Args: + obj_arr (np.ndarray): A NumPy array `[batch, dim]` where the categorical + parameters are strings. + + Returns: + Tensor: A tensor with the categorical parameters converted to indices. + """ + obj_arr = obj_arr[:] + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats.index(cat) for cat in obj_arr[:, idx]] + + return torch.tensor(obj_arr.astype("float64"), dtype=torch.float64) + @classmethod def get_config_options( cls, @@ -132,6 +233,31 @@ def get_config_options( for par in parnames: # This is the order that transforms are potentially applied, order matters + try: + par_type = config[par]["par_type"] + except KeyError: # Probably because par doesn't have its own section + par_type = "continuous" + + # Integer variable + if par_type == "integer": + round = Round.from_config( + config=config, name=par, options=transform_options + ) + + # Transform bounds + transform_options["bounds"] = round.transform_bounds( + transform_options["bounds"] + ) + transform_dict[f"{par}_Round"] = round + + # Categorical variable + elif par_type == "categorical": + categorical = Categorical.from_config( + config=config, name=par, options=transform_options + ) + + transform_dict[f"{par}_Categorical"] = categorical + # Log scale if config.getboolean(par, "log_scale", fallback=False): log10 = Log10Plus.from_config( @@ -144,8 +270,10 @@ def get_config_options( ) transform_dict[f"{par}_Log10Plus"] = log10 - # Normalize scale (defaults true) - if config.getboolean(par, "normalize_scale", fallback=True): + # Normalize scale (defaults true), don't do this for categoricals + if par_type != "categorical" and config.getboolean( + par, "normalize_scale", fallback=True + ): normalize = NormalizeScale.from_config( config=config, name=par, options=transform_options ) @@ -720,7 +848,67 @@ def get_config_options( return options -class Log10Plus(Log10, ConfigurableMixin): +class Transform(ReversibleInputTransform, ConfigurableMixin, ABC): + """Base class for individual transforms. These transforms are intended to be stacked + together using the ParameterTransforms class. + """ + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): Which bound this is to transform, if + None, it's the `[2, dim]` form with both bounds stacked. + **kwargs: Keyword arguments for specific transforms, they should have + default values. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + return self.transform(X) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Return a dictionary of the relevant options to initialize a Log10Plus + transform for the named parameter within the config. + + Args: + config (Config): Config to look for options in. + name (str): Parameter to find options for. + options (Dict[str, Any]): Options to override from the config. + + Returns: + Dict[str, Any]: A diciontary of options to initialize this class with, + including the transformed bounds. + """ + if name is None: + raise ValueError(f"{name} must be set to initialize a transform.") + + if options is None: + options = {} + else: + options = deepcopy(options) + + # Figure out the index of this parameter + parnames = config.getlist("common", "parnames", element_type=str) + idx = parnames.index(name) + + if "indices" not in options: + options["indices"] = [idx] + + return options + + +class Log10Plus(Log10, Transform): """Base-10 log transform that we add a constant to the values""" def __init__( @@ -803,7 +991,7 @@ def get_config_options( Dict[str, Any]: A diciontary of options to initialize this class with, including the transformed bounds. """ - options = _get_parameter_options(config, name, options) + options = super().get_config_options(config=config, name=name, options=options) # Make sure we have bounds ready if "bounds" not in options: @@ -823,7 +1011,7 @@ def get_config_options( return options -class NormalizeScale(Normalize, ConfigurableMixin): +class NormalizeScale(Normalize, Transform): def __init__( self, d: int, @@ -901,15 +1089,281 @@ def get_config_options( Dict[str, Any]: A diciontary of options to initialize this class with, including the transformed bounds. """ - options = _get_parameter_options(config, name, options) + options = super().get_config_options(config=config, name=name, options=options) # Make sure we have bounds ready if "bounds" not in options: options["bounds"] = get_bounds(config) if "d" not in options: - parnames = config.getlist("common", "parnames", element_type=str) - options["d"] = len(parnames) + options["d"] = options["bounds"].shape[1] + + return options + + +class Round(Transform, torch.nn.Module): + def __init__( + self, + indices: list[int], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a round transform. This operation rounds the inputs at the indices + in both direction. + + Args: + indices: The indices of the inputs to round. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: Currently will not do anything, here to conform to + API. + reverse: Whether to round in forward or backward passes. + **kwargs: Accepted to conform to API. + """ + super().__init__() + self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + + @subset_transform + def _transform(self, X: torch.Tensor) -> torch.Tensor: + r"""Round the inputs to a model to be discrete. This rounding is the same both + in the forward and the backward pass. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + torch.Tensor: The input tensor with values rounded. + """ + return X.round() + + @subset_transform + def _untransform(self, X: Tensor) -> Tensor: + r"""Round the inputs to a model to be discrete. This rounding is the same both + in the forward and the backward pass. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + torch.Tensor: The input tensor with values rounded. + """ + return X.round() + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + **kwargs: passed to _transform_bounds + epsilon: will modify the offset for the rounding to ensure each discrete + value has equal space in the parameter space. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + epsilon = kwargs.get("epsilon", 1e-6) + return self._transform_bounds(X, bound=bound, epsilon=epsilon) + + def _transform_bounds( + self, + X: torch.Tensor, + bound: Optional[Literal["lb", "ub"]] = None, + epsilon: float = 1e-6, + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + epsilon: + **kwargs: other kwargs + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + X = X.clone() + + if bound == "lb": + X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) + elif bound == "ub": + X[0, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices)) + else: # Both bounds + X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) + X[1, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices)) + + return X + + +class Categorical(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): + is_one_to_many = True + + def __init__( + self, + indices: list[int], + categorical_map: Dict[int, List[str]], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a Categorical transform. Takes the integer at the indices and + converts it to one_hot starting from that indices (and therefore pushing) + forward other indices. + + Args: + indices: The indices of the inputs to turn into categoricals. + categorical_map: A dictionary where the key is the index of the categorical + variable and the values are the possible categories. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: Currently will not do anything, here to conform to + API. + reverse: A boolean indicating whether the forward pass should untransform + the parameters. + **kwargs: Accepted to conform to API. + """ + # indices needs to be sorted + indices = sorted(indices) + + # Multiple categoricals need to shift indices + categorical_offset = 0 + new_indices = [] + cat_map_transformed = {} + for idx in indices: + category_values = categorical_map[idx] + num_classes = len(categorical_map[idx]) + new_idx = idx + categorical_offset + categorical_offset += num_classes - 1 + + new_indices.append(new_idx) + cat_map_transformed[new_idx] = category_values + + super().__init__() + self.register_buffer("indices", torch.tensor(new_indices, dtype=torch.long)) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + self.cat_map_raw = categorical_map + self.cat_map_transformed = cat_map_transformed + + def _transform(self, X: torch.Tensor) -> torch.Tensor: + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + def _untransform(self, X: torch.Tensor) -> torch.Tensor: + for idx in reversed(self.indices): + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Chop up X around the one_hot + pre_categorical = X[:, :idx] + one_hot = X[:, idx : idx + num_classes] + post_categorical = X[:, idx + num_classes :] + + # Turn one_hot back into indices + idxs = torch.argmax(one_hot, dim=1).unsqueeze(-1).to(X) + + X = torch.cat((pre_categorical, idxs, post_categorical), dim=1) + + return X + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + if bound == "lb": + one_hot[:] = 0.0 + elif bound == "ub": + one_hot[:] = 1.0 + else: # Both bounds + one_hot[0, :] = 0.0 + one_hot[1, :] = 1.0 + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Return a dictionary of the relevant options to initialize the categorical + transform from the config for the named transform. + + Args: + config (Config): Config to look for options in. + name (str, optional): The parameter to find options for. + options (Dict[str, Any], optional): Options to override from the config, + defaults to None. + + Return: + Dict[str, Any]: A dictionary of options to initialize this class. + """ + options = _get_parameter_options(config, name, options) + + if "categorical_map" not in options: + if name is None: + raise ValueError("name argument must be set to initialize from config.") + + options["categorical_map"] = { + options["indices"][0]: config.getlist(name, "choices", element_type=str) + } return options @@ -938,7 +1392,10 @@ def transform_options( value = np.array(value, dtype=float) value = torch.tensor(value).to(torch.float64) - value = transforms.transform(value) + if option in ["ub", "lb"]: + value = transforms.transform_bounds(value, bound=option) + else: + value = transforms.transform(value) def _arr_to_list(iter): if hasattr(iter, "__iter__"): @@ -982,35 +1439,3 @@ def get_bounds(config: Config) -> torch.Tensor: bounds = torch.stack((_lb, _ub)) return bounds - - -def _get_parameter_options( - config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - """Return options for a parameter in a config. - - Args: - config (Config): Config to search for parameter. - name (str): Name of parameter. - options (Dict[str, Any], optional): dictionary of options to overwrite config - options, defaults to an empty dictionary. - - Returns: - Dict[str, Any]: Dictionary of options to initialize a transform from config. - """ - if name is None: - raise ValueError(f"{name} must be set to initialize a transform.") - - if options is None: - options = {} - else: - options = deepcopy(options) - - # Figure out the index of this parameter - parnames = config.getlist("common", "parnames", element_type=str) - idx = parnames.index(name) - - if "indices" not in options: - options["indices"] = [idx] - - return options diff --git a/docs/parameters.md b/docs/parameters.md index ef7dfaf14..f93526b66 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -27,6 +27,19 @@ parameters can have any non-infinite ranges. This means that continuous paramete include negative values (e.g., lower bound = -1, upper bound = 1) or have very large ranges (e.g., lower bound = 0, upper bound = 1,000,000). +