Wasserstein Autoencoder (WAE)

An autoencoder derived from the optimal-transport view, matching the aggregated posterior to the prior with an MMD or adversarial penalty rather than a per-sample KL.

A Wasserstein autoencoder (WAE) is a generative autoencoder derived from the optimal-transport view of distribution matching. It minimizes a reconstruction cost together with a divergence between the aggregated (marginal) posterior and the prior:

xinputencoder QQ(z | x)zlatent codedecoder GG(z)reconstructionreconstruction cost c(x, x̂)aggr. posterior q_Zaveraged over all xprior p_Ze.g. standard normalmatch (MMD/adv.)matching acts on the aggregate, not per sample

How the aggregated posterior matches the prior — adjust the matching strength:

no discriminator · closed-form divergence
Prior p(z)Aggregated posterior q(z)
Kernel-MMD penalty
0.558

The WAE measures the gap between the aggregated posterior q(z) and the prior with the MMD of a fixed Gaussian kernel — a closed-form quantity over sample pairs, with NO discriminator anywhere. Raise the penalty weight λ and the MMD falls as q(z) spreads to cover the prior, while individual codes stay informative, so samples are sharper. Compare the AAE demo: the same match, done there by an adversarial discriminator instead.

State the whole picture up front: a WAE does exactly two things. First, encode each input xx to a latent code zz and decode it back, reconstructing as faithfully as possible. Second, make the cloud of all latent codes taken together conform, in its overall shape, to a fixed simple prior — usually a standard Gaussian. The second job concerns only the overall outline of that cloud, not where any particular xx lands, and that single distinction is what separates it from the VAE. The VAE asks every input’s code to move toward the prior on its own; the WAE asks only that the codes look like the prior on average. Swapping a per-point constraint for a whole-cloud one looks minor but decides whether samples come out blurry or sharp.

Intuition

Picture two clouds of points in latent space: the codes the encoder actually produces, and samples from the target prior. You want them to overlap. The WAE does not train any referee; it writes down a closed-form formula — a fixed-kernel maximum mean discrepancy (MMD) — that measures the distance between the two clouds directly over sample pairs, with no discriminator anywhere. By contrast, the adversarial autoencoder performs the same match but hires a trained adversarial discriminator as referee. Dropping the referee buys stability and reproducibility; the price is that the kernel carries a fixed built-in notion of “close,” so its bandwidth has to be chosen well.

minQ  EpXEQ(zx) ⁣[c(x,G(z))]+λD ⁣(qZ,pZ),qZ(z)=Q(zx)pX(x)dx,\min_{Q}\; \mathbb{E}_{p_X}\,\mathbb{E}_{Q(z\mid x)}\!\big[c\big(x, G(z)\big)\big] +\lambda\,\mathcal{D}\!\big(q_Z,\,p_Z\big), \qquad q_Z(z)=\int Q(z\mid x)\,p_X(x)\,dx,

Read the symbols one at a time. EpX\mathbb{E}_{p_X} averages over the data distribution, i.e. over all training samples xx; Q(zx)Q(z\mid x) is the encoder, giving for each xx a distribution over codes zz (which may collapse to a point); c(x,G(z))c(x, G(z)) is the transport cost between the decoded G(z)G(z) and the original xx, often the squared error xG(z)2\|x-G(z)\|^2; GG is the decoder. The first term together says one thing: average reconstruction error. D(qZ,pZ)\mathcal{D}(q_Z, p_Z) is a divergence between the aggregated posterior qZq_Z and the prior pZp_Z, weighted by λ\lambda. The trailing integral defines qZq_Z: it mixes each input’s code distribution Q(zx)Q(z\mid x), weighted by how often xx occurs in the data pX(x)p_X(x) — the marginal shape of the whole latent cloud once you forget which code came from which input. The form arises from the optimal-transport distance between the data distribution and the model: when the decoder is deterministic, the coupling between pXp_X and the generated distribution factors through the latent space, reducing the transport problem to reconstruction plus a single constraint on qZq_Z.

That constraint is what separates WAE from the variational autoencoder. The VAE penalizes the KL divergence of each per-sample posterior qϕ(zx)q_\phi(z\mid x) against the prior; the WAE only requires that the aggregate qZq_Z match pZp_Z, the marginal latent distribution obtained after averaging over the data. Different inputs may therefore map to overlapping or even identical regions, and the encoder may be deterministic. Removing the per-sample penalty relaxes a force that, in VAEs, tends to blur reconstructions, so WAEs often produce sharper samples.

Aggregate matching is also what makes sampling from the prior generative: after training, a latent code drawn from pZp_Z passes through the decoder to yield a new sample. When qZq_Z is close to pZp_Z, such draws fall within latent regions the decoder has seen, avoiding the “holes” — regions of the prior with no encoded support — that a too-loose per-sample constraint tends to leave. The divergence weight λ\lambda trades reconstruction fidelity against how tightly the latent distribution conforms to the prior: small λ\lambda gives crisp reconstructions but leaves holes, so prior samples decode to artifacts; large λ\lambda pins the latent to the prior but starts to drag reconstruction down.

The optimal-transport derivation

The objective is not posited but derived. The transport cost between the data distribution pXp_X and the model pGp_G — the decoder applied to the prior — is an infimum over couplings Γ\Gamma with those two marginals,

Wc(pX,pG)=infΓP(pX,pG)E(x,y)Γ[c(x,y)].W_c(p_X, p_G) = \inf_{\Gamma\in\mathcal{P}(p_X, p_G)} \mathbb{E}_{(x, y)\sim\Gamma}\big[c(x, y)\big].

Here Γ\Gamma is a joint “transport plan”: P(pX,pG)\mathcal{P}(p_X, p_G) is the set of all joint distributions whose two marginals are pXp_X and pGp_G, E(x,y)Γ[c(x,y)]\mathbb{E}_{(x,y)\sim\Gamma}[c(x,y)] is the average cost of moving xx to yy under that plan, and the inf\inf ranges over all plans to find the cheapest. When the decoder GG is a deterministic map, every yy in the model is the image G(z)G(z) of some latent zz, so the coupling can be expressed through the latent variable rather than over the full product of xx and yy. The transport problem then factors through the latent space: minimizing over couplings reduces to minimizing reconstruction cost c(x,G(z))c(x, G(z)) over conditional encoders Q(zx)Q(z\mid x), subject to the single marginal constraint that the aggregated posterior qZq_Z equal the prior pZp_Z. Relaxing that hard constraint to a penalty λD(qZ,pZ)\lambda\,\mathcal{D}(q_Z, p_Z) yields the WAE objective. The constraint, not an approximation, is what the OT view makes exact.

It is worth spelling out what this step saves. Optimizing the coupling Γ\Gamma between pairs of images directly means searching the full x×yx\times y product space, whose size explodes with the pixel count. The latent factorization compresses that search into “encoder + decoder + one latent-space constraint” — the first two being an ordinary autoencoder forward pass, the last a distribution match in a low-dimensional (latent) space. Pushing the cost from transport in image space down to matching in latent space is exactly why a WAE is trainable.

Deterministic versus stochastic encoders, and the source of sharpness

Because only the aggregate qZq_Z is constrained, the encoder Q(zx)Q(z\mid x) may be deterministic — a point map z=Q(x)z=Q(x) — without violating the objective; a stochastic encoder is permitted but not required. The VAE, by contrast, depends on a stochastic per-sample posterior whose KL-to-prior term acts on every xx separately. That per-sample force inflates each posterior toward the prior and overlaps the codes of distinct inputs, and the decoder, asked to reconstruct from these overlapping codes, hedges by producing blurred averages. Matching only the aggregate removes that per-sample pressure: distinct inputs may keep separated, even non-overlapping codes so long as their average matches pZp_Z, leaving the decoder free to commit to sharp reconstructions. This is the mechanism behind the WAE’s typically sharper samples.

A small picture nails down where the blur comes from. Take two very different inputs x1x_1 and x2x_2. The VAE’s per-sample KL inflates both q(zx1)q(z\mid x_1) and q(zx2)q(z\mid x_2) toward the prior’s center, and their high-probability regions begin to overlap; for a zz that lands in the overlap, the decoder is asked to produce both x1x_1 and x2x_2, and the squared-error-optimal compromise is their average — visually, blur. The WAE applies no such per-sample force, so Q(x1)Q(x_1) and Q(x2)Q(x_2) can sit far apart and not overlap, each zz maps to a single target, the decoder need not hedge, and the result is sharp.

Depth

Two estimators of D(qZ,pZ)\mathcal{D}(q_Z, p_Z) are standard. WAE-MMD uses a maximum mean discrepancy with a fixed kernel — a closed-form, sample-based penalty with no extra network. Given a kernel kk, the MMD maps both clouds of samples into a reproducing-kernel Hilbert space and compares their mean embeddings: MMD2=Ez,zqZk(z,z)+Ez,zpZk(z,z)2EzqZ,zpZk(z,z)\mathrm{MMD}^2 = \mathbb{E}_{z,z'\sim q_Z}k(z,z') + \mathbb{E}_{z,z'\sim p_Z}k(z,z') - 2\,\mathbb{E}_{z\sim q_Z, z'\sim p_Z}k(z,z'), each term estimated as a mean over minibatch sample pairs, vanishing when the two distributions coincide. A common choice is the inverse-multiquadratic kernel k(z,z)=C/(C+zz2)k(z,z')=C/(C+\|z-z'\|^2), whose heavier tails make it more sensitive to outlier codes far from the origin than a Gaussian kernel. WAE-GAN trains a discriminator in latent space to estimate the divergence adversarially, which is more flexible but reintroduces minimax instability — when the discriminator lags the moving qZq_Z, the gradient handed to the encoder is biased.

The adversarial variant is closely related to the adversarial autoencoder, which also matches an aggregated posterior to a prior through a discriminator. Both can be read as special cases of the broader transport-based matching that EVIA develops with entropic optimal transport.

Where this lands in Cryo-ET

Cryo-ET reconstruction has no ground-truth volume to imitate, only tomograms scarred by the missing wedge. The usable supervision is a prior on what real structures look like, and “matching a distribution to a prior” is exactly the problem the WAE solves in latent space. The step of swapping a discriminator for a closed-form, transport-based match is the same step CryoGEN-II takes past CryoGEN-I: CryoGEN-I uses a discriminator as a point-estimate restorer and inherits the minimax instability described in the Depth callout; CryoGEN-II follows the WAE/OT route, trading the moving referee for a stable optimal-transport objective that gives one stable answer. Carrying aggregate matching one step further — the entropic transport of EVIA — underlies CryoWGEN, which no longer returns a single restored volume but expands it into a posterior family, exposing exactly which details the missing wedge leaves undetermined.

← Generative & Distribution Matching