Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for categorical parameters #449

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,19 @@ def update(
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
lb = []
ub = []
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]
if self[par_name]["par_type"] == "categorical":
choices = self.getlist(par_name, "choices", element_type=str)
lb.append("0")
ub.append(str(len(choices) - 1))
else:
lb.append(self[par_name]["lower_bound"])
ub.append(self[par_name]["upper_bound"])

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
Expand Down Expand Up @@ -260,6 +265,28 @@ def _check_param_settings(self, param_name: str) -> None:
raise ValueError(
f"Parameter {param_name} is missing the upper_bound setting."
)
elif param_block["par_type"] == "integer":
# Check if bounds exist and actaully integers
if "lower_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the lower_bound setting."
)
if "upper_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the upper_bound setting."
)

if not (
self.getint(param_name, "lower_bound") % 1 == 0
and self.getint(param_name, "upper_bound") % 1 == 0
):
raise ValueError(f"Parameter {param_name} has non-integer bounds.")
elif param_block["par_type"] == "categorical":
# Need a choices array
if "choices" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the choices setting."
)
else:
raise ValueError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
21 changes: 15 additions & 6 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,22 +276,31 @@ def can_pregen_ask(self):
return self.strat is not None and self.enable_pregen

def _tensor_to_config(self, next_x):
next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0]
config = {}
for name, val in zip(self.parnames, next_x):
if val.dim() == 0:
if isinstance(val, str):
config[name] = [val]
elif isinstance(val, (int, float)):
config[name] = [float(val)]
elif isinstance(val[0], str):
config[name] = val
else:
config[name] = np.array(val)
config[name] = np.array(val, dtype="float64")
return config

def _config_to_tensor(self, config):
unpacked = [config[name] for name in self.parnames]

# handle config elements being either scalars or length-1 lists
if isinstance(unpacked[0], list):
x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1)
x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1)
else:
x = torch.tensor(np.stack(unpacked))
x = np.stack(unpacked, dtype="O")

# Unsqueeze batch dimension
x = np.expand_dims(x, 0)

x = self.strat.transforms.str_to_indices(x)[0]

return x

def __getstate__(self):
Expand Down
Loading
Loading