Skip to content

Commit

Permalink
Add wrapper to offset adustments for scoring functions
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Aug 23, 2019
1 parent 69316c4 commit fdaa851
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
13 changes: 11 additions & 2 deletions gordo_components/builder/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_val_score, TimeSeriesSplit
from sklearn.metrics import explained_variance_score, make_scorer
from sklearn.pipeline import Pipeline

from gordo_components.util import disk_registry
from gordo_components import serializer, __version__, MAJOR_VERSION, MINOR_VERSION
from gordo_components.dataset.dataset import _get_dataset
from gordo_components.dataset.base import GordoBaseDataset
from gordo_components.model.base import GordoBase
from gordo_components.model.utils import metric_wrapper


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,8 +78,14 @@ def build_model(
start = time.time()

scores: Dict[str, Any]
if hasattr(model, "score"):
cv_scores = cross_val_score(model, X, y, cv=TimeSeriesSplit(n_splits=3))
if hasattr(model, "predict"):
cv_scores = cross_val_score(
model,
X,
y,
scoring=make_scorer(metric_wrapper(explained_variance_score)),
cv=TimeSeriesSplit(n_splits=3),
)
scores = {
"explained-variance": {
"mean": cv_scores.mean(),
Expand Down
14 changes: 14 additions & 0 deletions gordo_components/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import typing
import functools
from typing import Optional, Union, List
from datetime import timedelta, datetime

Expand All @@ -10,6 +11,19 @@
from gordo_components.dataset.sensor_tag import SensorTag


def metric_wrapper(metric):
"""
Ensures that a given metric works properly when the model itself returns
a y which is shorter than the target y.
"""

@functools.wraps(metric)
def _wrapper(y_true, y_pred, *args, **kwargs):
return metric(y_true[-len(y_pred) :], y_pred, *args, **kwargs)

return _wrapper


def make_base_dataframe(
tags: typing.Union[typing.List[SensorTag], typing.List[str]],
model_input: np.ndarray,
Expand Down
1 change: 0 additions & 1 deletion tests/gordo_components/builder/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tempfile import TemporaryDirectory

import pytest
import sklearn
import numpy as np

import gordo_components
Expand Down

0 comments on commit fdaa851

Please sign in to comment.