生成式模型统一视角-Diffusion

生成式模型统一视角-Diffusion

Diffusion的优化思路

​ 我们先从整体的框架去了解 diffusion 在干什么,而不是按照论文的公式一步一步来,那样会被各种细节的公式给迷惑住而忘记了整体的模型设计思路,下面的行文中我们省略中间公式的推导

上一次我们推导到了 ELBO 使用的优化的最终形式为:

ELBO=Eq(x1x0)[logp(x0x1)]reconstruction item+DKL(q(xTx0)p(xT))prior matching item+t=2TEq(xTx0)[DKL(pθ(xt1xt)q(xt1xt,x0))]denoising item\text{ELBO}=\underbrace{\mathbb{E}_{q(\boldsymbol{x}_1|\boldsymbol{x}_0)}\left[\log p(\boldsymbol{x}_0|\boldsymbol{x}_1)\right]}_\text{reconstruction item} + \underbrace{D_{KL}(q(\boldsymbol{x}_T|\boldsymbol{x}_0)||p(\boldsymbol{x}_T))}_\text{prior matching item} + \sum_{t=2}^T\underbrace{\mathbb{E}_{q(\boldsymbol{x}_T|\boldsymbol{x}_0)}\left[D_{KL}\left( p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) || q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) \right)\right]}_\text{denoising item}

其中优化的重点在于 denoising item 这一项:

t=2TEq(xTx0)[DKL(pθ(xt1xt)q(xt1xt,x0))]\sum_{t=2}^T\mathbb{E}_{q(\boldsymbol{x}_T|\boldsymbol{x}_0)}\left[D_{KL}\left( p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) || q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) \right)\right]

其中 q(xt1xt,x0)q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) 为真实去噪分布,这实际上是可以使用

推导出真实去噪分布:

​ 按照加噪过程的条件:

q(xtxt1)=N(xt;αtxt1,(1αt)I)q(x_t|x_{t-1})=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I)

它的含义是,给定一个 xt1x_{t-1}xtx_t 的分布是完全随机的,类似于随机游走(这个从 SDE 的视角好解释),其中 αt\alpha_t 是超参数,它的值不是学习的,因此后向的去噪过程也可以使用这个参数

​ 使用重参数化的技巧,可以将 xtx_t 使用 x0x_0 表示:

xt=αˉtx0+1αˉtξ0ξ0N(ξ0;0,I)x_t = \sqrt{\bar \alpha_t}x_0 + \sqrt{1 - \bar \alpha_t}\xi_0 \\ \xi_0 \sim \mathcal{N}(\xi_0; 0, I)

​ 那么可以推导出真实的去噪分布:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xt1x0)=N(xt1;αt(1αˉt+1)xt+αˉt1(1αt)x01αˉt,(1αˉt)(1αˉt+1)1αˉtI)define:μq(xt,x0)=αt(1αˉt+1)xt+αˉt1(1αt)x01αˉtdefine:Σq(t)=(1αˉt)(1αˉt+1)1αˉtIq(x_{t-1}|x_t,x_0) =\frac{q(x_t|x_{t-1}, x_0) q(x_{t-1}|x_0)}{q(x_{t-1}|x_0)} \\ = \mathcal{N}(x_{t-1}; \frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t}, \frac{(1-\bar \alpha_t)(1-\bar \alpha_{t+1})}{1-\bar \alpha_t}I)\\ \text{define:} \quad \mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t} \\ \text{define:} \quad \Sigma_q(t) = \frac{(1-\bar \alpha_t)(1-\bar \alpha_{t+1})}{1-\bar \alpha_t}I

让模型拟合真实去噪分布:

​ 我们上面推导出了真实的去噪分布,现在我们需要设计模型 pθp_\theta 让他能够拟合这个分布,我们也让模型拟合的分布也设为高斯分布,且方差相同,注意 αt\alpha_t 可以是一个超参数,它不是通过学习得到的

Givenxt1q(xt1xt,x0)Letxt1pθ(xt1xt)=N(xt1;μθ,Σθ(t))Σθ(t)=Σq(t)\text{Given} \quad x_{t-1}\sim q(x_{t-1}|x_t,x_0) \\ \text{Let} \quad x_{t-1} \sim p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta, \Sigma_\theta(t))\\ \Sigma_\theta(t) = \Sigma_q(t)

  • 之所以能让方差相等,是因为 Σq(t)\Sigma_q(t) 中的表达式是超参数,我们在即使没有前向过程的时候是不知道的,但是 μq\mu_q 的解析式中含有 x0x_0,这一项是在没有前向过程的时候是不知道的,因此让方差相同而均值不同

​ 因此可以让 denoising item 的优化目标变为:

argminθDKL(q(xt1xt,x0)pθ(xt1xt))=argminθ12σq2(t)[μθμq22]\arg \min_\theta D_{KL}\left( q(x_{t-1}|x_t,x_0) || p_\theta(x_{t-1}|x_t) \right) \\ =\arg \min_\theta \frac{1}{2\sigma_q^2(t)}\left[ \Vert \mu_\theta - \mu_q \Vert_2^2 \right]

因此模型只用去拟合一个向量 μθ\mu_\theta,现在就引出了另外一个问题,如何设计模型,尽可能最大限度地利用模型的表达能力,让模型去拟合 μθ\mu_\theta 呢?

设定 μθ\mu_\theta

μq(xt,x0)=αt(1αˉt+1)xt+αˉt1(1αt)x01αˉt\mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t}

  • 首先,无论什么设定,我们需要注意的一点,由于 generative model 的目的是生成新的样本,因此我们需要一个可以只进行反向过程的模型,这意味着我们不能在时间 tt 时刻出现在时间 tt 之前的变量,也就是 x0,x1x_0,x_1\dots

设定一:直接去拟合 x0x_0

​ 我们使用 μq\mu_q 相同的格式,只去改变 x0x_0 项:

μθ=αt(1αˉt+1)xt+αˉt1(1αt)x^θ(xt,t)1αˉt\mu_\theta =\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t) \textcolor{red}{\hat{x}_\theta(x_t,t)}}{1-\bar \alpha_t}

则在这种设定下,模型的输入为 xtx_t 和时间步 tt,让 μθ\mu_\theta 去拟合 μq\mu_q 等效为:

argminθ12σq2(t)αˉt1(1αt)2(1αˉt)2(x^θ(xt,t)x022)\arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert \hat x_\theta(x_t, t) - x_0\Vert_2^2 \right)

但是实际上没人会用这种设定方式,我们下面分析不使用它的原因

  • 很荒谬的一点是,我们的初衷本来是想通过 tt 步采样得到我们的 x0x_0,但是我们每一步去噪的时候,去拟合 μq\mu_q 的同时我们需要去预测一个 x0x_0,如果能够直接单步采样出很好的 x0x_0,那我们要多步采样还干嘛,直接使用 GAN 之类的模型就好了
  • 记住神经网络虽然具有万能逼近定理,但是不是所有的函数的拟合难度都是一样的,我们要尽可能去利用神经网络的表达能力,个人的经验是,神经网络能够很好地去拟合需要微调的量,例如 diffusion 的 noise 和 deformable convolution 的 offset 等

设定二:去拟合 ξ0\xi_0

​ 实际上,我们可以让 xtx_t 重参数化为:

xt=αˉtx0+1αˉtξ0ξ0N(ξ0;0,I)x_t = \sqrt{\bar \alpha_t}x_0 + \sqrt{1 - \bar \alpha_t}\xi_0 \\ \xi_0 \sim \mathcal{N}(\xi_0; 0, I)

​ 则 x0x_0 项可以用 xtx_tξ0\xi_0 表示:

x0=xt1αˉtξ0dˉtμq(xt,x0)=1dtxt1αt1αˉtdtξ0x_0 = \frac{x_t - \sqrt{1-\bar\alpha_t}\xi_0}{\sqrt{\bar d_t}} \\ \mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}}\xi_0

那么我们让模型也用相同的形式拟合:

μθ=μq(xt,x0)=1dtxt1αt1αˉtdtξˉθ(xt,t)\mu_\theta =\mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}} \textcolor{red}{\bar\xi_\theta(x_t, t)}

​ 因此优化目标可以变为:

argminθ12σq2(t)αˉt1(1αt)2(1αˉt)2(ξ^θ(xt,t)ξ022)\arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert \hat \xi_\theta(x_t, t) - \xi_0 \Vert_2^2 \right)

  • 到这里我们可能会好奇,为什么模型能够拟合一个随机变量?让模型去拟合一个随机变量的分布这个说法听起来十分抽象,但是实际上 ξ0\xi_0 的值在反向过程中是确定的,它是在正向过程中的时候采样得到的,因此确定,即使我们在训练的时候没有正向过程,我们仍然可以认为那张原图本身是存在的,只是模型没有正向过程去采样罢了

The sculpture is already complete within the marble block before I start my work. It is already there, I just have to chisel away the superfluous material.

​ ——Michelangelo

设定三:Tweedie 公式

​ 使用 Tweedie 公式,我们可以将上面的 x0x_0 写成:

x0=xt1αˉtlogp(x)dˉtx_0 = \frac{x_t - \sqrt{1-\bar\alpha_t}\nabla\log p(x) }{\sqrt{\bar d_t}}

这时候的模型就是基于分数的模型了,这个的原因将在后面的博客内解释;因此我们可以将 μq\mu_qμθ\mu_\theta 的解析式写为:

μq(xt,x0)=1dtxt1αt1αˉtdtlogp(x)μq(xt,x0)=1dtxt1αt1αˉtdtSθ(xt,t)\mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}}\nabla \log p(x) \\ \Rightarrow \mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}} \textcolor{red}{S_\theta(x_t,t)}

​ 因此优化目标变为:

argminθ12σq2(t)αˉt1(1αt)2(1αˉt)2(Sθ(xt,t)logp(x)22)\arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert S_\theta(x_t, t) - \nabla \log p(x) \Vert_2^2 \right)

计算细节:

​ 打 latex 太繁琐了,请查阅论文:Understanding Diffusion Models: A Unified Perspective

实现网络:

最终去噪网络的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from torch import nn
import torch
import math


class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()

def forward(
self, x, t,
):
h = self.bnorm1(self.relu(self.conv1(x)))
# Time embedding
time_emb = self.relu(self.time_mlp(t))
# 使用 position embedding 的方式进行 time embedding
time_emb = time_emb[(...,) + (None,) * 2]
h = h + time_emb
h = self.bnorm2(self.relu(self.conv2(h)))
# Down or Upsample
return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings


class SimpleUnet(nn.Module):
def __init__(self):
super().__init__()
image_channels = 3
down_channels = (64, 128, 256, 512, 1024)
up_channels = (1024, 512, 256, 128, 64)
out_dim = 3
time_emb_dim = 32

# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)

# Initial projection
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

# Downsample
self.downs = nn.ModuleList(
[Block(down_channels[i], down_channels[i + 1], time_emb_dim) for i in range(len(down_channels) - 1)]
)
# Upsample
self.ups = nn.ModuleList(
[Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) for i in range(len(up_channels) - 1)]
)

self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

def forward(self, x, timestep):
# Embedd time
t = self.time_mlp(timestep)
# Initial conv
x = self.conv0(x)

# Unet
residual_inputs = []
for down in self.downs:
x = down(x, t)
residual_inputs.append(x)
for up in self.ups:
residual_x = residual_inputs.pop()
x = torch.cat((x, residual_x), dim=1)
x = up(x, t)
return self.output(x)


if __name__ == "__main__":
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
img_size = 64
img = torch.randn((1, 3, img_size, img_size), device='cpu')
t = torch.tensor([4], device='cpu')
denoise = model(img, t)
print(denoise.shape)

训练代码如下:

​ 在训练扩散模型时,每次训练的输入并不是从一张图片出发,经过所有去噪步骤生成的一系列中间结果,而是将每一步加噪后的图像视为独立的样本进行训练。训练过程中会随机选择时间步 tt,并对每张图像进行加噪,生成带有噪声的图像 xtx_t 和对应的噪声 ϵ\epsilon,然后使用这些加噪后的图像及其对应的时间步 tt 来训练模型

  • 如果对一张图片进行所有去噪步骤并将其作为一个序列来训练,会导致训练过程非常冗长和复杂

  • 随机选择时间步的方法使得训练过程更加高效,因为每个批次的数据都是独立的,可以并行处理

  • 通过在每个时间步上独立训练,模型可以更稳定地学习去噪任务,避免了序列训练中可能出现的累积误差

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from forward_noising import forward_diffusion_sample
from unet import SimpleUnet
from dataloader import load_transformed_dataset
import torch.nn.functional as F
import torch
from torch.optim import Adam


def get_loss(model, x_0, t, device):
x_noisy, noise = forward_diffusion_sample(x_0, t, device)
noise_pred = model(x_noisy, t)
return F.mse_loss(noise, noise_pred)


if __name__ == "__main__":
model = SimpleUnet()
T = 300
BATCH_SIZE = 128
epochs = 100

dataloader = load_transformed_dataset(batch_size=BATCH_SIZE)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
for batch_idx, (batch, _) in enumerate(dataloader):
optimizer.zero_grad()

t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
loss = get_loss(model, batch, t, device=device)
loss.backward()
optimizer.step()


torch.save(model.state_dict(), "./trained_models/ddpm_mse_epochs_100.pth")