计算机视觉:语义分割的损失函数设计与类别不平衡处理
计算机视觉:语义分割的损失函数设计与类别不平衡处理

一、语义分割的"长尾困境":小目标决定整体精度
语义分割任务中,类别不平衡是影响模型精度的首要因素。以自动驾驶场景为例:道路和天空占据图像 70% 以上的面积,而行人、交通标志、路沿等小目标仅占 1-5%。使用标准交叉熵损失训练时,模型倾向于将大部分像素预测为"道路"或"天空"——因为即使漏掉所有小目标,整体像素准确率仍然很高。
更严重的是,小目标往往是最关键的目标。自动驾驶中漏检一个行人远比漏检一段路面危险。医学影像中,病灶区域通常只占图像的 0.1-2%,但恰恰是诊断的核心。标准损失函数无法区分"大类别少一个像素"和"小类别少一个像素"的严重程度差异,必须通过损失函数设计来显式处理类别不平衡问题。
二、损失函数的数学原理与设计策略
2.1 从标准交叉熵到加权交叉熵
标准交叉熵损失对所有像素一视同仁:
CE(p, y) = -Σ y_i * log(p_i)
加权交叉熵为每个类别分配不同权重:
WCE(p, y) = -Σ w_c * y_i * log(p_i)
权重 w_c 的计算方式通常为:w_c = 1 / freq_c(频率的倒数)或 w_c = median_freq / freq_c(中位数频率归一化)。但简单加权存在过矫正问题:极低频类别的权重可能过大,导致训练不稳定。
flowchart TD
A[语义分割损失函数] --> B[像素级损失]
A --> C[区域级损失]
A --> D[边界级损失]
B --> B1[加权交叉熵<br/>WCE]
B --> B2[Focal Loss<br/>降低易分类样本权重]
C --> C1[Dice Loss<br/>区域重叠度]
C --> C2[IoU Loss<br/>交并比]
D --> D1[Boundary Loss<br/>边界距离惩罚]
D --> D2[Lovász-Softmax<br/>IoU 的凸代理]
style B fill:#e1f5fe
style C fill:#fff3e0
style D fill:#e8f5e9
2.2 Focal Loss 的核心思想
Focal Loss 的设计动机是:大量像素属于"易分类"的大类别,它们对梯度的贡献占据主导地位,淹没了"难分类"小目标的梯度信号。
FL(p_t) = -α_t * (1 - p_t)^γ * log(p_t)
其中 p_t 是正确类别的预测概率,γ 是聚焦参数(通常取 2)。当 p_t 接近 1 时(易分类样本),(1 - p_t)^γ 趋近于 0,该样本的损失权重被大幅降低。当 p_t 较小时(难分类样本),权重接近 1,保留完整梯度信号。
2.3 Dice Loss 与 IoU Loss 的区域级优化
Dice Loss 基于区域重叠度,天然对小目标敏感:
Dice = 2 * |P ∩ G| / (|P| + |G|)
Dice_Loss = 1 - Dice
IoU Loss(Jaccard Loss)类似,但分母更严格:
IoU = |P ∩ G| / |P ∪ G|
IoU_Loss = 1 - IoU
两者的共同优势是:损失值与目标大小无关——一个小目标的 IoU 从 0.1 提升到 0.5,和一个大目标的 IoU 从 0.1 提升到 0.5,对损失的贡献相同。这天然缓解了类别不平衡问题。
三、生产级代码实现:组合损失函数与训练策略
3.1 组合损失函数
import torch
import torch.nn as nn
import torch.nn.functional as F
class CombinedSegLoss(nn.Module):
"""组合损失函数:WCE + Dice + Boundary
权重比例根据验证集调优,初始建议:
- WCE: 1.0(基础分类损失)
- Dice: 1.0(区域级平衡)
- Boundary: 0.5(边界精度增强)
"""
def __init__(
self,
num_classes: int,
class_weights: torch.Tensor,
focal_gamma: float = 2.0,
dice_smooth: float = 1.0,
boundary_weight: float = 0.5,
):
super().__init__()
self.num_classes = num_classes
self.class_weights = class_weights
self.focal_gamma = focal_gamma
self.dice_smooth = dice_smooth
self.boundary_weight = boundary_weight
def forward(
self,
logits: torch.Tensor, # (B, C, H, W)
targets: torch.Tensor, # (B, H, W)
boundaries: torch.Tensor = None, # (B, 1, H, W) 边界标注
):
# 1. 加权 Focal Loss
wce_loss = self._weighted_focal_loss(logits, targets)
# 2. Dice Loss
dice_loss = self._dice_loss(logits, targets)
# 3. Boundary Loss(可选)
boundary_loss = 0.0
if boundaries is not None:
boundary_loss = self._boundary_loss(logits, boundaries)
total = wce_loss + dice_loss + self.boundary_weight * boundary_loss
return total, {
"wce": wce_loss.item(),
"dice": dice_loss.item(),
"boundary": boundary_loss.item() if isinstance(boundary_loss, torch.Tensor) else boundary_loss,
}
def _weighted_focal_loss(self, logits, targets):
"""加权 Focal Loss"""
ce = F.cross_entropy(
logits, targets, weight=self.class_weights, reduction="none"
)
pt = torch.exp(-ce)
focal_weight = (1 - pt) ** self.focal_gamma
return (focal_weight * ce).mean()
def _dice_loss(self, logits, targets):
"""多类别 Dice Loss"""
probs = F.softmax(logits, dim=1)
targets_one_hot = F.one_hot(targets, self.num_classes)
targets_one_hot = targets_one_hot.permute(0, 3, 1, 2).float()
# 逐类别计算 Dice
dims = (0, 2, 3) # Batch, Height, Width 维度
intersection = (probs * targets_one_hot).sum(dim=dims)
cardinality = (probs + targets_one_hot).sum(dim=dims)
dice_score = (2.0 * intersection + self.dice_smooth) / (
cardinality + self.dice_smooth
)
# 对所有类别取平均(每个类别权重相同,天然平衡)
return 1.0 - dice_score.mean()
def _boundary_loss(self, logits, boundaries):
"""边界损失:增强分割边界的精度"""
probs = F.softmax(logits, dim=1)
# 取最大预测概率作为边界置信度
max_prob, _ = probs.max(dim=1, keepdim=True)
# 边界区域的交叉熵
boundary_mask = boundaries.float()
boundary_pixels = boundary_mask.sum()
if boundary_pixels < 1:
return torch.tensor(0.0, device=logits.device)
# 在边界区域计算预测与标注的差异
loss = F.binary_cross_entropy(max_prob, boundary_mask, reduction="none")
return (loss * boundary_mask).sum() / boundary_pixels
3.2 类别权重自动计算
import numpy as np
from collections import Counter
def compute_class_weights(
dataset,
num_classes: int,
method: str = "median_freq",
) -> torch.Tensor:
"""从数据集统计中自动计算类别权重
Args:
dataset: 数据集,每条数据包含 (image, mask) 对
num_classes: 类别数
method: 权重计算方法
- "inverse_freq": 频率倒数
- "median_freq": 中位数频率归一化
- "effective_num": 有效样本数(CB Loss)
"""
# 统计每个类别的像素数
class_pixel_counts = np.zeros(num_classes, dtype=np.float64)
for _, mask in dataset:
mask_np = np.array(mask)
counts = Counter(mask_np.flatten())
for cls_id, count in counts.items():
if 0 <= cls_id < num_classes:
class_pixel_counts[cls_id] += count
total_pixels = class_pixel_counts.sum()
class_frequencies = class_pixel_counts / total_pixels
if method == "inverse_freq":
# 简单频率倒数,上限裁剪防止极端权重
weights = 1.0 / (class_frequencies + 1e-6)
weights = np.clip(weights, 0.1, 50.0)
elif method == "median_freq":
# 中位数频率归一化
median_freq = np.median(class_frequencies[class_frequencies > 0])
weights = median_freq / (class_frequencies + 1e-6)
weights = np.clip(weights, 0.1, 50.0)
elif method == "effective_num":
# Class-Balanced Loss 的有效样本数
beta = 0.9999
effective_num = 1.0 - beta ** class_pixel_counts
weights = (1.0 - beta) / effective_num
weights = np.clip(weights, 0.1, 50.0)
else:
raise ValueError(f"未知权重计算方法: {method}")
# 归一化使权重均值为 1
weights = weights / weights.mean()
return torch.tensor(weights, dtype=torch.float32)
3.3 在线难例挖掘(Online Hard Example Mining)
class OHEMLoss(nn.Module):
"""在线难例挖掘:只对损失最高的像素反向传播"""
def __init__(
self,
base_loss: nn.Module,
keep_ratio: float = 0.25, # 保留 25% 最难像素
):
super().__init__()
self.base_loss = base_loss
self.keep_ratio = keep_ratio
def forward(self, logits, targets):
# 计算逐像素损失
pixel_losses = F.cross_entropy(
logits, targets, reduction="none"
) # (B, H, W)
# 展平并排序
flat_losses = pixel_losses.flatten()
num_keep = int(len(flat_losses) * self.keep_ratio)
# 保留损失最高的像素
topk_losses, _ = torch.topk(flat_losses, num_keep)
return topk_losses.mean()
四、损失函数设计的工程权衡
4.1 Dice Loss 的梯度不稳定
Dice Loss 在训练初期(预测接近随机时)梯度极大,可能导致训练不稳定。解决方案:前 N 个 epoch 使用 WCE 预热,再切换为 WCE + Dice 组合。或者使用 Dice 的变体——Soft Dice,在分子分母添加平滑项。
4.2 多损失组合的超参敏感性
WCE + Dice + Boundary 三项损失的权重比例对训练结果影响显著。权重设置不当可能导致某项损失主导训练,其他项形同虚设。建议策略:先单独训练每项损失确定各自的收敛范围,再按量级比例设置组合权重。
4.3 类别权重的数据依赖
自动计算的类别权重高度依赖训练集的类别分布。如果训练集与实际部署场景的分布不一致(如训练集道路占比 60%,实际场景道路占比 40%),权重设置会偏离最优。在分布差异大的场景中,需要在验证集上调整权重,而非直接使用训练集统计。
五、总结
语义分割的类别不平衡处理是损失函数设计的核心驱动力。三个关键策略:第一,使用加权 Focal Loss 降低易分类样本的梯度贡献,让模型聚焦于难分类的小目标;第二,组合 Dice Loss 提供区域级优化信号,天然平衡大小目标的损失贡献;第三,在线难例挖掘(OHEM)动态筛选困难像素,避免简单样本浪费训练资源。损失函数设计不是"选一个最好的",而是"组合多个互补的"——每种损失解决不同层面的问题,组合使用才能覆盖完整的优化空间。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)