Back to blogs

Rollout likelihood generalization of maximum likelihood reinforcement learning

Topic: Machine Learning

Date:

I’ve always wondered what happens if one applies RL to supervisable tasks. For example, given a binary classification task (x_k, y_k), maximize accuracy as the reward

No one does this, presumably for good reasons; how does such a model compare to a normal model trained using cross-entropy?

Turns out that a highly impressive recent paper from CMU goes exactly down this rabbit hole — and comes up with actionable, principled insights. They point out that correctness-based RL is optimizing

where is the success probability (“pass rate”), while the likelihood principle suggests optimizing

whose gradient upweights low-pass-rate inputs by a factor .

What’s in the paper

  • A compute-indexed family of truncated objectives interpolating between RL and ML,
  • A simple unbiased estimator whose expected gradient matches the truncated objective,
  • Experiments demonstrating strong Pareto improvements over common baselines (e.g. GRPO/RLOO) including up to (~20x) rollout scaling efficiency gains in their reasoning setups.

What’s new here

Luckily for our understand-by-practice purposes, the authors focus on the binary, discrete-reward setting. This post unpacks the following qualitative behavior of MaxRL:

  • is sharper than direct objectives, admitting a natural bounded truncation that interpolates RL -> ML with rollout budget,
  • fixing a prompt/sample, upweights the most successful rollouts (soft-max / log-sum-exp behavior),
  • marginalizing over rollouts, upweights the most difficult prompts via inverse-probability reweighting.

We also put forward a generalization that abstracts at the level of per-rollout likelihood, admitting application to e.g. regression tasks. We expect this generalization to be highly applicable to regression RL tasks with low signal-to-noise ratio.

Math is cheap, just show me the code

Contents

Preamble / ramble on MLE

Maximum likelihood estimation (MLE) is a near-axiomatic principle: given data and a parametric family, choose parameters that maximize the probability (likelihood) of observed data.

  • Classification is MLE: data-label pairs , model .
  • Regression is MLE: MSE corresponds to Gaussian noise MLE; L1 corresponds to Laplace noise MLE.
  • VAEs are MLE-ish too: maximize a tractable lower bound on log-likelihood, while tightening that lower bound.

Given the ubiquity of MLE, it’s surprising that Maximum-likelihood RL (MaxRL) is this recent.

Why do we supervise with cross-entropy rather than accuracy-like ? One mechanistic lens is the gradient:

So difficult cases (small ) get amplified by .

Maximum likelihood upweights difficult samples aggressively, updating the model on the frontier of its understanding.

Addendum: I would not over-interpret this as the principle itself. It’s a side-effect interpretation. The principle is still MLE. One example of a canonical interpretation of MLE comes from an information-theory perspective, where MLE is a compression objective that minimizes expected code length under the model. In this lens, each example contributes nats of surprise. Hard examples are expensive in description length, so ML naturally prioritizes optimization there.

Formulation, notation

Let’s start in the standard LLM-RL latent-generation setting.

  • Data: .
  • Policy/sequence model: , where is a rollout.
  • Rollout evaluation: decode/postprocess into a prediction; compare to target .

In binary reasoning with verifier, the typical objective is

So far this is standard RL setup. Now the useful abstraction.

  • A rollout defines a predictive distribution over labels/targets.
  • Define per-rollout likelihood
  • Marginalize over rollouts:
  • Score function:

This score vector is almost always the only intermediate through which policy parameters touch policy-gradient objectives. Most algorithms differ mainly in how they reweight or shift these score vectors. Unpacking / understanding the latent variable formulation there already goes 60% of the way towards understanding this post; also note that this formulation subsumes supervised maximum-likelihood when sampling is trivial.

Direct RL as a Maximum Likelihood approximation

The maximum likelihood objective (for our latent-generation likelihood) is

This is the exact analogue of cross-entropy: the log is outside the expectation because is a marginal likelihood.

Binary reasoning setup

In the binary reasoning setup, typical reward is the expected pass rate:

Max-likelihood wants instead:

This is exactly the paper’s core distinction: RL maximizes , ML maximizes .

(Also: yes, is a little silly-looking, but that’s kind of the point: we’re doing maximum likelihood on a Bernoulli observation whose success probability is induced by a non-differentiable latent generator.)

Continuous regression

Specialize to a Gaussian noise model with :

Then

An intuitive per-rollout reward is negative MSE:

A direct analogue would average these rollout rewards:

Maximum likelihood instead optimizes

(up to constants): a log-sum-exp, rather than a simple average, over per-rollout MSE-derived terms. So within each sample, the objective reinforces the most successful trajectories; we’ll recover this same structure later in the general form.

(If you want this to literally satisfy so the Maclaurin series below is automatic, just pick a scale so the peak density is ; multiplying by a -independent constant only adds a constant to , hence doesn’t change the ML gradient.)

Putting them together: Jensen and Taylor

From here on, fix and suppress dependencies to reduce clutter:

For , Taylor-expanding about yields

Note that this expansion is about (success), so implies greater deviation between the first (direct) and full (maximum-likelihood) orders, another perspective on larger improvements for harder tasks.

Truncating to order gives the compute-indexed MaxRL objective

  • : , that is RL / pass-rate training up to an additive constant.
  • : , that is exact maximum likelihood.

Differentiating:

As , , recovering ML’s inverse-probability reweighting. Now use the log-derivative trick on

Therefore

At this point we can already read off the two qualitative behaviors:

Remark 1 (per-prompt): inside , trajectories with larger likelihood get more weight.

  • Binary case: only successful rollouts contribute ().
  • Gaussian regression: so the best rollouts dominate.

Remark 2 (across prompts): the scalar multiplier increases when is small (hard prompts). In the ML limit , reproduces the inverse-pass-rate reweighting.

Gradient estimators

Here, you say: “all’s great! There’s this beautiful theory and a highly principled objective. How are we going about optimizing it?”

The MaxRL paper proposes an unbiased estimator for the truncated maximum-likelihood binary objective.

Below we recap the proof and give a per-rollout-likelihood generalization that recovers the binary construction and extends to likelihood-based objectives such as Gaussian regression.

Binary case

In the binary setting , draw trajectories , define

  • REINFORCE (pass@1) uses

which is unbiased for .

  • MaxRL’s key estimator is: average scores over successful trajectories only,

Conditioning on , note that

It’s an unbiased ML estimator, conditioning on !! The bias comes from which happens with probability . Substituting shows that this exactly equals the gradient form in .

Generalization

Now we drop the assumption and keep only what the derivation above actually used:

  • is a per-rollout likelihood/reward signal,
  • ,
  • ,
  • .

We want to estimate the last quantity using iid rollout samples. The obstacle is subtle but standard:

  • We can estimate and unbiasedly from samples.
  • But is a product of unknowns.
  • Plugging in sample estimates makes bias because the factors are correlated.

The key here is the leave-one-out trick. Use one sample to estimate , the remaining samples to estimate , and average over all leave-one-outs. The product factorizes because samples are conditionally independent.

Formally, suppose we have an estimator subroutine , built from only, such that

We know that is an unbiased estimator for , then the leave-one-out product is unbiased for . Averaging over gives the final estimator:

In addition, this is a particularly neat expression because it is, again, a simple reweighted average of the rollout scores! To reduce variance, we can subtract any constant baseline (since ).

It remains to solve the problem: given iid samples , estimate

  • Taking , it suffices to obtain estimators for for all .
  • Fortunately, this turns out to be a textbook problem with a known MVUE (Minimal Variance Unbiased Estimator) using U-statistics (something something elementary symmetric polynomials, you’re welcome to look it up — drop a comment!).

Pseudocode: generalized MaxRL

# Generalized truncated Maximum likelihood RL for latent-generation likelihoods
#
# Inputs:
#   - model m_theta(z | x)
#   - per-rollout likelihood l(y, z) in [0, 1]   (can be scaled)
#   - truncation order T
#   - number of rollouts N >= T (T=1 corresponds to REINFORCE)


def Estimate_wT(l_loo, T):
    """
    Args:
        l_loo: 1D tensor of length N-1 with values in [0,1], detached;
              all samples from the same (x,y). loo for leave-one-out. 
        T: truncation order.

    Returns:
        Scalar unbiased estimate of w_T(p) = sum_{k=0}^{T-1}(1-p)^k.
    """
    ...
    return w_hat


def maxrl_step(batch, model, optimizer, T, N, baseline_const=0.0):
    # batch: list of (x, y) samples
    optimizer.zero_grad()
    total_loss = 0.0

    for (x, y) in batch:
        # 1) sample rollouts and log-probabilities
        z_j, logp_j = model.sample_with_logprob(x, N)  # logp_j: [N]

        # 2) per-rollout likelihoods (no grad)
        with torch.no_grad():
            l_j = likelihood_l(y, z_j).detach()  # [N], in [0,1]

        # 3) leave-one-out weights
        omega_j = []
        for j in range(N):
            l_loo = torch.cat([l_j[:j], l_j[j+1:]]).detach()
            omega_j.append(Estimate_wT(l_loo, T))
        omega_j = torch.stack(omega_j)  # [N]

        # 4) subtract constant baseline + equivalent-loss (one backward pass)
        with torch.no_grad():
            adv_j = omega_j * l_j - baseline_const

        # Equivalent loss: gradient is (adv_j * grad logp_j)
        loss_x = -(adv_j.detach() * logp_j).mean()
        total_loss = total_loss + loss_x

    total_loss = total_loss / len(batch)
    total_loss.backward()
    optimizer.step()

Comments