Mamba和State Space Models详解
- 1. Transformer的问题
- 1.1 Transformers的核心组件
- 1.2 一份带有训练的Blessing……
- 1.3 还有带推理的Curse!
- 1.4 RNN是解决方案吗?
- 2. 状态空间模型(SSM)
- 2.1 什么是状态空间?
- 2.2 什么是状态空间模型?
- 2.3 从连续信号到离散信号
- 2.4 循环表示
- 2.5 卷积表示
- 2.7 三种表示方式的对比
- 2.8 矩阵A的重要性
- 3. Mamba——一种选择性SSM
- 3.1 它试图解决什么问题?
- 3.2 有选择地保留信息
- 3.3 扫描操作
- 3.4 硬件感知算法
- 3.5 Mamba Block
- 4. 结论
- 5. 资源
Transformer架构一直是大型语言模型(LLM)取得成功的重要组成部分。从开源模型如Mistral,到闭源模型如ChatGPT,如今几乎所有正在使用的大型语言模型都采用了这一架构。
为进一步提升大语言模型的性能,人们开发了新的架构,这些架构甚至可能超越Transformer架构。其中一种方法便是Mamba,一种状态空间模型。
Mamba 是在论文《Mamba:具有选择性状态空间的线性时间序列建模》中提出的。您可以在其仓库中找到官方实现和模型检查点。
在本文中,我将从语言建模的角度介绍状态空间模型这一领域,并逐一探讨相关概念,以帮助大家建立对该领域的直观理解。随后,我们将深入探讨 Mamba 如何可能对 Transformer 架构构成挑战。
作为一份可视化指南,将助您更好地理解 Mamba 和状态空间模型!
1. Transformer的问题
为了说明Mamba为何是一种如此有趣的架构,我们先来简要回顾一下Transformer,并探讨其一个缺点。
Transformer 将任何文本输入视为由token组成的sequence。
Transformer的一个主要优势在于,无论它接收到何种输入,都能回溯序列中任意先前的标记,以获取其表示。
1.1 Transformers的核心组件
请记住,Transformer 由两种结构组成:一组用于表示文本的编码器块,以及一组用于生成文本的解码器块。这两种结构结合起来,可用于多项任务,包括翻译。
我们可以采用这种结构,仅使用解码器来构建生成模型。这种基于Transformer的模型——生成式预训练Transformer(GPT),利用解码器模块来完成对输入文本的生成。
让我们来看看它是如何运作的!
1.2 一份带有训练的Blessing……
一个解码器块由两个主要组成部分构成:掩码自注意力机制,后接前馈神经网络。
自注意力机制是这些模型能够取得如此出色效果的一个重要原因。它使得模型能够以快速的训练速度,对整个序列进行无压缩的全局视角分析。
那么,它是如何工作的呢?
它会创建一个矩阵,将每个token与之前出现过的每一个token进行比较。矩阵中的权重由token对之间的相关性决定。
在训练过程中,这个矩阵是一次性生成的。在计算“name”与“is”之间的注意力之前,无需先计算“My”与“name”之间的注意力。
这使得并行化成为可能,从而极大地加快了训练速度!
1.3 还有带推理的Curse!
然而,存在一个缺陷:在生成下一个token时,即使我们已经生成了一些tokens,仍需重新计算entire sequence的注意力。
为长度为L的序列生成tokens大约需要L²次计算,如果序列长度增加,这可能会带来高昂的成本。
这种需要重新计算整个序列的做法,是Transformer架构的一个主要瓶颈。
让我们来看看一种“经典”技术——循环神经网络——是如何解决这一推理速度慢的问题的。
1.4 RNN是解决方案吗?
循环神经网络(RNN)是一种基于序列的网络。在序列的每个时间步,它都会接收两个输入:时间步t的输入,以及前一时间步t-1的隐藏状态,从而生成下一个隐藏状态并预测输出。
RNN具有一种循环机制,能够将信息从上一个时间步传递到下一个时间步。我们可以对这一过程进行“展开”可视化,以使其更加清晰明了。
在生成输出时,RNN只需考虑前一个隐藏状态和当前输入。这避免了像Transformer那样重新计算所有先前的隐藏状态。换句话说,RNN的推理速度非常快,因为它与序列长度呈线性 scaling!理论上,它的上下文长度甚至可以是无限的。
为了说明这一点,让我们将RNN应用于我们之前使用过的输入文本。
每个隐藏状态都是所有先前隐藏状态的聚合,通常是一种压缩后的视图。然而,这里存在一个问题……
请注意,当生成名字“Maarten”时,最后一个隐藏状态已不再包含关于单词“Hello”的信息。由于RNN每次只考虑一个先前状态,因此随着时间推移,它们往往会逐渐遗忘信息。
RNN这种序列化的特性也带来了另一个问题:训练无法并行进行,因为必须按顺序逐个步骤地完成。
与Transformer相比,RNN的问题恰恰相反!它的推理速度极快,但却无法并行化。
我们能否找到一种架构,既能像Transformer那样实现训练的并行化,又能在推理时保持与序列长度呈线性 scaling 的性能?
当然可以!这正是Mamba所提出的方案。不过,在深入探讨其架构之前,让我们先来了解一下状态空间模型的世界吧。
2. 状态空间模型(SSM)
状态空间模型(SSM)与Transformer和RNN一样,能够处理信息序列,这些序列既可以是文本,也可以是信号。在本节中,我们将介绍状态空间模型的基本概念,以及它们如何应用于文本数据。
2.1 什么是状态空间?
状态空间包含完整描述一个系统所需的最少变量。它是一种通过定义系统的可能状态来对问题进行数学建模的方法。
让我们来简化一下这个概念。想象一下,我们正在迷宫中导航。“状态空间”就是一张描绘所有可能位置(状态)的地图。地图上的每个点都代表迷宫中的一个独特位置,并附带具体信息,比如你距离出口有多远。
“状态空间表示”是对这张地图的一种简化描述。它展示了你当前所处的位置(当前状态),你下一步可能前往的地方(未来可能的状态),以及哪些变化会带你进入下一个状态(向右或向左移动)。
尽管状态空间模型使用方程和矩阵来追踪这种行为,但这其实只是一种用来追踪你当前所处位置、可前往的地点以及如何抵达这些地点的方法。
在我们的示例中,描述状态的变量——即X和Y坐标,以及到出口的距离——可以被表示为“state vectors”。
听起来很熟悉吧?这是因为语言模型中的嵌入或向量也常被用来描述输入序列的“state”。例如,你当前位置的向量(状态向量)可能看起来有点像这样:
在神经网络中,系统的“state”通常指其隐藏状态;而在大型语言模型的背景下,这正是生成新token时最重要的方面之一。
2.2 什么是状态空间模型?
状态空间模型是用来描述这些状态表示,并根据某些输入预测其下一状态的模型。
传统上,在时刻t,SSMs:
- 将输入序列x ( t ) x(t)x(t)——(例如,在迷宫中向左下方移动)
- 映射到潜在状态表示h ( t ) h(t)h(t)——(例如,到出口的距离以及x/y坐标);
- 并由此推导出预测输出序列y ( t ) y(t)y(t)——(例如,再次向左移动以更早到达出口)。
然而,它并非使用离散序列(例如向左移动一次),而是以连续序列为输入,并预测输出序列。
SSM假设,诸如物体在三维空间中运动之类的动态系统,可以通过两个方程从其在时刻t tt的状态进行预测。
h ′ ( t ) = A h ( t ) + B x ( t ) h'(t)=Ah(t)+Bx(t)h′(t)=Ah(t)+Bx(t)
y ( t ) = C h ( t ) + D x ( t ) y(t)=Ch(t)+Dx(t)y(t)=Ch(t)+Dx(t)
通过求解这些方程,我们假设能够揭示出统计规律,从而根据观测数据(输入序列和先前状态)预测系统的状态。
它的目标是找到这种状态表示h ( t ) h(t)h(t),以便我们能够从输入序列过渡到输出序列。
这两个方程是状态空间模型的核心。
在本指南中,将多次引用这两个方程。为了使它们更易于理解,我们对它们进行了颜色编码,以便您能快速查阅。
状态方程描述了状态如何变化(通过矩阵A),以及输入如何影响状态(通过矩阵B)。
正如我们之前所见,h ( t ) h(t)h(t)表示在任意时刻 t 的隐状态表示,x ( t ) x(t)x(t)则表示某种输入。
输出方程描述了状态如何通过矩阵 C 转换为输出,以及输入如何通过矩阵 D 影响输出。
注意:矩阵A、B、C和D通常也被称为参数,因为它们是可学习的。
将这两个方程可视化,我们得到以下架构:
让我们逐步了解这一通用技术,以理解这些矩阵如何影响学习过程。
假设我们有一个输入信号x ( t ) x(t)x(t),该信号首先会与矩阵B相乘,而矩阵B描述了输入如何影响系统。
更新后的状态(类似于神经网络的隐藏状态)是一个潜在空间,其中包含了环境的核心“知识”。我们将该状态与矩阵A相乘,矩阵A描述了所有内部状态之间的连接方式,从而表征了系统的内在动力学。
正如你可能注意到的,矩阵A是在创建状态表示之前应用的,并在状态表示更新之后进行更新。
然后,我们使用矩阵C来描述状态如何转化为输出。
最后,我们可以利用矩阵D实现从输入到输出的直接信号传递。这种连接方式通常也被称为**“跳跃连接”**。
由于矩阵D类似于跳跃连接,因此SSM通常被视为不带跳跃连接的以下形式:
回到简化的视角,我们现在可以将矩阵A、B和C作为SSM的核心来重点研究。
可以表示为:
可以像之前那样,更新原始方程(并添加一些鲜艳的色彩),以表明每个矩阵的用途。
这两条方程共同旨在根据观测数据预测系统的状态。由于输入预计为连续信号,状态空间模型的主要表示形式是连续时间表示。
2.3 从连续信号到离散信号
如果你处理的是连续信号,那么寻找状态表示h(t)在理论上会面临很大挑战。此外,由于我们通常处理的是离散输入(比如文本序列),因此希望对模型进行离散化处理。
为此,我们采用了零阶保持技术(Zero-order hold technique)。其工作原理如下:首先,每当我们接收到一个离散信号时,便将其值保持下来,直至接收到新的离散信号。这一过程生成了一种连续信号,可供SSM使用:
我们保持该值的时间长短由一个名为步长∆(step size)的新可学习参数来表示。它代表了输入的分辨率。
现在我们有了输入的连续信号,就可以生成一个连续的输出,并仅按输入的时间步长对数值进行采样。
这些采样值就是我们的离散化输出!
从数学上讲,我们可以按如下方式应用零阶保持器:
它们共同使我们能够从连续的SSM过渡到离散的SSM,这种离散SSM采用一种新的表述形式:不再是函数到函数的映射x(t) → y(t),而变成了序列到序列的映射xₖ → yₖ。
在这里,矩阵A和B现在代表了模型的离散化参数。
我们用k代替t来表示离散化的时间步长,以便在提及连续与离散状态空间模型时更加清晰。
注意:在训练期间仍保存的是矩阵A的连续形式,而非离散化后的版本。在训练过程中,连续表示会被离散化。
2.4 循环表示
我们所采用的离散化状态空间模型使我们能够以特定的时间步长来表述问题,而非连续信号。正如我们之前在循环神经网络中所见,这种递归方法在此处非常有用。
如果我们考虑离散时间步长而非连续信号,便可以重新以时间步长来表述该问题:
在每个时间步,我们计算当前输入(Bxₖ)如何影响先前状态(Ahₖ₋₁),然后计算预测输出(Chₖ)。
这种表示方式或许已经显得有些熟悉了!我们可以像之前处理RNN那样,以同样的方式来处理它。
我们可以将其展开(或展开成)如下:
请注意,我们如何能够利用这种离散化版本,并基于RNN的底层方法论来实现。
2.5 卷积表示
我们还可以用卷积来表示SSM。还记得在经典的图像识别任务中,我们如何应用滤波器(内核)以提取聚合特征吗:
由于我们处理的是文本而非图像,因此我们需要采用一维视角:
运用来自不同领域的技术,能打造出一条有趣的流水线。用于表示此“滤波器(filter)”的内核源自SSM公式:
让我们来探讨一下这个内核在实际中的工作原理。与卷积类似,我们也可以利用SSM内核遍历每组标记,并计算输出:
这也说明了填充可能对输出产生的影响。我调整了填充的顺序以改善可视化效果,但通常会在句子末尾应用填充。
下一步,内核会向右移动一次,以执行计算的下一步:
在最后一步,我们可以看到内核的全部效果:
将SSM表示为卷积的一大优势在于,它能够像卷积神经网络(CNN)一样并行训练。然而,由于核大小固定,其推理速度不如循环神经网络(RNN)那样快速且不受限制。
2.7 三种表示方式的对比
这三种表示方法——连续、循环和卷积——各自都具有一系列不同的优缺点:
有趣的是,我们现在实现了基于循环SSM的高效推理,以及基于卷积SSM的可并行训练。
借助这些表示方法,我们可采用一种巧妙的技巧:根据任务的不同选择相应的表示。在训练过程中,我们使用可并行化的卷积表示;而在推理阶段,则采用高效的循环表示。
该模型被称为线性状态空间层Linear State-Space Layer, LSSL。
这些表示法具有一项重要特性,即线性时不变性(Linear Time Invariance, LTI)。所谓线性时不变性,是指SSM的参数A、B和C在所有时间步长上均保持固定。这意味着,对于SSM生成的每一个标记,矩阵A、B和C都是相同的。
换句话说,无论你向SSM输入何种序列,A、B和C的值始终保持不变。我们所采用的是一种静态表示,它并不具备内容感知能力。
在探讨Mamba如何解决这一问题之前,让我们先来了解一下拼图的最后一块——矩阵A。
2.8 矩阵A的重要性
可以说,SSM公式中最重要的方面之一就是矩阵A。正如我们之前在循环表示中所看到的,它会捕捉前一状态的信息,以构建新状态。
本质上,矩阵A生成隐藏状态:
因此,创建矩阵A的差异可能就在于:是仅记住前几个标记,还是捕捉迄今为止我们所见过的每一个标记。尤其是在循环表示的背景下,因为这种表示方式只回看之前的状态。
那么,我们该如何以一种能够保留大容量记忆(上下文大小)的方式创建矩阵A呢?我们使用“饥饿的河马(Hungry Hungry Hippo)”!或者称作HiPPO3,用于高阶多项式投影算子。
HiPPO试图将其迄今为止所见的所有输入信号压缩成一个系数向量。它使用矩阵A构建一种状态表示,能够很好地捕捉近期的标记,并对较早的标记进行衰减。其公式可表示如下:
假设我们有一个方阵A,这将给我们:
使用HiPPO构建矩阵A被证明比将其初始化为随机矩阵要好得多。因此,与旧信号(初始标记)相比,它能更准确地重建较新的信号(近期标记)。HiPPO矩阵的核心思想在于,它能够生成一种记忆其历史的隐藏状态。
从数学上讲,它通过追踪Legendre多项式的系数来实现这一功能,从而能够近似表示此前的所有历史。
随后,HiPPO被应用于我们之前所见到的循环和卷积表示,以处理长距离依赖关系。其成果便是结构化状态空间序列模型(Structured State Space for Sequences, S4),一类能够高效处理长序列的SSM。它由三个部分组成:
- 状态空间模型
- 用于处理长距离依赖关系的HiPPO
- 用于构建循环和卷积表示的离散化方法
这类SSM具有多种优势,具体取决于您选择的表示方式(循环型 vs. 卷积型)。它还能通过基于HiPPO矩阵,高效地处理长文本序列并有效存储记忆。
注意:如果想深入了解如何计算HiPPO矩阵并自行构建S4模型的技术细节,我强烈建议您仔细阅读《带注释的S4》。
3. Mamba——一种选择性SSM
至此,我们终于涵盖了理解Mamba独特之处所需的所有基础知识。状态空间模型可用于对文本序列建模,但仍然存在一些我们希望避免的缺点。
在本节中,我们将详细介绍Mamba的两大主要贡献:
- 一种选择性扫描算法,使模型能够过滤(不相关)信息;
- 一种面向硬件的算法,通过并行扫描、内核融合和重新计算,实现(中间)结果的高效存储。
这两者共同构成了选择性SSM或S6模型,这些模型可像自注意力机制一样,用于构建Mamba模块。
在深入探讨这两项主要贡献之前,让我们先了解一下它们为何必不可少。
3.1 它试图解决什么问题?
状态空间模型,甚至包括S4(结构化状态空间模型),在语言建模与生成中至关重要的某些任务上表现欠佳,尤其是对特定输入进行聚焦或忽略的能力。我们可以通过两个合成任务来说明这一点,即选择性复制(selective copying)和归纳头任务(induction heads)。
在选择性复制任务中,状态空间模型的目标是按顺序复制输入中的部分内容并将其输出:
然而,(递归/卷积)SSM 在这项任务中表现不佳,因为它具有线性时不变特性。正如我们之前所见,SSM 生成的每个标记所对应的矩阵 A、B 和 C 均相同。
因此,SSM 无法进行内容感知推理,因为它会因矩阵 A、B 和 C 固定不变而对每个标记一视同仁。这构成了一个问题,因为我们希望 SSM 能够针对输入(提示)进行推理。
SSM表现不佳的第二个任务是归纳头,其目标是重现输入中所发现的模式:
在上述示例中,我们实际上是在进行一次性提示,试图“教会”模型在每个“Q:”之后都给出“A:”的回复。然而,由于SSM是时间不变的,它无法从历史中选择哪些先前的标记进行调用。
让我们以矩阵B为例来说明这一点。无论输入x为何,矩阵B始终保持完全不变,因此它与x无关:
同样,无论输入如何变化,A和C也始终保持不变。这体现了我们迄今所见的SSM的静态特性。
相比之下,这些任务对Transformer来说相对容易,因为它们会根据输入序列动态地调整注意力。它们能够有选择地“关注”或“留意”序列的不同部分。
SSM在这些任务上表现不佳,凸显了时不变SSM所存在的根本问题:矩阵A、B和C的静态特性导致其在内容感知方面存在缺陷。
3.2 有选择地保留信息
SSM的递归表示会创建一个状态较小且非常高效的模型,因为它能够压缩整个历史信息。然而,与通过注意力矩阵完全不压缩历史信息的Transformer模型相比,它的表现要逊色得多。
Mamba的目标是兼得两者之长:既拥有小型状态,又具备与Transformer模型相当的强大能力。
正如上文所暗示的,它是通过将数据有选择地压缩到状态中来实现这一功能的。当你输入一个句子时,其中往往包含一些诸如停用词之类的词语,这些词语本身并没有太多实际意义。
为了有选择地压缩信息,我们需要使参数依赖于输入。为此,我们首先来探讨一下在训练过程中SSM中输入与输出的维度:
在结构化状态空间模型(S4)中,矩阵A、B和C与输入无关,因为它们的维度N和D是静态的,不会发生变化。
相反,Mamba 通过引入输入的序列长度和批次大小,使矩阵 B 和 C 乃至步长 ∆ 都依赖于输入:
这意味着,对于每个输入标记,我们现在都有不同的B和C矩阵,从而解决了内容感知问题!
注意:矩阵A保持不变,因为我们希望状态本身保持静态,但其受外界影响的方式(通过B和C)则应是动态的。
它们共同选择性地决定在隐藏状态中保留什么、忽略什么,因为此时它们已开始依赖输入。步长∆越小,就越倾向于忽略特定的词语,而更多地利用先前的上下文;步长∆越大,则越注重输入的词语,而非上下文。
3.3 扫描操作
由于这些矩阵现在是动态的,因此无法再采用卷积表示法进行计算,因为卷积表示法假定核是固定的。我们只能使用递归表示法,从而失去卷积带来的并行化优势。
为了实现并行化,让我们来探讨一下如何通过递归方式计算输出:
每个状态都是前一状态(乘以A)与当前输入(乘以B)之和。这种操作称为扫描运算,可轻松通过for循环来实现。相比之下,由于每个状态的计算都依赖于前一状态,因此并行化似乎是不可能的。然而,Mamba借助并行扫描算法,使这一目标成为可能。它利用结合律,假定操作的顺序并不重要。因此,我们可以分段计算序列,并逐步将各部分合并起来:
动态矩阵B和C与并行扫描算法共同构成了选择性扫描算法,以体现递归表示所具有的动态性和快速性。
3.4 硬件感知算法
近期GPU的一个缺点是,其小型但高效能的SRAM与大型但效率稍低的DRAM之间的传输(I/O)速度有限。频繁地在SRAM和DRAM之间复制数据,已成为一个瓶颈。
Mamba与Flash Attention一样,试图减少我们从DRAM到SRAM以及反向访问的次数。它通过**内核融合(kernel fusion)**实现这一目标,使模型能够避免写入中间结果,并持续进行计算,直至任务完成。
我们可以通过可视化Mamba的基础架构,查看DRAM和SRAM分配的具体实例:
在此,以下内容被融合为一个内核:
- 步长为∆的离散化步骤
- 选择性扫描算法
- 与C的乘法运算
硬件感知算法的最后一部分是重新计算(recomputation)。中间状态不会被保存,但对反向传播计算梯度却是必需的。因此,作者在反向传播过程中重新计算这些中间状态。尽管这看似效率不高,但实际上其成本远低于从速度较慢的DRAM中读取所有这些中间状态。我们现已涵盖了其架构的所有组成部分,相关示意图见其文章中的以下图片:
图:选择性SSM。来源:Gu, Albert, and Tri Dao. “Mamba: Linear-time sequence modeling with selective state spaces.” arXiv preprint arXiv:2312.00752 (2023).
这种架构通常被称为选择性SSM或S6模型,因为它本质上是采用选择性扫描算法计算得出的S4模型。
3.5 Mamba Block
我们迄今所探讨的这种选择性SSM可以像我们在解码器块中表示自注意力机制那样,以一个模块的形式实现。
与解码器类似,我们可以堆叠多个Mamba模块,并将它们的输出用作下一个Mamba模块的输入:它首先通过线性投影来扩展输入嵌入(input embeddings)。随后,在应用选择性SSM之前,会先进行卷积操作,以防止对各个标记进行独立计算。
The Mamba Block
选择性SSM具有以下特性:
- 通过离散化创建的递归SSM
- 对矩阵A进行HiPPO初始化,以捕捉长程依赖关系
- 选择性扫描算法,用于有选择地压缩信息
- 面向硬件的算法,以加速计算
在查看代码实现时,我们可以进一步扩展这一架构,并探讨一个端到端示例的具体模样:
请注意一些变化,例如引入了归一化层以及用于选择输出标记的Softmax函数。
将所有这些整合起来,我们便实现了快速的推理与训练,甚至还能支持无限长度的上下文。借助这种架构,作者发现其性能不仅与同等规模的Transformer模型相当,有时甚至超越了后者!
4. 结论
至此,我们关于状态空间模型以及采用选择性状态空间模型的惊艳Mamba架构的旅程便告一段落。希望本文能帮助您更好地理解状态空间模型——尤其是Mamba——的潜力。谁又知道,未来它是否会取代Transformer呢?但眼下,看到如此迥异的架构获得应有的关注,实在令人振奋!
5. 资源
希望这能为你提供一个易于理解的Mamba与状态空间模型入门介绍。如果你想深入学习,我推荐以下资源:
- The Annotated S4 is a JAX implementation and guide through the S4 model and is highly advised!
- A great YouTube video introducing Mamba by building it up through foundational papers.
- The Mamba repository with checkpoints on Hugging Face.
- An amazing series of blog posts (1, 2, 3) that introduces the S4 model.
- The Mamba No. 5 (A Little Bit Of…) blog post is a great next step to dive into more technical details about Mamba but still from an amazingly intuitive perspective.
- And of course, the Mamba paper! It was even used for DNA modeling and speech generation.
1 Gu, Albert, and Tri Dao. “Mamba: Linear-time sequence modeling with selective state spaces.” arXiv preprint arXiv:2312.00752 (2023).
2 Gu, Albert, et al. “Combining recurrent, convolutional, and continuous-time models with linear state space layers.” Advances in neural information processing systems 34 (2021): 572-585.
3 Gu, Albert, et al. “Hippo: Recurrent memory with optimal polynomial projections.” Advances in neural information processing systems 33 (2020): 1474-1487.
4 Voelker, Aaron, Ivana Kajić, and Chris Eliasmith. “Legendre memory units: Continuous-time representation in recurrent neural networks.” Advances in neural information processing systems 32 (2019).
5 Gu, Albert, Karan Goel, and Christopher Ré. “Efficiently modeling long sequences with structured state spaces.” arXiv preprint arXiv:2111.00396 (2021).