-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] Unify API #1023
base: main
Are you sure you want to change the base?
[FIX] Unify API #1023
Changes from 74 commits
72771c0
dd9f26e
e0ee8d1
e7bbf30
3419432
75bea55
ef019d1
4313c13
8101656
14fbf32
ae6d73c
302489e
0dcb6a2
9160647
b73c097
20c18c5
a26ac29
b20fe3f
1070f1d
452388f
2d3762f
2419eb5
bffa8d1
f02b50f
f80c59b
a60498b
b5ba554
f4de0ff
a4ec70d
706ef74
829fc17
efe2e76
47c36f7
998e813
99c4b14
a4e4ee7
ff89950
b3fafc3
87af3ac
af070a9
ec32f28
2801f19
6c3b2af
ccf8b2d
8cba223
e35f5e1
5fc0437
9c52adb
9d5a2bc
97507f0
5494554
e9bc822
baf7014
fffbda3
9c727cc
1a0ba55
6bb64be
a8a9362
d6e24de
430732f
6a472dc
ae49324
d681fdf
abe522b
932fd55
1f52b8e
0b980c0
63984e6
bbea7ba
6f2272c
a4c8b54
8ee4592
96ab536
030dabe
ddc617f
c529ced
b2c7691
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we could also merge let me know what you think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, then we can still fire up multiple runners (for the sake of keeping test time under control it makes sense to split the tests) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,35 +2,39 @@ | |
import time | ||
|
||
import fire | ||
import numpy as np | ||
# import numpy as np | ||
import pandas as pd | ||
import pytorch_lightning as pl | ||
import torch | ||
# import pytorch_lightning as pl | ||
# import torch | ||
|
||
import neuralforecast | ||
# import neuralforecast | ||
from neuralforecast.core import NeuralForecast | ||
|
||
from neuralforecast.models.gru import GRU | ||
from neuralforecast.models.rnn import RNN | ||
from neuralforecast.models.tcn import TCN | ||
# from neuralforecast.models.rnn import RNN | ||
# from neuralforecast.models.tcn import TCN | ||
from neuralforecast.models.lstm import LSTM | ||
from neuralforecast.models.dilated_rnn import DilatedRNN | ||
from neuralforecast.models.deepar import DeepAR | ||
from neuralforecast.models.mlp import MLP | ||
from neuralforecast.models.nhits import NHITS | ||
from neuralforecast.models.nbeats import NBEATS | ||
# from neuralforecast.models.deepar import DeepAR | ||
# from neuralforecast.models.mlp import MLP | ||
# from neuralforecast.models.nhits import NHITS | ||
# from neuralforecast.models.nbeats import NBEATS | ||
Comment on lines
+14
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the commented code going to be restored in the future? if this change is permanent, maybe we could delete those lines instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm kind of treating this file also as a testing file locally, we can delete it (it's mainly so that testing locally is faster that you don't need to type all that stuff every time) |
||
from neuralforecast.models.nbeatsx import NBEATSx | ||
from neuralforecast.models.tft import TFT | ||
from neuralforecast.models.vanillatransformer import VanillaTransformer | ||
from neuralforecast.models.informer import Informer | ||
from neuralforecast.models.autoformer import Autoformer | ||
from neuralforecast.models.patchtst import PatchTST | ||
# from neuralforecast.models.tft import TFT | ||
# from neuralforecast.models.vanillatransformer import VanillaTransformer | ||
# from neuralforecast.models.informer import Informer | ||
# from neuralforecast.models.autoformer import Autoformer | ||
# from neuralforecast.models.patchtst import PatchTST | ||
|
||
from neuralforecast.auto import ( | ||
AutoMLP, AutoNHITS, AutoNBEATS, AutoDilatedRNN, AutoTFT | ||
# AutoMLP, | ||
AutoNHITS, | ||
AutoNBEATS, | ||
# AutoDilatedRNN, | ||
# AutoTFT | ||
) | ||
|
||
from neuralforecast.losses.pytorch import SMAPE, MAE | ||
from neuralforecast.losses.pytorch import MAE | ||
from ray import tune | ||
|
||
from src.data import get_data | ||
|
@@ -49,32 +53,18 @@ def main(dataset: str = 'M3', group: str = 'Monthly') -> None: | |
"scaler_type": "minmax1", | ||
"random_seed": tune.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), | ||
} | ||
config = { | ||
"hidden_size": tune.choice([256, 512]), | ||
"num_layers": tune.choice([2, 4]), | ||
"input_size": tune.choice([2 * horizon]), | ||
"max_steps": 1000, | ||
"val_check_steps": 300, | ||
"scaler_type": "minmax1", | ||
"random_seed": tune.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), | ||
} | ||
config_drnn = {'input_size': tune.choice([2 * horizon]), | ||
'encoder_hidden_size': tune.choice([124]), | ||
"max_steps": 300, | ||
"val_check_steps": 100, | ||
"random_seed": tune.choice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),} | ||
models = [ | ||
LSTM(h=horizon, input_size=2 * horizon, encoder_hidden_size=50, max_steps=300), | ||
DilatedRNN(h=horizon, input_size=2 * horizon, encoder_hidden_size=50, max_steps=300), | ||
GRU(h=horizon, input_size=2 * horizon, encoder_hidden_size=50, max_steps=300), | ||
LSTM(h=horizon, input_size=2 * horizon, encoder_hidden_size=64, max_steps=300), | ||
DilatedRNN(h=horizon, input_size=2 * horizon, encoder_hidden_size=64, max_steps=300), | ||
GRU(h=horizon, input_size=2 * horizon, encoder_hidden_size=64, max_steps=300), | ||
AutoNBEATS(h=horizon, loss=MAE(), config=config_nbeats, num_samples=2, cpus=1), | ||
AutoNHITS(h=horizon, loss=MAE(), config=config_nbeats, num_samples=2, cpus=1), | ||
NBEATSx(h=horizon, input_size=2 * horizon, loss=MAE(), max_steps=1000), | ||
PatchTST(h=horizon, input_size=2 * horizon, patch_len=4, stride=4, loss=MAE(), scaler_type='minmax1', windows_batch_size=512, max_steps=1000, val_check_steps=500), | ||
# PatchTST(h=horizon, input_size=2 * horizon, patch_len=4, stride=4, loss=MAE(), scaler_type='minmax1', windows_batch_size=512, max_steps=1000, val_check_steps=500), | ||
] | ||
|
||
# Models | ||
for model in models[:-1]: | ||
for model in models: | ||
model_name = type(model).__name__ | ||
print(50*'-', model_name, 50*'-') | ||
start = time.time() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i noticed that this file and
action_files/test_models/src/evaluation2.py
are quite similar. i have a couple of suggestions:utilsforecast
evaluation features. we could replacemae
andsmape
from thelosses
module and theevaluate
function..github/workflows/ci.yaml
file.the idea would be to abstract the code in the
if __name__ == '__main__':
clause, something like this:and then you could use fire inside the main clause:
this way, we can run it for different models inside
.github/workflows/ci.yaml
:python -m action_files.test_models.src.evaluation --models <list of models>
. wdyt?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this also could apply to
action_files/test_models/src/multivariate_evaluation.py
. since we are changing models and datasets, we could definemain(models: list, dataset: str)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, one remarkt - why favour ci over circleci? (I'm ambivalent, don't know why we would prefer one over the other)