ReFT:通过强化微调提升推理能力
代码: https://github.com/lqtrung1998/mwp_ReFT
摘要
提升大型语言模型(LLMs)推理能力的一种常见方法是使用链式思维(CoT)注释进行监督微调(SFT)。然而,这种方法在泛化能力上表现有限,因为训练仅依赖于单一的 CoT 数据。例如,在数学问题求解中,每个问题通常只有一个注释的推理路径。显然,模型从多个注释的推理路径中学习效果会更好。为此,我们提出了一种简单而有效的方法——强化微调(ReFT),以增强 LLMs 在推理中的泛化能力,并以数学问题求解为例进行验证。ReFT 首先通过 SFT 进行预热,然后利用在线强化学习(RL),特别是 PPO 算法,进一步微调模型。在此过程中,模型自动采样大量推理路径,并根据真实答案自然推导出奖励。在 GSM8K、MathQA 和 SVAMP 数据集上的实验表明, ReFT 显著优于 SFT,并且通过结合推理时策略(如多数投票和重新排序)可以进一步提升性能。值得注意的是, ReFT 仅通过与 SFT 相同的训练问题进行学习,无需依赖额外或增强的训练数据,展示了其优越的泛化能力。
1. 引言
当前解决数学问题的前沿方法(Luo et al., 2023; Wang et al., 2023)主要采用监督微调(SFT)来训练模型,使用链式思维(CoT)注释(Wei et al., 2022)。如图 1 所示, CoT 注释概述了解决数学问题的中间推理步骤。
通常,训练数据中每个问题只有一个 CoT 注释,即一个正确的推理路径,用于 SFT。然而,我们发现这可能导致 SFT 模型的泛化能力较弱。事实上,同一个问题往往存在多个有效的 CoT 注释(Cobbe et al., 2021; Zhang et al., 2023),这表明需要一种更强大的微调方法。为此,我们提出了强化微调(ReFT)(图 1 底部)。
ReFT 从预热阶段开始,进行一到两个 epoch 的监督微调(SFT)(图 1,阴影框)。这一阶段使模型具备生成正确数学问题响应的基本能力(Cobbe et al., 2021)。随后, ReFT 通过在线强化学习(RL)算法(Schulman et al., 2017),特别是 PPO 算法,进一步优化模型。通过这种方式, ReFT 能够采样多个正确的推理路径或 CoT 注释,并从中学习(图 2,右)。
由于训练数据包含真实答案,因此在训练 PPO 时可以直接从这些答案中推导出奖励,无需单独训练奖励模型。相比之下, RLHF (Ziegler et al., 2019)必须依赖从人类标注数据中学习的奖励模型。
在预热阶段, ReFT 通过监督学习获得了一定的准确性。在 RL 阶段, ReFT 通过采样各种 CoT 推理路径进一步增强了其能力。通过这种方式, ReFT 获得了比 SFT 更丰富的监督信号,从而在数学问题求解中显著提高了泛化能力(Cobbe et al., 2021; Zhang et al., 2023)。值得注意的是, ReFT 仅通过与 SFT 相同的训练问题进行学习,无需依赖额外或增强的训练数据。事实上, ReFT 与这些数据工程并不冲突,并且可以无缝结合。
我们的贡献如下:
- 我们提出了一种新的微调方法——强化微调(ReFT),利用强化学习解决数学问题。与传统的监督微调相比, ReFT 在相同数据集上表现出更强的泛化能力。
- 我们在两个基础模型 CodeLLAMA (Roziere et al., 2023)和 Galactica (Taylor et al., 2022)上进行了广泛的实验,使用了三个标准数据集: GSM8K (Cobbe et al., 2021)、MathQA (Amini et al., 2019)和 SVAMP (Patel et al., 2021)。实验涵盖了自然语言和基于程序的 CoT,展示了 ReFT 显著提高的性能和泛化能力。
- 我们还展示了 ReFT 在推理时受益于多数投票(Wang et al., 2023)和奖励模型重新排序(Lightman et al., 2023),进一步提升了其性能。
2. 相关工作
2.1 数学问题求解
最近的研究主要集中在 CoT 提示设计和数据工程上。大多数研究试图使 CoT 全面且细致,以呈现逐步推理的解决方案(Wei et al., 2022; Gao et al., 2023; Zhou et al., 2023)。Gao et al.(2023)进一步提出使用 Python 程序作为 CoT 提示,展示了更准确的推理步骤,并显著优于自然语言 CoT (Wei et al., 2022)。Zhou et al.(2023)引入了一种提示方法,生成代码以验证 GPT-4 (OpenAI, 2023)的中间推理步骤,从而在 GSM8K (Cobbe et al., 2021)和 MATH (Hendrycks et al., 2021)数据集上实现了最先进的性能。另一项工作则集中在提高 CoT 的质量(Wang et al., 2023; Lightman et al., 2023)和增加 CoT 数据的数量(Wang et al., 2023; Yue et al., 2023),这些数据通常来自 OpenAI 的 ChatGPT (GPT-3.5-turbo)或 GPT-4。
2.2 强化学习
我们的工作与最近将 PPO (Schulman et al., 2017)应用于自然语言处理以对齐人类偏好的研究密切相关(Ziegler et al., 2019)。此后,提出了几种训练算法以有效提高对齐效果,包括直接偏好优化(DPO)(Rafailov et al., 2023)、身份偏好优化(IPO)(Azar et al., 2023)和 Kahneman-Tversky 优化(KTO)(Ethayarajh et al., 2023)。除了对齐目的外,我们旨在采用强化学习作为一种微调范式,以提高传统监督微调的性能。
特别是对于解决数学问题, Uesato et al.(2022)和 Lightman et al.(2023)训练了一个基于结果或基于过程的奖励模型,以执行重新排序(Cobbe et al., 20212021),从而在 SFT 和多数投票(Wang et al., 2023)上实现了更好的性能。虽然我们的方法旨在提高策略本身的性能,但这些奖励模型重新排序方法可以轻松集成到生成的策略模型中。
3. 方法
在本工作中,我们专注于自然语言 CoT(N-CoT)(Wei et al., 2022)(图 1)和基于程序的 CoT(P-CoT)(Gao et al., 2023),使用 Python。Gao et al.(2023)提出了基于程序的 CoT 用于数学问题求解。我们可以简单地执行程序以获得答案。为了确保清晰并避免歧义,我们使用术语 N-CoT 和 P-CoT 分别表示自然语言和基于程序的 CoT。
3.1 强化微调
提出的强化微调(ReFT)过程包括两个阶段:预热阶段和强化学习阶段。整体算法如算法 1 所示。
预热:在此阶段,策略在由"(问题, CoT)"元组组成的数据集上进行几个 epoch 的微调:(<eos>表示生成过程终止。CoT
其中
当生成的动作是<eos> Token 时,结果状态
强化学习:在此阶段,策略通过在线自学习的形式提高其性能,使用由(问题,答案)元组组成的数据集:(
这种部分奖励可以帮助减少从稀疏奖励中学习的影响(Riedmiller et al., 2018; Trott et al., 2019)。此外,遵循 Zheng et al.(2023),我们的总奖励是奖励函数得分和学习的 RL 策略与初始策略之间的 Kullback-Leibler (KL)散度(Kullback and Leibler, 1951)之和,乘以系数因子
广义优势估计(Schulman et al., 2018)用于优势计算:
其中时间差分(TD)定义为
终止状态值
最后,策略和价值目标可以写成以下两个方程
其中
其中
4. 实验
4.1 数据集
我们在三个数学问题数据集上进行了实验: GSM8K (Cobbe et al., 20212021)、SVAMP (Patel et al., 2021)和 MathQA (Amini et al., 2019)。对于 GSM8K 和 SVAMP,答案的格式是数值。在 MathQA 中,答案格式是多项选择列表(即 ABCD)。表 1 展示了所有数据集的统计信息。我们使用 GPT-3.5-turbo 进行少样本提示(Wei et al., 2022; Gao et al., 2023)以获得 N-CoT 和 P-CoT 注释。N-CoT 和 P-CoT 注释的获取遵循 Jie et al.(2023)。我们还在 MathQA 的数值版本(Jie and Lu, 2023)上进行了额外的实验,其中答案格式也是数值。这些实验用于展示我们对 MathQA 上潜在奖励黑客现象的假设(Skalse et al., 2022)(§4.4)。
4.2 基线
我们将 ReFT 与 SFT 和自训练(Xie et al., 2020; Amini et al., 2022)基线进行比较。SFT 简单地在训练数据上微调语言模型。自训练方法的实验确保了相对公平的比较,因为这些方法共享从模型生成的样本用于训练的机制。
我们实现了离线自训练(Offline-ST)(He et al., 2020)和在线自训练(Online-ST)(Hoi et al., 2021)。Offline-ST 方法与专家迭代(Anthony et al., 2017; Uesato et al., 2022; Zelikman et al., 2022)类似。我们首先使用早期检查点的 SFT 检查点采样 CoT,并根据真实答案验证它们。我们仅保留那些具有正确答案的专家样本。我们在原始训练数据和专家样本的组合上执行 SFT。
Online-ST 方法与 ReFT 非常相似。与 ReFT 一样, Online-ST 具有相同的预热过程。之后,我们使用即时生成的样本进行持续训练。在每个训练步骤中,模型首先为一个批次采样 CoT,并仅保留那些具有正确答案的样本。生成的批次包括采样和真实 CoT。然后,我们使用监督微调目标
4.3 实验设置
我们在两个基础模型上进行了实验: Galactica-6.7B (Taylor et al., 2022)和 CodeLLAMA-7B (Roziere et al., 2023)。这两个模型在数学求解方面表现出色,并且在最近的推理任务文献中常被采用(Yue et al., 2023; Luo et al., 2023)。
除了与基线的比较外,我们还在 GSM8K 上应用了常见的技术,多数投票(Wang et al., 2023)和奖励模型重新排序(Lightman et al., 2023)。
超参数:在所有实验中,训练使用 8 个 A100-80GB GPU,使用 DeepSpeed (Rajbhandari et al., 2020; Rasley et al., 2020) Zero 阶段 2 和 Hugging Face Accelerate (Gugger et al., 2022)。在 ReFT 的预热阶段,我们使用 AdamW (Loshchilov and Hutter, 2017)优化器,预热比例为 10% 。批量大小为 48,学习率为 1e-5。最大长度设置为 1024。预热阶段的 epoch 数在所有设置中为 2,除了在 MathQA
对于 SFT 基线,我们训练模型 40 个 epoch,并选择性能最佳的检查点。这个 epoch 数被选择为足够大,以确保 SFT 收敛。对于 Offline-ST 基线,我们使用 ReFT 预热阶段的检查点采样 CoT。使用生成温度为 1.0 和最大长度为 1024,我们为每个问题采样 100 个 CoT,并仅保留那些具有正确答案的样本。遵循 Singh et al.(2023),我们然后将 CoT 子采样为每个问题 10 个随机唯一的 CoT,以平衡问题的难度。微调的 epoch 数设置为 20,这足够大以确保训练收敛。如§4.2 所述, Online-ST 基线试图模仿与 ReFT 相同的设置。我们具有相同的预热过程,并且超参数设置大致与 ReFT 相同。
奖励模型重新排序:遵循 Cobbe et al.(2021); Uesato et al.(2022),我们训练一个奖励模型(RM)来确定 CoT 的正确性。为了构建 RM 训练数据,我们使用预热阶段的模型并执行采样,以获得训练集中每个问题的 100 个 CoT。CoT 被去重,并且可以通过将提取的答案与真实答案进行比较来获得二进制标签。
作为一种常见做法,奖励模型是一个语言模型,从最佳 SFT 检查点初始化(Cobbe et al., 2021; Ouyang et al., 2022)。与基于结果的奖励模型(ORM)(Uesato et al., 2022)类似,奖励模型被训练为预测一个二进制标签,指示"正确"或"错误"的解决方案。一旦输入通过奖励模型,分类将在最后一个 Token 的隐藏状态上进行线性分类。最后,选择具有最高"正确"分数的候选解决方案作为最终答案。我们使用批量大小为 24、最大长度为 700 和线性学习率计划训练 RM 模型 3 个 epoch,预热期为 10% ,最大学习率为 1e-6。
评估:我们在所有数据集上报告 N-CoT 和 P-CoT 的值准确性。对于多数投票和重新排序(表 4),我们采样 100 个 CoT 进行评估。在投票中,选择具有多数计数的有效答案作为计算准确性的最终答案。在重新排序中,我们选择具有最高分数的 CoT 并提取答案。
4.4 结果
ReFT 优于 SFT:表 2 比较了基线和提出的 ReFT 在 GSM8K、SVAMP 和 MathQA 数据集上的性能。我们可以观察到, ReFT 在除 MathQA N-CoT 外的所有数据集上始终优于 SFT。具体来说,我们在 GSM8K N-CoT 和 P-CoT 上分别比 SFT 提高了近 10 分和 12 分。平均而言,我们在所有数据集上使用 CodeLLAMA 在 N-CoT 和 P-CoT 上分别提高了 6.7 分和 7.4 分。值得注意的是, ReFT 中没有使用额外的注释或奖励模型。如此强劲的结果展示了 ReFT 的强大泛化能力(见分析§5),并为进一步探索训练数据与强化学习的潜力提供了巨大的空间(Lu et al., 2023)。
离线自训练包括从初始策略中采样数据进行微调。我们可以看到,这个简单的基线可以比 SFT 提高性能(He et al., 2020; Gulcehre et al., 2023),但改进远远落后于 ReFT。这种比较表明,"探索"在 ReFT 中对于获得良好性能至关重要。尽管在线自训练在 Galactica 上取得了一些改进,但平均而言仍远远落后于 ReFT。这一结果表明,错误实例对于指导模型进行更好的探索也非常重要。与自训练的比较还表明,提出的具有在线采样和强化学习的方法优于标准的数据增强方法。
MathQA 的奖励黑客:我们对 MathQA 上的负面结果的调查表明, ReFT 在多选题训练期间遭受奖励黑客(Skalse et al., 2022)。图 3 展示了采样解决方案如何产生"不准确的奖励",这使得 RL 训练受到影响。正如我们所看到的,采样的 CoT 获得了一个错误的答案"172",这不是"18"和"22"乘积的一半。然而,最终的推理步骤仍然预测选项"C"作为最终答案,因为模型总是从{A, B, C, D, E}中预测一个选项,而不管中间 CoT 的正确性。因此,这种误导性的 CoT 将获得正奖励"1",并误导模型将其视为正确的 CoT。潜在的奖励黑客现象严重干扰了模型训练(Everitt et al., 2021)。这也是我们选择具有较长预热步骤的检查点用于 MathQA N-CoT 以减少奖励黑客效应的原因。
为了进一步展示多选题的负面影响,我们在 MathQA 变体上进行了实验,由 Jie and Lu (2023)提出, MathQA 上的奖励黑客效应。然而,开发一个可靠的基于过程的奖励模型是昂贵的,并且需要广泛的手动注释推理步骤。认识到这些挑战,我们认为控制奖励黑客及其分析是未来工作中需要解决的重要问题。
多数投票和重新排序使 ReFT 受益:遵循 Wang et al.(2023b); Uesato et al.(2022); Lightman et al.(2023),我们还执行了多数投票和奖励模型重新排序,以展示 ReFT 可以从这些常见技术中受益。具体来说,我们从 SFT 和 ReFT 策略中执行采样。我们为每个问题采样 100 个 CoT 解决方案,并使用§4.3 中描述的奖励模型执行重新排序。表 4 中的结果表明, ReFT 通过奖励模型重新排序在 GSM8K 上始终表现最佳。ReFT + Voting 在所有设置中平均比 SFT + Voting 高出 8.6 分。ReFT 与重新排序相比, SFT 与重新排序相比,平均高出 3 分以上。
与现有的开源方法(Luo et al., 2023; Wang et al., 2023; Yue et al., 2023)(表 4 底部)相比,我们最好的 P-CoT 变体在 GSM8K 上实现了最佳性能,准确率为 81.2。此外,这些方法主要包括从 ChatGPT 生成的额外数据,并在微调期间进行蒸馏。相比之下,我们通过挖掘现有训练数据的潜力并推动策略性能的极限来改进策略本身。我们在表 4 中报告的最佳结果,即 CodeLLAMA + ReFT + Reranking with P-CoT 设置,甚至超过了 GPT-3.5-turbo。然而,我们使用了一个只有 7B 大小的模型获得了这一结果。
小模型实验:直观上,探索可能会导致小语言模型的不完美演示。我们在 Galactica-125M、Codeparrot-small 和 Codegen-350M 上使用 P-CoT 数据进行了实验。表 5 展示了 SFT 和 ReFT 之间的性能比较。令人惊讶的是, ReFT 在三个数据集上仍然优于 SFT。这些改进展示了 ReFT 在探索合理程序时的鲁棒性。
消融研究:我们使用 CodeLLAMA 在 GSM8K P-CoT 上进行了消融研究(表 6)。如果没有部分奖励, ReFT 获得的准确率较低,为 74.4,但仍然比 SFT 好得多。如§3.1 所述,这种部分奖励可以帮助减少训练期间稀疏奖励的影响(Trott et al., 2019)。此外,如果我们将 KL 系数
5. 分析
泛化:图 4 展示了使用 CodeLLAMA 作为基础模型在 GSM8K P-CoT 上训练的 ReFT 的平均奖励、评估准确性和 KL 散度。SFT 在接近 40 个 epoch 时收敛并过拟合。然而,我们可以看到 ReFT 策略在 40 个 epoch 时的平均奖励约为 80% 到 90% ,并且值准确性也在增加。此外,我们可以看到 KL 散度(图 4 (c))在开始时非常大,然后保持在 0 到 10 之间的合理值。稳定的 KL 散度表明我们的策略在包含适当程序的空间内进行探索。潜在的强化学习机制大大提高了 ReFT 的泛化能力(Brown et al., 2020)。
定性评估:我们进行了人工评估,以定性评估 SFT 模型、预热检查点和 ReFT 模型的输出。评估使用 50 个问题,并采样 GSM8K 测试集中所有三个模型都能正确解决的解决方案。我们要求四位不同的注释者根据以下标准对推理路径进行评分,每个标准评分从 0 到 1。
- 逻辑:评估导致答案的逻辑是否正确。
- 命名:评估变量是否传达了适当且合理的语义。
- 紧凑性:评估推理路径是否包含冗余信息。
完美得分为 3 表示在这三个维度上表现良好。为了确保评估的公正性和忠实性,我们严格遵循以下设置:(1)每个推理路径的来源(来自 SFT、预热或 ReFT)被匿名化,以防止注释者偏见。(2)四位不同的注释者负责不同的样本部分。
如表 7 所示,尽管总体得分非常接近,但 ReFT 的表现略优于 SFT,并且优于预热变体。请注意, SFT 本质上是训练来学习真实答案的,因此很可能获得高分。这种比较分析强调了 ReFT 在生成准确且语义一致的推理路径方面的鲁棒性。
ReFT 何时超越 SFT?:为了进一步研究 ReFT 和 SFT 之间的关系,我们使用不同数量的 SFT 预热步骤进行 ReFT 训练。图 5 展示了不同 ReFT 变体与 SFT 的值准确性。具体来说,如果预热步骤为 3,则意味着策略从第 3 个 epoch 的 SFT 检查点初始化。我们可以看到,所有 ReFT 策略的性能在预热后立即下降,直到训练 epoch 达到 8 左右。因为共享价值模型中的线性层是随机初始化的,可能需要几个 epoch 来调整分布。从第 30 个 epoch 开始, SFT 收敛,所有 ReFT 变体仍在改进。我们还可以看到,所有变体都显著优于 SFT,并且没有明显的特定 ReFT 变体的优势。
结论
我们引入了强化微调(ReFT)作为一种新的方法,用于微调模型以解决数学问题。与 SFT 相比, ReFT 通过探索多个 CoT 注释来优化不可微的目标,而不是依赖于单个注释。
通过在三个数据集上使用两个基础模型进行广泛的实验,我们证明了 ReFT 在性能和泛化能力方面优于 SFT。此外,我们展示了使用 ReFT 训练的模型与多数投票(Wang et al., 2023)和奖励模型重新排序(Cobbe et al., 2021; Uesato et al., 2022)等技术的兼容性。
此外, ReFT 在数学问题求解方面表现出优于几个公开可用的开源模型的性能。这证明了 ReFT 方法的有效性和实用价值。
6. 未来工作
我们首次尝试将强化学习,特别是 PPO 算法(Schulman et al., 2017),应用于 LLMs 的微调以解决数学问题。我们的未来工作包括利用离线强化学习技术(Levine et al., 2020; Gulcehre et al., 2023),开发一种无预热方法以提高训练效率和性能,从而缩小与重新排序方法的差距。此外, Lightman et al.(2023)建议,训练有素的基于过程的奖励模型(PRM)可以显著提高性能。因此,值得探索在强化学习训练中实现基于过程的奖励。最后,由于 ReFT 是一种多功能方法,我们打算将其应用于更一般的推理任务,其中推理可以通过 CoT 形式化。