Guided Filtering
Guided Filtering是一种图像处理技术,它主要用于边缘保留平滑,即在去除噪声或细节的同时保持图像中的重要边缘,引导滤波(Guided Filtering)和双边滤波(BF)、最小二乘滤波(WLS)是三大边缘保持(Edge-perserving)滤波器。当然,引导滤波的功能不仅仅是边缘保持,只有当引导图是原图的时候,它就成了一个边缘保持滤波器
问题建模:
对于一个输入的图像p,通过引导图像I,经过滤波后得到输出图像q,其中p和1都是算法的输入。引导滤波定义了如下所示的一个线性滤波过程,对于i位置的像素点,得到的滤波输出是一个加权平均值:
qi=j∑Wij(I) pj
其中,i和j分别表示像素下标。 Wij 是只和引导图像I相关的滤波核。该滤波器相对于 p 是线性的。 导向滤波的一个重要假设是输出图像 q 和引导图像I在滤波窗口 wk 上存在局部线性关系:
qi=akIi+bk, ∀i∈wk
对于一个以 r 为半径的确定的窗口 wk ,(ak,bk) 也将是唯一确定的常量系数。这就保证了在一个局部区域里,如果引导图像 I 有一个边缘的时候,输出图像 q 也保持边缘不变,因为对于相邻的像素点而言,存在 ∇q=a∇I 。因此只要求解得到了系数 a,b 也就得到了输出 q。同时认为输入图像中非边缘区域又不平滑的地方视为噪声 n ,就有 qi=pi−ni 。最终的目标就是最小化这个噪声。对于每一个滤波窗口,该算法在最小二乘意义上的最优化可表示为:
mini∈wk∑(qi−pi)2⇒mini∈wk∑(akIi+bk−pi)2
最后,引入一个正则化参数 ϵ 避免 ak 过大,得到滤波窗口内的损失函数:
J(ak,bk)=i∈wk∑((akIi+bk−pi)2+ϵak2)
求解最优化过程(对参数求偏导):
ak∂J=i∈wk∑(2(akIi+bk−pi)Ii+2ϵak)=0bk∂J=i∈wk∑2(akIi+bk−pi)=0
最优化问题的解
为了简便表示,记:
pk=card(wk)1i∈wk∑pi, Ik=card(wk)1Ii, σk2=Ik2−Ik2
解得最优化问题的解为:
ak=σk2+ϵpkIk−pkIkbk=pk−akIk
边缘保持
对于该算法,当 I=p 时,即输入图像和引导图像是同一副图像时,该算法即成为一个边缘保持滤波器。同时,方程的解也可作如下表示:
ak=σk2+ϵσk2bk=(1−ak)pk
从中可以看出,ϵ 在这里相当于界定平滑区域和边缘区域的阈值
考虑以下两种情况:
- Case 1: 平坦区域。如果在某个滤波窗口内,该区域是相对平滑的,方差 σk2 将远远小于 ϵ。从而 a1≈0,bk≈pˉk 。相当于对该区域作均值滤波
- Case 2: 高方差区域。相反,如果该区域是边缘区域,方差很大, σk2 将远远大于 ϵ。从而 ak≈1,bk≈0 。相当于在区域保持原有梯度
代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| import PIL import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt def box_filter(x, r): """ Box filter implementation. """ ch = x.shape[1] kernel_size = 2 * r + 1 pad = torch.nn.ReplicationPad2d(r) padded_x = pad(x) kernel = torch.ones((ch, 1, kernel_size, kernel_size), device=x.device) output = F.conv2d(padded_x, kernel, padding=0, groups=ch) return output
class GuidedFilter(torch.nn.Module): def __init__(self, radius, eps=1e-8): super(GuidedFilter, self).__init__() self.radius = radius self.eps = eps
def forward(self, x, y): n_x = box_filter(x, self.radius) mean_x = n_x / (2 * self.radius + 1) ** 2 mean_y = box_filter(y, self.radius) / (2 * self.radius + 1) ** 2 corr_x = box_filter(x * x, self.radius) var_x = corr_x / (2 * self.radius + 1) ** 2 - mean_x * mean_x cov_xy = box_filter(x * y, self.radius) / (2 * self.radius + 1) ** 2 - mean_x * mean_y a = cov_xy / (var_x + self.eps) b = mean_y - a * mean_x mean_a = box_filter(a, self.radius) / (2 * self.radius + 1) ** 2 mean_b = box_filter(b, self.radius) / (2 * self.radius + 1) ** 2 return mean_a * x + mean_b
if __name__ == '__main__': input_image = PIL.Image.open('./cat.jpg').convert('RGB') np_image = np.array(input_image) input_image = torch.from_numpy(np_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 guide_image = input_image guided_filter = GuidedFilter(radius=5, eps=1e-4) filtered_image = guided_filter(guide_image, input_image) filtered_image = F.max_pool2d(filtered_image, kernel_size=2, stride=2, padding=1) filtered_image = filtered_image.squeeze(0).permute(1, 2, 0).numpy() fig, axes = plt.subplots(1, 2) axes[0].imshow(np_image) axes[0].set_title('Original Image') axes[1].imshow(filtered_image) axes[1].set_title('Filtered Image') plt.show()
print("Filtered image shape:", filtered_image.shape)
|
- 代码简洁之处就在于使用
box_filter
函数来简化计算,对照着上文 ak,bk 的计算公式即可理解