title: classifier_guidance
date: 2025-06-02 20:06:57
tags:

由于基于 diffusion 的图像生成模型多样性过高(不可控,而且容易出现很多真实性很差的图片,FID 指标太高),可能的一种解决办法就是 conditional generation,即将 $$p(x)\rightarrow p(x|y)$$,让一个 DDPM 变为一个 conditional DDPM。

条件生成模型是只指在条件 y\boldsymbol y(condition)下,根据条件生成内容 x\boldsymbol x,概率表示为 p(xy)p(\boldsymbol x|\boldsymbol y)

Classifier Guidance

注意我们对 guidance 的要求是:

  • 不影响训练过程,即我们只在采样过程做文章
  • 尝试用一种高效的方法进行指导采样,而不是最后接一个简单的分类器,让模型生成很多图片最终选一个最好的图片

我们用 q^\hat{q} 表述加入条件 y\boldsymbol y 后的(逆向)采样过程。我们的目标是修改无条件的逆向采样 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 得到条件采样 q^(xt1xt,y)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y)。利用贝叶斯定理来引入条件 y\boldsymbol y

q^(xt1xt,y)=q^(xt1xt)q^(yxt1,xt)q^(yxt)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)}

这个公式的右边包含三个概率项:

  1. q^(xt1xt)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t): 在给定 xt\boldsymbol x_t 的情况下,xt1\boldsymbol x_{t-1} 的(条件化后的)先验分布。我们希望它与原始的 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 相关。
  2. q^(yxt1,xt)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t): 给定 xt\boldsymbol x_txt1\boldsymbol x_{t-1} 后,条件 y\boldsymbol y 的似然。这部分将由分类器提供指导。
  3. q^(yxt)\hat{q}(\boldsymbol y|\boldsymbol x_t): 证据项 (evidence),作为归一化常数。

看上去这些新的 q^\hat{q} 分布我们都不知道,无法进行下一步求解。但是注意我们需要的是一个 training-free 的模型,即不改变原始 DDPM 的训练。为此,我们做出如下规定和假设:

核心假设 (Training-Free):

  1. 条件化的前向过程不变: q^(xtxt1,y)=q(xtxt1)\hat{q}(\boldsymbol x_{t}|\boldsymbol x_{t-1}, \boldsymbol y) = q(\boldsymbol x_{t}|\boldsymbol x_{t-1})。这意味着给定 xt1\boldsymbol x_{t-1}xt\boldsymbol x_t 的产生(加噪)过程不直接依赖于外部条件 y\boldsymbol y。这是保持原始扩散模型不变的关键。
  2. 初始分布不变: q^(x0)=q(x0)\hat{q}(\boldsymbol x_0) = q(\boldsymbol x_0)
  3. Markov property:整个条件化的前向轨迹满足: q^(x1:Tx0,y)=t=1Tq^(xtxt1,y)=t=1Tq(xtxt1)=q(x1:Tx0)\hat{q}(\boldsymbol x_{1:T}|\boldsymbol x_0, \boldsymbol y) = \prod_{t=1}^T \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) = \prod_{t=1}^T q(\boldsymbol x_t|\boldsymbol x_{t-1}) = q(\boldsymbol x_{1:T}|\boldsymbol x_0)

接下来,我们基于这些假设推导上面三个概率的具体形式

q^(xtxt1)\hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}) 求解

这一步的目标是证明,即使我们考虑了条件 y\boldsymbol y 的存在,当我们对 y\boldsymbol y 进行积分(边缘化)后,从 xt1\boldsymbol x_{t-1}xt\boldsymbol x_t 的单步前向转移概率与原始的无条件转移概率 q(xtxt1)q(\boldsymbol x_t|\boldsymbol x_{t-1}) 是一致的。这有助于简化后续分析,表明模型的前向“骨架”在整体上未变:

q^(xtxt1)=yq^(xt,yxt1)dy=yq^(xtxt1,y)q^(yxt1)dy=yq(xtxt1)q^(yxt1)dy=q(xtxt1)yq^(yxt1)dy=q(xtxt1)\begin{aligned} \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}) &= \int_\boldsymbol y \hat{q}(\boldsymbol x_t, \boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ &= \int_\boldsymbol y \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ &= \int_\boldsymbol y q(\boldsymbol x_t|\boldsymbol x_{t-1}) \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ &= q(\boldsymbol x_t|\boldsymbol x_{t-1}) \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ &= q(\boldsymbol x_t|\boldsymbol x_{t-1}) \end{aligned}

这表明,在边缘化条件 y\boldsymbol y 后,条件模型的前向转移概率与原始模型的相同

q^(xt)\hat{q}(\boldsymbol x_t) 求解

这一步的目标是证明在所有时间步和初始状态上进行积分后,任意时刻 tt 的边缘分布 q^(xt)\hat{q}(\boldsymbol x_t) 与原始的 q(xt)q(\boldsymbol x_t) 相同。这意味着尽管我们引入了条件生成的机制,但从宏观上看,数据在各个噪声水平的(无条件)分布特性保持不变

q^(xt)=x0:t1q^(x0:t)dx0:t1=x0:t1q^(x0)q^(x1:tx0)dx0:t1\begin{aligned} \hat{q}(\boldsymbol x_t) &= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol{x}_{0:t})d\boldsymbol{x}_{0:t-1} \\ &= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \end{aligned}

为此需要求解 q^(x1:tx0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)

q^(x1:tx0)=yq^(x1:t,yx0)dy=yq^(yx0)q^(x1:tx0,y)dy=yq^(yx0)i=1tq^(xixi1,y)dy(Markov property)=yq^(yx0)i=1tq(xixi1)dy=(i=1tq(xixi1))yq^(yx0)dy=q(x1:tx0)(yq^(yx0)dy=1)\begin{aligned} \hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)&= \int_\boldsymbol y \hat{q}(\boldsymbol x_{1:t}, \boldsymbol y|\boldsymbol x_0) d\boldsymbol y= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\hat{q}(\boldsymbol x_{1:t}|\boldsymbol x_0, \boldsymbol y) d\boldsymbol y\\ &= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t \hat{q}(\boldsymbol x_i|\boldsymbol x_{i-1}, \boldsymbol y) d\boldsymbol y \quad (\text{Markov property}) \\ &= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t q(\boldsymbol x_i|\boldsymbol x_{i-1}) d\boldsymbol y \\ &= \left(\prod_{i=1}^t q(\boldsymbol x_i|\boldsymbol x_{i-1})\right) \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0) d\boldsymbol y \\ &= q(\boldsymbol x_{1:t}|\boldsymbol x_0) \quad (\int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0) d\boldsymbol y = 1) \end{aligned}

这证明了 q^(x1:tx0)=q(x1:tx0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)=q(\boldsymbol x_{1:t}|\boldsymbol x_0),那么接下来:

q^(xt)=x0:t1q^(x0)q^(x1:tx0)dx0:t1=x0:t1q(x0)q(x1:tx0)dx0:t1=x0:t1q(x0:t)dx0:t1=q(xt)\begin{aligned} \hat{q}(\boldsymbol x_t) &= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\ &= \int_{\boldsymbol x_{0:t-1}}q(\boldsymbol x_0)q(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\ &= \int_{\boldsymbol x_{0:t-1}} q(\boldsymbol x_{0:t})d\boldsymbol{x}_{0:t-1} = q(\boldsymbol x_t) \end{aligned}

这表明任意时刻 tt 的边缘分布 q^(xt)\hat{q}(\boldsymbol x_t) 与原始的 q(xt)q(\boldsymbol x_t) 相同。这再次确认了我们的 training-free 假设下,模型的基础特性得以保持。

q^(yxt1,xt)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) 求解

这是贝叶斯公式中的关键似然项之一。我们希望简化它,理想情况下将其与一个易于建模的分类器联系起来

q^(yxt1,xt)=q^(xtxt1,y)q^(yxt1)q^(xtxt1)=q(xtxt1)q^(yxt1)q(xtxt1)=q^(yxt1)\begin{aligned} \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) &= \frac{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1}, \boldsymbol y)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1})} \\ &= \frac{q (\boldsymbol x_t| \boldsymbol x_{t-1})\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{q (\boldsymbol x_t| \boldsymbol x_{t-1})} = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) \end{aligned}

这个推导表明 q^(yxt1,xt)=q^(yxt1)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1})。这意味着在给定 xt1\boldsymbol x_{t-1} 的情况下,xt\boldsymbol x_t 对于确定 y\boldsymbol y 不提供额外信息(即 yxt1xty \leftrightarrow \boldsymbol x_{t-1} \leftrightarrow \boldsymbol x_t 形成马尔可夫链)。这一项 q^(yxt1)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) 将由一个分类器 pϕ(yxt1)p_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 来建模。

则最终的优化(采样)目标:

现在我们将所有简化的部分代回到最初的贝叶斯公式

q^(xt1xt,y)=q^(xt1xt)q^(yxt1,xt)q^(yxt)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)}

根据我们的推导和假设:

  • q^(xt1xt)q(xt1xt)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t) \approx q(\boldsymbol x_{t-1}|\boldsymbol x_t): 我们近似认为条件化的先验与原始DDPM的逆向步骤 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t)(由 qθ(xt1xt)q_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) 建模)相同。这是指导的核心所在,即在原始逆向步骤基础上进行调整。
  • q^(yxt1,xt)=q^(yxt1)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}): 由分类器 qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 给出。
  • 分母 q^(yxt)\hat{q}(\boldsymbol y|\boldsymbol x_t) 作为归一化常数 1/C1/C

因此,采样目标可以写成:

q^(xt1xt,y)q(xt1xt)×qϕ(yxt1)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \propto q(\boldsymbol x_{t-1}|\boldsymbol x_t)\times q_\phi(\boldsymbol y|\boldsymbol x_{t-1})

这意味着条件逆向采样正比于原始 DDPM 的逆向采样 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 和分类器在 xt1\boldsymbol x_{t-1} 处预测条件 y\boldsymbol y 的概率 qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 的乘积。其中 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 就是 DDPM 部分的逆向采样, qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 部分就是一个分类器。因此如果需要模型预测两个部分,一个模型就是 DDPM,预测 qθ(xt1xt)q_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) (通过预测噪声 ϵθ(xt,t)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t)),另一个模型需要一个分类器,预测 qϕ(yxt)q_\phi(\boldsymbol y|\boldsymbol x_t)

Score-based Guidance Trick (分类器指导技巧)

前面我们推导出理想的条件逆向采样满足(或者说,我们希望构造的条件逆向采样过程):

q^(xt1xt,y)q(xt1xt)×qϕ(yxt1)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \propto q(\boldsymbol x_{t-1}|\boldsymbol x_t)\times q_\phi(\boldsymbol y|\boldsymbol x_{t-1})

这里的 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 是原始 DDPM 的逆向(去噪)步骤,而 qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 是一个分类器在给定 xt1\boldsymbol x_{t-1} 时预测条件 y\boldsymbol y 的概率。

然而,这个公式直接应用存在一个核心问题:在 tt 时刻计算(采样) xt1\boldsymbol x_{t-1} 时,我们需要 qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) 的值,但 xt1\boldsymbol x_{t-1} 本身是未知的,是我们正要采样的目标。这就构成了一个循环依赖,在实际中我们转而依赖一个在 xt\boldsymbol x_t 上操作的分类器。

​ 为了解决这个问题并实现有效的指导,Classifier Guidance 采用了以下关键技巧:

  1. 分类器作用于当前状态 xt\boldsymbol x_t:我们不再试图使用依赖于未来(待采样)状态 xt1\boldsymbol x_{t-1} 的分类器 qϕ(yxt1)q_\phi(\boldsymbol y|\boldsymbol x_{t-1})。取而代之,我们训练或使用一个分类器 qϕ(yxt,t)q_\phi(\boldsymbol y|\boldsymbol x_t, t),该分类器根据当前的噪声图像 xt\boldsymbol x_t 和可选的时间步 tt 来预测条件 y\boldsymbol y 的概率。这有效地打破了循环依赖,因为在采样 xt1\boldsymbol x_{t-1} 的时刻,xt\boldsymbol x_t 是已知的。

  2. 通过梯度引导均值 (Mean Perturbation via Gradient):原始 DDPM 的逆向采样步骤 pθ(xt1xt)p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) (模型对 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 的近似) 通常被建模为一个高斯分布 N(xt1;μθ(xt,t),σt2I)\mathcal{N}(\boldsymbol x_{t-1}; \boldsymbol\mu_\theta(\boldsymbol x_t, t), \sigma_t^2 \boldsymbol{I})。其均值 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t) 由 DDPM 的噪声预测网络 ϵθ(xt,t)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t) 决定:

    μθ(xt,t)=1αt(xt1αt1αˉtϵθ(xt,t))\boldsymbol\mu_\theta(\boldsymbol x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol\epsilon_\theta(\boldsymbol x_t, t)\right)

    方差 σt2\sigma_t^2 通常是预定义的,例如 σt2=β~t=1αˉt11αˉtβt\sigma_t^2 = \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t 或者 σt2=βt\sigma_t^2 = \beta_t

Classifier Guidance 的核心思想是利用分类器 qϕ(yxt,t)q_\phi(\boldsymbol y|\boldsymbol x_t, t) 的对数概率的梯度来调整(或“扰动”)这个均值 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t),从而将采样过程“引导”向更可能满足条件 y\boldsymbol yxt1\boldsymbol x_{t-1}

​ 我们可以考虑以下推导。这对应了后验分布 $ q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) $ 的一种建模思路。通过贝叶斯定理(并关注对数概率):

logq(xt1xt,y)=logpθ(xtxt1)+logpθ(yxt)logZ(xt,y)\log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \log p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) + \log p_\theta(\boldsymbol y|\boldsymbol x_t) - \log Z(\boldsymbol x_t, \boldsymbol y)

其中 Z(xt,y)Z(\boldsymbol x_t, \boldsymbol y) 是归一化常数。我们主要关注前两项的近似展开。我们知道 $ p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) $ (给定 xt1\boldsymbol x_{t-1}xt\boldsymbol x_t 的分布) 是高斯分布:

pθ(xtxt1)=N(xt;μ(xt1),Σ(xt1))p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) = \mathcal{N}(\boldsymbol x_t; \boldsymbol\mu(\boldsymbol x_{t-1}), \boldsymbol\Sigma(\boldsymbol x_{t-1}))

忽略常数项,其对数概率为:

logpθ(xtxt1)12(xtμ)TΣ1(xtμ)\log p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) \approx -\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu)

​ 对于 logpθ(yxt)\log p_\theta(\boldsymbol y|\boldsymbol x_t) (观测模型),当噪声较小(即 Σ\boldsymbol\Sigma 较小,使得 xt\boldsymbol x_t 集中在真实值 μ\boldsymbol\mu 附近时),可在 xt=μ\boldsymbol x_t = \boldsymbol\mu 处进行一阶泰勒展开:

logpθ(yxt)logpθ(yμ)+(xtμ)Txtlogpθ(yxt)xt=μ\log p_\theta(\boldsymbol y|\boldsymbol x_t) \approx \log p_\theta(\boldsymbol y|\boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \nabla_{\boldsymbol x_t} \log p_\theta(\boldsymbol y|\boldsymbol x_t)\big|_{\boldsymbol x_t=\boldsymbol\mu}

记梯度项为 g=xtlogpθ(yxt)xt=μ\boldsymbol g = \nabla_{\boldsymbol x_t} \log p_\theta(\boldsymbol y|\boldsymbol x_t)\big|_{\boldsymbol x_t=\boldsymbol\mu},则近似为:

logpθ(yxt)const+(xtμ)Tg\log p_\theta(\boldsymbol y|\boldsymbol x_t) \approx \text{const} + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g

将两项合并到 logq(xt1xt,y)\log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) 中(此时我们将 μ\boldsymbol\muΣ\boldsymbol\Sigma 视为给定 xt1\boldsymbol x_{t-1} 下的参数,而 xt\boldsymbol x_t 是变量),忽略常数项:

logq(,y)12(xtμ)TΣ1(xtμ)+(xtμ)Tg\log q(\cdot|\cdot, \boldsymbol y) \approx -\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g

通过配方法(Completing the Square)整理上式:

12(xtμ)TΣ1(xtμ)+(xtμ)Tg=12[(xtμ)TΣ1(xtμ)2(xtμ)Tg]=12[(xtμ)TΣ1(xtμ)2(xtμ)TΣ1(Σg)]=12((xtμ)Σg)TΣ1((xtμ)Σg)+12(Σg)TΣ1(Σg)=12(xt(μ+Σg))TΣ1(xt(μ+Σg))+constlogq(xt1xt,y)N(xt1;μ+Σ,Σ2)\begin{aligned} &-\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g \\ &= -\frac{1}{2} \left[ (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) - 2(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g \right] \\ &= -\frac{1}{2} \left[ (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) - 2(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1} (\boldsymbol\Sigma \boldsymbol g) \right] \\ &= -\frac{1}{2} \left( (\boldsymbol x_t - \boldsymbol\mu) - \boldsymbol\Sigma \boldsymbol g \right)^T \boldsymbol\Sigma^{-1} \left( (\boldsymbol x_t - \boldsymbol\mu) - \boldsymbol\Sigma \boldsymbol g \right) + \frac{1}{2} (\boldsymbol\Sigma \boldsymbol g)^T \boldsymbol\Sigma^{-1} (\boldsymbol\Sigma \boldsymbol g) \\ &= -\frac{1}{2}(\boldsymbol x_t - (\boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g))^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - (\boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g)) + \text{const} \\ &\log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \sim \mathcal{N}(\boldsymbol x_{t-1}; \boldsymbol\mu + \boldsymbol \Sigma, \boldsymbol \Sigma^2) \end{aligned}

​ 这表明,在上述近似下,后验分布(或与 xt\boldsymbol x_t 相关的条件分布)仍然是一个高斯分布,但其均值从 μ\boldsymbol\mu 移至了 μpost=μ+Σg\boldsymbol\mu_{\text{post}} = \boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g

如果一个高斯分布的对数概率叠加上一个(近似)线性项 (vμold)Tg(\boldsymbol v - \boldsymbol\mu_{old})^T \boldsymbol g(或者直接是 vTg\boldsymbol v^T \boldsymbol g'),其均值会从 μold\boldsymbol\mu_{old} 平移到 μold+Σg\boldsymbol\mu_{old} + \boldsymbol\Sigma \boldsymbol g(或 μold+Σg\boldsymbol\mu_{old} + \boldsymbol\Sigma \boldsymbol g'

​ 将此原理应用于我们对 xt1\boldsymbol x_{t-1} 的采样均值 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t) 的修正:原始均值为 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t),对应的(逆向过程)方差为 σt2I\sigma_t^2 \boldsymbol{I},指导梯度为 xtlogqϕ(yxt,t)\nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)。因此,修正后的均值形式上就体现为 μθ(xt,t)+σt2Ixtlogqϕ(yxt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t) + \sigma_t^2 \boldsymbol{I} \cdot \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)。引入一个指导强度超参数 ss,我们便得到最终的均值更新规则。

修改后的均值 μ^(xt,y,t)\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) 计算如下:

μ^(xt,y,t)=μθ(xt,t)+sσt2xtlogqϕ(yxt,t)\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) = \boldsymbol\mu_\theta(\boldsymbol x_t, t) + s \cdot \sigma_t^2 \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)

其中:

  • xtlogqϕ(yxt,t)\nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) 是分类器对数似然关于其输入 xt\boldsymbol x_t 的梯度。这个梯度向明了在输入空间中能够使分类器最快地增加对类别 y\boldsymbol y 置信度的方向。
  • ss 是指导强度 (guidance scale) 或指导权重 (guidance weight),一个正的超参数。它控制了分类器指导的强度。当 s=0s=0 时,模型退化为无条件生成。较大的 ss 会使生成结果更贴合条件 y\boldsymbol y,但可能会牺牲生成样本的多样性或导致过度锐化等问题。
  • σt2\sigma_t^2 (即 pθ(xt1xt)p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) 的方差) 乘以梯度项,用于根据当前噪声水平调整指导的幅度。这确保了指导效果与扩散模型的内在尺度相匹配。

新的条件采样步骤

因此,在每个逆向采样步骤 tt(从 TT11),我们执行以下操作:
a. 使用 DDPM 模型 ϵθ(xt,t)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t) 计算原始均值 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t)
b. 计算分类器 qϕ(yxt,t)q_\phi(\boldsymbol y|\boldsymbol x_t, t) 的对数似然梯度 xtlogqϕ(yxt,t)\nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)
c. 根据上述公式计算修正后的均值 μ^(xt,y,t)\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t)
d. 从以下高斯分布中采样得到 xt1\boldsymbol x_{t-1}

xt1N(xt1;μ^(xt,y,t),σt2I)\boldsymbol x_{t-1} \sim \mathcal{N}(\boldsymbol x_{t-1}; \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t), \sigma_t^2 \mathbf{I})

image-20250603004953608

与Score Matching (分数匹配) 的联系:

​ 这个指导技巧与基于分数的生成模型 (Score-based Generative Models) 的联系。在这些模型中,条件生成的关键在于估计条件概率分布 pt(xty)p_t(\boldsymbol x_t|\boldsymbol y) 的对数梯度,即分数函数 xtlogpt(xty)\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t|\boldsymbol y)

​ 根据贝叶斯定理,pt(xty)=pt(yxt)pt(xt)pt(y)p_t(\boldsymbol x_t|\boldsymbol y) = \frac{p_t(\boldsymbol y|\boldsymbol x_t) p_t(\boldsymbol x_t)}{p_t(\boldsymbol y)}。两边取对数并对 xt\boldsymbol x_t 求梯度,得到:

xtlogpt(xty)=xtlogpt(xt)原始分数+xtlogpt(yxt)条件分数\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t|\boldsymbol y) = \underbrace{\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t)}_{\text{原始分数}} + \underbrace{\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t)}_{\text{条件分数}}

这表明条件分布的分数等于无条件分布的分数与似然函数(由分类器给出)的分数之和。

  • 原始分数: xtlogpt(xt)\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t):DDPM 中的噪声预测模型 ϵθ(xt,t)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t) 实际上是在估计这个分数的一个缩放版本。具体来说,sθ(xt,t)=ϵθ(xt,t)1αˉt\boldsymbol s_\theta(\boldsymbol x_t, t) = -\frac{\boldsymbol\epsilon_\theta(\boldsymbol x_t, t)}{\sqrt{1-\bar{\alpha}_t}} 是对 qt(xt)q_t(\boldsymbol x_t) (即 pt(xt)p_t(\boldsymbol x_t) 的真实对应) 分数 xtlogqt(xt)\nabla_{\boldsymbol x_t} \log q_t(\boldsymbol x_t) 的估计。DDPM 的均值计算公式 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t) 隐式地使用了这个分数来指导去噪。

  • 条件分数: xtlogpt(yxt)\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t):这正是由分类器 qϕ(yxt,t)q_\phi(\boldsymbol y|\boldsymbol x_t, t) 提供的梯度项 xtlogqϕ(yxt,t)\nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)

因此,通过调整均值 μθ(xt,t)\boldsymbol\mu_\theta(\boldsymbol x_t, t) 来得到 μ^(xt,y,t)\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t),我们实际上是在有效地将无条件模型的去噪方向(由 ϵθ\boldsymbol\epsilon_\theta 决定)与分类器指示的使样本更符合条件 y\boldsymbol y 的方向进行线性组合

​ 这种方法非常巧妙,因为它避免了从头开始训练一个复杂的条件扩散模型。我们只需要一个预训练的无条件 DDPM 和一个在(可能带噪的)数据上训练的分类器,就可以实现高质量、可控的条件图像生成。这体现了模块化和组合优化的思想