-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into documentation
- Loading branch information
Showing
11 changed files
with
173 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_target_: uni2ts.model.moirai.MoiraiForecast.load_from_checkpoint | ||
checkpoint_path: ... | ||
num_samples: 100 | ||
patch_size: ??? | ||
context_length: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Benchmark | ||
This directory contains the code and scripts for benchmarking. | ||
|
||
|
||
## Chronos | ||
`run_chronos.py` is the code to run Chronos on a given dataset. | ||
|
||
`chronos_scripts` contains the scripts to run Chronos on different datasets. | ||
|
||
Example: | ||
``` | ||
sh chronos_scripts/monash_chronos_base.sh | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model_size=base | ||
model_path=amazon/chronos-t5-${model_size} | ||
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin | ||
do | ||
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model_size=mini | ||
model_path=amazon/chronos-t5-${model_size} | ||
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin | ||
do | ||
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model_size=small | ||
model_path=amazon/chronos-t5-${model_size} | ||
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin | ||
do | ||
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model_size=tiny | ||
model_path=amazon/chronos-t5-${model_size} | ||
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin | ||
do | ||
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
from chronos import ChronosPipeline | ||
from gluonts.dataset.repository import get_dataset | ||
from gluonts.dataset.split import split | ||
from gluonts.ev.metrics import ( | ||
MAE, | ||
MAPE, | ||
MASE, | ||
MSE, | ||
MSIS, | ||
ND, | ||
NRMSE, | ||
RMSE, | ||
SMAPE, | ||
MeanWeightedSumQuantileLoss, | ||
) | ||
from gluonts.itertools import batcher | ||
|
||
# from gluonts.model.evaluation import evaluate_forecasts | ||
from gluonts.model.forecast import SampleForecast | ||
from tqdm.auto import tqdm | ||
|
||
from uni2ts.eval_util.data import get_gluonts_test_dataset | ||
from uni2ts.eval_util.evaluation import evaluate_forecasts | ||
from uni2ts.eval_util.metrics import MedianMSE | ||
|
||
|
||
def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512): | ||
print("-" * 5, f"Evaluating {dataset}", "-" * 5) | ||
test_data, metadata = get_gluonts_test_dataset(dataset) | ||
prediction_length = metadata.prediction_length | ||
|
||
while True: | ||
try: | ||
# Generate forecast samples | ||
forecast_samples = [] | ||
for batch in tqdm(batcher(test_data.input, batch_size=batch_size)): | ||
context = [torch.tensor(entry["target"]) for entry in batch] | ||
forecast_samples.append( | ||
pipeline.predict( | ||
context, | ||
prediction_length=prediction_length, | ||
num_samples=num_samples, | ||
limit_prediction_length=False, # We disable the limit on prediction length. | ||
).numpy() | ||
) | ||
forecast_samples = np.concatenate(forecast_samples) | ||
break | ||
except torch.cuda.OutOfMemoryError: | ||
print( | ||
f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size//2}" | ||
) | ||
batch_size //= 2 | ||
|
||
# Convert forecast samples into gluonts SampleForecast objects | ||
sample_forecasts = [] | ||
for item, ts in zip(forecast_samples, test_data.input): | ||
forecast_start_date = ts["start"] + len(ts["target"]) | ||
sample_forecasts.append( | ||
SampleForecast(samples=item, start_date=forecast_start_date) | ||
) | ||
|
||
# Evaluate | ||
metrics_df = evaluate_forecasts( | ||
sample_forecasts, | ||
test_data=test_data, | ||
metrics=[ | ||
MSE(), | ||
MAE(), | ||
MAPE(), | ||
SMAPE(), | ||
MSIS(), | ||
RMSE(), | ||
NRMSE(), | ||
ND(), | ||
MASE(), | ||
MedianMSE(), | ||
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)), | ||
], | ||
) | ||
metrics_df.index = [dataset] | ||
print(metrics_df) | ||
metrics_df.to_csv(save_path) | ||
print(f"Results saved to {save_path}") | ||
print("-" * 5, f"Evaluation of {dataset} complete", "-" * 5) | ||
return metrics_df | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Load a model and dataset, then make predictions." | ||
) | ||
parser.add_argument( | ||
"--model_path", type=str, required=True, help="Path to load the model" | ||
) | ||
parser.add_argument( | ||
"--dataset", type=str, required=True, help="Name of the dataset to use" | ||
) | ||
parser.add_argument( | ||
"--save_dir", type=str, default="results", help="Directory to save the results" | ||
) | ||
parser.add_argument( | ||
"--num_samples", type=int, default=20, help="Number of samples to generate" | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=512, help="Batch size for generating samples" | ||
) | ||
parser.add_argument("--run_name", type=str, default="test", help="Name of the run") | ||
|
||
args = parser.parse_args() | ||
# Load Chronos | ||
pipeline = ChronosPipeline.from_pretrained( | ||
# "amazon/chronos-t5-small", | ||
args.model_path, | ||
device_map="cuda:0", | ||
torch_dtype=torch.bfloat16, | ||
) | ||
output_dir = os.path.join(args.save_dir, args.run_name) | ||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
evaluate(pipeline, args.dataset, os.path.join(output_dir, f"{args.dataset}.csv")) |