Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docker #2

Merged
merged 9 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tmp
15 changes: 15 additions & 0 deletions docker/tableshift.dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM python:3.8-bullseye

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install git
COPY requirements.txt requirements.txt
RUN python -m pip install --upgrade pip
RUN python -m pip install -r requirements.txt

RUN mkdir /tableshift
COPY . /tableshift
WORKDIR /tableshift
RUN python -m pip install --no-deps .

# Add tableshift to pythonpath; necessary to ensure
# tableshift module imports work inside docker.
ENV PYTHONPATH "${PYTHONPATH}:/tableshift"
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
dependencies:
- python=3.8
- pip
- lightgbm=3.3
- pip:
- -r requirements.txt
- -e .
Expand Down
27 changes: 19 additions & 8 deletions examples/run_expt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
import logging

import torch
from sklearn.metrics import accuracy_score

from tableshift.core import get_dataset
from tableshift.models.training import train
from tableshift.models.utils import get_estimator
from tableshift.models.default_hparams import get_default_config


LOG_LEVEL = logging.DEBUG

Expand All @@ -23,15 +26,23 @@ def main(experiment, cache_dir, model, debug: bool):

dset = get_dataset(experiment, cache_dir)
X, y, _, _ = dset.get_pandas("train")
estimator = get_estimator(model)
estimator = train(estimator, dset)
if dset.is_domain_split:
X_te, y_te, _, _ = dset.get_pandas("ood_test")
config = get_default_config(model, dset)
estimator = get_estimator(model, **config)
estimator = train(estimator, dset, config=config)

if not isinstance(estimator, torch.nn.Module):
# Case: non-pytorch estimator; perform test-split evaluation.
test_split = "ood_test" if dset.is_domain_split else "test"
# Fetch predictions and labels for a sklearn model.
X_te, y_te, _, _ = dset.get_pandas(test_split)
yhat_te = estimator.predict(X_te)

acc = accuracy_score(y_true=y_te, y_pred=yhat_te)
print(f"training completed! {test_split} accuracy: {acc:.4f}")

else:
X_te, y_te, _, _ = dset.get_pandas("test")
yhat_te = estimator.predict(X_te)
acc = accuracy_score(y_true=y_te, y_pred=yhat_te)
print(f"training completed! test accuracy: {acc:.4f}")
# Case: pytorch estimator; eval is already performed + printed by train().
print("training completed!")
return


Expand Down
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ botocore==1.29.106
build==0.10.0
cachetools==5.2.0
catalogue==2.0.8
catboost==1.1.1
catboost==1.2
category-encoders==2.6.0
certifi==2022.09.24
cffi==1.15.1
Expand Down Expand Up @@ -118,7 +118,6 @@ kaggle==1.5.13
keyring==23.13.1
kiwisolver==1.4.4
langcodes==3.3.0
lightgbm==3.3.3
lightgbm-ray==0.1.8
lightning-utilities==0.8.0
llvmlite==0.39.1
Expand Down Expand Up @@ -170,7 +169,6 @@ pkgutil_resolve_name==1.3.10
platformdirs==2.5.4
plotly==5.14.0
pluggy==1.0.0
pmdarima==1.8.5
preshed==3.0.8
prometheus-client==0.13.1
promise==2.3
Expand Down
3 changes: 1 addition & 2 deletions tableshift/configs/benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from tableshift.configs.experiment_config import ExperimentConfig
from tableshift.configs.experiment_defaults import DEFAULT_ID_TEST_SIZE, \
DEFAULT_OOD_VAL_SIZE, DEFAULT_ID_VAL_SIZE, DEFAULT_RANDOM_STATE
from tableshift.core import RandomSplitter, Grouper, PreprocessorConfig, \
DomainSplitter
from tableshift.core import Grouper, PreprocessorConfig, DomainSplitter
from tableshift.datasets import BRFSS_YEARS, ACS_YEARS, NHANES_YEARS
from tableshift.datasets.mimic_extract import MIMIC_EXTRACT_STATIC_FEATURES
from tableshift.datasets.mimic_extract_feature_lists import \
Expand Down
12 changes: 4 additions & 8 deletions tableshift/core/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import DataLoader

from tableshift.third_party.domainbed import InfiniteDataLoader
from .features import Preprocessor, PreprocessorConfig
from .features import Preprocessor, PreprocessorConfig, is_categorical
from .grouper import Grouper
from .metrics import metrics_by_group
from .splitter import Splitter, DomainSplitter
Expand Down Expand Up @@ -79,7 +79,7 @@ def uid(self) -> str:
@property
def is_domain_split(self) -> bool:
"""Return True if this dataset uses a DomainSplitter, else False."""
return self.domain_label_colname is not None
return isinstance(self.splitter, DomainSplitter)

@property
def eval_split_names(self) -> Tuple:
Expand All @@ -93,10 +93,7 @@ def eval_split_names(self) -> Tuple:

@property
def domain_split_varname(self):
if not self.is_domain_split:
return None

elif isinstance(self.splitter, DomainSplitter):
if isinstance(self.splitter, DomainSplitter):
return self.splitter.domain_split_varname
else:
return self.domain_label_colname
Expand Down Expand Up @@ -255,8 +252,7 @@ def n_domains(self) -> int:

@property
def cat_idxs(self) -> List[int]:
# TODO: implement this.
raise
return [i for i, col in enumerate(self._df.columns) if is_categorical(self._df[col])]

def get_domains(self, split) -> Union[List[str], None]:
"""Fetch a list of the domains."""
Expand Down
5 changes: 4 additions & 1 deletion tableshift/models/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def is_domain_adaptation_model_name(model_name: str) -> bool:
def is_pytorch_model_name(model: str) -> bool:
"""Helper function to determine whether a model name is a pytorch model.

ISee description of is_pytorch_model() above."""
See description of is_pytorch_model() above."""
if model=="catboost":
logging.warning("Catboost models are not suported in Ray hyperparameter training."
" Instead, use the provided catboost-specific script.")
is_sklearn = model in SKLEARN_MODEL_NAMES
is_pt = model in PYTORCH_MODEL_NAMES
assert is_sklearn or is_pt, f"unknown model name {model}"
Expand Down
17 changes: 9 additions & 8 deletions tableshift/models/default_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,23 @@


def get_default_config(model: str, dset: TabularDataset) -> dict:
"""Get a default config for a model by name."""
"""Get a default config for a model, by name."""
config = _DEFAULT_CONFIGS.get(model, {})
model_is_pt = is_pytorch_model_name(model)

d_in = dset.X_shape[1]
if is_pytorch_model_name(model) and model != "ft_transformer":
if model_is_pt and model != "ft_transformer":
config.update({"d_in": d_in,
"activation": "ReLU"})
elif is_pytorch_model_name(model):
elif model_is_pt:
config.update({"n_num_features": d_in})

if model in ("tabtransformer", "saint"):
cat_idxs = dset.cat_idxs
config["cat_idxs"] = cat_idxs
config["categories"] = [2] * len(cat_idxs)

# Models that use non-cross-entropy training objectives.
# Set the training objective and any associated hypperparameters.
if model == "dro":
config["criterion"] = DROLoss(size=config["size"],
reg=config["reg"],
Expand All @@ -170,10 +171,10 @@ def get_default_config(model: str, dset: TabularDataset) -> dict:
config["criterion"] = GroupDROLoss(n_groups=2)


else:
elif model_is_pt:
config["criterion"] = F.binary_cross_entropy_with_logits

if is_pytorch_model_name(model) and model != "dann":
if model_is_pt and model != "dann":
# Note: for DANN model, lr and weight decay are set separately for D
# and G.
config.update({"lr": 0.01,
Expand All @@ -182,9 +183,9 @@ def get_default_config(model: str, dset: TabularDataset) -> dict:

# Do not overwrite batch size or epochs if they are set in the default
# config for the model.
if "batch_size" not in config:
if "batch_size" not in config and model_is_pt:
config["batch_size"] = DEFAULT_BATCH_SIZE
if "n_epochs" not in config:
if "n_epochs" not in config and model_is_pt:
config["n_epochs"] = 1

if model == "saint" and d_in > 100:
Expand Down
6 changes: 5 additions & 1 deletion tableshift/models/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def apply_model(model: torch.nn.Module, x):


@torch.no_grad()
def get_predictions_and_labels(model, loader, device, as_logits=False) -> Tuple[
def get_predictions_and_labels(model, loader, device=None, as_logits=False) -> Tuple[
np.ndarray, np.ndarray]:
"""Get the predictions (as logits, or probabilities) and labels."""
prediction = []
label = []

if not device:
device = f"cuda:{torch.cuda.current_device()}" \
if torch.cuda.is_available() else "cpu"

modelname = model.__class__.__name__
for batch in tqdm(loader, desc=f"{modelname}:getpreds"):
batch_x, batch_y, _, _ = unpack_batch(batch)
Expand Down
5 changes: 4 additions & 1 deletion tableshift/models/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,13 @@ def get_eval_loaders(


def _train_pytorch(estimator: SklearnStylePytorchModel, dset: TabularDataset,
device: str,
config=PYTORCH_DEFAULTS,
device: str=None,
tune_report_split: str = None):
"""Helper function to train a pytorch estimator."""
if not device:
device = f"cuda:{torch.cuda.current_device()}" \
if torch.cuda.is_available() else "cpu"
logging.debug(f"config is {config}")
logging.debug(f"estimator is of type {type(estimator)}")
logging.debug(f"dset name is {dset.name}")
Expand Down
16 changes: 15 additions & 1 deletion tableshift/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@
from tableshift.models.wcs import WeightedCovariateShiftClassifier


def get_estimator(model, d_out=1, **kwargs):
def get_estimator(model:str, d_out=1, **kwargs):
"""
Fetch an estimator for training.

Args:
model: the string name of the model to use.
d_out: output dimension of the model (set to 1 for binary classification).
kwargs: named arguments to pass to the model's class constructor. These
vary by model; for more details see below. Note that only a specific
subset of the kwargs will be used; passing arbitrary kwargs not accepted by
the model's class constructor will result in those kwargs being ignored.
Returns:
An instance of the class specified by the `model` string, with
any hyperparameters set according to kwargs.
"""
if model == "aldro":
assert d_out == 1, "assume binary classification."
return AdversarialLabelDROModel(
Expand Down
Loading