Model 是 Ultralytics YOLO 框架的门面类(Facade),继承自 torch.nn.Module,为 7 种模型变体(YOLO、YOLOWorld、YOLOE、RTDETR、SAM、FastSAM、NAS)提供统一的训练、验证、推理、导出、跟踪和调优接口。它不是一个简单的包装器——而是一个策略调度中枢,通过 _smart_load + task_map 二元机制,在运行时根据任务类型动态选择正确的 Trainer/Predictor/Validator/Model 实现,同时通过多层配置合并策略确保默认参数、用户覆盖和方法级默认值的优先级正确。

Sources: model.py, init.py

类继承体系:从抽象基座到具体变体

Model 类作为抽象基座,定义了所有模型共享的行为契约。每个具体子类只需通过 task_map 属性声明自己的任务→实现映射,即可获得完整的训练/推理/导出能力。这种设计使得新增模型类型只需实现一个子类和一个字典。

torch_nn_Module

+forward()

+parameters()

+to(device)

Model

+predictor: BasePredictor

+model: torch.nn.Module

+trainer: BaseTrainer

+ckpt: dict

+overrides: dict

+task: str

+callbacks: dict

+task_map* dict

+init(model, task, verbose)

+predict(source, stream, **kwargs) : list<Results>

+train(trainer, **kwargs) : Metrics

+val(validator, **kwargs) : Metrics

+export(**kwargs) : str

+track(source, stream, persist, **kwargs) : list<Results>

+tune(use_ray, iterations, **kwargs)

+benchmark(data, format, **kwargs) : DataFrame

+embed(source, stream, **kwargs) : list

+fuse() : Model

+info(detailed, verbose, imgsz) : tuple

+save(filename)

+load(weights) : Model

+reset_weights() : Model

+_smart_load(key) : class

YOLO

+task_map: dict

-detect→DetectionModel/Trainer/Validator/Predictor

-segment→SegmentationModel/Trainer/Validator/Predictor

-classify→ClassificationModel/Trainer/Validator/Predictor

-pose→PoseModel/Trainer/Validator/Predictor

-obb→OBBModel/Trainer/Validator/Predictor

YOLOWorld

+task_map: dict

-detect→WorldModel/WorldTrainer

+set_classes(classes)

YOLOE

+task_map: dict

-detect→YOLOEModel/YOLOETrainer

-segment→YOLOESegModel/YOLOESegTrainer

+set_vocab(vocab, names)

+set_classes(classes, embeddings)

RTDETR

+task_map: dict

-detect→RTDETRDetectionModel/Trainer/Validator/Predictor

SAM

+task_map: dict

-segment→Predictor/SAM2Predictor/SAM3Predictor

+predict(source, bboxes, points, labels)

FastSAM

+task_map: dict

-segment→FastSAMPredictor/Validator

NAS

+task_map: dict

-detect→NASPredictor/Validator

task_map 为抽象属性\n子类必须实现

上图中,实线三角箭头表示继承关系。Modeltask_map 是一个标记为 raise NotImplementedError 的抽象属性,强制所有具体子类必须提供任务映射。值得注意的是,NAS 和 FastSAM 的 task_map 只包含 predictorvalidator,不支持训练——它们是纯推理型模型。

Sources: model.py, yolo/model.py, rtdetr/model.py, sam/model.py, nas/model.py, fastsam/model.py

初始化流程:模型来源的六路分发

Model 的 __init__ 方法是一个六路分发器,根据传入的 model 参数类型,走完全不同的初始化路径。理解这个分发逻辑是掌握 Model 类行为的关键。

.yaml/.yml

其他

Model.__init__(model, task, verbose)

model 是 Model 实例?

直接复制 __dict__
浅拷贝返回

初始化所有属性为空值

is_hub_model?

HUBTrainingSession
下载模型文件

is_triton_model?

保存 URL 为 model
设置 task=detect
直接返回

文件扩展名?

_new(cfg)
从 YAML 构建新模型

_load(weights)
从检查点加载

删除 self.training
委托给 self.model.training

属性初始化快照

无论走哪条路径,Model 都会先初始化以下核心状态属性为空值或默认值:

属性 初始值 用途
callbacks callbacks.get_default_callbacks() 事件回调字典
predictor None 预测器实例(惰性创建)
model None 底层 PyTorch 模型
trainer None 训练器实例
ckpt {} 检查点数据
cfg None YAML 配置路径
ckpt_path None 检查点文件路径
overrides {} 配置覆盖字典
metrics None 验证/训练指标
session None HUB 会话
task 构造参数 任务类型
model_name None 模型名称

初始化的最后一步是 del self.training——这不是 bug,而是有意为之。因为 Model 继承自 torch.nn.Module,后者有 self.training 布尔属性。删除它后,通过 __getattr__ 的代理机制,对 model.training 的访问会直接透传到底层模型,避免了状态不同步的问题。

Sources: model.py

_new_load:两条构建路径的对比

_new(cfg):从 YAML 配置构建新模型

当传入 .yaml.yml 文件时,_new 方法通过 yaml_model_load 解析配置字典,再通过 guess_model_task 从配置中推断任务类型(如果未显式指定),最后调用 self._smart_load("model") 获取任务对应的模型类(如 DetectionModel),实例化底层 PyTorch 模型。

关键步骤:yaml_model_load 会将 YAML 路径中的规模标识符(如 yolo26n 中的 n)提取为 cfg["scale"],供 parse_model 在构建网络时使用来缩放通道数。

Sources: model.py, tasks.py

_load(weights):从权重文件加载模型

当传入 .pt 或其他权重文件时,_load 处理两种子情况:

权重类型 处理方式 赋值
.pt 文件 load_checkpoint() 反序列化 self.model = PyTorch模型, self.ckpt = 检查点字典
其他格式(如 .onnx 仅记录路径,不加载 self.model = 路径字符串, self.ckpt = None

对于 .pt 文件,load_checkpoint 函数会完成 EMA 模型提取、参数合并、任务推断、模型融合评估等一系列操作。对于远程权重(URL),会自动下载到本地 SETTINGS["weights_dir"] 目录。

Sources: model.py, tasks.py

任务推断的四级降级策略

guess_model_task 函数实现了一个四级降级策略来推断模型的任务类型:

优先级 推断来源 示例
1 配置字典的 head 末尾模块名 cfg["head"][-1][-2] 包含 “segment” → "segment"
2 PyTorch 模型的 args["task"]yaml 从模型属性中读取
3 PyTorch 模型的模块类型 遍历 model.modules() 检测头类型
4 文件名后缀 -seg"segment", -cls"classify"
兜底 默认值 "detect"

Sources: tasks.py

_smart_load:策略模式的核心调度

_smart_load 是 Model 类架构的心脏。它将 [task][key] 的二维查找委托给子类的 task_map 属性,实现了策略模式的动态调度:

def _smart_load(self, key: str):
    try:
        return self.task_map[self.task][key]
    except Exception as e:
        name = self.__class__.__name__
        mode = inspect.stack()[1][3]  # 获取调用方函数名
        raise NotImplementedError(
            f"'{name}' model does not support '{mode}' mode for '{self.task}' task."
        ) from e

key 的取值及其被调用的场景:

key 值 调用场景 返回类型
"model" _new() 中构建新模型 模型类(如 DetectionModel
"trainer" train() 中创建训练器 BaseTrainer 子类
"validator" val() 中创建验证器 BaseValidator 子类
"predictor" predict() 中创建预测器 BasePredictor 子类

错误信息中通过 inspect.stack()[1][3] 获取调用方函数名(如 trainpredict),使得错误信息能精确地告知用户"哪个操作不支持哪个任务",而非暴露内部实现细节。

Sources: model.py

配置合并的三层覆盖策略

Model 类中所有主要方法(trainvalpredictexport)都遵循统一的三层配置合并模式

args = {**self.overrides, **custom, **kwargs, "mode": "xxx"}

合并顺序(右侧覆盖左侧):

self.overrides → custom(方法默认值)→ kwargs(用户参数)→ mode(强制覆盖)
层级 来源 含义
1 self.overrides 模型初始化时确定的参数(如 modeltask
2 custom 当前方法的方法级默认值(如 predictconf=0.25
3 kwargs 用户调用时传入的参数,优先级最高
4 "mode" 强制覆盖操作模式(如 "train""predict"

这种设计确保了:用户参数永远优先,方法默认值次之,模型初始化参数兜底。以 predict 方法为例,它的 custom 默认值包括 conf=0.25(置信度阈值)、batch=1save=is_cli(仅在 CLI 模式下保存)和 rect=True(矩形推理)。

Sources: model.py, model.py, model.py, model.py

核心方法详解

predict():惰性预测器与流式推理

predict 方法实现了预测器惰性创建设备感知重建两个关键模式。预测器仅在首次调用或设备变更时创建:

if not self.predictor or self.predictor.args.device != args.get("device", ...):
    # 首次调用或设备变更 → 创建新预测器
    self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
    self.predictor.setup_model(model=self.model, verbose=is_cli)
else:
    # 已有预测器 → 仅更新参数
    self.predictor.args = get_cfg(self.predictor.args, args)

对于 CLI 调用和 Python API 调用,最终走不同的执行路径:CLI 调用 predictor.predict_cli(source),Python API 调用 predictor(source, stream)。此外,当 sourceNone 时,会使用默认的 ASSETS 图片(OBB 任务使用 boats.jpg),并输出警告。

Sources: model.py

train():HUB 集成与检查点恢复

train 是 Model 类中最复杂的方法。它处理以下场景:

  1. HUB 训练覆盖:如果存在活跃的 HUB session,本地参数被忽略,全部使用 HUB 传来的训练参数
  2. 断点续训:当 resume=True 时,检查当前检查点是否包含有效的 epochoptimizer 状态
  3. 权重传递优化:对于非续训场景,直接将已加载的模型传递给 trainer,避免重新下载远程权重
  4. 训练后状态更新:训练完成后,重新加载 best/last 检查点,重置 overridesmetrics

训练完成后,Model 实例的状态会被完全刷新——self.model 指向新训练的最佳模型,self.ckpt 包含最新检查点数据,self.metrics 携带验证指标。

Sources: model.py

export():格式无关的导出委托

export 方法通过 Exporter 类处理所有导出逻辑,自身仅负责参数组装和调用:

args = {**self.overrides, **custom, **kwargs, "mode": "export"}
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)

方法级默认值会将 device 重置为 None(避免多 GPU 错误),batch 设为 1,data 设为 None,并将 imgsz 继承自模型参数。具体的导出格式支持(ONNX、TensorRT、OpenVINO 等 16 种格式)由 Exporter 类内部处理,详见 模型导出引擎:支持 ONNX、TensorRT、OpenVINO 等 16 种格式

Sources: model.py

val():轻量级验证委托

val 方法是最简洁的核心方法之一。它创建验证器、运行验证、保存指标并返回:

validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
return validator.metrics

注意 validator(model=self.model) 的调用方式——验证器是可调用对象,底层会完成数据加载、推理、指标计算的全部流程。

Sources: model.py

track():跟踪器的动态注册

trackpredict 的增强版,在调用 predict 之前完成两件事:

  1. 跟踪器注册:首次调用时通过 register_tracker 注入 ByteTrack/BoTSORT
  2. 参数预设:设置 conf=0.1(ByteTrack 需要低置信度输入)和 batch=1(视频跟踪要求逐帧处理)

跟踪功能本质上是"predict + 后处理跟踪",详见 目标跟踪器:ByteTrack 与 BoTSORT 集成

Sources: model.py

embed():特征嵌入提取

embedpredict 的语义化包装。当用户不指定 embed 参数时,默认提取模型倒数第二层的输出作为图像嵌入:

if not kwargs.get("embed"):
    kwargs["embed"] = [len(self.model.model) - 2]
return self.predict(source, stream, **kwargs)

Sources: model.py

YOLO 子类的自动路由机制

YOLO 类在 __init__ 中实现了一个三级自动路由,根据模型文件名自动切换到合适的子类:

YOLO.__init__(model)

文件名包含 -world?

替换为 YOLOWorld 实例
self.__class__ = YOLOWorld

文件名包含 yoloe?

替换为 YOLOE 实例
self.__class__ = YOLOE

调用 Model.__init__

检测头是 RTDETR?

替换为 RTDETR 实例

保持 YOLO 实例

这种类自变(self-mutation) 模式通过 self.__class__ = type(new_instance)self.__dict__ = new_instance.__dict__ 实现,使得 YOLO("yolov8s-world.pt") 返回的实例实际上是 YOLOWorld 类型——尽管用户代码中的变量类型标注仍然是 YOLO

Sources: yolo/model.py

属性代理与设备管理

__getattr__:透明的属性委托

Model 重写了 __getattr__,将对 Model 实例的属性访问透传到底层模型

def __getattr__(self, name):
    return self._modules["model"] if name == "model" else getattr(self.model, name)

这意味着 model.stridemodel.namesmodel.yaml 等属性实际上来自底层 PyTorch 模型。唯一例外是 name == "model" 时返回 torch.nn.Module_modules["model"],这确保了 self.model 访问能正确获取子模块注册表中的模型对象。

_apply:设备迁移的安全网

_apply 重写确保了模型在设备迁移(.cuda().cpu().half())时的一致性:

def _apply(self, fn):
    self._check_is_pytorch_model()
    self = super()._apply(fn)
    self.predictor = None  # 重置预测器(设备可能已变更)
    self.overrides["device"] = self.device  # 同步设备信息
    return self

每次设备迁移都会强制重置预测器,因为预测器可能缓存了与特定设备绑定的张量。这是避免设备不匹配错误的预防性设计。

计算属性一览

属性 类型 行为
names dict[int, str] 从底层模型获取类名,经过 check_class_names 校验
device torch.device 从模型参数推断当前设备,非 nn.Module 返回 None
transforms object | None 底层模型的预处理变换

Sources: model.py, model.py

模型变体的能力矩阵

不同子类通过 task_map 声明的能力差异显著。下表总结了各变体支持的操作矩阵:

模型变体 任务 train val predict export track tune 特殊能力
YOLO detect/segment/classify/pose/obb 5 任务全覆盖
YOLOWorld detect - - set_classes() 开放词汇
YOLOE detect/segment - - 视觉/文本提示嵌入
RTDETR detect - - Transformer 架构
SAM segment - - - - - 点/框提示分割
FastSAM segment - - - - 实时 SAM
NAS detect - - - - 神经架构搜索

注:✅ 表示通过 task_map 提供了完整实现;- 表示 task_map 中未提供对应类,调用时会抛出 NotImplementedError

Sources: yolo/model.py, rtdetr/model.py, sam/model.py, nas/model.py, fastsam/model.py

回调系统:事件驱动的扩展点

Model 类内建了完整的回调管理机制,允许用户在模型生命周期的关键节点注入自定义逻辑:

# 注册回调
model.add_callback("on_train_start", lambda trainer: print("训练开始!"))

# 清除特定事件的所有回调
model.clear_callback("on_train_start")

# 重置所有回调到默认状态
model.reset_callbacks()

回调字典在初始化时通过 callbacks.get_default_callbacks() 填充默认回调(如 TensorBoard 日志、WandB 同步等),随后传递给 Trainer/Validator/Predictor 使用。reset_callbacks 的实现值得注意——它将每个事件重置为只包含 default_callbacks[event][0](第一个默认回调),而非完全清空。

Sources: model.py, model.py

检查点参数重置机制

_reset_ckpt_args 是一个容易被忽视但至关重要的私有方法。当从检查点加载模型时,它将参数字典过滤到最小集合

include = {"imgsz", "data", "task", "single_cls"}
return {k: v for k, v in args.items() if k in include}

这确保了训练时的临时参数(如 epochsbatchoptimizer 等)不会泄漏到后续操作中。比如加载一个训练了 300 个 epoch 的检查点后调用 predict(),不会意外地触发 300 轮训练配置。

Sources: model.py

完整 API 速查表

方法 签名 返回值 PyTorch 专属 说明
predict predict(source, stream, **kwargs) list[Results] 推理预测
__call__ __call__(source, stream, **kwargs) list[Results] predict 的别名
embed embed(source, stream, **kwargs) list[Tensor] 特征嵌入
track track(source, stream, persist, **kwargs) list[Results] 目标跟踪
train train(trainer, **kwargs) Metrics | None 模型训练
val val(validator, **kwargs) Metrics 模型验证
export export(**kwargs) str 模型导出
tune tune(use_ray, iterations, **kwargs) ResultGrid | None 超参调优
benchmark benchmark(data, format, **kwargs) DataFrame 性能基准
info info(detailed, verbose, imgsz) tuple 模型信息
fuse fuse() Model BN 融合
load load(weights) Model 加载权重
save save(filename) - 保存检查点
reset_weights reset_weights() Model 重置权重
eval eval() Model 评估模式
add_callback add_callback(event, func) - 注册回调
clear_callback clear_callback(event) - 清除回调
reset_callbacks reset_callbacks() - 重置回调

Sources: model.py, model.py

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐