pp006.Mistral 7B

 

AI前沿|Mistral 7B文章挂arxiv了

前几天引起关注的Mistral 7B,其公司才创建了几个月,全欧洲的AI热钱都在往这里流,创立几周后种子轮就融了1.13亿美元,阵容也不错,奔着欧洲地区的OpenAI去的。他们开源的模型Mistral 7B声称在所有评估榜单上比Llama2-13B都好,并且超过了34B的Llama1的推理、数学和代码生成能力(见figure4和table2),最近论文被挂到arxiv上了。主要用到的技术是分组查询注意力GQA(用于更快推理)和滑窗注意力SWA(用于高效处理任意长度序列),他们也提供了一个微调后能跟随指令的模型Mistral 7B - Instruct,在人类和自动评估指标上都超过了Llama 2 13B Chat模型,见table3。模型发布的协议是Apache 2.0。

大家对这个系列讨论最多的是滑窗注意力SWA,认为是Mistral这么强的关键(当然更高的数据质量也是原因之一),不过SWA很早就在别的论文里提出来了,考虑到还是有人不知道,还是简单写一下。见figure1,原始的注意力相较于序列长度是平方关系,且内存消耗随着token数量线性增长,在推理时该种注意力也会引起较高的时延和更小的吞吐。为缓解该问题,作者使用了SWA,每一个token可以最多注意到来自先前层的W个token(图例中是3),当然在滑窗之外的token实际还是能影响到下一个词的预测。因为在每一层,信息可以向前传递W token,因此在k层个注意力层后,信息可以向前传递达到k x W个token。在作者的实际模型中他们使用的W是4096,而层数是32(见table1的参数),因此理论上注意力的跨度超过131k tokens,而实际上,对于一个token长度为16k和W=4096的情况下,作者对FlashAttention与xFormers继续魔改,可以得到相比于原始注意力基线两倍的速度。

作者在架构细节里还提到另外两个技术:分别是Rolling Buffer Cache和Pre-fill与Chunking。前者滚动缓冲区缓存可以限制缓存大小,将缓存大小固定在W,时间步i的的键和值存储在缓存的位置 i mod W 中。当位置i大于W时,缓存中过去的值将被覆盖,且缓存大小停止增加。见figure2中W=4的说明。对于32k token的序列长度,这将缓存内存使用量在不影响模型质量的情况下减少了8倍;后者预填充和分块。生成序列时,需要逐一预测token,因为每个token都以前一个token为条件。但prompt是预先知道的,因此可以用prompt预先填充(k,v)缓存。如果prompt非常大,便可以将其分成更小的块,并用每个块预填充缓存。为此,我们可以选择窗口大小作为块大小。因此,对于每个块,我们需要计算对缓存和块的注意力。见figure3,序列被分成了三部分,对于当前的分块,其注意自己的方式用的因果掩码,对于缓存的分块用的是SWA,对于过去的分块并不注意因为脱离了滑窗。