生成式模型统一视角-Diffusion
Diffusion的优化思路
我们先从整体的框架去了解 diffusion 在干什么,而不是按照论文的公式一步一步来,那样会被各种细节的公式给迷惑住而忘记了整体的模型设计思路,下面的行文中我们省略中间公式的推导
上一次我们推导到了 ELBO 使用的优化的最终形式为:
ELBO = E q ( x 1 ∣ x 0 ) [ log p ( x 0 ∣ x 1 ) ] ⏟ reconstruction item + D K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) ⏟ prior matching item + ∑ t = 2 T E q ( x T ∣ x 0 ) [ D K L ( p θ ( x t − 1 ∣ x t ) ∣ ∣ q ( x t − 1 ∣ x t , x 0 ) ) ] ⏟ denoising item \text{ELBO}=\underbrace{\mathbb{E}_{q(\boldsymbol{x}_1|\boldsymbol{x}_0)}\left[\log p(\boldsymbol{x}_0|\boldsymbol{x}_1)\right]}_\text{reconstruction item} +
\underbrace{D_{KL}(q(\boldsymbol{x}_T|\boldsymbol{x}_0)||p(\boldsymbol{x}_T))}_\text{prior matching item} +
\sum_{t=2}^T\underbrace{\mathbb{E}_{q(\boldsymbol{x}_T|\boldsymbol{x}_0)}\left[D_{KL}\left( p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) || q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) \right)\right]}_\text{denoising item}
ELBO = reconstruction item E q ( x 1 ∣ x 0 ) [ log p ( x 0 ∣ x 1 ) ] + prior matching item D K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) + t = 2 ∑ T denoising item E q ( x T ∣ x 0 ) [ D K L ( p θ ( x t − 1 ∣ x t ) ∣ ∣ q ( x t − 1 ∣ x t , x 0 ) ) ]
其中优化的重点在于 denoising item 这一项:
∑ t = 2 T E q ( x T ∣ x 0 ) [ D K L ( p θ ( x t − 1 ∣ x t ) ∣ ∣ q ( x t − 1 ∣ x t , x 0 ) ) ] \sum_{t=2}^T\mathbb{E}_{q(\boldsymbol{x}_T|\boldsymbol{x}_0)}\left[D_{KL}\left( p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) || q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) \right)\right]
t = 2 ∑ T E q ( x T ∣ x 0 ) [ D K L ( p θ ( x t − 1 ∣ x t ) ∣ ∣ q ( x t − 1 ∣ x t , x 0 ) ) ]
其中 q ( x t − 1 ∣ x t , x 0 ) q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) q ( x t − 1 ∣ x t , x 0 ) 为真实去噪分布,这实际上是可以使用
推导出真实去噪分布:
按照加噪过程的条件:
q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , ( 1 − α t ) I ) q(x_t|x_{t-1})=\mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, (1-\alpha_t)I)
q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , ( 1 − α t ) I )
它的含义是,给定一个 x t − 1 x_{t-1} x t − 1 ,x t x_t x t 的分布是完全随机的,类似于随机游走(这个从 SDE 的视角好解释),其中 α t \alpha_t α t 是超参数,它的值不是学习的,因此后向的去噪过程也可以使用这个参数
使用重参数化的技巧,可以将 x t x_t x t 使用 x 0 x_0 x 0 表示:
x t = α ˉ t x 0 + 1 − α ˉ t ξ 0 ξ 0 ∼ N ( ξ 0 ; 0 , I ) x_t = \sqrt{\bar \alpha_t}x_0 + \sqrt{1 - \bar \alpha_t}\xi_0 \\
\xi_0 \sim \mathcal{N}(\xi_0; 0, I)
x t = α ˉ t x 0 + 1 − α ˉ t ξ 0 ξ 0 ∼ N ( ξ 0 ; 0 , I )
那么可以推导出真实的去噪分布:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t , ( 1 − α ˉ t ) ( 1 − α ˉ t + 1 ) 1 − α ˉ t I ) define: μ q ( x t , x 0 ) = α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t define: Σ q ( t ) = ( 1 − α ˉ t ) ( 1 − α ˉ t + 1 ) 1 − α ˉ t I q(x_{t-1}|x_t,x_0) =\frac{q(x_t|x_{t-1}, x_0) q(x_{t-1}|x_0)}{q(x_{t-1}|x_0)} \\
= \mathcal{N}(x_{t-1}; \frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t}, \frac{(1-\bar \alpha_t)(1-\bar \alpha_{t+1})}{1-\bar \alpha_t}I)\\
\text{define:} \quad \mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t} \\
\text{define:} \quad \Sigma_q(t) = \frac{(1-\bar \alpha_t)(1-\bar \alpha_{t+1})}{1-\bar \alpha_t}I
q ( x t − 1 ∣ x t , x 0 ) = q ( x t − 1 ∣ x 0 ) q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; 1 − α ˉ t α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 , 1 − α ˉ t ( 1 − α ˉ t ) ( 1 − α ˉ t + 1 ) I ) define: μ q ( x t , x 0 ) = 1 − α ˉ t α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 define: Σ q ( t ) = 1 − α ˉ t ( 1 − α ˉ t ) ( 1 − α ˉ t + 1 ) I
让模型拟合真实去噪分布:
我们上面推导出了真实的去噪分布,现在我们需要设计模型 p θ p_\theta p θ 让他能够拟合这个分布,我们也让模型拟合的分布也设为高斯分布,且方差相同 ,注意 α t \alpha_t α t 可以是一个超参数,它不是通过学习得到的
Given x t − 1 ∼ q ( x t − 1 ∣ x t , x 0 ) Let x t − 1 ∼ p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ , Σ θ ( t ) ) Σ θ ( t ) = Σ q ( t ) \text{Given} \quad x_{t-1}\sim q(x_{t-1}|x_t,x_0) \\
\text{Let} \quad x_{t-1} \sim p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta, \Sigma_\theta(t))\\
\Sigma_\theta(t) = \Sigma_q(t)
Given x t − 1 ∼ q ( x t − 1 ∣ x t , x 0 ) Let x t − 1 ∼ p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ , Σ θ ( t ) ) Σ θ ( t ) = Σ q ( t )
之所以能让方差相等,是因为 Σ q ( t ) \Sigma_q(t) Σ q ( t ) 中的表达式是超参数,我们在即使没有前向过程的时候是不知道的,但是 μ q \mu_q μ q 的解析式中含有 x 0 x_0 x 0 ,这一项是在没有前向过程的时候是不知道的,因此让方差相同而均值不同
因此可以让 denoising item 的优化目标变为:
arg min θ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) = arg min θ 1 2 σ q 2 ( t ) [ ∥ μ θ − μ q ∥ 2 2 ] \arg \min_\theta D_{KL}\left( q(x_{t-1}|x_t,x_0) || p_\theta(x_{t-1}|x_t) \right) \\
=\arg \min_\theta \frac{1}{2\sigma_q^2(t)}\left[ \Vert \mu_\theta - \mu_q \Vert_2^2 \right]
arg θ min D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) = arg θ min 2 σ q 2 ( t ) 1 [ ∥ μ θ − μ q ∥ 2 2 ]
因此模型只用去拟合一个向量 μ θ \mu_\theta μ θ ,现在就引出了另外一个问题,如何设计模型,尽可能最大限度地利用模型的表达能力,让模型去拟合 μ θ \mu_\theta μ θ 呢?
设定 μ θ \mu_\theta μ θ :
μ q ( x t , x 0 ) = α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t \mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t)x_0}{1-\bar \alpha_t}
μ q ( x t , x 0 ) = 1 − α ˉ t α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0
首先,无论什么设定,我们需要注意的一点,由于 generative model 的目的是生成新的样本,因此我们需要一个可以只进行反向过程的模型,这意味着我们不能在时间 t t t 时刻出现在时间 t t t 之前的变量,也就是 x 0 , x 1 … x_0,x_1\dots x 0 , x 1 …
设定一:直接去拟合 x 0 x_0 x 0
我们使用 μ q \mu_q μ q 相同的格式,只去改变 x 0 x_0 x 0 项:
μ θ = α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t ) 1 − α ˉ t \mu_\theta =\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t+1})x_t + \sqrt{\bar \alpha_{t-1}}(1-\alpha_t) \textcolor{red}{\hat{x}_\theta(x_t,t)}}{1-\bar \alpha_t}
μ θ = 1 − α ˉ t α t ( 1 − α ˉ t + 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ ( x t , t )
则在这种设定下,模型的输入为 x t x_t x t 和时间步 t t t ,让 μ θ \mu_\theta μ θ 去拟合 μ q \mu_q μ q 等效为:
arg min θ 1 2 σ q 2 ( t ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 ( ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ) \arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert \hat x_\theta(x_t, t) - x_0\Vert_2^2 \right)
arg θ min 2 σ q 2 ( t ) 1 ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 ( ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 )
但是实际上没人会用这种设定方式,我们下面分析不使用它的原因
很荒谬的一点是,我们的初衷本来是想通过 t t t 步采样得到我们的 x 0 x_0 x 0 ,但是我们每一步去噪的时候,去拟合 μ q \mu_q μ q 的同时我们需要去预测一个 x 0 x_0 x 0 ,如果能够直接单步采样出很好的 x 0 x_0 x 0 ,那我们要多步采样还干嘛,直接使用 GAN 之类的模型就好了
记住神经网络虽然具有万能逼近定理,但是不是所有的函数的拟合难度都是一样的,我们要尽可能去利用神经网络的表达能力,个人的经验是,神经网络能够很好地去拟合需要微调的量,例如 diffusion 的 noise 和 deformable convolution 的 offset 等
设定二:去拟合 ξ 0 \xi_0 ξ 0
实际上,我们可以让 x t x_t x t 重参数化为:
x t = α ˉ t x 0 + 1 − α ˉ t ξ 0 ξ 0 ∼ N ( ξ 0 ; 0 , I ) x_t = \sqrt{\bar \alpha_t}x_0 + \sqrt{1 - \bar \alpha_t}\xi_0 \\
\xi_0 \sim \mathcal{N}(\xi_0; 0, I)
x t = α ˉ t x 0 + 1 − α ˉ t ξ 0 ξ 0 ∼ N ( ξ 0 ; 0 , I )
则 x 0 x_0 x 0 项可以用 x t x_t x t 和 ξ 0 \xi_0 ξ 0 表示:
x 0 = x t − 1 − α ˉ t ξ 0 d ˉ t μ q ( x t , x 0 ) = 1 d t x t − 1 − α t 1 − α ˉ t d t ξ 0 x_0 = \frac{x_t - \sqrt{1-\bar\alpha_t}\xi_0}{\sqrt{\bar d_t}} \\
\mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}}\xi_0
x 0 = d ˉ t x t − 1 − α ˉ t ξ 0 μ q ( x t , x 0 ) = d t 1 x t − 1 − α ˉ t d t 1 − α t ξ 0
那么我们让模型也用相同的形式拟合:
μ θ = μ q ( x t , x 0 ) = 1 d t x t − 1 − α t 1 − α ˉ t d t ξ ˉ θ ( x t , t ) \mu_\theta =\mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}} \textcolor{red}{\bar\xi_\theta(x_t, t)}
μ θ = μ q ( x t , x 0 ) = d t 1 x t − 1 − α ˉ t d t 1 − α t ξ ˉ θ ( x t , t )
因此优化目标可以变为:
arg min θ 1 2 σ q 2 ( t ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 ( ∥ ξ ^ θ ( x t , t ) − ξ 0 ∥ 2 2 ) \arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert \hat \xi_\theta(x_t, t) - \xi_0 \Vert_2^2 \right)
arg θ min 2 σ q 2 ( t ) 1 ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 ( ∥ ξ ^ θ ( x t , t ) − ξ 0 ∥ 2 2 )
到这里我们可能会好奇,为什么模型能够拟合一个随机变量? 让模型去拟合一个随机变量的分布这个说法听起来十分抽象,但是实际上 ξ 0 \xi_0 ξ 0 的值在反向过程中是确定的,它是在正向过程中的时候采样得到的,因此确定,即使我们在训练的时候没有正向过程,我们仍然可以认为那张原图本身是存在的,只是模型没有正向过程去采样罢了
The sculpture is already complete within the marble block before I start my work. It is already there, I just have to chisel away the superfluous material.
——Michelangelo
设定三:Tweedie 公式
使用 Tweedie 公式,我们可以将上面的 x 0 x_0 x 0 写成:
x 0 = x t − 1 − α ˉ t ∇ log p ( x ) d ˉ t x_0 = \frac{x_t - \sqrt{1-\bar\alpha_t}\nabla\log p(x) }{\sqrt{\bar d_t}}
x 0 = d ˉ t x t − 1 − α ˉ t ∇ log p ( x )
这时候的模型就是基于分数的模型了,这个的原因将在后面的博客内解释;因此我们可以将 μ q \mu_q μ q 和 μ θ \mu_\theta μ θ 的解析式写为:
μ q ( x t , x 0 ) = 1 d t x t − 1 − α t 1 − α ˉ t d t ∇ log p ( x ) ⇒ μ q ( x t , x 0 ) = 1 d t x t − 1 − α t 1 − α ˉ t d t S θ ( x t , t ) \mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}}\nabla \log p(x) \\
\Rightarrow \mu_q(x_t, x_0)=\frac{1}{\sqrt{d_t}}x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{d_t}} \textcolor{red}{S_\theta(x_t,t)}
μ q ( x t , x 0 ) = d t 1 x t − 1 − α ˉ t d t 1 − α t ∇ log p ( x ) ⇒ μ q ( x t , x 0 ) = d t 1 x t − 1 − α ˉ t d t 1 − α t S θ ( x t , t )
因此优化目标变为:
arg min θ 1 2 σ q 2 ( t ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 ( ∥ S θ ( x t , t ) − ∇ log p ( x ) ∥ 2 2 ) \arg \min_\theta \frac{1}{2\sigma_q^2(t)}\frac{\bar\alpha_{t-1}(1-\alpha_t)^2}{(1-\bar\alpha_t)^2 } \left( \Vert S_\theta(x_t, t) - \nabla \log p(x) \Vert_2^2 \right)
arg θ min 2 σ q 2 ( t ) 1 ( 1 − α ˉ t ) 2 α ˉ t − 1 ( 1 − α t ) 2 ( ∥ S θ ( x t , t ) − ∇ log p ( x ) ∥ 2 2 )
计算细节:
打 latex 太繁琐了,请查阅论文:Understanding Diffusion Models: A Unified Perspective
实现网络:
最终去噪网络的实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 from torch import nnimport torchimport mathclass Block (nn.Module): def __init__ (self, in_ch, out_ch, time_emb_dim, up=False ): super ().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3 , padding=1 ) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4 , 2 , 1 ) else : self.conv1 = nn.Conv2d(in_ch, out_ch, 3 , padding=1 ) self.transform = nn.Conv2d(out_ch, out_ch, 4 , 2 , 1 ) self.conv2 = nn.Conv2d(out_ch, out_ch, 3 , padding=1 ) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU() def forward ( self, x, t, ): h = self.bnorm1(self.relu(self.conv1(x))) time_emb = self.relu(self.time_mlp(t)) time_emb = time_emb[(...,) + (None ,) * 2 ] h = h + time_emb h = self.bnorm2(self.relu(self.conv2(h))) return self.transform(h)class SinusoidalPositionEmbeddings (nn.Module): def __init__ (self, dim ): super ().__init__() self.dim = dim def forward (self, time ): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000 ) / (half_dim - 1 ) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None ] * embeddings[None , :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1 ) return embeddingsclass SimpleUnet (nn.Module): def __init__ (self ): super ().__init__() image_channels = 3 down_channels = (64 , 128 , 256 , 512 , 1024 ) up_channels = (1024 , 512 , 256 , 128 , 64 ) out_dim = 3 time_emb_dim = 32 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) self.conv0 = nn.Conv2d(image_channels, down_channels[0 ], 3 , padding=1 ) self.downs = nn.ModuleList( [Block(down_channels[i], down_channels[i + 1 ], time_emb_dim) for i in range (len (down_channels) - 1 )] ) self.ups = nn.ModuleList( [Block(up_channels[i], up_channels[i + 1 ], time_emb_dim, up=True ) for i in range (len (up_channels) - 1 )] ) self.output = nn.Conv2d(up_channels[-1 ], out_dim, 1 ) def forward (self, x, timestep ): t = self.time_mlp(timestep) x = self.conv0(x) residual_inputs = [] for down in self.downs: x = down(x, t) residual_inputs.append(x) for up in self.ups: residual_x = residual_inputs.pop() x = torch.cat((x, residual_x), dim=1 ) x = up(x, t) return self.output(x)if __name__ == "__main__" : model = SimpleUnet() print ("Num params: " , sum (p.numel() for p in model.parameters())) img_size = 64 img = torch.randn((1 , 3 , img_size, img_size), device='cpu' ) t = torch.tensor([4 ], device='cpu' ) denoise = model(img, t) print (denoise.shape)
训练代码如下:
在训练扩散模型时,每次训练的输入并不是从一张图片出发,经过所有去噪步骤生成的一系列中间结果,而是将每一步加噪后的图像视为独立的样本进行训练。训练过程中会随机选择时间步 t t t ,并对每张图像进行加噪,生成带有噪声的图像 x t x_t x t 和对应的噪声 ϵ \epsilon ϵ ,然后使用这些加噪后的图像及其对应的时间步 t t t 来训练模型
如果对一张图片进行所有去噪步骤并将其作为一个序列来训练,会导致训练过程非常冗长和复杂
随机选择时间步的方法使得训练过程更加高效,因为每个批次的数据都是独立的,可以并行处理
通过在每个时间步上独立训练,模型可以更稳定地学习去噪任务,避免了序列训练中可能出现的累积误差
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 from forward_noising import forward_diffusion_samplefrom unet import SimpleUnetfrom dataloader import load_transformed_datasetimport torch.nn.functional as Fimport torchfrom torch.optim import Adamdef get_loss (model, x_0, t, device ): x_noisy, noise = forward_diffusion_sample(x_0, t, device) noise_pred = model(x_noisy, t) return F.mse_loss(noise, noise_pred)if __name__ == "__main__" : model = SimpleUnet() T = 300 BATCH_SIZE = 128 epochs = 100 dataloader = load_transformed_dataset(batch_size=BATCH_SIZE) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) optimizer = Adam(model.parameters(), lr=0.001 ) for epoch in range (epochs): for batch_idx, (batch, _) in enumerate (dataloader): optimizer.zero_grad() t = torch.randint(0 , T, (BATCH_SIZE,), device=device).long() loss = get_loss(model, batch, t, device=device) loss.backward() optimizer.step() torch.save(model.state_dict(), "./trained_models/ddpm_mse_epochs_100.pth" )