Skip to content

RLHF 中的 PPO 代码拆解

1. RLHF 三阶段

RLHF 最突出的应用之一是使大语言模型能够与复杂的人类价值观对齐, 让大语言模型 (LLM) 变得更靠谱、更精准、更合乎伦理。

根据 OpenAI 的思路, RLHF 分为三步:

1.1 Supervised Fine-Tuning(SFT)

在步骤 SFT 中,采用有监督的方式对预训练的语言模型进行微调。这又被称为行为克隆(Behavioral Cloning,简称 BC),即直接使用专家的行为数据(例如,专家在特定情况下采取的动作)来训练模型。在行为克隆中,模型的目标是尽可能地复制专家的行为,而不是尝试优化某种奖励函数,所以它可能无法处理那些专家数据中没有覆盖到的情况,因为它完全依赖于专家的行为数据。

1.2 Reward Model (RM)

RM,奖励模型(Reward Model)的目标是训练一个模型来适应人类的偏好。在这个阶段,首先从提示库中进行采样,并使用大型语言模型生成多个响应。然后,人工对这些响应进行排名,根据这些排名训练一个奖励模型。

奖励模型的目标是学习人类对于不同响应的偏好,并将这些偏好编码到模型中。这样,奖励模型可以用来为模型生成的新响应打分,从而在后续的训练中引导模型生成更符合人类偏好的内容。这种方式不仅能帮助模型处理训练数据中未覆盖的情况,也能减少模型生成不确定或模棱两可的回答,从而打破行为克隆的影响。

1.3 RL & Policy Optimization (RLHF)

RLHF 通过引入奖励信号来调整模型的行为,使模型生成的内容更符合人类的偏好。具体来说,在训练过程中,通过最大化预期奖励来调整模型的策略,使模型在选择行为时更倾向于选择可以得到更高奖励的行为。

在这个阶段中,我们首先使用在第一阶段训练的有监督微调模型和第二阶段训练的奖励模型来生成一个初始的策略。然后,我们使用 PPO 算法来调整这个策略,使模型在生成内容时更考虑人类的偏好。通过这个阶段的训练,模型不仅可以理解人类的语言,还可以理解人类的偏好,并生成更符合人类偏好的内容。

2. Reward 模型训练

在强化学习阶段,用到的 Reward Model 和 Critic Model 都使用同一个模型初始化,因此在训练 reward 模型的过程中,也是在训练 Critic Model。

Reward Model 相较于原始的 SFT Model,在后面加上了一个 value head, value head 是一个 Linear,输入维度为模型的 hidden_dim,输出维度为 1,输出表示模型预测每一字符获取的得分。

奖励模型的输入是 Prompt+Answer 的形式,让模型学会对 Prompt+Answer 进行打分。奖励模型最后一层隐藏层的输出维度为(B,L,D),通过一个 D x 1 的全连接层将维度变为(B, L),在 L 这个维度上,第 i 个位置的数据表示:从第 i 个位置到最后一个位置输出所能获得的奖励分值的累加和(和 DQN 里边的 Q 值一个意义)。

python
# huggingface 模型返回值是个list,第0位是模型最后输出的hideen state
hidden_states = transformer_outputs[0]
# v_head为Dx1的全连接网络对最后一维压缩
rewards = self.v_head(hidden_states).squeeze(-1)

对于奖励模型来说,目标是给一个句子进行打分,按理说每个句子对应一个分值就行了,但是目前对于长度为 L 的句子,奖励模型输出了 L 个值。DeepSpeed-Chat 中使用最后一个字符的得分作为整个 Response 的得分(当然也可以使用整个句子中每个字符的平均分作为整体的得分)。

奖励模型训练优化采用 Pair Wiss Loss,即同时输入模型关于同一个问题的两个回答,让模型学会这两个句子哪个分高哪个分低。之所以如此训练是因为,在给奖励模型进行数据标注的过程中,给同一个问题的不同回答量化的打分比较难,但是对他们进行排序相对简单,代码如下:

python
# 同一个batch里边的句子需要等长,短句后边会被padding
# [divergence_ind:end_ind] 索引了padding前一个位置的输出分值
# chosen_reward 是同一个句子pair里分数高的句子,r_truncated_reward是句子pair里分数低的句子
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]

Pair wise loss 代码如下,如果给 pair 里边好的句子打分高(c_truncated_reward),坏的句子(r_truncated_reward)打分低, loss 就会小。

python
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()

在训练强化学习的过程中,会用到 Reward Model (Critic Model,再次提醒, Critic Model 和 Reward Model 是同一个模型的两个副本)的推理过程,通过调用 Forward_value 实现,具体代码如下,返回的值中有两种值, values 表示每个位置 i,从第 i 个位置到最后一个位置的奖励累加值,供强化学习过程中 Critic Model 使用;“chosen_end_scores ”指的是对每个 Prompt+Answer 的打分,供 Reward Model 使用。

3. RLHF 的整体架构

PPO 是一种用于训练强化学习模型的算法。它可以用于调整语言模型,使得模型生成的结果更符合人类的偏好。具体来说,过程可以分为三个阶段:

  • Rollout and Evaluation:

    • 在这个阶段,我们从 Prompt 库里抽样,使用语言模型生成 Response. 将 Prompt 和 Response 一起输入奖励模型(Reward Model, RM)得到奖励得分。
  • Make experience:

    • 在这个阶段,我们收集了一系列的“经验”,即模型的行为和对应的奖励。这些经验包括了模型生成的 response 以及对应的奖励得分。这些经验将被用于下一步的优化过程。
  • Optimization:

    • 在这个阶段,我们使用收集到的经验来更新模型的参数。具体来说,我们使用 PPO 算法来调整模型的参数,使得模型生成的 Response 的奖励得分能够增加。PPO 算法的一个关键特性是它尝试保持模型的行为不会发生太大的改变,这有助于保证模型的稳定性。

通过这三个阶段的微调,我们可以使得语言模型的输出更符合我们的期望,例如更有创造性,更符合人类的偏好等。

RLHF 基于 A2C 方法,这一步包含了四个模型: Actor Model, Ref Model, Reward Model 和 Critic Model。

  • Actor Model:由 SFT 之后的模型初始化而来。作为策略(policy)模型,用于接收上文,做出动作,预测下一个字符。学习完毕之后,我们最终使用的就是这个模型。

  • Reference Model:和 Actor Model 同样初始化自 SFT Model,训练过程中冻结参数,用于和 Actor Model 做对比,保证模型不要偏离原始 SFT Model 太多。

  • Reward Model:作为环境(Env),训练过程中冻结参数,针对每一个状态,给出奖励分数。

  • Critic Model:由 Reward Model 初始化而来,用于近似价值函数,输入为状态 S,估计当前状态的价值 V。

img

4. 代码拆解

4.1 Rollout

在强化学习中, Rollout 是指在给定的策略下模拟环境的过程。在 PPO 中, Rollout 的过程对应于根据当前的语言模型(策略)生成文本(轨迹)。

这个过程依赖于在 prompt 库中抽取的一个 batch 的数据 Batch Prompt 和当前的语言模型 LM。

语言模型接收一个 prompt 作为输入,并生成一个 Response。这些 Response 就构成了我们的"轨迹"。

输入: Batch Prompt, Actor LM

输出: Prompt+Response

python
def _generate_sequence(self, prompts, mask, step):
  max_min_length = self.max_answer_seq_len + prompts.shape[1]
  # This has been added due to a probability/nan error that happens after
  # meta-llama/Llama-2-7b-hf enabled do_sample:
  # https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
  if self.actor_model.module.config.model_type == "llama":
    kwargs = dict(do_sample=False)
  else:
    kwargs = dict()
    with torch.no_grad():
      seq = self.actor_model.module.generate(
        prompts,
        attention_mask=mask,
        max_length=max_min_length,
        pad_token_id=self.tokenizer.pad_token_id,
        synced_gpus=self.z3_enabled,
        **kwargs)
      return seq

4.2 Evaluate

Evaluate 是在强化学习中对生成的轨迹(在我们的例子中就是文本)进行评估的步骤。在 PPO 中,这个评估过程由一个 RM 模型来完成,来为每一对 Prompt+Response 产生一个标量奖励值,这个值表示生成的轨迹的好坏,优化过程会试图最大化这个值。

输入输出

输入: Prompt+Response、RM

输出: Reward

python
reward_score = self.reward_model.forward_value(
  seq, attention_mask,
  prompt_length=self.prompt_length)['chosen_end_scores'].detach(
)

4.3 Old Policy Sampling

这个步骤是 make experience 的过程,计算并存储旧策略的概率、价值等值,来为后面更新的过程服务。

4.3.1 Old Logprobs

这个步骤中,我们从“旧的”策略,即在这个 batch 数据中初始的 LM (initial actor)中计算每个 token 在旧的策略下的概率 Old Logprobs。

这个步骤的重要性在于,我们在优化策略的时候,需要比较新旧策略下动作的概率,以此来更新我们的策略。因此,我们需要存储旧的策略的动作概率作为参考。

之所以要比较这个概率是为了算一个叫 ratio 的值,用这个值更新策略梯度,能限制更新率、

4.3.2 Old Values

Old Values 的含义是旧策略中每个时间步(每个 token 的预测结果)的价值,这个值由 critic 网络进行预测, critic 网络就是 actor 上加几个线性层能够给每个 token 预测一个值。需要这个值的原因是 advantage 的计算依赖于 Old Values。

4.3.3 Ref Logprobs

Ref Logprobs 的含义是最最原始的 LM 对于每个时间步的概率预测,计算这个值的目的是限制 actor 的更新,防止其偏离原始模型太远。

python
def generate_experience(self, prompts, mask, step):
    self.eval()
    seq = self._generate_sequence(prompts, mask, step)
    self.train()
    pad_token_id = self.tokenizer.pad_token_id
    attention_mask = seq.not_equal(pad_token_id).long()
    with torch.no_grad():
        output = self.actor_model(seq, attention_mask=attention_mask)
        output_ref = self.Ref Model(seq, attention_mask=attention_mask)
        reward_score = self.reward_model.forward_value(
            seq, attention_mask,
            prompt_length=self.prompt_length)['chosen_end_scores'].detach(
            )
        values = self.critic_model.forward_value(
            seq, attention_mask, return_value_only=True).detach()[:, :-1]

    logits = output.logits
    logits_ref = output_ref.logits

    return {
        'prompts': prompts,
        'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
        'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,
                                                                    1:]),
        'value': values,
        'rewards': reward_score,
        'input_ids': seq,
        "attention_mask": attention_mask
    }

def gather_log_probs(logits, labels):
	log_probs = F.log_softmax(logits, dim=-1)
  log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
 	return log_probs_labels.squeeze(-1)

4.4 KL Penalty

在 PPO 实现中, KL Penalty 是在模型优化过程中添加的一个惩罚项,用于保证经过强化学习后的模型(新策略 actor)不会过于偏离原始预训练模型(ref model)。

具体来说,首先使用微调过程中的模型(新策略 actor)和预训练模型(ref model)来计算序列中每个词的对数概率。然后,我们计算两个模型输出之间的 Kullback-Leibler (KL) 散度,这是一种衡量两个概率分布差异的方法。该 KL 散度被用作一个额外的奖励信号,并作为优化过程中的惩罚项,用于确保微调后的模型生成的响应不会偏离太远于预训练模型。这样可以保证模型在微调的过程中不会丢失预训练模型学习到的有用的知识和模式。

在图中的 KL Penalty 步骤中,我们会在 reward 上增加这个 kl 惩罚项来实现这个过程。

输入: Ref Logprobs、Old Logprobs、Reward

输出: Token Reward

python
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                    action_mask):
  # 计算 KL 散度,即模型的对数概率与参考模型的对数概率之间的差值
  kl_divergence_estimate = log_probs - ref_log_probs
  # 计算非得分奖励,即 KL 散度乘以 KL 控制器值的负值
  non_score_reward = -self.kl_ctl * kl_divergence_estimate

  # 复制非得分奖励为新的奖励
  rewards = kl_divergence_estimate
  # response 的起始位置和结束位置
  start = prompts.shape[1] - 1
  ends = start + action_mask[:, start:].sum(1) + 1
  # 对 reward 进行裁剪
  reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                            self.clip_reward_value)
  batch_size = log_probs.shape[0]
  for j in range(batch_size):
    # 对于最后一个非mask部分的token,其奖励是偏好模型的得分加上 KL 散度
    rewards[j, start:ends[j]][-1] += reward_clip[j]
  return rewards

4.5 Generalized Advantage Estimation (GAE)

在强化学习中,我们会使用一个 advantage 的概念,来衡量每个时间步动作的价值,他的本质含义是当前采样到的动作(生成的文本)的价值比平均的数学期望价值高的部分。

而在 PPO 中,我们一般使用 GAE 来进行 advantage 的计算。GAE 是一种多步优势估计方法。它通过引入一个权衡参数λ,在单步 TD 误差和多步 TD 误差之间进行权衡,从而减小估计的方差,提高学习的稳定性。

GAE 的主要目标是希望找到一种策略,使得从当前状态开始,采取该策略能够获得的未来奖励最大, GAE 使用了一种名为 TD 误差的概念,这是一种预测未来奖励的方法。然后, GAE 将这些 TD 误差组合成一个加权和,权重由一个衰减因子λ决定。当λ=0 时, GAE 就退化为普通的优势函数估计;当λ=1 时, GAE 就变成了一种名为"蒙特卡洛"的方法。总的来说, GAE 的本质就是把优势估计为后续时间步 TD 误差的加权和。

输入: Token Reward、Old Values

输出: Advantages、Returns

python
def get_advantages_and_returns(self, values, rewards, start):
  lastgaelam = 0
  advantages_reversed = []
  length = rewards.size()[-1]
  # 从后往前遍历整个生成的序列
  for t in reversed(range(start, length)):
    # 计算下一个状态的价值,如果当前状态已经是最后一个状态,则下一个状态的价值为0
    nextvalues = values[:, t + 1] if t < length - 1 else 0.0
    # 计算 δ,它是奖励加上衰减后的下一个状态的价值,然后减去当前状态的价值
    delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
    # 使用 δ 更新 lastgaelam,这是 GAE 公式的一部分
    lastgaelam = delta + self.gamma * self.lam * lastgaelam
    # 将计算的优势值添加到优势值列表中
    advantages_reversed.append(lastgaelam)

  # 将优势值列表反向并转换为张量
  advantages = torch.stack(advantages_reversed[::-1], dim=1)
  # 计算回报值,它是优势值加上状态值
  returns = advantages + values[:, start:]
  return advantages.detach(), returns

4.6 New Policy Sampling

New Policy Sampling 是 PPO 算法中的一个关键步骤。在 PPO 中,策略优化的过程涉及到两个策略:一个是"旧的"策略,这是我们在开始每次优化迭代时使用的策略,另一个是"新的"策略,这是我们在优化过程中不断更新的策略。

New Policy Sampling 就是在新的策略(更新后的 actor)下对轨迹(文本)计算概率的过程。这个信息会被用于计算"Actor Loss",也就是策略梯度的损失。在我们的步骤中, Old Logprobs 是一次性一个 batch 的数据计算的,这是因为在一个 batch 中旧策略都是不变的;而 New Logprobs 是一个 mini batch 计算一次,这是因为新策略每个 mini batch 变一次。

此外这个步骤还会输出 New Values 和 Logits 分别用于 critic loss 和 entropy loss 的计算。

输入输出

输入: Ref Model、Actor、Critic

输出: New Logprobs、New Values、Logits

python
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:])
value = self.critic_model.forward_value(**batch,
                                        return_value_only=True,
                                        use_cache=False)[:, :-1]

4.7 Critic Loss

在 Actor-Critic 强化学习算法框架中, Critic 模型的任务是估计状态的价值函数,也就是预测从当前状态开始,通过遵循某个策略,期望能得到的总回报。Critic 的训练目标是最小化它的预测价值与实际回报之间的差距。

Critic Loss 通常通过均方误差(Mean Squared Error, MSE)来计算。对于每一个状态,我们都有一个由 Critic 预测出的预期回报值 V(s),以及一个真实的回报值 G(returns)。Critic Loss 就是这两个值之间差的平方。在一个批量的数据中, Critic Loss 是所有状态的这个差的平方的平均值。公式如下: Critic Loss=E[(V(s)G)2]

其中 E[] 表示期望值,V(s) 是 Critic 对状态 s(这个时间步的 token)的价值预测(New Values),G 是真实的回报值(Returns)。

通过最小化 Critic Loss, Critic 的预测能力会逐渐提升。因为 Critic 的预测结果会被用来估计每个行动的优势(Advantage),这个优势值又会被用来计算策略的更新(Actor Loss)。

输入: New Values、Old_values、Returns

输出:梯度更新

python
def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    # 将价值函数的预测值裁剪到一个范围内
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    if self.compute_fp32_loss:
        values = values.float()
        values_clipped = values_clipped.float()
    # 计算裁剪前和裁剪后的价值函数损失
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    # 最终的价值函数损失是裁剪前和裁剪后损失的最大值的平均值的一半
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

代码的作用是将 values 裁剪到一个范围内,这个范围是由 old_values - cliprange_value 和 old_values + cliprange_value 确定的,其中 old_values 是初始的价值函数预测值,目的是为了避免 value 的变化太快。

4.8 Actor Loss

在深度强化学习中,我们通常有两个主要的组成部分: Actor 和 Critic。Actor 是策略,它决定文本会被怎么样生成。Critic 则是我们的价值函数估计器,它预测我们从当前状态开始,如果遵循当前的策略,能够得到的未来回报。

Actor Loss 是我们用来优化 Actor 的损失函数。它的计算通常基于优势函数,优势函数表示在给定的状态下采取某个行动比遵循当前策略的期望回报要好多少。

在 PPO 中,我们使用一种称为 Importance Sampling 的技术来计算 Actor Loss。我们比较了在旧策略和新策略下行动的概率(Old Logprobs, New Logprobs),然后将这个比值(也就是 Importance Sampling 的权重)与优势函数 Advantages 相乘,得到了对 Actor Loss 的一个估计。

PPO 的 Actor loss 如下:

L(s,a,θk,θ)=min(πθ(a|s)πθk(a|s)Aπθk(s,a),clip(πθ(a|s)πθk(a|s),1ϵ,1+ϵ)Aπθk(s,a))

rθ=πθ(a|s)πθk(a|s)是新旧策略的比率。Aπθk(s,a)是优势函数, clip 是剪裁函数,它将πθ(a|s)πθk(a|s)限制在[1ϵ,1+ϵ] 范围内,ϵ 是一个超参数,通常设置为 0.1 或 0.2。

这个损失函数的目标是最大化策略的期望回报,同时限制新旧策略之间的差异。当新旧策略的比率 rθ超出[1ϵ,1+ϵ] 范围时,剪裁函数会限制其影响,防止策略更新过大。

输入: Old Logprobs, New Logprobs、Advantages

输出:梯度更新

python
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    # 计算新旧策略下概率的比值
    log_ratio = (logprobs - old_logprobs) * mask
    ratio = torch.exp(log_ratio)
    # 计算未截断的策略梯度损失
    pg_loss1 = -advantages * ratio
    # 计算截断的策略梯度损失
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    # 选择两者中较大的作为最终的策略梯度损失
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss

Reference:

  • 图解大模型 RLHF 系列之:人人都能看懂的 PPO 原理与源码解读

Maintained by Robin