Back to blogs

OT for generative modeling 2 — Wasserstein gradients and drifting models

Topic: Machine Learning

Date:

Diffusion and flow matching split the generative problem into two phases: at training time, learn a vector field; at inference time, integrate an ODE or SDE through that field to produce a sample. The integration is expensive, and a great deal of recent work has gone into compressing it: distillation, consistency models, progressive reduction.

Recent work on Drifting Models (Deng et al., 2026) defines a “drifting field” that tells each generated sample which direction to move, requires antisymmetry () so that the field vanishes at equilibrium (), and train the generator to chase its own drifted targets.

We currently have a mechanical interpretation — the loss going to zero produces desirable behavior. But equipped with the Wasserstein machinery we built in parts 01, we can say something much sharper. The antisymmetric drifting field is the Wasserstein gradient of a distribution discrepancy; the training dynamics execute gradient descent on the manifold ; and the regression loss against drifted targets is, through the Wasserstein bridge, a gradient step on a proper statistical divergence. The whole paradigm admits a precise connection to maximum likelihood.

For those looking for novel content, the main results of this post are as follows:

  1. Statistical interpretation of Gaussian drifting: we show that drifting with a Gaussian kernel implements Wasserstein gradient descent on the reverse, mode-seeking KL divergence between KDE-smoothed distributions. The stop-grad loss implements gradient pullback from sample to parameter space.
  2. Maximum likelihood modification: we derive the drifting field for the maximum likelihood (forward KL) objective . The changes to the current paradigm are minimal: reweigh by the density ratio and use the Gaussian (instead of Laplace) kernel. The resulting drifting field is notably not antisymmetric.

Contents

Formulation

The drifting models paradigm consists of:

  • a drifting field that tells generated samples which direction to move
  • a training loop that chases drifted targets

In this generative paradigm, we’re given samples from , we consider an initial noise distribution , and a parametric model generator . Denote the pushforward measure by . To be consistent with the paper, we abbreviate

The work considers general antisymmetric drift fields ; each field is a vector field on the sample space. The training loop consists of iteratively minimizing

Note that . Further note that if , then

The stop-grad loss exactly implements the pullback of the gradient. If is a gradient on sample space , then is the pullback of the gradient in parameter space.

The drifting field

Let’s consider the authors’ choice of the drifting field

Definition 1 (canonical drifting field).

Consider the following antisymmetric drifting field evaluated at sample space :

The authors chose the Laplace kernel

The per-point normalization translates to softmax weighting

Wasserstein Gradient Flow

We take a step back to develop the theory of Wasserstein gradient flow (and return to capital letters): given a probability distribution , we can assign some loss / preference to it by some functional . What happens to the probability distribution as we try to minimize by gradient descent?

The Kullback-Leibler functional

Fixing data distribution , maximizing likelihood of the data under the model is equivalent to minimizing the KL divergence:

For more interesting properties of KL divergence, see these notes. From a SGD perspective, minimizing KL is equivalent to maximizing likelihood when empirical samples are i.i.d from :

Gradients on manifolds

I like to interpret differential geometry as the “lifting” of Euclidean constructs into locally Euclidean manifolds. Gradients are no different. In Euclidean space, given a curve and a linear function , the chain rule yields

Lifting the inner product to the manifold metric, we can use this to define gradients on manifolds:

Definition 2 (Wasserstein gradients).

Given a scalar function , the gradient of at the point is the unique tangent vector such that, for any curve with , we have

where is our familiar Wasserstein metric on , and in the second equality is the familiar Euclidean metric, after we expanded the definition of the Wasserstein metric.

Now, we’re equipped to state a major result in Otto calculus. We’ll prove it shortly.

Theorem 1 (fundamental theorem of Otto calculus).

Given a probability functional , its Wasserstein gradient can be computed as

The theorem should look fairly intuitive: on the RHS, we compute the pointwise derivative of w.r.t. the point density at and use this as the potential. The theorem tells us that the direction of steepest -ascent is the gradient of this (functional derivative) potential.

Several remarks are in order:

  1. Despite appearing like a definition, this is a theorem! The general differential-geometry gradient exists, but its general computation does not usually admit such easy form.
  2. The expression is a scalar function on the sample space that’s usually known as the functional derivative. Its values are the point-wise derivatives of w.r.t. .

Again, the functional derivative is a scalar function on the sample space. Its definition is best demonstrated by two useful examples:

Example.

For the entropy functional , applying the product rule yields:

The KL functional has two arguments:

Taking the functional derivative w.r.t. the data distribution , we treat as a constant:

Similarly w.r.t. :

Example (applying Otto’s theorem to entropy).

From above, . Applying the theorem:

The Wasserstein gradient of entropy is the negative score. Gradient ascent on entropy has velocity . Plugging into the continuity equation:

This is the heat equation: heat diffusion is Wasserstein gradient ascent of entropy.

Example (applying Otto’s theorem to forward KL).

Apply to in . From above, . Applying the theorem:

Gradient descent velocity: .

Example (applying Otto’s theorem to reverse KL).

Apply to in . The functional derivative is

Applying the theorem:

The Wasserstein gradient is a score difference. Gradient descent velocity: . Particles flow in the direction where the data score exceeds the model score. This is exactly the drifting field’s structure — but it requires , the data score. With empirical samples (Diracs), this is undefined: the Dirac trap that we return to in the next section.

Proving Otto’s theorem

We need to show that satisfies the gradient definition for all test tangent vectors .

By the gradient definition, is the unique tangent vector satisfying

for all test velocities , where is the perturbation of induced by flowing along . By the continuity equation, an infinitesimal flow along produces perturbation . By the definition of the functional derivative,

Applying the divergence theorem (the boundary term vanishes since and decay at infinity):

Statistical interpretation of drifting

We now connect the drifting field to Wasserstein gradient flow and ask: what functional is being minimized, and what does this have to do with maximum likelihood?

Recall the antisymmetry property . If you have carefully followed through all of the previous examples, this should remind you of the Wasserstein gradient of reverse KL . Let’s explore this.

Gaussian kernel smoothing implements Reverse KL

The reverse KL example above gave us — a score difference matching the drifting field’s attraction-minus-repulsion structure. But with empirical Diracs, is undefined: the Dirac trap. Diffusion models resolve this by convolving with noise. Drifting models resolve it differently — via implicit kernel density estimation (KDE). Let’s see how this works; define the KDE-smoothed distributions:

Definition 3 (KDE-smoothed distributions).

Given empirical distributions and the Gaussian kernel , define

Note that and . The per-point normalization constants in the drifting field are the KDE densities.

Now consider the reverse KL between the smoothed distributions: . From the reverse KL example above, the Wasserstein gradient descent velocity is:

We expand each score using the log-derivative trick . For the smoothed data score:

For the Gaussian kernel, . Absorbing the into the learning rate:

This is precisely the data attraction field. The same calculation with yields , the model repulsion field. Therefore:

Theorem 2 (drifting field as Wasserstein gradient).

The canonical drifting field is the negative Wasserstein gradient of the reverse KL between KDE-smoothed distributions:

Each training step executes the Jacobian pullback of the Wasserstein gradient descent w.r.t. .

Remark (the Laplace deviation).

The derivation above assumes a Gaussian kernel throughout. The actual paper specifies a Laplace kernel , whose true spatial gradient is

This is a unit-direction pull of constant magnitude . The paper’s drifting field uses Laplace scalar weights but Gaussian vector gradients — a deliberate hybrid. This might have been an empirically successful choice because Laplace weights have heavier tails than Gaussian, preventing kernel starvation in high dimensions.

Implementing forward KL

The reverse KL functional is mode-seeking: particles collapse into well-defined modes and ignore the rest, because the cost of generating samples where data density is zero is infinite. Under the maximum likelihood estimate principle, we typically prefer forward KL , which is mass-covering: the model is forced to stretch over the entire support of the data.

Apply Otto calculus to . The functional derivative w.r.t. the flowing model is (from the forward KL example above). The Wasserstein gradient descent velocity is:

The term in parentheses is exactly the reverse KL velocity — the canonical drifting field :

Theorem 3 (forward KL via density ratio scaling).

Since are already computed to evaluate the drifting field, the density ratio is free.

Proposition: maximum likelihood drifting

Combining the results above, we can state concretely what a maximum-likelihood variant of drifting looks like.

Theorem 4 (MLE drifting field).

The Wasserstein gradient descent velocity for the forward KL between KDE-smoothed distributions, using the Gaussian kernel, is

where is the canonical drifting field evaluated with Gaussian kernel weights.

What does this mean in practice? The changes to the existing training protocol are minimal. Here is the current Deng et al. procedure:

Existing protocol (Deng et al.):

  1. Sample a batch of data points and generate model outputs .
  2. For each model sample , compute the Laplace kernel values and against all data and model samples respectively. Normalize to obtain , , and evaluate the drifting field .
  3. Form drifted targets .
  4. Minimize .

MLE modification (two changes):

  1. Same.
  2. Same, but replace the Laplace kernel with the Gaussian kernel.
  3. Form drifted targets . That is, scale the drifting field by the density ratio.
  4. Same.
Deng et al.MLE modification
KernelLaplace Gaussian
Drifting field
Antisymmetric?YesNo
Functional minimized Reverse KL (mode-seeking)Forward KL (mass-covering, MLE)

Several qualitative differences are worth noting:

Antisymmetry is lost. The density ratio breaks the symmetry: in general. But equilibrium is preserved — when , we have and , so the field still vanishes.

Mode-covering vs mode-seeking. This is the main qualitative shift. Forward KL penalizes the model for assigning low density where data is present: dropped modes are actively hunted. The density ratio acts as a spatially varying learning rate — particles in under-represented regions ( small, large) take larger steps, while particles in over-crowded regions slow down.

Comments