AI前沿|重用提示词的状态加速推理
在此论文中作者提出Prompt Cache,提示缓存,通过在不同LLM提示之间重用注意力状态来加速模型推理。很多提示都有相同的文本片段,如系统消息,提示模版(工具调用的语句)和为上下文提供的文档(法律分析和教育领域等)。作者的想法是通过在推理服务器上预先计算和存储这些频繁出现的文本片段的注意力状态,当这些片段出现在用户提示中时,就可以有效地重用它们。提示缓存使用一种范式Schema来显式定义此类可重用文本段,称为提示模块。该模式确保注意力状态重用期间的位置准确性,并为用户提供在提示中访问缓存状态的接口。使用原型实现,作者评估了多个LLM的Prompt Cache。结果显示提示缓存显着减少了首次token出现的时延,特别是对于较长的提示,例如基于文档的问答和推荐。改进从基于GPU的推理的8倍到基于CPU的推理的60倍,该方法能同时保持输出精度,且无需修改模型参数。
Prompt Cache基于KV Cache构建,通过使注意力状态重用模块化,将注意力状态重用从单个提示扩展到多个提示。在该方法中,经常重用的文本片段被单独预先计算并存储在内存中。当这样的“缓存”片段出现在输入提示中时,系统使用内存中预先计算的键值注意力状态,而非重计算它们。因此,仅对未缓存的文本片段需要注意计算。figure1清晰说明了完全自回归生成、KV缓存和提示缓存间的区别。作者注意到,随着缓存段大小的增加,性能优势变得更明显,毕竟注意力状态的计算开销与输入序列大小成二次方,而提示缓存的的存储开销呈线性放缩。
在跨提示重用注意力状态时有两个难点:1.由于transformer的位置编码,注意力状态也是位置相关的,因此文本片段的状态只有在片段在相同位置出现时才能重用。2.当某个文本片段的注意力状态可能已经被缓存下来时,该系统必须能高效识别出该片段。为解决问题,提示缓存组合了两个想法,第一个是使用提示标记语言(PML,见figure2)使提示的结构变得明确。PML使可重用的文本段明确为模块,即提示模块。因为每个提示模块都可以分配有唯一的位置ID,它就不仅解决了上面的第二个问题,而且为解决第一个问题做铺垫。第二个想法是作者实证发现,LLM可以对具有不连续位置ID的注意力状态进行操作,这意味着我们可以提取不同的注意力状态片段并将它们连接起来以形成新的含义。他们利用这点使用户能根据需要选择提示模块,甚至在运行时更新提示模块。
总之,使用提示缓存时用户用PML编写提示,目的是可以基于提示模块重用注意力状态。重要的是,他们必须从范式中推出提示,该范式也是用PML编写的。figure2展示了一个基于示例方式的提示示例,这个例子中用PML在范式与提示中都构造了可重用模块,提示模块可以具有参数(如行程计划),而提示可以在该模块中给具体参数赋值,比如给duration参数3days。接着当提示缓存收到提示时,它首先处理其范式并为提示模块计算注意力状态。它对提示中的提示模块以及从同一范式派生的其他提示重用这些状态。当遇到新的提示时,模型可以检索缓存了的注意力状态来导入提示模块,并计算其中参数和和新的文本片段,并最后拼接起来产生整个提示的注意力状态。
最后,这种方法能在任何支撑KV Cache的Transformer架构上使用,作者在此构造的原型能在Llama2、Falcon和MPT上使用,他们考虑了GPU和CPU两种内存。时延对比可以看figure3和4,精度对比可以看table1。