Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma committed Oct 17, 2024
1 parent bc7e3d7 commit 4099d06
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tst/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def f():
return pd.DataFrame({"a": [1, 2], "b": [3, 4]})

for ignore_cache in [True, False]:
res = cache_function_dataframe(f, "f", ignore_cache=ignore_cache)
res = cache_function_dataframe(f, "f", cache_path="tmp_cache_dir", ignore_cache=ignore_cache)
pd.testing.assert_frame_equal(res, pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
13 changes: 9 additions & 4 deletions tst/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def verify_equivalent_repository(
assert repo1.datasets() == repo2.datasets()
assert sorted(repo1.dataset_fold_config_pairs()) == sorted(repo2.dataset_fold_config_pairs())
if verify_metrics:
assert repo1.metrics().equals(repo2.metrics())
metrics1 = repo1.metrics().sort_index()
metrics2 = repo2.metrics().sort_index()
assert metrics1.equals(metrics2)
if verify_predictions:
for dataset in repo1.datasets():
for f in repo1.folds:
Expand Down Expand Up @@ -60,7 +62,9 @@ def verify_equivalent_repository(
columns1 = sorted(list(baselines1.columns))
columns2 = sorted(list(baselines2.columns))
assert columns1 == columns2
assert baselines1[columns1].equals(baselines2[columns1])
baselines1 = baselines1[columns1].sort_values(by=columns1, ignore_index=True)
baselines2 = baselines2[columns1].sort_values(by=columns1, ignore_index=True)
assert baselines1.equals(baselines2)
else:
assert baselines1 == baselines2
if verify_metadata:
Expand All @@ -71,9 +75,10 @@ def verify_equivalent_repository(
else:
columns1 = sorted(list(metadata1.columns))
columns2 = sorted(list(metadata2.columns))
print(len(metadata1))
assert columns1 == columns2
assert metadata1[columns1].equals(metadata2[columns1])
metadata1 = metadata1[columns1].sort_values(by=columns1, ignore_index=True)
metadata2 = metadata2[columns1].sort_values(by=columns1, ignore_index=True)
assert metadata1.equals(metadata2)
if verify_configs_hyperparameters:
assert repo1.configs_hyperparameters() == repo2.configs_hyperparameters()

Expand Down

0 comments on commit 4099d06

Please sign in to comment.