Skip to content

Commit

Permalink
Merge pull request #32 from microsoft/fm_bench
Browse files Browse the repository at this point in the history
Add quick guides for benchmarking foundation models
  • Loading branch information
imJiawen authored Dec 13, 2024
2 parents ce62aff + 188a605 commit 11eb9be
Show file tree
Hide file tree
Showing 61 changed files with 796 additions and 138 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

1.sh
1.sh
log/
.vscode/
89 changes: 45 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

## News :tada:

:triangular_flag_on_post: **Dec 2024**: Added quick guides for benchmarking foundation models. Visit [this page](./docs/benchmark/foundation_model/README.md) for detailed instructions.

:triangular_flag_on_post: **Oct 2024**: ProbTS now includes the ElasTST model! Check out the [ElasTST branch](https://github.com/microsoft/ProbTS/tree/elastst) to reproduce all results reported in paper or run `bash scripts/run_elastst.sh` for a quick start.

:triangular_flag_on_post: **Oct 2024**: The [camera-ready version](https://arxiv.org/abs/2310.07446) of ProbTS is now available, with more in-depth analyses on the impact of normalization.
Expand All @@ -16,9 +18,9 @@
A wide range of industrial applications desire precise point and distributional forecasting for diverse prediction horizons. ProbTS serves as a benchmarking tool to aid in understanding how advanced time-series models fulfill these essential forecasting needs. It also sheds light on their advantages and disadvantages in addressing different challenges and unveil the possibilities for future research.

To achieve these objectives, ProbTS provides a unified pipeline that implements [cutting-edge models](#-available-models) from different research threads, including:
- Long-term point forecasting approaches, such as [PatchTST](https://arxiv.org/abs/2211.14730), [iTransformer](https://arxiv.org/abs/2310.06625), etc.
- Short-term probabilistic forecasting methods, such as [TimeGrad](https://arxiv.org/abs/2101.12072), [CSDI](https://arxiv.org/abs/2107.03502), etc.
- Recent time-series foundation models for universal forecasting, such as [TimesFM](https://arxiv.org/abs/2310.10688), [MOIRAI](https://arxiv.org/abs/2402.02592), etc.
- Supervised long-term point forecasting models, such as [PatchTST](https://arxiv.org/abs/2211.14730), [iTransformer](https://arxiv.org/abs/2310.06625), etc.
- Supervised short-term probabilistic forecasting models, such as [TimeGrad](https://arxiv.org/abs/2101.12072), [CSDI](https://arxiv.org/abs/2107.03502), etc.
- Pre-trained time-series foundation models for zero-shot forecasting, such as [TimesFM](https://arxiv.org/abs/2310.10688), [MOIRAI](https://arxiv.org/abs/2402.02592), etc.

Specifically, ProbTS emphasizes the differences in their primary methodological designs, including:
- Supporting point or distributional forecasts
Expand All @@ -37,36 +39,36 @@ ProbTS includes both classical time-series models, specializing in long-term poi
| **Model** | **Original Eval. Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |
| --- | --- | --- | --- | --- |
| Linear | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.LinearForecaster` |
| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.GRUForecaster` |
| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | Auto / Non-auto | `probts.model.forecaster.point_forecaster.TransformerForecaster` |
| [Autoformer](https://arxiv.org/abs/2106.13008) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.Autoformer` |
| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NHiTS` |
| [NLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.NLinear` |
| [DLinear](https://arxiv.org/abs/2205.13504) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.DLinear` |
| [TSMixer](https://arxiv.org/abs/2303.06053) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.TSMixer` |
| [TimesNet](https://arxiv.org/abs/2210.02186) | Short- / Long-term | Point | Non-auto | `probts.model.forecaster.point_forecaster.TimesNet` |
| [PatchTST](https://arxiv.org/abs/2211.14730) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.PatchTST` |
| [iTransformer](https://arxiv.org/abs/2310.06625) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.iTransformer` |
| [ElasTST](https://arxiv.org/abs/2411.01842) | Long-trem | Point | Non-auto | `probts.model.forecaster.point_forecaster.ElasTST` |
| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_NVP` |
| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.GRU_MAF` |
| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.Trans_MAF` |
| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short-term | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.TimeGrad` |
| [CSDI](https://arxiv.org/abs/2107.03502) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.CSDI` |
| [TSDiff](https://arxiv.org/abs/2307.11494) | Short-term | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.TSDiffCond` |

### Fundation Models
| [GRU](https://arxiv.org/abs/1412.3555) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.GRUForecaster` |
| [Transformer](https://arxiv.org/abs/1706.03762) | - | Point | AR / NAR | `probts.model.forecaster.point_forecaster.TransformerForecaster` |
| [Autoformer](https://arxiv.org/abs/2106.13008) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.Autoformer` |
| [N-HiTS](https://arxiv.org/abs/2201.12886) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NHiTS` |
| [NLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.NLinear` |
| [DLinear](https://arxiv.org/abs/2205.13504) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.DLinear` |
| [TSMixer](https://arxiv.org/abs/2303.06053) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.TSMixer` |
| [TimesNet](https://arxiv.org/abs/2210.02186) | Short / Long | Point | NAR | `probts.model.forecaster.point_forecaster.TimesNet` |
| [PatchTST](https://arxiv.org/abs/2211.14730) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.PatchTST` |
| [iTransformer](https://arxiv.org/abs/2310.06625) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.iTransformer` |
| [ElasTST](https://arxiv.org/abs/2411.01842) | Long | Point | NAR | `probts.model.forecaster.point_forecaster.ElasTST` |
| [GRU NVP](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_NVP` |
| [GRU MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.GRU_MAF` |
| [Trans MAF](https://arxiv.org/abs/2002.06103) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Trans_MAF` |
| [TimeGrad](https://arxiv.org/abs/2101.12072) | Short | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.TimeGrad` |
| [CSDI](https://arxiv.org/abs/2107.03502) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.CSDI` |
| [TSDiff](https://arxiv.org/abs/2307.11494) | Short | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.TSDiffCond` |

### Foundation Models

| **Model** | **Any Horizon** | **Estimation** | **Decoding Scheme** | **Class Path** |
| --- | --- | --- | --- | --- |
| [Lag-Llama](https://arxiv.org/abs/2310.08278) | ✔ | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.LagLlama` |
| [ForecastPFN](https://arxiv.org/abs/2311.01933) | ✔ | Point | Non-auto | `probts.model.forecaster.point_forecaster.ForecastPFN` |
| [TimesFM](https://arxiv.org/abs/2310.10688) | ✔ | Point | Auto | `probts.model.forecaster.point_forecaster.TimesFM` |
| [TTM](https://arxiv.org/abs/2401.03955) | ✘ | Point | Non-auto | `probts.model.forecaster.point_forecaster.TinyTimeMixer` |
| [Timer](https://arxiv.org/abs/2402.02368) | ✔ | Point | Auto | `probts.model.forecaster.point_forecaster.Timer` |
| [MOIRAI](https://arxiv.org/abs/2402.02592) | ✔ | Probabilistic | Non-auto | `probts.model.forecaster.prob_forecaster.Moirai` |
| [UniTS](https://arxiv.org/abs/2403.00131) | ✔ | Point | Non-auto | `probts.model.forecaster.point_forecaster.UniTS` |
| [Chronos](https://arxiv.org/abs/2403.07815) | ✔ | Probabilistic | Auto | `probts.model.forecaster.prob_forecaster.Chronos` |
| [Lag-Llama](https://arxiv.org/abs/2310.08278) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.LagLlama` |
| [ForecastPFN](https://arxiv.org/abs/2311.01933) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.ForecastPFN` |
| [TimesFM](https://arxiv.org/abs/2310.10688) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.TimesFM` |
| [TTM](https://arxiv.org/abs/2401.03955) | ✘ | Point | NAR | `probts.model.forecaster.point_forecaster.TinyTimeMixer` |
| [Timer](https://arxiv.org/abs/2402.02368) | ✔ | Point | AR | `probts.model.forecaster.point_forecaster.Timer` |
| [MOIRAI](https://arxiv.org/abs/2402.02592) | ✔ | Probabilistic | NAR | `probts.model.forecaster.prob_forecaster.Moirai` |
| [UniTS](https://arxiv.org/abs/2403.00131) | ✔ | Point | NAR | `probts.model.forecaster.point_forecaster.UniTS` |
| [Chronos](https://arxiv.org/abs/2403.07815) | ✔ | Probabilistic | AR | `probts.model.forecaster.prob_forecaster.Chronos` |

Stay tuned for more models to be added in the future.

Expand All @@ -87,8 +89,13 @@ pip install .
pip uninstall -y probts # recommended to uninstall the root package (optional)
```

[Optional] For time-series foundation models, you need to install basic packages and additional dependencies:
<details>

<summary>Optional for TSFMs reproducibility</summary>

For time-series foundation models, you need to install basic packages and additional dependencies:

**1. Set Up Environment**
```bash
# Create a new conda environment
conda create -n probts_fm python=3.10
Expand All @@ -100,17 +107,14 @@ git submodule update --init --recursive
# Install additional packages for foundation models
pip install ".[tsfm]"
pip uninstall -y probts # recommended to uninstall the root package (optional)
```

**2. Initialize Submodules**
```bash
# For MOIRAI, we fix the version of the package for better performance
cd submodules/uni2ts
git reset --hard fce6a6f57bc3bc1a57c7feb3abc6c7eb2f264301
```

<details>

<summary>Optional for TSFMs reproducibility</summary>

```bash
# For TimesFM, fix the version for reproducibility (optional)
cd submodules/timesfm
git reset --hard 5c7b905
Expand All @@ -128,14 +132,12 @@ git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc

### Datasets

For a complete dataset list, refer to the [Datasets Overview](./docs/documentation/README.md#datasets-overview).

- **Short-Term Forecasting**: We use datasets from [GluonTS](https://github.com/awslabs/gluonts).
Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}`. You can choose from multivariate or univariate datasets as per your requirement.
```bash
# Multivariate Datasets
['exchange_rate_nips', 'electricity_nips', 'traffic_nips', 'solar_nips', 'wiki2000_nips']

# Univariate Datasets
['tourism_monthly', 'tourism_quarterly', 'tourism_yearly', 'm4_hourly', 'm4_daily', 'm4_weekly', 'm4_monthly', 'm4_quarterly', 'm4_yearly', 'm5']
```

- **Long-Term Forecasting**: To download the [long-term forecasting datasets](https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy), please follow these steps:
Expand All @@ -145,7 +147,6 @@ git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc

Configure the datasets using `--data.data_manager.init_args.dataset {DATASET_NAME}` with the following list of available datasets:
```bash
# Long-term Forecasting
['etth1', 'etth2','ettm1','ettm2','traffic_ltsf', 'electricity_ltsf', 'exchange_ltsf', 'illness_ltsf', 'weather_ltsf', 'caiso', 'nordpool']
```
*Note: When utilizing long-term forecasting datasets, you must explicitly specify the `context_length` and `prediction_length` parameters. For example, to set a context length of 96 and a prediction length of 192, use the following command-line arguments:*
Expand All @@ -160,7 +161,7 @@ git reset --hard bb125c14a05e4231636d6b64f8951d5fe96da1dc
- Navigate to the target dataset, such as the [Electricity Hourly Dataset](https://zenodo.org/records/4656140).
- Download the `.tsf` file and place it in your local `datasets` directory (e.g., `./datasets`).

2. **Configure the Dataset**:
1. **Configure the Dataset**:
- Use the following configuration to specify the dataset, file path, and frequency:
```bash
--data.data_manager.init_args.dataset {DATASET_NAME} \
Expand Down Expand Up @@ -346,4 +347,4 @@ If you have used ProbTS for research or production, please cite it as follows.
booktitle={NeurIPS Datasets and Benchmarks Track},
year={2024}
}
```
```
11 changes: 11 additions & 0 deletions config/default/patchtst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ model:
fc_dropout: 0.2
head_dropout: 0
individual: false
optimizer_config:
class_name: torch.optim.Adam
init_args:
weight_decay: 0
lr_scheduler_config:
class_name: torch.optim.lr_scheduler.OneCycleLR
init_args:
max_lr: 0.0001
steps_per_epoch: 100
pct_start: 0.3
epochs: 50
learning_rate: 0.0001
quantiles_num: 20
data:
Expand Down
1 change: 1 addition & 0 deletions config/tsfm/moirai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ data:
dataset: solar_nips
split_val: true
scaler: identity # identity, standard, temporal
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/electricity_ltsf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/electricity_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 3800 # maximum history length
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/etth1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ data:
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/etth2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ data:
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/ettm1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ data:
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/ettm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ data:
split_val: true
scaler: standard # identity, standard, temporal
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/exchange_rate_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/solar_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: identity # identity, standard, temporal
var_specific_norm: false
context_length: 5000
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_5000/weather_ltsf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 5000
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/electricity_ltsf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 4
test_batch_size: 4
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/electricity_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/etth1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/etth2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/ettm1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/ettm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/exchange_rate_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/solar_nips.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: identity # identity, standard, temporal
var_specific_norm: false
context_length: 96
auto_search: true
batch_size: 1
test_batch_size: 1
num_workers: 8
1 change: 1 addition & 0 deletions config/tsfm/moirai/context_96/weather_ltsf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data:
scaler: standard # identity, standard, temporal
var_specific_norm: true
context_length: 96
auto_search: true
batch_size: 64
test_batch_size: 64
num_workers: 8
Loading

0 comments on commit 11eb9be

Please sign in to comment.