Skip to content

Commit

Permalink
docs: Fix README example to be runnable with current APIs (#54)
Browse files Browse the repository at this point in the history
* Modern versions of scikit-learn do not import most modules of the
  library, and so these must now be imported before they can be used.
* Update Vset arguments to use 'vfuncs' and 'vfunc_keys' args over previous
  'module' args.
  • Loading branch information
matthewfeickert authored Feb 8, 2024
1 parent 22fcb92 commit 855116c
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,40 +39,40 @@ simply using `vflow`.

```python
import sklearn
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from vflow import init_args, Vset
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

from vflow import Vset, init_args

# initialize data
X, y = sklearn.datasets.make_classification()
X, y = make_classification()
X_train, X_test, y_train, y_test = init_args(
sklearn.model_selection.train_test_split(X, y),
names=['X_train', 'X_test', 'y_train', 'y_test'] # optionally name the args
train_test_split(X, y),
names=["X_train", "X_test", "y_train", "y_test"], # optionally name the args
)

# subsample data
subsampling_funcs = [
sklearn.utils.resample for _ in range(3)
]
subsampling_set = Vset(name='subsampling',
modules=subsampling_funcs,
output_matching=True)
subsampling_funcs = [sklearn.utils.resample for _ in range(3)]
subsampling_set = Vset(
name="subsampling", vfuncs=subsampling_funcs, output_matching=True
)
X_trains, y_trains = subsampling_set(X_train, y_train)

# fit models
models = [
sklearn.linear_model.LogisticRegression(),
sklearn.tree.DecisionTreeClassifier()
]
modeling_set = Vset(name='modeling',
modules=models,
module_keys=["LR", "DT"])
models = [LogisticRegression(), DecisionTreeClassifier()]
modeling_set = Vset(name="modeling", vfuncs=models, vfunc_keys=["LR", "DT"])
modeling_set.fit(X_trains, y_trains)
preds_test = modeling_set.predict(X_test)

# get metrics
binary_metrics_set = Vset(name='binary_metrics',
modules=[accuracy_score, balanced_accuracy_score],
module_keys=["Acc", "Bal_Acc"])
binary_metrics_set = Vset(
name="binary_metrics",
vfuncs=[accuracy_score, balanced_accuracy_score],
vfunc_keys=["Acc", "Bal_Acc"],
)
binary_metrics = binary_metrics_set.evaluate(preds_test, y_test)
```

Expand Down

0 comments on commit 855116c

Please sign in to comment.