Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pickle from data formats, memmap for Tabular predictions #45

Merged
merged 7 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ dmypy.json
# Pyre type checker
.pyre/

data/results/
data/results/*
.DS_Store
autogluon/
.pkl
Expand Down
229 changes: 229 additions & 0 deletions data/metadata/task_metric_names.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
dataset,eval_metric,problem_type
geoalgo marked this conversation as resolved.
Show resolved Hide resolved
kropt,log_loss,multiclass
Titanic,roc_auc,binary
microaggregation2,log_loss,multiclass
guillermo,roc_auc,binary
kc1,roc_auc,binary
2dplanes,roc_auc,binary
OVA_Endometrium,roc_auc,binary
hill-valley,roc_auc,binary
pm10,roc_auc,binary
MagicTelescope,roc_auc,binary
one-hundred-plants-margin,log_loss,multiclass
Bioresponse,roc_auc,binary
okcupid-stem,log_loss,multiclass
cpu_act,roc_auc,binary
KDDCup09-Upselling,roc_auc,binary
segment,log_loss,multiclass
Fashion-MNIST,log_loss,multiclass
meta,roc_auc,binary
nomao,roc_auc,binary
eucalyptus,log_loss,multiclass
texture,log_loss,multiclass
cpu_small,roc_auc,binary
Brazilian_houses,root_mean_squared_error,regression
bank-marketing,roc_auc,binary
fried,roc_auc,binary
diabetes,roc_auc,binary
cylinder-bands,roc_auc,binary
wind,roc_auc,binary
fri_c3_1000_25,roc_auc,binary
space_ga,root_mean_squared_error,regression
fri_c0_1000_5,roc_auc,binary
kick,roc_auc,binary
Indian_pines,log_loss,multiclass
ada,roc_auc,binary
house_sales,root_mean_squared_error,regression
eye_movements,log_loss,multiclass
porto-seguro,roc_auc,binary
arcene,roc_auc,binary
bank32nh,roc_auc,binary
robert,log_loss,multiclass
synthetic_control,log_loss,multiclass
OnlineNewsPopularity,root_mean_squared_error,regression
delta_elevators,roc_auc,binary
climate-model-simulation-crashes,roc_auc,binary
albert,roc_auc,binary
steel-plates-fault,log_loss,multiclass
pol,root_mean_squared_error,regression
sensory,root_mean_squared_error,regression
kr-vs-k,log_loss,multiclass
eeg-eye-state,roc_auc,binary
fri_c1_1000_50,roc_auc,binary
quake,root_mean_squared_error,regression
anneal,log_loss,multiclass
jasmine,roc_auc,binary
volkert,log_loss,multiclass
pc2,roc_auc,binary
volcanoes-b1,log_loss,multiclass
Allstate_Claims_Severity,root_mean_squared_error,regression
collins,log_loss,multiclass
qsar-biodeg,roc_auc,binary
airlines,roc_auc,binary
bank8FM,roc_auc,binary
spambase,roc_auc,binary
GAMETES_Heterogeneity_20atts_1600_Het_0_4_0_2_75_EDM-2_001,roc_auc,binary
Mercedes_Benz_Greener_Manufacturing,root_mean_squared_error,regression
diamonds,root_mean_squared_error,regression
Click_prediction_small,roc_auc,binary
volcanoes-b6,log_loss,multiclass
dilbert,log_loss,multiclass
puma8NH,roc_auc,binary
blood-transfusion-service-center,roc_auc,binary
fabert,log_loss,multiclass
OVA_Prostate,roc_auc,binary
ldpa,log_loss,multiclass
socmob,root_mean_squared_error,regression
autoUniv-au1-1000,roc_auc,binary
kin8nm,roc_auc,binary
phoneme,roc_auc,binary
ailerons,roc_auc,binary
riccardo,roc_auc,binary
wine-quality-red,log_loss,multiclass
letter,log_loss,multiclass
abalone,root_mean_squared_error,regression
madelon,roc_auc,binary
MiceProtein,log_loss,multiclass
no2,roc_auc,binary
pc4,roc_auc,binary
OVA_Kidney,roc_auc,binary
fri_c3_500_10,roc_auc,binary
volcanoes-b5,log_loss,multiclass
APSFailure,roc_auc,binary
Satellite,roc_auc,binary
mammography,roc_auc,binary
tecator,root_mean_squared_error,regression
wine_quality,root_mean_squared_error,regression
house_prices_nominal,root_mean_squared_error,regression
volcanoes-a2,log_loss,multiclass
led24,log_loss,multiclass
OVA_Ovary,roc_auc,binary
eating,log_loss,multiclass
wilt,roc_auc,binary
fri_c3_1000_10,roc_auc,binary
balance-scale,log_loss,multiclass
yprop_4_1,root_mean_squared_error,regression
boston,root_mean_squared_error,regression
nursery,log_loss,multiclass
hiva_agnostic,roc_auc,binary
churn,roc_auc,binary
analcatdata_dmft,log_loss,multiclass
semeion,log_loss,multiclass
house_16H,root_mean_squared_error,regression
UMIST_Faces_Cropped,log_loss,multiclass
pendigits,log_loss,multiclass
micro-mass,log_loss,multiclass
cardiotocography,log_loss,multiclass
Australian,roc_auc,binary
first-order-theorem-proving,log_loss,multiclass
artificial-characters,log_loss,multiclass
Traffic_violations,log_loss,multiclass
elevators,root_mean_squared_error,regression
connect-4,log_loss,multiclass
pollen,roc_auc,binary
kdd_internet_usage,roc_auc,binary
waveform-5000,log_loss,multiclass
car,log_loss,multiclass
SpeedDating,roc_auc,binary
CIFAR_10,log_loss,multiclass
fars,log_loss,multiclass
kdd_el_nino-small,roc_auc,binary
isolet,log_loss,multiclass
har,log_loss,multiclass
houses,roc_auc,binary
OVA_Colon,roc_auc,binary
kc2,roc_auc,binary
autoUniv-au7-1100,log_loss,multiclass
fri_c2_500_50,roc_auc,binary
vehicle,log_loss,multiclass
volcanoes-d4,log_loss,multiclass
sylvine,roc_auc,binary
splice,log_loss,multiclass
mnist_784,log_loss,multiclass
gina,roc_auc,binary
visualizing_soil,roc_auc,binary
volcanoes-a3,log_loss,multiclass
QSAR-TID-11,root_mean_squared_error,regression
volcanoes-d1,log_loss,multiclass
GAMETES_Epistasis_2-Way_1000atts_0_4H_EDM-1_EDM-1_1,roc_auc,binary
twonorm,roc_auc,binary
nyc-taxi-green-dec-2016,root_mean_squared_error,regression
mc1,roc_auc,binary
fri_c4_500_100,roc_auc,binary
pc3,roc_auc,binary
soybean,log_loss,multiclass
madeline,roc_auc,binary
Santander_transaction_value,root_mean_squared_error,regression
volcanoes-b2,log_loss,multiclass
fri_c3_500_50,roc_auc,binary
christine,roc_auc,binary
jannis,log_loss,multiclass
black_friday,root_mean_squared_error,regression
dna,log_loss,multiclass
tokyo1,roc_auc,binary
analcatdata_authorship,log_loss,multiclass
mfeat-factors,log_loss,multiclass
topo_2_1,root_mean_squared_error,regression
colleges,root_mean_squared_error,regression
fri_c2_1000_25,roc_auc,binary
autoUniv-au6-750,log_loss,multiclass
Amazon_employee_access,roc_auc,binary
Diabetes130US,log_loss,multiclass
boston_corrected,roc_auc,binary
baseball,log_loss,multiclass
Kuzushiji-MNIST,log_loss,multiclass
Buzzinsocialmedia_Twitter,root_mean_squared_error,regression
SAT11-HAND-runtime-regression,root_mean_squared_error,regression
cnae-9,log_loss,multiclass
Internet-Advertisements,roc_auc,binary
electricity,roc_auc,binary
arsenic-female-bladder,roc_auc,binary
KDDCup09_appetency,roc_auc,binary
MIP-2016-regression,root_mean_squared_error,regression
puma32H,roc_auc,binary
hypothyroid,log_loss,multiclass
wine-quality-white,log_loss,multiclass
tamilnadu-electricity,log_loss,multiclass
ilpd,roc_auc,binary
credit-g,roc_auc,binary
GAMETES_Epistasis_3-Way_20atts_0_2H_EDM-1_1,roc_auc,binary
Yolanda,root_mean_squared_error,regression
dresses-sales,roc_auc,binary
cmc,log_loss,multiclass
delta_ailerons,roc_auc,binary
PhishingWebsites,roc_auc,binary
Run_or_walk_information,roc_auc,binary
GesturePhaseSegmentationProcessed,log_loss,multiclass
parity5_plus_5,roc_auc,binary
wall-robot-navigation,log_loss,multiclass
rmftsa_ladata,roc_auc,binary
autoUniv-au7-700,log_loss,multiclass
OVA_Lung,roc_auc,binary
OVA_Breast,roc_auc,binary
page-blocks,log_loss,multiclass
Moneyball,root_mean_squared_error,regression
QSAR-TID-10980,root_mean_squared_error,regression
pbcseq,roc_auc,binary
yeast,log_loss,multiclass
colleges_usnews,roc_auc,binary
LED-display-domain-7digit,log_loss,multiclass
GAMETES_Heterogeneity_20atts_1600_Het_0_4_0_2_50_EDM-2_001,roc_auc,binary
GAMETES_Epistasis_2-Way_20atts_0_1H_EDM-1_1,roc_auc,binary
volcanoes-a4,log_loss,multiclass
ringnorm,roc_auc,binary
MiniBooNE,roc_auc,binary
optdigits,log_loss,multiclass
jungle_chess_2pcs_raw_endgame_complete,log_loss,multiclass
adult,roc_auc,binary
fri_c0_500_5,roc_auc,binary
mozilla4,roc_auc,binary
ozone-level-8hr,roc_auc,binary
jm1,roc_auc,binary
volcanoes-e1,log_loss,multiclass
satimage,log_loss,multiclass
us_crime,root_mean_squared_error,regression
GAMETES_Epistasis_2-Way_20atts_0_4H_EDM-1_1,roc_auc,binary
numerai28_6,roc_auc,binary
pc1,roc_auc,binary
shuttle,log_loss,multiclass
philippine,roc_auc,binary
7 changes: 4 additions & 3 deletions scripts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

output_path = Path(__file__).parent

def load_context(version: str = "BAG_D244_F3_C1416", filter_very_large_dataset: bool = True) -> EvaluationRepository:
def load_context(version: str = "BAG_D244_F3_C1416", filter_very_large_dataset: bool = True, ignore_cache: bool = False) -> EvaluationRepository:
def _load_fun():
repo = load(version=version)
repo = repo.subset(models=[m for m in repo.list_models() if not "NeuralNetFastAI" in m])
return repo
repo = cache_function(_load_fun, cache_name=f"repo_{version}")
return repo.force_to_dense(verbose=True)
repo = cache_function(_load_fun, cache_name=f"repo_{version}", ignore_cache=ignore_cache)


if filter_very_large_dataset:
# For some reason, only 184 datasets are found from this list
Expand Down
2 changes: 1 addition & 1 deletion scripts/baseline_comparison/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def filter_configurations_above_budget(repo, test_tid, configs, max_runtime, qua
n_initial_configs = len(configs)
configs_fast_enough = set(df_configs_runtime[df_configs_runtime < max_runtime].index.tolist())
configs = [c for c in configs if c in configs_fast_enough]
print(f"kept only {len(configs)} from initial {n_initial_configs} for runtime {max_runtime}")
# print(f"kept only {len(configs)} from initial {n_initial_configs} for runtime {max_runtime}")
return configs


Expand Down
11 changes: 6 additions & 5 deletions scripts/baseline_comparison/evaluate_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from autogluon.common.savers import save_pd
from tabrepo.repository.evaluation_repository import (
load,
EvaluationRepository,
)
from tabrepo.utils import catchtime
from tabrepo.utils.cache import cache_function, cache_function_dataframe
from scripts.baseline_comparison.baselines import (
automl_results,
Expand Down Expand Up @@ -250,7 +250,7 @@ def save_total_runtime_to_file(total_time_h):
# n_training_configs = list(range(10, 210, 10))
n_training_datasets = [1, 5, 10, 25, 50, 75, 100, 125, 150, 175, 199]
n_training_configs = [1, 5, 10, 25, 50, 75, 100, 125, 150, 175, 200]
n_seeds = 20
n_seeds = 1
n_training_folds = [1, 2, 5, 10]
n_ensembles = [10, 20, 40, 80]
linestyle_ensemble = "--"
Expand Down Expand Up @@ -333,9 +333,10 @@ def save_total_runtime_to_file(total_time_h):
))


df = pd.concat([
experiment.data(ignore_cache=ignore_cache) for experiment in experiments
])
with catchtime("total time to generate evaluations"):
df = pd.concat([
experiment.data(ignore_cache=ignore_cache) for experiment in experiments
])
# De-duplicate in case we ran a config multiple times
df = df.drop_duplicates(subset=["method", "tid", "fold"])
df = rename_dataframe(df)
Expand Down
58 changes: 58 additions & 0 deletions scripts/benchmark_tabpred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

from scripts import load_context
from tabrepo.simulation.convert_memmap import convert_memmap_pred_from_pickle
from tabrepo.simulation.tabular_predictions import TabularPredictionsMemmap
from tabrepo.simulation.tabular_predictions_old import TabularPicklePerTaskPredictions
from tabrepo.utils import catchtime

filepath = Path(__file__)

if __name__ == '__main__':

"""
start: Compute sum with memmap
Sum obtained with memmap: 1176878.8741704822
Time for Compute sum with memmap: 0.0547 secs

start: Compute sum with pickle per task
Sum obtained with pickle per task: 1176878.874170535
Time for Compute sum with pickle per task: 7.5385 secs

"""
models = [f"CatBoost_r{i}_BAG_L1" for i in range(1, 10)]
repeats = 1

# Download predictions locally
# load_context("BAG_D244_F3_C1416_micro", ignore_cache=False)

pickle_path = Path(filepath.parent.parent / "data/results/2023_08_21/zeroshot_metadata/")
memmap_dir = Path(filepath.parent.parent / "data/results/2023_08_21/model_predictions/")
if not memmap_dir.exists():
print("converting to memmap")
convert_memmap_pred_from_pickle(pickle_path, memmap_dir)

with catchtime("Compute sum with memmap"):
for _ in range(repeats):
preds = TabularPredictionsMemmap(data_dir=memmap_dir)
datasets = preds.datasets
res = 0
for dataset in datasets:
for fold in preds.folds:
pred_val = preds.predict_val(dataset, fold, models)
pred_test = preds.predict_test(dataset, fold, models)
res += pred_val.mean() + pred_test.mean()
print(f"Sum obtained with memmap: {res}")

# Load previous format to compare performance
paths = [x for x in list(pickle_path.rglob("*zeroshot_pred_proba.pkl"))]
preds = TabularPicklePerTaskPredictions.from_paths(paths, output_dir=pickle_path)

with catchtime("Compute sum with pickle per task"):
for _ in range(repeats):
res = 0
for dataset in datasets:
for fold in preds.folds:
pred_val, pred_test = preds.predict(dataset=dataset, fold=fold, models=models, splits=["val", "test"])
res += pred_val.mean() + pred_test.mean()
print(f"Sum obtained with pickle per task: {res}")
Loading
Loading