在处理序列预测问题时,长短期记忆网络(Long Short Term Memory,简称LSTM)是一种特殊的递归神经网络(Recurrent Neural Network,简称RNN),它能够学习序列数据中的顺序依赖性。与传统RNN不同,LSTM能够解决长期依赖问题,即在长序列中保持信息的能力,使其在时间序列数据处理、预测和分类任务中表现出色。LSTM由Hochreiter和Schmidhuber提出,它的设计初衷是为了改善传统RNN在处理长序列时的性能下降问题。
LSTM网络由多个神经网络单元和记忆块组成,这些记忆块被称为“单元”,它们以链式结构排列。一个标准的LSTM单元包括一个单元、一个输入门、一个输出门和一个遗忘门。这三个门控制着信息流入和流出单元的过程,而单元则负责在任意时间间隔内保持值。LSTM算法非常适合对不确定时长的时间序列进行分类、分析和预测。
LSTM单元中的记忆块负责存储信息,而门则负责操作记忆。具体来说,有三个门控制着信息的流动:
输入门:它决定哪些输入值应该用来改变记忆。sigmoid函数决定是否允许0或1值通过,而tanh函数则为提供的数据分配权重,确定它们的重要性,范围在-1到1之间。
遗忘门:它确定应该从记忆中移除哪些细节。这也是由sigmoid函数决定的。对于单元状态Ct-1中的每个数字,它都会查看前一个状态(ht-1)和内容输入(Xt),并产生一个介于0(忽略这个)和1(保留这个)之间的数字。
输出门:它使用块的输入和记忆来确定输出。sigmoid函数决定是否允许0或1值通过,而tanh函数则确定哪些值被允许通过0,1。同时,tanh函数为提供的数据分配权重,确定它们的重要性,并将其与sigmoid输出相乘。
所有递归神经网络都由一系列重复的神经网络模块组成。在传统RNN中,这个重复模块可能结构简单,例如只有一个tanh层。当前时间步的输出成为下一个时间步的输入,这就是所谓的递归。在序列的每个元素中,模型不仅查看当前输入,还查看它对之前元素的了解。
LSTM的重复模块由四个相互作用的层组成。在图中顶部的水平线表示单元状态,对LSTM至关重要。单元状态在某些方面类似于传送带。当它沿着整个链移动时,只有一些微小的线性交互。数据可以很容易地沿着它向下移动而不受改变。
LSTM可以通过称为门的结构来删除或添加单元状态中的信息,这些门严格控制着信息的流动。门由sigmoid神经网络层和逐点乘法操作组成。sigmoid层产生从零到一的整数,指示应该允许多少每个组成部分通过。值为零表示“不允许任何东西通过”,而值为一表示“允许所有东西通过”。
LSTM的工作周期分为四个步骤:
1. 使用遗忘门,从之前的时间步中识别出要忘记的信息。
2. 使用输入门和tanh,寻找更新单元状态的新信息。
3. 使用上述两个门的信息来更新单元状态。
4. 使用输出门和压缩操作提供有用信息。
双向LSTM是LSTM的一个常见改进。在双向递归神经网络(BRNN)中,每个训练序列都以正向和反向呈现给两个独立的递归网络,这两个网络都与同一个输出层耦合。这意味着BRNN对给定序列中的每个点之前和之后的所有点都有全面的、顺序的知识。
传统RNN的缺点是只能使用之前的上下文。双向RNN(BRNN)通过以两种方式处理数据,并通过两个隐藏层向前馈送到同一个输出层来实现这一点。当BRNN和LSTM结合时,将得到一个可以访问两个输入方向的长距离上下文的双向LSTM。
LSTM有许多知名的应用,包括:
- 图像描述
- 机器翻译
- 语言建模
- 手写生成
- 问答聊天机器人
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, Embedding, Masking
model = Sequential()
model.add(Embedding(input_dim=num_words, input_length=training_length, output_dim=100, weights=[embedding_matrix], trainable=False, mask_zero=True))
model.add(Masking(mask_value=0.0))
model.add(LSTM(64, return_sequences=False, dropout=0.1, recurrent_dropout=0.1))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_words, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])