Skip to content

Commit

Permalink
Fix GenerationStrategy and GenerataionNode todos (#330)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2142


In this diff we update the Ax GenerationStrategy code to remove all todos related to aepsych, and use the standard GS code flow for aepsych usecases. This required some minimal updates to the aepsych criterion and one of the storage tests.

In following diffs we will:
- revisit storage
- update AEPsych GSs as needed
- determine if run indefinetly can be replaced by simply having gen_unlimited_trials = true

Reviewed By: lena-kashtelyan

Differential Revision: D52898268
  • Loading branch information
mgarrard authored and facebook-github-bot committed Jan 26, 2024
1 parent 43df46e commit e44d213
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
17 changes: 14 additions & 3 deletions aepsych/generators/completion_criterion/min_asks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@


class MinAsks(TransitionCriterion, ConfigurableMixin):
def __init__(self, threshold: int) -> None:
def __init__(
self,
threshold: int,
block_transition_if_unmet: Optional[bool] = True,
block_gen_if_met: Optional[bool] = False,
) -> None:
self.threshold = threshold
self.block_transition_if_unmet = block_transition_if_unmet
self.block_gen_if_met = block_gen_if_met

def is_met(self, experiment: Experiment) -> bool:
def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
) -> bool:
return experiment.num_asks >= self.threshold

def block_continued_generation_error(
Expand All @@ -32,5 +41,7 @@ def block_continued_generation_error(
@classmethod
def get_config_options(cls, config: Config, name: str) -> Dict[str, Any]:
min_asks = config.getint(name, "min_asks", fallback=1)
options = {"threshold": min_asks}
options = {
"threshold": min_asks,
}
return options
13 changes: 11 additions & 2 deletions aepsych/generators/completion_criterion/run_indefinitely.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@


class RunIndefinitely(TransitionCriterion, ConfigurableMixin):
def __init__(self, run_indefinitely: bool) -> None:
def __init__(
self,
run_indefinitely: bool,
block_transition_if_unmet: Optional[bool] = False,
block_gen_if_met: Optional[bool] = False,
) -> None:
self.run_indefinitely = run_indefinitely
self.block_transition_if_unmet = block_transition_if_unmet
self.block_gen_if_met = block_gen_if_met

def is_met(self, experiment: Experiment) -> bool:
def is_met(
self, experiment: Experiment, trials_from_node: Optional[Set[int]] = None
) -> bool:
return not self.run_indefinitely

def block_continued_generation_error(
Expand Down

0 comments on commit e44d213

Please sign in to comment.