F.grid_sample详解

F.grid_sample 计算细节

​ 该采样函数的接口声明如下:

1
torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)

函数内部先回将根据特征图 input 的大小将坐标规范化到 (1,1)(-1, 1),规范化坐标如下图所示,可以发现:特征图的坐标 (x,y)(x,y) 对应的是 (w,h)(w,h),这点和 Opencv 的规定相同,而不是按数学上的矩阵索引规律

归一化坐标

​ 因此使用 torch.meshgird 生成网格坐标的时候,要使用 indexing='xy' 或者 transpose() 方法(以两倍上采样为例):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
h = torch.linspace(-h/2 + 0.5, h/2 - 0.5, h) / (h/2)
w = torch.linspace(-w/2 + 0.5, w/2 - 0.5, w) / (w/2)
init_pos = torch.stack(torch.meshgrid(w, h, indexing='xy'))
print(init_pos)

```
init_pos:
tensor([[[[-0.7500, -0.7500],
[-0.2500, -0.7500],
[ 0.2500, -0.7500],
[ 0.7500, -0.7500]],

[[-0.7500, -0.2500],
[-0.2500, -0.2500],
[ 0.2500, -0.2500],
[ 0.7500, -0.2500]],

[[-0.7500, 0.2500],
[-0.2500, 0.2500],
[ 0.2500, 0.2500],
[ 0.7500, 0.2500]],

[[-0.7500, 0.7500],
[-0.2500, 0.7500],
[ 0.2500, 0.7500],
[ 0.7500, 0.7500]]]])
```

​ 上面计算过程又引出一个问题:h 和 w为什么是那样计算的?

格点和网点

​ 如上图所示,F.grid_sample(0,0)(0,0) 处坐标对应着图像的正中心(上右图中的中心红点),右上角的绿点的坐标对应着才是 (1,1)(-1, 1),如果我们要生成 target pixel 的坐标,那么图上的情形为 h = w = 8,(1,0)(-1,0) 区间内对应有 4 个点,归一化前每个点间隔为1,则归一化前右上角坐标为 (4,4)(-4, -4),最右上角 target source 坐标 (3.5,3.5)(-3.5,-3.5),归一化后为 (0.875,0.875)(-0.875,0.875),同样的方法可以解释上面的代码和结果

​ 注:padding_mode 参数会对结果造成影响,在 dysample 中选用 padding_mode='border'