Skip to content

Commit

Permalink
Merge branch 'main' into documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghaoliu89 authored Aug 16, 2024
2 parents 07656c4 + ea6b44e commit 387903e
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ repos:
rev: 24.2.0
hooks:
- id: black
language_version: python3.10
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
name: isort (python)
5 changes: 5 additions & 0 deletions cli/conf/eval/model/moirai_lightning_ckpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: uni2ts.model.moirai.MoiraiForecast.load_from_checkpoint
checkpoint_path: ...
num_samples: 100
patch_size: ???
context_length: ???
1 change: 1 addition & 0 deletions cli/conf/finetune/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ run_name: ???
seed: 0
tf32: true
compile: false # set to mode: default, reduce-overhead, max-autotune
ckpt_path: null
trainer:
_target_: lightning.Trainer
accelerator: auto
Expand Down
1 change: 1 addition & 0 deletions cli/conf/pretrain/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ run_name: ???
seed: 0
tf32: true
compile: false # set to mode: default, reduce-overhead, max-autotune
ckpt_path: null # set to "last" to resume training
trainer:
_target_: lightning.Trainer
accelerator: auto
Expand Down
1 change: 1 addition & 0 deletions cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def main(cfg: DictConfig):
trainer.fit(
model,
datamodule=DataModule(cfg, train_dataset, val_dataset),
ckpt_path=cfg.ckpt_path,
)


Expand Down
14 changes: 14 additions & 0 deletions project/benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Benchmark
This directory contains the code and scripts for benchmarking.


## Chronos
`run_chronos.py` is the code to run Chronos on a given dataset.

`chronos_scripts` contains the scripts to run Chronos on different datasets.

Example:
```
sh chronos_scripts/monash_chronos_base.sh
```

6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/monash_chronos_base.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=base
model_path=amazon/chronos-t5-${model_size}
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/monash_chronos_mini.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=mini
model_path=amazon/chronos-t5-${model_size}
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/monash_chronos_small.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=small
model_path=amazon/chronos-t5-${model_size}
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
done
6 changes: 6 additions & 0 deletions project/benchmarks/chronos_scripts/monash_chronos_tiny.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_size=tiny
model_path=amazon/chronos-t5-${model_size}
for ds in us_births saugeenday sunspot_with_missing temperature_rain_with_missing covid_deaths hospital rideshare_with_missing traffic_weekly traffic_hourly fred_md car_parts_with_missing electricity_weekly electricity_hourly solar_weekly solar_10_minutes nn5_weekly nn5_daily_with_missing weather kdd_cup_2018_with_missing vehicle_trips_with_missing pedestrian_counts bitcoin_with_missing dominick australian_electricity_demand cif_2016_12 cif_2016_6 tourism_monthly tourism_quarterly m4_hourly m4_daily m4_weekly m4_monthly monash_m3_other monash_m3_monthly m1_monthly m1_yearly monash_m3_yearly m4_yearly tourism_yearly m1_quarterly monash_m3_quarterly m4_quarterly kaggle_web_traffic_weekly kaggle_web_traffic_with_missing bitcoin
do
python run_chronos.py --model_path=${model_path} --dataset=${ds} --run_name=chronos-${model_size}
done
125 changes: 125 additions & 0 deletions project/benchmarks/run_chronos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import argparse
import os

import numpy as np
import torch
from chronos import ChronosPipeline
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import (
MAE,
MAPE,
MASE,
MSE,
MSIS,
ND,
NRMSE,
RMSE,
SMAPE,
MeanWeightedSumQuantileLoss,
)
from gluonts.itertools import batcher

# from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from uni2ts.eval_util.data import get_gluonts_test_dataset
from uni2ts.eval_util.evaluation import evaluate_forecasts
from uni2ts.eval_util.metrics import MedianMSE


def evaluate(pipeline, dataset, save_path, num_samples=20, batch_size=512):
print("-" * 5, f"Evaluating {dataset}", "-" * 5)
test_data, metadata = get_gluonts_test_dataset(dataset)
prediction_length = metadata.prediction_length

while True:
try:
# Generate forecast samples
forecast_samples = []
for batch in tqdm(batcher(test_data.input, batch_size=batch_size)):
context = [torch.tensor(entry["target"]) for entry in batch]
forecast_samples.append(
pipeline.predict(
context,
prediction_length=prediction_length,
num_samples=num_samples,
limit_prediction_length=False, # We disable the limit on prediction length.
).numpy()
)
forecast_samples = np.concatenate(forecast_samples)
break
except torch.cuda.OutOfMemoryError:
print(
f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size//2}"
)
batch_size //= 2

# Convert forecast samples into gluonts SampleForecast objects
sample_forecasts = []
for item, ts in zip(forecast_samples, test_data.input):
forecast_start_date = ts["start"] + len(ts["target"])
sample_forecasts.append(
SampleForecast(samples=item, start_date=forecast_start_date)
)

# Evaluate
metrics_df = evaluate_forecasts(
sample_forecasts,
test_data=test_data,
metrics=[
MSE(),
MAE(),
MAPE(),
SMAPE(),
MSIS(),
RMSE(),
NRMSE(),
ND(),
MASE(),
MedianMSE(),
MeanWeightedSumQuantileLoss(np.arange(0.1, 1.0, 0.1)),
],
)
metrics_df.index = [dataset]
print(metrics_df)
metrics_df.to_csv(save_path)
print(f"Results saved to {save_path}")
print("-" * 5, f"Evaluation of {dataset} complete", "-" * 5)
return metrics_df


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Load a model and dataset, then make predictions."
)
parser.add_argument(
"--model_path", type=str, required=True, help="Path to load the model"
)
parser.add_argument(
"--dataset", type=str, required=True, help="Name of the dataset to use"
)
parser.add_argument(
"--save_dir", type=str, default="results", help="Directory to save the results"
)
parser.add_argument(
"--num_samples", type=int, default=20, help="Number of samples to generate"
)
parser.add_argument(
"--batch_size", type=int, default=512, help="Batch size for generating samples"
)
parser.add_argument("--run_name", type=str, default="test", help="Name of the run")

args = parser.parse_args()
# Load Chronos
pipeline = ChronosPipeline.from_pretrained(
# "amazon/chronos-t5-small",
args.model_path,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
output_dir = os.path.join(args.save_dir, args.run_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
evaluate(pipeline, args.dataset, os.path.join(output_dir, f"{args.dataset}.csv"))

0 comments on commit 387903e

Please sign in to comment.