代码兼容性较强,使用的是一些基本的库、基础的函数
在argparse中可以选择使用wandb,能在wandb网站中生成可视化的训练过程
torch:https://pytorch.org/get-started/previous-versions/
pip install tqdm wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
参考dataset中的样例
模型训练时运行该文件,argparse中有对每个参数的说明
使用训练好的pt模型预测
将pt模型导出为onnx模型
使用导出的onnx模型预测
文档中有onnx模型导出为tensort模型的详细说明
使用导出的trt模型预测
模型(m) | input_column | output_column | input_size | output_size | divide | train_mse_decay | val_mse |
---|---|---|---|---|---|---|---|
tsf | all | all | 96 | 24 | 19:1 | 0.188 | 0.268 |
itransformer | all | all | 96 | 24 | 19:1 | 0.219 | 0.260 |
nlinear | all | all | 96 | 24 | 19:1 | 0.228 | 0.255 |
lstm | all | all | 96 | 24 | 19:1 | 0.241 | 0.260 |
linear | all | all | 96 | 24 | 19:1 | 0.247 | 0.267 |
pathformer | all | all | 96 | 24 | 19:1 | 0.229 | 0.275 |
crossformer | all | all | 96 | 24 | 19:1 | 0.258 | 0.278 |
diffusion_ts | all | all | 96 | 24 | 19:1 | 0.212 | 0.330 |