swin_backbone

Swin Transformer

​ Swin Transformer 是一种用于 CV 的通用的 backbone。与传统的卷积神经网络(CNN)不同,Swin Transformer 综合了 CNN 和 VIT 的优点,在局部区域进行 self-attention 计算,捕捉长距离依赖关系也注重于局部信息。

​ Swin Transformer 的名称来源于其核心组件——滑动窗口(Sliding Window)的多头自注意力模块。这个设计允许模型在保持计算效率的同时,有效处理局部信息和全局上下文。

​ Swin Transformer 将 self-attention 的计算限制在局部区域,因此优化了 VIT 的计算复杂度随图像大小变大而平方增长的复杂度,swin transformer 的复杂度被优化到了线性复杂度,做到了轻量化的同时效果也很好

网络设计

swin

​ 上图展示了 Swin Transformer 架构概览 (tiny 版 Swin-T)

  • 它首先通过 Patch 拆分模块 (Patch Partition) (同 ViT) 将输入的 H×W×3H\times W\times 3 大小的图像拆分为非重叠的 N×(P2×3)N\times(P^2\times3) 个 patch,每个长度为 P2×3P^2\times 3 的 patch 都视为一个 patch token,在论文中设置了 patch size 为 4×44\times 4 而不是像 VIT 的 16×1616\times 16,那么通过 patch 拆分模块之后,特征图大小变为 H4×W4×48\frac{H}{4}\times \frac{W}{4} \times 48
  • Linear embedding 层就是一个 1*1 的 convolution 调整通道数
  • 论文中对 swin transformer block 的叫法其实很奇怪,右侧两个 block 示意图都称为 swin transformer block,但是事实上两个 block 拼在一起才是 swin 的一个基本模块,所以示意图中 block 数量都为偶数,这个网络结构等会介绍
  • 为产生一个层次化表示 (Hierarchical Representation),随着网络的加深,tokens 数逐渐通过 Patch 合并层 (Patch Meraging) 减少(这也算一个下采样的模块)。首个 patch 合并层拼接了每组相邻的 2×22\times 2 个 patch,则特征图分辨率变为 H8×W8\frac{H}{8}\times \frac{W}{8},token 向量的维度数变为原来四倍,即通道维度变为 4C4C,为了防止通道数增长过快,后续加了一个 linear 层进行通道压缩到 2C2C。patch merging 的代码可以如下实现:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class PatchMerging(nn.Module):
def __init__(self, in_channels, out_channels, downscaling_factor):
super().__init__()
self.downscaling_factor = downscaling_factor
self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

def forward(self, x):
b, c, h, w = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
# 也可以设置 nn.Conv2d 设置 kernel_size=1
x = self.linear(x)
return x

swin感受野变化

Swin Transformer Block

WS-MSA

​ 窗口大小固定为 7×77\times 7,self-attention 只在窗口内进行,但是窗口之间完全没有信息交互的话就会损失很多信息,因此作者在论文里面设计了两种 window attention,即 WS-MA (Window-based Self-Attention) 和 SW-MSA (Shifted Window-based Multi-head Self-Attention),WS-MSA 的设计代码就完全和 VIT 一样了,重点在于 SW-MSA

​ 由于我们需要窗口之间相互有信息交流,因此要求窗口之间要有重叠部分,上图列出了窗口重叠与不重叠的区别,这样又引出了一个问题,分割出来的不同的区域大小不是一样的,这样就不呢使用 batch 运算加速了,如果只是简单的把小窗口的元素补零匹配大的窗口,这样算下来窗口数量仍然为 9,相比于不重叠的窗口数量 4 大了一倍多,这样的计算开销还是很大。作者在论文里面将使用重叠窗口后,得到的不同的窗口拼成和不重叠的窗口一样,窗口拼接的可视化如下:

partition

concat

​ Cyclic shift 代码实现如下:

1
2
3
4
5
6
7
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement

def forward(self, x):
return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

​ 这样的话,由于拼接后每个窗口块是可能包含不同的窗口的,而我们不应该在不同的窗口块之间计算 cross attention,作者巧妙地使用了 mask-attention 来阻止了不同窗口之间进行计算相似度,mask 的可视化效果如下,黄色部分代表值为 -100,经过 softmax 之后就变为 0 了,就认为是 mask 掉了:

visualization

​ 其中 mask 图像的颜色解释如下(以 window2 为例):

example

​ create mask 代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def create_mask(window_size, displacement, upper_lower, left_right):
mask = torch.zeros(window_size ** 2, window_size ** 2)

if upper_lower:
mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

if left_right:
mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
mask[:, -displacement:, :, :-displacement] = float('-inf')
mask[:, :-displacement, :, -displacement:] = float('-inf')
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

return mask

​ 最终 WindowAttention 实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class WindowAttention(nn.Module):
def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
inner_dim = head_dim * heads

self.heads = heads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.relative_pos_embedding = relative_pos_embedding
self.shifted = shifted

if self.shifted:
displacement = window_size // 2
self.cyclic_shift = CyclicShift(-displacement)
self.cyclic_back_shift = CyclicShift(displacement)
self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=True, left_right=False), requires_grad=False)
self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=False, left_right=True), requires_grad=False)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

if self.relative_pos_embedding:
self.relative_indices = get_relative_distances(window_size) + window_size - 1
self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
else:
self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x):
if self.shifted:
x = self.cyclic_shift(x)

b, n_h, n_w, _, h = *x.shape, self.heads

qkv = self.to_qkv(x).chunk(3, dim=-1)
nw_h = n_h // self.window_size
nw_w = n_w // self.window_size

q, k, v = map(
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
h=h, w_h=self.window_size, w_w=self.window_size), qkv)

dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

if self.relative_pos_embedding:
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
else:
dots += self.pos_embedding

if self.shifted:
dots[:, :, -nw_w:] += self.upper_lower_mask
dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

attn = dots.softmax(dim=-1)

out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
out = self.to_out(out)

if self.shifted:
out = self.cyclic_back_shift(out)
return out

完整代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat


class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement

def forward(self, x):
return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))


class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
return self.net(x)


def create_mask(window_size, displacement, upper_lower, left_right):
mask = torch.zeros(window_size ** 2, window_size ** 2)

if upper_lower:
# window2 的左下部分
mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
# window3 的右上部分
mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

if left_right:

mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
mask[:, -displacement:, :, :-displacement] = float('-inf')
mask[:, :-displacement, :, -displacement:] = float('-inf')
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

return mask


def get_relative_distances(window_size):
indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
distances = indices[None, :, :] - indices[:, None, :]
return distances


class WindowAttention(nn.Module):
def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
inner_dim = head_dim * heads

self.heads = heads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.relative_pos_embedding = relative_pos_embedding
self.shifted = shifted

if self.shifted:
displacement = window_size // 2
self.cyclic_shift = CyclicShift(-displacement)
self.cyclic_back_shift = CyclicShift(displacement)
self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=True, left_right=False), requires_grad=False)
self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
upper_lower=False, left_right=True), requires_grad=False)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

if self.relative_pos_embedding:
self.relative_indices = get_relative_distances(window_size) + window_size - 1
self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
else:
self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x):
if self.shifted:
x = self.cyclic_shift(x)

b, n_h, n_w, _, h = *x.shape, self.heads

qkv = self.to_qkv(x).chunk(3, dim=-1)
nw_h = n_h // self.window_size
nw_w = n_w // self.window_size

q, k, v = map(
lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
h=h, w_h=self.window_size, w_w=self.window_size), qkv)

dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

if self.relative_pos_embedding:
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
else:
dots += self.pos_embedding

if self.shifted:
dots[:, :, -nw_w:] += self.upper_lower_mask
dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

attn = dots.softmax(dim=-1)

out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
out = self.to_out(out)

if self.shifted:
out = self.cyclic_back_shift(out)
return out


class SwinBlock(nn.Module):
def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
heads=heads,
head_dim=head_dim,
shifted=shifted,
window_size=window_size,
relative_pos_embedding=relative_pos_embedding)))
self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

def forward(self, x):
x = self.attention_block(x)
x = self.mlp_block(x)
return x


class PatchMerging(nn.Module):
def __init__(self, in_channels, out_channels, downscaling_factor):
super().__init__()
self.downscaling_factor = downscaling_factor
self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

def forward(self, x):
b, c, h, w = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
x = self.linear(x)
return x


class StageModule(nn.Module):
def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
relative_pos_embedding):
super().__init__()
assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
downscaling_factor=downscaling_factor)

self.layers = nn.ModuleList([])
for _ in range(layers // 2):
self.layers.append(nn.ModuleList([
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
]))

def forward(self, x):
x = self.patch_partition(x)
for regular_block, shifted_block in self.layers:
x = regular_block(x)
x = shifted_block(x)
return x.permute(0, 3, 1, 2)


class SwinTransformer(nn.Module):
def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
super().__init__()

self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
window_size=window_size, relative_pos_embedding=relative_pos_embedding)

self.mlp_head = nn.Sequential(
nn.LayerNorm(hidden_dim * 8),
nn.Linear(hidden_dim * 8, num_classes)
)

def forward(self, img):
x = self.stage1(img)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = x.mean(dim=[2, 3])
return self.mlp_head(x)