转摘【时空序列预测第三篇】PredRNN: Recurrent Neural Networks for Predictive Learning using Spatiotemporal LSTMs
前言
接下来保持住节奏,每周起码一篇paper reading,要时刻了解研究的前沿,是一个不管是工程岗位还是研究岗位AIer必备的工作,共勉!
一、Address
这是nips2017年的一篇paper,来自于清华的团队
PredRNN: Recurrent Neural Networks for Predictive Learning using Spatiotemporal LSTMs
http://ise.thss.tsinghua.edu.cn/ml/doc/2017/predrnn-nips17.pdf
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200131205417502.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
二、Introduction
2.1 创新思路
在Abstract中直接点名了本model的innovation,平时的时间和空间记忆都是在LSTM或者GRU cell中做文章,本paper的思路转移到stacked RNN layers中,即模型的堆叠结构中存在可以记忆的单元。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200131205733396.png)
2.2 时间信息和空间信息
![在这里插入图片描述](https://img-blog.csdnimg.cn/2020013121021722.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
文章在这里指出 时间信息和空间信息都是十分重要的。
并且在文章前面又是再一次的说2015-2017年左右的时空序列模型主要都集中在lstm的内部的memory的改造,并且主要集中于temporal的信息提取。
2.3 时空问题
这里作者又对时空序列问题进行一波定义和说明并且对施行建博士的开山之作ConvLSTM模型又进行介绍,这两个部分我都已介绍过了,请看我之前的文章。
[【时空序列预测第二篇】Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting](https://blog.csdn.net/qq_33431368/article/details/100053949)
但是这篇文章中指出了这样一个问题
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200131212749370.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
加入有四层的ConvLSTM的一个encoding-forcasting结构,输入帧进入第一层,将来的视频序列产生在第四层,在这个过程中,空间维度随着每层的cnn结构被逐步编码,而时间维度的memory cells属于彼此独立,在每个时间步被更新,这种情况下,最底层就会忽略之前的时间步中的最高层的时间信息,这也是ConvLSTM的层与层之间独立mermory mechanism的缺点。
实际上简单点说,就是这种简单的并行stacked结构中,堆叠之后层与层之间是独立的,t时刻的最底层cell会忽略到t-1时刻的最顶层cell的时间信息。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200131221427119.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
其实强调的就是对应色调的cell之间没有时间信息联系
三、PredRNN
3.1 Spatiotemporal memory flow
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200131221616256.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
对于研究时空序列预测问题,network的 basic building blocks一般先采用ConvLSTM进行研究
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201093157677.png)
这里再次强调结构是每一层每一层的extract,并且cell states只在 水平方向,其实说的就是每一层独立,c只在每一层的时间步传播,而 有一部分信息,这部分主要是空间信息只在 hidden state上向上传播。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201093413390.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
![在这里插入图片描述](https://img-blog.csdnimg.cn/2020020109355269.png)
我们假设输入序列的信息应该是被保留的,我需要不同level cnn提取到的信息。
其实意思就是每一个输入,经过每一层网络结构有一个信息提取,这个提取到的最后的抽象信息,应该是需要保留给下一次第一层的输入的。
所以提出这样一个网络结构
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201093840179.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
其中M就是cell output,只是为了图中区别,标成了M。
此时的ConvLSTM公式为
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201093933221.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
原始的ConvLSTM公式为
![在这里插入图片描述](https://img-blog.csdnimg.cn/2020020109425958.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
这里我用各种颜色的标注一下你就知道区别了,其实就是根据结构来改变的公式本身。
原始的ConvLSTM
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201094430925.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
输入的hidden state和cell output都是上一个时刻的
此时更改的结构:
![在这里插入图片描述](https://img-blog.csdnimg.cn/2020020109461418.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
红色表示在非最底层时的单个网络cell的公式变换,输入的hidden state和cell output都是前一层的(L-1)
而公式中的紫色部分说明L=1的时候有特殊情况,即图中的折线top到bottom的传播部分
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201094845552.png)
这幅图比较直观
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201095121505.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
3.2 Spatiotemporal LSTM
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201095440977.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
这里文中又指出上面提出的那出结构的一些缺点
- 去掉水平方向的时间流,会牺牲时间上的一致性,因为在同一层的不同时间没有时间流了。
- 记忆需要在遥远的状态之间流动更长的路径,更容易造成梯度消失。
所以引入了一个新的building blocks为ST-LSTM。我想大家可能不会陌生对于它。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201100136149.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
我们可以转换成更为肉眼所理解的图。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201100241598.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
你会惊奇的发现这上下其实是完全一样的。
我们再来看一下LSTM的结构。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201101205939.png)
你细品,发现没,其实这两个完全一样的结构就是LSTM,只是下面的cell output和hidden state都由M代替了,其他的输出部分其实就相当于把两个LSTM结构的输出整合在一起分别输出计算了,我这里自己标了一下供大家来观察。
文中把上半部分称为'Standard Temporal Memory',下半部分称为'Spatiotemporal Memory',上半部分和普通的LSTM 没有任何区别,下半部分相当于把c和h一起更改为M,M即时空记忆状态。
(下图来自于https://blog.csdn.net/The_lastest/article/details/88230959,感谢这位朋友的付出)
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201102041612.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
ST-LSTM的公式![在这里插入图片描述](https://img-blog.csdnimg.cn/2020020110220034.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
现在这图结构图迎刃而解了。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201102557190.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
相当于在原始的基础结构上,多了一个M 状态,用M状态进行折线连接上一时刻top层的信息流入到此时刻的bottom层。并且在垂直方向引入M状态。其实你再仔细看看,这个结构其实就相当于把前面3.1(left)所讲的结构中的两个状态整合在一起成为一个状态M,之后把这个结构和3.1(right)进行整合,最后得到上图的结构,不同的是这里用一个ST-LSTM巧妙的解决了这个问题。
四、 Experiments
这里只简单说明下Moving MNIST dataset数据集和雷达数据集的结果(可以和上一篇对比)
对应的训练参数
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201104534918.png)
4.1 Moving MNIST dataset
Moving MNIST dataset数据集不再过多介绍,请看我第二篇时空序列文章。
这里与之前那篇不同的在于数据集的玩法,这里清华团队是自己随机生成train数据集,而test是固定的
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201152504141.png)
对运动的数字给一个速度,和随机的方向, 这个方向是单位圆也就是360°等分的一个角度,之后运动的振幅在3到5之间,并且存在两个数字的位置有覆盖的情况,故理论上可以生成无线数量的训练数据集。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201152745185.png)
作者测试集的用法是每次挑选训练数据中,也就是除去与随机生成的训练数据集中相同的样本以外的测试数据集作为最终的测试数据集。
并且用两个数字的训练集训练的模型去预测图中有三个数字的测试集,这也是ConvLSTM中同样用到的测试方法,无非是想测试模型的泛化性和迁移性 。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201153118788.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
可以看到ST-LSTM的PredRNN的效果最好,这里给出的参数最好表现是128的hidden state 维度和4层的stacked结构
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201153622951.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
几个模型的结果,很直观的可以看到对于数字没有重叠的情况下,PredRNN与VPN baseline效果差不多,但是在有重叠的情况下,VPN baseline把8预测成了3,文中把这种预测的情况叫成 sharp,说明VPN baseline模型对于复杂的情况还是没法很好的预测,并且整体的模型都是对于长时间的预测随着时间步的越来越长,变得越来越模糊。
4.2 Radar echo dataset
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201154211327.png)
这里的这个总结我十分赞同且准确,对于雷达数据集的最难的地方就在于它没有所谓的明显的周期性,并且移动的速度也是不固定的,变换也不是极具严格的,比如Moving MNIST dataset数据集运动的对象是数字,这个数字本身空间的信息基本上是不变的,这个和识别问题类似,而雷达数据集会因为各种天气原因,慢慢的积累、消散或变化,或者快速的积累、消散或变化,所以预测问题也是十分艰难的,其实本身数据还有着大量的噪声,因为地形等因素造成的。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201154644359.png)
这里其实在数据准备阶段说的较施行建博士的文章要相对清晰,10000个连续的雷达数据,每6分钟一个,转换成图片并压缩到100✖100大小,切片序列为20,输入10,输出10, 总共9600个序列,其中随机分到7800为训练集,1800为测试集,这个方法在时空序列预测问题上很常见,基本上的baseline的代码都有这个步骤,如果自己处理整体连续数据的话。
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200201155850769.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9jaGVob25nc2h1LmJsb2cuY3Nkbi5uZXQ=,size_16,color_FFFFFF,t_70)
因为降水预测需要实时性,所以这里把训练速度以及占用的内存全都列出来了。
可以很直观的看出predrnn的效果确实要较ConvLSTM和VPN baseline要好很多。并且运行速度也不是特别慢(VPN就很慢,因为它的预测是递归的,预测下一个时刻,之后再利用预测下一时刻的去预测下下一时刻,比较耗时)
五、Conclusions
- 提出了一个新的端到端结构PredRNN
- 提出了新的LSTM结构,ST-LSTM,并作为PredRNN的basic building blocks
- 得到了最好的结果在时空序列预测数据集以及问题上
又不知不觉,码了8k多字,不为了别的,就为了简单、通俗、易懂、全面,共勉。
===========================
【来源: CSDN】
【作者: AI蜗牛车】
【原文链接】 https://chehongshu.blog.csdn.net/article/details/104128016
声明:转载此文是出于传递更多信息之目的。若有来源标注错误或侵犯了您的合法权益,请作者持权属证明与本网联系,我们将及时更正、删除,谢谢。