熵正则最优传输与 Sinkhorn

在最优传输代价上加入熵 / KL 罚项,得到唯一、处处为正的平滑耦合,并由 Sinkhorn 矩阵缩放高效求解。

Kantorovich 形式的最优传输是一个线性规划:在所有以 μ,ν\mu,\nu 为边缘的耦合里挑代价最小的那个。它的解往往「太硬」—— 质量被挤到极少数几对 (x,y)(x,y) 上,逼近一个确定性的 Monge 映射。熵正则做的事只有一句话:在原代价上再加一项相对熵罚,惩罚这种塌缩,逼着方案把质量摊开。换来的是三样东西 —— 唯一解、处处可微、以及一个能用矩阵缩放在几十步内算完的高效解法(Sinkhorn)。本页要讲清这三样从哪来,以及为什么这正是 CryoWGEN 训练时所采样的后验。

熵正则最优传输在 Kantorovich 传输代价上附加一个相对熵罚项,将原本的线性规划改写为强凸问题。罚项以参考耦合 κ\kappa 为基准、以正则强度 γ(可理解为「温度」:越大耦合越被抹平)为权重,在传输代价与耦合的弥散程度之间取得平衡,所得目标为

minπΠ(μ,ν)cdπ  +  γKL(πκ).\min_{\pi\in\Pi(\mu,\nu)} \int c\,d\pi \;+\; \gamma\,\mathrm{KL}(\pi\Vert\kappa).

逐项读:π\pi 是待求的耦合(运输时刻表),Π(μ,ν)\Pi(\mu,\nu) 是所有以 μ,ν\mu,\nu 为行/列边缘的耦合集合,cc 是单位传输代价(常取 xy2\lVert x-y\rVert^2),cdπ\int c\,d\pi 是该方案的总代价;KL(πκ)\mathrm{KL}(\pi\Vert\kappa) 度量 π\pi 偏离参考耦合 κ\kappa 的程度,γ>0\gamma>0 是它的权重。γ\gamma 越大,第二项越重,求解器越愿意为了「弥散」而牺牲「省代价」。

行边缘 μ列边缘 ν耦合 πᵢⱼ

每个格子按玻尔兹曼权重 exp(−cᵢⱼ/γ) 着色:γ 越小,质量越集中在代价最低的对角带,逼近硬性指派;γ 越大,权重被熵抹平,耦合弥散趋向独立乘积 μ⊗ν。

这里 π\pi 始终表示耦合、γ\gamma 表示温度(与 optimal-transport 全站约定一致)。参考耦合最常取独立乘积测度 κ=μν\kappa=\mu\otimes\nu,此时 KL(πμν)\mathrm{KL}(\pi\Vert\mu\otimes\nu) 衡量耦合 π\pi 偏离独立的程度(见 熵与 KL 散度),它只是一般式 γKL(πκ)\gamma\,\mathrm{KL}(\pi\Vert\kappa)κ=μν\kappa=\mu\otimes\nu 下的特例。强凸性保证最优解唯一,且具有显式的吉布斯–玻尔兹曼形式

π(x,y)    κ(x,y)ec(x,y)/γ   κ=μν   μ(x)ν(y)ec(x,y)/γ,\pi^\star(x,y)\;\propto\;\kappa(x,y)\,e^{-c(x,y)/\gamma} \;\xrightarrow{\ \kappa=\mu\otimes\nu\ }\; \mu(x)\,\nu(y)\,e^{-c(x,y)/\gamma},

即在参考耦合上以代价为势能、温度为 γ\gamma 的玻尔兹曼重加权。式中 ec(x,y)/γe^{-c(x,y)/\gamma} 把「代价低」翻译成「权重高」:代价每升高 γ\gamma,权重就掉一个 ee 倍,所以 γ\gamma 直接定下了这条衰减有多陡。

未正则的最优传输倾向于把质量集中在稀疏、近乎确定性的支撑上(极端情形为 Monge 映射)。KL 罚项惩罚这种塌缩:任意一对 (x,y)(x,y) 都被赋予严格为正的传输质量,于是最优方案处处弥散、平滑,而非锐利的指派。

熵项到底做了什么

熵罚远不止是一个计算技巧 —— 它在三个层面上改变了答案,值得分开来看。

它让问题良定、且求解廉价。 未正则的最优传输是个线性规划,解可能落在可行域多胞形的某个脆弱顶点上,求解又要 O(n3)O(n^3)。加上熵,目标变得严格凸:最优解因而唯一,并获得对角缩放结构 π=diag(a)Kdiag(b)\pi^\star=\mathrm{diag}(a)\,K\,\mathrm{diag}(b) —— 正是下文 Sinkhorn 迭代所利用的 —— 从而把线性规划压成几次矩阵–向量乘法(Cuturi 2013)。

它换来平滑,也就换来了能用的梯度。 既然没有哪一对会被赋予恰好为零的质量,传输方案便随代价或边缘的改变而平滑变化,对它们也可微。硬性指派是分段常数的 —— 梯度几乎处处为零、又在跳变处无定义 —— 根本没法拿来训练;正是熵正则后的方案,才让最优传输能当作网络里的损失。具体到 Cryo-ET:编码器一调参数、生成的重构 xx 就动一点,退化后的 TM(x)\mathcal{T}_M(x) 与代价 cc 也随之微动;只有当传输方案对这点微动有非零、连续的响应时,梯度才能一路传回编码器。

它还有实打实的物理含义 —— 薛定谔桥。 熵正则最优传输刻画的,恰是一团做扩散(布朗)运动的粒子从 μ\mu 演化到 ν\nu最可能路径:代价对应动能耗费,熵项则是这场扩散的热噪声,γ\gamma 就是它的温度。γ\gamma 之所以是「温度」而非比喻,正源于此。

直觉

γ\gamma 读作「传输被迫保留多少随机性」。γ0\gamma\to0 时噪声消失,退回锐利、确定的(Monge)传输 —— 答案只有一个;γ\gamma 一大,耦合便开始「对冲」,把质量摊到许多对上;推到极限,它索性忘掉代价,松弛为独立乘积 μν\mu\otimes\nu(熵最大,传输结构荡然无存)。熵项的全部价值,就在于让你能在「单一笃定的答案」与「弥散对冲的答案」之间拨档 —— 而下文一旦固定某张观测 yy,这正是「单一 MAP 重构」与「宽后验」之间的那个旋钮。

Sinkhorn 算法

求解依赖最优耦合的特殊结构。π\pi^\star 必可写为对角缩放 π=diag(a)Kdiag(b)\pi^\star=\mathrm{diag}(a)\,K\,\mathrm{diag}(b),其中核 Kij=ecij/γK_{ij}=e^{-c_{ij}/\gamma}Sinkhorn 算法交替更新两个缩放向量 a,ba,b,使行边缘与列边缘分别归为 μ\muν\nu。给定代价矩阵 CC、边缘 μ,ν\mu,\nu 与温度 γ\gamma

  1. 构造核 Kij=eCij/γK_{ij}=e^{-C_{ij}/\gamma},初始化 b1b\leftarrow\mathbf{1}
  2. 更新行缩放:aμ(Kb)a\leftarrow \mu \oslash (Kb)\oslash 为逐元素除)。
  3. 更新列缩放:bν(Ka)b\leftarrow \nu \oslash (K^{\top}a)
  4. 重复 2–3 直至边缘误差收敛。
  5. 返回 π=diag(a)Kdiag(b)\pi^\star=\mathrm{diag}(a)\,K\,\mathrm{diag}(b)

为什么这样两步交替就对?第 2 步把当前的行和强行掰到 μ\mu,第 3 步把列和强行掰到 ν\nu —— 各自单看都是「把一边的边缘一次性归位」的最省 KL 投影。两步轮流做,每次只破坏对方一点点,误差按几何速率收缩,于是迭代线性收敛到同时满足两个边缘的那个唯一耦合。每步仅需一次矩阵–向量乘法。由于全程为光滑可微运算,Sinkhorn 的输出可作为可微损失嵌入端到端训练。这里的缩放 a,ba,b 正是对偶势函数的指数 f=γlogaf=\gamma\log ag=γlogbg=\gamma\log b —— Sinkhorn 即是对光滑熵对偶做坐标上升。

深入

跑一遍最小的数。取两个 2 点分布 μ=ν=(12,12)\mu=\nu=(\tfrac12,\tfrac12),代价矩阵 C=\begin{psmallmatrix}0&1\\1&0\end{psmallmatrix}(对角是「同点搬运」零代价,非对角是「换位」单位代价),温度 γ=1\gamma=1。先构核 Kij=eCijK_{ij}=e^{-C_{ij}}K=\begin{psmallmatrix}1&e^{-1}\\e^{-1}&1\end{psmallmatrix}\approx\begin{psmallmatrix}1&0.368\\0.368&1\end{psmallmatrix}。这里对称性已经替我们把答案猜了出来 —— 行/列对称使 a=b=常数a=b=\text{常数},唯一满足边缘的解是 \pi^\star\approx\begin{psmallmatrix}0.365&0.135\\0.135&0.365\end{psmallmatrix}:大部分质量留在零代价的对角,少部分「漏」到换位上,每行每列恰好加到 0.50.5。换个温度看趋势:γ0\gamma\to0KK 的非对角项 e1/γ0e^{-1/\gamma}\to0,方案塌成纯对角 \begin{psmallmatrix}0.5&0\\0&0.5\end{psmallmatrix}(硬传输,质量全走零代价);γ\gamma\to\inftye1/γ1e^{-1/\gamma}\to1KK 趋于全 1,方案摊成 \begin{psmallmatrix}0.25&0.25\\0.25&0.25\end{psmallmatrix}=\mu\otimes\nu(完全忘掉代价)。γ=1\gamma=1 落在两者之间,正是熵罚「拨档」的可视化。

直觉

γ\gamma 扮演温度的角色。γ\gamma 越大,传输方案越被熵”抹平”,趋向参考耦合 κ\kappa(取 μν\mu\otimes\nu 时即独立乘积);γ0\gamma\to0 时玻尔兹曼权重退化为对最小代价的硬性选择,恢复未正则(硬)最优传输,但 Sinkhorn 迭代也随之变得病态:核里的 ec/γe^{-c/\gamma} 在小 γ\gamma 下要么爆下溢、要么数值悬殊,实践中因此改在对数域跑(log-domain Sinkhorn),用 log-sum-exp 代替直接相乘。

从熵正则代价到 Sinkhorn 散度

上面的熵正则代价虽方便,却是有偏的OTγ(μ,μ)0\mathrm{OT}_\gamma(\mu,\mu)\neq 0,于是最小化它的生成器会”摊得不够开”——γ\gamma 越大,熵代价把距离压向一个被抹平的度量。换句话说,连「一个分布到它自己」的代价都不为零,拿它当损失,最优点就不在「生成分布 = 数据分布」处,而被熵项往「更集中」的方向拽偏。修正办法是 Sinkhorn 散度,它减去两个自配项:

Sγ(μ,ν)=OTγ(μ,ν)12OTγ(μ,μ)12OTγ(ν,ν).\mathrm{S}_\gamma(\mu,\nu)=\mathrm{OT}_\gamma(\mu,\nu)-\tfrac12\,\mathrm{OT}_\gamma(\mu,\mu)-\tfrac12\,\mathrm{OT}_\gamma(\nu,\nu).

两个 12OTγ(,)-\tfrac12\,\mathrm{OT}_\gamma(\cdot,\cdot) 自配项正是把「分布到自己」那份残留代价扣掉,使总量在 μ=ν\mu=\nu 处归零。它满足 Sγ(μ,μ)=0\mathrm{S}_\gamma(\mu,\mu)=0Sγ0\mathrm{S}_\gamma\ge 0,并在最优传输(γ0\gamma\to0)与最大均值差异 MMD(γ\gamma\to\infty)之间插值(Genevay 等 2018;Feydy 等 2019)。真正用作生成模型损失的,正是这个去偏的 Sγ\mathrm{S}_\gamma,而非原始熵代价(即 Sinkhorn 自编码器的设定,Patrini 等 2020);而它在大 γ\gamma 下的 MMD 极限,恰是 Wasserstein 自编码器用来让聚合后验匹配先验的那个 MMD。

这正是 EVIA / CryoWGEN 的后验

上述玻尔兹曼耦合 π(y,x)κ(y,x)ec(y,TM(x))/γ\pi^\star(y,x)\propto\kappa(y,x)\,e^{-c(y,\mathcal{T}_M(x))/\gamma} 不只是一个求解技巧 —— 它就是 EVIACryoWGEN 在训练中采样的后验。固定一张观测 yy、读取条件切片 π(y)\pi^\star(\cdot\mid y),便得到该观测的逐图后验

q(xy)    κ(xy)ec(y,TM(x))/γ,q^\star(x\mid y)\;\propto\;\kappa(x\mid y)\,e^{-c(y,\mathcal{T}_M(x))/\gamma},

即在数据一致性参考耦合 κ\kappa 上、以退化失配 c(y,TM(x))=yTM(x)2c(y,\mathcal{T}_M(x))=\lVert y-\mathcal{T}_M(x)\rVert^2 为能量、以 γ\gamma 为温度做玻尔兹曼重加权。逐项读:xx 是候选的干净重构,TM(x)\mathcal{T}_M(x) 是把它用缺失楔形退化算子 TM\mathcal{T}_M 打回观测空间的结果,yTM(x)2\lVert y-\mathcal{T}_M(x)\rVert^2 衡量它退化后与实测 yy 差多少;xx 越能解释 yy,能量越低,后验给它的权重越高。这正是 CryoWGEN 在 E-step 所采样的分布 —— CryoWGEN-I 用 Monte-Carlo 重加权采样它,CryoWGEN-II 用 Langevin / SGLD 迭代采样它;编码器则学习其条件均值 E[q(xy)]\mathbb{E}[q(x\mid y)],等价于一个 Entropy-SGD 式的平滑点估计。

温度 γ\gamma 在此含义清晰。γ0\gamma\to0 时后验塌缩到 argmin\arg\min 能量处的单点,退回 CryoGEN-II / MAP 的硬传输 —— 每个观测只得一个确定性重构;γ>0\gamma>0 则保留一族重构,正好刻画缺失楔形带来的不确定性:同一张被破坏的 yy 本应对应许多个都说得通的干净体 xx。这把全站的四方法分类接到了同一根旋钮上 —— MAP 点估计(CryoGEN-I)、WAE/OT 稳定单解(CryoGEN-II)、EVIA Monte-Carlo(CryoWGEN-I)、EVIA Langevin 后验族(CryoWGEN-II)—— 它们的差别,很大程度上就是这个 γ\gamma(以及如何采样玻尔兹曼后验)的差别。

深入

为何最优条件分布恰取玻尔兹曼形式?固定观测 yy,把熵正则目标的条件部分写为对 q(y)q(\cdot\mid y) 的泛函

minq  {Eq[c(y,TM(x))]  +  γKL(qκ(y))}.\min_{q}\;\Big\{\mathbb{E}_{q}\big[c(y,\mathcal{T}_M(x))\big]\;+\;\gamma\,\mathrm{KL}\big(q\Vert\kappa(\cdot\mid y)\big)\Big\}.

对归一化约束 qdx=1\int q\,dx=1 引入拉格朗日乘子 λ\lambda,对 q(x)q(x) 求变分并令其为零:

c(y,TM(x))+γ(logq(x)κ(xy)+1)+λ=0    q(xy)κ(xy)ec(y,TM(x))/γ.c(y,\mathcal{T}_M(x))+\gamma\big(\log\tfrac{q(x)}{\kappa(x\mid y)}+1\big)+\lambda=0 \;\Longrightarrow\; q^\star(x\mid y)\propto\kappa(x\mid y)\,e^{-c(y,\mathcal{T}_M(x))/\gamma}.

这正是吉布斯变分原理(Donsker–Varadhan):在 KL 邻域约束下最小化期望代价,最优解必为对参考测度的玻尔兹曼重加权。γ\gamma 越小,权重越尖锐,后验越接近对代价的硬性 argmin —— 即 MAP。反过来,这也解释了为什么编码器学条件均值就够:玻尔兹曼后验在小 γ\gamma 下近似高斯钟形,其均值与众数几乎重合,于是「学均值」与「找 MAP」在这一极限下收敛到同一答案。

熵正则最优传输是 EVIAWasserstein 自编码器 中传输目标的基础,亦为 CryoWGEN 所用的分布匹配提供可微近似。把这个 Sinkhorn 目标套进编码器–解码器,就得到 Sinkhorn 自编码器(Patrini 等, 2020)—— 也正是 EVIA 的直接前身。

延伸阅读

← 最优传输