Paper:https://arxiv.org/pdf/2401.15204v6

Code: https://github.com/albrateanu/LYT-Net

目录

0. 摘要

1. 引言

 2. 方法

2.1. CWD模块

2.2. MHSA模块

2.3. MSEF模块

2.4. 损失函数

3. 结果与讨论

 4. 消融实验

5. 结论

附【网络结构的Pytorch代码】:


0. 摘要

        本文提出了LYT-Net,这是一个新颖的、轻量的、transformer-based的低光照图像增强模型,它由几个层和可拆卸的块组成,包括我们的新块——Channel-Wise Denoiser (CWD)和Multi-Stage Squeeze & Excite Fusion (MSEF)——以及传统的Transformer块,Multi-Headed Self-Attention (MHSA)。我们采用双路径方法,将色度通道 U 和 V 和亮度通道 Y 视为单独的实体,以帮助模型更好地处理光照调整和损坏恢复。我们对已建立的LLIE数据集的综合评估表明,尽管它的复杂性较低,但我们的模型优于最近的LLIE方法。

1. 引言

        低光照图像增强(LLIE)是计算成像中一项重要且具有挑战性的任务。当图像在弱光条件下捕获时,它们的质量通常会恶化,从而导致细节和对比度的损失。这不仅使图像在视觉上不吸引人,而且会影响许多成像系统的性能。LLIE 的目标是提高这些图像的清晰度和对比度,同时还纠正黑暗环境中经常出现的失真,所有这些都不引入不需要的伪影或导致颜色不平衡。

        早期的LLIE方法[1]主要依靠频率分解[2]、[3]、[4]、直方图均衡化[5]、[6]、[7]和Retinex理论[8]、[9]、[10]、[11]、[12]。随着深度学习的快速发展,各种CNN架构[13],[14],[15],[16],[17],[18],[19],[20],[21],[22]已被证明优于传统的LLIE技术。基于Retinex理论,Retinex-Net[13]将Retinex分解与原始CNN架构相结合,而DiffRetinex[16]提出了一个生成框架,以进一步解决弱光引起的内容丢失和颜色偏差。生成对抗网络(GAN)[23]的发展为LLIE提供了新的视角,其中弱光图像作为输入,生成正常光对应图像。例如,启蒙gan [24] 使用单个生成器模型直接将低光图像转换为正常光版本,有效地在转换过程中同时使用全局和局部鉴别器。

        生成对抗网络(GAN)[23]的发展为LLIE提供了新的视角,其中弱光图像作为输入,生成正常光对应图像。例如,启蒙gan [24] 使用单个生成器模型直接将低光图像转换为正常光版本,有效地在转换过程中同时使用全局和局部鉴别器。

        最近,视觉transformer(ViTs)[25]在各种计算机视觉任务中表现出了显著的有效性[26],[27],[28],[29],[30],这主要是归功于自注意力 (SA) 机制。尽管取得了这些进展,但vit在低级(low-level)视觉任务中的应用仍未得到充分探索。在最近的文献[31]、[32]、[33]中只引入了一些基于LLIE(微光图像增强)-VIT的策略。Uformer [31] 基于经典的 UNet 架构,其中卷积层被替换为 Transformer 块,同时保持分层编码器-解码器结构和跳过连接。另一方面,Restormer [33] 引入了多 Dconv 头转置注意力 (MDTA) 块,取代了普通的多头自注意力。

        我们提出了一种新的轻量级基于transformer的方法,称为LYT-Net。与现有的基于transformer的方法不同,我们的方法专注于计算效率,同时仍然产生最先进的(SOTA)结果。具体来说,我们首先使用 YUV 颜色空间将色度与亮度分开。色度信息(通道U和V)最初通过专门的通道去噪器(CWD)块进行处理,在保持精细细节的同时减少了噪声。为了降低计算复杂度,亮度通道 Y 经历卷积和池化来提取特征,随后由传统的多头自注意力 (MHSA) 块增强。然后,通过一种新颖的多级SE融合 (MSEF) 块重新组合和处理增强通道。最后,色度通道U和V通道与亮度Y通道连接,并通过最后一组卷积层来生成恢复的图像。

        我们的方法对已建立的LLIE数据集进行了广泛的测试。定性和定量评估都表明我们的方法取得了极具竞争力的结果。图1展示了使用LOL数据集[13]评估的SOTA方法之间性能相对于复杂性的比较分析。可以看出,尽管它的设计轻量级,但我们的方法产生的结果不仅与最近更复杂的深度学习LLIE 技术的结果相当,而且通常效果更好。

 2. 方法

        在图 2 中,我们说明了 LYTNet 的整体架构,它由几层和可拆卸块组成,包括我们的新块——Channel-Wise Denoiser (CWD) 和多阶段 Squeeze & Excite Fusion (MSEF)——以及传统的 ViT 块,Multi-Headed Self-Attention (MHSA)。我们采用双路径方法,将色度和亮度视为单独的实体,以帮助模型更好地处理光照调整和损坏恢复。亮度通道Y经过卷积和池化提取特征,然后由MHSAblock增强。通过CWD块处理色度通道U和V,以减少噪声,同时保留细节。然后,通过MSEF块重新组合和处理增强的色度通道。最后,将色度 U、V 和亮度 Y 通道连接起来并通过最后一组卷积层来生成输出,从而产生高质量、增强的图像。

2.1. CWD模块

        CWD块采用u型网络和MHSA作为瓶颈,集成了卷积和注意力机制。它包括多个具有不同步幅和跳过连接的 conv3×3 层,促进了详细的特征捕获和去噪。

        它由一系列四个 conv3×3 层组成。第一个conv3×3 在特征提取方面步幅为 1。其他三个 conv3×3 层的步长为 2,有助于捕获不同尺度的特征。注意力瓶颈的集成使模型能够捕获长期依赖关系,然后是上采样层和跳过连接来重建并促进空间分辨率的恢复。

        这种方法允许我们在空间维度降低的特征图上应用MHSA,显著提高计算效率。此外,使用基于插值的上采样而不是转置卷积将 CWD 中的参数数量减少了一半以上,同时保留了性能。

2.2. MHSA模块

        在我们简化版的transformer架构中,输入特征F_{in} \in R^{H \times W \times C}首先通过无偏置全连接层线性投影到查询(Q)、键(K)和值(V)分量。线性投影保持原始输入维度。

        接下来,这些投影特征被分成 k 个头:

         其中每个头都以维度 d_k 独立运行。自注意力机制应用于每个头部,定义如下:

        最后,将所有头部的注意力输出连接起来,组合输出通过线性层将其投影回原始嵌入大小。输出标记 X_{out} 被重新整形回原始空间维度以形成输出特征 F_{out } \in \mathbb{R}^{H\times W \times C}

2.3. MSEF模块

        MSEF块增强了F_{in}的空间和通道特征。最初,F_{in}经历层归一化,然后是全局平均池化来捕获全局空间上下文和具有 ReLU 激活的缩减全连接层,产生减少的描述符 S_{reduced},如公式 (4)。然后,该描述符通过另一个具有 Tanh 激活的全连接层扩展到原始维度,从而产生S_{expanded}, 如公式 (5)。

         在融合输出中加入残差连接,生成最终的输出特征图F_{out},如式(6)所示。

2.4. 损失函数

         在我们的方法中,混合损失函数在有效地训练我们的模型方面起着关键作用。混合损失L如式(7)所示,其中α1到α5是用于平衡每个组成损失函数的超参数。

        我们模型中的混合损失结合了几个组件来提高图像质量和感知。平滑 L1 损失 LS 通过基于预测值和真实值之间的差异应用二次或线性惩罚来处理异常值。感知损失 LPerc 通过比较 VGG 提取的特征图来保持特征一致性。直方图损失LHist对齐预测图像和真实图像之间的像素强度分布。PSNR损失LPSNR通过惩罚均方误差来降低噪声,而颜色损失LColor通过最小化通道平均值的差异来确保颜色保真度。最后,多尺度SSIM损失LMS-SSIM通过在多个尺度上评估相似性来保持结构的完整性。总之,这些损失形成了一个综合策略,解决了图像增强的各个方面。

3. 结果与讨论

实现细节:LYT-Net 的实现利用了 TensorFlow 框架。ADAM 优化器 (β1 = 0.9 and β2 = 0.999) 用于超过 1000 个 epoch 的训练。初始学习率设置为 2×10−4,并在余弦退火计划后逐渐衰减到 1 × 10−6,有助于优化收敛并避免局部最小值。混合损失函数的超参数设置为:α1=0.06、α2=0.05、α3=0.5、α4=0.0083 和 α5=0.25。 LYT-Net 在 LOL 数据集的三个版本上进行训练和评估:LOL-v1、LOL-v2-real 和 LOL-v2-synthetic。LOLv1、LOL-v2-real 的相应训练/测试拆分为 485 : 15,LOL-v2-real 为 689 : 100,LOL-v2-synthetic 为 900 : 100。在训练期间,图像对进行随机增强,包括随机裁剪到 256 × 256 和随机翻转/旋转,以防止过度拟合。训练以 1 的批大小进行。评估指标包括 PSNR 和 SSIM 进行性能评估。

PS:官方的Gtihub代码里面同时也包含了Pytorch版本。

定量结果:将所提出的方法与 SOTA LLIE 技术进行比较,如表 I 所示,重点关注两个方面:LOL 数据集(LOLV1、LOL-v2-real、LOL-v2-synthetic)和模型复杂度的定量性能。如表 I 所示,LYT-Net 在 PSNR 和 SSIM 方面在所有版本的 LOL 数据集上始终优于当前的 SOTA 方法。此外,LYTNet 非常高效,只需要 3.49G FLOPS 且仅使用 0.045M 参数,这使得它比其他通常更复杂的 SOTA 方法具有显着优势。唯一的例外是 3DLUT[35],它在复杂性方面与我们的方法相当。然而,LYT-Net 在 PSNR 和 SSIM 中明显优于 3DLUT 方法。这种强大的性能和低复杂度的组合突出了 LYT-Net 的整体有效性。

定性结果: LOL数据集上的LYT-Net与SOTALLIE技术的定性评估如图3所示,LIME[38]上的图4所示。以前的方法,如KiND[17]和Restormer[33],表现出颜色失真问题,如图3所示。此外,几种算法(如MIRNet[20]和SNR-Net[22])往往会产生过度曝光或曝光不足的区域,在增强亮度的同时损害图像对比度。同样,图 4 表明 SRIE [39]、DeHz [40] 和 NPE [41] 导致对比度损失。一般来说,我们的LYT-Net在提高能见度和增强低对比度或光线较差的区域方面非常有效,同时在不引入斑点或伪影的情况下有效地消除噪声。

 4. 消融实验

        消融研究是在LOLV1数据集上进行的,使用PSNR作为定量指标,并评估CWD和MSEF块的影响。在 YUV 分解中,将 CWD 应用于 Y 通道(用作照明图)会导致照明伪影的保留,导致与池化操作和基于插值的上采样相比性能下降,从而平滑照明以获得更好的结果。然而,CWD增强了色度通道(U和V),在不引入噪声的情况下保留细节。此外,MSEF 块始终在所有 CWD 组合中提高性能,PSNR 分别提高了 0.16、0.24 和 0.26 dB,同时将参数计数提高了 546。

5. 结论

         我们引入了LYT-Net,这是一种创新的基于transformer的轻量级模型,用于增强低光照图像。我们的方法利用双路径框架,分别处理色度和亮度,以提高模型管理光照调整和恢复损坏区域的能力。LYT-Net 集成了多层和模块化块,包括两个独特的组件——ChannelWise Denoiser (CWD) 和多阶段 Squeeze & Excite Fusion (MSEF)——以及具有多头自注意力 (MHSA) 的传统视觉transformer (ViT) 块。全面的定性和定量分析表明,LYT-Net 在 PSNR 和 SSIM 方面在所有版本的 LOL 数据集上始终优于 SOTA 方法,同时保持了较高的计算效率。

附:

LYT-Net 网络结构的Pytorch代码,来自官方实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class LayerNormalization(nn.Module):
    def __init__(self, dim):
        super(LayerNormalization, self).__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # Rearrange the tensor for LayerNorm (B, C, H, W) to (B, H, W, C)
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        # Rearrange back to (B, C, H, W)
        return x.permute(0, 3, 1, 2)

class SEBlock(nn.Module):
    def __init__(self, input_channels, reduction_ratio=16):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(input_channels, input_channels // reduction_ratio)
        self.fc2 = nn.Linear(input_channels // reduction_ratio, input_channels)
        self._init_weights()

    def forward(self, x):
        batch_size, num_channels, _, _ = x.size()
        y = self.pool(x).reshape(batch_size, num_channels)
        y = F.relu(self.fc1(y))
        y = torch.tanh(self.fc2(y))
        y = y.reshape(batch_size, num_channels, 1, 1)
        return x * y
    
    def _init_weights(self):
        init.kaiming_uniform_(self.fc1.weight, a=0, mode='fan_in', nonlinearity='relu')
        init.kaiming_uniform_(self.fc2.weight, a=0, mode='fan_in', nonlinearity='relu')
        init.constant_(self.fc1.bias, 0)
        init.constant_(self.fc2.bias, 0)

class MSEFBlock(nn.Module):
    def __init__(self, filters):
        super(MSEFBlock, self).__init__()
        self.layer_norm = LayerNormalization(filters)
        self.depthwise_conv = nn.Conv2d(filters, filters, kernel_size=3, padding=1, groups=filters)
        self.se_attn = SEBlock(filters)
        self._init_weights()

    def forward(self, x):
        x_norm = self.layer_norm(x)
        x1 = self.depthwise_conv(x_norm)
        x2 = self.se_attn(x_norm)
        x_fused = x1 * x2
        x_out = x_fused + x
        return x_out
    
    def _init_weights(self):
        init.kaiming_uniform_(self.depthwise_conv.weight, a=0, mode='fan_in', nonlinearity='relu')
        init.constant_(self.depthwise_conv.bias, 0)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        assert embed_size % num_heads == 0
        self.head_dim = embed_size // num_heads
        self.query_dense = nn.Linear(embed_size, embed_size)
        self.key_dense = nn.Linear(embed_size, embed_size)
        self.value_dense = nn.Linear(embed_size, embed_size)
        self.combine_heads = nn.Linear(embed_size, embed_size)
        self._init_weights()

    def split_heads(self, x, batch_size):
        x = x.reshape(batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        x = x.reshape(batch_size, height * width, -1)

        query = self.split_heads(self.query_dense(x), batch_size)
        key = self.split_heads(self.key_dense(x), batch_size)
        value = self.split_heads(self.value_dense(x), batch_size)
        
        attention_weights = F.softmax(torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5), dim=-1)
        attention = torch.matmul(attention_weights, value)
        attention = attention.permute(0, 2, 1, 3).contiguous().reshape(batch_size, -1, self.embed_size)
        
        output = self.combine_heads(attention)
        
        return output.reshape(batch_size, height, width, self.embed_size).permute(0, 3, 1, 2)

    def _init_weights(self):
        init.xavier_uniform_(self.query_dense.weight)
        init.xavier_uniform_(self.key_dense.weight)
        init.xavier_uniform_(self.value_dense.weight)
        init.xavier_uniform_(self.combine_heads.weight)
        init.constant_(self.query_dense.bias, 0)
        init.constant_(self.key_dense.bias, 0)
        init.constant_(self.value_dense.bias, 0)
        init.constant_(self.combine_heads.bias, 0)

class Denoiser(nn.Module):
    def __init__(self, num_filters, kernel_size=3, activation='relu'):
        super(Denoiser, self).__init__()
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size=kernel_size, padding=1)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
        self.conv3 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
        self.conv4 = nn.Conv2d(num_filters, num_filters, kernel_size=kernel_size, stride=2, padding=1)
        self.bottleneck = MultiHeadSelfAttention(embed_size=num_filters, num_heads=4)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
        self.output_layer = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=1)
        self.res_layer = nn.Conv2d(num_filters, 1, kernel_size=kernel_size, padding=1)
        self.activation = getattr(F, activation)
        self._init_weights()

    def forward(self, x):
        x1 = self.activation(self.conv1(x))
        x2 = self.activation(self.conv2(x1))
        x3 = self.activation(self.conv3(x2))
        x4 = self.activation(self.conv4(x3))
        x = self.bottleneck(x4)
        x = self.up4(x)
        x = self.up3(x + x3)
        x = self.up2(x + x2)
        x = x + x1
        x = self.res_layer(x)
        return torch.tanh(self.output_layer(x + x))
    
    def _init_weights(self):
        for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.output_layer, self.res_layer]:
            init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='relu')
            if layer.bias is not None:
                init.constant_(layer.bias, 0)

class LYT(nn.Module):
    def __init__(self, filters=32):
        super(LYT, self).__init__()
        self.process_y = self._create_processing_layers(filters)
        self.process_cb = self._create_processing_layers(filters)
        self.process_cr = self._create_processing_layers(filters)

        self.denoiser_cb = Denoiser(filters // 2)
        self.denoiser_cr = Denoiser(filters // 2)
        self.lum_pool = nn.MaxPool2d(8)
        self.lum_mhsa = MultiHeadSelfAttention(embed_size=filters, num_heads=4)
        self.lum_up = nn.Upsample(scale_factor=8, mode='nearest')
        self.lum_conv = nn.Conv2d(filters, filters, kernel_size=1, padding=0)
        self.ref_conv = nn.Conv2d(filters * 2, filters, kernel_size=1, padding=0)
        self.msef = MSEFBlock(filters)
        self.recombine = nn.Conv2d(filters * 2, filters, kernel_size=3, padding=1)
        self.final_adjustments = nn.Conv2d(filters, 3, kernel_size=3, padding=1)
        self._init_weights()

    def _create_processing_layers(self, filters):
        return nn.Sequential(
            nn.Conv2d(1, filters, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def _rgb_to_ycbcr(self, image):
        r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]
    
        y = 0.299 * r + 0.587 * g + 0.114 * b
        u = -0.14713 * r - 0.28886 * g + 0.436 * b + 0.5
        v = 0.615 * r - 0.51499 * g - 0.10001 * b + 0.5
        
        yuv = torch.stack((y, u, v), dim=1)
        return yuv

    def forward(self, inputs):
        ycbcr = self._rgb_to_ycbcr(inputs)
        y, cb, cr = torch.split(ycbcr, 1, dim=1)
        cb = self.denoiser_cb(cb) + cb
        cr = self.denoiser_cr(cr) + cr

        y_processed = self.process_y(y)
        cb_processed = self.process_cb(cb)
        cr_processed = self.process_cr(cr)

        ref = torch.cat([cb_processed, cr_processed], dim=1)
        lum = y_processed
        lum_1 = self.lum_pool(lum)
        lum_1 = self.lum_mhsa(lum_1)
        lum_1 = self.lum_up(lum_1)
        lum = lum + lum_1

        ref = self.ref_conv(ref)
        shortcut = ref
        ref = ref + 0.2 * self.lum_conv(lum)
        ref = self.msef(ref)
        ref = ref + shortcut

        recombined = self.recombine(torch.cat([ref, lum], dim=1))
        output = self.final_adjustments(recombined)
        return torch.sigmoid(output)
    
    def _init_weights(self):
        for module in self.children():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
                if module.bias is not None:
                    init.constant_(module.bias, 0)
                    

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐