热力图可视化、特征图可视化、gradcam 图像分类模型 vgg resnet densenet mobilenet 数据集cifar10 mini-imagenet 不同网络结构可视化结果 (子图请点击放大后在看,缩略图看不到全貌) 只限pytorch代码

在深度学习领域,理解模型如何做出决策至关重要。图像分类模型如VGG、ResNet、DenseNet和MobileNet在CIFAR - 10和Mini - ImageNet数据集上表现出色,但探究它们内部工作机制能让我们进一步优化和改进模型。本文将通过热力图可视化、特征图可视化以及GradCAM技术,利用PyTorch代码展示不同网络结构的可视化结果。

1. 数据集加载

首先,我们需要加载CIFAR - 10或Mini - ImageNet数据集。以CIFAR - 10为例:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=4,
                         shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4,
                        shuffle=False)

这里我们对图像进行了尺寸调整、转换为张量并归一化处理。DataLoader则用于按批次加载数据。

2. 模型加载

以VGG16为例加载预训练模型:

import torchvision.models as models
import torch

vgg16 = models.vgg16(pretrained=True)
vgg16.eval()

models.vgg16(pretrained=True)加载了在ImageNet上预训练的VGG16模型,并通过eval()设置为评估模式。

3. 特征图可视化

特征图可视化能让我们看到模型不同层对图像特征的提取情况。以下代码获取VGG16某一层的特征图:

from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# 获取某一层的输出
layer_name = 'features.29'
hooked_features = None
def hook(module, input, output):
    global hooked_features
    hooked_features = output

vgg16._modules.get(layer_name).register_forward_hook(hook)

images, labels = next(iter(testloader))
_ = vgg16(images)

feature_maps = hooked_features.squeeze()
grid = make_grid(feature_maps[:16], nrow=4)
plt.imshow(grid.permute(1, 2, 0))
plt.show()

这段代码定义了一个钩子函数hook,用于捕获features.29层的输出。获取特征图后,使用make_grid将部分特征图拼成一个网格并展示。

4. 热力图可视化与GradCAM

GradCAM(Gradient - weighted Class Activation Mapping)用于生成热力图,指示模型在做出决策时关注图像的哪些区域。

import torch.nn.functional as F

def gradcam(model, input_image, target_layer):
    model.eval()
    input_image = input_image.requires_grad_(True)
    features = None
    def hook(module, input, output):
        nonlocal features
        features = output

    target_layer.register_forward_hook(hook)
    output = model(input_image)
    pred = output.argmax(dim = 1)
    one_hot = torch.zeros_like(output)
    one_hot[0][pred[0]] = 1
    model.zero_grad()
    output.backward(gradient = one_hot)
    gradients = input_image.grad

    pooled_gradients = F.adaptive_avg_pool2d(gradients, (1, 1))
    features = features.detach()
    for i in range(features.shape[1]):
        features[:, i, :, :] *= pooled_gradients[:, i, :, :]

    heatmap = features.mean(dim = 1).squeeze()
    heatmap = F.relu(heatmap)
    heatmap = heatmap / heatmap.max()

    return heatmap

# 使用GradCAM获取热力图
layer_to_visualize = vgg16.features[-1]
image, _ = next(iter(testloader))
heatmap = gradcam(vgg16, image, layer_to_visualize)

plt.imshow(heatmap, cmap='jet', alpha=0.5)
image = image.squeeze().permute(1, 2, 0)
image = (image * torch.tensor((0.5, 0.5, 0.5)) + torch.tensor((0.5, 0.5, 0.5))).clamp(0, 1)
plt.imshow(image)
plt.show()

gradcam函数中,我们首先注册了一个钩子获取目标层的特征,然后通过反向传播计算梯度。对梯度进行池化后与特征图加权求和,生成热力图。最后将热力图与原始图像叠加展示。

5. 不同网络结构可视化结果对比

对于ResNet、DenseNet和MobileNet,只需更改模型加载部分代码,例如加载ResNet18:

resnet18 = models.resnet18(pretrained=True)
resnet18.eval()

然后重复上述特征图可视化和GradCAM的步骤,就可以得到不同网络结构的可视化结果。通过对比不同网络在同一数据集图像上的可视化结果,可以发现:

  • VGG网络的特征图相对较为清晰,反映出其卷积层对纹理等特征的有效提取,但热力图可能较为分散,说明其决策依据相对广泛。
  • ResNet由于跳跃连接,特征图保留了更多底层信息,热力图可能更集中在关键物体区域,显示出对物体主体的关注。
  • DenseNet的密集连接使得特征传递更高效,特征图可能包含丰富且复杂的信息,热力图或许能更精准地定位物体。
  • MobileNet作为轻量级网络,特征图可能相对简单,但热力图也能抓住主要判别区域,体现其在资源受限情况下的有效性。

(子图请点击放大后在看,缩略图看不到全貌)通过这些可视化技术,我们能更好地理解不同图像分类网络结构的特点,为模型选择和优化提供依据。

热力图可视化、特征图可视化、gradcam 图像分类模型 vgg resnet densenet mobilenet 数据集cifar10 mini-imagenet 不同网络结构可视化结果 (子图请点击放大后在看,缩略图看不到全貌) 只限pytorch代码

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐