基于 ResNet50V2 和 DenseNet121 的学习,本周探索两种架构的融合可能性
构建一个兼具两者优势的新模型框架 ResDenseNet,并用同一图像分类任务验证效果。

一、前置知识

1、知识总结(融合核心思想)

对比项

ResNet50V2

DenseNet121

ResDenseNet (本文)

核心思想

残差连接 (add)

密集连接 (concat)

内部 concat + 外部 add

参数量

~25.6M

~7.0M

~5.0M

特征复用

弱 (跨层求和)

强 (保留全部)

强 (密集复用 + 压缩精炼)

梯度回传

残差路径直通

concat 路径

残差直通 + 密集辅助

激活顺序

Pre-Activation

Post-Activation

Pre-Activation

数据增强

翻转+颜色抖动

翻转+颜色抖动

学习率策略

固定 lr

余弦退火

余弦退火

正则化

标签平滑

标签平滑 + weight_decay

ResDenseNet 成功融合了 ResNet 和 DenseNet 的核心优势:

  1. Dense Residual Block 在保持参数高效的同时,实现了丰富的特征复用
  2. 残差连接 确保了深层网络的梯度回传畅通
  3. Pre-Activation 设计优化了 Batch Normalization 的效果
  4. 整体架构在参数量、特征利用率和梯度稳定性之间取得了良好的平衡

2、模型架构图

运行下方代码查看 ResDenseNet 的总体架构和核心模块 Dense Residual Block 的详细结构。

import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import matplotlib.patches as mpatches

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 11))
fig.suptitle('ResDenseNet Architectural Design', fontsize=18, fontweight='bold', y=0.98)

def draw_box(ax, cx, cy, w, h, text, color, fs=8.5):
    box = FancyBboxPatch((cx - w/2, cy - h/2), w, h,
                          boxstyle="round,pad=0.1",
                          facecolor=color, edgecolor='#333', linewidth=1.5)
    ax.add_patch(box)
    ax.text(cx, cy, text, ha='center', va='center', fontsize=fs, fontweight='bold')

# ==================== 左图: 总体架构 ====================
ax1.set_xlim(0, 10); ax1.set_ylim(2, 22)
ax1.axis('off')
ax1.set_title('Architecture‌', fontsize=14, fontweight='bold')

layers = [
    (21.0, 6.0, 0.9, 'Input (3 x 224 x 224)', '#B3D9FF'),
    (19.5, 6.5, 1.2, 'Stem: Conv7x7 -> BN -> ReLU -> MaxPool\n-> 64 x 56 x 56', '#FFF2CC'),
    (17.8, 6.0, 0.9, 'Stage 1:  DRB x 3    (64 ch, 56x56)', '#C8E6C9'),
    (16.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 128 x 28 x 28', '#FFF2CC'),
    (14.8, 6.0, 0.9, 'Stage 2:  DRB x 4    (128 ch, 28x28)', '#C8E6C9'),
    (13.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 256 x 14 x 14', '#FFF2CC'),
    (11.8, 6.0, 0.9, 'Stage 3:  DRB x 6    (256 ch, 14x14)', '#C8E6C9'),
    (10.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 512 x 7 x 7', '#FFF2CC'),
    (8.8,  6.0, 0.9, 'Stage 4:  DRB x 4    (512 ch, 7x7)', '#C8E6C9'),
    (7.1,  5.5, 0.9, 'BN -> ReLU -> GAP -> FC(3)', '#FFCDD2'),
]

for (y, w, h, text, color) in layers:
    draw_box(ax1, 5, y, w, h, text, color)

for i in range(len(layers)-1):
    y_from = layers[i][0] - layers[i][2]/2
    y_to   = layers[i+1][0] + layers[i+1][2]/2
    ax1.annotate('', xy=(5, y_to), xytext=(5, y_from),
                 arrowprops=dict(arrowstyle='->', color='#555', lw=1.5))

ax1.legend(handles=[
    mpatches.Patch(color='#C8E6C9', label='Dense Residual Block (DRB)'),
    mpatches.Patch(color='#FFF2CC', label='Transition Layer'),
    mpatches.Patch(color='#FFCDD2', label='Classification Head'),
], loc='lower center', fontsize=9, framealpha=0.9)

# ==================== 右图: DRB 详细结构 ====================
ax2.set_xlim(0, 10); ax2.set_ylim(4, 22)
ax2.axis('off')
ax2.set_title('Dense Residual Block (DRB) Detailed Structure', fontsize=14, fontweight='bold')

draw_box(ax2, 3.5, 21, 4.5, 0.8, 'Input (C, H, W)', '#B3D9FF', fs=9)

for i, y in enumerate([19.2, 17.4, 15.6]):
    draw_box(ax2, 3.5, y, 4.5, 1.0,
             f'DenseLayer {i+1}\nBN->ReLU->1x1->BN->ReLU->3x3', '#FFF9C4', fs=8)
    ax2.text(6.3, y, 'concat', fontsize=8, color='#1565C0',
             fontstyle='italic', va='center', fontweight='bold')

ax2.text(3.5, 14.4, '...', ha='center', fontsize=16, fontweight='bold')

draw_box(ax2, 3.5, 13.2, 4.5, 1.0,
         'DenseLayer L\nBN->ReLU->1x1->BN->ReLU->3x3', '#FFF9C4', fs=8)
ax2.text(6.3, 13.2, 'concat', fontsize=8, color='#1565C0',
         fontstyle='italic', va='center', fontweight='bold')

draw_box(ax2, 3.5, 11.4, 4.0, 0.8, 'Conv 1x1 Compress\n(C+L*k -> C)', '#FFCDD2', fs=8)

circle = plt.Circle((3.5, 9.7), 0.4, facecolor='#C8E6C9', edgecolor='#333', lw=2)
ax2.add_patch(circle)
ax2.text(3.5, 9.7, '+', ha='center', va='center', fontsize=16, fontweight='bold')

draw_box(ax2, 3.5, 8.2, 4.5, 0.8, 'Output (C, H, W)', '#B3D9FF', fs=9)

# 主路径箭头
arrow_pairs = [
    (21-0.4, 19.7), (18.7, 17.9), (16.9, 16.1),
    (15.1, 14.7), (14.1, 13.7), (12.7, 11.8),
    (11.0, 10.1), (9.3, 8.6)
]
for (yf, yt) in arrow_pairs:
    ax2.annotate('', xy=(3.5, yt), xytext=(3.5, yf),
                 arrowprops=dict(arrowstyle='->', color='#555', lw=1.3))

# 残差箭头 (红色虚线)
ax2.annotate('', xy=(7.8, 9.7), xytext=(7.8, 21),
             arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))
ax2.annotate('', xy=(7.8, 21), xytext=(5.75, 21),
             arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))
ax2.annotate('', xy=(4.4, 9.7), xytext=(7.8, 9.7),
             arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))

ax2.text(9.0, 15.5, 'Residual\nShortcut\n(Identity)', ha='center', va='center',
         fontsize=10, color='#D32F2F', fontweight='bold',
         bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
                  edgecolor='#D32F2F', alpha=0.9))

ax2.text(5, 5.5,
         'DenseNet feature reuse (concat) + ResNet gradient highway (add)\n'
         'Internal dense -> rich features | External residual -> gradient flow',
         ha='center', va='center', fontsize=10, color='#1B5E20', fontweight='bold',
         bbox=dict(boxstyle='round,pad=0.5', facecolor='#E8F5E9',
                  edgecolor='#2E7D32', lw=2, alpha=0.9))

plt.tight_layout()
plt.show()

二、代码实现

1、准备工作

1.1 设置GPU

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import os, copy, warnings
import numpy as np
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

1.2 导入数据(增强策略)

data_dir = './data/day01'

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
total_data
Dataset ImageFolder
    Number of datapoints: 1661
    Root location: ./data/day01
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
total_data.class_to_idx
{'0Normal': 0, '2Mild': 1, '4Severe': 2}

1.3 划分数据集

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])

# 测试集使用无增强 transform
test_dataset.dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataset.dataset = datasets.ImageFolder(data_dir, transform=train_transforms)

batch_size = 8

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                       shuffle=True, num_workers=0)
test_dl  = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                      num_workers=0)

for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
Shape of X [N, C, H, W]:  torch.Size([8, 3, 224, 224])
Shape of y:  torch.Size([8]) torch.int64

2、搭建 ResDenseNet 模型

ResDenseNet 设计理念

Dense Residual Block (DRB) 是核心创新模块,融合了两个网络的优势:

DenseNet 的优势:                      ResNet 的优势:
  每层与前面所有层直接拼接                跨层残差连接
  → 特征复用最大化                       → 梯度回传无阻碍
  → 参数效率高                           → 可以训练极深网络
       |                                      |
       +---------------  融  合  --------------+
                        |
                Dense Residual Block
        内部: 密集连接 (concat)  ← 来自 DenseNet
        外部: 残差连接 (add)     ← 来自 ResNet

具体设计:

  • 每个 DRB 内部有 L 个 DenseLayer,每层输出 growth_rate 个新特征并 concat
  • 所有 DenseLayer 执行完后,1×1 卷积压缩回原始通道数
  • 最后将压缩结果与输入相加(残差连接)
  • Transition 层负责下采样和通道扩展
class DenseLayer(nn.Module):
    """密集层:Pre-Activation 瓶颈结构 (BN->ReLU->1x1Conv->BN->ReLU->3x3Conv)
    
    借鉴 ResNetV2 的 Pre-Activation 设计,将 BN 和 ReLU 放在卷积之前。
    瓶颈结构先用 1x1 卷积降维(bn_size * growth_rate),再用 3x3 卷积产生新特征。
    """
    def __init__(self, in_channels, growth_rate, bn_size=4):
        super(DenseLayer, self).__init__()
        # Pre-activation: BN -> ReLU -> Conv
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, bn_size * growth_rate,
                               kernel_size=1, stride=1, bias=False)
        self.bn2   = nn.BatchNorm2d(bn_size * growth_rate)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate,
                               kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        return torch.cat([x, out], dim=1)  # DenseNet 风格: 拼接


class DenseResidualBlock(nn.Module):
    """密集残差块 (DRB): 内部密集连接 + 外部残差连接
    
    这是 ResDenseNet 的核心创新模块:
    1. 内部: 多个 DenseLayer 密集拼接 -> 特征复用 (DenseNet)
    2. 外部: 压缩后与输入相加 -> 梯度直通 (ResNet)
    """
    def __init__(self, in_channels, num_layers, growth_rate, bn_size=4):
        super(DenseResidualBlock, self).__init__()

        # 密集层
        self.dense_layers = nn.ModuleList()
        for i in range(num_layers):
            self.dense_layers.append(
                DenseLayer(in_channels + i * growth_rate, growth_rate, bn_size)
            )

        # 压缩层: 将密集拼接的高维特征映射回原始通道数
        total_channels = in_channels + num_layers * growth_rate
        self.compress = nn.Sequential(
            nn.BatchNorm2d(total_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(total_channels, in_channels, kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        # 密集连接路径
        features = x
        for layer in self.dense_layers:
            features = layer(features)
        # 压缩 + 残差连接
        out = self.compress(features) + x   # ResNet 风格: 残差相加
        return out

class Transition(nn.Sequential):
    """过渡层: BN->ReLU->1x1Conv->AvgPool"""
    def __init__(self, in_channels, out_channels):
        super(Transition, self).__init__()
        self.add_module('bn',   nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(in_channels, out_channels,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))

class ResDenseNet(nn.Module):
    """ResDenseNet: 融合 ResNet 残差连接与 DenseNet 密集连接
    
    架构: Stem -> [DRB -> Transition] x 3 -> DRB -> BN -> ReLU -> GAP -> FC
    
    block_config = (3, 4, 6, 4)  比 DenseNet121 的 (6,12,24,16) 更轻量
    growth_rate  = 32             每个密集层产生 32 个新特征
    """
    def __init__(self, growth_rate=32, block_config=(3, 4, 6, 4),
                 num_init_features=64, bn_size=4, num_classes=1000):
        super(ResDenseNet, self).__init__()

        # ===== Stem =====
        self.stem = nn.Sequential(
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
                      padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # ===== Stages =====
        self.stages = nn.ModuleList()
        self.transitions = nn.ModuleList()

        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            # Dense Residual Block
            self.stages.append(
                DenseResidualBlock(num_features, num_layers, growth_rate, bn_size)
            )
            # Transition (最后一个 stage 后不加)
            if i != len(block_config) - 1:
                next_features = num_features * 2
                self.transitions.append(
                    Transition(num_features, next_features)
                )
                num_features = next_features

        # ===== Final BN + Classifier =====
        self.final_bn   = nn.BatchNorm2d(num_features)
        self.classifier = nn.Linear(num_features, num_classes)

        # ===== 权重初始化 (Kaiming) =====
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        for i, stage in enumerate(self.stages):
            x = stage(x)
            if i < len(self.transitions):
                x = self.transitions[i](x)
        x = F.relu(self.final_bn(x), inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = ResDenseNet(num_classes=3).to(device)
model
ResDenseNet(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (stages): ModuleList(
    (0): DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (2): DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (4): DenseLayer(
          (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (5): DenseLayer(
          (bn1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (3): DenseResidualBlock(
      (dense_layers): ModuleList(
        (0): DenseLayer(
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): DenseLayer(
          (bn1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): DenseLayer(
          (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (3): DenseLayer(
          (bn1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (compress): Sequential(
        (0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(640, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
  )
  (transitions): ModuleList(
    (0): Transition(
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (1): Transition(
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (2): Transition(
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
  )
  (final_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): Linear(in_features=512, out_features=3, bias=True)
)

2.1 查看模型详情

import torchsummary as summary
summary.summary(model, (3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
            Conv2d-6          [-1, 128, 56, 56]           8,192
       BatchNorm2d-7          [-1, 128, 56, 56]             256
            Conv2d-8           [-1, 32, 56, 56]          36,864
        DenseLayer-9           [-1, 96, 56, 56]               0
      BatchNorm2d-10           [-1, 96, 56, 56]             192
           Conv2d-11          [-1, 128, 56, 56]          12,288
      BatchNorm2d-12          [-1, 128, 56, 56]             256
           Conv2d-13           [-1, 32, 56, 56]          36,864
       DenseLayer-14          [-1, 128, 56, 56]               0
      BatchNorm2d-15          [-1, 128, 56, 56]             256
           Conv2d-16          [-1, 128, 56, 56]          16,384
      BatchNorm2d-17          [-1, 128, 56, 56]             256
           Conv2d-18           [-1, 32, 56, 56]          36,864
       DenseLayer-19          [-1, 160, 56, 56]               0
      BatchNorm2d-20          [-1, 160, 56, 56]             320
             ReLU-21          [-1, 160, 56, 56]               0
           Conv2d-22           [-1, 64, 56, 56]          10,240
DenseResidualBlock-23           [-1, 64, 56, 56]               0
      BatchNorm2d-24           [-1, 64, 56, 56]             128
             ReLU-25           [-1, 64, 56, 56]               0
           Conv2d-26          [-1, 128, 56, 56]           8,192
        AvgPool2d-27          [-1, 128, 28, 28]               0
      BatchNorm2d-28          [-1, 128, 28, 28]             256
           Conv2d-29          [-1, 128, 28, 28]          16,384
      BatchNorm2d-30          [-1, 128, 28, 28]             256
           Conv2d-31           [-1, 32, 28, 28]          36,864
       DenseLayer-32          [-1, 160, 28, 28]               0
      BatchNorm2d-33          [-1, 160, 28, 28]             320
           Conv2d-34          [-1, 128, 28, 28]          20,480
      BatchNorm2d-35          [-1, 128, 28, 28]             256
           Conv2d-36           [-1, 32, 28, 28]          36,864
       DenseLayer-37          [-1, 192, 28, 28]               0
      BatchNorm2d-38          [-1, 192, 28, 28]             384
           Conv2d-39          [-1, 128, 28, 28]          24,576
      BatchNorm2d-40          [-1, 128, 28, 28]             256
           Conv2d-41           [-1, 32, 28, 28]          36,864
       DenseLayer-42          [-1, 224, 28, 28]               0
      BatchNorm2d-43          [-1, 224, 28, 28]             448
           Conv2d-44          [-1, 128, 28, 28]          28,672
      BatchNorm2d-45          [-1, 128, 28, 28]             256
           Conv2d-46           [-1, 32, 28, 28]          36,864
       DenseLayer-47          [-1, 256, 28, 28]               0
      BatchNorm2d-48          [-1, 256, 28, 28]             512
             ReLU-49          [-1, 256, 28, 28]               0
           Conv2d-50          [-1, 128, 28, 28]          32,768
DenseResidualBlock-51          [-1, 128, 28, 28]               0
      BatchNorm2d-52          [-1, 128, 28, 28]             256
             ReLU-53          [-1, 128, 28, 28]               0
           Conv2d-54          [-1, 256, 28, 28]          32,768
        AvgPool2d-55          [-1, 256, 14, 14]               0
      BatchNorm2d-56          [-1, 256, 14, 14]             512
           Conv2d-57          [-1, 128, 14, 14]          32,768
      BatchNorm2d-58          [-1, 128, 14, 14]             256
           Conv2d-59           [-1, 32, 14, 14]          36,864
       DenseLayer-60          [-1, 288, 14, 14]               0
      BatchNorm2d-61          [-1, 288, 14, 14]             576
           Conv2d-62          [-1, 128, 14, 14]          36,864
      BatchNorm2d-63          [-1, 128, 14, 14]             256
           Conv2d-64           [-1, 32, 14, 14]          36,864
       DenseLayer-65          [-1, 320, 14, 14]               0
      BatchNorm2d-66          [-1, 320, 14, 14]             640
           Conv2d-67          [-1, 128, 14, 14]          40,960
      BatchNorm2d-68          [-1, 128, 14, 14]             256
           Conv2d-69           [-1, 32, 14, 14]          36,864
       DenseLayer-70          [-1, 352, 14, 14]               0
      BatchNorm2d-71          [-1, 352, 14, 14]             704
           Conv2d-72          [-1, 128, 14, 14]          45,056
      BatchNorm2d-73          [-1, 128, 14, 14]             256
           Conv2d-74           [-1, 32, 14, 14]          36,864
       DenseLayer-75          [-1, 384, 14, 14]               0
      BatchNorm2d-76          [-1, 384, 14, 14]             768
           Conv2d-77          [-1, 128, 14, 14]          49,152
      BatchNorm2d-78          [-1, 128, 14, 14]             256
           Conv2d-79           [-1, 32, 14, 14]          36,864
       DenseLayer-80          [-1, 416, 14, 14]               0
      BatchNorm2d-81          [-1, 416, 14, 14]             832
           Conv2d-82          [-1, 128, 14, 14]          53,248
      BatchNorm2d-83          [-1, 128, 14, 14]             256
           Conv2d-84           [-1, 32, 14, 14]          36,864
       DenseLayer-85          [-1, 448, 14, 14]               0
      BatchNorm2d-86          [-1, 448, 14, 14]             896
             ReLU-87          [-1, 448, 14, 14]               0
           Conv2d-88          [-1, 256, 14, 14]         114,688
DenseResidualBlock-89          [-1, 256, 14, 14]               0
      BatchNorm2d-90          [-1, 256, 14, 14]             512
             ReLU-91          [-1, 256, 14, 14]               0
           Conv2d-92          [-1, 512, 14, 14]         131,072
        AvgPool2d-93            [-1, 512, 7, 7]               0
      BatchNorm2d-94            [-1, 512, 7, 7]           1,024
           Conv2d-95            [-1, 128, 7, 7]          65,536
      BatchNorm2d-96            [-1, 128, 7, 7]             256
           Conv2d-97             [-1, 32, 7, 7]          36,864
       DenseLayer-98            [-1, 544, 7, 7]               0
      BatchNorm2d-99            [-1, 544, 7, 7]           1,088
          Conv2d-100            [-1, 128, 7, 7]          69,632
     BatchNorm2d-101            [-1, 128, 7, 7]             256
          Conv2d-102             [-1, 32, 7, 7]          36,864
      DenseLayer-103            [-1, 576, 7, 7]               0
     BatchNorm2d-104            [-1, 576, 7, 7]           1,152
          Conv2d-105            [-1, 128, 7, 7]          73,728
     BatchNorm2d-106            [-1, 128, 7, 7]             256
          Conv2d-107             [-1, 32, 7, 7]          36,864
      DenseLayer-108            [-1, 608, 7, 7]               0
     BatchNorm2d-109            [-1, 608, 7, 7]           1,216
          Conv2d-110            [-1, 128, 7, 7]          77,824
     BatchNorm2d-111            [-1, 128, 7, 7]             256
          Conv2d-112             [-1, 32, 7, 7]          36,864
      DenseLayer-113            [-1, 640, 7, 7]               0
     BatchNorm2d-114            [-1, 640, 7, 7]           1,280
            ReLU-115            [-1, 640, 7, 7]               0
          Conv2d-116            [-1, 512, 7, 7]         327,680
DenseResidualBlock-117            [-1, 512, 7, 7]               0
     BatchNorm2d-118            [-1, 512, 7, 7]           1,024
          Linear-119                    [-1, 3]           1,539
================================================================
Total params: 1,986,691
Trainable params: 1,986,691
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 114.32
Params size (MB): 7.58
Estimated Total Size (MB): 122.47
----------------------------------------------------------------

3、训练模型

3.1 编写训练函数

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    train_loss, train_acc = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3.2 编写测试函数

def test(dataloader, model, loss_fn):
    size        = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, test_acc = 0, 0

    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

3.3 正式训练

训练策略(综合 J2/J3 最优实践):

  • AdamW 优化器 + 权重衰减 (weight_decay=1e-4)
  • 标签平滑 (label_smoothing=0.1) 防止过拟合
  • 余弦退火 学习率调度 (CosineAnnealingLR)
  • Kaiming 权重初始化
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn   = nn.CrossEntropyLoss(label_smoothing=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

epochs = 10

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

best_acc = 0

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)

    scheduler.step()  # 余弦退火更新学习率

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    if epoch_test_acc > best_acc:
        best_acc   = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    lr = optimizer.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, '
                'Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
                          epoch_test_acc*100, epoch_test_loss, lr))

# 保存最佳模型
PATH = './model/day04_resdensenet_best_model.pth'
os.makedirs(os.path.dirname(PATH), exist_ok=True)
torch.save(best_model.state_dict(), PATH)

print('Done')
Epoch: 1, Train_acc:63.6%, Train_loss:0.920, Test_acc:32.7%, Test_loss:1.586, Lr:9.76E-04
Epoch: 2, Train_acc:67.3%, Train_loss:0.867, Test_acc:69.4%, Test_loss:0.870, Lr:9.05E-04
Epoch: 3, Train_acc:69.1%, Train_loss:0.855, Test_acc:63.7%, Test_loss:0.977, Lr:7.94E-04
Epoch: 4, Train_acc:73.0%, Train_loss:0.799, Test_acc:80.2%, Test_loss:0.765, Lr:6.55E-04
Epoch: 5, Train_acc:74.9%, Train_loss:0.744, Test_acc:83.8%, Test_loss:0.651, Lr:5.01E-04
Epoch: 6, Train_acc:76.0%, Train_loss:0.731, Test_acc:62.2%, Test_loss:1.192, Lr:3.46E-04
Epoch: 7, Train_acc:80.0%, Train_loss:0.674, Test_acc:81.4%, Test_loss:0.663, Lr:2.07E-04
Epoch: 8, Train_acc:80.0%, Train_loss:0.667, Test_acc:77.2%, Test_loss:0.760, Lr:9.64E-05
Epoch: 9, Train_acc:82.8%, Train_loss:0.613, Test_acc:85.3%, Test_loss:0.542, Lr:2.54E-05
Epoch:10, Train_acc:84.1%, Train_loss:0.593, Test_acc:85.0%, Test_loss:0.557, Lr:1.00E-06
Done

4、结果可视化

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']    = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi']         = 100

from datetime import datetime
current_time = datetime.now()

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc,  label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('ResDenseNet - Training and Validation Accuracy')
plt.xlabel(current_time)

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss,  label='Test Loss')
plt.legend(loc='upper right')
plt.title('ResDenseNet - Training and Validation Loss')
plt.show()

5、模型评估

best_model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'ResDenseNet Best Test Accuracy: {epoch_test_acc*100:.1f}%')
print(f'ResDenseNet Best Test Loss:     {epoch_test_loss:.4f}')
ResDenseNet Best Test Accuracy: 85.3%
ResDenseNet Best Test Loss:     0.5423
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐