Optimal transport & the Wasserstein distance

Measuring the distance between probability distributions by the minimal cost of moving mass from one to the other.

Optimal transport measures how far apart two probability distributions are by the minimal cost of rearranging the mass of one into the other. State the problem plainly first: given two probability distributions μ\mu (source) and ν\nu (target), and a cost function c(x,y)c(x,y) for moving one unit of mass from xx to yy, find a way to move the mass that empties μ\mu entirely, fills ν\nu exactly, and incurs the least total cost. That minimal total cost is the “transport distance” between μ\mu and ν\nu.

In one dimension with squared cost the optimum takes a particularly clean form: pairing the sorted source points with the sorted target points yields the monotone coupling, which is optimal. This special case admits a closed-form solution and makes the central idea — mass moving the minimal total distance — directly visible. The Monge formulation seeks a deterministic map TT pushing μ\mu onto ν\nu at least cost,

Source μTarget ν
Source μTarget νOptimal matching
Total transport cost: 13.79

Under 1D squared cost the optimal coupling pairs the sorted source points with the sorted target points (the monotone matching). Any crossing pair strictly raises the total cost, so the connecting lines never cross; shifting the offset translates the mass and the cost varies smoothly.

infT#μ=νc(x,T(x))dμ(x),\inf_{T_\# \mu = \nu}\int c\big(x, T(x)\big)\,d\mu(x),

symbol by symbol: TT is a function sending each source point xx to a single target point T(x)T(x); the constraint T#μ=νT_\#\mu=\nu (the “pushforward”) requires that the mass carried by TT lands exactly as ν\nu, i.e. for any region AA the source mass μ(T1(A))\mu\big(T^{-1}(A)\big) that maps into AA equals ν(A)\nu(A); the integral sums the per-point cost c(x,T(x))c\big(x,T(x)\big) weighted by the source density μ\mu, and inf\inf takes the infimum over all maps TT meeting the constraint. The catch is that such a map need not exist: when the mass at one source point must be split among several targets, no single-valued function TT can do it. For squared cost with a continuous source, however, it always does, by Brenier’s theorem, and it is the gradient of a convex function.

The Kantorovich relaxation sidesteps this existence obstacle. Instead of forcing each xx to a single destination, it replaces the map by a transport plan, a coupling π\pi with marginals μ\mu and ν\nu:

infπΠ(μ,ν)c(x,y)dπ(x,y).\inf_{\pi \in \Pi(\mu,\nu)} \int c(x,y)\,d\pi(x,y).

A coupling π(x,y)\pi(x,y) is a joint distribution on (x,y)(x,y) recording “how much mass moves from xx to yy”; Π(μ,ν)\Pi(\mu,\nu) is the set of all couplings whose marginals are μ\mu and ν\nu — the marginal constraints π(x,)dy=μ\int\pi(x,\cdot)\,dy=\mu and π(,y)dx=ν\int\pi(\cdot,y)\,dx=\nu formalize “empty μ\mu, fill ν\nu.” Every Monge map corresponds to a coupling concentrated on the curve y=T(x)y=T(x), so the Kantorovich optimum can only be better, never worse. The key payoff: a coupling always exists (the product measure μν\mu\otimes\nu is one), and the problem becomes a linear program over π\pi — the objective cdπ\int c\,d\pi is linear in π\pi, and the marginal constraints are linear too. In the discrete case π\pi is a nonnegative matrix whose row sums and column sums are pinned to μ,ν\mu,\nu; the feasible set is a polytope and the optimum sits at a vertex (a sparse, near-deterministic assignment). (Site convention: π\pi always denotes the coupling; γ\gamma is reserved for the “temperature” in entropic regularization.)

Every such problem has a dual. The dual of a linear program trades “which plan to choose” for “how to price the transport”: introduce potentials f(x)f(x), g(y)g(y) and maximize fdμ+gdν\int f\,d\mu+\int g\,d\nu subject to f(x)+g(y)c(x,y)f(x)+g(y)\le c(x,y). Read f,gf,g as a price list; the constraint guarantees the pricing never beats moving the mass directly, and the dual optimum exactly equals the primal optimum (strong duality). For W1W_1 (p=1p=1) the optimal g=fg=-f with ff forced to be 1-Lipschitz, and the dual collapses to a supremum over a single 1-Lipschitz potential — developed in Kantorovich duality, and the form that turns W1W_1 into the Wasserstein-GAN objective (the discriminator is that 1-Lipschitz potential).

When the cost is a distance raised to the pp-th power, c(x,y)=xypc(x,y)=\|x-y\|^p, the optimal value defines the Wasserstein-pp distance:

Wp(μ,ν)=(infπΠ(μ,ν)xypdπ)1/p.W_p(\mu,\nu) = \left(\inf_{\pi\in\Pi(\mu,\nu)}\int \|x-y\|^p\,d\pi\right)^{1/p}.

The outer 1/p1/p power is what makes WpW_p satisfy the triangle inequality and so be a genuine metric. The common choices are p=2p=2 (squared cost) and p=1p=1 (linear cost, the WGAN case). Unlike the KL divergence, WpW_p remains finite and meaningful even when the two distributions have disjoint supports, and it varies smoothly as one distribution is shifted — a property that makes it attractive as a training loss.

Intuition

A coupling π(x,y)\pi(x,y) is a transportation schedule: it says how much mass moves from xx to yy. Among all schedules that correctly empty μ\mu and fill ν\nu, optimal transport selects the cheapest. Where KL asks “how surprising is one distribution under the other,” Wasserstein asks “how far does the mass physically have to move.”

Depth

Adding a relative-entropy penalty yields entropic optimal transport (Sinkhorn). Against a reference coupling κ\kappa, weighted by a temperature γ>0\gamma>0:

minπΠ(μ,ν)cdπ  +  γKL(πκ).\min_{\pi\in\Pi(\mu,\nu)} \int c\,d\pi \;+\; \gamma\,\mathrm{KL}(\pi\Vert\kappa).

Taking κ=μν\kappa=\mu\otimes\nu (the independent product measure) recovers the standard Sinkhorn penalty. The optimal coupling then takes a Gibbs–Boltzmann form π(x,y)κ(x,y)ec(x,y)/γ\pi^\star(x,y)\propto\kappa(x,y)\,e^{-c(x,y)/\gamma}, strictly positive everywhere. The entropy penalty prevents the plan from collapsing onto a sparse, near-deterministic support, so the resulting transport is smooth, and Sinkhorn’s iteration solves it efficiently by alternating rescalings. The full derivation and its Cryo-ET counterpart appear in entropic optimal transport.

Intuition: moving earth

Picture probability mass as piles of earth. μ\mu is the current height of earth everywhere, ν\nu the target profile we want to reach. Carrying a shovel of earth over some distance costs effort that grows with the distance; the total cost is the sum, over all earth, of how far it travels (weighted by squared distance). Optimal transport asks how to move it for the least total effort — hence the name “earth-mover’s distance.”

In one dimension with squared cost the answer is strikingly simple: sort both sets of piles from small to large, then pair them off by rank. Take four equal-weight units of earth with sources at {0,1,4,5}\{0,1,4,5\} and targets at {2,3,6,7}\{2,3,6,7\}. Sorted pairing gives 0 ⁣ ⁣2,  1 ⁣ ⁣3,  4 ⁣ ⁣6,  5 ⁣ ⁣70\!\to\!2,\;1\!\to\!3,\;4\!\to\!6,\;5\!\to\!7, each unit moving 22, for a mean squared cost of

14(22+22+22+22)=4.\tfrac14\big(2^2+2^2+2^2+2^2\big)=4 .

Any “crossing” plan is worse: pairing 0 ⁣ ⁣30\!\to\!3 and 1 ⁣ ⁣21\!\to\!2 instead costs 32+12=103^2+1^2=10 on those two units, already above the original 22+22=82^2+2^2=8. The monotone (non-crossing) coupling is optimal because squared cost is submodular — crossings can always be “uncrossed” to lower the cost. This is why one dimension admits a closed form, and it is the seed of the higher-dimensional intuition. The catch is that this shortcut holds only in 1-D: in higher dimensions there is no “sort,” and one must actually solve the linear program above, or fall back on the entropic / sliced approximations below.

The cost of computing: why 1-D is cheap and high-D is not

Ranking the three formulations by how hard they are to solve traces out the motivation for the whole optimal-transport toolchain. One dimension has a closed form: sort, pair by rank, and the cost is dominated by the sort at O(nlogn)O(n\log n). But the exact Kantorovich linear program with nn source and nn target points has n2n^2 unknowns and 2n2n constraints; the optimum is sparse (only about 2n12n-1 nonzeros), yet network-simplex / auction solvers still cost roughly O(n3logn)O(n^3\log n) — infeasible when nn is a whole dataset. This is exactly why the other two routes exist: entropic regularization softens the linear program into a strongly convex one solved by Sinkhorn in a few matrix–vector products, while the sliced idea at the end of this page replaces one expensive high-D transport with many cheap 1-D ones. Deep learning almost never solves the exact Kantorovich LP directly; it uses one of these cheap, differentiable approximations.

Versus KL: why it still works when supports are disjoint

Let two distributions have entirely disjoint support: μ\mu a point mass at x=0x=0, ν\nu a point mass at x=δx=\delta. Then KL(μν)=\mathrm{KL}(\mu\Vert\nu)=\infty for every δ\delta, no matter how small — KL cannot tell “barely apart” from “far apart,” so it offers no usable gradient (see entropy and KL divergence). Wasserstein instead gives W2(μ,ν)=δW_2(\mu,\nu)=\delta: it varies smoothly with the gap between the two piles and is differentiable everywhere.

This difference bites hardest early in training. A generated distribution qxq_x initially overlaps the data distribution pyp_y almost nowhere — exactly the disjoint-support regime. A KL- or density-ratio-based adversarial loss (a discriminator) then either saturates or yields exploding/vanishing gradients, training unstably and prone to mode collapse; the optimal-transport cost instead decreases smoothly with the geometric proximity of the two distributions, supplying an always-useful descent direction. This is precisely why CryoGEN-II, by replacing CryoGEN-I’s adversarial min–max with an OT loss, trains more stably.

The geometry of Wasserstein space

Because W2W_2 varies smoothly as mass moves, it turns the space of distributions into a geometry — one can interpolate between μ\mu and ν\nu by moving along the shortest path, not by fading one into the other. The naive linear (mixture) interpolant (1t)μ+tν(1-t)\,\mu + t\,\nu keeps both piles fixed and merely reweights them: μ\mu‘s mass fades out where it sits while ν\nu‘s fades in elsewhere, so the path passes through a bimodal blur that resembles neither endpoint. The displacement interpolant of McCann instead transports mass along straight lines. With TT the optimal Monge map from μ\mu to ν\nu, push μ\mu forward by the partial map that has travelled a fraction tt of the way:

μt=((1t)id+tT)#μ.\mu_t = \big((1-t)\,\mathrm{id} + t\,T\big)_{\#}\,\mu .

Symbol by symbol: id\mathrm{id} is the identity map (xxx\mapsto x, “stay put”), TT is “go all the way to the target,” and their convex combination (1t)id+tT(1-t)\,\mathrm{id}+t\,T is the partial map that has “travelled a fraction tt”; the subscript #{}_\# pushes μ\mu forward through it to give the distribution μt\mu_t at intermediate time t[0,1]t\in[0,1], with endpoints μ0=μ\mu_0=\mu and μ1=ν\mu_1=\nu. Each unit of mass slides at constant velocity from xx to T(x)T(x), so a single bump stays a single bump — it glides across rather than dissolving and re-forming. This path is the geodesic of W2W_2: the shortest curve between the two distributions, and its length accumulates at a constant rate, W2(μ,μt)=tW2(μ,ν)W_2(\mu,\mu_t) = t\,W_2(\mu,\nu).

source μtarget νinterpolant μₜ

Drag t from 0 to 1. Linear is (1−t)μ + tν: mass fades out on the left and reappears on the right, passing through a two-humped blur that resembles neither end. Displacement slides a single bump's mean from μ to ν, staying unimodal — the geodesic (shortest path) of Wasserstein space, which averages positions instead of pictures.

Intuition

Linear interpolation teleports mass — it dims one pile and brightens another, so a halfway frame shows two half-height piles at once. Displacement interpolation walks the mass: at the halfway frame the single pile sits halfway between source and target. The first averages pictures; the second averages positions.

In Cryo-ET: moving what onto what

In Cryo-ET the two distributions CryoGEN-II transports are concrete. The target is the distribution of real observations pyp_y (the many real tomograms). The source is the distribution of generated clean reconstructions qxq_x — the aggregate of the volumes xx the network outputs across observations, living in reconstruction space X\mathcal{X}. The two live in different spaces, so the cost does not compare xx and yy directly; it first pushes a reconstruction back into observation space with the same missing-wedge degradation operator TM\mathcal{T}_M used in CryoGEN, then takes the squared distance,

c(y,TM(x))=yTM(x)2,c\big(y,\mathcal{T}_M(x)\big)=\big\lVert y-\mathcal{T}_M(x)\big\rVert^2 ,

giving the transport cost

Wc(py,qx)=infπΠ(py,qx)E(y,x)πyTM(x)2.\mathcal{W}_c(p_y,q_x)=\inf_{\pi\in\Pi(p_y,q_x)}\mathbb{E}_{(y,x)\sim\pi}\big\lVert y-\mathcal{T}_M(x)\big\rVert^2 .

Symbol by symbol: TM\mathcal{T}_M is the missing-wedge degradation operator (it turns a clean volume xx into what the microscope would actually record, with a wedge of Fourier information missing); yTM(x)2\lVert y-\mathcal{T}_M(x)\rVert^2 measures how far a “degraded generated volume” sits from a “real observation”; πΠ(py,qx)\pi\in\Pi(p_y,q_x) ranges over couplings whose marginals are the real pyp_y and the generated qxq_x; the inf\inf and the expectation together are the optimal-transport cost between these two whole distributions. Minimizing it forces the aggregate of “generated volumes, after degradation” to align with the aggregate of “real observations” — equivalently, the aggregated reconstructions q(xy)p(y)dy\int q(x\mid y)\,p(y)\,dy to approach the clean prior p(x)p(x). There is no ground-truth xx anywhere: since p(x)p(x) is unavailable, supervision comes entirely from the real observations pyp_y (CryoGEN’s PYP_Y proxy). It is a single deterministic optimization that still returns one reconstruction per observation — upgrading it to a family of reconstructions is the job of CryoWGEN.

Intuition

One degradation operator TM\mathcal{T}_M, two uses. CryoGEN-I does per-image MAP: for each yy it solves separately for a single xx^* minimizing TM(x)y2\lVert\mathcal{T}_M(x)-y\rVert^2 plus an energy prior, image by image. CryoGEN-II does global distribution matching: rather than aligning image by image, it requires the aggregate distribution qxq_x of reconstructions, once degraded by TM\mathcal{T}_M, to be transported onto pyp_y, the cost measured by Wc\mathcal{W}_c above. The former is faithful to each image, the latter faithful to the statistics of the whole dataset — and the latter is less prone to artifacts that contradict the real distribution of structures.

Unregularized optimal transport is the foundation of CryoGEN-II’s global distribution matching. Adding an entropy penalty to this cost goes one step further, giving the Boltzmann posterior of CryoWGEN / EVIA and upgrading a single reconstruction into a family of reconstructions.

From distance to autoencoder

Turning the Wasserstein distance from a number to compute into a trainable objective is exactly the Wasserstein autoencoder (WAE) idea: an encoder compresses data into a latent space, a decoder maps it back, and the transport cost between the aggregated posterior and the prior is minimized (Tolstikhin et al., 2018). The “aggregated posterior” here is the overall distribution of all data points after the encoder projects them into the latent space; requiring it to match the prior, rather than dragging each point’s posterior to the prior one by one as a VAE does, is the key shift WAE makes over a VAE — and it is the same idea as CryoGEN-II’s aggregate distribution matching in the previous section. This “optimal transport → autoencoder” thread is the shared skeleton of WAE, EVIA, and in turn CryoGEN-II and CryoWGEN. Its entropic-regularized version (see Sinkhorn) is the Sinkhorn autoencoder (Patrini et al., 2020); and in high dimensions, replacing one expensive high-dimensional transport with many cheap 1-D transports along random projections gives the sliced-Wasserstein autoencoder (Kolouri et al., 2019).

That last idea reuses the very closed form the opening demo already shows. Project both distributions onto a random direction θ\thetaθ#μ\theta_{\#}\mu and θ#ν\theta_{\#}\nu are now 1-D — and average the resulting one-dimensional distances over directions. This is the sliced-Wasserstein distance:

SW22(μ,ν)=Eθ[W22(θ#μ,  θ#ν)].\mathrm{SW}_2^2(\mu,\nu) = \mathbb{E}_{\theta}\big[\,W_2^2(\theta_{\#}\mu,\;\theta_{\#}\nu)\,\big].

Symbol by symbol: θ\theta is a random direction on the unit sphere, θ#μ\theta_\#\mu is μ\mu projected along θ\theta (each high-dimensional point xx replaced by the scalar θ,x\langle\theta,x\rangle), giving a 1-D distribution; W22(θ#μ,θ#ν)W_2^2(\theta_\#\mu,\theta_\#\nu) is the (squared) Wasserstein distance of that 1-D pair; and the outer expectation Eθ\mathbb{E}_\theta averages over directions. Each slice is the 1-D problem from the top of this page — sort the projected source and target samples, pair by rank, sum the squared gaps — an O(nlogn)O(n\log n) cost dominated entirely by the sort, with no linear program and no Sinkhorn iteration. Averaging over a handful of random θ\theta gives a cheap, differentiable estimate, and since each slice is a metric and an average of metrics is a metric, SW2\mathrm{SW}_2 is itself a genuine distance (Bonneel et al., 2015).

The Wasserstein distance underpins the Wasserstein GAN and the Wasserstein autoencoder, and supplies the distribution-matching objective used in CryoWGEN and CryoGEN.

Further reading

← Optimal Transport