Rollout likelihood generalization of maximum likelihood reinforcement learning
Topic: Machine Learning
Date:
I’ve always wondered what happens if one applies RL to supervisable tasks. For example, given a binary classification task (x_k, y_k), maximize accuracy as the reward
No one does this, presumably for good reasons; how does such a model compare to a normal model trained using cross-entropy?
Turns out that a highly impressive recent paper from CMU goes exactly down this rabbit hole — and comes up with actionable, principled insights. They point out that correctness-based RL is optimizing
where
whose gradient upweights low-pass-rate inputs by a factor
What’s in the paper
- A compute-indexed family of truncated objectives interpolating between RL and ML,
- A simple unbiased estimator whose expected gradient matches the truncated objective,
- Experiments demonstrating strong Pareto improvements over common baselines (e.g. GRPO/RLOO) including up to (~20x) rollout scaling efficiency gains in their reasoning setups.
What’s new here
Luckily for our understand-by-practice purposes, the authors focus on the binary, discrete-reward setting. This post unpacks the following qualitative behavior of MaxRL:
- is sharper than direct objectives, admitting a natural bounded truncation that interpolates RL -> ML with rollout budget,
- fixing a prompt/sample, upweights the most successful rollouts (soft-max / log-sum-exp behavior),
- marginalizing over rollouts, upweights the most difficult prompts via inverse-probability reweighting.
We also put forward a generalization that abstracts at the level of per-rollout likelihood, admitting application to e.g. regression tasks. We expect this generalization to be highly applicable to regression RL tasks with low signal-to-noise ratio.
Math is cheap, just show me the code
Contents
- Preamble / ramble on MLE
- Formulation, notation
- Direct RL as a Maximum Likelihood approximation
- Gradient estimators
- Pseudocode: generalized MaxRL
Preamble / ramble on MLE
Maximum likelihood estimation (MLE) is a near-axiomatic principle: given data and a parametric family, choose parameters that maximize the probability (likelihood) of observed data.
- Classification is MLE: data-label pairs
, model . - Regression is MLE: MSE corresponds to Gaussian noise MLE; L1 corresponds to Laplace noise MLE.
- VAEs are MLE-ish too: maximize a tractable lower bound on log-likelihood, while tightening that lower bound.
Given the ubiquity of MLE, it’s surprising that Maximum-likelihood RL (MaxRL) is this recent.
Why do we supervise with cross-entropy
So difficult cases (small
Maximum likelihood upweights difficult samples aggressively, updating the model on the frontier of its understanding.
Addendum: I would not over-interpret this as the principle itself. It’s a side-effect interpretation. The principle is still MLE. One example of a canonical interpretation of MLE comes from an information-theory perspective, where MLE is a compression objective that minimizes expected code length under the model. In this lens, each example contributes
Formulation, notation
Let’s start in the standard LLM-RL latent-generation setting.
- Data:
. - Policy/sequence model:
, where is a rollout. - Rollout evaluation: decode/postprocess
into a prediction; compare to target .
In binary reasoning with verifier, the typical objective is
So far this is standard RL setup. Now the useful abstraction.
- A rollout
defines a predictive distribution over labels/targets. - Define per-rollout likelihood
- Marginalize over rollouts:
- Score function:
This score vector is almost always the only intermediate through which policy parameters
Direct RL as a Maximum Likelihood approximation
The maximum likelihood objective (for our latent-generation likelihood) is
This is the exact analogue of cross-entropy: the log is outside the expectation because
Binary reasoning setup
In the binary reasoning setup, typical reward is the expected pass rate:
Max-likelihood wants instead:
This is exactly the paper’s core distinction: RL maximizes
(Also: yes,
Continuous regression
Specialize to a Gaussian noise model with
Then
An intuitive per-rollout reward is negative MSE:
A direct analogue would average these rollout rewards:
Maximum likelihood instead optimizes
(up to constants): a log-sum-exp, rather than a simple average, over per-rollout MSE-derived terms. So within each sample, the objective reinforces the most successful trajectories; we’ll recover this same structure later in the general form.
(If you want this to literally satisfy
Putting them together: Jensen and Taylor
From here on, fix
For
Note that this expansion is about
Truncating to order
: , that is RL / pass-rate training up to an additive constant. : , that is exact maximum likelihood.
Differentiating:
As
Therefore
At this point we can already read off the two qualitative behaviors:
Remark 1 (per-prompt): inside
- Binary case: only successful rollouts contribute (
). - Gaussian regression:
so the best rollouts dominate.
Remark 2 (across prompts): the scalar multiplier
Gradient estimators
Here, you say: “all’s great! There’s this beautiful theory and a highly principled objective. How are we going about optimizing it?”
The MaxRL paper proposes an unbiased estimator for the truncated maximum-likelihood binary objective.
Below we recap the proof and give a per-rollout-likelihood
Binary case
In the binary setting
- REINFORCE (pass@1) uses
which is unbiased for
- MaxRL’s key estimator is: average scores over successful trajectories only,
Conditioning on
It’s an unbiased ML estimator, conditioning on
Generalization
Now we drop the assumption
is a per-rollout likelihood/reward signal, , , .
We want to estimate the last quantity using
- We can estimate
and unbiasedly from samples. - But
is a product of unknowns. - Plugging in sample estimates makes bias because the factors are correlated.
The key here is the leave-one-out trick. Use one sample to estimate
Formally, suppose we have an estimator subroutine
We know that
In addition, this is a particularly neat expression because it is, again, a simple reweighted average of the rollout scores! To reduce variance, we can subtract any constant baseline (since
It remains to solve the problem: given iid samples
- Taking
, it suffices to obtain estimators for for all . - Fortunately, this turns out to be a textbook problem with a known MVUE (Minimal Variance Unbiased Estimator) using U-statistics (something something elementary symmetric polynomials, you’re welcome to look it up — drop a comment!).
Pseudocode: generalized MaxRL
# Generalized truncated Maximum likelihood RL for latent-generation likelihoods
#
# Inputs:
# - model m_theta(z | x)
# - per-rollout likelihood l(y, z) in [0, 1] (can be scaled)
# - truncation order T
# - number of rollouts N >= T (T=1 corresponds to REINFORCE)
def Estimate_wT(l_loo, T):
"""
Args:
l_loo: 1D tensor of length N-1 with values in [0,1], detached;
all samples from the same (x,y). loo for leave-one-out.
T: truncation order.
Returns:
Scalar unbiased estimate of w_T(p) = sum_{k=0}^{T-1}(1-p)^k.
"""
...
return w_hat
def maxrl_step(batch, model, optimizer, T, N, baseline_const=0.0):
# batch: list of (x, y) samples
optimizer.zero_grad()
total_loss = 0.0
for (x, y) in batch:
# 1) sample rollouts and log-probabilities
z_j, logp_j = model.sample_with_logprob(x, N) # logp_j: [N]
# 2) per-rollout likelihoods (no grad)
with torch.no_grad():
l_j = likelihood_l(y, z_j).detach() # [N], in [0,1]
# 3) leave-one-out weights
omega_j = []
for j in range(N):
l_loo = torch.cat([l_j[:j], l_j[j+1:]]).detach()
omega_j.append(Estimate_wT(l_loo, T))
omega_j = torch.stack(omega_j) # [N]
# 4) subtract constant baseline + equivalent-loss (one backward pass)
with torch.no_grad():
adv_j = omega_j * l_j - baseline_const
# Equivalent loss: gradient is (adv_j * grad logp_j)
loss_x = -(adv_j.detach() * logp_j).mean()
total_loss = total_loss + loss_x
total_loss = total_loss / len(batch)
total_loss.backward()
optimizer.step()