• YOLO-Pose 是 Ultralytics 基于 YOLO11/12(也兼容 YOLOv8)扩展的关键点 / 姿态估计专用模型,核心逻辑是在 YOLO 目标检测基础上增加 “关键点预测分支”—— 既检测目标(如手部)的边界框(bbox),又预测框内关键点的坐标(x,y)和可见性(v)。做 手势关键点检测(Hand Pose Estimation),研究框架应该是:

    • 任务定义–>数据体系–>模型体系–>训练策略–>评估体系–>部署与应用
  • 在目标检测网络中直接嵌入关键点预测分支,一次前向传播同时输出 “目标边界框 + 框内关键点坐标”。精度与效率的最佳平衡(速度接近 Bottom-Up,精度接近 Top-Down),工程化落地简单(单模型部署,无需复杂后处理);但是多目标密集场景下,关键点精度略低于 Top-Down;对超小目标(如远距离手部)的检测需额外调优。

  • 在CV领域通常分成 四层:Level1 任务理解、Level2 数据与标注、Level3 模型训练、Level4 推理与应用。而 YOLO-pose 本质是:Object Detection + Keypoint Regression,即 bbox detection + keypoints regression。YOLO-pose论文提出直接预测关键点而不是 heatmap。

    • image
        └─> backbone
               └─> neck
                     └─> head
                            ├─ bbox
                            ├─ class
                            └─ keypoints
      
  • 姿势估计是一项涉及识别图像中特定点的位置的任务,这些点通常称为关键点。关键点可以代表对象的各个部分,例如关节、地标或其他独特特征。关键点的位置通常表示为一组 2D [x, y] 或 3D [x, y, visible] 坐标。手部关键点数据集包含 26,768 张带有关键点标注的手部图像,使其适用于训练 Ultralytics YOLO 等模型进行姿势估计任务。这些标注是使用 Google MediaPipe 库生成的,确保了高准确性和一致性,并且该数据集与Ultralytics YOLO26格式兼容。该数据集包括用于手部检测的关键点。关键点注释如下(每只手总共有 21 个关键点),手腕、拇指(4 个点)、食指(4 个点)、中指(4 个点)、无名指(4 个点)、小指(4 个点)

    • 0  wrist
      thumb
      1  thumb_cmc
      2  thumb_mcp
      3  thumb_ip
      4  thumb_tip
      index
      5  index_mcp
      6  index_pip
      7  index_dip
      8  index_tip
      middle
      9  middle_mcp
      10 middle_pip
      11 middle_dip
      12 middle_tip
      ring
      13 ring_mcp
      14 ring_pip
      15 ring_dip
      16 ring_tip
      little
      17 little_mcp
      18 little_pip
      19 little_dip
      20 little_tip
      
    • 数据采集需要覆盖:lighting、background、skin color、pose、occlusion、camera angle、distance。5k ~ 20k images。自动标注可以用,MediaPipe Hands,SAM + keypoints。

  • 手势关键点检测项目的核心流程(基于 YOLO-Pose)

    • 核心环节 具体操作(YOLO-Pose 落地要点)
      数据准备 ① 采用 Hand-Keypoints 数据集(或自定义标注手部数据);② 按 YOLO 格式整理(图像 + txt 标注,txt 含 bbox+21 个关键点坐标);③ 用 JSON2YOLO 工具转换 COCO 格式到手部数据集的 YOLO 格式。
      模型训练 ① 加载预训练yolo11n-pose.pt(迁移学习,复用人体关键点特征);② 执行命令:yolo pose train data=hand-keypoints.yaml model=yolo11n-pose.pt epochs=100 imgsz=640;③ PoseTrainer 自动处理关键点损失(pose_loss/kobj_loss),无需手动定义损失函数。
      模型验证 yolo val pose data=coco-pose.yaml device=0验证,核心关注mAP_pose50-95(关键点检测精度)、速度(CPU/GPU 推理耗时);
      推理部署 ① 实时检测:yolo pose predict model=yolo11n-pose.pt source=0(摄像头实时手势检测);② 轻量化部署:导出 ONNX/TensorRT 格式,适配嵌入式设备(如 Jetson Nano)。

数据集构建与标注格式 (Data & Annotation)

  • 手势通常定义有 21 个关键点(如 MediaPipe Hands 标准)。YOLO-pose 要求数据集目录结构非常规整,请严格遵循:

    • hand_pose_dataset/
      ├── images/
      │   ├── train/  # 存放训练集图片 (jpg/png)
      │   └── val/    # 存放验证集图片
      └── labels/
          ├── train/  # 存放对应的标注文件 (.txt)
          └── val/
      
    • Ultralytics YOLO-pose 采用自定义的 TXT 格式,不直接使用 JSON。每一张图片对应一个同名的 .txt 文件。标注文件内容(一行代表一个手部目标):<class_id> <x_center> <y_center> <width> <height> <px1> <py1> <v1> <px2> <py2> <v2> ... <px21> <py21> <v21>。示例(假设只有 3 个关键点简化展示):0 0.5 0.5 0.3 0.4 0.45 0.45 2 0.55 0.45 2 0.5 0.6 1

    • class_id: 类别索引(手势检测通常设为 0,因为只有 “手” 这一个大类)。

    • x_center, y_center, width, height: 手部检测框(BBox)的坐标,必须归一化到 [0, 1] 之间。

    • px, py: 关键点坐标,同样需要归一化。

    • v (visibility): 关键点可见性标志。0: 关键点未标注(不存在),1: 关键点存在但被遮挡(不可见),2: 关键点存在且可见

训练跑通项目 (Implementation)

  • 创建一个 hand_pose.yaml 文件来告诉模型数据在哪里、有多少个关键点。

    • # hand_pose.yaml
      path: /path/to/hand_pose_dataset  # 数据集根目录
      train: images/train
      val: images/val
      # 关键点配置
      kpt_shape: [21, 3]  # [关键点数量, 维度(2=xy, 3=xy+visibility)]
      flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19] # 左右翻转时的关键点对应关系(用于数据增强)
      # 类别
      names:
        0: hand
      
  • 使用 Python API 进行训练,这是最灵活的方式:

    • from ultralytics import YOLO
      # 1. 加载模型
      # 推荐使用 yolov8n-pose.pt (Nano版,速度快) 或 yolov8s-pose.pt 进行迁移学习
      model = YOLO('yolov8s-pose.pt') 
      # 2. 训练
      results = model.train(
          data='hand_pose.yaml',
          epochs=100,
          imgsz=640,
          batch=32,
          device='0',  # 使用 GPU
          project='hand_pose_project',
          name='exp1'
      )
      # 3. 验证
      metrics = model.val()
      # 4. 推理
      source = 'test_video.mp4' # 或者是图片路径
      results = model(source, show=True, save=True)
      
    • YOLO-Pose 是一个 Anchor-Free 的单阶段模型。它的核心创新在于:并行输出: 在同一个检测头(Head)里,同时输出目标检测框(BBox)和关键点坐标(Keypoints)。解耦头: 分类和回归任务分开处理,这对关键点回归精度提升很大。

    • 训练时的损失由三部分组成,理解它对调参至关重要:Loss=Lbox+Lcls+Lkpts+LobjLoss=L_{box}+L_{cls}+L_{kpts} + L_{obj}Loss=Lbox+Lcls+Lkpts+Lobj

      • Lbox: 检测框损失 (CIoU Loss)。
      • Lcls: 类别损失 (Hand/Not Hand)。
      • Lkpts: 关键点损失。通常使用 OKS Loss (Object Keypoint Similarity) 或者简单的 MSE Loss。只有当 v>0 时,该点的损失才会被计算。
  • 手势关键点检测的核心指标是 OKS (Object Keypoint Similarity)。它类似于目标检测的 IoU,但考虑了关键点的标准差。简单来说,预测点和标注点离得越近,OKS 越高。公式如下:

    • OKS=∑exp(−di2/2s2ki2)δ(vi>0)∑δ(vi>0) OKS=\frac{∑exp(−d_i^2/2s^2k_i^2)δ(v_i>0)}{∑δ(v_i>0)} OKS=δ(vi>0)exp(di2/2s2ki2)δ(vi>0)
  • 手势是刚体且视角变化大,默认的 YOLO 增强可能不够:

    • 务必开启: hsv_h, hsv_s, hsv_v (颜色抖动),degrees (旋转)。小心使用随机裁剪(Random Crop)可能会把手切掉,导致关键点丢失。建议使用 Safe Augmentation 或针对手部检测优化的裁剪策略。
    • flip、rotate、scale、mosaic、motion blur、brightness、noise
  • 即使手指被遮挡,只要在画面内,建议标注其估计位置并设为 v=1,这比直接设为 0 更有助于模型学习结构。输出结果中,通常会附带一个 keypoints 数组。你可以根据 bbox 的置信度过滤掉低质量的检测,再提取关键点坐标进行后续的手势分类(如握拳、比耶)。

  • keypoint 顺序必须一致,最常见 bug index顺序错→ loss不收敛。YOLO pose 需要 bbox,因为 keypoints 依附于 bbox。输出:bbox: 4、class: 1、keypoints: 21*3。head输出:[batch, anchors, (4 + 1 + 63)]

  • 很多算法工程师会犯的错误沉迷改模型。其实真正重要的是 数据 > 训练策略 > 模型。比例 数据 约为 60%、训练策略 约为 25%、模型 约为 15%。

YOLO-pose 在 Ultralytics 中的整体源码架构

  • GitHub:ultralytics/ultralytics 核心目录:

    • ultralytics/
      │
      ├── models/
      │   ├── yolo/
      │   │   ├── detect
      │   │   ├── pose
      │   │   └── segment
      │
      ├── nn/
      │   ├── modules
      │   ├── tasks.py
      │
      ├── data/
      │   ├── dataset.py
      │   └── augment.py
      │
      ├── engine/
      │   ├── trainer.py
      │   ├── validator.py
      │   └── predictor.py
      │
      └── utils/
      
  • 这是 YOLO pose从数据 → 模型 → 训练 → 推理的完整流程

    •                 ┌─────────────────────┐
                      │  Dataset Loader     │
                      │ PoseDataset         │
                      └──────────┬──────────┘
                                 ▼
                       ┌─────────────────┐
                       │ Data Augment    │
                       │ augment.py      │
                       └────────┬────────┘
                                ▼
                      ┌──────────────────┐
                      │ Model Builder    │
                      │ tasks.py         │
                      └────────┬─────────┘
                               ▼
                     ┌─────────────────────┐
                     │ YOLO Backbone       │
                     │ CSP / Conv blocks   │
                     └────────┬────────────┘
                              ▼
                      ┌───────────────────┐
                      │ Neck (FPN/PAN)    │
                      └────────┬──────────┘
                               ▼
                      ┌────────────────────┐
                      │ Pose Head          │
                      │ PoseDetect         │
                      └────────┬───────────┘
                               ▼
                       ┌────────────────┐
                       │ Loss Builder   │
                       │ PoseLoss       │
                       └───────┬────────┘
                               ▼
                       ┌────────────────┐
                       │ Trainer        │
                       │ engine/        │
                       └───────┬────────┘
                               ▼
                        ┌───────────────┐
                        │ Predictor     │
                        └───────────────┘
      
    • 最关键的 5个文件

      功能 文件
      模型构建 nn/tasks.py
      pose head nn/modules/head.py
      dataset data/dataset.py
      loss utils/loss.py
      trainer engine/trainer.py
  • 真正创建 YOLO 模型的是:ultralytics/nn/tasks.py,核心类 class DetectionModel(BaseModel)。pose模型其实也是 detection model 的扩展。关键函数:def parse_model(),可以 读取yaml、构建网络。结构:backbone、neck、head.。在 Ultralytics YOLO 中模型不是手写,而是 YAML自动生成。比如:

    • nc: 1
      kpt_shape: [21,3]
      
      backbone:
        - [-1,1,Conv,[64,3,2]]
        - [-1,3,C2f,[128]]
      
      head:
        - [-1,1,PoseDetect,[nc,kpt_shape]]
      
    • YAML–>list structure–>loop build layers–>nn.Module。YAML每行结构 [from, number, module, args],例如 [-1,1,Conv,[64,3,2]],表示 输入:上一层、重复:1、模块:Conv、参数:64,3,2。

  • YOLO-pose head 实现(最关键),核心类 class PoseDetect(Detect),继承来于 Detect。核心参数有 nc → 类别数、nk → keypoints。如果 21 keypoints,那么输出维度 21 * 3 = 63。Pose Head输出结构,假设 1 class,21 keypoints;那么输出 4 bbox、1 cls、63 keypoints,总计68 dims,tensor:[B, anchors, 68]。

  • keypoint prediction 实现,关键代码逻辑 kpt = x[:, 5:]。x = [bbox + class + keypoints],拆分 bbox = x[:, :4]、cls = x[:, 4:5]、kpt = x[:, 5:],源码核心结构(简化):

    • class PoseDetect(Detect):
          def __init__(self, nc=1, kpt_shape=(21,3), ch=()):
              super().__init__(nc, ch) # 初始化 detection head
              self.kpt_shape = kpt_shape
              self.nk = kpt_shape[0] * kpt_shape[1]
      # Head卷积结构
      self.cv4 = nn.ModuleList(
          nn.Conv2d(x, self.nk, 1) for x in ch
      )
      
    • 每个 feature map 都有一个1×1 conv。每个 feature map 都会预测 63 channels。Detect head 会创建 cv2 → bbox conv、cv3 → class conv。如果三层 feature:P3、P4、P5。就有三个 keypoint head。

  • loss计算,ultralytics/utils/loss.py。核心类 v8PoseLoss,loss结构 box loss、cls loss、keypoint loss、dfl loss。关键点loss一般是 L1 loss,形式为 |pred - gt|,只计算 visibility > 0。vis = gt_kpt[...,2] > 0

  • dataset pipeline,data/dataset.py。关键类是PoseDataset,流程 image–>load label–>load keypoints–>augmentation–>tensor.

  • 数据增强,data/augment.py,关键函数RandomPerspective,同步变换 image、bbox、keypoints。

  • trainer结构,engine/trainer.py,核心类pose trainer,继承自 BaseTrainer。流程 train()–>build_dataset()–>build_model()–>forward–>loss–>optimizer。forward过程

    • def forward(self, x): # x = [P3, P4, P5]
      
      P3 [B, C, 80,80]
      P4 [B, C, 40,40]
      P5 [B, C, 20,20]
      ```
      z = []
      for i in range(self.nl):
          box = self.cv2[i](x[i])
          cls = self.cv3[i](x[i])
          kpt = self.cv4[i](x[i])
          y = torch.cat((box, cls, kpt), 1)
          z.append(y)
      return z
      

      reshape为anchor形式

      y = y.view(bs, -1, ny*nx)
      y = y.permute(0,2,1) # [B, anchors, dims]

      
      
    • 每个 feature map 输出 bbox、class、keypoints 然后 concat。keypoint reshape kpt = kpt.view(bs, self.nk, -1).最终 reshape [B, 63, H*W],再转换为 [B, anchors, 63]

    • bbox prediction 计算来自 box = self.cv2[i](x[i]),输出 [B,4,H,W],代表 tx ty tw th。class prediction,cls = self.cv3[i](x[i]),输出 [B,nc,H,W]。keypoint prediction ,kpt = self.cv4[i](x[i]),输出 [B,63,H,W],表示 21 keypoints × (x,y,v)。concat y = torch.cat((box, cls, kpt),1),输出 [B, 4+nc+63 ,H,W]。

    • INPUT IMAGE
           │
           ▼
      DataLoader
      (PoseDataset)
           │
           ▼
      Augmentation
      (RandomPerspective)
           │
           ▼
      Image Tensor
      [B,3,H,W]
           │
           ▼
      Backbone
      (C2f + Conv)
           │
           ▼
      Feature Maps
      P3
      P4
      P5
           │
           ▼
      Neck
      (FPN + PAN)
           │
           ▼
      Multi-scale Features
           │
           ▼
      PoseDetect Head
           │
           ├── bbox prediction
           │
           ├── class prediction
           │
           └── keypoint prediction
                 │
                 ▼
      Tensor
      [B, anchors, 4 + 1 + 63]
           │
           ▼
      Loss Builder
      (v8PoseLoss)
           │
           ├── box_loss
           ├── cls_loss
           ├── dfl_loss
           └── kpt_loss
           │
           ▼
      Total Loss
           │
           ▼
      Backward
           │
           ▼
      Optimizer Update
      
  • 推理流程,engine/predictor.py。关键步骤 image–>preprocess–>model forward–>NMS–>decode keypoints

  • 关键点decode 输出是 relative bbox、relative keypoints。需要 * stride,然后 + offset。转换成 image coordinates。

    • kpt[..., 0::3] = (kpt[..., 0::3] * 2 + grid_x) * stride
      kpt[..., 1::3] = (kpt[..., 1::3] * 2 + grid_y) * stride
      
  • import torch
    import torch.nn as nn
    import torch.nn.functional as F
    # Backbone
    class Conv(nn.Module):
        def __init__(self,c1,c2,k=3,s=1):
            super().__init__()
            p = k // 2  # 自动计算padding
            self.conv = nn.Conv2d(c1,c2,k,s,p)
            self.bn = nn.BatchNorm2d(c2)
            self.act = nn.SiLU()
        def forward(self,x):
            return self.act(self.bn(self.conv(x)))
    class Backbone(nn.Module):
        def __init__(self):
            super().__init__()
            self.layer1 = Conv(3,32,3,2)
            self.layer2 = Conv(32,64,3,2)
            self.layer3 = Conv(64,128,3,2)
        def forward(self,x):
            x = self.layer1(x)
            p3 = self.layer2(x)
            p4 = self.layer3(p3)
            return [p3,p4]
    class SimpleNeck(nn.Module):
        def __init__(self):
            super().__init__()
            # 统一P3通道到128
            self.p3_conv = Conv(64, 128, 1)
            # 融合后处理
            self.fuse_conv = Conv(256, 128, 3)
        def forward(self, feats):
            p3, p4 = feats
            # P3上采样到P4尺度
            p3 = self.p3_conv(p3)
            p3_up = F.interpolate(p3, scale_factor=2, mode='nearest')
            # 通道维度拼接融合
            fused = torch.cat([p3_up, p4], dim=1)
            return self.fuse_conv(fused)
    # Pose Head
    class PoseHead(nn.Module):
        def __init__(self,nc=1,nk=21):
            super().__init__()
            self.nc = nc
            self.nk = nk*3
            self.box = nn.Conv2d(128,4,1)
            self.cls = nn.Conv2d(128,nc,1)
            self.kpt = nn.Conv2d(128,self.nk,1)
        def forward(self,x):
            box = self.box(x)
            cls = self.cls(x)
            kpt = self.kpt(x)
            out = torch.cat([box,cls,kpt],1)
            return out
    class MultiScalePoseHead(nn.Module):
        def __init__(self, nc=1, nk=21):
            super().__init__()
            self.nc = nc
            self.nk = nk * 3  # x, y, conf
            # P3头(64通道)
            self.p3_box = nn.Conv2d(64, 4, 1)
            self.p3_cls = nn.Conv2d(64, nc, 1)
            self.p3_kpt = nn.Conv2d(64, self.nk, 1)
            # P4头(128通道)
            self.p4_box = nn.Conv2d(128, 4, 1)
            self.p4_cls = nn.Conv2d(128, nc, 1)
            self.p4_kpt = nn.Conv2d(128, self.nk, 1)
        def forward(self, feats):
            p3, p4 = feats
            # P3输出
            p3_out = torch.cat([self.p3_box(p3), self.p3_cls(p3), self.p3_kpt(p3)], 1)
            # P4输出
            p4_out = torch.cat([self.p4_box(p4), self.p4_cls(p4), self.p4_kpt(p4)], 1)
            # 返回多尺度列表(损失计算时分别处理)
            return [p3_out, p4_out]
    # Model
    class MiniYOLOPose(nn.Module):
        def __init__(self,multiHead=true):
            super().__init__()
            self.multiHead = multiHead
            self.backbone = Backbone()
            if multiHead:
                self.head = MultiScalePoseHead()  # 多尺度头
            else:
            	self.neck = SimpleNeck()  
            	self.head = PoseHead()
        def forward(self,x):
            feats = self.backbone(x)
            if self.multiHead:
                out = self.head(feats)
            else:
            	fused = self.neck(feats) 
            	out = self.head(fused)
            return out
    # Test
    model = MiniYOLOPose()
    img = torch.randn(1,3,640,640)
    out = model(img)
    print(out.shape)
    
  • 数据加载代码 (utils/dataloader.py)

    • import os
      import cv2
      import numpy as np
      import torch
      from torch.utils.data import Dataset, DataLoader
      from torchvision import transforms
      import random
      class PoseDataset(Dataset):
          def __init__(self, img_dir, label_dir, img_size=640, nk=21, augment=True):
              """
              img_dir: 图片文件夹路径
              label_dir: 标注文件夹路径
              img_size: 输入图像尺寸
              nk: 关键点数量
              augment: 是否数据增强
              """
              self.img_dir = img_dir
              self.label_dir = label_dir
              self.img_size = img_size
              self.nk = nk
              self.augment = augment
              # 获取所有图片路径
              self.img_files = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
              self.label_files = [f.replace('.jpg', '.txt').replace('.png', '.txt') for f in self.img_files]
          def __len__(self):
              return len(self.img_files)
          def __getitem__(self, idx):
              # 读取图片
              img_path = os.path.join(self.img_dir, self.img_files[idx])
              img = cv2.imread(img_path)
              img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
              h, w = img.shape[:2]
              # 读取标注
              label_path = os.path.join(self.label_dir, self.label_files[idx])
              labels = self._load_labels(label_path, w, h)
              # 数据增强
              if self.augment:
                  img, labels = self._augment(img, labels)
              # 归一化到0-1
              img = img.astype(np.float32) / 255.0
              img = transforms.ToTensor()(img)
              return img, labels
          def _load_labels(self, label_path, img_w, img_h):
              """
              加载YOLO格式标注
              格式: class cx cy w h kpt1_x kpt1_y kpt1_vis kpt2_x kpt2_y kpt2_vis ...
              """
              labels = []
              if os.path.exists(label_path):
                  with open(label_path, 'r') as f:
                      for line in f:
                          data = list(map(float, line.strip().split()))
                          if len(data) < 5 + self.nk * 3:
                              continue
                          cls = data[0]
                          cx, cy, bw, bh = data[1:5]
                          kpts = np.array(data[5:]).reshape(-1, 3)  # (nk, 3)
                          # 转换为绝对坐标
                          cx, cy = cx * img_w, cy * img_h
                          bw, bh = bw * img_w, bh * img_h
                          kpts[:, 0] *= img_w  # kpt_x
                          kpts[:, 1] *= img_h  # kpt_y
                          # kpts[:, 2] 保持可见性 0/1/2
                          labels.append({
                              'cls': cls,
                              'box': np.array([cx, cy, bw, bh]),
                              'kpts': kpts
                          })
              return labels
          def _augment(self, img, labels):
              """简单数据增强"""
              # 随机水平翻转
              if random.random() > 0.5:
                  img = cv2.flip(img, 1)
                  h, w = img.shape[:2]
                  for label in labels:
                      label['box'][0] = w - label['box'][0]  # cx翻转
                      label['kpts'][:, 0] = w - label['kpts'][:, 0]  # kpt_x翻转
              
              # 随机缩放
              scale = random.uniform(0.8, 1.2)
              new_size = int(self.img_size * scale)
              img = cv2.resize(img, (new_size, new_size))
              for label in labels:
                  label['box'] *= scale
                  label['kpts'][:, :2] *= scale
              # 裁剪到目标尺寸
              img = cv2.resize(img, (self.img_size, self.img_size))
              scale_factor = self.img_size / new_size
              for label in labels:
                  label['box'] *= scale_factor
                  label['kpts'][:, :2] *= scale_factor
              return img, labels
      def collate_fn(batch):
          """自定义collate函数,处理不同数量的目标"""
          imgs = []
          targets = []
          
          for i, (img, labels) in enumerate(batch):
              imgs.append(img)
              for label in labels:
                  targets.append([
                      i,                      # 图片索引
                      label['cls'],           # 类别
                      *label['box'],          # cx, cy, w, h
                      *label['kpts'].flatten()  # 关键点展平
                  ])
          imgs = torch.stack(imgs, 0)
          targets = torch.tensor(targets, dtype=torch.float32) if targets else torch.zeros((0, 5 + 21*3))
          return imgs, targets
      def create_dataloader(img_dir, label_dir, batch_size=16, img_size=640, nk=21, augment=True):
          dataset = PoseDataset(img_dir, label_dir, img_size, nk, augment)
          dataloader = DataLoader(
              dataset,
              batch_size=batch_size,
              shuffle=True,
              num_workers=4,
              collate_fn=collate_fn,
              pin_memory=True
          )
          return dataloader
      
  • 损失函数代码 (utils/loss.py)

    • import torch
      import torch.nn as nn
      import torch.nn.functional as F
      import math
      class YOLOPoseLoss(nn.Module):
          def __init__(self, nc=1, nk=21, box_weight=7.5, cls_weight=0.5, kpt_weight=0.05):
              """
              nc: 类别数
              nk: 关键点数量
              box_weight: 边界框损失权重
              cls_weight: 分类损失权重
              kpt_weight: 关键点损失权重
              """
              super().__init__()
              self.nc = nc
              self.nk = nk
              self.box_weight = box_weight
              self.cls_weight = cls_weight
              self.kpt_weight = kpt_weight
              # 边界框损失
              self.bce_cls = nn.BCEWithLogitsLoss(reduction='none')
              self.bce_kpt = nn.BCEWithLogitsLoss(reduction='none')
              # OKS阈值(用于关键点匹配)
              self.oks_sigmas = torch.ones(nk) * 0.025  # 可根据数据集调整
          def forward(self, preds, targets, img_size=640):
              """
              preds: 模型输出 [B, 4+nc+nk*3, H, W]
              targets: 标注 [N, 5+nk*3] (batch_idx, cls, cx, cy, w, h, kpts...)
              img_size: 输入图像尺寸
              """
              device = preds.device
              bs, _, h, w = preds.shape
              stride = img_size / h  # 假设单尺度,实际需根据P3/P4分别计算    
              # 解析预测值
              preds = preds.permute(0, 2, 3, 1).reshape(bs, -1, 4 + self.nc + self.nk * 3)
              box_pred = preds[..., :4]           # (B, N, 4)
              cls_pred = preds[..., 4:4+self.nc]  # (B, N, nc)
              kpt_pred = preds[..., 4+self.nc:]   # (B, N, nk*3)
              # 初始化损失
              loss_box = torch.tensor(0.0, device=device)
              loss_cls = torch.tensor(0.0, device=device)
              loss_kpt = torch.tensor(0.0, device=device)
              n_pos = 0
              if len(targets) > 0:
                  # 获取锚点网格
                  grid_y, grid_x = torch.meshgrid(torch.arange(h, device=device), 
                                                  torch.arange(w, device=device), indexing='ij')
                  anchors = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0).expand(bs, -1, -1, -1)
                  anchors = anchors.reshape(bs, -1, 2) * stride  # (B, H*W, 2)
                  # 匹配策略:简化版IoU匹配
                  for b in range(bs):
                      batch_targets = targets[targets[:, 0] == b]
                      if len(batch_targets) == 0:
                          continue
                      gt_boxes = batch_targets[:, 2:6]  # cx, cy, w, h
                      gt_kpts = batch_targets[:, 6:].reshape(-1, self.nk, 3)  # (N_gt, nk, 3)
                      gt_cls = batch_targets[:, 1]
                      # 计算预测框与GT的IoU(简化)
                      # 这里使用中心点距离作为匹配依据
                      pred_cx = box_pred[b, :, 0]
                      pred_cy = box_pred[b, :, 1]
                      matches = []
                      for i, gt in enumerate(gt_boxes):
                          # 找到最近的预测点
                          dist = (pred_cx - gt[0])**2 + (pred_cy - gt[1])**2
                          best_match = torch.argmin(dist)
                          matches.append((best_match, i))
                      # 计算损失
                      for pred_idx, gt_idx in matches:
                          n_pos += 1
                          # 1. 边界框损失 (CIoU简化版)
                          pred_box = box_pred[b, pred_idx]
                          gt_box = gt_boxes[gt_idx]
                          loss_box += self._box_loss(pred_box, gt_box)
                          # 2. 分类损失
                          cls_target = torch.zeros(self.nc, device=device)
                          cls_target[int(gt_cls[gt_idx])] = 1.0
                          loss_cls += self.bce_cls(cls_pred[b, pred_idx], cls_target).sum()
                          # 3. 关键点损失
                          pred_kpt = kpt_pred[b, pred_idx].reshape(self.nk, 3)
                          gt_kpt = gt_kpts[gt_idx]
                          loss_kpt += self._kpt_loss(pred_kpt, gt_kpt, stride)
                  # 平均损失
                  if n_pos > 0:
                      loss_box = loss_box / n_pos
                      loss_cls = loss_cls / n_pos
                      loss_kpt = loss_kpt / n_pos
              # 加权总损失
              total_loss = (self.box_weight * loss_box + 
                           self.cls_weight * loss_cls + 
                           self.kpt_weight * loss_kpt)
              return total_loss, {
                  'box': loss_box.item(),
                  'cls': loss_cls.item(),
                  'kpt': loss_kpt.item(),
                  'total': total_loss.item()
              }
          def _box_loss(self, pred, gt):
              """简化CIoU损失"""
              pred_cx, pred_cy, pred_w, pred_h = pred.chunk(4, -1)
              gt_cx, gt_cy, gt_w, gt_h = gt.chunk(4, -1)
              # 转换为x1y1x2y2
              pred_x1 = pred_cx - pred_w / 2
              pred_y1 = pred_cy - pred_h / 2
              pred_x2 = pred_cx + pred_w / 2
              pred_y2 = pred_cy + pred_h / 2
              gt_x1 = gt_cx - gt_w / 2
              gt_y1 = gt_cy - gt_h / 2
              gt_x2 = gt_cx + gt_w / 2
              gt_y2 = gt_cy + gt_h / 2
              # IoU计算
              inter_x1 = torch.max(pred_x1, gt_x1)
              inter_y1 = torch.max(pred_y1, gt_y1)
              inter_x2 = torch.min(pred_x2, gt_x2)
              inter_y2 = torch.min(pred_y2, gt_y2)
              inter_w = torch.clamp(inter_x2 - inter_x1, min=0)
              inter_h = torch.clamp(inter_y2 - inter_y1, min=0)
              inter_area = inter_w * inter_h
              pred_area = pred_w * pred_h
              gt_area = gt_w * gt_h
              union_area = pred_area + gt_area - inter_area
              iou = torch.clamp(inter_area / (union_area + 1e-6), min=0, max=1)
              # 中心点距离
              center_dist = (pred_cx - gt_cx)**2 + (pred_cy - gt_cy)**2
              # 对角线距离
              cw = torch.max(pred_x2, gt_x2) - torch.min(pred_x1, gt_x1)
              ch = torch.max(pred_y2, gt_y2) - torch.min(pred_y1, gt_y1)
              diag_dist = cw**2 + ch**2
              ciou = iou - center_dist / (diag_dist + 1e-6)
              return 1 - ciou
          
          def _kpt_loss(self, pred_kpt, gt_kpt, stride):
              """
              关键点损失 = 坐标L1损失 + 可见性BCE损失
              pred_kpt: (nk, 3) [x, y, conf]
              gt_kpt: (nk, 3) [x, y, vis]
              """
              # 坐标损失(只计算可见的关键点)
              vis_mask = gt_kpt[:, 2] > 0  # 可见性>0
              if vis_mask.sum() == 0:
                  return torch.tensor(0.0, device=pred_kpt.device)
              coord_loss = F.l1_loss(pred_kpt[vis_mask, :2], gt_kpt[vis_mask, :2], reduction='sum')
              # 可见性损失
              vis_target = (gt_kpt[:, 2] > 0).float()
              vis_loss = self.bce_kpt(pred_kpt[:, 2], vis_target).sum()
              return (coord_loss / stride + vis_loss) / self.nk
      
  • 训练脚本 (train.py)

    • import torch
      import torch.nn as nn
      from torch.optim import AdamW
      from torch.optim.lr_scheduler import CosineAnnealingLR
      from tqdm import tqdm
      import os
      from models.yolo_pose import MiniYOLOPose
      from utils.dataloader import create_dataloader
      from utils.loss import YOLOPoseLoss
      def train(
          img_dir='data/images',
          label_dir='data/labels',
          epochs=100,
          batch_size=16,
          img_size=640,
          lr=0.001,
          device='cuda',
          save_dir='weights'
      ):
          # 创建保存目录
          os.makedirs(save_dir, exist_ok=True)
          # 设备
          device = torch.device(device if torch.cuda.is_available() else 'cpu')
          print(f"Using device: {device}")
          # 模型
          model = MiniYOLOPose().to(device)
          # 数据加载
          train_loader = create_dataloader(
              img_dir, label_dir, 
              batch_size=batch_size, 
              img_size=img_size,
              augment=True
          )
          # 损失函数
          criterion = YOLOPoseLoss(nc=1, nk=21).to(device)
          # 优化器
          optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0.0005)
          scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
          # 训练循环
          best_loss = float('inf')
          for epoch in range(epochs):
              model.train()
              epoch_loss = 0
              epoch_box_loss = 0
              epoch_cls_loss = 0
              epoch_kpt_loss = 0
              pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
              for imgs, targets in pbar:
                  imgs = imgs.to(device)
                  targets = targets.to(device)
                  # 前向传播
                  preds = model(imgs)
                  # 计算损失
                  loss, loss_dict = criterion(preds, targets, img_size=img_size)
                  # 反向传播
                  optimizer.zero_grad()
                  loss.backward()
                  optimizer.step()
                  # 记录损失
                  epoch_loss += loss_dict['total']
                  epoch_box_loss += loss_dict['box']
                  epoch_cls_loss += loss_dict['cls']
                  epoch_kpt_loss += loss_dict['kpt']
                  pbar.set_postfix({
                      'loss': f"{loss_dict['total']:.4f}",
                      'box': f"{loss_dict['box']:.4f}",
                      'cls': f"{loss_dict['cls']:.4f}",
                      'kpt': f"{loss_dict['kpt']:.4f}"
                  })
              # 学习率更新
              scheduler.step()
              # 平均损失
              n_batches = len(train_loader)
              avg_loss = epoch_loss / n_batches
              avg_box = epoch_box_loss / n_batches
              avg_cls = epoch_cls_loss / n_batches
              avg_kpt = epoch_kpt_loss / n_batches
              print(f"\nEpoch {epoch+1} Summary:")
              print(f"  Total Loss: {avg_loss:.4f}")
              print(f"  Box Loss: {avg_box:.4f}")
              print(f"  Cls Loss: {avg_cls:.4f}")
              print(f"  Kpt Loss: {avg_kpt:.4f}")
              # 保存最佳模型
              if avg_loss < best_loss:
                  best_loss = avg_loss
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': avg_loss,
                  }, os.path.join(save_dir, 'best_pose.pt'))
                  print(f"Saved best model with loss {best_loss:.4f}")
              # 每10epoch保存一次
              if (epoch + 1) % 10 == 0:
                  torch.save({
                      'epoch': epoch,
                      'model_state_dict': model.state_dict(),
                      'optimizer_state_dict': optimizer.state_dict(),
                      'loss': avg_loss,
                  }, os.path.join(save_dir, f'pose_epoch_{epoch+1}.pt'))
          print("\n Training completed!")
      if __name__ == '__main__':
          train(
              img_dir='data/images',
              label_dir='data/labels',
              epochs=100,
              batch_size=16,
              img_size=640,
              lr=0.001,
              device='cuda',
              save_dir='weights'
          )
      
  • 评估指标代码 (utils/metrics.py)

    • import torch
      import numpy as np
      from typing import Tuple, List, Dict
      import warnings
      class PoseMetrics:
          """YOLO-Pose 评估指标计算器"""
          def __init__(self, nc=1, nk=21, iou_thresholds=None, oks_sigmas=None):
              """
              nc: 类别数
              nk: 关键点数量
              iou_thresholds: IoU阈值列表(用于mAP)
              oks_sigmas: 关键点归一化因子(用于OKS)
              """
              self.nc = nc
              self.nk = nk
              self.iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05)
              # COCO关键点OKS sigmas(可根据数据集调整)
              if oks_sigmas is None:
                  # COCO 17 keypoints sigmas
                  self.oks_sigmas = np.array([
                      0.026, 0.025, 0.025, 0.035, 0.035,  # 鼻子、眼
                      0.079, 0.072, 0.062, 0.079, 0.072,  # 耳、肩
                      0.062, 0.107, 0.087, 0.089, 0.107,  # 肘、腕、臀
                      0.087, 0.089, 0.107, 0.087, 0.089,  # 膝、踝
                  ][:nk])
              else:
                  self.oks_sigmas = np.array(oks_sigmas)
          def compute_iou(self, box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
              """
              计算边界框IoU
              box1: (N, 4) [cx, cy, w, h]
              box2: (M, 4) [cx, cy, w, h]
              return: (N, M) IoU矩阵
              """
              # 转换为x1y1x2y2
              b1_x1 = box1[:, 0] - box1[:, 2] / 2
              b1_y1 = box1[:, 1] - box1[:, 3] / 2
              b1_x2 = box1[:, 0] + box1[:, 2] / 2
              b1_y2 = box1[:, 1] + box1[:, 3] / 2
              
              b2_x1 = box2[:, 0] - box2[:, 2] / 2
              b2_y1 = box2[:, 1] - box2[:, 3] / 2
              b2_x2 = box2[:, 0] + box2[:, 2] / 2
              b2_y2 = box2[:, 1] + box2[:, 3] / 2
              
              # 交集
              inter_x1 = torch.max(b1_x1[:, None], b2_x1[None, :])
              inter_y1 = torch.max(b1_y1[:, None], b2_y1[None, :])
              inter_x2 = torch.min(b1_x2[:, None], b2_x2[None, :])
              inter_y2 = torch.min(b1_y2[:, None], b2_y2[None, :])
              
              inter_w = torch.clamp(inter_x2 - inter_x1, min=0)
              inter_h = torch.clamp(inter_y2 - inter_y1, min=0)
              inter_area = inter_w * inter_h
              
              # 并集
              b1_area = box1[:, 2] * box1[:, 3]
              b2_area = box2[:, 2] * box2[:, 3]
              union_area = b1_area[:, None] + b2_area[None, :] - inter_area
              
              iou = inter_area / (union_area + 1e-6)
              return iou
          
          def compute_oks(self, pred_kpts: np.ndarray, gt_kpts: np.ndarray, 
                          gt_area: np.ndarray) -> np.ndarray:
              """
              计算OKS (Object Keypoint Similarity)
              pred_kpts: (N, nk, 3) [x, y, conf]
              gt_kpts: (M, nk, 3) [x, y, vis]
              gt_area: (M,) 目标面积
              return: (N, M) OKS矩阵
              """
              n_pred = pred_kpts.shape[0]
              n_gt = gt_kpts.shape[0]
              
              if n_pred == 0 or n_gt == 0:
                  return np.zeros((n_pred, n_gt))
              
              oks = np.zeros((n_pred, n_gt))
              
              for i in range(n_pred):
                  for j in range(n_gt):
                      # 关键点距离
                      d = pred_kpts[i, :, :2] - gt_kpts[j, :, :2]  # (nk, 2)
                      d_squared = np.sum(d ** 2, axis=1)  # (nk,)
                      
                      # 可见性掩码
                      vis = gt_kpts[j, :, 2] > 0
                      
                      if vis.sum() == 0:
                          oks[i, j] = 0
                          continue
                      
                      # OKS公式
                      numerator = np.exp(-d_squared / (2 * (gt_area[j] * self.oks_sigmas) ** 2))
                      oks[i, j] = np.sum(numerator * vis) / vis.sum()
              
              return oks
          
          def compute_ap(self, scores: np.ndarray, ious: np.ndarray) -> float:
              """
              计算单个类别的AP
              scores: (N,) 置信度分数
              ious: (N,) 与GT的IoU/OKS
              """
              if len(scores) == 0:
                  return 0.0
              
              # 按置信度排序
              sorted_idx = np.argsort(scores)[::-1]
              scores = scores[sorted_idx]
              ious = ious[sorted_idx]
              
              # 计算precision-recall
              n_gt = len(ious)
              tp = np.zeros(len(ious))
              fp = np.zeros(len(ious))
              
              matched = np.zeros(len(ious))
              for i in range(len(ious)):
                  if ious[i] >= 0.5:  # IoU阈值
                      if matched[i] == 0:
                          tp[i] = 1
                          matched[i] = 1
                      else:
                          fp[i] = 1
                  else:
                      fp[i] = 1
              
              # 累积
              tp = np.cumsum(tp)
              fp = np.cumsum(fp)
              
              recall = tp / (n_gt + 1e-6)
              precision = tp / (tp + fp + 1e-6)
              
              # 计算AP(11点插值)
              ap = 0.0
              for t in np.arange(0, 1.1, 0.1):
                  if np.sum(recall >= t) == 0:
                      p = 0
                  else:
                      p = np.max(precision[recall >= t])
                  ap += p / 11
              
              return ap
          
          def evaluate(self, pred_boxes: torch.Tensor, pred_kpts: torch.Tensor,
                       pred_scores: torch.Tensor, pred_cls: torch.Tensor,
                       gt_boxes: torch.Tensor, gt_kpts: torch.Tensor,
                       gt_cls: torch.Tensor, img_area: float) -> Dict:
              """
              单次评估
              所有输入都是单张图片的预测和标注
              """
              device = pred_boxes.device
              
              # 按类别分别评估
              results = {
                  'map50': 0.0,
                  'map50_95': 0.0,
                  'oks_map50': 0.0,
                  'oks_map50_95': 0.0,
              }
              
              for c in range(self.nc):
                  # 筛选当前类别
                  pred_mask = pred_cls == c
                  gt_mask = gt_cls == c
                  
                  pred_b = pred_boxes[pred_mask]
                  pred_k = pred_kpts[pred_mask]
                  pred_s = pred_scores[pred_mask]
                  
                  gt_b = gt_boxes[gt_mask]
                  gt_k = gt_kpts[gt_mask]
                  
                  if len(gt_b) == 0:
                      continue
                  
                  # 1. 检测mAP (基于边界框IoU)
                  if len(pred_b) > 0:
                      iou_matrix = self.compute_iou(pred_b, gt_b).cpu().numpy()
                      
                      # 对每个预测,找到最佳匹配GT
                      best_ious = np.max(iou_matrix, axis=1) if iou_matrix.shape[1] > 0 else np.zeros(len(pred_b))
                      
                      # mAP@50
                      ap50 = self.compute_ap(pred_s.cpu().numpy(), best_ious)
                      results['map50'] += ap50
                      
                      # mAP@50:95
                      ap_list = []
                      for thresh in self.iou_thresholds:
                          ious_thresh = (iou_matrix >= thresh).any(axis=1).astype(float)
                          ap = self.compute_ap(pred_s.cpu().numpy(), ious_thresh)
                          ap_list.append(ap)
                      results['map50_95'] += np.mean(ap_list)
                  
                  # 2. 姿态mAP (基于OKS)
                  if len(pred_b) > 0:
                      # 计算GT面积
                      gt_areas = (gt_b[:, 2] * gt_b[:, 3]).cpu().numpy()
                      
                      pred_kpts_np = pred_k.reshape(-1, self.nk, 3).cpu().numpy()
                      gt_kpts_np = gt_k.reshape(-1, self.nk, 3).cpu().numpy()
                      
                      oks_matrix = self.compute_oks(pred_kpts_np, gt_kpts_np, gt_areas)
                      
                      # 对每个预测,找到最佳匹配GT
                      best_oks = np.max(oks_matrix, axis=1) if oks_matrix.shape[1] > 0 else np.zeros(len(pred_b))
                      
                      # OKS-mAP@50
                      oks_ap50 = self.compute_ap(pred_s.cpu().numpy(), best_oks)
                      results['oks_map50'] += oks_ap50
                      
                      # OKS-mAP@50:95
                      oks_ap_list = []
                      for thresh in self.iou_thresholds:
                          oks_thresh = (oks_matrix >= thresh).any(axis=1).astype(float)
                          ap = self.compute_ap(pred_s.cpu().numpy(), oks_thresh)
                          oks_ap_list.append(ap)
                      results['oks_map50_95'] += np.mean(oks_ap_list)
              
              # 平均到各类别
              if self.nc > 0:
                  for k in results:
                      results[k] /= self.nc
              
              return results
      class PoseEvaluator:
          """完整数据集评估器"""
          def __init__(self, model, dataloader, nc=1, nk=21, device='cuda', img_size=640):
              self.model = model
              self.dataloader = dataloader
              self.nc = nc
              self.nk = nk
              self.device = device
              self.img_size = img_size
              self.metrics = PoseMetrics(nc=nc, nk=nk)
          
          def evaluate(self, conf_thresh=0.25, iou_thresh=0.6) -> Dict:
              """
              在整个验证集上评估
              """
              self.model.eval()
              
              all_preds = []
              all_gts = []
              
              with torch.no_grad():
                  for imgs, targets in self.dataloader:
                      imgs = imgs.to(self.device)
                      targets = targets.to(self.device)
                      # 模型推理
                      output = self.model(imgs)
                      # 解析预测
                      preds = self._parse_output(output, conf_thresh, iou_thresh)
                      # 解析标注
                      gts = self._parse_targets(targets, imgs.shape[0])
                      all_preds.extend(preds)
                      all_gts.extend(gts)
              
              # 汇总评估
              return self._aggregate_metrics(all_preds, all_gts)
          
          def _parse_output(self, output: torch.Tensor, conf_thresh: float, 
                            iou_thresh: float) -> List[Dict]:
              """解析模型输出为每张图片的预测"""
              bs, _, h, w = output.shape
              stride = self.img_size / h
              # 解析输出
              output = output.permute(0, 2, 3, 1).reshape(bs, -1, 4 + self.nc + self.nk * 3)
              
              box_pred = output[..., :4]
              cls_pred = output[..., 4:4+self.nc]
              kpt_pred = output[..., 4+self.nc:]
              # 置信度 = cls_score * box_conf (简化)
              cls_scores, cls_idx = torch.max(cls_pred, dim=-1)
              results = []
              for b in range(bs):
                  # NMS简化版(按置信度排序取top)
                  scores = cls_scores[b]
                  mask = scores > conf_thresh
                  if mask.sum() == 0:
                      results.append({
                          'boxes': torch.zeros((0, 4)),
                          'kpts': torch.zeros((0, self.nk, 3)),
                          'scores': torch.zeros(0),
                          'cls': torch.zeros(0, dtype=torch.long)
                      })
                      continue
                  
                  boxes = box_pred[b, mask] * stride
                  kpts = kpt_pred[b, mask].reshape(-1, self.nk, 3) * stride
                  kpts[..., 2] = torch.sigmoid(kpts[..., 2])  # 可见性置信度
                  scores = scores[mask]
                  cls = cls_idx[b, mask]
                  
                  # 简单NMS
                  keep = self._nms(boxes, scores, iou_thresh)
                  
                  results.append({
                      'boxes': boxes[keep],
                      'kpts': kpts[keep],
                      'scores': scores[keep],
                      'cls': cls[keep]
                  })
              
              return results
          
          def _parse_targets(self, targets: torch.Tensor, batch_size: int) -> List[Dict]:
              """解析标注为每张图片的GT"""
              results = []
              for b in range(batch_size):
                  batch_targets = targets[targets[:, 0] == b]
                  
                  if len(batch_targets) == 0:
                      results.append({
                          'boxes': torch.zeros((0, 4)),
                          'kpts': torch.zeros((0, self.nk, 3)),
                          'cls': torch.zeros(0, dtype=torch.long)
                      })
                      continue
                  
                  boxes = batch_targets[:, 2:6]
                  kpts = batch_targets[:, 6:].reshape(-1, self.nk, 3)
                  cls = batch_targets[:, 1].long()
                  
                  results.append({
                      'boxes': boxes,
                      'kpts': kpts,
                      'cls': cls
                  })
              
              return results
          
          def _nms(self, boxes: torch.Tensor, scores: torch.Tensor, 
                   iou_thresh: float) -> torch.Tensor:
              """简单NMS实现"""
              if len(boxes) == 0:
                  return torch.zeros(0, dtype=torch.long)
              
              x1 = boxes[:, 0] - boxes[:, 2] / 2
              y1 = boxes[:, 1] - boxes[:, 3] / 2
              x2 = boxes[:, 0] + boxes[:, 2] / 2
              y2 = boxes[:, 1] + boxes[:, 3] / 2
              
              areas = (x2 - x1) * (y2 - y1)
              order = scores.argsort(descending=True)
              
              keep = []
              while order.numel() > 0:
                  i = order[0]
                  keep.append(i)
                  
                  if order.numel() == 1:
                      break
                  
                  order = order[1:]
                  
                  xx1 = torch.max(x1[i], x1[order])
                  yy1 = torch.max(y1[i], y1[order])
                  xx2 = torch.min(x2[i], x2[order])
                  yy2 = torch.min(y2[i], y2[order])
                  
                  w = torch.clamp(xx2 - xx1, min=0)
                  h = torch.clamp(yy2 - yy1, min=0)
                  inter = w * h
                  
                  iou = inter / (areas[i] + areas[order] - inter + 1e-6)
                  order = order[iou <= iou_thresh]
              
              return torch.tensor(keep, dtype=torch.long)
          
          def _aggregate_metrics(self, all_preds: List[Dict], 
                                all_gts: List[Dict]) -> Dict:
              """汇总所有图片的评估结果"""
              map50_list = []
              map50_95_list = []
              oks_map50_list = []
              oks_map50_95_list = []
              
              for preds, gts in zip(all_preds, all_gts):
                  if len(gts['boxes']) == 0:
                      continue
                  
                  img_area = self.img_size ** 2
                  
                  result = self.metrics.evaluate(
                      preds['boxes'], preds['kpts'], preds['scores'], preds['cls'],
                      gts['boxes'], gts['kpts'], gts['cls'], img_area
                  )
                  
                  map50_list.append(result['map50'])
                  map50_95_list.append(result['map50_95'])
                  oks_map50_list.append(result['oks_map50'])
                  oks_map50_95_list.append(result['oks_map50_95'])
              
              return {
                  'mAP@50': np.mean(map50_list) if map50_list else 0.0,
                  'mAP@50:95': np.mean(map50_95_list) if map50_95_list else 0.0,
                  'OKS-mAP@50': np.mean(oks_map50_list) if oks_map50_list else 0.0,
                  'OKS-mAP@50:95': np.mean(oks_map50_95_list) if oks_map50_95_list else 0.0,
              }
      
  • 推理代码 (utils/predictor.py)

    • import torch
      import cv2
      import numpy as np
      from typing import List, Dict, Tuple
      import os
      class YOLOPosePredictor:
          """YOLO-Pose 推理器"""
          def __init__(self, model_path: str, nc=1, nk=21, img_size=640, 
                       conf_thresh=0.25, iou_thresh=0.6, device='cuda'):
              """
              model_path: 模型权重路径
              nc: 类别数
              nk: 关键点数量
              img_size: 输入尺寸
              conf_thresh: 置信度阈值
              iou_thresh: NMS IoU阈值
              """
              self.nc = nc
              self.nk = nk
              self.img_size = img_size
              self.conf_thresh = conf_thresh
              self.iou_thresh = iou_thresh
              self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
              
              # 加载模型
              from models.yolo_pose import MiniYOLOPose
              self.model = MiniYOLOPose()
              checkpoint = torch.load(model_path, map_location=self.device)
              self.model.load_state_dict(checkpoint['model_state_dict'])
              self.model.to(self.device)
              self.model.eval()
              print(f"Model loaded from {model_path}")
              print(f"Device: {self.device}")
          def preprocess(self, img: np.ndarray) -> Tuple[torch.Tensor, Dict]:
              """
              预处理图像
              return: tensor, meta_info (用于后处理)
              """
              h, w = img.shape[:2]
              # 保持宽高比resize
              scale = min(self.img_size / h, self.img_size / w)
              new_h, new_w = int(h * scale), int(w * scale)
              # resize
              img_resized = cv2.resize(img, (new_w, new_h))
              # padding到目标尺寸
              pad_h = (self.img_size - new_h) // 2
              pad_w = (self.img_size - new_w) // 2
              img_padded = cv2.copyMakeBorder(
                  img_resized, pad_h, self.img_size - new_h - pad_h,
                  pad_w, self.img_size - new_w - pad_w,
                  cv2.BORDER_CONSTANT, value=114
              )
              # 归一化
              img_norm = img_padded.astype(np.float32) / 255.0
              img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)
              
              meta = {
                  'original_shape': (h, w),
                  'resize_shape': (new_h, new_w),
                  'scale': scale,
                  'pad': (pad_h, pad_w)
              }
              return img_tensor.to(self.device), meta
          def postprocess(self, output: torch.Tensor, meta: Dict) -> List[Dict]:
              """
              后处理:解析输出 + NMS + 坐标还原
              """
              bs, _, h, w = output.shape
              stride = self.img_size / h
              # 解析输出
              output = output.permute(0, 2, 3, 1).reshape(bs, -1, 4 + self.nc + self.nk * 3)
              box_pred = output[..., :4]
              cls_pred = output[..., 4:4+self.nc]
              kpt_pred = output[..., 4+self.nc:]
              # 置信度
              cls_scores, cls_idx = torch.max(cls_pred, dim=-1)
              results = []
              for b in range(bs):
                  scores = cls_scores[b]
                  mask = scores > self.conf_thresh
                  if mask.sum() == 0:
                      results.append([])
                      continue
                  boxes = box_pred[b, mask] * stride
                  kpts = kpt_pred[b, mask].reshape(-1, self.nk, 3) * stride
                  kpts[..., 2] = torch.sigmoid(kpts[..., 2])
                  scores = scores[mask]
                  cls = cls_idx[b, mask]
                  # NMS
                  keep = self._nms(boxes, scores, self.iou_thresh)
                  # 坐标还原到原图
                  boxes = self._restore_coords(boxes, meta)
                  kpts = self._restore_kpts(kpts, meta)
                  # 转换为numpy
                  det_results = []
                  for i in range(len(keep)):
                      idx = keep[i]
                      det_results.append({
                          'box': boxes[idx].cpu().numpy(),
                          'kpts': kpts[idx].cpu().numpy(),
                          'score': scores[idx].cpu().item(),
                          'cls': cls[idx].cpu().item()
                      })
                  
                  results.append(det_results)
              
              return results
          
          def _nms(self, boxes: torch.Tensor, scores: torch.Tensor, 
                   iou_thresh: float) -> torch.Tensor:
              """NMS"""
              if len(boxes) == 0:
                  return torch.zeros(0, dtype=torch.long)
              
              x1 = boxes[:, 0] - boxes[:, 2] / 2
              y1 = boxes[:, 1] - boxes[:, 3] / 2
              x2 = boxes[:, 0] + boxes[:, 2] / 2
              y2 = boxes[:, 1] + boxes[:, 3] / 2
              
              areas = (x2 - x1) * (y2 - y1)
              order = scores.argsort(descending=True)
              
              keep = []
              while order.numel() > 0:
                  i = order[0]
                  keep.append(i)
                  
                  if order.numel() == 1:
                      break
                  
                  order = order[1:]
                  
                  xx1 = torch.max(x1[i], x1[order])
                  yy1 = torch.max(y1[i], y1[order])
                  xx2 = torch.min(x2[i], x2[order])
                  yy2 = torch.min(y2[i], y2[order])
                  
                  w = torch.clamp(xx2 - xx1, min=0)
                  h = torch.clamp(yy2 - yy1, min=0)
                  inter = w * h
                  
                  iou = inter / (areas[i] + areas[order] - inter + 1e-6)
                  order = order[iou <= iou_thresh]
              
              return torch.tensor(keep, dtype=torch.long)
          
          def _restore_coords(self, boxes: torch.Tensor, meta: Dict) -> torch.Tensor:
              """还原边界框坐标到原图"""
              scale = meta['scale']
              pad_h, pad_w = meta['pad']
              
              # 去掉padding
              boxes[:, 0] -= pad_w  # cx
              boxes[:, 1] -= pad_h  # cy
              boxes[:, 2] /= scale  # w
              boxes[:, 3] /= scale  # h
              
              return boxes
          
          def _restore_kpts(self, kpts: torch.Tensor, meta: Dict) -> torch.Tensor:
              """还原关键点坐标到原图"""
              scale = meta['scale']
              pad_h, pad_w = meta['pad']
              
              kpts[..., 0] -= pad_w  # x
              kpts[..., 1] -= pad_h  # y
              kpts[..., :2] /= scale
              
              return kpts
          
          def predict(self, img_path: str) -> List[Dict]:
              """
              单张图片推理
              return: 检测结果列表
              """
              # 读取图片
              img = cv2.imread(img_path)
              img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
              
              # 预处理
              img_tensor, meta = self.preprocess(img)
              
              # 推理
              with torch.no_grad():
                  output = self.model(img_tensor)
              
              # 后处理
              results = self.postprocess(output, meta)
              
              return results[0]  # 单张图片
          
          def predict_batch(self, img_paths: List[str]) -> List[List[Dict]]:
              """
              批量推理
              return: 每张图片的检测结果
              """
              all_results = []
              
              for img_path in img_paths:
                  results = self.predict(img_path)
                  all_results.append(results)
              
              return all_results
          
          def predict_video(self, video_path: str, output_path: str, 
                           show=True, save=True):
              """
              视频推理
              """
              from utils.visualize import draw_pose
              
              cap = cv2.VideoCapture(video_path)
              fps = cap.get(cv2.CAP_PROP_FPS)
              w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
              h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
              
              if save:
                  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
                  out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
              
              frame_idx = 0
              while True:
                  ret, frame = cap.read()
                  if not ret:
                      break
                  # BGR to RGB
                  img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                  # 预处理
                  img_tensor, meta = self.preprocess(img)
                  # 推理
                  with torch.no_grad():
                      output = self.model(img_tensor)
                  # 后处理
                  results = self.postprocess(output, meta)[0]
                  # 可视化
                  vis_img = draw_pose(frame, results, nk=self.nk)
                  if show:
                      cv2.imshow('YOLO-Pose', vis_img)
                      if cv2.waitKey(1) & 0xFF == ord('q'):
                          break
                  if save:
                      out.write(vis_img)
                  frame_idx += 1
                  print(f"\rProcessing frame {frame_idx}", end='')
              cap.release()
              if save:
                  out.release()
              cv2.destroyAllWindows()
              print(f"\n Video processing completed!")
      
  • 针对手部 “小目标、多视角” 特点,增加随机裁剪(聚焦手部)、旋转(-30°~30°,模拟不同手势角度)、缩放(0.5x~1.5x,适配不同距离)、遮挡增强(随机遮挡指节,模拟握拳 / 交叉手);避免过度增强(如颜色抖动过大),防止丢失手部纹理特征。

  • 静态手势识别(如点赞、比心):仅需 YOLO-Pose 输出 21 个关键点,再通过 SVM/MLP 分类关键点坐标;动态手势识别(如挥手、握拳→展开):需结合时序模型(如 LSTM、Transformer),将连续帧的关键点坐标作为输入,捕捉动作时序特征。

  • 1. 定义kpt_shape=21,3
    2. 适配关键点输出

    YOLO

    底层调用

    model.train
    model.val

    model.train<>

    model.predict<>

    model.export<>

    配置准备
    class:config

    数据集yaml
    hand-keypoints.yaml
    nc=1, kpt_shape=21,3

    模型yaml
    yolo11n-pose.yaml
    Pose层: <1, <21,3>>

    模型初始化
    class:api

    加载预训练权重
    绑定kpt_shape配置

    PoseModel.init
    构建Backbone+Neck+Pose Head
    class:core

    数据加载
    class:api

    build_dataset<>
    校验kpt_shape合法性
    class:core

    YOLODataset加载
    图像+标注解析
    输出: batch3640640 + batchn*67

    数据增强
    mosaic/旋转/遮挡
    关键点同步变换

    模型训练
    class:api

    PoseTrainer.init
    初始化优化器/EMA/损失函数
    class:core

    前向传播
    模型输出: batch672020
    67=4+21
    3<关键点>

    损失计算
    PoseLoss.forward<>
    box_loss+pose_loss+kobj_loss
    class:core

    反向传播
    更新模型权重

    验证环节
    PoseValidator.val<>
    计算mAP_pose50-95

    推理/导出
    class:api

    PosePredictor.predict<>
    解析输出-还原关键点坐标
    class:core

    输出结果
    bbo* + 关键点

    Exporter.export_onnx/trt
    导出部署格式
    class:core

    • 明确kpt_shape=[21,3]从「配置文件」→「模型初始化」→「数据加载」→「训练损失计算」的全链路传递;标注核心输入输出维度(如batch×67×20×20,67=4+21×3),贴合手势场景;高层 API(如model.train())与底层核心接口(如PoseTrainer)的调用关系一目了然。
  • Ultralytics 框架的核心代码位于 ultralytics/ 目录下,YOLO-Pose 的定制化逻辑集中在 “任务特定类”(如 PoseTrainer/PoseValidator)和 “配置文件” 中。

核心环节源码剖析(从入口到落地)

  • 入口:统一 API 与配置解析。核心文件

    • ultralytics/models/model.pyYOLO 类(所有任务的统一入口)。

    • ultralytics/cfg/__init__.py:配置解析函数(get_cfgmerge_cfg)。

    • ultralytics/cfg/default.yaml:全局默认超参。

    • 接口逻辑(以手势训练为例)

      • # 二次开发入口:Python API
        from ultralytics import YOLO
        # 1. 初始化YOLO类:加载模型配置+权重
        # 若自定义手势模型,可传入修改后的yolo11-pose.yaml(如kpt_shape改为[21,3])
        model = YOLO("yolo11n-pose.pt")  # 或 "/custom-yolo11-pose.yaml"
        # 2. 训练:传入超参(合并default.yaml+命令行/代码参数)
        model.train(
            data="hand-keypoints.yaml",  # 自定义数据集配置
            epochs=100,
            imgsz=960,  # 手势小目标适配
            device="0",
            batch=16
        )
        
      • YOLO.__init__:调用 _load 方法,若传入 .pt 权重则加载模型 + 配置;若传入 .yaml 则仅初始化模型结构。YOLO.train:调用 self._smart_load("trainer") 动态加载 PoseTrainer(根据模型任务类型自动选择),并传入合并后的配置。配置合并优先级:代码传入参数 > 数据集yaml > 模型yaml > default.yaml

  • 数据加载与增强:关键点标注解析与增强核心文件

    • ultralytics/data/dataset.pyYOLODataset 类(继承 torch.utils.data.Dataset),负责加载图像 + 解析关键点标注。

    • ultralytics/data/augmentations.pyAlbumentationsMosaicMixUp 等增强类,支持关键点同步变换。

    • ultralytics/data/utils.pyverify_image_label 函数(验证标注合法性)。

    • 接口逻辑(手势数据加载流程)

      • 数据集配置解析:读取 hand-keypoints.yaml,获取 path(数据根目录)、train/val(图像路径)、kpt_shape(关键点数量,如 [21, 3])。
      • 标注解析YOLODataset.get_labels() 读取 .txt 标注文件,格式为 class_id x_center y_center w h kpt1_x kpt1_y kpt1_v ... kpt21_x kpt21_y kpt21_vv=0不可见,v=1可见但遮挡,v=2可见)。
      • 数据增强YOLODataset.build_transforms() 构建增强流水线,关键点随图像同步变换(如旋转时,调用 augmentations.keypoints_format 调整坐标)。
    • 若标注格式非 YOLO 标准,可重写 YOLODataset.get_labels() 中的解析逻辑。在 augmentations.py 中添加新增强类(如针对手部的 “指节遮挡增强”),并在 YOLODataset.build_transforms() 中注册。

  • 模型构建:PoseModel 网络结构 核心文件

    • ultralytics/models/tasks.pyBaseModel(基类)、PoseModel(继承 BaseModel,关键点检测头)。

    • ultralytics/nn/modules.pyC3k2(YOLO11 骨干模块)、ConvSPPF 等基础组件。

    • ultralytics/cfg/models/yolo11-pose.yaml:模型结构配置。

    • # 模型配置
      nc: 1  # 类别数(手势检测通常仅1类:hand)
      kpt_shape: [21, 3]  # 关键点形状:[数量, 维度(x/y/可见性)],手势改为21
      scales: # 模型缩放系数(n/s/m/l/x)
        n: [0.25, 0.25, 1024]  # [depth_multiple, width_multiple, max_channels]
      
      # 网络结构(backbone + neck + head)
      backbone:
        [[-1, 1, Conv, [64, 3, 2]],  # 0-P1/2
         [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
         [-1, 2, C3k2, [256, True]], # 2
         [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
         [-1, 2, C3k2, [512, True]], # 4
         [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
         [-1, 2, C3k2, [512, True]], # 6
         [-1, 1, Conv, [1024, 3, 2]],# 7-P5/32
         [-1, 2, C3k2, [1024, True]],# 8
         [-1, 1, SPPF, [1024, 5]],   # 9
        ]
      head:
        [[-1, 1, Conv, [512, 1, 1]],
         [-1, 1, nn.Upsample, [None, 2, 'nearest']],
         [[-1, 6], 1, Concat, [1]],  # 拼接backbone第6层
         [-1, 2, C3k2, [512, False]],
         [-1, 1, Conv, [256, 1, 1]],
         [-1, 1, nn.Upsample, [None, 2, 'nearest']],
         [[-1, 4], 1, Concat, [1]],  # 拼接backbone第4层
         [-1, 2, C3k2, [256, False]], # 15 (P3/8-small)
         [-1, 1, Conv, [256, 3, 2]],
         [[-1, 12], 1, Concat, [1]], # 拼接head第12层
         [-1, 2, C3k2, [512, False]], # 18 (P4/16-medium)
         [-1, 1, Conv, [512, 3, 2]],
         [[-1, 9], 1, Concat, [1]],  # 拼接backbone第9层
         [-1, 2, C3k2, [1024, False]],# 21 (P5/32-large)
         [[15, 18, 21], 1, Pose, [nc, kpt_shape]],  # Pose检测头:输入3个尺度特征
        ]
      
    • PoseModel.__init__:解析 yolo11-pose.yaml,调用 parse_model 函数构建网络(根据配置列表动态拼接模块)。Pose 检测头(ultralytics/nn/modules.py):输出 3 个分支,检测分支:边界框(reg)、类别(cls);关键点分支:关键点坐标(kpt,形状 [batch, num_anchors, 21*3])、关键点可见性(kobj)。

  • 模型训练:PoseTrainer 训练循环

    • ultralytics/engine/trainer.pyBaseTrainer(基类,通用训练逻辑)、PoseTrainer(继承 BaseTrainer,关键点特定逻辑)。

    • ultralytics/utils/torch_utils.pyModelEMA(指数移动平均)、init_seeds(随机种子)。

    • 训练器动态加载YOLO._smart_load("trainer") 根据模型任务类型(task=pose)返回 PoseTrainer,若需自定义训练逻辑(如添加新的训练 trick),可继承 PoseTrainer 并重写 _train_epoch损失计算调用PoseTrainer._get_loss() 返回 PoseLoss 实例,在 _train_epoch 中:

    • # 前向传播
      preds = self.model(batch["img"])
      # 计算损失:loss是总损失,loss_items是各分项损失(box/cls/pose/kobj)
      loss, loss_items = self.compute_loss(preds, batch)
      # 反向传播
      loss.backward()
      self.optimizer.step()
      self.ema.update(self.model)  # EMA更新
      
    • 训练是 YOLO-Pose 最复杂的环节,单独拆解该子流程,展示「数据→前向→损失→反向传播」的接口调用细节。

    • 图像: batch3640640
      标注: batch
      n*67

      未到epochs

      达到epochs

      训练批次数据
      class:input

      PoseTrainer._train_epoch<>
      单轮训练循环
      class:process

      model>
      模型前向预测

      输出解析
      检测分支: bbox+cls
      关键点分支: 21*3坐标+可见性
      class:process

      PoseLoss.forward<>
      1. 计算box_loss/cls_loss
      2. 计算pose_loss<21关键点坐标>
      3. 计算kobj_loss<可见性>
      class:process

      总损失 = box_loss + pose_loss12 + kobj_loss1
      <权重在default.yaml配置>

      loss.backward<>
      梯度计算

      optimizer.step<>
      权重更新
      class:process

      EMA.update<>
      模型EMA更新

      每epoch/间隔验证
      PoseValidator
      class:process

      保存模型权重
      best.pt/latest.pt
      class:output

      训练完成
      输出训练日志/损失曲线
      class:output

  • 损失计算:PoseLoss 关键点损失

    • ultralytics/nn/loss.pyPoseLoss 类(继承 DetectionLoss,添加关键点损失)。

    • # 源码位置:ultralytics/nn/loss.py → PoseLoss
      class PoseLoss(DetectionLoss):
          def __init__(self, *args, **kwargs):
              super().__init__(*args, **kwargs)
              self.kpt_shape = self.model.model[-1].kpt_shape  # 获取[21,3]
      
          def forward(self, preds, batch):
              # 1. 先计算检测损失(box/cls/dfl,继承自DetectionLoss)
              loss, loss_items = super().forward(preds, batch)
              
              # 2. 计算关键点损失
              feats = preds[1] if isinstance(preds, tuple) else preds  # 提取特征
              kpt_pred = feats.view(..., sum(self.kpt_shape)).split(self.kpt_shape, -1)[0]  # 关键点预测
              kpt_target = batch["keypoints"]  # 关键点真值
              
              # 关键点坐标损失(Smooth L1,仅计算可见关键点v>0)
              mask = kpt_target[..., 2] > 0  # 可见性掩码
              pose_loss = F.smooth_l1_loss(kpt_pred[mask], kpt_target[..., :2][mask], beta=1.0)
              
              # 关键点可见性损失(BCE,若需预测可见性)
              kobj_loss = F.binary_cross_entropy_with_logits(...)  # 可选
              
              # 3. 合并损失
              loss += pose_loss * self.hyp["pose"]  # pose是损失权重,在default.yaml中配置
              loss_items = torch.cat((loss_items, pose_loss.detach().unsqueeze(0)))
              return loss, loss_items
      
    • default.yaml 中修改 pose: 12.0(关键点损失权重),若手势指尖易出错,可在 PoseLoss 中对指尖关键点单独加权。若需添加 “关键点拓扑约束损失”(如手指长度比例),可在 PoseLoss.forward 中新增损失项。

  • 模型验证与评估:PoseValidator 与 mAP_pose

    • ultralytics/engine/validator.pyBaseValidator(基类)、PoseValidator(关键点验证逻辑)。
    • ultralytics/utils/metrics.pyPoseMetrics 类(计算 mAP_pose50-95)。
    • PoseValidator.__call__:遍历验证集,对每个 batch 执行前向传播→NMS 后处理→匹配真值与预测→调用 PoseMetrics 累积指标。
    • 核心指标:mAP_pose50-95(关键点检测平均精度)、mAP_pose50(IOU=0.5 时的精度)。
  • 模型保存与导出:Exporter

    • ultralytics/engine/exporter.pyExporter 类(支持导出 ONNX/TensorRT/OpenVINO 等)。
  • 模型推理:PosePredictor

    • ultralytics/engine/predictor.pyBasePredictor(基类)、PosePredictor(关键点推理后处理)。

    • HW3

      推理源
      图像/视频/摄像头
      class:source

      预处理
      class:pre

      resize到imgsz<640>
      归一化到0,1
      转tensor: 13640*640

      model.predict<>
      调用PosePredictor
      class:infer

      前向传播
      输出: 16720*20

      NMS过滤重复bbox
      class:infer

      关键点坐标还原
      从网格坐标-原图像素坐标
      class:post

      输出Results对象
      boxes-n4
      keypoints: n
      21*3
      class:post

      可视化
      绘制bbox+21关键点

      结果导出
      JSON/CSV/图片标注

    • 推理「输入→预处理→推理→后处理→输出」的线性流程;突出「关键点坐标还原」(从归一化网格坐标转回原图像素坐标),这是推理的核心后处理步骤;

  • 二次开发关键路径总结

    • 二次开发需求 核心修改文件 / 类 关键逻辑点
      自定义手势数据集 data/dataset.py(YOLODataset) 重写 get_labels() 解析标注
      自定义数据增强 data/augmentations.py 添加新增强类,在 build_transforms 注册
      修改关键点数量(21→自定义) cfg/models/yolo11-pose.yaml 修改 kpt_shape
      调整损失函数 nn/loss.py(PoseLoss) 重写 forward() 添加损失项
      自定义训练逻辑 engine/trainer.py(继承 PoseTrainer) 重写 _train_epoch()
      自定义推理后处理 engine/predictor.py(继承 PosePredictor) 重写 postprocess()
  • 针对手势 21 个关键点(kpt_shape=[21,3],3 代表 x/y/ 可见性),需先明确核心配置约定:数据集 yaml(如hand-keypoints.yaml):指定nc: 1(仅 “手部” 1 类)、kpt_shape: [21, 3]、数据集路径、类别名(names: ['hand']);模型 yaml(如yolo11n-pose.yaml):输出层Pose模块参数为[nc, kpt_shape],即[1, [21,3]]。核心流程 Mermaid 流程图

    • 训练核心接口

      模型构建接口

      数据加载接口

      配置阶段接口

      配置阶段
      自定义hand-keypoints.yaml+模型yaml

      数据加载
      build_dataset/YOLO.load

      模型构建
      PoseModel初始化

      训练阶段
      PoseTrainer.train

      评估阶段
      PoseValidator.val

      推理阶段
      PosePredictor.predict

      输出评估指标
      mAP_pose等

      输出推理结果
      bbox+21关键点

      编写hand-keypoints.yaml
      指定kpt_shape=21,3

      修改yolo11n-pose.yaml
      Pose层参数适配=21,3

      YOLO-model.train
      底层调用build_dataset

      数据增强-mosaic,旋转,缩放
      输出tensor化数据

      YOLO-yolo11n_pose.pt
      加载预训练权重

      PoseModel-nc=1, kpt_shape=21,3
      初始化网络

      PoseTrainer.get_model
      绑定kpt_shape

      PoseTrainer.get_dataset
      校验kpt_shape存在

      前向传播-模型输出关键点预测

      损失计算-box_loss+pose_loss+kobj_loss

  • 配置阶段

    • 配置项 内容示例 作用
      数据集 yaml(hand-keypoints.yaml) nc: 1``kpt_shape: [21, 3]``train: ./train/images``names: ['hand'] 告诉模型:检测 1 类目标(手)、关键点数量 21、每个关键点含 x/y/ 可见性 3 维度
      模型 yaml(yolo11n-pose.yaml) 末尾层:[-[16,19,22], 1, Pose, [1, [21,3]]] 定义 Pose 输出层适配 21 关键点,nc=1(手部类别)
  • 数据加载(接口:YOLO.train()/build_dataset

    • 调用方式 输入类型 / 维度 输出类型 / 维度 核心逻辑
      高层接口 图像文件:H×W×3(RGB)标注 txt:每行class x1 y1 x2 y2 k1x k1y k1v ... k21x k21y k21v 训练数据集对象:图像 tensor:batch×3×imgsz×imgsz(imgsz 默认 640)标注 tensor:batch×n_boxes×(4 + 21×3)(4=bbox 坐标,21×3 = 关键点) 1. 读取图像并 resize 到 imgsz;2. 解析标注 txt,将关键点坐标归一化到 [0,1];3. 数据增强(马赛克 / 旋转 / 遮挡);4. 打包成批次 tensor
      底层接口(build_dataset) data=hand-keypoints.yaml, mode='train' 同高层输出 校验数据集是否包含kpt_shape,否则抛出 KeyError
  • 模型构建(接口:YOLO()/PoseModel()

    • 调用方式 输入类型 / 维度 输出类型 / 维度 核心逻辑
      高层接口 YOLO("yolo11n-pose.pt")(预训练权重)或YOLO("yolo11n-pose.yaml")(空模型) PoseModel 实例:输入层:batch×3×imgsz×imgsz输出层:batch×(1×4 + 21×3)×grid_h×grid_w(grid=imgsz/32,如 640→20) 1. 加载预训练权重(复用人体关键点特征);2. 调用PoseModel初始化,绑定kpt_shape=[21,3];3. 输出层维度适配:1 类 bbox(4 维)+21 关键点(3 维)
      底层接口(PoseModel) cfg=yolo11n-pose.yaml, nc=1, kpt_shape=[21,3] 同高层输出 构建 Backbone(C2PSA)+ Neck(C3k2)+ Head(Pose 层),Pose 层输出关键点预测
  • 训练(接口:model.train()/PoseTrainer.train()

    • 调用方式 输入类型 / 维度 输出类型 / 维度 核心逻辑
      高层接口 model.train(data='hand-keypoints.yaml', epochs=100, imgsz=640) 训练结果对象:损失值(box_loss/pose_loss/kobj_loss):标量模型权重文件:yolo11n-pose.pt 1. 调用PoseTrainer.get_dataset加载数据;2. 调用PoseTrainer.get_model构建模型;3. 前向传播:模型预测 bbox + 关键点;4. 损失计算: - box_loss:bbox 回归损失 - pose_loss:关键点坐标回归损失 - kobj_loss:关键点可见性损失5. 反向传播更新权重
      底层接口(PoseTrainer) 批次数据:batch×3×640×640(图像)、batch×n_boxes×(4+63)(标注) 损失 tensor:标量 重写 DetectionTrainer,新增 pose_loss/kobj_loss 计算逻辑
  • 评估(接口:model.val()/PoseValidator.val()

    • 调用方式 输入类型 / 维度 输出类型 / 维度 核心逻辑
      高层接口 model.val(data='hand-keypoints.yaml') 评估指标字典:mAP_pose50-95:标量(0-1)mAP_pose50:标量(0-1)速度:ms / 帧 1. 加载验证集数据;2. 模型推理验证集;3. 计算关键点检测精度(mAP):对比预测关键点与标注关键点的距离误差
      底层接口(PoseValidator) 验证集数据 + 模型实例 同高层输出 重写 DetectionValidator,适配关键点 mAP 计算逻辑
  • 推理(接口:model.predict()/PosePredictor.predict()

    • 调用方式 输入类型 / 维度 输出类型 / 维度 核心逻辑
      高层接口 图像:H×W×3(RGB)/ 视频 / 摄像头流model.predict(source='img.jpg', imgsz=640) Results 对象:results[0].boxes.xyxyn×4(n = 检测到手数,4=bbox 坐标)results[0].keypoints.datan×21×3(21 关键点,3=x/y/ 可见性)可见性 v:0(不可见)/1(可见) 1. 图像 resize 到 imgsz;2. 模型前向预测;3. NMS 过滤重复 bbox;4. 关键点坐标还原到原图尺寸
      底层接口(PosePredictor) 图像 tensor:1×3×640×640 预测 tensor:1×(4+63)×20×20 → 解码后n×4+n×21×3 解析模型输出层 tensor,还原关键点坐标到原图尺度
  • 数据流转核心维度链

    • 原始输入(图像:H×W×3 + 标注:1×(4+63))
          ↓(resize+归一化)
      训练输入(batch×3×640×640 + batch×n_boxes×67)
          ↓(模型前向)
      模型输出(batch×67×20×20)→ 67=1×4(bbox)+21×3(关键点)、20=640/32
          ↓(解码+NMS)
      预测结果(n×4(bbox) +21×3(关键点))
          ↓(评估)
      mAP指标(标量)
      
    • 标注 txt 中的关键点坐标是「原图像素坐标」,加载时会归一化到 [0,1](除以 imgsz),并转换为 tensor;模型输出的关键点坐标是「网格相对坐标」,需乘以网格步长(32/16/8)和 imgsz,还原到原图像素尺度;pose_loss计算时,会将预测关键点坐标与标注坐标(归一化后)做 MSE 损失,kobj_loss针对关键点可见性(v)做二分类交叉熵损失。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐