title: classifier_guidance
date: 2025-06-02 20:06:57
tags:
由于基于 diffusion 的图像生成模型多样性过高(不可控,而且容易出现很多真实性很差的图片,FID 指标太高),可能的一种解决办法就是 conditional generation,即将 $$p(x)\rightarrow p(x|y)$$,让一个 DDPM 变为一个 conditional DDPM。
条件生成模型是只指在条件 y \boldsymbol y y (condition)下,根据条件生成内容 x \boldsymbol x x ,概率表示为 p ( x ∣ y ) p(\boldsymbol x|\boldsymbol y) p ( x ∣ y ) ,
Classifier Guidance
注意我们对 guidance 的要求是:
不影响训练过程,即我们只在采样过程做文章
尝试用一种高效的方法进行指导采样,而不是最后接一个简单的分类器,让模型生成很多图片最终选一个最好的图片
我们用 q ^ \hat{q} q ^ 表述加入条件 y \boldsymbol y y 后的(逆向)采样过程。我们的目标是修改无条件的逆向采样 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 得到条件采样 q ^ ( x t − 1 ∣ x t , y ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) q ^ ( x t − 1 ∣ x t , y ) 。利用贝叶斯定理来引入条件 y \boldsymbol y y :
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t ) q ^ ( y ∣ x t ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)}
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( y ∣ x t ) q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t )
这个公式的右边包含三个概率项:
q ^ ( x t − 1 ∣ x t ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t) q ^ ( x t − 1 ∣ x t ) : 在给定 x t \boldsymbol x_t x t 的情况下,x t − 1 \boldsymbol x_{t-1} x t − 1 的(条件化后的)先验分布。我们希望它与原始的 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 相关。
q ^ ( y ∣ x t − 1 , x t ) \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) q ^ ( y ∣ x t − 1 , x t ) : 给定 x t \boldsymbol x_t x t 和 x t − 1 \boldsymbol x_{t-1} x t − 1 后,条件 y \boldsymbol y y 的似然。这部分将由分类器提供指导。
q ^ ( y ∣ x t ) \hat{q}(\boldsymbol y|\boldsymbol x_t) q ^ ( y ∣ x t ) : 证据项 (evidence),作为归一化常数。
看上去这些新的 q ^ \hat{q} q ^ 分布我们都不知道,无法进行下一步求解。但是注意我们需要的是一个 training-free 的模型,即不改变原始 DDPM 的训练。为此,我们做出如下规定和假设:
核心假设 (Training-Free):
条件化的前向过程不变: q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) \hat{q}(\boldsymbol x_{t}|\boldsymbol x_{t-1}, \boldsymbol y) = q(\boldsymbol x_{t}|\boldsymbol x_{t-1}) q ^ ( x t ∣ x t − 1 , y ) = q ( x t ∣ x t − 1 ) 。这意味着给定 x t − 1 \boldsymbol x_{t-1} x t − 1 ,x t \boldsymbol x_t x t 的产生(加噪)过程不直接依赖于外部条件 y \boldsymbol y y 。这是保持原始扩散模型不变的关键。
初始分布不变: q ^ ( x 0 ) = q ( x 0 ) \hat{q}(\boldsymbol x_0) = q(\boldsymbol x_0) q ^ ( x 0 ) = q ( x 0 ) 。
Markov property:整个条件化的前向轨迹满足: q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) = q ( x 1 : T ∣ x 0 ) \hat{q}(\boldsymbol x_{1:T}|\boldsymbol x_0, \boldsymbol y) = \prod_{t=1}^T \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) = \prod_{t=1}^T q(\boldsymbol x_t|\boldsymbol x_{t-1}) = q(\boldsymbol x_{1:T}|\boldsymbol x_0) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) = q ( x 1 : T ∣ x 0 ) 。
接下来,我们基于这些假设推导上面三个概率的具体形式
q ^ ( x t ∣ x t − 1 ) \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}) q ^ ( x t ∣ x t − 1 ) 求解
这一步的目标是证明,即使我们考虑了条件 y \boldsymbol y y 的存在,当我们对 y \boldsymbol y y 进行积分(边缘化)后,从 x t − 1 \boldsymbol x_{t-1} x t − 1 到 x t \boldsymbol x_t x t 的单步前向转移概率与原始的无条件转移概率 q ( x t ∣ x t − 1 ) q(\boldsymbol x_t|\boldsymbol x_{t-1}) q ( x t ∣ x t − 1 ) 是一致的。这有助于简化后续分析,表明模型的前向“骨架”在整体上未变:
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ x t − 1 , y ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) ∫ y q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) \begin{aligned}
\hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}) &= \int_\boldsymbol y \hat{q}(\boldsymbol x_t, \boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\
&= \int_\boldsymbol y \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\
&= \int_\boldsymbol y q(\boldsymbol x_t|\boldsymbol x_{t-1}) \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\
&= q(\boldsymbol x_t|\boldsymbol x_{t-1}) \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\
&= q(\boldsymbol x_t|\boldsymbol x_{t-1})
\end{aligned}
q ^ ( x t ∣ x t − 1 ) = ∫ y q ^ ( x t , y ∣ x t − 1 ) d y = ∫ y q ^ ( x t ∣ x t − 1 , y ) q ^ ( y ∣ x t − 1 ) d y = ∫ y q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 ) ∫ y q ^ ( y ∣ x t − 1 ) d y = q ( x t ∣ x t − 1 )
这表明,在边缘化条件 y \boldsymbol y y 后,条件模型的前向转移概率与原始模型的相同
q ^ ( x t ) \hat{q}(\boldsymbol x_t) q ^ ( x t ) 求解
这一步的目标是证明在所有时间步和初始状态上进行积分后,任意时刻 t t t 的边缘分布 q ^ ( x t ) \hat{q}(\boldsymbol x_t) q ^ ( x t ) 与原始的 q ( x t ) q(\boldsymbol x_t) q ( x t ) 相同。这意味着尽管我们引入了条件生成的机制,但从宏观上看,数据在各个噪声水平的(无条件)分布特性保持不变
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 \begin{aligned}
\hat{q}(\boldsymbol x_t) &= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol{x}_{0:t})d\boldsymbol{x}_{0:t-1} \\
&= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1}
\end{aligned}
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 : t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1
为此需要求解 q ^ ( x 1 : t ∣ x 0 ) \hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0) q ^ ( x 1 : t ∣ x 0 ) :
q ^ ( x 1 : t ∣ x 0 ) = ∫ y q ^ ( x 1 : t , y ∣ x 0 ) d y = ∫ y q ^ ( y ∣ x 0 ) q ^ ( x 1 : t ∣ x 0 , y ) d y = ∫ y q ^ ( y ∣ x 0 ) ∏ i = 1 t q ^ ( x i ∣ x i − 1 , y ) d y ( Markov property ) = ∫ y q ^ ( y ∣ x 0 ) ∏ i = 1 t q ( x i ∣ x i − 1 ) d y = ( ∏ i = 1 t q ( x i ∣ x i − 1 ) ) ∫ y q ^ ( y ∣ x 0 ) d y = q ( x 1 : t ∣ x 0 ) ( ∫ y q ^ ( y ∣ x 0 ) d y = 1 ) \begin{aligned}
\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)&= \int_\boldsymbol y \hat{q}(\boldsymbol x_{1:t}, \boldsymbol y|\boldsymbol x_0) d\boldsymbol y= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\hat{q}(\boldsymbol x_{1:t}|\boldsymbol x_0, \boldsymbol y) d\boldsymbol y\\
&= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t \hat{q}(\boldsymbol x_i|\boldsymbol x_{i-1}, \boldsymbol y) d\boldsymbol y \quad (\text{Markov property}) \\
&= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t q(\boldsymbol x_i|\boldsymbol x_{i-1}) d\boldsymbol y \\
&= \left(\prod_{i=1}^t q(\boldsymbol x_i|\boldsymbol x_{i-1})\right) \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0) d\boldsymbol y \\
&= q(\boldsymbol x_{1:t}|\boldsymbol x_0) \quad (\int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0) d\boldsymbol y = 1)
\end{aligned}
q ^ ( x 1 : t ∣ x 0 ) = ∫ y q ^ ( x 1 : t , y ∣ x 0 ) d y = ∫ y q ^ ( y ∣ x 0 ) q ^ ( x 1 : t ∣ x 0 , y ) d y = ∫ y q ^ ( y ∣ x 0 ) i = 1 ∏ t q ^ ( x i ∣ x i − 1 , y ) d y ( Markov property ) = ∫ y q ^ ( y ∣ x 0 ) i = 1 ∏ t q ( x i ∣ x i − 1 ) d y = ( i = 1 ∏ t q ( x i ∣ x i − 1 ) ) ∫ y q ^ ( y ∣ x 0 ) d y = q ( x 1 : t ∣ x 0 ) ( ∫ y q ^ ( y ∣ x 0 ) d y = 1 )
这证明了 q ^ ( x 1 : t ∣ x 0 ) = q ( x 1 : t ∣ x 0 ) \hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)=q(\boldsymbol x_{1:t}|\boldsymbol x_0) q ^ ( x 1 : t ∣ x 0 ) = q ( x 1 : t ∣ x 0 ) ,那么接下来:
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 : t ) d x 0 : t − 1 = q ( x t ) \begin{aligned}
\hat{q}(\boldsymbol x_t)
&= \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\
&= \int_{\boldsymbol x_{0:t-1}}q(\boldsymbol x_0)q(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\
&= \int_{\boldsymbol x_{0:t-1}} q(\boldsymbol x_{0:t})d\boldsymbol{x}_{0:t-1} = q(\boldsymbol x_t)
\end{aligned}
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 : t ) d x 0 : t − 1 = q ( x t )
这表明任意时刻 t t t 的边缘分布 q ^ ( x t ) \hat{q}(\boldsymbol x_t) q ^ ( x t ) 与原始的 q ( x t ) q(\boldsymbol x_t) q ( x t ) 相同。这再次确认了我们的 training-free 假设下,模型的基础特性得以保持。
q ^ ( y ∣ x t − 1 , x t ) \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) q ^ ( y ∣ x t − 1 , x t ) 求解
这是贝叶斯公式中的关键似然项之一。我们希望简化它,理想情况下将其与一个易于建模的分类器联系起来
q ^ ( y ∣ x t − 1 , x t ) = q ^ ( x t ∣ x t − 1 , y ) q ^ ( y ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = q ^ ( y ∣ x t − 1 ) \begin{aligned}
\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) &= \frac{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1}, \boldsymbol y)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1})} \\
&= \frac{q (\boldsymbol x_t| \boldsymbol x_{t-1})\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{q (\boldsymbol x_t| \boldsymbol x_{t-1})} = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1})
\end{aligned}
q ^ ( y ∣ x t − 1 , x t ) = q ^ ( x t ∣ x t − 1 ) q ^ ( x t ∣ x t − 1 , y ) q ^ ( y ∣ x t − 1 ) = q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) q ^ ( y ∣ x t − 1 ) = q ^ ( y ∣ x t − 1 )
这个推导表明 q ^ ( y ∣ x t − 1 , x t ) = q ^ ( y ∣ x t − 1 ) \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) q ^ ( y ∣ x t − 1 , x t ) = q ^ ( y ∣ x t − 1 ) 。这意味着在给定 x t − 1 \boldsymbol x_{t-1} x t − 1 的情况下,x t \boldsymbol x_t x t 对于确定 y \boldsymbol y y 不提供额外信息(即 y ↔ x t − 1 ↔ x t y \leftrightarrow \boldsymbol x_{t-1} \leftrightarrow \boldsymbol x_t y ↔ x t − 1 ↔ x t 形成马尔可夫链)。这一项 q ^ ( y ∣ x t − 1 ) \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) q ^ ( y ∣ x t − 1 ) 将由一个分类器 p ϕ ( y ∣ x t − 1 ) p_\phi(\boldsymbol y|\boldsymbol x_{t-1}) p ϕ ( y ∣ x t − 1 ) 来建模。
则最终的优化(采样)目标:
现在我们将所有简化的部分代回到最初的贝叶斯公式
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t ) q ^ ( y ∣ x t ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)}
q ^ ( x t − 1 ∣ x t , y ) = q ^ ( y ∣ x t ) q ^ ( x t − 1 ∣ x t ) q ^ ( y ∣ x t − 1 , x t )
根据我们的推导和假设:
q ^ ( x t − 1 ∣ x t ) ≈ q ( x t − 1 ∣ x t ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t) \approx q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ^ ( x t − 1 ∣ x t ) ≈ q ( x t − 1 ∣ x t ) : 我们近似认为条件化的先验与原始DDPM的逆向步骤 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) (由 q θ ( x t − 1 ∣ x t ) q_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) q θ ( x t − 1 ∣ x t ) 建模)相同。这是指导的核心所在,即在原始逆向步骤基础上进行调整。
q ^ ( y ∣ x t − 1 , x t ) = q ^ ( y ∣ x t − 1 ) \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) = \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) q ^ ( y ∣ x t − 1 , x t ) = q ^ ( y ∣ x t − 1 ) : 由分类器 q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 给出。
分母 q ^ ( y ∣ x t ) \hat{q}(\boldsymbol y|\boldsymbol x_t) q ^ ( y ∣ x t ) 作为归一化常数 1 / C 1/C 1 / C 。
因此,采样目标可以写成:
q ^ ( x t − 1 ∣ x t , y ) ∝ q ( x t − 1 ∣ x t ) × q ϕ ( y ∣ x t − 1 ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \propto q(\boldsymbol x_{t-1}|\boldsymbol x_t)\times q_\phi(\boldsymbol y|\boldsymbol x_{t-1})
q ^ ( x t − 1 ∣ x t , y ) ∝ q ( x t − 1 ∣ x t ) × q ϕ ( y ∣ x t − 1 )
这意味着条件逆向采样正比于原始 DDPM 的逆向采样 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 和分类器在 x t − 1 \boldsymbol x_{t-1} x t − 1 处预测条件 y \boldsymbol y y 的概率 q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 的乘积。其中 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 就是 DDPM 部分的逆向采样, q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 部分就是一个分类器 。因此如果需要模型预测两个部分,一个模型就是 DDPM,预测 q θ ( x t − 1 ∣ x t ) q_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) q θ ( x t − 1 ∣ x t ) (通过预测噪声 ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta(\boldsymbol x_t, t) ϵ θ ( x t , t ) ),另一个模型需要一个分类器,预测 q ϕ ( y ∣ x t ) q_\phi(\boldsymbol y|\boldsymbol x_t) q ϕ ( y ∣ x t )
Score-based Guidance Trick (分类器指导技巧)
前面我们推导出理想的条件逆向采样满足(或者说,我们希望构造的条件逆向采样过程):
q ^ ( x t − 1 ∣ x t , y ) ∝ q ( x t − 1 ∣ x t ) × q ϕ ( y ∣ x t − 1 ) \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \propto q(\boldsymbol x_{t-1}|\boldsymbol x_t)\times q_\phi(\boldsymbol y|\boldsymbol x_{t-1})
q ^ ( x t − 1 ∣ x t , y ) ∝ q ( x t − 1 ∣ x t ) × q ϕ ( y ∣ x t − 1 )
这里的 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 是原始 DDPM 的逆向(去噪)步骤,而 q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 是一个分类器在给定 x t − 1 \boldsymbol x_{t-1} x t − 1 时预测条件 y \boldsymbol y y 的概率。
然而,这个公式直接应用存在一个核心问题:在 t t t 时刻计算(采样) x t − 1 \boldsymbol x_{t-1} x t − 1 时,我们需要 q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 的值,但 x t − 1 \boldsymbol x_{t-1} x t − 1 本身是未知的,是我们正要采样的目标。这就构成了一个循环依赖,在实际中我们转而依赖一个在 x t \boldsymbol x_t x t 上操作的分类器。
为了解决这个问题并实现有效的指导,Classifier Guidance 采用了以下关键技巧:
分类器作用于当前状态 x t \boldsymbol x_t x t :我们不再试图使用依赖于未来(待采样)状态 x t − 1 \boldsymbol x_{t-1} x t − 1 的分类器 q ϕ ( y ∣ x t − 1 ) q_\phi(\boldsymbol y|\boldsymbol x_{t-1}) q ϕ ( y ∣ x t − 1 ) 。取而代之,我们训练或使用一个分类器 q ϕ ( y ∣ x t , t ) q_\phi(\boldsymbol y|\boldsymbol x_t, t) q ϕ ( y ∣ x t , t ) ,该分类器根据当前 的噪声图像 x t \boldsymbol x_t x t 和可选的时间步 t t t 来预测条件 y \boldsymbol y y 的概率。这有效地打破了循环依赖,因为在采样 x t − 1 \boldsymbol x_{t-1} x t − 1 的时刻,x t \boldsymbol x_t x t 是已知的。
通过梯度引导均值 (Mean Perturbation via Gradient):原始 DDPM 的逆向采样步骤 p θ ( x t − 1 ∣ x t ) p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) p θ ( x t − 1 ∣ x t ) (模型对 q ( x t − 1 ∣ x t ) q(\boldsymbol x_{t-1}|\boldsymbol x_t) q ( x t − 1 ∣ x t ) 的近似) 通常被建模为一个高斯分布 N ( x t − 1 ; μ θ ( x t , t ) , σ t 2 I ) \mathcal{N}(\boldsymbol x_{t-1}; \boldsymbol\mu_\theta(\boldsymbol x_t, t), \sigma_t^2 \boldsymbol{I}) N ( x t − 1 ; μ θ ( x t , t ) , σ t 2 I ) 。其均值 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) 由 DDPM 的噪声预测网络 ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta(\boldsymbol x_t, t) ϵ θ ( x t , t ) 决定:
μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol\epsilon_\theta(\boldsymbol x_t, t)\right)
μ θ ( x t , t ) = α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ θ ( x t , t ) )
方差 σ t 2 \sigma_t^2 σ t 2 通常是预定义的,例如 σ t 2 = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma_t^2 = \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t σ t 2 = β ~ t = 1 − α ˉ t 1 − α ˉ t − 1 β t 或者 σ t 2 = β t \sigma_t^2 = \beta_t σ t 2 = β t 。
Classifier Guidance 的核心思想是利用分类器 q ϕ ( y ∣ x t , t ) q_\phi(\boldsymbol y|\boldsymbol x_t, t) q ϕ ( y ∣ x t , t ) 的对数概率的梯度来调整(或“扰动”)这个均值 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) ,从而将采样过程“引导”向更可能满足条件 y \boldsymbol y y 的 x t − 1 \boldsymbol x_{t-1} x t − 1
我们可以考虑以下推导。这对应了后验分布 $ q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) $ 的一种建模思路。通过贝叶斯定理(并关注对数概率):
log q ( x t − 1 ∣ x t , y ) = log p θ ( x t ∣ x t − 1 ) + log p θ ( y ∣ x t ) − log Z ( x t , y ) \log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \log p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) + \log p_\theta(\boldsymbol y|\boldsymbol x_t) - \log Z(\boldsymbol x_t, \boldsymbol y)
log q ( x t − 1 ∣ x t , y ) = log p θ ( x t ∣ x t − 1 ) + log p θ ( y ∣ x t ) − log Z ( x t , y )
其中 Z ( x t , y ) Z(\boldsymbol x_t, \boldsymbol y) Z ( x t , y ) 是归一化常数。我们主要关注前两项的近似展开。我们知道 $ p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) $ (给定 x t − 1 \boldsymbol x_{t-1} x t − 1 时 x t \boldsymbol x_t x t 的分布) 是高斯分布:
p θ ( x t ∣ x t − 1 ) = N ( x t ; μ ( x t − 1 ) , Σ ( x t − 1 ) ) p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) = \mathcal{N}(\boldsymbol x_t; \boldsymbol\mu(\boldsymbol x_{t-1}), \boldsymbol\Sigma(\boldsymbol x_{t-1}))
p θ ( x t ∣ x t − 1 ) = N ( x t ; μ ( x t − 1 ) , Σ ( x t − 1 ) )
忽略常数项,其对数概率为:
log p θ ( x t ∣ x t − 1 ) ≈ − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) \log p_\theta(\boldsymbol x_t|\boldsymbol x_{t-1}) \approx -\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu)
log p θ ( x t ∣ x t − 1 ) ≈ − 2 1 ( x t − μ ) T Σ − 1 ( x t − μ )
对于 log p θ ( y ∣ x t ) \log p_\theta(\boldsymbol y|\boldsymbol x_t) log p θ ( y ∣ x t ) (观测模型),当噪声较小(即 Σ \boldsymbol\Sigma Σ 较小,使得 x t \boldsymbol x_t x t 集中在真实值 μ \boldsymbol\mu μ 附近时),可在 x t = μ \boldsymbol x_t = \boldsymbol\mu x t = μ 处进行一阶泰勒展开:
log p θ ( y ∣ x t ) ≈ log p θ ( y ∣ μ ) + ( x t − μ ) T ∇ x t log p θ ( y ∣ x t ) ∣ x t = μ \log p_\theta(\boldsymbol y|\boldsymbol x_t) \approx \log p_\theta(\boldsymbol y|\boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \nabla_{\boldsymbol x_t} \log p_\theta(\boldsymbol y|\boldsymbol x_t)\big|_{\boldsymbol x_t=\boldsymbol\mu}
log p θ ( y ∣ x t ) ≈ log p θ ( y ∣ μ ) + ( x t − μ ) T ∇ x t log p θ ( y ∣ x t ) ∣ ∣ ∣ x t = μ
记梯度项为 g = ∇ x t log p θ ( y ∣ x t ) ∣ x t = μ \boldsymbol g = \nabla_{\boldsymbol x_t} \log p_\theta(\boldsymbol y|\boldsymbol x_t)\big|_{\boldsymbol x_t=\boldsymbol\mu} g = ∇ x t log p θ ( y ∣ x t ) ∣ ∣ ∣ x t = μ ,则近似为:
log p θ ( y ∣ x t ) ≈ const + ( x t − μ ) T g \log p_\theta(\boldsymbol y|\boldsymbol x_t) \approx \text{const} + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g
log p θ ( y ∣ x t ) ≈ const + ( x t − μ ) T g
将两项合并到 log q ( x t − 1 ∣ x t , y ) \log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) log q ( x t − 1 ∣ x t , y ) 中(此时我们将 μ \boldsymbol\mu μ 和 Σ \boldsymbol\Sigma Σ 视为给定 x t − 1 \boldsymbol x_{t-1} x t − 1 下的参数,而 x t \boldsymbol x_t x t 是变量),忽略常数项:
log q ( ⋅ ∣ ⋅ , y ) ≈ − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T g \log q(\cdot|\cdot, \boldsymbol y) \approx -\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g
log q ( ⋅ ∣ ⋅ , y ) ≈ − 2 1 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T g
通过配方法(Completing the Square)整理上式:
− 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T g = − 1 2 [ ( x t − μ ) T Σ − 1 ( x t − μ ) − 2 ( x t − μ ) T g ] = − 1 2 [ ( x t − μ ) T Σ − 1 ( x t − μ ) − 2 ( x t − μ ) T Σ − 1 ( Σ g ) ] = − 1 2 ( ( x t − μ ) − Σ g ) T Σ − 1 ( ( x t − μ ) − Σ g ) + 1 2 ( Σ g ) T Σ − 1 ( Σ g ) = − 1 2 ( x t − ( μ + Σ g ) ) T Σ − 1 ( x t − ( μ + Σ g ) ) + const log q ( x t − 1 ∣ x t , y ) ∼ N ( x t − 1 ; μ + Σ , Σ 2 ) \begin{aligned}
&-\frac{1}{2}(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) + (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g \\
&= -\frac{1}{2} \left[ (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) - 2(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol g \right] \\
&= -\frac{1}{2} \left[ (\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - \boldsymbol\mu) - 2(\boldsymbol x_t - \boldsymbol\mu)^T \boldsymbol\Sigma^{-1} (\boldsymbol\Sigma \boldsymbol g) \right] \\
&= -\frac{1}{2} \left( (\boldsymbol x_t - \boldsymbol\mu) - \boldsymbol\Sigma \boldsymbol g \right)^T \boldsymbol\Sigma^{-1} \left( (\boldsymbol x_t - \boldsymbol\mu) - \boldsymbol\Sigma \boldsymbol g \right) + \frac{1}{2} (\boldsymbol\Sigma \boldsymbol g)^T \boldsymbol\Sigma^{-1} (\boldsymbol\Sigma \boldsymbol g) \\
&= -\frac{1}{2}(\boldsymbol x_t - (\boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g))^T \boldsymbol\Sigma^{-1}(\boldsymbol x_t - (\boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g)) + \text{const} \\
&\log q(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) \sim \mathcal{N}(\boldsymbol x_{t-1}; \boldsymbol\mu + \boldsymbol \Sigma, \boldsymbol \Sigma^2)
\end{aligned}
− 2 1 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) T g = − 2 1 [ ( x t − μ ) T Σ − 1 ( x t − μ ) − 2 ( x t − μ ) T g ] = − 2 1 [ ( x t − μ ) T Σ − 1 ( x t − μ ) − 2 ( x t − μ ) T Σ − 1 ( Σ g ) ] = − 2 1 ( ( x t − μ ) − Σ g ) T Σ − 1 ( ( x t − μ ) − Σ g ) + 2 1 ( Σ g ) T Σ − 1 ( Σ g ) = − 2 1 ( x t − ( μ + Σ g ) ) T Σ − 1 ( x t − ( μ + Σ g ) ) + const log q ( x t − 1 ∣ x t , y ) ∼ N ( x t − 1 ; μ + Σ , Σ 2 )
这表明,在上述近似下,后验分布(或与 x t \boldsymbol x_t x t 相关的条件分布)仍然是一个高斯分布,但其均值从 μ \boldsymbol\mu μ 移至了 μ post = μ + Σ g \boldsymbol\mu_{\text{post}} = \boldsymbol\mu + \boldsymbol\Sigma \boldsymbol g μ post = μ + Σ g 。
如果一个高斯分布的对数概率叠加上一个(近似)线性项 ( v − μ o l d ) T g (\boldsymbol v - \boldsymbol\mu_{old})^T \boldsymbol g ( v − μ o l d ) T g (或者直接是 v T g ′ \boldsymbol v^T \boldsymbol g' v T g ′ ),其均值会从 μ o l d \boldsymbol\mu_{old} μ o l d 平移到 μ o l d + Σ g \boldsymbol\mu_{old} + \boldsymbol\Sigma \boldsymbol g μ o l d + Σ g (或 μ o l d + Σ g ′ \boldsymbol\mu_{old} + \boldsymbol\Sigma \boldsymbol g' μ o l d + Σ g ′ )
将此原理应用于我们对 x t − 1 \boldsymbol x_{t-1} x t − 1 的采样均值 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) 的修正:原始均值为 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) ,对应的(逆向过程)方差为 σ t 2 I \sigma_t^2 \boldsymbol{I} σ t 2 I ,指导梯度为 ∇ x t log q ϕ ( y ∣ x t , t ) \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) ∇ x t log q ϕ ( y ∣ x t , t ) 。因此,修正后的均值形式上就体现为 μ θ ( x t , t ) + σ t 2 I ⋅ ∇ x t log q ϕ ( y ∣ x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) + \sigma_t^2 \boldsymbol{I} \cdot \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) μ θ ( x t , t ) + σ t 2 I ⋅ ∇ x t log q ϕ ( y ∣ x t , t ) 。引入一个指导强度超参数 s s s ,我们便得到最终的均值更新规则。
修改后的均值 μ ^ ( x t , y , t ) \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) μ ^ ( x t , y , t ) 计算如下:
μ ^ ( x t , y , t ) = μ θ ( x t , t ) + s ⋅ σ t 2 ∇ x t log q ϕ ( y ∣ x t , t ) \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) = \boldsymbol\mu_\theta(\boldsymbol x_t, t) + s \cdot \sigma_t^2 \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t)
μ ^ ( x t , y , t ) = μ θ ( x t , t ) + s ⋅ σ t 2 ∇ x t log q ϕ ( y ∣ x t , t )
其中:
∇ x t log q ϕ ( y ∣ x t , t ) \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) ∇ x t log q ϕ ( y ∣ x t , t ) 是分类器对数似然关于其输入 x t \boldsymbol x_t x t 的梯度。这个梯度向明了在输入空间中能够使分类器最快地增加对类别 y \boldsymbol y y 置信度的方向。
s s s 是指导强度 (guidance scale) 或指导权重 (guidance weight),一个正的超参数。它控制了分类器指导的强度。当 s = 0 s=0 s = 0 时,模型退化为无条件生成。较大的 s s s 会使生成结果更贴合条件 y \boldsymbol y y ,但可能会牺牲生成样本的多样性或导致过度锐化等问题。
σ t 2 \sigma_t^2 σ t 2 (即 p θ ( x t − 1 ∣ x t ) p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t) p θ ( x t − 1 ∣ x t ) 的方差) 乘以梯度项,用于根据当前噪声水平调整指导的幅度。这确保了指导效果与扩散模型的内在尺度相匹配。
新的条件采样步骤
因此,在每个逆向采样步骤 t t t (从 T T T 到 1 1 1 ),我们执行以下操作:
a. 使用 DDPM 模型 ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta(\boldsymbol x_t, t) ϵ θ ( x t , t ) 计算原始均值 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) 。
b. 计算分类器 q ϕ ( y ∣ x t , t ) q_\phi(\boldsymbol y|\boldsymbol x_t, t) q ϕ ( y ∣ x t , t ) 的对数似然梯度 ∇ x t log q ϕ ( y ∣ x t , t ) \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) ∇ x t log q ϕ ( y ∣ x t , t ) 。
c. 根据上述公式计算修正后的均值 μ ^ ( x t , y , t ) \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) μ ^ ( x t , y , t ) 。
d. 从以下高斯分布中采样得到 x t − 1 \boldsymbol x_{t-1} x t − 1 :
x t − 1 ∼ N ( x t − 1 ; μ ^ ( x t , y , t ) , σ t 2 I ) \boldsymbol x_{t-1} \sim \mathcal{N}(\boldsymbol x_{t-1}; \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t), \sigma_t^2 \mathbf{I})
x t − 1 ∼ N ( x t − 1 ; μ ^ ( x t , y , t ) , σ t 2 I )
与Score Matching (分数匹配) 的联系:
这个指导技巧与基于分数的生成模型 (Score-based Generative Models) 的联系。在这些模型中,条件生成的关键在于估计条件概率分布 p t ( x t ∣ y ) p_t(\boldsymbol x_t|\boldsymbol y) p t ( x t ∣ y ) 的对数梯度,即分数函数 ∇ x t log p t ( x t ∣ y ) \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t|\boldsymbol y) ∇ x t log p t ( x t ∣ y ) 。
根据贝叶斯定理,p t ( x t ∣ y ) = p t ( y ∣ x t ) p t ( x t ) p t ( y ) p_t(\boldsymbol x_t|\boldsymbol y) = \frac{p_t(\boldsymbol y|\boldsymbol x_t) p_t(\boldsymbol x_t)}{p_t(\boldsymbol y)} p t ( x t ∣ y ) = p t ( y ) p t ( y ∣ x t ) p t ( x t ) 。两边取对数并对 x t \boldsymbol x_t x t 求梯度,得到:
∇ x t log p t ( x t ∣ y ) = ∇ x t log p t ( x t ) ⏟ 原始分数 + ∇ x t log p t ( y ∣ x t ) ⏟ 条件分数 \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t|\boldsymbol y) = \underbrace{\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t)}_{\text{原始分数}} + \underbrace{\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t)}_{\text{条件分数}}
∇ x t log p t ( x t ∣ y ) = 原始分数 ∇ x t log p t ( x t ) + 条件分数 ∇ x t log p t ( y ∣ x t )
这表明条件分布的分数等于无条件分布的分数与似然函数(由分类器给出)的分数之和。
原始分数: ∇ x t log p t ( x t ) \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t) ∇ x t log p t ( x t ) :DDPM 中的噪声预测模型 ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta(\boldsymbol x_t, t) ϵ θ ( x t , t ) 实际上是在估计这个分数的一个缩放版本。具体来说,s θ ( x t , t ) = − ϵ θ ( x t , t ) 1 − α ˉ t \boldsymbol s_\theta(\boldsymbol x_t, t) = -\frac{\boldsymbol\epsilon_\theta(\boldsymbol x_t, t)}{\sqrt{1-\bar{\alpha}_t}} s θ ( x t , t ) = − 1 − α ˉ t ϵ θ ( x t , t ) 是对 q t ( x t ) q_t(\boldsymbol x_t) q t ( x t ) (即 p t ( x t ) p_t(\boldsymbol x_t) p t ( x t ) 的真实对应) 分数 ∇ x t log q t ( x t ) \nabla_{\boldsymbol x_t} \log q_t(\boldsymbol x_t) ∇ x t log q t ( x t ) 的估计。DDPM 的均值计算公式 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) 隐式地使用了这个分数来指导去噪。
条件分数: ∇ x t log p t ( y ∣ x t ) \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t) ∇ x t log p t ( y ∣ x t ) :这正是由分类器 q ϕ ( y ∣ x t , t ) q_\phi(\boldsymbol y|\boldsymbol x_t, t) q ϕ ( y ∣ x t , t ) 提供的梯度项 ∇ x t log q ϕ ( y ∣ x t , t ) \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t, t) ∇ x t log q ϕ ( y ∣ x t , t ) 。
因此,通过调整均值 μ θ ( x t , t ) \boldsymbol\mu_\theta(\boldsymbol x_t, t) μ θ ( x t , t ) 来得到 μ ^ ( x t , y , t ) \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) μ ^ ( x t , y , t ) ,我们实际上是在有效地将无条件模型的去噪方向(由 ϵ θ \boldsymbol\epsilon_\theta ϵ θ 决定)与分类器指示的使样本更符合条件 y \boldsymbol y y 的方向进行线性组合
这种方法非常巧妙,因为它避免了从头开始训练一个复杂的条件扩散模型。我们只需要一个预训练的无条件 DDPM 和一个在(可能带噪的)数据上训练的分类器,就可以实现高质量、可控的条件图像生成。这体现了模块化和组合优化的思想