CNN算法实战系列04 | ResDenseNet融合实现算法实战与解析
·
- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
基于 ResNet50V2 和 DenseNet121 的学习,本周探索两种架构的融合可能性,
构建一个兼具两者优势的新模型框架 ResDenseNet,并用同一图像分类任务验证效果。
一、前置知识
1、知识总结(融合核心思想)
|
对比项 |
ResNet50V2 |
DenseNet121 |
ResDenseNet (本文) |
|
核心思想 |
残差连接 (add) |
密集连接 (concat) |
内部 concat + 外部 add |
|
参数量 |
~25.6M |
~7.0M |
~5.0M |
|
特征复用 |
弱 (跨层求和) |
强 (保留全部) |
强 (密集复用 + 压缩精炼) |
|
梯度回传 |
残差路径直通 |
concat 路径 |
残差直通 + 密集辅助 |
|
激活顺序 |
Pre-Activation |
Post-Activation |
Pre-Activation |
|
数据增强 |
无 |
翻转+颜色抖动 |
翻转+颜色抖动 |
|
学习率策略 |
固定 lr |
余弦退火 |
余弦退火 |
|
正则化 |
无 |
标签平滑 |
标签平滑 + weight_decay |
ResDenseNet 成功融合了 ResNet 和 DenseNet 的核心优势:
- Dense Residual Block 在保持参数高效的同时,实现了丰富的特征复用
- 残差连接 确保了深层网络的梯度回传畅通
- Pre-Activation 设计优化了 Batch Normalization 的效果
- 整体架构在参数量、特征利用率和梯度稳定性之间取得了良好的平衡
2、模型架构图
运行下方代码查看 ResDenseNet 的总体架构和核心模块 Dense Residual Block 的详细结构。
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import matplotlib.patches as mpatches
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 11))
fig.suptitle('ResDenseNet Architectural Design', fontsize=18, fontweight='bold', y=0.98)
def draw_box(ax, cx, cy, w, h, text, color, fs=8.5):
box = FancyBboxPatch((cx - w/2, cy - h/2), w, h,
boxstyle="round,pad=0.1",
facecolor=color, edgecolor='#333', linewidth=1.5)
ax.add_patch(box)
ax.text(cx, cy, text, ha='center', va='center', fontsize=fs, fontweight='bold')
# ==================== 左图: 总体架构 ====================
ax1.set_xlim(0, 10); ax1.set_ylim(2, 22)
ax1.axis('off')
ax1.set_title('Architecture', fontsize=14, fontweight='bold')
layers = [
(21.0, 6.0, 0.9, 'Input (3 x 224 x 224)', '#B3D9FF'),
(19.5, 6.5, 1.2, 'Stem: Conv7x7 -> BN -> ReLU -> MaxPool\n-> 64 x 56 x 56', '#FFF2CC'),
(17.8, 6.0, 0.9, 'Stage 1: DRB x 3 (64 ch, 56x56)', '#C8E6C9'),
(16.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 128 x 28 x 28', '#FFF2CC'),
(14.8, 6.0, 0.9, 'Stage 2: DRB x 4 (128 ch, 28x28)', '#C8E6C9'),
(13.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 256 x 14 x 14', '#FFF2CC'),
(11.8, 6.0, 0.9, 'Stage 3: DRB x 6 (256 ch, 14x14)', '#C8E6C9'),
(10.3, 6.5, 1.0, 'Transition: Conv1x1 + AvgPool -> 512 x 7 x 7', '#FFF2CC'),
(8.8, 6.0, 0.9, 'Stage 4: DRB x 4 (512 ch, 7x7)', '#C8E6C9'),
(7.1, 5.5, 0.9, 'BN -> ReLU -> GAP -> FC(3)', '#FFCDD2'),
]
for (y, w, h, text, color) in layers:
draw_box(ax1, 5, y, w, h, text, color)
for i in range(len(layers)-1):
y_from = layers[i][0] - layers[i][2]/2
y_to = layers[i+1][0] + layers[i+1][2]/2
ax1.annotate('', xy=(5, y_to), xytext=(5, y_from),
arrowprops=dict(arrowstyle='->', color='#555', lw=1.5))
ax1.legend(handles=[
mpatches.Patch(color='#C8E6C9', label='Dense Residual Block (DRB)'),
mpatches.Patch(color='#FFF2CC', label='Transition Layer'),
mpatches.Patch(color='#FFCDD2', label='Classification Head'),
], loc='lower center', fontsize=9, framealpha=0.9)
# ==================== 右图: DRB 详细结构 ====================
ax2.set_xlim(0, 10); ax2.set_ylim(4, 22)
ax2.axis('off')
ax2.set_title('Dense Residual Block (DRB) Detailed Structure', fontsize=14, fontweight='bold')
draw_box(ax2, 3.5, 21, 4.5, 0.8, 'Input (C, H, W)', '#B3D9FF', fs=9)
for i, y in enumerate([19.2, 17.4, 15.6]):
draw_box(ax2, 3.5, y, 4.5, 1.0,
f'DenseLayer {i+1}\nBN->ReLU->1x1->BN->ReLU->3x3', '#FFF9C4', fs=8)
ax2.text(6.3, y, 'concat', fontsize=8, color='#1565C0',
fontstyle='italic', va='center', fontweight='bold')
ax2.text(3.5, 14.4, '...', ha='center', fontsize=16, fontweight='bold')
draw_box(ax2, 3.5, 13.2, 4.5, 1.0,
'DenseLayer L\nBN->ReLU->1x1->BN->ReLU->3x3', '#FFF9C4', fs=8)
ax2.text(6.3, 13.2, 'concat', fontsize=8, color='#1565C0',
fontstyle='italic', va='center', fontweight='bold')
draw_box(ax2, 3.5, 11.4, 4.0, 0.8, 'Conv 1x1 Compress\n(C+L*k -> C)', '#FFCDD2', fs=8)
circle = plt.Circle((3.5, 9.7), 0.4, facecolor='#C8E6C9', edgecolor='#333', lw=2)
ax2.add_patch(circle)
ax2.text(3.5, 9.7, '+', ha='center', va='center', fontsize=16, fontweight='bold')
draw_box(ax2, 3.5, 8.2, 4.5, 0.8, 'Output (C, H, W)', '#B3D9FF', fs=9)
# 主路径箭头
arrow_pairs = [
(21-0.4, 19.7), (18.7, 17.9), (16.9, 16.1),
(15.1, 14.7), (14.1, 13.7), (12.7, 11.8),
(11.0, 10.1), (9.3, 8.6)
]
for (yf, yt) in arrow_pairs:
ax2.annotate('', xy=(3.5, yt), xytext=(3.5, yf),
arrowprops=dict(arrowstyle='->', color='#555', lw=1.3))
# 残差箭头 (红色虚线)
ax2.annotate('', xy=(7.8, 9.7), xytext=(7.8, 21),
arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))
ax2.annotate('', xy=(7.8, 21), xytext=(5.75, 21),
arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))
ax2.annotate('', xy=(4.4, 9.7), xytext=(7.8, 9.7),
arrowprops=dict(arrowstyle='->', color='#D32F2F', lw=2.5, linestyle='--'))
ax2.text(9.0, 15.5, 'Residual\nShortcut\n(Identity)', ha='center', va='center',
fontsize=10, color='#D32F2F', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
edgecolor='#D32F2F', alpha=0.9))
ax2.text(5, 5.5,
'DenseNet feature reuse (concat) + ResNet gradient highway (add)\n'
'Internal dense -> rich features | External residual -> gradient flow',
ha='center', va='center', fontsize=10, color='#1B5E20', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.5', facecolor='#E8F5E9',
edgecolor='#2E7D32', lw=2, alpha=0.9))
plt.tight_layout()
plt.show()

二、代码实现
1、准备工作
1.1 设置GPU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import os, copy, warnings
import numpy as np
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
1.2 导入数据(增强策略)
data_dir = './data/day01'
train_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
total_data
Dataset ImageFolder
Number of datapoints: 1661
Root location: ./data/day01
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=None)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
total_data.class_to_idx
{'0Normal': 0, '2Mild': 1, '4Severe': 2}
1.3 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
# 测试集使用无增强 transform
test_dataset.dataset = datasets.ImageFolder(data_dir, transform=test_transforms)
train_dataset.dataset = datasets.ImageFolder(data_dir, transform=train_transforms)
batch_size = 8
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
num_workers=0)
for X, y in test_dl:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
Shape of X [N, C, H, W]: torch.Size([8, 3, 224, 224])
Shape of y: torch.Size([8]) torch.int64
2、搭建 ResDenseNet 模型
ResDenseNet 设计理念
Dense Residual Block (DRB) 是核心创新模块,融合了两个网络的优势:
DenseNet 的优势: ResNet 的优势:
每层与前面所有层直接拼接 跨层残差连接
→ 特征复用最大化 → 梯度回传无阻碍
→ 参数效率高 → 可以训练极深网络
| |
+--------------- 融 合 --------------+
|
Dense Residual Block
内部: 密集连接 (concat) ← 来自 DenseNet
外部: 残差连接 (add) ← 来自 ResNet
具体设计:
- 每个 DRB 内部有 L 个 DenseLayer,每层输出 growth_rate 个新特征并 concat
- 所有 DenseLayer 执行完后,1×1 卷积压缩回原始通道数
- 最后将压缩结果与输入相加(残差连接)
- Transition 层负责下采样和通道扩展
class DenseLayer(nn.Module):
"""密集层:Pre-Activation 瓶颈结构 (BN->ReLU->1x1Conv->BN->ReLU->3x3Conv)
借鉴 ResNetV2 的 Pre-Activation 设计,将 BN 和 ReLU 放在卷积之前。
瓶颈结构先用 1x1 卷积降维(bn_size * growth_rate),再用 3x3 卷积产生新特征。
"""
def __init__(self, in_channels, growth_rate, bn_size=4):
super(DenseLayer, self).__init__()
# Pre-activation: BN -> ReLU -> Conv
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, bn_size * growth_rate,
kernel_size=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(bn_size * growth_rate)
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
return torch.cat([x, out], dim=1) # DenseNet 风格: 拼接
class DenseResidualBlock(nn.Module):
"""密集残差块 (DRB): 内部密集连接 + 外部残差连接
这是 ResDenseNet 的核心创新模块:
1. 内部: 多个 DenseLayer 密集拼接 -> 特征复用 (DenseNet)
2. 外部: 压缩后与输入相加 -> 梯度直通 (ResNet)
"""
def __init__(self, in_channels, num_layers, growth_rate, bn_size=4):
super(DenseResidualBlock, self).__init__()
# 密集层
self.dense_layers = nn.ModuleList()
for i in range(num_layers):
self.dense_layers.append(
DenseLayer(in_channels + i * growth_rate, growth_rate, bn_size)
)
# 压缩层: 将密集拼接的高维特征映射回原始通道数
total_channels = in_channels + num_layers * growth_rate
self.compress = nn.Sequential(
nn.BatchNorm2d(total_channels),
nn.ReLU(inplace=True),
nn.Conv2d(total_channels, in_channels, kernel_size=1, stride=1, bias=False)
)
def forward(self, x):
# 密集连接路径
features = x
for layer in self.dense_layers:
features = layer(features)
# 压缩 + 残差连接
out = self.compress(features) + x # ResNet 风格: 残差相加
return out
class Transition(nn.Sequential):
"""过渡层: BN->ReLU->1x1Conv->AvgPool"""
def __init__(self, in_channels, out_channels):
super(Transition, self).__init__()
self.add_module('bn', nn.BatchNorm2d(in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class ResDenseNet(nn.Module):
"""ResDenseNet: 融合 ResNet 残差连接与 DenseNet 密集连接
架构: Stem -> [DRB -> Transition] x 3 -> DRB -> BN -> ReLU -> GAP -> FC
block_config = (3, 4, 6, 4) 比 DenseNet121 的 (6,12,24,16) 更轻量
growth_rate = 32 每个密集层产生 32 个新特征
"""
def __init__(self, growth_rate=32, block_config=(3, 4, 6, 4),
num_init_features=64, bn_size=4, num_classes=1000):
super(ResDenseNet, self).__init__()
# ===== Stem =====
self.stem = nn.Sequential(
nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False),
nn.BatchNorm2d(num_init_features),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# ===== Stages =====
self.stages = nn.ModuleList()
self.transitions = nn.ModuleList()
num_features = num_init_features
for i, num_layers in enumerate(block_config):
# Dense Residual Block
self.stages.append(
DenseResidualBlock(num_features, num_layers, growth_rate, bn_size)
)
# Transition (最后一个 stage 后不加)
if i != len(block_config) - 1:
next_features = num_features * 2
self.transitions.append(
Transition(num_features, next_features)
)
num_features = next_features
# ===== Final BN + Classifier =====
self.final_bn = nn.BatchNorm2d(num_features)
self.classifier = nn.Linear(num_features, num_classes)
# ===== 权重初始化 (Kaiming) =====
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.stem(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i < len(self.transitions):
x = self.transitions[i](x)
x = F.relu(self.final_bn(x), inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
model = ResDenseNet(num_classes=3).to(device)
model
ResDenseNet(
(stem): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(stages): ModuleList(
(0): DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(1): DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(2): DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(4): DenseLayer(
(bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(5): DenseLayer(
(bn1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(448, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(3): DenseResidualBlock(
(dense_layers): ModuleList(
(0): DenseLayer(
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(1): DenseLayer(
(bn1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(2): DenseLayer(
(bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(3): DenseLayer(
(bn1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
)
(compress): Sequential(
(0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): ReLU(inplace=True)
(2): Conv2d(640, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
)
(transitions): ModuleList(
(0): Transition(
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(1): Transition(
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(2): Transition(
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
)
(final_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(classifier): Linear(in_features=512, out_features=3, bias=True)
)
2.1 查看模型详情
import torchsummary as summary
summary.summary(model, (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
BatchNorm2d-5 [-1, 64, 56, 56] 128
Conv2d-6 [-1, 128, 56, 56] 8,192
BatchNorm2d-7 [-1, 128, 56, 56] 256
Conv2d-8 [-1, 32, 56, 56] 36,864
DenseLayer-9 [-1, 96, 56, 56] 0
BatchNorm2d-10 [-1, 96, 56, 56] 192
Conv2d-11 [-1, 128, 56, 56] 12,288
BatchNorm2d-12 [-1, 128, 56, 56] 256
Conv2d-13 [-1, 32, 56, 56] 36,864
DenseLayer-14 [-1, 128, 56, 56] 0
BatchNorm2d-15 [-1, 128, 56, 56] 256
Conv2d-16 [-1, 128, 56, 56] 16,384
BatchNorm2d-17 [-1, 128, 56, 56] 256
Conv2d-18 [-1, 32, 56, 56] 36,864
DenseLayer-19 [-1, 160, 56, 56] 0
BatchNorm2d-20 [-1, 160, 56, 56] 320
ReLU-21 [-1, 160, 56, 56] 0
Conv2d-22 [-1, 64, 56, 56] 10,240
DenseResidualBlock-23 [-1, 64, 56, 56] 0
BatchNorm2d-24 [-1, 64, 56, 56] 128
ReLU-25 [-1, 64, 56, 56] 0
Conv2d-26 [-1, 128, 56, 56] 8,192
AvgPool2d-27 [-1, 128, 28, 28] 0
BatchNorm2d-28 [-1, 128, 28, 28] 256
Conv2d-29 [-1, 128, 28, 28] 16,384
BatchNorm2d-30 [-1, 128, 28, 28] 256
Conv2d-31 [-1, 32, 28, 28] 36,864
DenseLayer-32 [-1, 160, 28, 28] 0
BatchNorm2d-33 [-1, 160, 28, 28] 320
Conv2d-34 [-1, 128, 28, 28] 20,480
BatchNorm2d-35 [-1, 128, 28, 28] 256
Conv2d-36 [-1, 32, 28, 28] 36,864
DenseLayer-37 [-1, 192, 28, 28] 0
BatchNorm2d-38 [-1, 192, 28, 28] 384
Conv2d-39 [-1, 128, 28, 28] 24,576
BatchNorm2d-40 [-1, 128, 28, 28] 256
Conv2d-41 [-1, 32, 28, 28] 36,864
DenseLayer-42 [-1, 224, 28, 28] 0
BatchNorm2d-43 [-1, 224, 28, 28] 448
Conv2d-44 [-1, 128, 28, 28] 28,672
BatchNorm2d-45 [-1, 128, 28, 28] 256
Conv2d-46 [-1, 32, 28, 28] 36,864
DenseLayer-47 [-1, 256, 28, 28] 0
BatchNorm2d-48 [-1, 256, 28, 28] 512
ReLU-49 [-1, 256, 28, 28] 0
Conv2d-50 [-1, 128, 28, 28] 32,768
DenseResidualBlock-51 [-1, 128, 28, 28] 0
BatchNorm2d-52 [-1, 128, 28, 28] 256
ReLU-53 [-1, 128, 28, 28] 0
Conv2d-54 [-1, 256, 28, 28] 32,768
AvgPool2d-55 [-1, 256, 14, 14] 0
BatchNorm2d-56 [-1, 256, 14, 14] 512
Conv2d-57 [-1, 128, 14, 14] 32,768
BatchNorm2d-58 [-1, 128, 14, 14] 256
Conv2d-59 [-1, 32, 14, 14] 36,864
DenseLayer-60 [-1, 288, 14, 14] 0
BatchNorm2d-61 [-1, 288, 14, 14] 576
Conv2d-62 [-1, 128, 14, 14] 36,864
BatchNorm2d-63 [-1, 128, 14, 14] 256
Conv2d-64 [-1, 32, 14, 14] 36,864
DenseLayer-65 [-1, 320, 14, 14] 0
BatchNorm2d-66 [-1, 320, 14, 14] 640
Conv2d-67 [-1, 128, 14, 14] 40,960
BatchNorm2d-68 [-1, 128, 14, 14] 256
Conv2d-69 [-1, 32, 14, 14] 36,864
DenseLayer-70 [-1, 352, 14, 14] 0
BatchNorm2d-71 [-1, 352, 14, 14] 704
Conv2d-72 [-1, 128, 14, 14] 45,056
BatchNorm2d-73 [-1, 128, 14, 14] 256
Conv2d-74 [-1, 32, 14, 14] 36,864
DenseLayer-75 [-1, 384, 14, 14] 0
BatchNorm2d-76 [-1, 384, 14, 14] 768
Conv2d-77 [-1, 128, 14, 14] 49,152
BatchNorm2d-78 [-1, 128, 14, 14] 256
Conv2d-79 [-1, 32, 14, 14] 36,864
DenseLayer-80 [-1, 416, 14, 14] 0
BatchNorm2d-81 [-1, 416, 14, 14] 832
Conv2d-82 [-1, 128, 14, 14] 53,248
BatchNorm2d-83 [-1, 128, 14, 14] 256
Conv2d-84 [-1, 32, 14, 14] 36,864
DenseLayer-85 [-1, 448, 14, 14] 0
BatchNorm2d-86 [-1, 448, 14, 14] 896
ReLU-87 [-1, 448, 14, 14] 0
Conv2d-88 [-1, 256, 14, 14] 114,688
DenseResidualBlock-89 [-1, 256, 14, 14] 0
BatchNorm2d-90 [-1, 256, 14, 14] 512
ReLU-91 [-1, 256, 14, 14] 0
Conv2d-92 [-1, 512, 14, 14] 131,072
AvgPool2d-93 [-1, 512, 7, 7] 0
BatchNorm2d-94 [-1, 512, 7, 7] 1,024
Conv2d-95 [-1, 128, 7, 7] 65,536
BatchNorm2d-96 [-1, 128, 7, 7] 256
Conv2d-97 [-1, 32, 7, 7] 36,864
DenseLayer-98 [-1, 544, 7, 7] 0
BatchNorm2d-99 [-1, 544, 7, 7] 1,088
Conv2d-100 [-1, 128, 7, 7] 69,632
BatchNorm2d-101 [-1, 128, 7, 7] 256
Conv2d-102 [-1, 32, 7, 7] 36,864
DenseLayer-103 [-1, 576, 7, 7] 0
BatchNorm2d-104 [-1, 576, 7, 7] 1,152
Conv2d-105 [-1, 128, 7, 7] 73,728
BatchNorm2d-106 [-1, 128, 7, 7] 256
Conv2d-107 [-1, 32, 7, 7] 36,864
DenseLayer-108 [-1, 608, 7, 7] 0
BatchNorm2d-109 [-1, 608, 7, 7] 1,216
Conv2d-110 [-1, 128, 7, 7] 77,824
BatchNorm2d-111 [-1, 128, 7, 7] 256
Conv2d-112 [-1, 32, 7, 7] 36,864
DenseLayer-113 [-1, 640, 7, 7] 0
BatchNorm2d-114 [-1, 640, 7, 7] 1,280
ReLU-115 [-1, 640, 7, 7] 0
Conv2d-116 [-1, 512, 7, 7] 327,680
DenseResidualBlock-117 [-1, 512, 7, 7] 0
BatchNorm2d-118 [-1, 512, 7, 7] 1,024
Linear-119 [-1, 3] 1,539
================================================================
Total params: 1,986,691
Trainable params: 1,986,691
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 114.32
Params size (MB): 7.58
Estimated Total Size (MB): 122.47
----------------------------------------------------------------
3、训练模型
3.1 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss, train_acc = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
3.2 编写测试函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, test_acc = 0, 0
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
3.3 正式训练
训练策略(综合 J2/J3 最优实践):
- AdamW 优化器 + 权重衰减 (weight_decay=1e-4)
- 标签平滑 (label_smoothing=0.1) 防止过拟合
- 余弦退火 学习率调度 (CosineAnnealingLR)
- Kaiming 权重初始化
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0
for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
scheduler.step() # 余弦退火更新学习率
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
lr = optimizer.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, '
'Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型
PATH = './model/day04_resdensenet_best_model.pth'
os.makedirs(os.path.dirname(PATH), exist_ok=True)
torch.save(best_model.state_dict(), PATH)
print('Done')
Epoch: 1, Train_acc:63.6%, Train_loss:0.920, Test_acc:32.7%, Test_loss:1.586, Lr:9.76E-04
Epoch: 2, Train_acc:67.3%, Train_loss:0.867, Test_acc:69.4%, Test_loss:0.870, Lr:9.05E-04
Epoch: 3, Train_acc:69.1%, Train_loss:0.855, Test_acc:63.7%, Test_loss:0.977, Lr:7.94E-04
Epoch: 4, Train_acc:73.0%, Train_loss:0.799, Test_acc:80.2%, Test_loss:0.765, Lr:6.55E-04
Epoch: 5, Train_acc:74.9%, Train_loss:0.744, Test_acc:83.8%, Test_loss:0.651, Lr:5.01E-04
Epoch: 6, Train_acc:76.0%, Train_loss:0.731, Test_acc:62.2%, Test_loss:1.192, Lr:3.46E-04
Epoch: 7, Train_acc:80.0%, Train_loss:0.674, Test_acc:81.4%, Test_loss:0.663, Lr:2.07E-04
Epoch: 8, Train_acc:80.0%, Train_loss:0.667, Test_acc:77.2%, Test_loss:0.760, Lr:9.64E-05
Epoch: 9, Train_acc:82.8%, Train_loss:0.613, Test_acc:85.3%, Test_loss:0.542, Lr:2.54E-05
Epoch:10, Train_acc:84.1%, Train_loss:0.593, Test_acc:85.0%, Test_loss:0.557, Lr:1.00E-06
Done
4、结果可视化
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
from datetime import datetime
current_time = datetime.now()
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('ResDenseNet - Training and Validation Accuracy')
plt.xlabel(current_time)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('ResDenseNet - Training and Validation Loss')
plt.show()

5、模型评估
best_model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'ResDenseNet Best Test Accuracy: {epoch_test_acc*100:.1f}%')
print(f'ResDenseNet Best Test Loss: {epoch_test_loss:.4f}')
ResDenseNet Best Test Accuracy: 85.3%
ResDenseNet Best Test Loss: 0.5423
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)