Skip to content

Commit

Permalink
Fix: Ability load models saved using versions before 1.7 (#1207)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylernisonoff authored Nov 19, 2024
1 parent 196ebb3 commit 642ced4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
19 changes: 15 additions & 4 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1626,15 +1626,26 @@
" except FileNotFoundError:\n",
" raise Exception('No configuration found in directory.')\n",
"\n",
" # in 1.6.4, `local_scaler_type` / `scalers_` lived on the dataset.\n",
" # in order to preserve backwards-compatibility, we check to see if these are found on the dataset\n",
" # in case they cannot be found in `config_dict`\n",
" default_scalar_type = getattr(dataset, \"local_scaler_type\", None)\n",
" default_scalars_ = getattr(dataset, \"scalers_\", None)\n",
"\n",
" # Create NeuralForecast object\n",
" neuralforecast = NeuralForecast(\n",
" models=models,\n",
" freq=config_dict['freq'],\n",
" local_scaler_type=config_dict['local_scaler_type'],\n",
" local_scaler_type=config_dict.get(\"local_scaler_type\", default_scalar_type),\n",
" )\n",
"\n",
" for attr in ['id_col', 'time_col', 'target_col']:\n",
" setattr(neuralforecast, attr, config_dict[attr])\n",
" attr_to_default = {\n",
" \"id_col\": \"unique_id\",\n",
" \"time_col\": \"ds\",\n",
" \"target_col\": \"y\"\n",
" }\n",
" for attr, default in attr_to_default.items():\n",
" setattr(neuralforecast, attr, config_dict.get(attr, default))\n",
" # only restore attribute if available\n",
" for attr in ['prediction_intervals', '_cs_df']:\n",
" if attr in config_dict.keys():\n",
Expand All @@ -1655,7 +1666,7 @@
" # Fitted flag\n",
" neuralforecast._fitted = config_dict['_fitted']\n",
"\n",
" neuralforecast.scalers_ = config_dict['scalers_']\n",
" neuralforecast.scalers_ = config_dict.get(\"scalers_\", default_scalars_)\n",
"\n",
" return neuralforecast\n",
" \n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/nbdev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ project:

website:
title: "neuralforecast"
site-url: "https://Nixtla.github.io/neuralforecast/"
site-url: "https://nixtlaverse.nixtla.io/neuralforecast/"
description: "Time series forecasting suite using deep learning models"
repo-branch: main
repo-url: "https://github.com/Nixtla/neuralforecast/"
15 changes: 11 additions & 4 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,15 +1623,22 @@ def load(path, verbose=False, **kwargs):
except FileNotFoundError:
raise Exception("No configuration found in directory.")

# in 1.6.4, `local_scaler_type` / `scalers_` lived on the dataset.
# in order to preserve backwards-compatibility, we check to see if these are found on the dataset
# in case they cannot be found in `config_dict`
default_scalar_type = getattr(dataset, "local_scaler_type", None)
default_scalars_ = getattr(dataset, "scalers_", None)

# Create NeuralForecast object
neuralforecast = NeuralForecast(
models=models,
freq=config_dict["freq"],
local_scaler_type=config_dict["local_scaler_type"],
local_scaler_type=config_dict.get("local_scaler_type", default_scalar_type),
)

for attr in ["id_col", "time_col", "target_col"]:
setattr(neuralforecast, attr, config_dict[attr])
attr_to_default = {"id_col": "unique_id", "time_col": "ds", "target_col": "y"}
for attr, default in attr_to_default.items():
setattr(neuralforecast, attr, config_dict.get(attr, default))
# only restore attribute if available
for attr in ["prediction_intervals", "_cs_df"]:
if attr in config_dict.keys():
Expand All @@ -1652,7 +1659,7 @@ def load(path, verbose=False, **kwargs):
# Fitted flag
neuralforecast._fitted = config_dict["_fitted"]

neuralforecast.scalers_ = config_dict["scalers_"]
neuralforecast.scalers_ = config_dict.get("scalers_", default_scalars_)

return neuralforecast

Expand Down

0 comments on commit 642ced4

Please sign in to comment.