classifier_free_guidance

在讨论 Classifier Guidance 时,我们看到它通过一个预训练的分类器 qϕ(yxt)q_\phi(\boldsymbol y|\boldsymbol x_t) 的梯度来引导 DDPM 的采样过程,从而实现条件生成。这种方法虽然有效,但依赖于一个额外的分类器模型,该分类器需要在带噪图像上进行训练,并且其性能直接影响最终生成效果。此外,分类器的训练数据和噪声分布可能与 DDPM 不完全匹配,导致指导效果并非最优。

例如,在 Classifier Guidance 中,我们通过调整均值 μ\boldsymbol\mu 来实现引导,形式为 μguided=μoriginal+Σg\boldsymbol\mu_{\text{guided}} = \boldsymbol\mu_{\text{original}} + \boldsymbol\Sigma \boldsymbol g,其中 g=xtlogqϕ(yxt)\boldsymbol g = \nabla_{\boldsymbol x_t} \log q_\phi(\boldsymbol y|\boldsymbol x_t) 是分类器梯度,Σ\boldsymbol\Sigma 是逆向过程的协方差。当 Σ\boldsymbol\Sigma 很小(例如在去噪过程的后期,噪声水平较低时),基于梯度的调整项 Σg\boldsymbol\Sigma \boldsymbol g 的作用可能会减弱甚至失效,限制了指导的强度和效果。

Classifier-Free Guidance (无分类器指导) 提出了一种更为简洁和统一的方法,它不再需要一个独立的分类器。核心思想是让扩散模型本身同时学习条件分布和无条件分布,并在采样时利用这两者之间的差异进行自我指导,直接对预测的噪声 ϵ\boldsymbol\epsilon 进行修改,从而避免了上述对 Σ\boldsymbol\Sigma 大小的依赖问题。

Classifier-Free Guidance 的核心思想

Classifier-Free Guidance 的核心在于训练一个能够同时处理条件输入 y\boldsymbol y 和无条件情况(通常用一个特殊的空条件 \emptyset 表示)的噪声预测网络 ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y)

  1. 统一的噪声预测模型: 我们训练单个噪声预测网络 ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y)。这个网络将条件 y\boldsymbol y 作为额外的输入。
  2. 训练策略: 在训练过程中,对于一部分训练样本(例如10-20%的比例),我们会随机地将真实的条件 y\boldsymbol y替换为一个特殊的“空”条件或“无条件”标记 \emptyset。这意味着模型在训练时,既要学会根据给定的 y\boldsymbol y 预测噪声(即学习 ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y)),也要学会在没有明确条件时预测噪声(即学习 ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset))。这两者共享同一套模型参数 θ\theta

指导原理与采样过程

Classifier-Free Guidance 的巧妙之处在于它如何在采样阶段利用同一个模型产生的条件预测和无条件预测来实现指导。其背后的思想与我们在 Classifier Guidance 中看到的基于分数(score)的调整类似。

回忆一下,条件分布的对数梯度(score)可以分解为:

xtlogpt(xty)=xtlogpt(xt)+xtlogpt(yxt)\nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t|\boldsymbol y) = \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t) + \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t)

suncond(xt)=xtlogpt(xt)s_{\text{uncond}}(\boldsymbol x_t) = \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol x_t) 为无条件分数,g(xt,y)=xtlogpt(yxt)\boldsymbol g(\boldsymbol x_t, \boldsymbol y) = \nabla_{\boldsymbol x_t} \log p_t(\boldsymbol y|\boldsymbol x_t) 为条件项的梯度(在 Classifier Guidance 中由分类器提供)。目标引导分数是 sguided(xt,y)=suncond(xt)+g(xt,y)s_{\text{guided}}(\boldsymbol x_t, \boldsymbol y) = s_{\text{uncond}}(\boldsymbol x_t) + \boldsymbol g(\boldsymbol x_t, \boldsymbol y)

在 DDPM 中,无条件分数 suncond(xt)s_{\text{uncond}}(\boldsymbol x_t) 与模型预测的(无条件)噪声 ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) 相关,具体关系为 suncond(xt)=ϵθ(xt,t,)1αˉts_{\text{uncond}}(\boldsymbol x_t) = -\frac{\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset)}{\sqrt{1-\bar{\alpha}_t}}
如果我们希望通过一个等效的“引导后”噪声 ϵ^eff\hat{\boldsymbol\epsilon}_{\text{eff}} 来表示 sguided(xt,y)s_{\text{guided}}(\boldsymbol x_t, \boldsymbol y),即 sguided(xt,y)=ϵ^eff1αˉts_{\text{guided}}(\boldsymbol x_t, \boldsymbol y) = -\frac{\hat{\boldsymbol\epsilon}_{\text{eff}}}{\sqrt{1-\bar{\alpha}_t}},那么:

ϵ^eff1αˉt=ϵθ(xt,t,)1αˉt+g(xt,y)-\frac{\hat{\boldsymbol\epsilon}_{\text{eff}}}{\sqrt{1-\bar{\alpha}_t}} = -\frac{\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset)}{\sqrt{1-\bar{\alpha}_t}} + \boldsymbol g(\boldsymbol x_t, \boldsymbol y)

整理可得:

ϵ^eff=ϵθ(xt,t,)1αˉtg(xt,y)\hat{\boldsymbol\epsilon}_{\text{eff}} = \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) - \sqrt{1-\bar{\alpha}_t} \cdot \boldsymbol g(\boldsymbol x_t, \boldsymbol y)

这个推导表明,即使是基于分类器梯度 g(xt,y)\boldsymbol g(\boldsymbol x_t, \boldsymbol y) 的指导(传统上通过调整均值实现),其效果也可以等价地表示为对原始无条件噪声预测 ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) 的修正。

Classifier-Free Guidance 正是基于直接调整噪声预测的思想:
它不再依赖外部的分类器提供 g(xt,y)\boldsymbol g(\boldsymbol x_t, \boldsymbol y)。取而代之的是,通过训练策略,模型本身学会了:

  • ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y):预测与条件分布 pt(xty)p_t(\boldsymbol x_t|\boldsymbol y) 相关的噪声。
  • ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset):预测与无条件分布 pt(xt)p_t(\boldsymbol x_t) 相关的噪声。

然后,通过以下组合方式构造最终的引导噪声 ϵ^θ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y)

ϵ^θ(xt,t,y)=ϵθ(xt,t,)+s(ϵθ(xt,t,y)ϵθ(xt,t,))\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y) = \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) + s \cdot (\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y) - \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset))

其中:

  • ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y) 是模型在给定当前噪声图像 xt\boldsymbol x_t、时间步 tt 和条件 y\boldsymbol y 下的噪声预测。
  • ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) 是模型在相同 xt\boldsymbol x_ttt 但使用空条件 \emptyset 下的噪声预测(即无条件预测)。
  • ss 是指导强度 (guidance scale),一个超参数。
    • s=0s=0 时,ϵ^θ=ϵθ(xt,t,)\hat{\boldsymbol\epsilon}_\theta = \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset),模型进行无条件生成。
    • s=1s=1 时,ϵ^θ=ϵθ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta = \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y),模型进行标准的条件生成(没有额外的“强调”)。
    • s>1s > 1 时,模型会更加强调条件 y\boldsymbol y 的特征。差值项 (ϵθ(xt,t,y)ϵθ(xt,t,))(\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y) - \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset)) 可以被看作是模型学习到的、从无条件到条件的隐式“修正方向”,其作用类似于之前推导中的 1αˉtg(xt,y)-\sqrt{1-\bar{\alpha}_t} \cdot \boldsymbol g(\boldsymbol x_t, \boldsymbol y) 项(但带有不同的缩放因子 ss)。通过 ss 来放大这个修正量,可以更强地引导生成过程。

这个公式可以改写为:

ϵ^θ(xt,t,y)=(1s)ϵθ(xt,t,)+sϵθ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y) = (1-s)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset) + s \cdot \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y)

这种形式更直观地显示了最终的噪声预测是如何由无条件预测和条件预测加权组合而成的。当 s>1s>1 时,它实际上是对条件预测的“外插”(extrapolation)。

采样步骤

在每个逆向采样步骤 tt(从 TT11),执行以下操作:

  1. 获取条件预测: 输入 xt,t,y\boldsymbol x_t, t, \boldsymbol y 到噪声预测网络,得到 ϵθ(xt,t,y)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y)
  2. 获取无条件预测: 输入 xt,t,\boldsymbol x_t, t, \emptyset 到同一个噪声预测网络,得到 ϵθ(xt,t,)\boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \emptyset)
  3. 计算引导噪声: 使用上述公式计算 ϵ^θ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y)
  4. 计算均值: 使用引导噪声 ϵ^θ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y) 来计算去噪后的均值 μ^(xt,y,t)\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t),其方式与标准 DDPM 相同:

    μ^(xt,y,t)=1αt(xt1αt1αˉtϵ^θ(xt,t,y))\hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t) = \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y)\right)

  5. 采样 xt1\boldsymbol x_{t-1}: 从高斯分布 N(xt1;μ^(xt,y,t),σt2I)\mathcal{N}(\boldsymbol x_{t-1}; \hat{\boldsymbol\mu}(\boldsymbol x_t, \boldsymbol y, t), \sigma_t^2 \mathbf{I}) 中采样得到 xt1\boldsymbol x_{t-1}

算法实现要点

训练阶段

  1. 随机选择时间步 tt
  2. 从数据集中采样 (x0,y)(\boldsymbol x_0, \boldsymbol y)
  3. 以一定概率(例如 puncondp_{\text{uncond}},通常设为 0.1 到 0.2)将条件 y\boldsymbol y 替换为空条件 \emptyset
  4. 根据前向过程加噪:xt=αˉtx0+1αˉtϵ\boldsymbol x_t = \sqrt{\bar{\alpha}_t}\boldsymbol x_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol\epsilon,其中 ϵN(0,I)\boldsymbol\epsilon \sim \mathcal{N}(0, \mathbf{I})
  5. 训练噪声预测网络 ϵθ\boldsymbol\epsilon_\theta 最小化预测误差:Ex0,y,ϵ,tϵϵθ(xt,t,yused)2\mathbb{E}_{\boldsymbol x_0, \boldsymbol y, \boldsymbol\epsilon, t} ||\boldsymbol\epsilon - \boldsymbol\epsilon_\theta(\boldsymbol x_t, t, \boldsymbol y_{\text{used}})||^2,其中 yused\boldsymbol y_{\text{used}} 是可能被替换为 \emptyset 的条件。

生成阶段

  1. 从标准正态分布采样初始噪声 xTN(0,I)\boldsymbol x_T \sim \mathcal{N}(0, \mathbf{I})

  2. 对于 t=T,T1,,1t = T, T-1, \dots, 1
    a. 按照“指导原理与采样过程”中的步骤计算引导噪声 ϵ^θ(xt,t,y)\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y)
    b. 使用此引导噪声通过 DDPM 的标准逆向采样公式计算 xt1\boldsymbol x_{t-1} (如上述采样步骤4和5所示)。
    具体地,可以表示为:

    xt1=1αt(xt1αt1αˉtϵ^θ(xt,t,y))+σtzt\boldsymbol x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(\boldsymbol x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\boldsymbol\epsilon}_\theta(\boldsymbol x_t, t, \boldsymbol y)\right) + \sigma_t \boldsymbol z_t

    ​ 其中 ztN(0,I)\boldsymbol z_t \sim \mathcal{N}(0, \mathbf{I}) (如果 t>1t>1),zt=0\boldsymbol z_t = 0 (如果 t=1t=1)。σt2\sigma_t^2 是预定义的方差。

  3. 最终得到 x0\boldsymbol x_0 即为生成的条件样本。

Classifier-Free Guidance 的训练流程

Classifier-Free Guidance 的采样流程

Classifier-Free Guidance 的优势

相比于 Classifier Guidance,Classifier-Free Guidance 具有以下显著优势:

  1. 无需额外分类器: 最直接的好处是它不需要训练和维护一个独立的分类器模型。这简化了整个流程,减少了对外部组件的依赖。
  2. 更好的对齐性: 由于指导信号来源于扩散模型自身,条件信息和生成过程通常能更好地对齐。分类器可能在与 DDPM 不同的数据分布或噪声水平上训练,导致指导信号并非最优。Classifier-Free Guidance 通过让同一个模型学习条件和无条件行为,避免了这种潜在的不匹配。
  3. 实现简单: 一旦扩散模型按上述方式训练完成,采样时的指导实现非常直接,仅涉及两次模型前向传播和简单的算术组合。
  4. 通常效果更好: 实践表明,Classifier-Free Guidance 往往能产生更高质量、更符合条件的生成结果,尤其是在复杂条件下。

总结

Classifier-Free Guidance 是一种强大且流行的技术,用于在扩散模型中实现可控的条件生成。它通过修改模型的训练策略,使其能够同时进行条件和无条件预测,然后在采样时巧妙地结合这两种预测来强化条件信号,而无需依赖外部的分类器。这种方法不仅简化了条件生成流程,而且通常能带来更好的生成质量和条件一致性,已成为许多先进条件扩散模型的标准配置。