pp003.Think before you speak: Training Language Models With Pause Tokens

 

AI前沿|噢!在这停顿!加入暂停标记训练推理,模型效果居然变得更好

编者按:论文作者没有明确给出文中方法的生效的原理,但猜想可能是暂停标记引起了更宽的计算通路,这恰好和“从理论视角说明COT为何有效”这篇文章中的一个关键结论“更长的上下文带来更宽的计算通路可能是COT的关键”相互联系起来。

LM通过即刻连续生成一连串token的方式来生成回复:第(K + 1)个token是每层操作K个隐藏向量的结果,每个前序token一个向量。作者的问题是:如果我们让模型在输出第(K + 1)个标记前操作K+10个隐藏向量,会怎样?作者通过使用(可学习的)暂停标记(pause token)对语言模型进行训练和推理来实现该想法(该暂停标记的序列附加到输入前缀上)。然后再延迟提取模型的输出,直到看到最后一个暂停标记,从而允许模型在提交答案之前做额外的计算。作者在1B和130M参数的仅解码器模型上做了暂停训练的实验,在C4上做因果预训练,以及涵盖推理、问答、一般理解和事实回忆的下游任务。主要发现是,当模型经过延迟预训练并通过延迟微调时,推理时也用延迟的方式会对任务有所增益。对于1B模型,作者在8项任务上取得了进步,最显著的是SQuAD的QA任务的EM分数提高了18%,CommonSenseQA的EM分数提高了8%,GSM8k的推理任务的准确度提高了1%。作者希望这种延迟预测会成为广泛适用的新范式。

Transformer生成某个token时的计算量由先前出现的token数量所限制,尽管这种设计在最初非常自然,但作者认为从事后来看对于某些输入,第(K + 1)个token是否需要每层中的K+M个Transformer操作(M>0)?而每层任意约束K个操作无法满足这一要求。作者探索了一种使Transformer摆脱这种计算约束的方法,如figure1中是标准与暂停推理或暂停微调的对比,作者在模型输入前加入一些假的token(他们选的是一个可学习的暂停标记)来延迟模型输出,输出时也先忽略输出直到最后一个暂停标记出现,在后处理抽取真实输出。关键的是这种延迟不仅在推理中插入,也包括在微调或预训练中(见figure2,包含更多的训练细节)。

尽管尚不清楚具体原理,作者也给出了猜想,比如Transformer可能可以利用延迟引起的“更广泛”的计算路径。但更一般地来说,模型只是简单跳过标记引入的延迟。毕竟,该标记在推理时既不提供任何附加信息,也没有足够多的新参数(除了单个标记的少数嵌入参数)可以对训练数据中的任何附加信息进行编码。更糟的是,这些无信息的标记可能会淹没信息信号并损害模型。该问题部分答案在一些文献中被探索过,比如为了理解COT的提升来自于哪,有的人用句号("...")的形式加入虚拟的思维,但仅限于推理时,据推测,现成的模型可能还未学会利用推理时间延迟提供的新计算路径。另一些人用预先设置的虚拟标记进行学习,以添加内存(而非扩展计算)。但他们仅在目标任务上使用这些标记进行训练,只观察到很小的性能提升。

那同时在训练和推理中注入延迟会实现什么效果呢?首先是第一段里提到的大幅性能提升,见figure3。另外,当只在下游下游微调阶段(在标准预训练模型上)引入暂停标记时,他们发现增益更少了,甚至出现性能明显下降。作者还进行了一些消融实验,见figure4:(a)发现向后添加标记比前置它们要更好,(b)对于下游任务,都有相应的最佳暂停标记数量,(c)当减少推理时暂停标记数量时,发现性能会温和下降。