Deformable Attention

Deformable Attention

​ Deformable Attention 是一种改进的注意力机制,传统的 CNN 和注意力机制在处理图像数据时,通常使用固定的网格或窗口来提取特征,这限制了模型对不同尺度和形状的目标进行有效建模的能力。Deformable Attention 就是将 deformable convolution 和 VIT 缝合在一起了,以更好地捕捉目标的形状和结构

DAT

上面一张图展现了四种网络的特点:

  • ViT 中所有 Q 的感受野是一样的,都针对全局所有位置特征
  • Swin 中则是局部 Attention,因此处于不同窗口的两个 Q 针对的感受野区域是不一样的
  • DCN 则是针对周围九个位置学习偏差,之后采样矫正过的特征位置,可以看到图中红点蓝点数量均为 9
  • 本文提出的 DAT 则结合了 ViT 和 DCN,所有的 Q 会共享相同的感受野,但这些感受野会有学出来的位置偏差;为了降低计算复杂度,针对的特征数量也会降采样,因此图中采样点一共 16 个,相比原来缩小了 1/41/4

公式推导

Vanilla VIT:

下面的 Block 称为 multi-head self-attention(MHSA) block:

q=xWq, k=xWk, v=xWvz(w)=σ(q(m)k(m)T/d)v(m), m=1,,Mz=Concat(z(1),,z(M))Woq = xW_q, \ k=xW_k, \ v = xW_v \\ z^{(w)}=\sigma \left( q^{(m)}{k^{(m)}}^T / \sqrt{d} \right) v^{(m)},\ m=1,\dots,M \\ z = Concat(z^{(1)},\dots, z^{(M)})W_o

​ 其中 xx 为输入特征图,MM 为 multi-head attention 的 head 数量,其中 z(m)z^{(m)}xx 计算 multi-head attention 的第 mm 个分组的结果,σ\sigma 为 softmax 函数,最后一个 Concat 的式子代表着对这 mm 个分组进行特征融合

​ 则堆叠 MHSA 模块的计算过程为:

zl=MHSA(LN(zl1))+zl1zl=MLP(LN(zl))+zlz_l'=MHSA(LN(z_{l-1})) + z_{l-1} \\ z_l = MLP(LN(z_l')) + z_l'

Deformable Attention:

q=xWq,  offset=CNNblock(q)x~=F.grid_sample(x,offset+grid)k~=x~Wk,  v~=x~Wvz(m)=σ(q(m)k~(m)T/d+ϕ(B^;R))v~(m)q = xW_q,\ \ offset = CNNblock(q)\\ \tilde{x}=F.grid\_sample(x, offset+grid) \\ \tilde{k}=\tilde{x}W_k,\ \ \tilde{v}=\tilde{x}W_v \\ z^{(m)}=\sigma \left( q^{(m)}{\tilde{k}^{(m)}}^T / \sqrt{d} + \phi(\hat{B}; R) \right)\tilde{v}^{(m)}

​ 其中 ϕ\phi 为 position embedding,这个没什么说的,生成 grid 用一些乱七八糟的函数就行,论文中使用了 continuous positional bias

Module

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops


# continuous positional bias
class CPB(nn.Module):
def __init__(self, dim, *, heads, offset_groups, depth):
super().__init__()
self.heads = heads
self.offset_groups = offset_groups

self.mlp = nn.ModuleList([])

self.mlp.append(nn.Sequential(
nn.Linear(2, dim),
nn.ReLU()
))

for _ in range(depth - 1):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU()
))

self.mlp.append(nn.Linear(dim, heads // offset_groups))

def forward(self, grid_q, grid_kv):
grid_q = einops.rearrange(grid_q, 'h w c -> 1 (h w) c')
grid_kv = einops.rearrange(grid_kv, 'b h w c -> b (h w) c')
pos = einops.rearrange(grid_q, 'b i c -> b i 1 c') - einops.rearrange(grid_kv, 'b j c -> b 1 j c')
bias = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

for layer in self.mlp:
bias = layer(bias)
bias = einops.rearrange(bias, '(b g) i j o -> b (g o) i j', g = self.offset_groups)
return bias

class DeformableAttention(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
dropout=0.,
downsample_factor=4,
offset_kernel_size=6,
group_queries=True,
group_key_values=True
):
super().__init__()

# 经过 scale 之前的生成的 offset
self.offset_scale = downsample_factor
assert self.offset_scale - downsample_factor // 2
offset_groups = heads
inner_dim = dim_head * heads

self.offset_groups = offset_groups
self.scale = dim_head ** -0.5
self.heads = heads
self.downsample_factor = downsample_factor

offset_dims = inner_dim // offset_groups

self.to_offset = nn.Sequential(
nn.Conv2d(offset_dims, offset_dims, offset_kernel_size, groups=offset_dims, stride=downsample_factor, padding=(offset_kernel_size - downsample_factor) // 2),
nn.GELU(),
nn.Conv2d(offset_dims, 2, 1, bias=False),
nn.Tanh(),
)

self.rel_pos_bias = CPB(dim // 4, offset_groups=offset_groups, heads=heads, depth=2)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_queries else 1, bias = False)
self.to_k = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_key_values else 1, bias = False)
self.to_v = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_key_values else 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)


@staticmethod
def make_grid_like(x):
_, _, h, w = x.shape
coord_h = torch.arange(h, device=x.device)
cooor_w = torch.arange(w, device=x.device)
grid = torch.stack(torch.meshgrid(coord_h, cooor_w, indexing='ij'), dim=0)
grid.requires_grad = False
grid = grid.type_as(x)
return grid

@staticmethod
def normalize_grid(grid, unpack_dim=1):
h, w = grid.shape[-2:]
grid_h, grid_w = grid.unbind(dim=unpack_dim)
grid_h = 2.0 * grid_h / max(h - 1, 1) - 1.0
grid_w = 2.0 * grid_w / max(w - 1, 1) - 1.0
return torch.stack((grid_h, grid_w), dim=unpack_dim)

def forward(self, x, return_vgrid=False):
b, _, h, w = x.shape
# group channel to batch
group_c2b = lambda t: einops.rearrange(t, 'b (g d) ... -> (b g) d ...', g=self.offset_groups)

q = self.to_q(x)
grouped_queries = group_c2b(q)
offsets = self.to_offset(grouped_queries)

grid = self.make_grid_like(offsets)
vgrid = grid + offsets
vgrid_scaled = self.normalize_grid(vgrid, 1).permute(0, 2, 3, 1)

kv_feats = F.grid_sample(group_c2b(x), vgrid_scaled, mode='bilinear', align_corners=False)
kv_feats = einops.rearrange(kv_feats, '(b g) d ... -> b (g d) ...', b=b)

k = self.to_k(kv_feats)
v = self.to_v(kv_feats)
q, k, v = map(lambda t: einops.rearrange(t, 'b (h d) ... -> b h (...) d', h=self.heads), (q, k, v))

sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)

grid = self.make_grid_like(x)
grid_scaled = self.normalize_grid(grid, 0).permute(1, 2, 0)
rel_pos_bias = self.rel_pos_bias(grid_scaled, vgrid_scaled)
sim = sim + rel_pos_bias
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)

out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = einops.rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
out = self.to_out(out)

if return_vgrid:
return out, vgrid
else:
return out

if __name__ == '__main__':
x = torch.randn(1, 64, 32, 48)
model = DeformableAttention(dim=64, dim_head=64, heads=8, downsample_factor=4)
out = model(x)
print(out.shape)
  • 代码中使用了 group-query attention 的技巧,这个体现在生成 offset 的时候进行了下采样,用下采样后的 x~\tilde{x} 去生成 k,vk,v 来进行高效实现,即下采样倍率 downsample_factor2downsample\_factor^2 个点为一组,这一组共用一个 key 和 value 但是 query 不共用
  • 对于实现每组分组都采样不同,考虑 batch 维度的采样本身就不同,可以将不同组需要不同采样的维度挪到 batch 维度上就行,这个操作是需要比较多的,所以可以考虑像上面代码一样使用一个函数引用,引用一个匿名函数(einops 函数居然能用省略号…):
1
group_c2b = lambda t: einops.rearrange(t, 'b (g d) ... -> (b g) d ...', g=self.offset_groups)
  • map函数将同样的张量变换操作更加简洁(将不同的分组维度挪回来):
1
q, k, v = map(lambda t: einops.rearrange(t, 'b (h d) ... -> b h (...) d', h=self.heads), (q, k, v))
  • 对于生成的 grid,我们不需要对他进行反向传播求梯度,所以设置 grid.requires_grad = False 可以减少内存占用和计算开销