计算机视觉:语义分割的损失函数设计与类别不平衡处理

cover

一、语义分割的"长尾困境":小目标决定整体精度

语义分割任务中,类别不平衡是影响模型精度的首要因素。以自动驾驶场景为例:道路和天空占据图像 70% 以上的面积,而行人、交通标志、路沿等小目标仅占 1-5%。使用标准交叉熵损失训练时,模型倾向于将大部分像素预测为"道路"或"天空"——因为即使漏掉所有小目标,整体像素准确率仍然很高。

更严重的是,小目标往往是最关键的目标。自动驾驶中漏检一个行人远比漏检一段路面危险。医学影像中,病灶区域通常只占图像的 0.1-2%,但恰恰是诊断的核心。标准损失函数无法区分"大类别少一个像素"和"小类别少一个像素"的严重程度差异,必须通过损失函数设计来显式处理类别不平衡问题。

二、损失函数的数学原理与设计策略

2.1 从标准交叉熵到加权交叉熵

标准交叉熵损失对所有像素一视同仁:

CE(p, y) = -Σ y_i * log(p_i)

加权交叉熵为每个类别分配不同权重:

WCE(p, y) = -Σ w_c * y_i * log(p_i)

权重 w_c 的计算方式通常为:w_c = 1 / freq_c(频率的倒数)或 w_c = median_freq / freq_c(中位数频率归一化)。但简单加权存在过矫正问题:极低频类别的权重可能过大,导致训练不稳定。

flowchart TD
    A[语义分割损失函数] --> B[像素级损失]
    A --> C[区域级损失]
    A --> D[边界级损失]

    B --> B1[加权交叉熵<br/>WCE]
    B --> B2[Focal Loss<br/>降低易分类样本权重]

    C --> C1[Dice Loss<br/>区域重叠度]
    C --> C2[IoU Loss<br/>交并比]

    D --> D1[Boundary Loss<br/>边界距离惩罚]
    D --> D2[Lovász-Softmax<br/>IoU 的凸代理]

    style B fill:#e1f5fe
    style C fill:#fff3e0
    style D fill:#e8f5e9

2.2 Focal Loss 的核心思想

Focal Loss 的设计动机是:大量像素属于"易分类"的大类别,它们对梯度的贡献占据主导地位,淹没了"难分类"小目标的梯度信号。

FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)

其中 p_t 是正确类别的预测概率,γ 是聚焦参数(通常取 2)。当 p_t 接近 1 时(易分类样本),(1 - p_t)^γ 趋近于 0,该样本的损失权重被大幅降低。当 p_t 较小时(难分类样本),权重接近 1,保留完整梯度信号。

2.3 Dice Loss 与 IoU Loss 的区域级优化

Dice Loss 基于区域重叠度,天然对小目标敏感:

Dice = 2 * |P ∩ G| / (|P| + |G|)
Dice_Loss = 1 - Dice

IoU Loss(Jaccard Loss)类似,但分母更严格:

IoU = |P ∩ G| / |P ∪ G|
IoU_Loss = 1 - IoU

两者的共同优势是:损失值与目标大小无关——一个小目标的 IoU 从 0.1 提升到 0.5,和一个大目标的 IoU 从 0.1 提升到 0.5,对损失的贡献相同。这天然缓解了类别不平衡问题。

三、生产级代码实现:组合损失函数与训练策略

3.1 组合损失函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class CombinedSegLoss(nn.Module):
    """组合损失函数:WCE + Dice + Boundary

    权重比例根据验证集调优,初始建议:
    - WCE: 1.0(基础分类损失)
    - Dice: 1.0(区域级平衡)
    - Boundary: 0.5(边界精度增强)
    """

    def __init__(
        self,
        num_classes: int,
        class_weights: torch.Tensor,
        focal_gamma: float = 2.0,
        dice_smooth: float = 1.0,
        boundary_weight: float = 0.5,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.class_weights = class_weights
        self.focal_gamma = focal_gamma
        self.dice_smooth = dice_smooth
        self.boundary_weight = boundary_weight

    def forward(
        self,
        logits: torch.Tensor,       # (B, C, H, W)
        targets: torch.Tensor,       # (B, H, W)
        boundaries: torch.Tensor = None,  # (B, 1, H, W) 边界标注
    ):
        # 1. 加权 Focal Loss
        wce_loss = self._weighted_focal_loss(logits, targets)

        # 2. Dice Loss
        dice_loss = self._dice_loss(logits, targets)

        # 3. Boundary Loss(可选)
        boundary_loss = 0.0
        if boundaries is not None:
            boundary_loss = self._boundary_loss(logits, boundaries)

        total = wce_loss + dice_loss + self.boundary_weight * boundary_loss

        return total, {
            "wce": wce_loss.item(),
            "dice": dice_loss.item(),
            "boundary": boundary_loss.item() if isinstance(boundary_loss, torch.Tensor) else boundary_loss,
        }

    def _weighted_focal_loss(self, logits, targets):
        """加权 Focal Loss"""
        ce = F.cross_entropy(
            logits, targets, weight=self.class_weights, reduction="none"
        )
        pt = torch.exp(-ce)
        focal_weight = (1 - pt) ** self.focal_gamma
        return (focal_weight * ce).mean()

    def _dice_loss(self, logits, targets):
        """多类别 Dice Loss"""
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, self.num_classes)
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()

        # 逐类别计算 Dice
        dims = (0, 2, 3)  # Batch, Height, Width 维度
        intersection = (probs * targets_one_hot).sum(dim=dims)
        cardinality = (probs + targets_one_hot).sum(dim=dims)

        dice_score = (2.0 * intersection + self.dice_smooth) / (
            cardinality + self.dice_smooth
        )

        # 对所有类别取平均(每个类别权重相同,天然平衡)
        return 1.0 - dice_score.mean()

    def _boundary_loss(self, logits, boundaries):
        """边界损失:增强分割边界的精度"""
        probs = F.softmax(logits, dim=1)
        # 取最大预测概率作为边界置信度
        max_prob, _ = probs.max(dim=1, keepdim=True)

        # 边界区域的交叉熵
        boundary_mask = boundaries.float()
        boundary_pixels = boundary_mask.sum()

        if boundary_pixels < 1:
            return torch.tensor(0.0, device=logits.device)

        # 在边界区域计算预测与标注的差异
        loss = F.binary_cross_entropy(max_prob, boundary_mask, reduction="none")
        return (loss * boundary_mask).sum() / boundary_pixels

3.2 类别权重自动计算

import numpy as np
from collections import Counter

def compute_class_weights(
    dataset,
    num_classes: int,
    method: str = "median_freq",
) -> torch.Tensor:
    """从数据集统计中自动计算类别权重

    Args:
        dataset: 数据集,每条数据包含 (image, mask) 对
        num_classes: 类别数
        method: 权重计算方法
            - "inverse_freq": 频率倒数
            - "median_freq": 中位数频率归一化
            - "effective_num": 有效样本数(CB Loss)
    """
    # 统计每个类别的像素数
    class_pixel_counts = np.zeros(num_classes, dtype=np.float64)

    for _, mask in dataset:
        mask_np = np.array(mask)
        counts = Counter(mask_np.flatten())
        for cls_id, count in counts.items():
            if 0 <= cls_id < num_classes:
                class_pixel_counts[cls_id] += count

    total_pixels = class_pixel_counts.sum()
    class_frequencies = class_pixel_counts / total_pixels

    if method == "inverse_freq":
        # 简单频率倒数,上限裁剪防止极端权重
        weights = 1.0 / (class_frequencies + 1e-6)
        weights = np.clip(weights, 0.1, 50.0)

    elif method == "median_freq":
        # 中位数频率归一化
        median_freq = np.median(class_frequencies[class_frequencies > 0])
        weights = median_freq / (class_frequencies + 1e-6)
        weights = np.clip(weights, 0.1, 50.0)

    elif method == "effective_num":
        # Class-Balanced Loss 的有效样本数
        beta = 0.9999
        effective_num = 1.0 - beta ** class_pixel_counts
        weights = (1.0 - beta) / effective_num
        weights = np.clip(weights, 0.1, 50.0)

    else:
        raise ValueError(f"未知权重计算方法: {method}")

    # 归一化使权重均值为 1
    weights = weights / weights.mean()

    return torch.tensor(weights, dtype=torch.float32)

3.3 在线难例挖掘(Online Hard Example Mining)

class OHEMLoss(nn.Module):
    """在线难例挖掘:只对损失最高的像素反向传播"""

    def __init__(
        self,
        base_loss: nn.Module,
        keep_ratio: float = 0.25,  # 保留 25% 最难像素
    ):
        super().__init__()
        self.base_loss = base_loss
        self.keep_ratio = keep_ratio

    def forward(self, logits, targets):
        # 计算逐像素损失
        pixel_losses = F.cross_entropy(
            logits, targets, reduction="none"
        )  # (B, H, W)

        # 展平并排序
        flat_losses = pixel_losses.flatten()
        num_keep = int(len(flat_losses) * self.keep_ratio)

        # 保留损失最高的像素
        topk_losses, _ = torch.topk(flat_losses, num_keep)

        return topk_losses.mean()

四、损失函数设计的工程权衡

4.1 Dice Loss 的梯度不稳定

Dice Loss 在训练初期(预测接近随机时)梯度极大,可能导致训练不稳定。解决方案:前 N 个 epoch 使用 WCE 预热,再切换为 WCE + Dice 组合。或者使用 Dice 的变体——Soft Dice,在分子分母添加平滑项。

4.2 多损失组合的超参敏感性

WCE + Dice + Boundary 三项损失的权重比例对训练结果影响显著。权重设置不当可能导致某项损失主导训练,其他项形同虚设。建议策略:先单独训练每项损失确定各自的收敛范围,再按量级比例设置组合权重。

4.3 类别权重的数据依赖

自动计算的类别权重高度依赖训练集的类别分布。如果训练集与实际部署场景的分布不一致(如训练集道路占比 60%,实际场景道路占比 40%),权重设置会偏离最优。在分布差异大的场景中,需要在验证集上调整权重,而非直接使用训练集统计。

五、总结

语义分割的类别不平衡处理是损失函数设计的核心驱动力。三个关键策略:第一,使用加权 Focal Loss 降低易分类样本的梯度贡献,让模型聚焦于难分类的小目标;第二,组合 Dice Loss 提供区域级优化信号,天然平衡大小目标的损失贡献;第三,在线难例挖掘(OHEM)动态筛选困难像素,避免简单样本浪费训练资源。损失函数设计不是"选一个最好的",而是"组合多个互补的"——每种损失解决不同层面的问题,组合使用才能覆盖完整的优化空间。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐