AI前沿|噢!在这停顿!加入暂停标记训练推理,模型效果居然变得更好
编者按:论文作者没有明确给出文中方法的生效的原理,但猜想可能是暂停标记引起了更宽的计算通路,这恰好和“从理论视角说明COT为何有效”这篇文章中的一个关键结论“更长的上下文带来更宽的计算通路可能是COT的关键”相互联系起来。
LM通过即刻连续生成一连串token的方式来生成回复:第(K + 1)个token是每层操作K个隐藏向量的结果,每个前序token一个向量。作者的问题是:如果我们让模型在输出第(K + 1)个标记前操作K+10个隐藏向量,会怎样?作者通过使用(可学习的)暂停标记(pause token)对语言模型进行训练和推理来实现该想法(该暂停标记的序列附加到输入前缀上)。然后再延迟提取模型的输出,直到看到最后一个暂停标记,从而允许模型在提交答案之前做额外的计算。作者在1B和130M参数的仅解码器模型上做了暂停训练的实验,在C4上做因果预训练,以及涵盖推理、问答、一般理解和事实回忆的下游任务。主要发现是,当模型经过延迟预训练并通过延迟微调时,推理时也用延迟的方式会对任务有所增益。对于1B模型,作者在8项任务上取得了进步,最显著的是SQuAD的QA任务的EM分数提高了18%,CommonSenseQA的EM分数提高了8%,GSM8k的推理任务的准确度提高了1%。作者希望这种延迟预测会成为广泛适用的新范式。
Transformer生成某个token时的计算量由先前出现的token数量所限制,尽管这种设计在最初非常自然,但作者认为从事后来看对于某些输入,第(K + 1)个token是否需要每层中的K+M个Transformer操作(M>0)?而每层任意约束K个操作无法满足这一要求。作者探索了一种使Transformer摆脱这种计算约束的方法,如figure1中是标准与暂停推理或暂停微调的对比,作者在模型输入前加入一些假的token(他们选的是一个可学习的暂停标记
尽管尚不清楚具体原理,作者也给出了猜想,比如Transformer可能可以利用延迟引起的“更广泛”的计算路径。但更一般地来说,模型只是简单跳过
那同时在训练和推理中注入延迟会实现什么效果呢?首先是第一段里提到的大幅性能提升,见figure3。另外,当只在下游下游微调阶段(在标准预训练模型上)引入暂停标记时,他们发现增益更少了,甚至出现性能明显下降。作者还进行了一些消融实验,见figure4:(a)发现向后添加标记比前置它们要更好,(b)对于下游任务,都有相应的最佳暂停标记数量,(c)当减少推理时暂停标记数量时,发现性能会温和下降。