长短期记忆网络(LSTM)的解析

长短期记忆网络(Long Short-Term Memory,简称LSTM)是一种特殊的递归神经网络(Recurrent Neural Network,简称RNN),它在处理序列数据时表现出色。LSTM网络之所以能够取得如此显著的成就,是因为它解决了传统RNN所面临的一些关键问题。在90年代,科学家们就已经注意到了RNN在实际应用中的一些障碍,其中最主要的是长期依赖问题和梯度消失/爆炸问题。

长期依赖问题

在之前的文章中,已经提到了长期依赖的概念。使用RNN的一个优势是可以将先前的信息与当前任务联系起来,从而进行预测。例如,可以尝试根据句子中的前一个词来预测下一个词。当需要查看先前的信息来执行特定任务时,RNN表现出色。然而,情况并非总是如此。

以语言建模中的文本预测为例。有时,可以根据前一个词来预测文本中的下一个词,但更多时候,需要的信息不仅仅是前一个词。需要上下文,需要句子中的更多词。在这方面,RNN的表现并不理想。需要的先前状态越多,对标准RNN来说问题就越大。

梯度消失/爆炸问题

可能已经知道,梯度表示权重相对于误差变化的变化。因此,如果不知道梯度,就无法正确地更新权重,从而减少误差。在每个人工神经网络中,都面临一个问题,即如果梯度非常小,它将阻止权重改变其值。例如,当多次应用sigmoid函数时,数据会被压平,直到没有可检测的斜率。数据被压平,直到在很长的范围内没有可检测的斜率。同样的事情也会发生在梯度上,因为它通过神经网络的许多层传递。

在RNN中,这个问题更加严重,因为在递归神经网络的学习过程中,通过时间更新它们。使用时间反向传播(Backpropagation Through Time),这个过程引入了比常规反向传播更多的乘法和操作。当然,这个梯度问题也可以朝另一个方向发展。梯度可能会变得越来越大,导致RNN中的蝴蝶效应,即所谓的梯度爆炸问题。

LSTM架构

为了解决递归神经网络所面临的问题,Hochreiter & Schmidhuber(1997)提出了长短期记忆网络的概念。它们在处理一系列问题时表现出色,并且现在非常流行。它们的结构与标准RNN的结构相似,意味着它们将输出信息反馈到输入,但这些网络不会遇到标准RNN所面临的问题。在它们的架构中,它们实现了一种机制来记住长期依赖。

如果观察标准RNN的展开表示,认为它是一个重复结构的链。这种结构通常相当简单,例如单个tanh层。尽管LSTM遵循类似的原则,但它们的架构要复杂得多。它们通常有四层,而不是单层,并且它们在四个层次上有更多的操作。单个LSTM单元的结构可以在下面的图像中看到:

正如在图像中看到的,有四个网络层。一层使用tanh函数,两层使用sigmoid函数。还有一些点对点操作和向量上的常规操作,例如连接。将在下一章中更详细地探讨这一点。

然而,需要注意的重要一点是上面用字母C标记的数据流。这个数据流在递归神经网络的正常流程之外保存信息,被称为细胞状态。本质上,使用细胞状态传输的数据就像计算机内存中的数据一样,因为数据可以存储在这个细胞状态中或从中读取。LSTM单元决定将哪些信息添加到细胞状态中,使用称为门的模拟机制。

这些门是通过sigmoid(范围0-1)的逐元素乘法实现的。它们类似于神经网络的节点。根据输入信号的强度,它们将该信息传递到细胞状态或阻止它。为此,它们有自己的一组权重,这些权重在训练过程中得到维护。这意味着LSTM单元被训练来过滤某些数据并保留误差,这些误差后来可以在反向传播中使用。通过这种机制,LSMT网络能够在许多时间步骤上学习。

LSTM处理过程

LSTM单元的操作分为几个步骤:

1. 从细胞状态中忘记不必要的信息

2. 向细胞状态添加信息

3. 计算输出

LSTM单元中的第一个sigmoid层被称为遗忘门层。使用这一层/门,LSTM单元决定细胞状态Ct-1中的前一个状态的重要性,并决定将什么移除。由于使用的是sigmoid函数,范围从0到1,所以数据可以完全移除、部分移除或完全保留。这个决定是通过查看前一个状态Ht-1和当前输入Xt来做出的。使用这些信息和sigmoid级别,LSTM单元为细胞状态Ct-1中的每个数字生成一个介于0和1之间的数字:

forget_gate = sigmoid(Wf * [Ht-1, Xt] + bf)

当不必要的信息被移除后,LSTM单元决定哪些数据将被添加到细胞状态中。这是通过第二个sigmoid和tanh层的组合来完成的。首先,第二个sigmoid层,也称为输入门层,决定细胞状态中的哪些值将被更新(输出i),然后tanh层创建一个可能添加到细胞状态的新候选值向量(输出~Ct)。

input_gate = sigmoid(Wi * [Ht-1, Xt] + bi) candidate_values = tanh(Wc * [Ht-1, Xt] + bc)

有效地,使用这些信息,计算细胞状态的新值Ct。

Ct = forget_gate * Ct-1 + input_gate * candidate_values

需要做的最后一步是计算LSTM单元的输出。这是通过第三个sigmoid级别和额外的tanh过滤器来完成的。输出值基于细胞状态中的值,但也通过sigmoid层进行过滤。基本上,sigmoid层决定细胞状态的哪些部分将影响输出值。然后将细胞状态值通过tanh过滤器(将所有值推到-1和1之间)并乘以第三个sigmoid级别的输出。

output_gate = sigmoid(Wo * [Ht-1, Xt] + bo) output_value = tanh(Ct) * output_gate
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485