本文记载了在Low Light Image Enhancement (LLIE) 任务中常见的损失函数,并提供了相应的代码。其中的一些损失函数在其他Low Level任务中也可以适用。

L1损失

L1损失广泛应用在图像的各个任务中,L1损失(也称 Mean Absolute Error, MAE / 平均绝对误差)是深度学习和统计回归中常见的一种损失函数。它通过计算预测值与真实值之间的 绝对差值 来衡量误差:

1
2
3
4
5
6
7
8
9
10
11
12

import torch.nn as nn
Loss = nn.L1Loss()
# 或者
class L1_loss(nn.Module):
def __init__(self):
super().__init__()
self.loss = nn.L1Loss()
def forward(self, light_map, normal_img):
loss = self.loss(light_map, normal_img)
return loss

L2损失

与L1损失一样,L2损失也是很常见的一种损失函数,但在低光增强任务中的应用比较少。L2 损失通过计算预测值与真实值之间 差值的平方 来度量误差:

$$
\mathcal{L}{L2} = \frac{1}{N} \sum{i=1}^{N} \left( y_i - \hat{y}_i \right)^2
$$

优点

  • 平滑收敛:平方会让小误差贡献更小,大误差贡献更大,梯度会逐渐减小,更利于优化。
  • 数学性质好:连续可导,优化器更容易收敛。
    缺点
  • 对异常值敏感:如果样本中有 outlier,平方会放大其影响,模型容易偏向异常点。
  • 容易模糊(在图像任务中):因为它倾向于最小化整体均方误差,导致预测结果趋向“平均值”,图像增强和生成中可能出现平滑/模糊。
1
2
import torch.nn as nn
criterion = nn.MSELoss()

VGG19损失函数——感知损失

感知损失(Perceptual Loss)是一种基于深度学习的图像风格迁移方法中常用的损失函数。与传统的均方误差损失函数(Mean Square Error,MSE)相比,感知损失更注重图像的感知质量,更符合人眼对图像质量的感受。
感知损失是通过预训练的神经网络来计算两张图片之间的差异。通常使用预训练的卷积神经网络(Convolutional Neural Network,CNN),这些网络已经在大规模的数据集上进行了训练,可以提取图像的高级特征。例如,VGG-19网络中的卷积层可以提取图像的纹理和结构信息,而网络的全连接层可以提取图像的语义信息。
感知损失的计算方式通常是将输入图像和目标图像分别通过预训练的神经网络,得到它们在网络中的特征表示。然后将这些特征表示作为损失函数的输入,计算它们之间的欧氏距离或曼哈顿距离。感知损失的目标是最小化输入图像和目标图像在特征空间的距离。
下面的代码实例化 了一个VGG19网络,并在VGGloss中调用该网络,分别输入增强后的图像和正常图像,输出各层调整计算L1/L2损失。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False

def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self, loss_weight=1.0, criterion = 'l1', reduction='mean'):
super(VGGLoss, self).__init__()
self.vgg = VGG19().cuda()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. '
f'Supported ones are: {_reduction_modes}')


if criterion == 'l1':
self.criterion = nn.L1Loss(reduction=reduction)
elif criterion == 'l2':
self.criterion = nn.MSELoss(reduction=reduction)
else:
raise NotImplementedError('Unsupported criterion loss')

self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
self.weight = loss_weight

def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return self.weight * loss

另一种写法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models

# other import
import os
import math
import cv2
import numpy as np
from math import exp
import pytorch_msssim


class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.requires_grad = False


class VGGLoss(nn.Module):
def __init__(self, conv_index='54', rgb_range=1):
super(VGGLoss, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
self.vgg.cuda()
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])
self.vgg.cuda()

vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std).cuda()
self.vgg.requires_grad = False

def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x

vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())

loss = F.mse_loss(vgg_sr, vgg_hr)

return loss

与一般的L1损失和L2损失对比,感知损失有以下特点:

特征 像素级损失(L1/L2) 感知损失
比较空间 原始像素空间 深度特征空间
关注点 像素值精确匹配 语义内容匹配
结果 可能模糊/不自然 更自然、更符合视觉感知
计算复杂度 较高(需要前向传播)

通常,感知损失是由L1损失和L2损失来计算的,表达式如下:

$$
\begin{equation}
\mathcal{L}{\text{perceptual}}^{L1} = \sum{l} \lambda_l \left| \phi_l(I_{\text{gen}}) - \phi_l(I_{\text{ref}}) \right|_1
\end{equation}
$$

$$
\begin{equation}
\mathcal{L}{\text{perceptual}}^{L2} = \sum{l} \lambda_l \left| \phi_l(I_{\text{gen}}) - \phi_l(I_{\text{ref}}) \right|_2^2
\end{equation}
$$

边缘损失

边缘损失常用于低光增强、超分辨、去噪等重建类任务里。它的目标是让预测图像和真实图像在 边缘结构上尽量接近,从而避免结果模糊。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
self.sobel_kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.sobel_kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
def forward(self, x, x_hat):
self.sobel_kernel_x = self.sobel_kernel_x.to(x.device)
self.sobel_kernel_y = self.sobel_kernel_y.to(x.device)
grad_x = torch.zeros_like(x)
grad_x_hat = torch.zeros_like(x_hat)
for c in range(x.shape[1]):
grad_x_c_x = F.conv2d(x[:, c:c+1, :, :], self.sobel_kernel_x, padding=1)
grad_x_c_y = F.conv2d(x[:, c:c+1, :, :], self.sobel_kernel_y, padding=1)
grad_x[:, c:c+1, :, :] = torch.sqrt(grad_x_c_x ** 2 + grad_x_c_y ** 2 + 1e-6)
grad_x_hat_c_x = F.conv2d(x_hat[:, c:c+1, :, :], self.sobel_kernel_x, padding=1)
grad_x_hat_c_y = F.conv2d(x_hat[:, c:c+1, :, :], self.sobel_kernel_y, padding=1)
grad_x_hat[:, c:c+1, :, :] = torch.sqrt(grad_x_hat_c_x ** 2 + grad_x_hat_c_y ** 2 + 1e-6)
loss = F.mse_loss(grad_x, grad_x_hat)
return loss

SSIM损失

SSIM指标介于0到1直接,且越大说明模型效果越好,而SSIM损失就是:

SSIM的原理与具体计算方法不在这里展开

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class SSIM_loss(nn.Module):
def __init__(self, window_size=11, sigma=1.5, data_range=1.0, channel=1):
super(SSIM_loss, self).__init__()
self.window_size = window_size
self.sigma = sigma
self.data_range = data_range
self.channel = channel
self.gaussian_kernel = self._create_gaussian_kernel(window_size, sigma)

def _create_gaussian_kernel(self, window_size, sigma):
gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2))
for x in range(window_size)])
gauss = gauss / gauss.sum()
kernel = torch.outer(gauss, gauss)
return kernel.view(1, 1, window_size, window_size).repeat(self.channel, 1, 1, 1)

def ssim(self, img1, img2):
C1 = (0.01 * self.data_range)**2
C2 = (0.03 * self.data_range)**2
kernel = self.gaussian_kernel.to(img1.device)
mu1 = F.conv2d(img1, kernel, padding=0, groups=self.channel)
mu2 = F.conv2d(img2, kernel, padding=0, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, kernel, padding=0, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, kernel, padding=0, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, kernel, padding=0, groups=self.channel) - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

return ssim_map.mean()

def forward(self, img1, img2):
if img1.size() != img2.size():
raise ValueError(f"Input images must have the same dimensions. Got {img1.size()} and {img2.size()}")
if self.channel != img1.shape[1]:
self.channel = img1.shape[1]
self.gaussian_kernel = self._create_gaussian_kernel(self.window_size, self.sigma)
ssim_val = self.ssim(img1, img2)
return 1 - ssim_val

频域损失

随着频域处理引入低光增强领域,频域损失也在低光增强中得到应用,主要目的是让预测图与真值在 频率域的能量分布上保持一致

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class FrequencyLoss(nn.Module):
def __init__(self, loss_weight = 0.01, criterion ='l1', reduction = 'mean'):
super(FrequencyLoss, self).__init__()
self.loss_weight = loss_weight
self.reduction = reduction
if criterion == 'l1':
self.criterion = nn.L1Loss(reduction=reduction)
elif criterion == 'l2':
self.criterion = nn.MSELoss(reduction=reduction)
else:
raise NotImplementedError('Unsupported criterion loss')

def forward(self, pred, target, weight=None, **kwargs):
pred_freq = self.get_fft_amplitude(pred)
target_freq = self.get_fft_amplitude(target)
return self.loss_weight * self.criterion(pred_freq, target_freq)

def get_fft_amplitude(self, inp):
inp_freq = torch.fft.rfft2(inp, norm='backward')
amp = torch.abs(inp_freq)
return amp

Charbonnier Loss

发表于《Fast and Accurate Image Super-Resolution with Deep Laplacian Pyramid Networks》
与L1相比,增加了一个正则项。用Charbonnier Loss来近似L1损失来提高模型的性能,接近零点的值的梯度由于ε的存在,梯度不会太小,避免梯度消失。

1
2
3
4
5
6
7
8
9
10
11
class L1_Charbonnier_loss(torch.nn.Module):
"""L1 Charbonnierloss."""
def __init__(self):
super(L1_Charbonnier_loss, self).__init__()
self.eps = 1e-6

def forward(self, X, Y):
diff = torch.add(X, -Y)
error = torch.sqrt(diff * diff + self.eps)
loss = torch.mean(error)
return loss

颜色一致性损失

这里提到的颜色一致性损失是Zero-DCE模型中用到的一种损失,因此它不是一种监督学习中的损失函数,而是一种Zero-Shot中的损失函数。它在低光增强和图像复原任务中很常见,目的是约束增强后的图像保持合理的色彩平衡,避免增强结果出现明显的 偏色(例如全图偏红、偏绿)。该损失函数的原理为:在“灰世界假设 (Gray-World Assumption)”中,自然图像的 R/G/B 通道平均值应当接近一致。损失函数通过惩罚通道均值之间的差异,使得增强后的图像颜色更自然。
计算方式:

  • 先计算 RGB 三个通道的均值:
  • 计算通道差异:

  • 计算颜色损失:

$$
\mathcal{L}{color} = \frac{1}{B} \sum{b=1}^{B} \sqrt{ D_{RG}^{2} + D_{RB}^{2} + D_{GB}^{2}}
$$

1
2
3
4
5
6
7
8
9
10
11
12
class L_color(nn.Module):
def __init__(self):
super(L_color, self).__init__()
def forward(self, x ):
b,c,h,w = x.shape
mean_rgb = torch.mean(x,[2,3],keepdim=True)
mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
Drg = torch.pow(mr-mg,2)
Drb = torch.pow(mr-mb,2)
Dgb = torch.pow(mb-mg,2)
k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
return k.mean()

In the end:代码来源主要为主流低光增强模型的开源代码中,由于数量较多且各个模型所使用的损失函数多有重复,因此这里不再明确标注引用出处,对这些作者表示由衷感谢