【Ultralytics】「7」深度解析 统一模型接口 Model 类:训练、推理、导出的入口枢纽
Model 是 Ultralytics YOLO 框架的门面类(Facade),继承自 torch.nn.Module,为 7 种模型变体(YOLO、YOLOWorld、YOLOE、RTDETR、SAM、FastSAM、NAS)提供统一的训练、验证、推理、导出、跟踪和调优接口。它不是一个简单的包装器——而是一个策略调度中枢,通过 _smart_load + task_map 二元机制,在运行时根据任务类型动态选择正确的 Trainer/Predictor/Validator/Model 实现,同时通过多层配置合并策略确保默认参数、用户覆盖和方法级默认值的优先级正确。
类继承体系:从抽象基座到具体变体
Model 类作为抽象基座,定义了所有模型共享的行为契约。每个具体子类只需通过 task_map 属性声明自己的任务→实现映射,即可获得完整的训练/推理/导出能力。这种设计使得新增模型类型只需实现一个子类和一个字典。
上图中,实线三角箭头表示继承关系。Model 的 task_map 是一个标记为 raise NotImplementedError 的抽象属性,强制所有具体子类必须提供任务映射。值得注意的是,NAS 和 FastSAM 的 task_map 只包含 predictor 和 validator,不支持训练——它们是纯推理型模型。
Sources: model.py, yolo/model.py, rtdetr/model.py, sam/model.py, nas/model.py, fastsam/model.py
初始化流程:模型来源的六路分发
Model 的 __init__ 方法是一个六路分发器,根据传入的 model 参数类型,走完全不同的初始化路径。理解这个分发逻辑是掌握 Model 类行为的关键。
属性初始化快照
无论走哪条路径,Model 都会先初始化以下核心状态属性为空值或默认值:
| 属性 | 初始值 | 用途 |
|---|---|---|
callbacks |
callbacks.get_default_callbacks() |
事件回调字典 |
predictor |
None |
预测器实例(惰性创建) |
model |
None |
底层 PyTorch 模型 |
trainer |
None |
训练器实例 |
ckpt |
{} |
检查点数据 |
cfg |
None |
YAML 配置路径 |
ckpt_path |
None |
检查点文件路径 |
overrides |
{} |
配置覆盖字典 |
metrics |
None |
验证/训练指标 |
session |
None |
HUB 会话 |
task |
构造参数 | 任务类型 |
model_name |
None |
模型名称 |
初始化的最后一步是 del self.training——这不是 bug,而是有意为之。因为 Model 继承自 torch.nn.Module,后者有 self.training 布尔属性。删除它后,通过 __getattr__ 的代理机制,对 model.training 的访问会直接透传到底层模型,避免了状态不同步的问题。
Sources: model.py
_new 与 _load:两条构建路径的对比
_new(cfg):从 YAML 配置构建新模型
当传入 .yaml 或 .yml 文件时,_new 方法通过 yaml_model_load 解析配置字典,再通过 guess_model_task 从配置中推断任务类型(如果未显式指定),最后调用 self._smart_load("model") 获取任务对应的模型类(如 DetectionModel),实例化底层 PyTorch 模型。
关键步骤:yaml_model_load 会将 YAML 路径中的规模标识符(如 yolo26n 中的 n)提取为 cfg["scale"],供 parse_model 在构建网络时使用来缩放通道数。
_load(weights):从权重文件加载模型
当传入 .pt 或其他权重文件时,_load 处理两种子情况:
| 权重类型 | 处理方式 | 赋值 |
|---|---|---|
.pt 文件 |
load_checkpoint() 反序列化 |
self.model = PyTorch模型, self.ckpt = 检查点字典 |
其他格式(如 .onnx) |
仅记录路径,不加载 | self.model = 路径字符串, self.ckpt = None |
对于 .pt 文件,load_checkpoint 函数会完成 EMA 模型提取、参数合并、任务推断、模型融合评估等一系列操作。对于远程权重(URL),会自动下载到本地 SETTINGS["weights_dir"] 目录。
任务推断的四级降级策略
guess_model_task 函数实现了一个四级降级策略来推断模型的任务类型:
| 优先级 | 推断来源 | 示例 |
|---|---|---|
| 1 | 配置字典的 head 末尾模块名 |
cfg["head"][-1][-2] 包含 “segment” → "segment" |
| 2 | PyTorch 模型的 args["task"] 或 yaml |
从模型属性中读取 |
| 3 | PyTorch 模型的模块类型 | 遍历 model.modules() 检测头类型 |
| 4 | 文件名后缀 | -seg → "segment", -cls → "classify" |
| 兜底 | 默认值 | "detect" |
Sources: tasks.py
_smart_load:策略模式的核心调度
_smart_load 是 Model 类架构的心脏。它将 [task][key] 的二维查找委托给子类的 task_map 属性,实现了策略模式的动态调度:
def _smart_load(self, key: str):
try:
return self.task_map[self.task][key]
except Exception as e:
name = self.__class__.__name__
mode = inspect.stack()[1][3] # 获取调用方函数名
raise NotImplementedError(
f"'{name}' model does not support '{mode}' mode for '{self.task}' task."
) from e
key 的取值及其被调用的场景:
| key 值 | 调用场景 | 返回类型 |
|---|---|---|
"model" |
_new() 中构建新模型 |
模型类(如 DetectionModel) |
"trainer" |
train() 中创建训练器 |
BaseTrainer 子类 |
"validator" |
val() 中创建验证器 |
BaseValidator 子类 |
"predictor" |
predict() 中创建预测器 |
BasePredictor 子类 |
错误信息中通过 inspect.stack()[1][3] 获取调用方函数名(如 train、predict),使得错误信息能精确地告知用户"哪个操作不支持哪个任务",而非暴露内部实现细节。
Sources: model.py
配置合并的三层覆盖策略
Model 类中所有主要方法(train、val、predict、export)都遵循统一的三层配置合并模式:
args = {**self.overrides, **custom, **kwargs, "mode": "xxx"}
合并顺序(右侧覆盖左侧):
self.overrides → custom(方法默认值)→ kwargs(用户参数)→ mode(强制覆盖)
| 层级 | 来源 | 含义 |
|---|---|---|
| 1 | self.overrides |
模型初始化时确定的参数(如 model、task) |
| 2 | custom |
当前方法的方法级默认值(如 predict 的 conf=0.25) |
| 3 | kwargs |
用户调用时传入的参数,优先级最高 |
| 4 | "mode" |
强制覆盖操作模式(如 "train"、"predict") |
这种设计确保了:用户参数永远优先,方法默认值次之,模型初始化参数兜底。以 predict 方法为例,它的 custom 默认值包括 conf=0.25(置信度阈值)、batch=1、save=is_cli(仅在 CLI 模式下保存)和 rect=True(矩形推理)。
Sources: model.py, model.py, model.py, model.py
核心方法详解
predict():惰性预测器与流式推理
predict 方法实现了预测器惰性创建和设备感知重建两个关键模式。预测器仅在首次调用或设备变更时创建:
if not self.predictor or self.predictor.args.device != args.get("device", ...):
# 首次调用或设备变更 → 创建新预测器
self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else:
# 已有预测器 → 仅更新参数
self.predictor.args = get_cfg(self.predictor.args, args)
对于 CLI 调用和 Python API 调用,最终走不同的执行路径:CLI 调用 predictor.predict_cli(source),Python API 调用 predictor(source, stream)。此外,当 source 为 None 时,会使用默认的 ASSETS 图片(OBB 任务使用 boats.jpg),并输出警告。
Sources: model.py
train():HUB 集成与检查点恢复
train 是 Model 类中最复杂的方法。它处理以下场景:
- HUB 训练覆盖:如果存在活跃的 HUB session,本地参数被忽略,全部使用 HUB 传来的训练参数
- 断点续训:当
resume=True时,检查当前检查点是否包含有效的epoch和optimizer状态 - 权重传递优化:对于非续训场景,直接将已加载的模型传递给 trainer,避免重新下载远程权重
- 训练后状态更新:训练完成后,重新加载 best/last 检查点,重置
overrides和metrics
训练完成后,Model 实例的状态会被完全刷新——self.model 指向新训练的最佳模型,self.ckpt 包含最新检查点数据,self.metrics 携带验证指标。
Sources: model.py
export():格式无关的导出委托
export 方法通过 Exporter 类处理所有导出逻辑,自身仅负责参数组装和调用:
args = {**self.overrides, **custom, **kwargs, "mode": "export"}
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
方法级默认值会将 device 重置为 None(避免多 GPU 错误),batch 设为 1,data 设为 None,并将 imgsz 继承自模型参数。具体的导出格式支持(ONNX、TensorRT、OpenVINO 等 16 种格式)由 Exporter 类内部处理,详见 模型导出引擎:支持 ONNX、TensorRT、OpenVINO 等 16 种格式。
Sources: model.py
val():轻量级验证委托
val 方法是最简洁的核心方法之一。它创建验证器、运行验证、保存指标并返回:
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics
注意 validator(model=self.model) 的调用方式——验证器是可调用对象,底层会完成数据加载、推理、指标计算的全部流程。
Sources: model.py
track():跟踪器的动态注册
track 是 predict 的增强版,在调用 predict 之前完成两件事:
- 跟踪器注册:首次调用时通过
register_tracker注入 ByteTrack/BoTSORT - 参数预设:设置
conf=0.1(ByteTrack 需要低置信度输入)和batch=1(视频跟踪要求逐帧处理)
跟踪功能本质上是"predict + 后处理跟踪",详见 目标跟踪器:ByteTrack 与 BoTSORT 集成。
Sources: model.py
embed():特征嵌入提取
embed 是 predict 的语义化包装。当用户不指定 embed 参数时,默认提取模型倒数第二层的输出作为图像嵌入:
if not kwargs.get("embed"):
kwargs["embed"] = [len(self.model.model) - 2]
return self.predict(source, stream, **kwargs)
Sources: model.py
YOLO 子类的自动路由机制
YOLO 类在 __init__ 中实现了一个三级自动路由,根据模型文件名自动切换到合适的子类:
这种类自变(self-mutation) 模式通过 self.__class__ = type(new_instance) 和 self.__dict__ = new_instance.__dict__ 实现,使得 YOLO("yolov8s-world.pt") 返回的实例实际上是 YOLOWorld 类型——尽管用户代码中的变量类型标注仍然是 YOLO。
Sources: yolo/model.py
属性代理与设备管理
__getattr__:透明的属性委托
Model 重写了 __getattr__,将对 Model 实例的属性访问透传到底层模型:
def __getattr__(self, name):
return self._modules["model"] if name == "model" else getattr(self.model, name)
这意味着 model.stride、model.names、model.yaml 等属性实际上来自底层 PyTorch 模型。唯一例外是 name == "model" 时返回 torch.nn.Module 的 _modules["model"],这确保了 self.model 访问能正确获取子模块注册表中的模型对象。
_apply:设备迁移的安全网
_apply 重写确保了模型在设备迁移(.cuda()、.cpu()、.half())时的一致性:
def _apply(self, fn):
self._check_is_pytorch_model()
self = super()._apply(fn)
self.predictor = None # 重置预测器(设备可能已变更)
self.overrides["device"] = self.device # 同步设备信息
return self
每次设备迁移都会强制重置预测器,因为预测器可能缓存了与特定设备绑定的张量。这是避免设备不匹配错误的预防性设计。
计算属性一览
| 属性 | 类型 | 行为 |
|---|---|---|
names |
dict[int, str] |
从底层模型获取类名,经过 check_class_names 校验 |
device |
torch.device |
从模型参数推断当前设备,非 nn.Module 返回 None |
transforms |
object | None |
底层模型的预处理变换 |
模型变体的能力矩阵
不同子类通过 task_map 声明的能力差异显著。下表总结了各变体支持的操作矩阵:
| 模型变体 | 任务 | train | val | predict | export | track | tune | 特殊能力 |
|---|---|---|---|---|---|---|---|---|
| YOLO | detect/segment/classify/pose/obb | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 5 任务全覆盖 |
| YOLOWorld | detect | ✅ | ✅ | ✅ | ✅ | - | - | set_classes() 开放词汇 |
| YOLOE | detect/segment | ✅ | ✅ | ✅ | ✅ | - | - | 视觉/文本提示嵌入 |
| RTDETR | detect | ✅ | ✅ | ✅ | ✅ | - | - | Transformer 架构 |
| SAM | segment | - | - | ✅ | - | - | - | 点/框提示分割 |
| FastSAM | segment | - | ✅ | ✅ | - | - | - | 实时 SAM |
| NAS | detect | - | ✅ | ✅ | - | - | - | 神经架构搜索 |
注:✅ 表示通过 task_map 提供了完整实现;- 表示 task_map 中未提供对应类,调用时会抛出 NotImplementedError。
Sources: yolo/model.py, rtdetr/model.py, sam/model.py, nas/model.py, fastsam/model.py
回调系统:事件驱动的扩展点
Model 类内建了完整的回调管理机制,允许用户在模型生命周期的关键节点注入自定义逻辑:
# 注册回调
model.add_callback("on_train_start", lambda trainer: print("训练开始!"))
# 清除特定事件的所有回调
model.clear_callback("on_train_start")
# 重置所有回调到默认状态
model.reset_callbacks()
回调字典在初始化时通过 callbacks.get_default_callbacks() 填充默认回调(如 TensorBoard 日志、WandB 同步等),随后传递给 Trainer/Validator/Predictor 使用。reset_callbacks 的实现值得注意——它将每个事件重置为只包含 default_callbacks[event][0](第一个默认回调),而非完全清空。
检查点参数重置机制
_reset_ckpt_args 是一个容易被忽视但至关重要的私有方法。当从检查点加载模型时,它将参数字典过滤到最小集合:
include = {"imgsz", "data", "task", "single_cls"}
return {k: v for k, v in args.items() if k in include}
这确保了训练时的临时参数(如 epochs、batch、optimizer 等)不会泄漏到后续操作中。比如加载一个训练了 300 个 epoch 的检查点后调用 predict(),不会意外地触发 300 轮训练配置。
Sources: model.py
完整 API 速查表
| 方法 | 签名 | 返回值 | PyTorch 专属 | 说明 |
|---|---|---|---|---|
predict |
predict(source, stream, **kwargs) |
list[Results] |
❌ | 推理预测 |
__call__ |
__call__(source, stream, **kwargs) |
list[Results] |
❌ | predict 的别名 |
embed |
embed(source, stream, **kwargs) |
list[Tensor] |
❌ | 特征嵌入 |
track |
track(source, stream, persist, **kwargs) |
list[Results] |
❌ | 目标跟踪 |
train |
train(trainer, **kwargs) |
Metrics | None |
✅ | 模型训练 |
val |
val(validator, **kwargs) |
Metrics |
❌ | 模型验证 |
export |
export(**kwargs) |
str |
✅ | 模型导出 |
tune |
tune(use_ray, iterations, **kwargs) |
ResultGrid | None |
✅ | 超参调优 |
benchmark |
benchmark(data, format, **kwargs) |
DataFrame |
✅ | 性能基准 |
info |
info(detailed, verbose, imgsz) |
tuple |
✅ | 模型信息 |
fuse |
fuse() |
Model |
✅ | BN 融合 |
load |
load(weights) |
Model |
✅ | 加载权重 |
save |
save(filename) |
- | ✅ | 保存检查点 |
reset_weights |
reset_weights() |
Model |
✅ | 重置权重 |
eval |
eval() |
Model |
✅ | 评估模式 |
add_callback |
add_callback(event, func) |
- | ❌ | 注册回调 |
clear_callback |
clear_callback(event) |
- | ❌ | 清除回调 |
reset_callbacks |
reset_callbacks() |
- | ❌ | 重置回调 |
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)