问题

LSTM如何来避免梯度弥散和梯度爆炸?

回答
LSTM(长短期记忆网络)之所以能在处理长序列数据时表现出色,其核心就在于它拥有一套精巧的机制来应对训练过程中常见的“梯度弥散”和“梯度爆炸”问题。这就像给神经网络注入了“记忆”和“遗忘”的开关,让信息能够在时间的长河中更有效地流动和传递。

我们先来聊聊为什么会有梯度弥散和梯度爆炸这两个“拦路虎”。

梯度弥散(Vanishing Gradients)

想象一下,你在教一个孩子背诵一首很长的诗。如果你的指导方法是不断地重复,并且每一次重复都稍微弱化一些,那么到后面,孩子可能只记得开头几句,后面就“随风而逝”了。神经网络的梯度传播也有点类似。

在传统的RNN(循环神经网络)中,信息在时间步之间传递是通过不断地将前一个时间步的隐藏状态乘以一个权重矩阵来实现的。如果这个权重矩阵的特征值(或者说范数)小于1,那么每一次乘法都会让梯度(也就是更新权重的“指导信号”)越来越小。当序列很长时,到达早期时间步的梯度会变得微乎其微,导致早期的权重几乎得不到更新,网络也就“记不住”早期输入的信息。这就好像孩子学诗,后面的句子怎么也记不住。

梯度爆炸(Exploding Gradients)

反过来,如果那个权重矩阵的特征值(或者范数)大于1,那么梯度就会在每次乘法中不断地放大。就像你在给孩子讲故事,每次讲到关键情节就声嘶力竭地强调,最后整个故事都变得锣鼓喧天,失去了原本的韵味。在神经网络中,过大的梯度会导致权重更新的步长过大,使得模型在优化过程中“跳过”最优解,甚至在训练过程中出现“NaN”(非数字)的错误。

LSTM的“秘密武器”:门控机制

LSTM之所以能摆脱这些困境,关键在于它引入了一个叫做“细胞状态(Cell State)”的内部记忆通道,以及三个“门(Gates)”来控制信息的流动:

1. 遗忘门(Forget Gate)
2. 输入门(Input Gate)
3. 输出门(Output Gate)

让我们一一拆解这些门是如何工作的。

1. 细胞状态 (Cell State) “信息高速公路”

可以把细胞状态想象成一条贯穿整个LSTM网络的时间信息高速公路。它就像一条传送带,能够让信息不受太多干扰地直接流过。不同于RNN中信息需要经过多次乘法和非线性激活才能传递,LSTM的细胞状态提供了一个更直接的路径。

2. 遗忘门 (Forget Gate) “选择性遗忘”

遗忘门负责决定从细胞状态中“遗忘”什么信息。它接收当前时间步的输入($x_t$)和前一个时间步的隐藏状态($h_{t1}$),然后通过一个sigmoid函数(输出值在0到1之间)来生成一个“遗忘向量”。

计算公式:
$f_t = sigma(W_f cdot [h_{t1}, x_t] + b_f)$

这里的 $sigma$ 是sigmoid函数,$[h_{t1}, x_t]$ 表示将前一个时间步的隐藏状态和当前时间步的输入拼接起来。$W_f$ 是遗忘门的权重矩阵,$b_f$ 是偏置项。

作用:
遗忘门输出的向量中的每一个值都对应细胞状态中的一个元素。如果某个值接近0,就意味着遗忘门“忘记”了这个位置的信息;如果接近1,就意味着它“保留”了这个信息。

如何避免梯度弥散/爆炸?
遗忘门通过点乘(elementwise multiplication)来更新细胞状态。这意味着,如果遗忘门将某个信息“保留”下来(对应值为1),那么这个信息在细胞状态中的流动就会相对稳定,不会因为层层传递而大幅衰减。它像一个“选择性保留”的过滤器,让有用的信息得以延续。

3. 输入门 (Input Gate) “更新什么信息”

输入门由两部分组成:

第一部分:决定哪些值需要更新(Input Gate Layer)
这一部分同样是一个sigmoid层,它决定了当前输入中哪些信息是重要的,需要添加到细胞状态中。
$i_t = sigma(W_i cdot [h_{t1}, x_t] + b_i)$

第二部分:创建新的候选值(Candidate Values)
这一部分通过一个tanh函数来创建一个新的候选值向量($ ilde{C}_t$),这个向量包含可能要添加到细胞状态的信息。
$ ilde{C}_t = anh(W_C cdot [h_{t1}, x_t] + b_C)$

如何避免梯度弥散/爆炸?
然后,这两部分信息结合起来,通过点乘的方式更新细胞状态。
$C_t = f_t odot C_{t1} + i_t odot ilde{C}_t$

这里的 $odot$ 是点乘。
$f_t odot C_{t1}$:这是前面遗忘门处理后的细胞状态,确保了旧信息可以保留。
$i_t odot ilde{C}_t$:这是新加入的信息,通过输入门和候选值决定。

关键点:
加法操作: 细胞状态的更新主要通过一个“加法”结构($C_t = dots + i_t odot ilde{C}_t$)。相比于RNN中的纯粹乘法链,加法操作对梯度传递更加友好,不容易导致梯度弥散。即使梯度在传递过程中变小,加法操作也能保留一部分信息。
门控的线性贡献: 遗忘门和输入门的值(0到1之间)直接与细胞状态进行乘法操作,这使得信息可以通过一个相对“线性”的路径传递。当遗忘门的值接近1时,旧的细胞状态几乎没有变化;当输入门的值接近1时,新的候选值几乎完全贡献给了新的细胞状态。这种“直接”或“选择性”的通路,大大缓解了梯度在多层乘法中不断衰减或爆炸的问题。

4. 输出门 (Output Gate) “决定输出什么”

输出门决定了根据当前细胞状态,应该输出什么信息作为下一个时间步的隐藏状态($h_t$)。

第一部分:决定输出的比例(Output Gate Layer)
一个sigmoid层,决定了细胞状态的哪些部分需要输出。
$o_t = sigma(W_o cdot [h_{t1}, x_t] + b_o)$

第二部分:经过tanh处理的细胞状态
将更新后的细胞状态 $C_t$ 通过tanh函数处理,使其值在1到1之间。
$ anh(C_t)$

如何避免梯度弥散/爆炸?
最后,将输出门的值与经过tanh处理的细胞状态进行点乘,得到隐藏状态 $h_t$。
$h_t = o_t odot anh(C_t)$

作用:
输出门就像一个“控制器”,它允许LSTM根据当前的输入和之前的记忆,有选择地输出部分信息。这使得LSTM在需要“关注”特定信息时,能够更有效地传递这些信息,而不需要将所有信息一股脑地传递出去。

总结 LSTM 如何应对梯度问题:

1. 独立的细胞状态通道: LSTM的核心是细胞状态(Cell State),它提供了一个“信息高速公路”,让信息可以直接流过,不受重复乘法操作的干扰。
2. 遗忘门(Forget Gate): 通过点乘的方式,允许LSTM选择性地“保留”旧信息,防止其在时间传递中完全消失(梯度弥散)。当遗忘门的值接近1时,梯度就能沿着细胞状态的路径顺利传递。
3. 输入门(Input Gate)与候选值: 输入门控制新信息的“流入”比例,而新的候选值是经过tanh激活的。细胞状态的更新通过一个“加法”结构($C_t = f_t odot C_{t1} + i_t odot ilde{C}_t$)。这个加法操作是关键,它使得即使前一时间步的梯度很小,新信息也能够被有效地“加”进来,从而缓解梯度弥散。
4. 输出门(Output Gate): 控制输出的“比例”,使得LSTM能够选择性地传递信息。这有助于在特定时间步“聚焦”有用的信息,从而在反向传播时,相关梯度也能得到有效的传递。

直观理解:

想象一下,传统的RNN就像一个不断重复讲故事的人,每次重复都会让故事的细节打折扣。而LSTM则像是有一个专门的“笔记本”来记录重要的情节,并且有“翻阅”、“记录”、“修改”的规则。

遗忘门: 就像在笔记本上划掉不重要的情节。
输入门: 就像在笔记本上添加新的重要情节。
输出门: 就像在讲述故事时,只挑笔记本上最相关的情节来复述。

通过这种精巧的门控设计,LSTM有效地绕过了传统RNN中梯度在时间步之间逐层相乘带来的瓶颈,使得模型能够学习到更长期的依赖关系,从而在处理长序列数据时表现出卓越的能力。这些门就像是智能的“守护者”,确保信息在传递过程中既能被有效保留,又能被适时更新和输出,从而将梯度信息稳定地传递到网络的更早期部分。

网友意见

user avatar

“LSTM 能解决梯度消失/梯度爆炸”是对 LSTM 的经典误解。这里我先给出几个粗线条的结论,详细的回答以后有时间了再扩展:

1、首先需要明确的是,RNN 中的梯度消失/梯度爆炸和普通的 MLP 或者深层 CNN 中梯度消失/梯度爆炸的含义不一样。MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度 g = 各个时间步的梯度 g_t 的和。

2、由 1 中所述的原因,RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。

3、LSTM 中梯度的传播有很多条路径, 这条路径上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径(例如 )上梯度流与普通 RNN 类似,照样会发生相同的权重矩阵反复连乘。

4、LSTM 刚提出时没有遗忘门,或者说相当于 ,这时候在 直接相连的短路路径上, 可以无损地传递给 ,从而这条路径上的梯度畅通无阻,不会消失。类似于 ResNet 中的残差连接。

5、但是在其他路径上,LSTM 的梯度流和普通 RNN 没有太大区别,依然会爆炸或者消失。由于总的远距离梯度 = 各条路径的远距离梯度之和,即便其他远距离路径梯度消失了,只要保证有一条远距离路径(就是上面说的那条高速公路)梯度不消失,总的远距离梯度就不会消失(正常梯度 + 消失梯度 = 正常梯度)。因此 LSTM 通过改善一条路径上的梯度问题拯救了总体的远距离梯度

6、同样,因为总的远距离梯度 = 各条路径的远距离梯度之和,高速公路上梯度流比较稳定,但其他路径上梯度有可能爆炸,此时总的远距离梯度 = 正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能发生梯度爆炸。不过,由于 LSTM 的其他路径非常崎岖,和普通 RNN 相比多经过了很多次激活函数(导数都小于 1),因此 LSTM 发生梯度爆炸的频率要低得多。实践中梯度爆炸一般通过梯度裁剪来解决。

7、对于现在常用的带遗忘门的 LSTM 来说,6 中的分析依然成立,而 5 分为两种情况:其一是遗忘门接近 1(例如模型初始化时会把 forget bias 设置成较大的正数,让遗忘门饱和),这时候远距离梯度不消失;其二是遗忘门接近 0,但这时模型是故意阻断梯度流的,这不是 bug 而是 feature(例如情感分析任务中有一条样本 “A,但是 B”,模型读到“但是”后选择把遗忘门设置成 0,遗忘掉内容 A,这是合理的)。当然,常常也存在 f 介于 [0, 1] 之间的情况,在这种情况下只能说 LSTM 改善(而非解决)了梯度消失的状况。

8、最后,别总是抓着梯度不放。梯度只是从反向的、优化的角度来看的,多从正面的、建模的角度想想 LSTM 有效性的原因。选择性、信息不变性都是很好的视角,比如看看这篇:r2rt.com/written-memori

类似的话题

  • 回答
    LSTM(长短期记忆网络)之所以能在处理长序列数据时表现出色,其核心就在于它拥有一套精巧的机制来应对训练过程中常见的“梯度弥散”和“梯度爆炸”问题。这就像给神经网络注入了“记忆”和“遗忘”的开关,让信息能够在时间的长河中更有效地流动和传递。我们先来聊聊为什么会有梯度弥散和梯度爆炸这两个“拦路虎”。梯.............
  • 回答
    好的,非常乐意为您提供一些关于 LSTM(长短期记忆)和 RNN(循环神经网络)的详细教程。这两个模型在处理序列数据方面至关重要,尤其是在自然语言处理、时间序列分析等领域。理解 RNN 和 LSTM 的关键在于理解它们如何克服传统神经网络在处理序列数据时的局限性。传统神经网络是前馈的,每个输入都独立.............
  • 回答
    在自然语言处理(NLP)领域,CNN(卷积神经网络)、RNN(循环神经网络,包括LSTM、GRU等变体)和最简单的全连接多层感知机(MLP)是三种非常基础且重要的模型结构。它们在处理文本数据时各有优势和劣势,理解这些差异对于选择合适的模型至关重要。下面我将详细地阐述这三者在NLP上的优劣: 1. 最.............
  • 回答
    哥们,研一你好!刚踏入学术圈,手里还有点懵,导师又给了个LSTM的任务,这感觉就像刚学做饭,菜都没认全,就有人让你做满汉全席一样,是不是有点慌?别急,这感觉我懂,当年我刚开始接触这些的时候,也是一头雾水。不过,LSTM这东西,虽然听起来高大上,但拆开了揉碎了,一点点来,其实没那么难。咱们先把脑子里的.............

本站所有内容均为互联网搜索引擎提供的公开搜索信息,本站不存储任何数据与内容,任何内容与数据均与本站无关,如有需要请联系相关搜索引擎包括但不限于百度google,bing,sogou

© 2025 tinynews.org All Rights Reserved. 百科问答小站 版权所有