Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Deprecate activation functions for GRU #1198

Merged
merged 6 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions nbs/models.gru.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -70,6 +79,7 @@
"outputs": [],
"source": [
"#| export\n",
"import warnings\n",
"from typing import Optional\n",
"\n",
"import torch\n",
Expand All @@ -91,7 +101,7 @@
" \"\"\" GRU\n",
"\n",
" Multi Layer Recurrent Network with Gated Units (GRU), and\n",
" MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained \n",
" MLP decoder. The network has non-linear activation functions, it is trained \n",
" using ADAM stochastic gradient descent. The network accepts static, historic \n",
" and future exogenous data, flattens the inputs.\n",
"\n",
Expand All @@ -101,7 +111,7 @@
" `inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>\n",
" `encoder_n_layers`: int=2, number of layers for the GRU.<br>\n",
" `encoder_hidden_size`: int=200, units for the GRU's hidden state size.<br>\n",
" `encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.<br>\n",
" `encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.<br>\n",
" `encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.<br>\n",
" `encoder_dropout`: float=0., dropout regularization applied to GRU outputs.<br>\n",
" `context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>\n",
Expand Down Expand Up @@ -143,7 +153,7 @@
" inference_input_size: int = -1,\n",
" encoder_n_layers: int = 2,\n",
" encoder_hidden_size: int = 200,\n",
" encoder_activation: str = 'tanh',\n",
" encoder_activation: Optional[str] = None,\n",
" encoder_bias: bool = True,\n",
" encoder_dropout: float = 0.,\n",
" context_size: int = 10,\n",
Expand Down Expand Up @@ -199,6 +209,14 @@
" **trainer_kwargs\n",
" )\n",
"\n",
" if encoder_activation is not None:\n",
" warnings.warn(\n",
" \"The 'encoder_activation' argument is deprecated and will be removed in \"\n",
" \"future versions. The activation function in GRU is frozen in PyTorch and \"\n",
" \"it cannot be modified.\",\n",
" DeprecationWarning,\n",
" )\n",
"\n",
" # RNN\n",
" self.encoder_n_layers = encoder_n_layers\n",
" self.encoder_hidden_size = encoder_hidden_size\n",
Expand Down Expand Up @@ -322,7 +340,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import GRU\n",
"# from neuralforecast.models import GRU\n",
"from neuralforecast.losses.pytorch import DistributionLoss\n",
"from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
"\n",
Expand Down
19 changes: 14 additions & 5 deletions neuralforecast/models/gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# %% auto 0
__all__ = ['GRU']

# %% ../../nbs/models.gru.ipynb 6
# %% ../../nbs/models.gru.ipynb 7
import warnings
from typing import Optional

import torch
Expand All @@ -13,12 +14,12 @@
from ..common._base_recurrent import BaseRecurrent
from ..common._modules import MLP

# %% ../../nbs/models.gru.ipynb 7
# %% ../../nbs/models.gru.ipynb 8
class GRU(BaseRecurrent):
"""GRU

Multi Layer Recurrent Network with Gated Units (GRU), and
MLP decoder. The network has `tanh` or `relu` non-linearities, it is trained
MLP decoder. The network has non-linear activation functions, it is trained
using ADAM stochastic gradient descent. The network accepts static, historic
and future exogenous data, flattens the inputs.

Expand All @@ -28,7 +29,7 @@ class GRU(BaseRecurrent):
`inference_input_size`: int, maximum sequence length for truncated inference. Default -1 uses all history.<br>
`encoder_n_layers`: int=2, number of layers for the GRU.<br>
`encoder_hidden_size`: int=200, units for the GRU's hidden state size.<br>
`encoder_activation`: str=`tanh`, type of GRU activation from `tanh` or `relu`.<br>
`encoder_activation`: Optional[str]=None, Deprecated. Activation function in GRU is frozen in PyTorch.<br>
`encoder_bias`: bool=True, whether or not to use biases b_ih, b_hh within GRU units.<br>
`encoder_dropout`: float=0., dropout regularization applied to GRU outputs.<br>
`context_size`: int=10, size of context vector for each timestamp on the forecasting window.<br>
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
inference_input_size: int = -1,
encoder_n_layers: int = 2,
encoder_hidden_size: int = 200,
encoder_activation: str = "tanh",
encoder_activation: Optional[str] = None,
encoder_bias: bool = True,
encoder_dropout: float = 0.0,
context_size: int = 10,
Expand Down Expand Up @@ -129,6 +130,14 @@ def __init__(
**trainer_kwargs
)

if encoder_activation is not None:
warnings.warn(
"The 'encoder_activation' argument is deprecated and will be removed in "
"future versions. The activation function in GRU is frozen in PyTorch and "
"it cannot be modified.",
DeprecationWarning,
)

# RNN
self.encoder_n_layers = encoder_n_layers
self.encoder_hidden_size = encoder_hidden_size
Expand Down