From d3215f2d687282f480f5855544c9af34bfbbd004 Mon Sep 17 00:00:00 2001 From: jpgard Date: Thu, 10 Aug 2023 21:41:30 -0400 Subject: [PATCH 1/9] initial commit --- .dockerignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..a9a5aecf42 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +tmp From 2e28ae94e52c1dcb4d0a6f8576c1045573f692e7 Mon Sep 17 00:00:00 2001 From: jpgard Date: Thu, 10 Aug 2023 21:42:10 -0400 Subject: [PATCH 2/9] initial commit --- docker/tableshift.dockerfile | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docker/tableshift.dockerfile diff --git a/docker/tableshift.dockerfile b/docker/tableshift.dockerfile new file mode 100644 index 0000000000..81647f3143 --- /dev/null +++ b/docker/tableshift.dockerfile @@ -0,0 +1,5 @@ +FROM python:3.8-bullseye + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install git +COPY requirements.txt requirements.txt +RUN python -m pip install -r requirements.txt \ No newline at end of file From 7c7dc368ee67c524445ad57a5c6ad8ca9e5d6750 Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 15:42:34 -0400 Subject: [PATCH 3/9] install lightgbm via conda --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 76f8918c35..a9df03ecab 100644 --- a/environment.yml +++ b/environment.yml @@ -4,6 +4,7 @@ channels: dependencies: - python=3.8 - pip + - lightgbm=3.3 - pip: - -r requirements.txt - -e . From e18fc2c5ac01c9a7c5a2b46fa782763ad321a887 Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:06:30 -0400 Subject: [PATCH 4/9] complete dockerfile; add local files, installation of tableshift module, and set apropriate env variables --- docker/tableshift.dockerfile | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docker/tableshift.dockerfile b/docker/tableshift.dockerfile index 81647f3143..338e532e83 100644 --- a/docker/tableshift.dockerfile +++ b/docker/tableshift.dockerfile @@ -2,4 +2,14 @@ FROM python:3.8-bullseye RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install git COPY requirements.txt requirements.txt -RUN python -m pip install -r requirements.txt \ No newline at end of file +RUN python -m pip install --upgrade pip +RUN python -m pip install -r requirements.txt + +RUN mkdir /tableshift +COPY . /tableshift +WORKDIR /tableshift +RUN python -m pip install --no-deps . + +# Add tableshift to pythonpath; necessary to ensure +# tableshift module imports work inside docker. +ENV PYTHONPATH "${PYTHONPATH}:/tableshift" From dacb9bfddc4a30b2393b8301c9728358400c5a42 Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:06:51 -0400 Subject: [PATCH 5/9] modifications to better support both pytorch and sklearn models --- examples/run_expt.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 81a11708d5..1ef94c6dce 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -1,11 +1,14 @@ import argparse import logging +import torch from sklearn.metrics import accuracy_score from tableshift.core import get_dataset from tableshift.models.training import train from tableshift.models.utils import get_estimator +from tableshift.models.default_hparams import get_default_config + LOG_LEVEL = logging.DEBUG @@ -23,15 +26,23 @@ def main(experiment, cache_dir, model, debug: bool): dset = get_dataset(experiment, cache_dir) X, y, _, _ = dset.get_pandas("train") - estimator = get_estimator(model) - estimator = train(estimator, dset) - if dset.is_domain_split: - X_te, y_te, _, _ = dset.get_pandas("ood_test") + config = get_default_config(model, dset) + estimator = get_estimator(model, **config) + estimator = train(estimator, dset, config=config) + + if not isinstance(estimator, torch.nn.Module): + # Case: non-pytorch estimator; perform test-split evaluation. + test_split = "ood_test" if dset.is_domain_split else "test" + # Fetch predictions and labels for a sklearn model. + X_te, y_te, _, _ = dset.get_pandas(test_split) + yhat_te = estimator.predict(X_te) + + acc = accuracy_score(y_true=y_te, y_pred=yhat_te) + print(f"training completed! {test_split} accuracy: {acc:.4f}") + else: - X_te, y_te, _, _ = dset.get_pandas("test") - yhat_te = estimator.predict(X_te) - acc = accuracy_score(y_true=y_te, y_pred=yhat_te) - print(f"training completed! test accuracy: {acc:.4f}") + # Case: pytorch estimator; eval is already performed + printed by train(). + print("training completed!") return From fdd0231c689f7bd22ee4b52f12645b902639ddfe Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:07:11 -0400 Subject: [PATCH 6/9] remove dependencies not needed for docker setup --- requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0776eb1957..c8e359c724 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ botocore==1.29.106 build==0.10.0 cachetools==5.2.0 catalogue==2.0.8 -catboost==1.1.1 +catboost==1.2 category-encoders==2.6.0 certifi==2022.09.24 cffi==1.15.1 @@ -118,7 +118,6 @@ kaggle==1.5.13 keyring==23.13.1 kiwisolver==1.4.4 langcodes==3.3.0 -lightgbm==3.3.3 lightgbm-ray==0.1.8 lightning-utilities==0.8.0 llvmlite==0.39.1 @@ -170,7 +169,6 @@ pkgutil_resolve_name==1.3.10 platformdirs==2.5.4 plotly==5.14.0 pluggy==1.0.0 -pmdarima==1.8.5 preshed==3.0.8 prometheus-client==0.13.1 promise==2.3 From 6fc691cb491346962353f2111256486f3d88878b Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:07:32 -0400 Subject: [PATCH 7/9] remove unused import --- tableshift/configs/benchmark_configs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tableshift/configs/benchmark_configs.py b/tableshift/configs/benchmark_configs.py index 171149f6f0..4f0ef7cfd3 100644 --- a/tableshift/configs/benchmark_configs.py +++ b/tableshift/configs/benchmark_configs.py @@ -7,8 +7,7 @@ from tableshift.configs.experiment_config import ExperimentConfig from tableshift.configs.experiment_defaults import DEFAULT_ID_TEST_SIZE, \ DEFAULT_OOD_VAL_SIZE, DEFAULT_ID_VAL_SIZE, DEFAULT_RANDOM_STATE -from tableshift.core import RandomSplitter, Grouper, PreprocessorConfig, \ - DomainSplitter +from tableshift.core import Grouper, PreprocessorConfig, DomainSplitter from tableshift.datasets import BRFSS_YEARS, ACS_YEARS, NHANES_YEARS from tableshift.datasets.mimic_extract import MIMIC_EXTRACT_STATIC_FEATURES from tableshift.datasets.mimic_extract_feature_lists import \ From de58de00c630d6e6196e0ff48719d6e7f503c20d Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:08:04 -0400 Subject: [PATCH 8/9] implement TabularDataset.cat_idxs() --- tableshift/core/tabular_dataset.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tableshift/core/tabular_dataset.py b/tableshift/core/tabular_dataset.py index 6afda6d655..74f2dee34b 100644 --- a/tableshift/core/tabular_dataset.py +++ b/tableshift/core/tabular_dataset.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from tableshift.third_party.domainbed import InfiniteDataLoader -from .features import Preprocessor, PreprocessorConfig +from .features import Preprocessor, PreprocessorConfig, is_categorical from .grouper import Grouper from .metrics import metrics_by_group from .splitter import Splitter, DomainSplitter @@ -79,7 +79,7 @@ def uid(self) -> str: @property def is_domain_split(self) -> bool: """Return True if this dataset uses a DomainSplitter, else False.""" - return self.domain_label_colname is not None + return isinstance(self.splitter, DomainSplitter) @property def eval_split_names(self) -> Tuple: @@ -93,10 +93,7 @@ def eval_split_names(self) -> Tuple: @property def domain_split_varname(self): - if not self.is_domain_split: - return None - - elif isinstance(self.splitter, DomainSplitter): + if isinstance(self.splitter, DomainSplitter): return self.splitter.domain_split_varname else: return self.domain_label_colname @@ -255,8 +252,7 @@ def n_domains(self) -> int: @property def cat_idxs(self) -> List[int]: - # TODO: implement this. - raise + return [i for i, col in enumerate(self._df.columns) if is_categorical(self._df[col])] def get_domains(self, split) -> Union[List[str], None]: """Fetch a list of the domains.""" From 35ab0dbc662c98ec4c5389c0306e90be84373ae9 Mon Sep 17 00:00:00 2001 From: Josh Gardner Date: Sun, 20 Aug 2023 18:08:41 -0400 Subject: [PATCH 9/9] updates to fix various small model-specific bugs so that all models can be trained with run_expt.py --- tableshift/models/compat.py | 5 ++++- tableshift/models/default_hparams.py | 17 +++++++++-------- tableshift/models/torchutils.py | 6 +++++- tableshift/models/training.py | 5 ++++- tableshift/models/utils.py | 16 +++++++++++++++- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tableshift/models/compat.py b/tableshift/models/compat.py index 5e8712ebc7..7f4f2e3d28 100644 --- a/tableshift/models/compat.py +++ b/tableshift/models/compat.py @@ -140,7 +140,10 @@ def is_domain_adaptation_model_name(model_name: str) -> bool: def is_pytorch_model_name(model: str) -> bool: """Helper function to determine whether a model name is a pytorch model. - ISee description of is_pytorch_model() above.""" + See description of is_pytorch_model() above.""" + if model=="catboost": + logging.warning("Catboost models are not suported in Ray hyperparameter training." + " Instead, use the provided catboost-specific script.") is_sklearn = model in SKLEARN_MODEL_NAMES is_pt = model in PYTORCH_MODEL_NAMES assert is_sklearn or is_pt, f"unknown model name {model}" diff --git a/tableshift/models/default_hparams.py b/tableshift/models/default_hparams.py index d31782ccf1..27e8d79a47 100644 --- a/tableshift/models/default_hparams.py +++ b/tableshift/models/default_hparams.py @@ -140,14 +140,15 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: - """Get a default config for a model by name.""" + """Get a default config for a model, by name.""" config = _DEFAULT_CONFIGS.get(model, {}) + model_is_pt = is_pytorch_model_name(model) d_in = dset.X_shape[1] - if is_pytorch_model_name(model) and model != "ft_transformer": + if model_is_pt and model != "ft_transformer": config.update({"d_in": d_in, "activation": "ReLU"}) - elif is_pytorch_model_name(model): + elif model_is_pt: config.update({"n_num_features": d_in}) if model in ("tabtransformer", "saint"): @@ -155,7 +156,7 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: config["cat_idxs"] = cat_idxs config["categories"] = [2] * len(cat_idxs) - # Models that use non-cross-entropy training objectives. + # Set the training objective and any associated hypperparameters. if model == "dro": config["criterion"] = DROLoss(size=config["size"], reg=config["reg"], @@ -170,10 +171,10 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: config["criterion"] = GroupDROLoss(n_groups=2) - else: + elif model_is_pt: config["criterion"] = F.binary_cross_entropy_with_logits - if is_pytorch_model_name(model) and model != "dann": + if model_is_pt and model != "dann": # Note: for DANN model, lr and weight decay are set separately for D # and G. config.update({"lr": 0.01, @@ -182,9 +183,9 @@ def get_default_config(model: str, dset: TabularDataset) -> dict: # Do not overwrite batch size or epochs if they are set in the default # config for the model. - if "batch_size" not in config: + if "batch_size" not in config and model_is_pt: config["batch_size"] = DEFAULT_BATCH_SIZE - if "n_epochs" not in config: + if "n_epochs" not in config and model_is_pt: config["n_epochs"] = 1 if model == "saint" and d_in > 100: diff --git a/tableshift/models/torchutils.py b/tableshift/models/torchutils.py index c960d5d152..a2d30e7c88 100644 --- a/tableshift/models/torchutils.py +++ b/tableshift/models/torchutils.py @@ -69,12 +69,16 @@ def apply_model(model: torch.nn.Module, x): @torch.no_grad() -def get_predictions_and_labels(model, loader, device, as_logits=False) -> Tuple[ +def get_predictions_and_labels(model, loader, device=None, as_logits=False) -> Tuple[ np.ndarray, np.ndarray]: """Get the predictions (as logits, or probabilities) and labels.""" prediction = [] label = [] + if not device: + device = f"cuda:{torch.cuda.current_device()}" \ + if torch.cuda.is_available() else "cpu" + modelname = model.__class__.__name__ for batch in tqdm(loader, desc=f"{modelname}:getpreds"): batch_x, batch_y, _, _ = unpack_batch(batch) diff --git a/tableshift/models/training.py b/tableshift/models/training.py index 7f20f5f9aa..f129481e0b 100644 --- a/tableshift/models/training.py +++ b/tableshift/models/training.py @@ -120,10 +120,13 @@ def get_eval_loaders( def _train_pytorch(estimator: SklearnStylePytorchModel, dset: TabularDataset, - device: str, config=PYTORCH_DEFAULTS, + device: str=None, tune_report_split: str = None): """Helper function to train a pytorch estimator.""" + if not device: + device = f"cuda:{torch.cuda.current_device()}" \ + if torch.cuda.is_available() else "cpu" logging.debug(f"config is {config}") logging.debug(f"estimator is of type {type(estimator)}") logging.debug(f"dset name is {dset.name}") diff --git a/tableshift/models/utils.py b/tableshift/models/utils.py index 95dcb87357..cda392692b 100644 --- a/tableshift/models/utils.py +++ b/tableshift/models/utils.py @@ -22,7 +22,21 @@ from tableshift.models.wcs import WeightedCovariateShiftClassifier -def get_estimator(model, d_out=1, **kwargs): +def get_estimator(model:str, d_out=1, **kwargs): + """ + Fetch an estimator for training. + + Args: + model: the string name of the model to use. + d_out: output dimension of the model (set to 1 for binary classification). + kwargs: named arguments to pass to the model's class constructor. These + vary by model; for more details see below. Note that only a specific + subset of the kwargs will be used; passing arbitrary kwargs not accepted by + the model's class constructor will result in those kwargs being ignored. + Returns: + An instance of the class specified by the `model` string, with + any hyperparameters set according to kwargs. + """ if model == "aldro": assert d_out == 1, "assume binary classification." return AdversarialLabelDROModel(