sde

​ 上一篇文章中我们介绍了 score-based model 的基本概念,包括其如何对分布进行建模、如何从建模的分布中进行采样以及通过对分布进行扰动提高其建模精度的方式。在这篇文章中我们将介绍的是如何使用随机微分方程(也就是 SDE)进行 score-based 建模。

本文将score-based model和DDPM进行大一统。虽然SDE构造扩散模型是一个全新的框架,但是其本质还是score。

随机微分方程简介

​ 首先我们先介绍一些随机微分方程的基本知识。 我们首先举一个常微分方程(ODE)的例子,例如下面的一个常微分方程:

dxdt=f(x,t)ordx=f(x,t)dt\frac{d\boldsymbol x}{dt} = f(\boldsymbol x, t) \quad \text{or} \quad d\boldsymbol x = f(\boldsymbol x, t) dt

其中 f(x,t)f(x, t) 描述了 xx 随时间的变化趋势,这个常微分方程可以得到解析解:

x(t)=x(0)+0tf(x,τ)dτ\boldsymbol x(t) = \boldsymbol x(0) + \int_0^t f(\boldsymbol x, \tau) d\tau

然而在计算机中,只能通过类似前向欧拉法等数值方法来迭代得到数值近似解:

x(t+Δt)x(t)+f(x(t),t)Δt\boldsymbol x(t + \Delta t) \approx \boldsymbol x(t) + f(\boldsymbol x(t), t) \Delta t

常微分方程描述了一个确定性的过程,而对于非确定性的过程(比如从分布中采样),则需要使用随机微分方程(SDE)进行描述。随机微分方程相比于常微分方程只是在形式上多了一个高斯噪声:

dxdt=f(x,t)漂移系数+g(t)wt扩散系数ordxt=f(xt,t)dt+g(t)dwt\frac{d\boldsymbol x}{dt} = \underbrace{f(\boldsymbol x, t)}_{漂移系数} + \underbrace{g(t) \boldsymbol w_t}_{扩散系数} \quad \text{or} \quad dx_t = f(\boldsymbol x_t, t) dt + g(t) d\boldsymbol w_t

​ 其中 $ \omega_t $ 表示布朗运动(Brownian motion),也称为维纳过程(Wiener process)。它是描述随机扰动的核心元素,具有以下性质:

  1. 独立增量:任意时刻的增量 $ \omega_{t+\Delta t} - \omega_t $ 与之前的所有增量独立。
  2. 正态分布:增量服从均值为0、方差为 $ \Delta t $ 的正态分布,即 $ \omega_{t+\Delta t} - \omega_t \sim \mathcal{N}(0, \Delta t) $。
  3. 连续路径:$ \omega_t $ 的样本路径几乎处处连续,但不可导

dwtd\boldsymbol w_t 的积分对应伊藤积分(Ito integral),用于建模随机波动对系统的动态影响。在采样时 SDE 和 ODE 类似,也可以进行迭代采样:

x(t+Δt)x(t)+f(x(t),t)Δt+g(t)ΔtN(0,I)\boldsymbol x(t + \Delta t) \approx \boldsymbol x(t) + f(\boldsymbol x(t), t) \Delta t + g(t) \sqrt{\Delta t} \mathcal{N}(0, I)

由于采样过程中存在高斯噪声,进行多次采样会得到不同的轨迹,如下边右图中的一系列绿色折线所示

ode&sde

​ 我们不加证明地给出一个结论(数学家干的事情),一个满足以上条件的 SDE 过程的逆向过程也是一个 SDE,且正向与逆向过程满足条件:

dx=f(x,t)dt+g(t)dwdx=[f(x,t)g2(t)xlogpt(x)]dt+g(t)dw\begin{aligned} d\boldsymbol x &= f(\boldsymbol x,t)dt + g(t)d\boldsymbol w \\ d\boldsymbol x &= \left[ f(\boldsymbol x,t) - g^2(t) {\color{red} \nabla_x \log p_t(x)} \right] dt + g(t)d\boldsymbol w \end{aligned}

离散形式:

xn+1=xn+f(xn,tn)Δt+g(tn)Δwnxn1=xn+[f(xn,tn)g2(tn)xlogptn(xn)]Δt+g(tn)Δtϵn\begin{aligned} \boldsymbol x_{n+1} = \boldsymbol x_n &+ f(\boldsymbol x_n, t_n)\Delta t + g(t_n)\Delta \boldsymbol w_n \\ \boldsymbol x_{n-1} = \boldsymbol x_n &+ \left[ f(\boldsymbol x_n, t_n) - g^2(t_n) {\color{red} \nabla_x \log p_{t_n}(x_n)} \right]\Delta t + g(t_n)\sqrt{\Delta t}\boldsymbol \epsilon_n \end{aligned}

对于逆向过程,从 t=Tt=T 时刻到 t=0t=0 时刻,pt(x)p_t(\boldsymbol x) 表示 x(t)\boldsymbol x(t) 的概率密度函数(时刻 TT 就是随机噪声,时刻 00 就是真实数据分布)。红色的部分,在上节 score-based models 中,就是 score function,不过那里的 score function 的定义中概率分布不会随时间变化 s(x)=xlogp(x)s(\boldsymbol x) = \nabla_x\log p(\boldsymbol x) (其实 p(x)p(\boldsymbol x) 也是变化的,它受加噪强度的影响,因此也不是固定的)

​ 如果我们想从随机噪声(prior distribution)中采样出一个真实图片(data distribution),我们只需要知道逆向扩散过程的参数就行,对于 f(x,t)f(\boldsymbol x,t)g(t)g(t) 都是已知的,因为这几个参数在前向扩散过程中是人为设置的,因此我们只需要 xlogpt(x)\nabla_x \log p_t(\boldsymbol x) 就可以完全求解逆向过程了,因此 SDE 模型的重点就是估计 score

img

连续的加噪过程

​ 因此对于 SDE 模型,我们可以类似 score matching 的方式对模型进行训练,优化的目标为:

J(θ)=EtU(0,T)Ept(x)[λ(t)xlogpt(x)sθ(x,t)2]J(\theta) = \mathbb{E}_{t \in \mathcal{U}(0,T)} \mathbb{E}_{p_t(\boldsymbol x)} \left[ \lambda(t) \Vert \nabla_x \log p_t(\boldsymbol x) - s_\theta(\boldsymbol x, t) \Vert^2 \right]

SDE 与 DDPM

DDPM 的前向 SDE

​ 通过上文的介绍,我们可以发现用 SDE 描述的 score-based model 和扩散模型有很多相似之处。在 DDPM中,前向过程可以描述为以下形式:

xt=1βtxt1+βtϵt1,ϵt1N(0,I)\boldsymbol{x}_t=\sqrt{1-\beta_t}\boldsymbol{x}_{t-1}+\sqrt{\beta_t}\epsilon_{t-1},\quad\epsilon_{t-1}\sim\mathcal{N}(0,I)

这是一个离散的过程,t{0,1,,N}t\in\{0,1,\cdots,N\}。由于 SDE 是连续的,需要将 DDPM 也转变为连续的形式,为此可以将所有时间步都除以 TT,即 t{0,1N,,N1N,1}t\in\{0,\frac 1 N,\cdots,\frac{N-1}N,1\}T=0T=0 时刻是初始时刻,数据分布为真实分布,t=1t=1 时刻是数据完全变为噪声的时刻,当 TT\to\infty,DDPM 就变成了一个连续的过程。代入上式(Δt=1/N\Delta t= 1/N),可以得到:

x(t+Δt)=1β(t+Δt)x(t)+β(t+Δt)ϵ(t)=1β(t+Δt)Δtx(t)+β(t+Δt)Δtϵ(t)(112β(t+Δt)Δt)x(t)+β(t+Δt)Δtϵ(t)x(t)12β(t+Δt)Δtx(t)+β(t)Δtϵ(t)\begin{aligned} \boldsymbol{x}(t+\Delta t)=& \sqrt{1-\beta_{(t+\Delta t)}} \boldsymbol{x}(t)+\sqrt{\beta_{(t+\Delta t)}}\epsilon(t) \\ =& \sqrt{1-\beta(t+\Delta t)\Delta t}\boldsymbol{x}(t)+\sqrt{\beta(t+\Delta t)\Delta t}\epsilon(t) \\ \approx & (1-\frac 12\beta(t+\Delta t) \Delta t)\boldsymbol x(t) +\sqrt{\beta(t+\Delta t)\Delta t}\epsilon(t) \\ \approx& \boldsymbol x(t) - \frac 12\beta(t+\Delta t) \Delta t \boldsymbol x(t) + \sqrt{\beta(t)\Delta t}\epsilon(t) \end{aligned}

其中 β(t)=βt×N\beta(t) = \beta_t \times N写为差分形式:

x(t+Δt)x(t)=12β(t+Δt)Δtx(t)+β(t)Δtϵ(t)dx=12β(t)x(t)dt+β(t)dwf(x,t)=12β(t)x(t),g(t)=β(t)\boldsymbol x(t+\Delta t) - \boldsymbol x(t) = -\frac 12\beta(t+\Delta t) \Delta t \boldsymbol x(t) + \sqrt{\beta(t)\Delta t}\epsilon(t) \\ d\boldsymbol x = -\frac 12\beta(t)\boldsymbol x(t) dt + \sqrt{\beta(t)} d\boldsymbol w \\ f(\boldsymbol x, t) = -\frac12 \beta(t)\boldsymbol x(t), \quad g(t)=\sqrt{\beta(t)}

DDPM 的逆向 SDE

则逆向扩散过程的 SDE 形式为:

dx=[12β(t)x(t)β(t)sθ(t)]dt+β(t)dwd\boldsymbol x = \left[-\frac 12\beta(t)\boldsymbol x(t) -\beta(t)s_\theta(t) \right]dt + \sqrt{\beta(t)} d\boldsymbol w

其中 score function 表达式为:

sθ(t)=xtlogp(xtx0)=xt(xtμt)22σt2=2(xtμt)2σt2=xtμtσt2s_\theta(t)=\nabla_{\boldsymbol x_t}\log p(\boldsymbol x_t|\boldsymbol x_0)=-\nabla_{\boldsymbol x_t} \frac{(\boldsymbol x_t-\boldsymbol \mu_t)^2}{2\sigma_t^2}=-\frac{2(\boldsymbol x_t-\boldsymbol \mu_t)}{2\sigma_t^2}=-\frac{\boldsymbol x_t-\mu_t}{\sigma^2_t}

xt=αtˉx0+1αtˉϵμt=αtˉx0,σt=1αtˉ\boldsymbol x_t = \sqrt{\bar{\alpha_t}}\boldsymbol x_0 + \sqrt{1-\bar{\alpha_t}}\boldsymbol \epsilon \\ \Rightarrow \boldsymbol \mu_t = \sqrt{\bar{\alpha_t}}\boldsymbol x_0, \quad \boldsymbol\sigma_t = \sqrt{1-\bar{\alpha_t}}

s(t)=xtμtσt2=xtαtx01αtˉ=ϵ1αtˉs(t)=-\frac{\boldsymbol x_t-\mu_t}{\sigma^2_t} = -\frac{\boldsymbol x_t-\sqrt{\boldsymbol \alpha_t}\boldsymbol x_0}{1-\bar {\boldsymbol{\alpha_t}}} = -\frac{\boldsymbol \epsilon} {\sqrt{1-\bar {\boldsymbol \alpha_t}}}

即模型等效预测的 score 和等效预测的噪声有如下的等价关系

s_\theta(t) = -\frac{\boldsymbol \epsilon_\theta} {\sqrt{1-\bar {\boldsymbol{\alpha_t}}}} \\ s(t) = -\frac{\boldsymbol \epsilon} {\sqrt{1-\bar {\boldsymbol{\alpha_t}}}} \\\

回顾 DDPM 的采样过程:

x(t)=11β(t+1)(x(t+1)β(t+1)1αˉ(t+1)ϵ0)+β(t+1)z=11β(t+1)(x(t+1)β(t+1)sθ(t+1))+β(t+1)z(1+12β(t+1))(x(t+1)β(t+1)sθ(t+1))+β(t+1)z=x(t+1)+12β(t+1)x(t+1)(β(t+1)+12β2(t+1))sθ(t+1)+β(t+1)z=x(t+1)+f(x(t+1),t)g2(t+1)sθ(t+1)12β2(t+1)sθ(t+1)+β(t+1)zx(t+1)+f(x(t+1),t)g2(t+1)sθ(t+1)+g(t+1)z\begin{aligned} \boldsymbol x(t) =& \frac {1}{\sqrt{1-\beta(t+1)}}\left(\boldsymbol x(t+1) - \frac{\beta(t+1)}{\sqrt{1-\bar{\alpha}(t+1)}} \boldsymbol \epsilon_0\right) + \sqrt{\beta(t+1)}\boldsymbol z \\ =& \frac {1}{\sqrt{1-\beta(t+1)}}\left(\boldsymbol x(t+1) - \beta(t+1)s_\theta(t+1) \right) + \sqrt{\beta(t+1)}\boldsymbol z \\ \approx& \left(1+\frac 12 \beta(t+1) \right) \left(\boldsymbol x(t+1) - \beta(t+1)s_\theta(t+1) \right) + \sqrt{\beta(t+1)}\boldsymbol z \\ =& \boldsymbol x(t+1) +\frac 12 \beta(t+1) \boldsymbol x(t+1) - \left( \beta(t+1)+ \frac12 \beta^2(t+1) \right) s_\theta(t+1) + \sqrt{\beta(t+1)}\boldsymbol z \\ =& \boldsymbol x(t+1) + f(\boldsymbol x(t+1), t) - g^2(t+1)s_\theta(t+1) - \frac12 \beta^2(t+1) s_\theta(t+1) + \sqrt{\beta(t+1)}\boldsymbol z \\ \approx& \boldsymbol x(t+1) + f(\boldsymbol x(t+1), t) - g^2(t+1)s_\theta(t+1) + g(t+1) \boldsymbol z \end{aligned}

对比上面逆向 SDE 的离散形式,可以发现就是一种统一的形式:

xn1=xn+[f(xn,tn)g2(tn)xlogptn(xn)]Δt+g(tn)Δtϵn\boldsymbol x_{n-1} = \boldsymbol x_n + \left[ f(\boldsymbol x_n, t_n) - g^2(t_n) {\color{red} \nabla_x \log p_{t_n}(x_n)} \right]\Delta t + g(t_n)\sqrt{\Delta t}\boldsymbol \epsilon_n

推导到这里可以发现,

SDE 与 SMLD