Wasserstein Autoencoder (WAE)
An autoencoder derived from the optimal-transport view, matching the aggregated posterior to the prior with an MMD or adversarial penalty rather than a per-sample KL.
A Wasserstein autoencoder (WAE) is a generative autoencoder derived from the optimal-transport view of distribution matching. It minimizes a reconstruction cost together with a divergence between the aggregated (marginal) posterior and the prior:
How the aggregated posterior matches the prior — adjust the matching strength:
The WAE measures the gap between the aggregated posterior q(z) and the prior with the MMD of a fixed Gaussian kernel — a closed-form quantity over sample pairs, with NO discriminator anywhere. Raise the penalty weight λ and the MMD falls as q(z) spreads to cover the prior, while individual codes stay informative, so samples are sharper. Compare the AAE demo: the same match, done there by an adversarial discriminator instead.
State the whole picture up front: a WAE does exactly two things. First, encode each input to a latent code and decode it back, reconstructing as faithfully as possible. Second, make the cloud of all latent codes taken together conform, in its overall shape, to a fixed simple prior — usually a standard Gaussian. The second job concerns only the overall outline of that cloud, not where any particular lands, and that single distinction is what separates it from the VAE. The VAE asks every input’s code to move toward the prior on its own; the WAE asks only that the codes look like the prior on average. Swapping a per-point constraint for a whole-cloud one looks minor but decides whether samples come out blurry or sharp.
Picture two clouds of points in latent space: the codes the encoder actually produces, and samples from the target prior. You want them to overlap. The WAE does not train any referee; it writes down a closed-form formula — a fixed-kernel maximum mean discrepancy (MMD) — that measures the distance between the two clouds directly over sample pairs, with no discriminator anywhere. By contrast, the adversarial autoencoder performs the same match but hires a trained adversarial discriminator as referee. Dropping the referee buys stability and reproducibility; the price is that the kernel carries a fixed built-in notion of “close,” so its bandwidth has to be chosen well.
Read the symbols one at a time. averages over the data distribution, i.e. over all training samples ; is the encoder, giving for each a distribution over codes (which may collapse to a point); is the transport cost between the decoded and the original , often the squared error ; is the decoder. The first term together says one thing: average reconstruction error. is a divergence between the aggregated posterior and the prior , weighted by . The trailing integral defines : it mixes each input’s code distribution , weighted by how often occurs in the data — the marginal shape of the whole latent cloud once you forget which code came from which input. The form arises from the optimal-transport distance between the data distribution and the model: when the decoder is deterministic, the coupling between and the generated distribution factors through the latent space, reducing the transport problem to reconstruction plus a single constraint on .
That constraint is what separates WAE from the variational autoencoder. The VAE penalizes the KL divergence of each per-sample posterior against the prior; the WAE only requires that the aggregate match , the marginal latent distribution obtained after averaging over the data. Different inputs may therefore map to overlapping or even identical regions, and the encoder may be deterministic. Removing the per-sample penalty relaxes a force that, in VAEs, tends to blur reconstructions, so WAEs often produce sharper samples.
Aggregate matching is also what makes sampling from the prior generative: after training, a latent code drawn from passes through the decoder to yield a new sample. When is close to , such draws fall within latent regions the decoder has seen, avoiding the “holes” — regions of the prior with no encoded support — that a too-loose per-sample constraint tends to leave. The divergence weight trades reconstruction fidelity against how tightly the latent distribution conforms to the prior: small gives crisp reconstructions but leaves holes, so prior samples decode to artifacts; large pins the latent to the prior but starts to drag reconstruction down.
The optimal-transport derivation
The objective is not posited but derived. The transport cost between the data distribution and the model — the decoder applied to the prior — is an infimum over couplings with those two marginals,
Here is a joint “transport plan”: is the set of all joint distributions whose two marginals are and , is the average cost of moving to under that plan, and the ranges over all plans to find the cheapest. When the decoder is a deterministic map, every in the model is the image of some latent , so the coupling can be expressed through the latent variable rather than over the full product of and . The transport problem then factors through the latent space: minimizing over couplings reduces to minimizing reconstruction cost over conditional encoders , subject to the single marginal constraint that the aggregated posterior equal the prior . Relaxing that hard constraint to a penalty yields the WAE objective. The constraint, not an approximation, is what the OT view makes exact.
It is worth spelling out what this step saves. Optimizing the coupling between pairs of images directly means searching the full product space, whose size explodes with the pixel count. The latent factorization compresses that search into “encoder + decoder + one latent-space constraint” — the first two being an ordinary autoencoder forward pass, the last a distribution match in a low-dimensional (latent) space. Pushing the cost from transport in image space down to matching in latent space is exactly why a WAE is trainable.
Deterministic versus stochastic encoders, and the source of sharpness
Because only the aggregate is constrained, the encoder may be deterministic — a point map — without violating the objective; a stochastic encoder is permitted but not required. The VAE, by contrast, depends on a stochastic per-sample posterior whose KL-to-prior term acts on every separately. That per-sample force inflates each posterior toward the prior and overlaps the codes of distinct inputs, and the decoder, asked to reconstruct from these overlapping codes, hedges by producing blurred averages. Matching only the aggregate removes that per-sample pressure: distinct inputs may keep separated, even non-overlapping codes so long as their average matches , leaving the decoder free to commit to sharp reconstructions. This is the mechanism behind the WAE’s typically sharper samples.
A small picture nails down where the blur comes from. Take two very different inputs and . The VAE’s per-sample KL inflates both and toward the prior’s center, and their high-probability regions begin to overlap; for a that lands in the overlap, the decoder is asked to produce both and , and the squared-error-optimal compromise is their average — visually, blur. The WAE applies no such per-sample force, so and can sit far apart and not overlap, each maps to a single target, the decoder need not hedge, and the result is sharp.
Two estimators of are standard. WAE-MMD uses a maximum mean discrepancy with a fixed kernel — a closed-form, sample-based penalty with no extra network. Given a kernel , the MMD maps both clouds of samples into a reproducing-kernel Hilbert space and compares their mean embeddings: , each term estimated as a mean over minibatch sample pairs, vanishing when the two distributions coincide. A common choice is the inverse-multiquadratic kernel , whose heavier tails make it more sensitive to outlier codes far from the origin than a Gaussian kernel. WAE-GAN trains a discriminator in latent space to estimate the divergence adversarially, which is more flexible but reintroduces minimax instability — when the discriminator lags the moving , the gradient handed to the encoder is biased.
The adversarial variant is closely related to the adversarial autoencoder, which also matches an aggregated posterior to a prior through a discriminator. Both can be read as special cases of the broader transport-based matching that EVIA develops with entropic optimal transport.
Where this lands in Cryo-ET
Cryo-ET reconstruction has no ground-truth volume to imitate, only tomograms scarred by the missing wedge. The usable supervision is a prior on what real structures look like, and “matching a distribution to a prior” is exactly the problem the WAE solves in latent space. The step of swapping a discriminator for a closed-form, transport-based match is the same step CryoGEN-II takes past CryoGEN-I: CryoGEN-I uses a discriminator as a point-estimate restorer and inherits the minimax instability described in the Depth callout; CryoGEN-II follows the WAE/OT route, trading the moving referee for a stable optimal-transport objective that gives one stable answer. Carrying aggregate matching one step further — the entropic transport of EVIA — underlies CryoWGEN, which no longer returns a single restored volume but expands it into a posterior family, exposing exactly which details the missing wedge leaves undetermined.