Skip to content

Commit

Permalink
Fix: deserialize callbacks for fit params from dict definition (#1389)
Browse files Browse the repository at this point in the history
* Fix: deserialize callbacks for fit params from dict definition

* Fix: formatting

* Fix: better checking for None callbacks

* Fix: update build_callbacks reference
  • Loading branch information
RollerKnobster authored Jul 1, 2024
1 parent 1627a27 commit 23c5a73
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 24 deletions.
4 changes: 4 additions & 0 deletions gordo/machine/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ def get_params(self, **params):
params = super().get_params(**params)
params.update({"kind": self.kind})
params.update(self.kwargs)
if self.kwargs.get("callbacks") is not None and any(
isinstance(callback, dict) for callback in self.kwargs["callbacks"]
):
params["callbacks"] = serializer.build_callbacks(self.kwargs["callbacks"])
return params

def _prepare_model(self):
Expand Down
6 changes: 5 additions & 1 deletion gordo/serializer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .from_definition import from_definition, load_params_from_definition
from .from_definition import (
from_definition,
load_params_from_definition,
build_callbacks,
)
from .into_definition import into_definition, load_definition_from_params
from .serializer import (
dump,
Expand Down
46 changes: 23 additions & 23 deletions gordo/serializer/from_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,28 +248,6 @@ def _build_step(
)


def _build_callbacks(definitions: list):
"""
Parameters
----------
definitions
List of callbacks definitions
Examples
--------
>>> callbacks=_build_callbacks([{'tensorflow.keras.callbacks.EarlyStopping': {'monitor': 'val_loss,', 'patience': 10}}])
>>> type(callbacks[0])
<class 'keras.src.callbacks.early_stopping.EarlyStopping'>
Returns
-------
"""
callbacks = []
for callback in definitions:
callbacks.append(_build_step(callback))
return callbacks


def _load_param_classes(params: dict):
"""
Inspect the params' values and determine if any can be loaded as a class.
Expand Down Expand Up @@ -350,7 +328,7 @@ def _load_param_classes(params: dict):
kwargs = _load_param_classes(sub_params)
params[key] = create_instance(Model, **kwargs) # type: ignore
elif key == "callbacks" and isinstance(value, list):
params[key] = _build_callbacks(value)
params[key] = build_callbacks(value)
return params


Expand All @@ -367,3 +345,25 @@ def load_params_from_definition(definition: dict) -> dict:
"Expected definition to be a dict," f"found: {type(definition)}"
)
return _load_param_classes(definition)


def build_callbacks(definitions: list):
"""
Parameters
----------
definitions
List of callbacks definitions
Examples
--------
>>> callbacks=build_callbacks([{'tensorflow.keras.callbacks.EarlyStopping': {'monitor': 'val_loss,', 'patience': 10}}])
>>> type(callbacks[0])
<class 'keras.src.callbacks.early_stopping.EarlyStopping'>
Returns
-------
"""
callbacks = []
for callback in definitions:
callbacks.append(_build_step(callback))
return callbacks
2 changes: 2 additions & 0 deletions tests/gordo/machine/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def test_keras_autoencoder_fits_callbacks():
assert isinstance(first_callback, EarlyStopping)
assert first_callback.monitor == "val_loss"
assert first_callback.patience == 10
X, y = np.random.rand(10, 10), np.random.rand(10, 10)
model.fit(X, y)


def test_parse_module_path():
Expand Down

0 comments on commit 23c5a73

Please sign in to comment.