diff --git a/mlforecast/auto.py b/mlforecast/auto.py index 0629913b..bc015256 100644 --- a/mlforecast/auto.py +++ b/mlforecast/auto.py @@ -512,7 +512,12 @@ def fit( if loss is None: def loss(df, train_df): # noqa: ARG001 - return smape(df, models=["model"])["model"].mean() + return smape( + df, + models=["model"], + id_col=id_col, + target_col=target_col, + )["model"].mean() if study_kwargs is None: study_kwargs = {} @@ -554,8 +559,14 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]: study.optimize(objective, n_trials=num_samples, **optimize_kwargs) self.results_[name] = study best_config = study.best_trial.user_attrs["config"] - best_config["mlf_fit_params"].pop("fitted", None) - best_config["mlf_fit_params"].pop("prediction_intervals", None) + for arg in ( + "fitted", + "prediction_intervals", + "id_col", + "time_col", + "target_col", + ): + best_config["mlf_fit_params"].pop(arg, None) best_model = clone(auto_model.model) best_model.set_params(**best_config["model_params"]) self.models_[name] = MLForecast( @@ -567,6 +578,9 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]: df, fitted=fitted, prediction_intervals=prediction_intervals, + id_col=id_col, + time_col=time_col, + target_col=target_col, **best_config["mlf_fit_params"], ) return self diff --git a/nbs/auto.ipynb b/nbs/auto.ipynb index 63493b6c..bdb10958 100644 --- a/nbs/auto.ipynb +++ b/nbs/auto.ipynb @@ -589,7 +589,12 @@ "\n", " if loss is None:\n", " def loss(df, train_df): # noqa: ARG001\n", - " return smape(df, models=['model'])['model'].mean()\n", + " return smape(\n", + " df,\n", + " models=['model'],\n", + " id_col=id_col,\n", + " target_col=target_col,\n", + " )['model'].mean()\n", " if study_kwargs is None:\n", " study_kwargs = {}\n", " if 'sampler' not in study_kwargs:\n", @@ -629,8 +634,10 @@ " study.optimize(objective, n_trials=num_samples, **optimize_kwargs)\n", " self.results_[name] = study\n", " best_config = study.best_trial.user_attrs['config']\n", - " best_config['mlf_fit_params'].pop('fitted', None)\n", - " best_config['mlf_fit_params'].pop('prediction_intervals', None)\n", + " for arg in (\n", + " 'fitted', 'prediction_intervals', 'id_col', 'time_col', 'target_col'\n", + " ):\n", + " best_config['mlf_fit_params'].pop(arg, None)\n", " best_model = clone(auto_model.model)\n", " best_model.set_params(**best_config['model_params'])\n", " self.models_[name] = MLForecast(\n", @@ -642,6 +649,9 @@ " df,\n", " fitted=fitted,\n", " prediction_intervals=prediction_intervals,\n", + " id_col=id_col,\n", + " time_col=time_col,\n", + " target_col=target_col,\n", " **best_config['mlf_fit_params'],\n", " )\n", " return self\n", @@ -904,7 +914,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L574){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.predict\n", "\n", @@ -924,7 +934,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L574){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.predict\n", "\n", @@ -962,7 +972,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L606){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.save\n", "\n", @@ -978,7 +988,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L606){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.save\n", "\n", @@ -1012,7 +1022,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L616){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.forecast_fitted_values\n", "\n", @@ -1030,7 +1040,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L616){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### AutoMLForecast.forecast_fitted_values\n", "\n", @@ -1062,6 +1072,7 @@ "metadata": {}, "outputs": [], "source": [ + "import pandas as pd\n", "from datasetsforecast.m4 import M4, M4Evaluation, M4Info\n", "from sklearn.linear_model import Ridge\n", "from sklearn.compose import ColumnTransformer\n", @@ -1740,6 +1751,43 @@ "metric_step_1 = auto_mlf2.results_['ridge'].best_trial.value\n", "assert abs(metric_step_h / metric_step_1 - 1) > 0.02" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3c00fd0-ebee-4d40-b2aa-dc7ae98ee94c", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# default loss with non standard names\n", + "auto_mlf = AutoMLForecast(\n", + " freq=1,\n", + " season_length=season_length,\n", + " models={'ridge': AutoRidge()},\n", + ")\n", + "fit_kwargs = dict(\n", + " n_windows=2,\n", + " h=h,\n", + " step_size=1,\n", + " num_samples=2,\n", + " optimize_kwargs={'timeout': 60}, \n", + ")\n", + "preds = auto_mlf.fit(df=train, **fit_kwargs).predict(5)\n", + "\n", + "train2 = train.rename(columns={'unique_id': 'id', 'ds': 'time', 'y': 'target'})\n", + "preds2 = auto_mlf.fit(\n", + " df=train2,\n", + " id_col='id',\n", + " time_col='time',\n", + " target_col='target',\n", + " **fit_kwargs,\n", + ").predict(5)\n", + "pd.testing.assert_frame_equal(\n", + " preds,\n", + " preds2.rename(columns={'id': 'unique_id', 'time': 'ds'}),\n", + ")" + ] } ], "metadata": {