Skip to content

Commit

Permalink
Merge pull request #44 from Innixma/make_deterministic
Browse files Browse the repository at this point in the history
Make experiments deterministic
  • Loading branch information
geoalgo authored Oct 17, 2023
2 parents 366c063 + 5b5ce7c commit 07165cc
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions scripts/baseline_comparison/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def sample_and_pick_best(
print(f"missing data {tid} {fold} {framework_type}")

# shuffle the rows
df_sub = df_sub.sample(frac=1).reset_index(drop=True)
df_sub = df_sub.sample(frac=1, random_state=0).reset_index(drop=True)

# pick only configurations up to max_runtime
if max_runtime:
Expand Down Expand Up @@ -339,8 +339,8 @@ def filter_configurations_above_budget(repo, test_tid, configs, max_runtime, qua
).quantile(q=quantile, axis=1).sort_values()

n_initial_configs = len(configs)
configs_fast_enough = df_configs_runtime[df_configs_runtime < max_runtime].index.tolist()
configs = list(set(configs_fast_enough).intersection(configs))
configs_fast_enough = set(df_configs_runtime[df_configs_runtime < max_runtime].index.tolist())
configs = [c for c in configs if c in configs_fast_enough]
print(f"kept only {len(configs)} from initial {n_initial_configs} for runtime {max_runtime}")
return configs

Expand Down Expand Up @@ -406,6 +406,10 @@ def evaluate_dataset(test_dataset, n_portfolio, n_ensemble, n_training_dataset,
else:
configs += list(np.random.choice(models_framework, n_training_config, replace=False))

# Randomly shuffle the config order with seed 0
rng = np.random.default_rng(seed=0)
configs = list(rng.choice(configs, len(configs), replace=False))

# # exclude configurations from zeroshot selection whose runtime exceeds runtime budget by large amount
if max_runtime:
configs = filter_configurations_above_budget(repo, test_tid, configs, max_runtime)
Expand Down Expand Up @@ -457,7 +461,7 @@ def evaluate_dataset(test_dataset, n_portfolio, n_ensemble, n_training_dataset,
assert not any(df_rank.isna().values.reshape(-1))

model_frameworks = {
framework: [x for x in repo.list_models() if framework in x]
framework: sorted([x for x in repo.list_models() if framework in x])
for framework in framework_types
}

Expand Down

0 comments on commit 07165cc

Please sign in to comment.