classifier_guidance

​ 由于基于 diffusion 的图像生成模型多样性过高(不可控,而且容易出现很多真实性很差的图片,FID 指标太高),可能的一种解决办法就是 conditional generation,即将 $$p(x)\rightarrow p(x|y)$$,让一个 DDPM 变为一个 conditional DDPM。

​ 条件生成模型是只指在条件 y\boldsymbol y(condition)下,根据条件生成内容 x\boldsymbol x,概率表示为 p(xy)p(\boldsymbol x|\boldsymbol y)

Classifier Guidance

注意我们对 guidance 的要求是:

  • 不影响训练过程,即我们只在采样过程做文章
  • 尝试用一种高效的方法进行指导采样,而不是最后接一个简单的分类器,让模型生成很多图片最终选一个最好的图片

我们用 q^\hat{q} 表述加入条件的采样过程,注意我们不改变前向过程(不改变 pp,也就没有 p^\hat{p}),则我们需要对采样过程修改:

q(xt1xt)q^(xt1xt,y)=q^(xt1xt)q^(yxt1,xt)q^(yxt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) \rightarrow \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)}

看上去我们三个概率值都不知道,无法进行下一步求解。但是注意我们需要的是一个 training-free 的模型,我们规定:

q^(xt+1xt,y)=q(xt+1xt)q^(x0)=q(x0)q^(x1:Tx0,y)=t=1Tq(xtxt1,y)\hat{q}(\boldsymbol x_{t+1}|\boldsymbol x_t, \boldsymbol y) = q(\boldsymbol x_{t+1}|\boldsymbol x_t) \\ \hat{q}(\boldsymbol x_0) = q(\boldsymbol x_0) \\ \hat{q}(\boldsymbol x_{1:T}|\boldsymbol x_0, y) = \prod_{t=1}^T q(\boldsymbol x_t|\boldsymbol x_{t-1}, y)

新的采样方式是:

q^(xt1xt)=q^(xtxt1)q^(xt1)q^(xt)q^(xtxt1)=yq^(xt,yxt1)dy=yq^(xtxt1,y)q^(yxt1)dy=q(xtxt1)yq^(yxt1)dy=q(xtxt1)\begin{aligned} \hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t) =& \frac{\hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1})\hat{q}(\boldsymbol x_{t-1})}{\hat{q}(\boldsymbol x_t)} \\ \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}) =& \int_\boldsymbol y \hat{q}(\boldsymbol x_t, \boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ =& \int_\boldsymbol y \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ =& q(\boldsymbol x_t|\boldsymbol x_{t-1})\int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) d\boldsymbol y \\ =& q(\boldsymbol x_t|\boldsymbol x_{t-1}) \end{aligned}

接下来求 q^(xt)\hat{q}(\boldsymbol x_t)

q^(xt)=x0:t1q^(x0:t)dx0:t1=x0:t1q^(x0)q^(x1:tx0)dx0:t1q^(x1:tx0)=yq^(x1:t,yx0)dy=yq^(yx0)q^(x1:tx0,y)dy=yq^(yx0)i=1tq^(xtxt1,y)dy=yq^(yx0)i=1tq(xtxt1)dy=yq^(yx0)q(x1:tx0)dy=q(x1:tx0)\begin{aligned} \hat{q}(\boldsymbol x_t) =& \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol{x}_{0:t})d\boldsymbol{x}_{0:t-1} \\ =& \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\ \hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)=& \int_\boldsymbol y \hat{q}(\boldsymbol x_{1:t}, \boldsymbol y|\boldsymbol x_0) d\boldsymbol y= \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\hat{q}(\boldsymbol x_{1:t}|\boldsymbol x_0, \boldsymbol y) d\boldsymbol y\\ =& \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t \hat{q}(\boldsymbol x_t|\boldsymbol x_{t-1}, \boldsymbol y) d\boldsymbol y\\ =& \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0)\prod_{i=1}^t q(\boldsymbol x_t|\boldsymbol x_{t-1}) d\boldsymbol y \\ =& \int_\boldsymbol y \hat{q}(\boldsymbol y|\boldsymbol x_0) q(\boldsymbol x_{1:t}|\boldsymbol x_0) d\boldsymbol y =q(\boldsymbol x_{1:t}|\boldsymbol x_0) \end{aligned}

很神奇 q^(x1:tx0)=q(x1:tx0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)=q(\boldsymbol x_{1:t}|\boldsymbol x_0),那么接下来:

q^(xt)=x0:t1q^(x0)q^(x1:tx0)dx0:t1=x0:t1q(x0)q(x1:tx0)dx0:t1=x0:t1q(x0:t)dx0:t1=q(xt)\begin{aligned} \hat{q}(\boldsymbol x_t) =& \int_{\boldsymbol x_{0:t-1}} \hat{q}(\boldsymbol x_0)\hat{q}(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\ =& \int_{\boldsymbol x_{0:t-1}}q(\boldsymbol x_0)q(\boldsymbol{x}_{1:t}|\boldsymbol x_0)d\boldsymbol{x}_{0:t-1} \\ =& \int_{\boldsymbol x_{0:t-1}} q(\boldsymbol x_{0:t})d\boldsymbol{x}_{0:t-1} = q(\boldsymbol x_t) \end{aligned}

最后:

q^(yxt1,xt)=q^(xtxt1,y)q^(yxt1)q^(xtxt1)=q^(xtxt1)q^(yxt1)q^(xtxt1)=q^(yxt1)\begin{aligned} \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t) =& \frac{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1}, \boldsymbol y)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1})} \\ =& \hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1})\frac{\hat{q} (\boldsymbol y| \boldsymbol x_{t-1})}{\hat{q} (\boldsymbol x_t| \boldsymbol x_{t-1})} \\ =& \hat{q} (\boldsymbol y| \boldsymbol x_{t-1}) \end{aligned}

则最终的优化终极目标:

q^(xt1xt,y)=q^(xt1xt)q^(yxt1,xt)q^(yxt)=C×q(xt1xt)×q^(yxt1)\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t, \boldsymbol y) = \frac{\hat{q}(\boldsymbol x_{t-1}|\boldsymbol x_t)\hat{q} (\boldsymbol y| \boldsymbol x_{t-1}, \boldsymbol x_t)}{\hat{q}(\boldsymbol y|\boldsymbol x_t)} \\ =C \times q(\boldsymbol x_{t-1}|\boldsymbol x_t)\times \hat{q}(\boldsymbol y|\boldsymbol x_{t-1})

其中 C=1/q^(yxt1)C=1/\hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) 是一个归一化常数,重点的采样是在 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t)q^(yxt1)\hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) 部分,其中 q(xt1xt)q(\boldsymbol x_{t-1}|\boldsymbol x_t) 就是 DDPM 部分的逆向采样, q^(yxt1)\hat{q}(\boldsymbol y|\boldsymbol x_{t-1}) 部分就是一个分类器,因此如果需要模型预测两个部分,一个模型就是 DDPM,预测 qθ(xt1xt)q_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t),另一个模型需要一个分类器,预测 $$

Classifier Free Guidance

论文链接:Classifier-Free Diffusion Guidance