transformer结构在近几年可谓大放异彩,这几天正好有个时间序列预测的小场景,所以动手试了一下transformer的时序预测能力。同时,为了作为对比,也使用了xgboost树模型对预测效果进行了对比。
数据准备
在数据方面,手动生成了具有一定周期和趋势的一维序列数据,数据形态和代码如下:
|
|
transformer时间序列模型
本文使用的transformer预测时序任务架构参考的是 https://arxiv.org/abs/2001.08317, 整体结构如下:
由于我们的数据是时序的,因此训练的时候需要对src和tgt做mask,同时tgt的输入和输出是一对一的关系。
在做推理时,我们往往只知道需要预测的tgt序列当中最早的一个数据,因此我们需要循环着做推理,即把上一时刻的预测值加入到tgt输入中去预测下一时刻的值,再把下一时刻的值加入tgt如此循环。
最终transformer时间序列预测的效果图如下,预测值mae大概0.122:
xgboost时间序列模型
作为对比,使用了传统的提升树模型xgboost训练了同样一组数据,简单的做了一些调参,预测值mae大概0.124,
指标上看xgboost只比transformer模型弱了一点点,但是从可视化的图片来看,tranformer模型的预测结果方差要比xgboost小一些,从数据上看也是这样,transformer模型的预测误差方差为0.084,而xgboost预测误差方差为0.106。
具体实现代码如下:
TimeSeriesTransformer类:
TransformerDataset类
|
|
训练和推理代码
xgboost训练代码: