-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add easy model fitting and comparison #77
Open
prateekdesai04
wants to merge
90
commits into
autogluon:main
Choose a base branch
from
prateekdesai04:wip
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
90 commits
Select commit
Hold shift + click to select a range
64ed36d
adding test scripts
851bfb0
matching tabrepo and fit df, using zeroshot_context
b0b2552
plotting functionality
ef2aa9a
Update
Innixma 275df36
WIP exec.py
2ec9b6e
Add updates
Innixma 66ea368
Add v2 scripts
Innixma 2c5b3c3
Remove y_uncleaned
Innixma 3b3f791
resolve merge conflicts
Innixma a1df0a4
resolve merge conflicts
Innixma 022fc3f
resolve merge conflicts
Innixma 6ab5304
adding test scripts
fd1d0a9
plotting functionality
a411f5e
Initial Class implementation
7227ab2
typo
08b266c
minor updates
Innixma 095ceed
add run_scripts_v4
Innixma 41b098e
making run_experiment a staticmethod
prateekdesai04 1ef8070
Updated run_experiments
prateekdesai04 8b25bac
Cleanup, add TabPFNv2 prototype
Innixma a69596d
Cleanup
Innixma f5fe3c7
Cleanup
Innixma b95a76e
Cleanup
Innixma f8b8da4
Cleanup
Innixma 74df85f
Cleanup
Innixma 8f62e02
bug fix
Innixma 6c3833f
Add run_tabpfn_v2_benchmark.py + additional bugfixes
Innixma 175a38c
Add TabForestPFN_class.py
Innixma 3401d4e
Add TabForestPFN_class.py
Innixma e75dbd1
Delete old files
Innixma 5e6afab
Update file locations
Innixma c4eb4e1
Add AutoGluon_class.py, tabforestpfn_model.py
Innixma 4623a37
add hyperparameter/init_args support
prateekdesai04 90a7bad
Add run_tabforestpfn_benchmark.py
Innixma 9f9a269
removing unused files
prateekdesai04 5d044f0
Update add simulation_artifacts support
Innixma 3070f31
Add simulation ensemble comparison support via `evaluate_ensemble_wit…
Innixma e801a17
update
Innixma d3b4cfe
update
Innixma 7b2195d
minor cleanup
prateekdesai04 53bf01b
minor cleanup
prateekdesai04 36e0715
Update evaluate_ensemble_with_time
Innixma 1048d34
Fix bug in zeroshot_configs
Innixma 5e30d89
Refactor baselines.py
Innixma e405606
Add repo.evaluate_ensemble_with_time_multi
Innixma 0ceb91b
Update repo.evaluate_ensemble to return DataFrame
Innixma 0097ef2
Add logger module, and adding wrapper logs to run scripts, will add d…
prateekdesai04 17344dd
minor update
Innixma 8570fca
Refactor evaluate_ensemble
Innixma 0339c91
Refactor evaluate_ensemble
Innixma 1d2ec71
Refactor evaluate_ensemble
Innixma 426d041
Cleanup
Innixma a12dd5c
Cleanup
Innixma 1546337
Cleanup
Innixma 5508542
Add logic to context.py
Innixma e429dd6
minor update
Innixma 3ed7678
Add save/load logic to ZeroshotSimulatorContext
Innixma 84d4d81
Add save/load logic to EvaluationRepository
Innixma 697b477
Align column names in model fits
Innixma 8b8e06c
Add unit tests for repo save/load
Innixma 92c0cf7
Add extra unit tests for repo save/load
Innixma 9b65c66
Fix Self import
Innixma bc7e3d7
Fix imports
Innixma 4099d06
fix tests
Innixma 248e9cf
simplify run_quickstart_from_scratch.py
Innixma d4e8b59
minor update
Innixma 1f827e1
update `repo.from_raw`
Innixma 56016bf
Add root, app and console loggers
prateekdesai04 9003e1a
addition to logging module
prateekdesai04 f1abdb1
add context save/load with json + relative path support
Innixma 0102ac4
add ebm and tabpfnv2 models
Innixma 16c0329
add ebm and tabpfnv2 models
Innixma 1474e65
update
Innixma 315aa99
update
Innixma b5d2838
update
Innixma 2b962fb
update
Innixma b9380dc
update
Innixma 3b0a932
Support loading repo artifact from cloned directory
Innixma c5c69a6
minor fix
Innixma 2f8df7e
cleanup
Innixma afdb8b9
update
Innixma 6d22d4d
Update
Innixma e7390a9
cleanup
Innixma d0484d8
Add simple benchmark runner
Innixma 798219a
cleanup
Innixma 7440051
Update for ag12
Innixma 1393238
Update for ag12
Innixma 8c78273
Update for ag12
Innixma 6223f8c
TabPFN support stopped at best epoch
Innixma f65c4ef
update
Innixma File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Tuple, List, Union | ||
from autogluon_benchmark import OpenMLTaskWrapper | ||
from tabrepo import EvaluationRepository | ||
|
||
class ContextDataLoader(OpenMLTaskWrapper): | ||
""" | ||
Class to Fetch Train Test Splits of context dataset | ||
""" | ||
def get_context_train_test_split(self, repo: EvaluationRepository, task_id: Union[int, List[int]], repeat: int = 0, | ||
fold: int = 0, sample: int = 0): | ||
if repo.tid_to_dataset(task_id) in repo.datasets(): | ||
train_indices, test_indices = self.task.get_train_test_split_indices(repeat=repeat, fold=fold, | ||
sample=sample) | ||
X_train = self.X.loc[train_indices] | ||
y_train = self.y[train_indices] | ||
X_test = self.X.loc[test_indices] | ||
y_test = self.y[test_indices] | ||
return X_train, y_train, X_test, y_test | ||
else: | ||
raise KeyError(f"Dataset for task_id {task_id} not found.") | ||
|
||
# Add Another function to just get the X and y for random state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from tabpfn import TabPFNClassifier | ||
from autogluon.features import AutoMLPipelineFeatureGenerator | ||
from autogluon.core.metrics import get_metric, Scorer | ||
import pandas as pd | ||
from autogluon_benchmark.utils.time_utils import Timer | ||
from autogluon_benchmark.frameworks.autogluon.run import ag_eval_metric_map | ||
from autogluon.core.data import LabelCleaner | ||
from autogluon.core.utils import generate_train_test_split | ||
|
||
|
||
def fit_outer(task, fold: int, task_name: str, method: str, init_args: dict = None, **kwargs): | ||
if init_args is None: | ||
init_args = {} | ||
if 'eval_metric' not in init_args: | ||
init_args['eval_metric'] = ag_eval_metric_map[task.problem_type] | ||
|
||
X_train, y_train, X_test, y_test = task.get_train_test_split(fold=fold) | ||
|
||
out = fit_custom_clean(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, | ||
problem_type=task.problem_type, eval_metric=init_args['eval_metric'], label=task.label) | ||
|
||
out["framework"] = method | ||
out["dataset"] = task_name | ||
out["tid"] = task.task_id | ||
out["fold"] = fold | ||
out["problem_type"] = task.problem_type | ||
print(f"Task Name: {out['dataset']}") | ||
print(f"Task ID: {out['tid']}") | ||
print(f"Metric : {out['eval_metric']}") | ||
print(f"Test Error: {out['test_error']:.4f}") | ||
print(f"Fit Time: {out['time_fit']:.3f}s") | ||
print(f"Infer Time: {out['time_predict']:.3f}s") | ||
|
||
out.pop("predictions") | ||
out.pop("probabilities") | ||
out.pop("truth") | ||
|
||
df_results = pd.DataFrame([out]) | ||
ordered_columns = ["dataset", "fold", "framework", "test_error", "eval_metric", "time_fit"] | ||
columns_reorder = ordered_columns + [c for c in df_results.columns if c not in ordered_columns] | ||
df_results = df_results[columns_reorder] | ||
return df_results | ||
|
||
|
||
# TODO: Nick: This works for 99.99% of cases, but to handle all possible edge-cases, | ||
# we probably want to use Tabular's LabelCleaner during metric calculation to avoid any oddities. | ||
# This can be done as a follow-up | ||
# We also need to track positive_class for binary classification | ||
def calc_error( | ||
y_true: pd.Series, | ||
y_pred: pd.Series, | ||
y_pred_proba: pd.DataFrame, | ||
problem_type: str, | ||
scorer: Scorer, | ||
) -> float: | ||
if scorer.needs_pred: # use y_pred | ||
error = scorer.error(y_true=y_true, y_pred=y_pred) | ||
elif problem_type == "binary": # use y_pred_proba | ||
error = scorer.error(y_true=y_true, y_pred=y_pred_proba.iloc[:, 1]) | ||
else: | ||
error = scorer.error(y_true=y_true, y_pred=y_pred_proba) | ||
return error | ||
|
||
|
||
def fit_custom_clean(X_train, y_train, X_test, y_test, problem_type: str = None, eval_metric: str = None, **kwargs): | ||
label_cleaner = LabelCleaner.construct(problem_type=problem_type, y=y_train) | ||
y_train_clean = label_cleaner.transform(y_train) | ||
y_test_clean = label_cleaner.transform(y_test) | ||
|
||
# TODO: Nick: For now, I'm preprocessing via AutoGluon's feature generator because otherwise TabPFN crashes on some datasets. | ||
feature_generator = AutoMLPipelineFeatureGenerator() | ||
X_train_clean = feature_generator.fit_transform(X=X_train, y=y_train) | ||
X_test_clean = feature_generator.transform(X=X_test) | ||
|
||
out = fit_custom( | ||
X_train=X_train_clean, | ||
y_train=y_train_clean, | ||
X_test=X_test_clean, | ||
y_test=y_test_clean, | ||
problem_type=problem_type, | ||
eval_metric=eval_metric, | ||
**kwargs, | ||
) | ||
|
||
y_pred_test_clean = out["predictions"] | ||
y_pred_proba_test_clean = out["probabilities"] | ||
|
||
scorer: Scorer = get_metric(metric=eval_metric, problem_type=problem_type) | ||
|
||
test_error = calc_error( | ||
y_true=y_test_clean, | ||
y_pred=y_pred_test_clean, | ||
y_pred_proba=y_pred_proba_test_clean, | ||
problem_type=problem_type, | ||
scorer=scorer, | ||
) | ||
|
||
y_pred_test = label_cleaner.inverse_transform(y_pred_test_clean) | ||
out["predictions"] = y_pred_test | ||
|
||
if y_pred_proba_test_clean is not None: | ||
y_pred_proba_test = label_cleaner.inverse_transform_proba(y_pred_proba_test_clean, as_pandas=True) | ||
out["probabilities"] = y_pred_proba_test | ||
|
||
out["test_error"] = test_error | ||
out["eval_metric"] = eval_metric | ||
out["truth"] = y_test | ||
|
||
return out | ||
|
||
|
||
def fit_custom( | ||
X_train: pd.DataFrame, | ||
y_train: pd.Series, | ||
X_test: pd.DataFrame, | ||
y_test: pd.Series, | ||
eval_metric: str, | ||
problem_type: str = None, | ||
label: str = None, | ||
) -> dict: | ||
|
||
# FIXME: Nick: This is a hack specific to TabPFN, since it doesn't handle large data, parameterize later | ||
sample_limit = 4096 | ||
if len(X_train) > sample_limit: | ||
X_train, _, y_train, _ = generate_train_test_split( | ||
X=X_train, | ||
y=y_train, | ||
problem_type=problem_type, | ||
train_size=sample_limit, | ||
random_state=0, | ||
min_cls_count_train=1, | ||
) | ||
|
||
# with Timer() as timer_fit: | ||
# model = TabPFNClassifier(device='cpu', N_ensemble_configurations=32).fit(X_train, y_train, overwrite_warning=True) | ||
|
||
from tabpfn_client.estimator import TabPFNClassifier as TabPFNClassifierV2, TabPFNRegressor | ||
model = TabPFNClassifierV2(model="latest_tabpfn_hosted", n_estimators=32) | ||
with Timer() as timer_fit: | ||
model = model.fit(X_train, y_train) | ||
|
||
is_classification = problem_type in ['binary', 'multiclass'] | ||
if is_classification: | ||
with Timer() as timer_predict: | ||
y_pred_proba = model.predict_proba(X_test) | ||
y_pred_proba = pd.DataFrame(y_pred_proba, columns=model.classes_, index=X_test.index) | ||
y_pred = y_pred_proba.idxmax(axis=1) | ||
else: | ||
with Timer() as timer_predict: | ||
y_pred = model.predict(X_test) | ||
y_pred = pd.Series(y_pred, name=label, index=X_test.index) | ||
y_pred_proba = None | ||
|
||
return { | ||
'predictions': y_pred, | ||
'probabilities': y_pred_proba, | ||
'time_fit': timer_fit.duration, | ||
'time_predict': timer_predict.duration, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from __future__ import annotations | ||
|
||
import pandas as pd | ||
|
||
from autogluon_benchmark.tasks.experiment_utils import run_experiments | ||
from tabrepo import load_repository, EvaluationRepository | ||
|
||
|
||
def convert_leaderboard_to_configs(leaderboard: pd.DataFrame, minimal: bool = True) -> pd.DataFrame: | ||
df_configs = leaderboard.rename(columns=dict( | ||
time_fit="time_train_s", | ||
time_predict="time_infer_s", | ||
test_error="metric_error", | ||
eval_metric="metric", | ||
val_error="metric_error_val", | ||
)) | ||
if minimal: | ||
df_configs = df_configs[[ | ||
"dataset", | ||
"fold", | ||
"framework", | ||
"metric_error", | ||
"metric", | ||
"problem_type", | ||
"time_train_s", | ||
"time_infer_s", | ||
"tid", | ||
]] | ||
return df_configs | ||
|
||
|
||
if __name__ == '__main__': | ||
# Load Context | ||
context_name = "D244_F3_C1530_30" | ||
repo: EvaluationRepository = load_repository(context_name, cache=True) | ||
|
||
expname = "./initial_experiment_ag_models" # folder location of all experiment artifacts | ||
ignore_cache = False # set to True to overwrite existing caches and re-run experiments from scratch | ||
|
||
datasets = [ | ||
"blood-transfusion-service-center", | ||
"Australian", | ||
] | ||
tids = [repo.dataset_to_tid(dataset) for dataset in datasets] | ||
folds = repo.folds | ||
|
||
# all_configs = repo.configs() | ||
# import random | ||
# reproduce_configs = random.sample(all_configs, k=10) | ||
|
||
reproduce_configs = [ | ||
"RandomForest_c1_BAG_L1", | ||
"ExtraTrees_c1_BAG_L1", | ||
"LightGBM_c1_BAG_L1", | ||
"XGBoost_c1_BAG_L1", | ||
"CatBoost_c1_BAG_L1", | ||
"TabPFN_c1_BAG_L1", | ||
"NeuralNetTorch_c1_BAG_L1", | ||
"NeuralNetFastAI_c1_BAG_L1", | ||
] | ||
|
||
methods_dict = {} | ||
for c in reproduce_configs: | ||
ag_hyperparameters = repo.autogluon_hyperparameters_dict(configs=[c]) | ||
methods_dict[c + "_V2"] = {"hyperparameters": ag_hyperparameters} | ||
|
||
extra_kwargs = { | ||
"fit_weighted_ensemble": False, | ||
"num_bag_folds": 8, | ||
"num_bag_sets": 1, | ||
} | ||
|
||
for k, v in methods_dict.items(): | ||
v.update(extra_kwargs) | ||
|
||
methods = list(methods_dict.keys()) | ||
|
||
results_lst = run_experiments( | ||
expname=expname, | ||
tids=tids, | ||
folds=folds, | ||
methods=methods, | ||
methods_dict=methods_dict, | ||
task_metadata=repo.task_metadata, | ||
ignore_cache=ignore_cache, | ||
) | ||
results_df = pd.concat(results_lst, ignore_index=True) | ||
results_df = convert_leaderboard_to_configs(results_df) | ||
|
||
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000): | ||
print(results_df) | ||
|
||
metrics = repo.compare_metrics( | ||
results_df, | ||
datasets=datasets, | ||
folds=folds, | ||
baselines=["AutoGluon_bq_4h8c_2023_11_14"], | ||
configs=reproduce_configs, | ||
) | ||
with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 1000): | ||
print(f"Config Metrics Example:\n{metrics}") | ||
evaluator_output = repo.plot_overall_rank_comparison( | ||
results_df=metrics, | ||
save_dir=expname, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to get those!