Skip to content

基于pytorch实现的时间序列预测训练框架,各个部分模块化,方便修改模型。包含时间序列预测模型、训练、验证、测试、wandb可视化、onnx导出、onnx推理、tensorrt导出、tensorrt推理。

Notifications You must be signed in to change notification settings

TWK2022/TimeSeriesForecasting

Repository files navigation

pytorch时间序列预测训练框架

代码兼容性较强,使用的是一些基本的库、基础的函数
在argparse中可以选择使用wandb,能在wandb网站中生成可视化的训练过程

1,环境

torch:https://pytorch.org/get-started/previous-versions/

pip install tqdm wandb -i https://pypi.tuna.tsinghua.edu.cn/simple

2,数据格式

参考dataset中的样例

3,run.py

模型训练时运行该文件,argparse中有对每个参数的说明

4,predict_pt.py

使用训练好的pt模型预测

5,export_onnx.py

将pt模型导出为onnx模型

6,predict_onnx.py

使用导出的onnx模型预测

7,export_trt_record

文档中有onnx模型导出为tensort模型的详细说明

8,predict_trt.py

使用导出的trt模型预测

其他

学习笔记:https://github.com/TWK2022/notebook


ETTh1.csv

模型(m) input_column output_column input_size output_size divide train_mse_decay val_mse
tsf all all 96 24 19:1 0.188 0.268
itransformer all all 96 24 19:1 0.219 0.260
nlinear all all 96 24 19:1 0.228 0.255
lstm all all 96 24 19:1 0.241 0.260
linear all all 96 24 19:1 0.247 0.267
pathformer all all 96 24 19:1 0.229 0.275
crossformer all all 96 24 19:1 0.258 0.278
diffusion_ts all all 96 24 19:1 0.212 0.330

About

基于pytorch实现的时间序列预测训练框架,各个部分模块化,方便修改模型。包含时间序列预测模型、训练、验证、测试、wandb可视化、onnx导出、onnx推理、tensorrt导出、tensorrt推理。

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages