Flow Models 详解
生成模型概览
在探索模型结构的多样性时,主流的生成模型大致可分为几类:GAN (生成对抗网络)、VAE (变分自编码器)、Flow (流模型)、Diffusion (扩散模型)以及 AR (自回归模型)等。它们各有特点,致力于从不同角度解决数据生成的问题。
基于流的生成模型 (Flow-based Generative Model)
首先需要回顾一下生成模型的核心目标:最大化观测数据的对数似然 。简单来说,我们希望模型学习到的数据分布 p θ ( x ) p_\theta(\boldsymbol x) p θ ( x ) 尽可能地接近真实数据的分布 p d a t a ( x ) p_{data}(\boldsymbol x) p d a t a ( x ) 。
然而,直接对高维复杂数据的对数似然 log p θ ( x ) \log p_\theta(\boldsymbol x) log p θ ( x ) 进行建模和优化通常非常困难。因此,像 VAE 和扩散模型等方法,往往通过优化对数似然的一个下界 (Evidence Lower Bound, ELBO) 来间接达到目的。虽然这在实践中取得了成功,但优化下界终究与直接优化似然本身存在一定的差距。
基于流 (Flow-based) 的模型则提供了一条直接优化对数似然的路径。 其核心思想在于构建一个从真实数据分布到某个简单、易于处理的先验分布(例如标准正态分布)的可逆映射 (双射)。借助数学中的变量替换定理 (Change of Variables Theorem) ,这个可逆映射可以将复杂数据分布下的概率密度计算,转换为在简单先验分布下的概率密度计算 ,从而实现对数据对数似然的直接建模和优化
具体来说,我们的优化目标是最大化训练数据集 { x 1 , x 2 , … , x m } \{\boldsymbol x_1,\boldsymbol x_2,\ldots,\boldsymbol x_m \} { x 1 , x 2 , … , x m } (假设其独立同分布采样自 p d a t a p_{data} p d a t a ) 的总对数似然:
J ( θ ) = max θ ∑ i = 1 m log p θ ( x i ) J(\theta)=\max_\theta\sum_{i=1}^m \log p_\theta(\boldsymbol x^i)
J ( θ ) = θ max i = 1 ∑ m log p θ ( x i )
现在,我们引入隐变量 z \boldsymbol z z ,其服从一个我们已知的简单先验分布 π ( z ) \pi(\boldsymbol z) π ( z ) (例如 N ( 0 , I ) \mathcal{N}(\boldsymbol 0, \boldsymbol I) N ( 0 , I ) )。我们希望学习一个可逆的映射函数 G θ : Z → X G_\theta: \mathcal{Z} \to \mathcal{X} G θ : Z → X ,它可以将隐变量 z \boldsymbol z z 映射到数据空间,即 x = G θ ( z ) \boldsymbol x = G_\theta(\boldsymbol z) x = G θ ( z ) 。相应地,其逆映射为 G θ − 1 : X → Z G_\theta^{-1}: \mathcal{X} \to \mathcal{Z} G θ − 1 : X → Z ,即 z = G θ − 1 ( x ) \boldsymbol z = G_\theta^{-1}(\boldsymbol x) z = G θ − 1 ( x ) 。
根据变量替换定理,由 G θ G_\theta G θ 定义的数据分布 p θ ( x ) p_\theta(\boldsymbol x) p θ ( x ) 可以通过隐变量的概率密度 π ( z ) \pi(\boldsymbol z) π ( z ) 表示为:
p θ ( x ) = π ( G θ − 1 ( x ) ) ⋅ ∣ det ( ∂ G θ − 1 ( x ) ∂ x ) ∣ p_\theta(\boldsymbol x) = \pi(G_\theta^{-1}(\boldsymbol x)) \cdot \left| \det \left( \frac{\partial G_\theta^{-1}(\boldsymbol x)}{\partial \boldsymbol x} \right) \right|
p θ ( x ) = π ( G θ − 1 ( x ) ) ⋅ ∣ ∣ ∣ ∣ ∣ det ( ∂ x ∂ G θ − 1 ( x ) ) ∣ ∣ ∣ ∣ ∣
其中,∂ G θ − 1 ( x ) ∂ x \frac{\partial G_\theta^{-1}(\boldsymbol x)}{\partial \boldsymbol x} ∂ x ∂ G θ − 1 ( x ) 是逆映射 G θ − 1 G_\theta^{-1} G θ − 1 在点 x \boldsymbol x x 处的雅可比矩阵 (Jacobian Matrix),而 ∣ det ( … ) ∣ \left| \det \left( \dots \right) \right| ∣ det ( … ) ∣ 表示其行列式的绝对值。为了简洁,我们将其记为 ∣ det J G θ − 1 ( x ) ∣ \left| \det J_{G_\theta^{-1}}(\boldsymbol x) \right| ∣ ∣ ∣ ∣ det J G θ − 1 ( x ) ∣ ∣ ∣ ∣ 。于是,单一样本的对数似然 log p θ ( x ) \log p_\theta(\boldsymbol x) log p θ ( x ) 可以写为:
log p θ ( x ) = log π ( G θ − 1 ( x ) ) + log ∣ det J G θ − 1 ( x ) ∣ \log p_\theta(\boldsymbol x) = \log \pi(G_\theta^{-1}(\boldsymbol x)) + \log \left| \det J_{G_\theta^{-1}}(\boldsymbol x) \right|
log p θ ( x ) = log π ( G θ − 1 ( x ) ) + log ∣ ∣ ∣ ∣ det J G θ − 1 ( x ) ∣ ∣ ∣ ∣
将此表达式代入我们最初的优化目标 J ( θ ) J(\theta) J ( θ ) ,得到:
J ( θ ) = max θ ∑ i = 1 m [ log π ( G θ − 1 ( x i ) ) + log ∣ det J G θ − 1 ( x i ) ∣ ] J(\theta) = \max_\theta \sum_{i=1}^m \left[ \log \pi(G_\theta^{-1}(\boldsymbol x^i)) + \log \left| \det J_{G_\theta^{-1}}(\boldsymbol x^i) \right| \right]
J ( θ ) = θ max i = 1 ∑ m [ log π ( G θ − 1 ( x i ) ) + log ∣ ∣ ∣ ∣ det J G θ − 1 ( x i ) ∣ ∣ ∣ ∣ ]
这个公式是流模型训练的核心。它由两部分组成:第一项 log π ( G θ − 1 ( x i ) ) \log \pi(G_\theta^{-1}(\boldsymbol x^i)) log π ( G θ − 1 ( x i ) ) 鼓励模型将数据点 x i \boldsymbol x^i x i 映射到先验分布 π \pi π 下具有较高概率的隐变量;第二项 log ∣ det J G θ − 1 ( x i ) ∣ \log \left| \det J_{G_\theta^{-1}}(\boldsymbol x^i) \right| log ∣ ∣ ∣ ∣ det J G θ − 1 ( x i ) ∣ ∣ ∣ ∣ 则是雅可比行列式的对数,它衡量了从数据空间到隐空间的映射过程中发生的体积变化
可逆变换的数学本质
流模型成功的关键在于精心设计的可逆变换函数 f : X → Z f:\mathcal{X}\to \mathcal{Z} f : X → Z (在上述讨论中对应 G θ − 1 G_\theta^{-1} G θ − 1 )。通过这个变换及其雅可比行列式,我们可以精确地转换概率密度:
log p X ( x ) = log p Z ( f ( x ) ) + log ∣ det J f ( x ) ∣ \log p_{\mathcal{X}}(\boldsymbol x) = \log p_{\mathcal{Z}}(f(\boldsymbol x)) + \log\left|\det J_f(\boldsymbol x)\right|
log p X ( x ) = log p Z ( f ( x ) ) + log ∣ det J f ( x ) ∣
这里有几个关键点需要强调:
维度保持特性:与 VAE 等模型可能将数据压缩到低维隐空间不同,流模型中的可逆变换通常要求输入 X \mathcal{X} X 和输出 Z \mathcal{Z} Z 的维度严格一致。这既是一个约束,也使得模型能够保留数据的全部信息。
参数共享的映射:Flow 模型的核心是学习一个可逆映射 f θ : X ↔ Z f_\theta: \mathcal{X} \leftrightarrow \mathcal{Z} f θ : X ↔ Z 。无论是从数据空间 X \mathcal{X} X 到隐空间 Z \mathcal{Z} Z 的编码过程(即 f θ f_\theta f θ ,对应前文的 G θ − 1 G_\theta^{-1} G θ − 1 ),还是从隐空间 Z \mathcal{Z} Z 到数据空间 X \mathcal{X} X 的生成过程(即 f θ − 1 f_\theta^{-1} f θ − 1 ,对应前文的 G θ G_\theta G θ ),都由同一组参数 θ \theta θ 控制。因此,不像 VAE 那样需要分别训练编码器和解码器网络,Flow 模型通过优化上述对数似然来学习这单一可逆映射的参数
一个重要的观察是:在流模型中,隐空间 Z \mathcal{Z} Z 的维度通常与原始数据空间 X \mathcal{X} X 的维度保持一致。 如果我们随意设计一般的可逆变换 f k f_k f k ,计算其雅可比行列式 det J f k \det J_{f_k} det J f k 的复杂度可能是 O ( D 3 ) \mathcal{O}(D^3) O ( D 3 ) (D D D 是数据维度),这对于高维数据(如图像)是难以承受的。
因此,对变换 f k f_k f k 的设计有如下关键要求:
必须可逆 :这是流模型的根本。
雅可比行列式易于计算 :这是模型实用性的保证。通常希望计算复杂度为 O ( D ) \mathcal{O}(D) O ( D ) 或更低。这引导我们设计具有特定结构的变换,例如那些雅可比矩阵是三角矩阵的变换。
模型
前向过程
反向过程
Normalizing Flow
通过显式的可学习变换将样本分布变换为标准高斯分布
从标准高斯分布采样,并通过上述变换的逆变换得到生成的样本
Diffusion Model
通过不可学习的 schedule 对样本进行加噪,多次加噪变换为标准高斯分布
从标准高斯分布采样,通过模型隐式地学习反向过程的噪声,去噪得到生成样本
归一化流 (Normalizing Flow)
“归一化流”这一名称强调了模型将复杂数据分布“归一化”为一个标准、简单的目标分布(通常是标准正态分布)的过程。这与 VAE 中强制后验分布逼近标准正态先验有相似之处
由于单个可逆变换 G G G 的表达能力可能有限(特别是为了保证可逆性和雅可比行列式易算性而引入的结构约束),实践中我们通常将多个简单可逆变换(称为“流层”或“仿射耦合层”等)串联起来,形成一个更强大、更具表达能力的复合变换:
x → f 1 z 1 → f 2 z 2 → … z K − 1 → f K z K = z \boldsymbol x \xrightarrow{f_1} \boldsymbol z_1 \xrightarrow{f_2} \boldsymbol z_2 \xrightarrow{\dots} \boldsymbol z_{K-1} \xrightarrow{f_K} \boldsymbol z_K = \boldsymbol z
x f 1 z 1 f 2 z 2 … z K − 1 f K z K = z
其中 x \boldsymbol x x 是输入数据,z \boldsymbol z z 是最终的隐变量,每一个 f k f_k f k 都是一个可逆变换。
对于这样的复合变换,根据链式法则,总的雅可比行列式是对每一层雅可比行列式的连乘。因此,总的对数似然贡献也是各层对数雅可比行列式之和:
log p X ( x ) = log p Z ( z ) + ∑ k = 1 K log ∣ det J f k ( z k − 1 ) ∣ \log p_{\mathcal{X}}(\boldsymbol x) = \log p_{\mathcal{Z}}(\boldsymbol z) + \sum_{k=1}^K \log \left| \det J_{f_k}(\boldsymbol z_{k-1}) \right|
log p X ( x ) = log p Z ( z ) + k = 1 ∑ K log ∣ det J f k ( z k − 1 ) ∣
其中 z 0 = x \boldsymbol z_0 = \boldsymbol x z 0 = x ,z k = f k ( z k − 1 ) \boldsymbol z_k = f_k(\boldsymbol z_{k-1}) z k = f k ( z k − 1 ) 。
此时,优化目标变为:
J ( θ ) = max θ ∑ i = 1 m [ log π ( z i ) + ∑ k = 1 K log ∣ det J f k ( z k − 1 i ) ∣ ] J(\theta) = \max_\theta \sum_{i=1}^m \left[ \log \pi(\boldsymbol z^i) + \sum_{k=1}^K \log \left| \det J_{f_k}(\boldsymbol z_{k-1}^i) \right| \right]
J ( θ ) = θ max i = 1 ∑ m [ log π ( z i ) + k = 1 ∑ K log ∣ ∣ ∣ det J f k ( z k − 1 i ) ∣ ∣ ∣ ]
其中 z i \boldsymbol z^i z i 是样本 x i \boldsymbol x^i x i 经过整个流变换序列后的最终隐表示
Coupling Blocks:仿射耦合层 (Affine Coupling Layer)
为了满足 flow 模型可逆且雅可比矩阵便于计算的要求,有多种巧妙的层设计,其中仿射耦合层 是 RealNVP、NICE 和 Glow 等模型的核心组件。
其基本思想是将输入向量 x \boldsymbol x x 分成两部分,例如 x = ( x A , x B ) \boldsymbol x = (\boldsymbol x_A, \boldsymbol x_B) x = ( x A , x B ) 。变换时,一部分 (x A \boldsymbol x_A x A ) 保持不变,而另一部分 (x B \boldsymbol x_B x B ) 则通过一个仿射变换进行更新,该仿射变换的参数 (尺度 s s s 和平移 t t t ) 由保持不变的那部分 (x A \boldsymbol x_A x A ) 计算得出。
具体地,对于一个从 x \boldsymbol x x 到 y \boldsymbol y y 的耦合层变换 G G G :
将输入 x \boldsymbol x x 沿某个维度(例如通道维度)切分为两半:x A , x B \boldsymbol x_A, \boldsymbol x_B x A , x B 。
第一部分保持不变:y A = x A \boldsymbol y_A = \boldsymbol x_A y A = x A 。
第二部分经过仿射变换:y B = x B ⊙ exp ( s ( x A ) ) + t ( x A ) \boldsymbol y_B = \boldsymbol x_B \odot \exp(s(\boldsymbol x_A)) + t(\boldsymbol x_A) y B = x B ⊙ exp ( s ( x A ) ) + t ( x A ) 。
其中 ⊙ \odot ⊙ 表示逐元素相乘。尺度参数 s s s 和平移参数 t t t 都是通过神经网络(例如几层全连接或卷积层)作用于 x A \boldsymbol x_A x A 得到的。exp ( s ( x A ) ) \exp(s(\boldsymbol x_A)) exp ( s ( x A ) ) 确保尺度因子为正。
这个变换的雅可比矩阵 J G ( x ) J_G(\boldsymbol x) J G ( x ) 具有如下形式(假设 x A \boldsymbol x_A x A 是前 d d d 维,x B \boldsymbol x_B x B 是后 D − d D-d D − d 维):
J f ( x ) = [ I d 0 d × ( D − d ) ∂ y B ∂ x A d i a g ( exp ( s ( x A ) ) ) ] \boldsymbol{J}_f(\boldsymbol x) =
\begin{bmatrix}
\mathbb{I}_d & \boldsymbol{0}_{d\times(D-d)} \\
\frac{\partial\boldsymbol{y}_B} {\partial\boldsymbol{x}_A} & \mathrm{diag}(\exp(s(\boldsymbol{x}_A)))
\end{bmatrix}
J f ( x ) = [ I d ∂ x A ∂ y B 0 d × ( D − d ) d i a g ( exp ( s ( x A ) ) ) ]
这是一个下三角矩阵(或上三角,取决于分割和更新的顺序),其行列式就是对角线元素的乘积:
det ( J G ( x ) ) = ∏ j exp ( s j ( x A ) ) = exp ( ∑ j s j ( x A ) ) \det(J_G(\boldsymbol x)) = \prod_j \exp(s_j(\boldsymbol x_A)) = \exp\left(\sum_j s_j(\boldsymbol x_A)\right)
det ( J G ( x ) ) = j ∏ exp ( s j ( x A ) ) = exp ( j ∑ s j ( x A ) )
因此,对数雅可比行列式可以非常高效地计算:
log ∣ det ( J f ( x ) ) ∣ = ∑ j s j ( x A ) \log |\det(J_f(\boldsymbol x))| = \sum_j s_j(\boldsymbol x_A)
log ∣ det ( J f ( x ) ) ∣ = j ∑ s j ( x A )
这个变换的逆变换也容易计算:
x A = y A \boldsymbol x_A = \boldsymbol y_A x A = y A
x B = ( y B − t ( y A ) ) ⊙ exp ( − s ( y A ) ) \boldsymbol x_B = (\boldsymbol y_B - t(\boldsymbol y_A)) \odot \exp(-s(\boldsymbol y_A)) x B = ( y B − t ( y A ) ) ⊙ exp ( − s ( y A ) )
为了让所有维度都能得到更新,通常会交替地将不同部分的维度作为 x A \boldsymbol x_A x A (保持不变的部分)。例如,在一个耦合层中,前一半维度不变,后一半更新;在下一个耦合层中,后一半维度不变,前一半更新(通过一个固定的排列操作,如翻转,来实现)。
通过堆叠多个这样的耦合层,并可能在它们之间加入维度重排(如1x1卷积或固定置换),模型可以学习到非常复杂和灵活的数据变换。
Autoregressive Flows:自回归流
自回归流 (Autoregressive Flow) 是另一类重要的流模型,在自回归流中,数据点 x = ( x 1 , x 2 , … , x D ) \boldsymbol x = (x_1, x_2, \ldots, x_D) x = ( x 1 , x 2 , … , x D ) 的每个维度 x i x_i x i 到对应隐变量 z i z_i z i (或者反过来) 的变换,都依赖于 x \boldsymbol x x (或 z \boldsymbol z z ) 的前 i − 1 i-1 i − 1 个维度
对于从 z \boldsymbol z z 到 x \boldsymbol x x 的变换 G G G :
x i = τ ( z i ; h i ( z < i ) ) x_i = \tau(z_i; \boldsymbol{h}_i(\boldsymbol z_{<i}))
x i = τ ( z i ; h i ( z < i ) )
其中 τ \tau τ 是一个关于 z i z_i z i 的可逆标量函数,其参数(例如仿射变换中的尺度 α i \alpha_i α i 和偏置 β i \beta_i β i )由一个条件网络 h i \boldsymbol{h}_i h i 根据 z < i = ( z 1 , … , z i − 1 ) \boldsymbol z_{<i} = (z_1, \ldots, z_{i-1}) z < i = ( z 1 , … , z i − 1 ) 计算得到。这意味着 x i x_i x i 的生成依赖于 z i z_i z i 以及所有在 i i i 之前的隐变量 z 1 , … , z i − 1 z_1, \ldots, z_{i-1} z 1 , … , z i − 1 。
这种结构的关键优势在于其雅可比矩阵的特性。对于上述从 z \boldsymbol z z 到 x \boldsymbol x x 的变换 G G G ,其雅可比矩阵 J G ( z ) = ∂ x ∂ z J_G(\boldsymbol z) = \frac{\partial \boldsymbol x}{\partial \boldsymbol z} J G ( z ) = ∂ z ∂ x 是一个下三角矩阵:
J f ( z ) = ( ∂ x 1 ∂ z 1 0 ⋯ 0 ∂ x 2 ∂ z 1 ∂ x 2 ∂ z 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ ∂ x D ∂ z 1 ∂ x D ∂ z 2 ⋯ ∂ x D ∂ z D ) J_f(\boldsymbol z) =
\begin{pmatrix}
\frac{\partial x_1}{\partial z_1} & 0 & \cdots & 0 \\
\frac{\partial x_2}{\partial z_1} & \frac{\partial x_2}{\partial z_2} & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
\frac{\partial x_D}{\partial z_1} & \frac{\partial x_D}{\partial z_2} & \cdots & \frac{\partial x_D}{\partial z_D}
\end{pmatrix}
J f ( z ) = ⎝ ⎜ ⎜ ⎜ ⎜ ⎛ ∂ z 1 ∂ x 1 ∂ z 1 ∂ x 2 ⋮ ∂ z 1 ∂ x D 0 ∂ z 2 ∂ x 2 ⋮ ∂ z 2 ∂ x D ⋯ ⋯ ⋱ ⋯ 0 0 ⋮ ∂ z D ∂ x D ⎠ ⎟ ⎟ ⎟ ⎟ ⎞
这是因为 x i x_i x i 的计算只依赖于 z 1 , … , z i z_1, \ldots, z_i z 1 , … , z i ,而不依赖于 z j z_{j} z j 其中 j > i j>i j > i 。因此,∂ x i ∂ z j = 0 \frac{\partial x_i}{\partial z_j} = 0 ∂ z j ∂ x i = 0 对于 j > i j>i j > i 。
三角矩阵的行列式就是其对角线元素的乘积:
det ( J G ( z ) ) = ∏ i = 1 D ∂ x i ∂ z i \det(J_G(\boldsymbol z)) = \prod_{i=1}^D \frac{\partial x_i}{\partial z_i}
det ( J G ( z ) ) = i = 1 ∏ D ∂ z i ∂ x i
如果变换 τ \tau τ 是一个仿射变换,例如 x i = α i ( z < i ) z i + β i ( z < i ) x_i = \alpha_i(\boldsymbol z_{<i}) z_i + \beta_i(\boldsymbol z_{<i}) x i = α i ( z < i ) z i + β i ( z < i ) ,那么对角线元素就是尺度参数 α i ( z < i ) \alpha_i(\boldsymbol z_{<i}) α i ( z < i ) 。因此,对数雅可比行列式可以简单地计算为:
log ∣ det ( J f ( z ) ) ∣ = ∑ i = 1 D log ∣ α i ( z < i ) ∣ \log |\det(J_f(\boldsymbol z))| = \sum_{i=1}^D \log |\alpha_i(\boldsymbol z_{<i})|
log ∣ det ( J f ( z ) ) ∣ = i = 1 ∑ D log ∣ α i ( z < i ) ∣
条件网络 h i \boldsymbol{h}_i h i 通常使用能够有效处理序列依赖性的架构,如循环神经网络 (RNN) 或更常见的 Masked Autoencoder for Distribution Estimation (MADE) 以及 PixelCNN/WaveNet 中的掩码卷积。
采样 (生成):从 p ( z ) p(\boldsymbol z) p ( z ) 采样得到 z \boldsymbol z z ,然后计算 x \boldsymbol x x 。这个过程是高效的,因为给定 z \boldsymbol z z ,所有的 x i x_i x i 都可以并行计算(如果 h i \boldsymbol{h}_i h i 是一个掩码自编码器作用于 z \boldsymbol z z )
似然计算 (训练):计算 p ( x ) p(\boldsymbol x) p ( x ) 需要先得到 z = f − 1 ( x ) \boldsymbol z = f^{-1}(\boldsymbol x) z = f − 1 ( x ) 。z i = τ − 1 ( x i ; h i ( z < i ) ) z_i = \tau^{-1}(x_i; \boldsymbol{h}_i(\boldsymbol z_{<i})) z i = τ − 1 ( x i ; h i ( z < i ) ) 的计算是串行的,因为计算 z i z_i z i 需要 z 1 , … , z i − 1 z_1, \ldots, z_{i-1} z 1 , … , z i − 1 。这使得似然评估相对较慢
Residual Flows:残差流
残差流 (Residual Flow, ResFlow) 提供了一种构建可逆变换的替代方法,它借鉴了深度残差网络 (ResNet) 的思想。其核心是将变换 f : X → Y f: \mathcal{X} \to \mathcal{Y} f : X → Y 定义为一个残差块的形式:
y = f ( x ) = x + g ( x ; θ ) \boldsymbol y = f(\boldsymbol x) = \boldsymbol x + g(\boldsymbol x; \theta)
y = f ( x ) = x + g ( x ; θ )
其中 g g g 是一个神经网络(例如一个标准的 ResNet 块),θ \theta θ 是其参数
可逆性与逆变换
与耦合层或自回归流不同,上述形式的残差变换通常没有解析的逆函数 f − 1 f^{-1} f − 1 。然而,如果函数 g g g 满足特定的 Lipschitz 连续性条件,可以保证 G G G 是可逆的,并且可以通过不动点迭代来计算其逆变换。
具体来说,如果 g g g 关于 x \boldsymbol x x 是 Lipschitz 连续的,且其 Lipschitz 常数 Lip ( g ) < 1 \text{Lip}(g) < 1 Lip ( g ) < 1 (压缩映射),则 f ( x ) = x + g ( x ) f(\boldsymbol x) = \boldsymbol x + g(\boldsymbol x) f ( x ) = x + g ( x ) 是可逆的。其逆变换 x = f − 1 ( y ) \boldsymbol x = f^{-1}(\boldsymbol y) x = f − 1 ( y ) 可以通过以下不动点迭代求解:
x k + 1 = y − g ( x k ; θ ) \boldsymbol x_{k+1} = \boldsymbol y - g(\boldsymbol x_k; \theta)
x k + 1 = y − g ( x k ; θ )
从某个初始值 x 0 \boldsymbol x_0 x 0 (例如 x 0 = y \boldsymbol x_0 = \boldsymbol y x 0 = y ) 开始迭代,序列 x k \boldsymbol x_k x k 会收敛到真实的逆 f − 1 ( y ) f^{-1}(\boldsymbol y) f − 1 ( y ) 。在实践中,这个迭代过程会执行固定的步数或直到满足某个收敛准则。L i p ( g ) < 1 \mathrm{Lip}(g) < 1 L i p ( g ) < 1 的条件可以通过对 g g g 的权重进行谱归一化等技术来近似或强制满足
雅可比行列式的计算
残差变换的雅可比矩阵为 J G ( x ) = I + J g ( x ) J_G(\boldsymbol x) = \boldsymbol I + J_g(\boldsymbol x) J G ( x ) = I + J g ( x ) ,其中 J g ( x ) J_g(\boldsymbol x) J g ( x ) 是 g ( x ) g(\boldsymbol x) g ( x ) 关于 x \boldsymbol x x 的雅可比矩阵。直接计算 det ( I + J g ( x ) ) \det(\boldsymbol I + J_g(\boldsymbol x)) det ( I + J g ( x ) ) 的复杂度通常是 O ( D 3 ) \mathcal{O}(D^3) O ( D 3 ) ,这对于高维数据是不可接受的。
ResFlow 的一个关键贡献是采用了一种无偏的随机估计方法来计算对数雅可比行列式 log ∣ det ( J G ( x ) ) ∣ \log |\det(J_G(\boldsymbol x))| log ∣ det ( J G ( x ) ) ∣ ,而无需显式计算整个雅可比矩阵。这通常基于以下公式和 Hutchinso 迹估计:
log ∣ det ( I + J g ( x ) ) ∣ = T r ( log ( I + J g ( x ) ) ) \log |\det(\boldsymbol I + J_g(\boldsymbol x))| = \mathrm{Tr}(\log(\boldsymbol I + J_g(\boldsymbol x)))
log ∣ det ( I + J g ( x ) ) ∣ = T r ( log ( I + J g ( x ) ) )
其中 T r ( ⋅ ) \mathrm{Tr}(\cdot) T r ( ⋅ ) 表示矩阵的迹。log ( I + J g ( x ) ) \log(\boldsymbol I + J_g(\boldsymbol x)) log ( I + J g ( x ) ) 可以通过其泰勒级数展开来近似(前提是 J g ( x ) J_g(\boldsymbol x) J g ( x ) 的特征值满足一定条件,这与 L i p ( g ) < 1 \mathrm{Lip}(g)<1 L i p ( g ) < 1 相关):
log ( I + A ) = A − A 2 2 + A 3 3 − ⋯ = ∑ k = 1 ∞ ( − 1 ) k + 1 k A k \log(\boldsymbol I + A) = A - \frac{A^2}{2} + \frac{A^3}{3} - \dots = \sum_{k=1}^\infty \frac{(-1)^{k+1}}{k} A^k
log ( I + A ) = A − 2 A 2 + 3 A 3 − ⋯ = k = 1 ∑ ∞ k ( − 1 ) k + 1 A k
然后,Hutchinson迹估计器被用来估计迹:
Hutchinson 迹估计:
对于任意矩阵 M M M ,T r ( M ) = E v ∼ p ( v ) [ v T M v ] \mathrm{Tr}(M) = \mathbb{E}_{\boldsymbol v \sim p(\boldsymbol v)}[\boldsymbol v^T M \boldsymbol v] T r ( M ) = E v ∼ p ( v ) [ v T M v ] ,其中 p ( v ) p(\boldsymbol v) p ( v ) 是一个均值为0、协方差为单位阵的分布(例如,每个元素独立从 Rademacher 分布 { ± 1 } \{\pm 1\} { ± 1 } 或标准正态分布中采样)。
因此,对数行列式可以估计为:
log ∣ det ( J f ( x ) ) ∣ ≈ E v ∼ p ( v ) [ v T ( ∑ k = 1 K ( − 1 ) k + 1 k ( J g ( x ) ) k ) v ] \log |\det(J_f(\boldsymbol x))| \approx \mathbb{E}_{\boldsymbol v \sim p(\boldsymbol v)}\left[\boldsymbol v^T \left(\sum_{k=1}^K \frac{(-1)^{k+1}}{k} (J_g(\boldsymbol x))^k\right) \boldsymbol v\right]
log ∣ det ( J f ( x ) ) ∣ ≈ E v ∼ p ( v ) [ v T ( k = 1 ∑ K k ( − 1 ) k + 1 ( J g ( x ) ) k ) v ]
其中 K K K 是截断的级数项数。每一项 v T ( J g ( x ) ) k v \boldsymbol v^T (J_g(\boldsymbol x))^k \boldsymbol v v T ( J g ( x ) ) k v 都可以通过 k k k 次向量-雅可比积 (vector-Jacobian products, VJPs) 或雅可比-向量积 (Jacobian-vector products, JVPs) 高效计算,而无需实例化完整的雅可比矩阵 J g ( x ) J_g(\boldsymbol x) J g ( x ) 。这使得计算复杂度大致为 O ( K ⋅ D ⋅ C g ) \mathcal{O}(K \cdot D \cdot C_g) O ( K ⋅ D ⋅ C g ) ,其中 C g C_g C g 是计算 g g g 的一次前向/反向传播的代价
g g g 函数的选择非常灵活,可以直接使用标准的深度学习模块(如ResNet块)。相比结构受限的耦合层和自回归层,理论上可以构建更具表达能力的变换(但是严格满足 L i p ( g ) < 1 \mathrm{Lip}(g) < 1 L i p ( g ) < 1 可能比较困难,或者可能过度约束模型表达能力)
对数雅可比行列式的估计会给损失函数和梯度带来噪声/方差,可能影响训练的稳定性和收敛速度。逆变换依赖迭代求解,可能需要较多计算步骤,尤其是在生成新样本时
训练与推断
训练 (Training)
流模型的训练遵循严格的极大似然准则。目标是最小化负对数似然 (Negative Log-Likelihood, NLL):
L f l o w = − E x ∼ p d a t a [ log p X ( x ) ] = − E x ∼ p d a t a [ log p Z ( f θ ( x ) ) + ∑ k = 1 K log ∣ det J f k ( z k − 1 ) ∣ ] \mathcal{L}_{\mathrm{flow}} = -\mathbb{E}_{\boldsymbol x \sim p_{\mathrm{data}}}\left[\log p_{\mathcal{X}}(\boldsymbol x)\right] = -\mathbb{E}_{\boldsymbol x \sim p_{\mathrm{data}}}\left[\log p_{\mathcal{Z}}(f_{\theta}(\boldsymbol x)) + \sum_{k=1}^K \log\left|\det J_{f_k}(\boldsymbol z_{k-1})\right|\right]
L f l o w = − E x ∼ p d a t a [ log p X ( x ) ] = − E x ∼ p d a t a [ log p Z ( f θ ( x ) ) + k = 1 ∑ K log ∣ det J f k ( z k − 1 ) ∣ ]
其中 f θ ( x ) f_\theta(\boldsymbol x) f θ ( x ) 是整个流变换序列作用于 x \boldsymbol x x 后得到的最终隐变量 z K \boldsymbol z_K z K ,而 f k f_k f k 是序列中的第 k k k 个变换,J f k J_{f_k} J f k 是其雅可比矩阵,z k − 1 \boldsymbol z_{k-1} z k − 1 是第 k k k 个变换的输入。
这个损失函数包含两个核心组成部分:
先验匹配项 : log p Z ( f θ ( x ) ) \log p_{\mathcal{Z}}(f_{\theta}(\boldsymbol x)) log p Z ( f θ ( x ) ) 。这一项驱动模型学习到的隐空间分布 f θ ( x ) f_{\theta}(\boldsymbol x) f θ ( x ) 尽可能逼近我们预设的简单先验分布 p Z p_{\mathcal{Z}} p Z (例如,标准正态分布)。
流形校正项 (体积变化项) : ∑ k = 1 K log ∣ det J f k ( z k − 1 ) ∣ \sum_{k=1}^K \log\left|\det J_{f_k}(\boldsymbol z_{k-1})\right| ∑ k = 1 K log ∣ det J f k ( z k − 1 ) ∣ 。这一项是所有变换层对数雅可比行列式的总和,它精确地衡量了从数据空间到隐空间的映射过程中,局部体积是如何变化的。
通过梯度下降等优化算法最小化 L f l o w \mathcal{L}_{\mathrm{flow}} L f l o w ,我们可以学习到变换函数 f θ f_\theta f θ 的参数。
推断/生成 (Inference/Generation)
训练完成后,我们可以利用学习到的可逆变换进行多种操作:
密度估计 : 对于一个新的数据点 x n e w \boldsymbol x_{new} x n e w ,我们可以通过 f θ ( x n e w ) f_\theta(\boldsymbol x_{new}) f θ ( x n e w ) 计算其在隐空间的表示 z n e w \boldsymbol z_{new} z n e w ,并利用公式 log p X ( x n e w ) = log p Z ( z n e w ) + ∑ log ∣ det J f k ∣ \log p_{\mathcal{X}}(\boldsymbol x_{new}) = \log p_{\mathcal{Z}}(\boldsymbol z_{new}) + \sum \log|\det J_{f_k}| log p X ( x n e w ) = log p Z ( z n e w ) + ∑ log ∣ det J f k ∣ 来精确计算其对数似然值。这对于异常检测等任务非常有用。
数据生成 : 要生成新的样本,我们首先从先验分布 p Z p_{\mathcal{Z}} p Z 中采样一个隐变量 z s a m p l e \boldsymbol z_{sample} z s a m p l e ,然后通过逆变换 f θ − 1 = f K − 1 ∘ ⋯ ∘ f 1 − 1 f_\theta^{-1} = f_K^{-1} \circ \dots \circ f_1^{-1} f θ − 1 = f K − 1 ∘ ⋯ ∘ f 1 − 1 将其映射回数据空间,得到新的样本 x s a m p l e = f θ − 1 ( z s a m p l e ) \boldsymbol x_{sample} = f_\theta^{-1}(\boldsymbol z_{sample}) x s a m p l e = f θ − 1 ( z s a m p l e ) 。
以下是一个使用 PyTorch 实现的简化版 RealNVP (一种基于耦合层的流模型) 的示例代码,用于二维数据分布的学习:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 import torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltimport sklearn.datasets class CouplingLayer (nn.Module): def __init__ (self, input_dim, hidden_dim, parity ): super ().__init__() self.parity = parity dim_half = input_dim // 2 self.s_net = nn.Sequential( nn.Linear(dim_half, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim_half) ) self.t_net = nn.Sequential( nn.Linear(dim_half, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, dim_half) ) def forward (self, x, reverse=False ): xa, xb = torch.chunk(x, 2 , dim=-1 ) if self.parity: xa, xb = xb, xa s = self.s_net(xa) t = self.t_net(xa) if reverse: yb = (xb - t) * torch.exp(-s) log_det_jacobian = -s.sum (dim=-1 ) else : yb = xb * torch.exp(s) + t log_det_jacobian = s.sum (dim=-1 ) if self.parity: y = torch.cat([yb, xa], dim=-1 ) else : y = torch.cat([xa, yb], dim=-1 ) return y, log_det_jacobian def inverse (self, y ): return self.forward(y, reverse=True )class NormalizingFlow (nn.Module): def __init__ (self, input_dim=2 , hidden_dim=256 , num_layers=6 ): super ().__init__() self.layers = nn.ModuleList() for i in range (num_layers): self.layers.append(CouplingLayer(input_dim, hidden_dim, i % 2 == 0 )) def forward (self, x ): log_det_sum = torch.zeros(x.shape[0 ], device=x.device) for layer in self.layers: x, log_det = layer(x) log_det_sum += log_det return x, log_det_sum def inverse (self, z ): log_det_sum = torch.zeros(z.shape[0 ], device=z.device) for layer in reversed (self.layers): z, log_det = layer.inverse(z) log_det_sum += log_det return z, log_det_sumdef sample_data (n_samples=1024 ): data, _ = sklearn.datasets.make_moons(n_samples=n_samples, noise=0.05 ) return torch.tensor(data, dtype=torch.float32)def loss_function (z, log_det_jacobian, prior ): log_likelihood = prior.log_prob(z).sum (dim=-1 ) + log_det_jacobian return -log_likelihood.mean()def train_flow (dim=2 , num_epochs=10000 , batch_size=512 , lr=1e-3 ): flow_model = NormalizingFlow(input_dim=dim, num_layers=8 , hidden_dim=128 ) prior = torch.distributions.Normal(torch.zeros(dim), torch.ones(dim)) optimizer = optim.Adam(flow_model.parameters(), lr=lr) print ("开始训练 Flow 模型..." ) for epoch in range (num_epochs): data = sample_data(batch_size) optimizer.zero_grad() z, log_det_jacobian = flow_model(data) loss = loss_function(z, log_det_jacobian, prior) loss.backward() torch.nn.utils.clip_grad_norm_(flow_model.parameters(), 1.0 ) optimizer.step() if (epoch + 1 ) % 500 == 0 : print (f"Epoch [{epoch+1 } /{num_epochs} ], Loss: {loss.item():.4 f} " ) print ("训练完成!" ) return flow_model, priordef visualize_results (flow_model, prior, data_samples, num_generated_samples=1000 ): flow_model.eval () plt.figure(figsize=(18 , 6 )) plt.subplot(1 , 3 , 1 ) plt.scatter(data_samples[:, 0 ], data_samples[:, 1 ], s=10 , alpha=0.5 , c='blue' ) plt.title("Original Data Distribution (Moons)" ) plt.xlabel("x1" ) plt.ylabel("x2" ) plt.xlim(-2 , 3 ) plt.ylim(-1.5 , 2 ) with torch.no_grad(): z_transformed, _ = flow_model(data_samples) z_transformed = z_transformed.numpy() plt.subplot(1 , 3 , 2 ) plt.scatter(z_transformed[:, 0 ], z_transformed[:, 1 ], s=10 , alpha=0.5 , c='green' ) plt.title("Data Mapped to Latent Space (Z)" ) plt.xlabel("z1" ) plt.ylabel("z2" ) xx, yy = np.meshgrid(np.linspace(-3 , 3 , 100 ), np.linspace(-3 , 3 , 100 )) zz_prior = np.exp(-0.5 * (xx**2 + yy**2 )) / (2 * np.pi) plt.contour(xx, yy, zz_prior, levels=5 , alpha=0.3 , cmap='gray' ) plt.xlim(-4 , 4 ) plt.ylim(-4 , 4 ) with torch.no_grad(): z_samples = prior.sample((num_generated_samples,)) x_generated, _ = flow_model.inverse(z_samples) x_generated = x_generated.numpy() plt.subplot(1 , 3 , 3 ) plt.scatter(x_generated[:, 0 ], x_generated[:, 1 ], s=10 , alpha=0.5 , c='red' ) plt.title("Generated Data from Latent Samples" ) plt.xlabel("x1" ) plt.ylabel("x2" ) plt.xlim(-2 , 3 ) plt.ylim(-1.5 , 2 ) plt.tight_layout() plt.show()if __name__ == '__main__' : trained_model, model_prior = train_flow(dim=2 , num_epochs=10000 ) original_data_for_plot = sample_data(1000 ) visualize_results(trained_model, model_prior, original_data_for_plot)