From 4099d062adaa37942f4ad22438bdc6c68e95ed57 Mon Sep 17 00:00:00 2001 From: innixma Date: Thu, 17 Oct 2024 23:56:23 +0000 Subject: [PATCH] fix tests --- tst/test_cache.py | 2 +- tst/test_repository.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tst/test_cache.py b/tst/test_cache.py index 302a8a63..e4fdd7f5 100644 --- a/tst/test_cache.py +++ b/tst/test_cache.py @@ -17,5 +17,5 @@ def f(): return pd.DataFrame({"a": [1, 2], "b": [3, 4]}) for ignore_cache in [True, False]: - res = cache_function_dataframe(f, "f", ignore_cache=ignore_cache) + res = cache_function_dataframe(f, "f", cache_path="tmp_cache_dir", ignore_cache=ignore_cache) pd.testing.assert_frame_equal(res, pd.DataFrame({"a": [1, 2], "b": [3, 4]})) diff --git a/tst/test_repository.py b/tst/test_repository.py index e69b308c..dd4f0782 100644 --- a/tst/test_repository.py +++ b/tst/test_repository.py @@ -27,7 +27,9 @@ def verify_equivalent_repository( assert repo1.datasets() == repo2.datasets() assert sorted(repo1.dataset_fold_config_pairs()) == sorted(repo2.dataset_fold_config_pairs()) if verify_metrics: - assert repo1.metrics().equals(repo2.metrics()) + metrics1 = repo1.metrics().sort_index() + metrics2 = repo2.metrics().sort_index() + assert metrics1.equals(metrics2) if verify_predictions: for dataset in repo1.datasets(): for f in repo1.folds: @@ -60,7 +62,9 @@ def verify_equivalent_repository( columns1 = sorted(list(baselines1.columns)) columns2 = sorted(list(baselines2.columns)) assert columns1 == columns2 - assert baselines1[columns1].equals(baselines2[columns1]) + baselines1 = baselines1[columns1].sort_values(by=columns1, ignore_index=True) + baselines2 = baselines2[columns1].sort_values(by=columns1, ignore_index=True) + assert baselines1.equals(baselines2) else: assert baselines1 == baselines2 if verify_metadata: @@ -71,9 +75,10 @@ def verify_equivalent_repository( else: columns1 = sorted(list(metadata1.columns)) columns2 = sorted(list(metadata2.columns)) - print(len(metadata1)) assert columns1 == columns2 - assert metadata1[columns1].equals(metadata2[columns1]) + metadata1 = metadata1[columns1].sort_values(by=columns1, ignore_index=True) + metadata2 = metadata2[columns1].sort_values(by=columns1, ignore_index=True) + assert metadata1.equals(metadata2) if verify_configs_hyperparameters: assert repo1.configs_hyperparameters() == repo2.configs_hyperparameters()