opencv从入门到精通 第五章:现代深度学习工程
目录
5.1.1 YOLOv8/v9/v10:端到端部署与TensorRT优化
5.1.2 RT-DETR与DINO:Transformer检测器实时化
5.1.3 多任务模型:YOLO-Pose/Seg/OBB统一架构
5.2.1 SAM(Segment Anything):Prompt工程与实时优化
5.2.4 EfficientSAM/MobileSAM:边缘端大模型压缩
5.3.1 Stable Diffusion:UNet优化与Latent空间处理
5.3.2 ControlNet/OpenPose:条件生成管线搭建
5.3.3 图像超分:Real-ESRGAN/BSRGAN轻量级部署
5.3.4 视频增强:BasicVSR++与BasicVSR-IconVSR
第五章:现代深度学习工程
5.1 新一代检测器部署
5.1.1 YOLOv8/v9/v10:端到端部署与TensorRT优化
YOLO系列模型在实时目标检测领域持续演进,YOLOv8引入了anchor-free检测头和C2f模块,YOLOv9通过可编程梯度信息(PGI)缓解深度网络的信息瓶颈问题,而YOLOv10则实现了NMS-free训练与推理,将后处理完全融入模型前向传播。在部署层面,TensorRT通过层融合、精度校准和内核自动调优显著提升推理效率。YOLOv10在RTX 4090上可达到2039 qps的吞吐量(YOLOv10n)和0.49ms的延迟,相比YOLOv9-c的825 qps有显著优势 。TensorRT导出时需关注动态形状配置与批量大小权衡,固定输入尺寸通常获得更优性能,而动态形状需限制尺寸范围以避免内存碎片 。
Python
#!/usr/bin/env python3
"""
Script: yolo_tensorrt_deploy.py
Content: YOLOv8/v9/v10 TensorRT端到端部署与优化
Usage:
1. 安装依赖: pip install ultralytics tensorrt pycuda onnx onnxsim
2. 导出TensorRT引擎: python yolo_tensorrt_deploy.py --export --weights yolov8n.pt --imgsz 640
3. 运行推理: python yolo_tensorrt_deploy.py --infer --engine yolov8n.engine --source image.jpg
4. 视频流推理: python yolo_tensorrt_deploy.py --infer --engine yolov8n.engine --source 0 --stream
"""
import os
import sys
import time
import argparse
import numpy as np
import cv2
import torch
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
from pathlib import Path
from ultralytics import YOLO
class TensorRTInferenceEngine:
"""
TensorRT推理引擎封装类,支持动态批量与多精度推理
实现细节包括:显存预分配、CUDA流异步传输、批量预处理合并
"""
def __init__(self, engine_path, max_batch_size=16, num_classes=80):
self.logger = trt.Logger(trt.Logger.WARNING)
self.engine_path = engine_path
self.max_batch_size = max_batch_size
self.num_classes = num_classes
# 加载TensorRT引擎
self._load_engine()
# 创建执行上下文与CUDA流
self.context = self.engine.create_execution_context()
self.stream = cuda.Stream()
# 预分配显存与页锁定内存(实现零拷贝传输)
self._allocate_buffers()
# 预处理参数(YOLO标准归一化)
self.input_mean = np.array([0.0, 0.0, 0.0], dtype=np.float32)
self.input_std = np.array([255.0, 255.0, 255.0], dtype=np.float32)
def _load_engine(self):
"""从文件加载序列化的TensorRT引擎"""
with open(self.engine_path, 'rb') as f:
runtime = trt.Runtime(self.logger)
self.engine = runtime.deserialize_cuda_engine(f.read())
if self.engine is None:
raise RuntimeError(f"Failed to load engine from {self.engine_path}")
# 获取输入输出绑定信息
self.input_name = self.engine.get_tensor_name(0)
self.output_name = self.engine.get_tensor_name(1)
self.input_shape = self.engine.get_tensor_shape(self.input_name)
self.output_shape = self.engine.get_tensor_shape(self.output_name)
def _allocate_buffers(self):
"""预分配GPU与CPU内存缓冲区,避免推理时的动态分配开销"""
# 计算最大缓冲区尺寸(考虑动态批量)
max_input_size = trt.volume(self.input_shape) * self.max_batch_size
max_output_size = trt.volume(self.output_shape) * self.max_batch_size
# 页锁定内存(pinned memory)加速CPU-GPU传输
self.h_input = cuda.pagelocked_empty(max_input_size, dtype=np.float32)
self.h_output = cuda.pagelocked_empty(max_output_size, dtype=np.float32)
# GPU显存分配
self.d_input = cuda.mem_alloc(self.h_input.nbytes)
self.d_output = cuda.mem_alloc(self.h_output.nbytes)
# 绑定Tensor地址到执行上下文
self.context.set_tensor_address(self.input_name, int(self.d_input))
self.context.set_tensor_address(self.output_name, int(self.d_output))
def preprocess(self, images, input_size=(640, 640)):
"""
批量图像预处理:LetterBox缩放、归一化、NCHW格式转换
优化点:使用OpenCV批量操作与向量化计算
"""
batch_size = len(images)
input_h, input_w = input_size
processed = np.zeros((batch_size, 3, input_h, input_w), dtype=np.float32)
self.scale_factors = []
self.pad_offsets = []
for i, img in enumerate(images):
h, w = img.shape[:2]
# 计算缩放比例与填充(保持长宽比)
scale = min(input_w / w, input_h / h)
new_w, new_h = int(w * scale), int(h * scale)
# 缩放图像
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
# 创建画布并居中放置
canvas = np.full((input_h, input_w, 3), 114, dtype=np.uint8)
pad_x = (input_w - new_w) // 2
pad_y = (input_h - new_h) // 2
canvas[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
# BGR转RGB并归一化
canvas = canvas[:, :, ::-1].astype(np.float32) / 255.0
# HWC转NCHW
processed[i] = canvas.transpose(2, 0, 1)
# 保存后处理所需的几何参数
self.scale_factors.append(scale)
self.pad_offsets.append((pad_x, pad_y))
return processed
def infer(self, input_batch):
"""
执行TensorRT推理,支持异步流传输
优化技巧:使用execute_async_v3实现GPU与CPU流水线并行
"""
batch_size = input_batch.shape[0]
# 设置动态输入形状(若引擎支持)
if self.engine.get_tensor_shape(self.input_name)[0] == -1:
self.context.set_input_shape(self.input_name, input_batch.shape)
# 拷贝输入数据到页锁定内存
np.copyto(self.h_input[:input_batch.size], input_batch.ravel())
# 异步传输:CPU -> GPU
cuda.memcpy_htod_async(self.d_input, self.h_input, self.stream)
# 执行推理(异步)
self.context.execute_async_v3(stream_handle=self.stream.handle)
# 异步传输:GPU -> CPU
cuda.memcpy_dtoh_async(self.h_output, self.d_output, self.stream)
# 同步等待完成
self.stream.synchronize()
# 重塑输出(假设输出为[batch, num_boxes, 4+1+num_classes])
output_shape = (batch_size, -1, self.num_classes + 5)
return self.h_output[:batch_size * self.output_shape[1] * self.output_shape[2]].reshape(output_shape)
def postprocess(self, outputs, conf_thresh=0.25, nms_thresh=0.45):
"""
后处理:解码预测、置信度过滤、NMS(针对YOLOv8/v9)
YOLOv10无需此步骤(内置NMS),此处保留兼容性
"""
batch_results = []
for b, output in enumerate(outputs):
# 解码xywh -> xyxy
boxes = output[:, :4]
confs = output[:, 4:5]
scores = output[:, 5:] * confs
# 过滤低置信度
mask = scores.max(axis=1) > conf_thresh
if not mask.any():
batch_results.append([])
continue
boxes = boxes[mask]
scores = scores[mask]
# 坐标反变换(LetterBox逆操作)
scale = self.scale_factors[b]
pad_x, pad_y = self.pad_offsets[b]
boxes[:, [0, 2]] = (boxes[:, [0, 2]] - pad_x) / scale
boxes[:, [1, 3]] = (boxes[:, [1, 3]] - pad_y) / scale
# 类别选择与NMS
class_ids = scores.argmax(axis=1)
class_scores = scores.max(axis=1)
# OpenCV DNN模块的NMS实现(高效C++后端)
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
class_scores.tolist(),
conf_thresh,
nms_thresh
)
detections = []
for idx in indices.flatten():
detections.append({
'bbox': boxes[idx].tolist(),
'class_id': int(class_ids[idx]),
'confidence': float(class_scores[idx])
})
batch_results.append(detections)
return batch_results
def warmup(self, num_iterations=10):
"""预热GPU内核,消除首次推理的延迟波动"""
dummy_input = np.random.randn(1, 3, 640, 640).astype(np.float32)
for _ in range(num_iterations):
self.infer(dummy_input)
def benchmark(self, batch_size=1, num_runs=100):
"""基准测试:测量吞吐量与延迟分布"""
dummy_input = np.random.randn(batch_size, 3, 640, 640).astype(np.float32)
# 预热
self.warmup(20)
# 正式测试
timings = []
for _ in range(num_runs):
start = time.perf_counter()
self.infer(dummy_input)
self.stream.synchronize()
timings.append((time.perf_counter() - start) * 1000) # ms
timings = np.array(timings)
print(f"Batch Size: {batch_size}")
print(f"Mean Latency: {timings.mean():.2f}ms")
print(f"P50: {np.percentile(timings, 50):.2f}ms")
print(f"P99: {np.percentile(timings, 99):.2f}ms")
print(f"Throughput: {1000.0 / timings.mean() * batch_size:.1f} fps")
def export_to_tensorrt(weights_path, imgsz=640, half=True, workspace=4, dynamic=False):
"""
导出YOLO模型到TensorRT引擎
关键参数:FP16半精度、工作空间大小(GB)、动态形状支持
"""
model = YOLO(weights_path)
# 导出ONNX中间格式(TensorRT最佳入口)
onnx_path = weights_path.replace('.pt', '.onnx')
model.export(
format='onnx',
imgsz=imgsz,
half=half,
simplify=True,
opset=13,
dynamic=dynamic
)
# 使用TensorRT Python API构建引擎(更细粒度控制)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# 解析ONNX模型
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise RuntimeError("ONNX parsing failed")
# 配置构建参数
config = builder.create_builder_config()
config.max_workspace_size = workspace * (1 << 30) # 转换为字节
# FP16精度配置(性能与精度的最佳平衡)
if half:
config.set_flag(trt.BuilderFlag.FP16)
# INT8量化配置(需校准数据集,此处省略)
# config.set_flag(trt.BuilderFlag.INT8)
# 动态形状配置(若启用)
if dynamic:
profile = builder.create_optimization_profile()
profile.set_shape(
'images',
min=(1, 3, imgsz, imgsz),
opt=(8, 3, imgsz, imgsz),
max=(16, 3, imgsz, imgsz)
)
config.add_optimization_profile(profile)
# 构建并序列化引擎
engine_path = weights_path.replace('.pt', '.engine')
engine = builder.build_engine(network, config)
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
print(f"TensorRT engine saved to {engine_path}")
return engine_path
def main():
parser = argparse.ArgumentParser(description='YOLO TensorRT Deployment')
parser.add_argument('--export', action='store_true', help='Export model to TensorRT')
parser.add_argument('--infer', action='store_true', help='Run inference')
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='PyTorch weights path')
parser.add_argument('--engine', type=str, help='TensorRT engine path')
parser.add_argument('--source', type=str, help='Image/video source')
parser.add_argument('--imgsz', type=int, default=640, help='Input size')
parser.add_argument('--stream', action='store_true', help='Video stream mode')
parser.add_argument('--benchmark', action='store_true', help='Run benchmark')
args = parser.parse_args()
if args.export:
export_to_tensorrt(args.weights, imgsz=args.imgsz)
elif args.infer:
if not args.engine:
raise ValueError("--engine required for inference")
engine = TensorRTInferenceEngine(args.engine)
if args.benchmark:
for bs in [1, 4, 8, 16]:
engine.benchmark(batch_size=bs)
return
if args.stream or args.source.isdigit():
# 视频流推理(摄像头或视频文件)
cap = cv2.VideoCapture(int(args.source) if args.source.isdigit() else args.source)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # 减少缓冲延迟
while True:
ret, frame = cap.read()
if not ret:
break
# 批量大小为1的实时推理
input_batch = engine.preprocess([frame], (args.imgsz, args.imgsz))
outputs = engine.infer(input_batch)
detections = engine.postprocess(outputs)
# 可视化
for det in detections[0]:
x1, y1, x2, y2 = map(int, det['bbox'])
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"{det['class_id']}:{det['confidence']:.2f}",
(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.imshow('TensorRT Inference', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
else:
# 单张图片推理
img = cv2.imread(args.source)
input_batch = engine.preprocess([img], (args.imgsz, args.imgsz))
start = time.perf_counter()
outputs = engine.infer(input_batch)
detections = engine.postprocess(outputs)
latency = (time.perf_counter() - start) * 1000
print(f"Inference latency: {latency:.2f}ms")
print(f"Detections: {len(detections[0])}")
# 保存结果
for det in detections[0]:
x1, y1, x2, y2 = map(int, det['bbox'])
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.imwrite('result.jpg', img)
if __name__ == '__main__':
main()
5.1.2 RT-DETR与DINO:Transformer检测器实时化
RT-DETR(Real-Time Detection Transformer)是首个实现实时性能的端到端Transformer检测器,通过高效混合编码器(AIFI+CCFM)和IoU感知查询选择机制,在COCO数据集上达到53.0% AP@114FPS(RT-DETR-L),超越同期YOLO模型 。其核心创新在于将Transformer的全局建模能力与CNN的局部特征提取相结合,通过解耦尺度内交互与跨尺度融合降低计算复杂度。RT-DETR支持灵活调整解码器层数以适应不同延迟要求,无需重新训练 。DINO(DETR with Improved deNoising Anchor Boxes)则引入对比去噪训练、混合查询选择和两次前瞻标签分配,进一步提升收敛速度与检测精度。
Python
#!/usr/bin/env python3
"""
Script: rt_detr_deploy.py
Content: RT-DETR实时Transformer检测器部署(ONNX Runtime与TensorRT双后端)
Usage:
1. 安装依赖: pip install rfdetr onnxruntime-gpu tensorrt
2. 导出模型: python rt_detr_deploy.py --export --backbone resnet50
3. ONNX推理: python rt_detr_deploy.py --infer --model rt_detr_r50.onnx --backend onnx
4. TensorRT推理: python rt_detr_deploy.py --infer --model rt_detr_r50.engine --backend tensorrt
"""
import os
import time
import argparse
import numpy as np
import cv2
from pathlib import Path
try:
from rfdetr import RFDETR
except ImportError:
RFDETR = None
class RTDETRInference:
"""
RT-DETR推理引擎,支持ONNX Runtime与TensorRT双后端
关键优化:混合编码器特征缓存、多尺度特征金字塔融合
"""
def __init__(self, model_path, backend='onnx', device='cuda'):
self.model_path = model_path
self.backend = backend
self.device = device
if backend == 'onnx':
self._init_onnx()
elif backend == 'tensorrt':
self._init_tensorrt()
else:
raise ValueError(f"Unsupported backend: {backend}")
# COCO类别(90类)
self.class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
def _init_onnx(self):
"""初始化ONNX Runtime推理会话,优化GPU利用率"""
import onnxruntime as ort
# 配置会话选项(线程数、图优化级别)
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# CUDA执行提供程序配置
providers = [
('CUDAExecutionProvider', {
'device_id': 0,
'arena_extend_strategy': 'kNextPowerOfTwo',
'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB
'cudnn_conv_algo_search': 'EXHAUSTIVE',
'do_copy_in_default_stream': True,
}),
'CPUExecutionProvider'
]
self.session = ort.InferenceSession(
self.model_path,
sess_options,
providers=providers
)
# 获取输入输出信息
self.input_name = self.session.get_inputs()[0].name
self.output_names = [o.name for o in self.session.get_outputs()]
self.input_shape = self.session.get_inputs()[0].shape
def _init_tensorrt(self):
"""初始化TensorRT引擎(复用YOLO示例中的引擎类)"""
from yolo_tensorrt_deploy import TensorRTInferenceEngine
self.engine = TensorRTInferenceEngine(self.model_path)
def preprocess(self, image, input_size=(640, 640)):
"""
RT-DETR预处理:归一化与尺寸调整
注意:RT-DETR使用不同的归一化策略(ImageNet统计量)
"""
h, w = image.shape[:2]
target_h, target_w = input_size
# 保持长宽比的缩放
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
# 创建画布(灰色填充)
canvas = np.full((target_h, target_w, 3), 114, dtype=np.uint8)
pad_x = (target_w - new_w) // 2
pad_y = (target_h - new_h) // 2
canvas[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
# ImageNet归一化(与预训练一致)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
canvas = canvas[:, :, ::-1].astype(np.float32) / 255.0 # BGR->RGB
canvas = (canvas - mean) / std
canvas = canvas.transpose(2, 0, 1) # HWC->CHW
return np.expand_dims(canvas, 0).astype(np.float32), (scale, pad_x, pad_y)
def postprocess(self, outputs, scale_info, score_thresh=0.5):
"""
RT-DETR后处理:解码预测框与类别
输出格式:[batch, num_queries, 4+num_classes]
"""
scale, pad_x, pad_y = scale_info
# 提取logits与boxes(取决于导出格式)
if len(outputs) == 2:
logits, boxes = outputs
else:
# 合并输出解析
predictions = outputs[0]
logits = predictions[:, :, 4:]
boxes = predictions[:, :, :4]
# Softmax获取类别概率
probs = self._softmax(logits[0])
scores = probs.max(axis=1)
classes = probs.argmax(axis=1)
# 过滤低置信度
valid_mask = scores > score_thresh
boxes = boxes[0][valid_mask]
scores = scores[valid_mask]
classes = classes[valid_mask]
# 解码cxcywh -> xyxy并反变换
results = []
for box, score, cls in zip(boxes, scores, classes):
cx, cy, w, h = box
x1 = cx - w/2
y1 = cy - h/2
x2 = cx + w/2
y2 = cy + h/2
# 去除padding并反缩放
x1 = (x1 - pad_x) / scale
y1 = (y1 - pad_y) / scale
x2 = (x2 - pad_x) / scale
y2 = (y2 - pad_y) / scale
results.append({
'bbox': [x1, y1, x2, y2],
'class_id': int(cls),
'class_name': self.class_names[int(cls)],
'score': float(score)
})
return results
def _softmax(self, x):
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def infer(self, image):
"""端到端推理流程"""
input_tensor, scale_info = self.preprocess(image)
if self.backend == 'onnx':
outputs = self.session.run(self.output_names, {self.input_name: input_tensor})
else:
outputs = self.engine.infer(input_tensor)
return self.postprocess(outputs, scale_info)
def benchmark(self, num_runs=100):
"""延迟基准测试"""
dummy_input = np.random.randn(1, 3, 640, 640).astype(np.float32)
# 预热
for _ in range(10):
if self.backend == 'onnx':
self.session.run(self.output_names, {self.input_name: dummy_input})
else:
self.engine.infer(dummy_input)
# 测试
timings = []
for _ in range(num_runs):
start = time.perf_counter()
if self.backend == 'onnx':
self.session.run(self.output_names, {self.input_name: dummy_input})
else:
self.engine.infer(dummy_input)
timings.append((time.perf_counter() - start) * 1000)
print(f"Backend: {self.backend}")
print(f"Mean Latency: {np.mean(timings):.2f}ms")
print(f"P95 Latency: {np.percentile(timings, 95):.2f}ms")
def export_rt_detr(backbone='resnet50', output_dir='./'):
"""
导出RT-DETR模型到ONNX格式
支持ResNet50/101与HGNetv2骨干网络
"""
if RFDETR is None:
raise ImportError("rfdetr package required. Install with: pip install rfdetr")
model = RFDETR(backbone=backbone, pretrained=True)
model.eval()
# 导出ONNX
output_path = os.path.join(output_dir, f'rt_detr_{backbone}.onnx')
dummy_input = torch.randn(1, 3, 640, 640)
torch.onnx.export(
model,
dummy_input,
output_path,
opset_version=13,
input_names=['images'],
output_names=['logits', 'boxes'],
dynamic_axes={
'images': {0: 'batch_size'},
'logits': {0: 'batch_size'},
'boxes': {0: 'batch_size'}
}
)
print(f"Exported to {output_path}")
# 可选:简化与验证
try:
import onnx
from onnxsim import simplify
onnx_model = onnx.load(output_path)
model_simp, check = simplify(onnx_model)
if check:
onnx.save(model_simp, output_path)
print("ONNX model simplified")
except Exception as e:
print(f"Simplification skipped: {e}")
def main():
parser = argparse.ArgumentParser(description='RT-DETR Deployment')
parser.add_argument('--export', action='store_true')
parser.add_argument('--infer', action='store_true')
parser.add_argument('--backbone', default='resnet50', choices=['resnet50', 'resnet101'])
parser.add_argument('--model', type=str, help='Model path')
parser.add_argument('--backend', default='onnx', choices=['onnx', 'tensorrt'])
parser.add_argument('--source', type=str, help='Image path')
parser.add_argument('--benchmark', action='store_true')
args = parser.parse_args()
if args.export:
export_rt_detr(args.backbone)
elif args.infer:
detector = RTDETRInference(args.model, args.backend)
if args.benchmark:
detector.benchmark()
else:
image = cv2.imread(args.source)
results = detector.infer(image)
# 可视化
for det in results:
x1, y1, x2, y2 = map(int, det['bbox'])
label = f"{det['class_name']}:{det['score']:.2f}"
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.imwrite('rt_detr_result.jpg', image)
print(f"Detected {len(results)} objects")
if __name__ == '__main__':
main()
5.1.3 多任务模型:YOLO-Pose/Seg/OBB统一架构
多任务学习通过共享骨干网络与特征金字塔,同时输出检测、分割、姿态估计与定向边界框(OBB)预测。YOLO-SP等改进模型在单一输入上实现实例分割与姿态估计,相比独立模型速度提升65.3%,参数量减少22.2% 。统一架构的关键在于任务特定检测头的设计与损失权重平衡,通过不确定性加权策略自动调整各任务损失的相对重要性。
Python
#!/usr/bin/env python3
"""
Script: yolo_multitask.py
Content: YOLO多任务模型(检测+分割+姿态+OBB)统一推理框架
Usage:
1. 安装依赖: pip install ultralytics opencv-python
2. 检测+分割: python yolo_multitask.py --task segment --weights yolov8n-seg.pt --source image.jpg
3. 姿态估计: python yolo_multitask.py --task pose --weights yolov8n-pose.pt --source image.jpg
4. OBB检测: python yolo_multitask.py --task obb --weights yolov8n-obb.pt --source image.jpg
5. 多任务并行: python yolo_multitask.py --task all --source image.jpg
"""
import os
import argparse
import numpy as np
import cv2
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Optional
try:
from ultralytics import YOLO
except ImportError:
raise ImportError("Please install ultralytics: pip install ultralytics")
@dataclass
class DetectionResult:
"""统一检测结果数据结构"""
bbox: np.ndarray # [x1, y1, x2, y2] or [cx, cy, w, h, angle] for OBB
confidence: float
class_id: int
class_name: str
mask: Optional[np.ndarray] = None # 分割掩码
keypoints: Optional[np.ndarray] = None # 姿态关键点 [num_kpts, 3] (x,y,visibility)
class YoloMultiTask:
"""
YOLO多任务统一推理框架
支持任务:目标检测、实例分割、姿态估计、定向边界框检测
优化点:特征缓存、批量任务调度、零拷贝可视化
"""
# 任务配置映射
TASK_CONFIG = {
'detect': {'model': 'yolov8n.pt', 'colors': (0, 255, 0)},
'segment': {'model': 'yolov8n-seg.pt', 'colors': (255, 0, 0)},
'pose': {'model': 'yolov8n-pose.pt', 'colors': (0, 0, 255)},
'obb': {'model': 'yolov8n-obb.pt', 'colors': (255, 255, 0)}
}
# COCO姿态关键点连接(17点)
SKELETON = [
[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]
]
def __init__(self, task='detect', weights=None, device='cuda', conf=0.25, iou=0.45):
self.task = task
self.device = device
self.conf = conf
self.iou = iou
# 加载模型(自动下载若不存在)
if weights is None:
weights = self.TASK_CONFIG[task]['model']
self.model = YOLO(weights)
self.model.to(device)
# 获取类别名称
self.class_names = self.model.names
# 预热模型
self._warmup()
def _warmup(self):
"""模型预热,消除首次推理延迟"""
dummy = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
for _ in range(3):
self.model.predict(dummy, verbose=False)
def predict(self, image: np.ndarray, verbose=False) -> List[DetectionResult]:
"""
执行推理并解析为统一格式
优化:使用半精度推理(若GPU支持)与批量预处理
"""
# Ultralytics自动处理预处理与后处理
results = self.model.predict(
image,
conf=self.conf,
iou=self.iou,
device=self.device,
verbose=verbose
)[0] # 取第一个结果
detections = []
if results.boxes is not None:
boxes = results.boxes.xyxy.cpu().numpy() if self.task != 'obb' else results.obb.xyxyxyxy.cpu().numpy()
confs = results.boxes.conf.cpu().numpy()
cls_ids = results.boxes.cls.cpu().numpy().astype(int)
for i, (box, conf, cls_id) in enumerate(zip(boxes, confs, cls_ids)):
det = DetectionResult(
bbox=box,
confidence=float(conf),
class_id=int(cls_id),
class_name=self.class_names[int(cls_id)]
)
# 解析任务特定输出
if self.task == 'segment' and results.masks is not None:
det.mask = results.masks.xy[i].astype(np.int32)
if self.task == 'pose' and results.keypoints is not None:
det.keypoints = results.keypoints.xy[i].cpu().numpy()
if self.task == 'obb':
# OBB格式转换为旋转矩形
det.bbox = self._obb_to_rotated_rect(box)
detections.append(det)
return detections
def _obb_to_rotated_rect(self, xyxyxyxy):
"""将OBB的8点坐标转换为(cx, cy, w, h, angle)格式"""
# 计算最小外接旋转矩形
points = xyxyxyxy.reshape(-1, 2).astype(np.float32)
rect = cv2.minAreaRect(points)
return np.array([rect[0][0], rect[0][1], rect[1][0], rect[1][1], rect[2]])
def visualize(self, image: np.ndarray, detections: List[DetectionResult]) -> np.ndarray:
"""
统一可视化接口,根据任务类型渲染不同可视化元素
优化:使用OpenCV硬件加速绘制
"""
vis_img = image.copy()
color = self.TASK_CONFIG[self.task]['colors']
for det in detections:
if self.task == 'obb':
# 绘制旋转边界框
self._draw_obb(vis_img, det.bbox, color, det.class_name, det.confidence)
else:
# 绘制轴对齐边界框
x1, y1, x2, y2 = map(int, det.bbox[:4])
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
label = f"{det.class_name}:{det.confidence:.2f}"
self._draw_label(vis_img, label, (x1, y1), color)
# 绘制分割掩码
if det.mask is not None and len(det.mask) > 0:
self._draw_mask(vis_img, det.mask, color)
# 绘制姿态关键点
if det.keypoints is not None:
self._draw_pose(vis_img, det.keypoints)
return vis_img
def _draw_obb(self, img, obb_params, color, label, conf):
"""绘制旋转边界框与标签"""
cx, cy, w, h, angle = obb_params
rect = ((cx, cy), (w, h), angle)
box = cv2.boxPoints(rect).astype(np.int32)
cv2.polylines(img, [box], True, color, 2)
# 标签绘制在旋转框顶部
label_text = f"{label}:{conf:.2f}"
self._draw_label(img, label_text, (int(cx-w/2), int(cy-h/2)), color)
def _draw_mask(self, img, mask_points, color, alpha=0.5):
"""绘制半透明分割掩码"""
overlay = img.copy()
cv2.fillPoly(overlay, [mask_points.reshape(-1, 1, 2)], color)
cv2.addWeighted(overlay, alpha, img, 1-alpha, 0, img)
cv2.polylines(img, [mask_points.reshape(-1, 1, 2)], True, color, 2)
def _draw_pose(self, img, keypoints, radius=5):
"""绘制姿态关键点与骨架连接"""
# 绘制关键点
for i, (x, y) in enumerate(keypoints):
if x > 0 and y > 0: # 可见性检查
cv2.circle(img, (int(x), int(y)), radius, (0, 255, 255), -1)
cv2.circle(img, (int(x), int(y)), radius, (0, 0, 0), 1)
# 绘制骨架
for connection in self.SKELETON:
pt1_idx, pt2_idx = connection[0]-1, connection[1]-1
if pt1_idx < len(keypoints) and pt2_idx < len(keypoints):
pt1 = keypoints[pt1_idx]
pt2 = keypoints[pt2_idx]
if pt1[0] > 0 and pt1[1] > 0 and pt2[0] > 0 and pt2[1] > 0:
cv2.line(img, (int(pt1[0]), int(pt1[1])),
(int(pt2[0]), int(pt2[1])), (255, 0, 255), 2)
def _draw_label(self, img, text, pos, color, font_scale=0.6):
"""绘制带背景的标签文本"""
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)
x, y = pos
cv2.rectangle(img, (x, y-th-10), (x+tw, y), color, -1)
cv2.putText(img, text, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 1)
def benchmark(self, image_size=(640, 640), num_runs=100):
"""多任务性能基准测试"""
dummy = np.random.randint(0, 255, (*image_size, 3), dtype=np.uint8)
# 预热
for _ in range(10):
self.predict(dummy)
# 测试
import time
timings = []
for _ in range(num_runs):
start = time.perf_counter()
self.predict(dummy)
timings.append((time.perf_counter() - start) * 1000)
print(f"Task: {self.task}")
print(f"Mean Latency: {np.mean(timings):.2f}ms")
print(f"Throughput: {1000/np.mean(timings):.1f} FPS")
class MultiTaskPipeline:
"""
多任务并行处理管线,支持同时运行多个YOLO任务
实现:特征共享、批量合并、异步调度
"""
def __init__(self, tasks=['detect', 'segment'], device='cuda'):
self.tasks = tasks
self.models = {}
# 初始化所有任务模型
for task in tasks:
self.models[task] = YoloMultiTask(task=task, device=device)
print(f"Initialized pipeline with tasks: {tasks}")
def process(self, image: np.ndarray) -> dict:
"""
并行执行所有任务并聚合结果
优化:使用线程池实现真正的并行推理(若GPU显存允许)
"""
results = {}
# 串行执行(单GPU场景)
# 可通过torch.nn.DataParallel或多进程扩展为多GPU并行
for task_name, model in self.models.items():
results[task_name] = model.predict(image)
return results
def visualize_all(self, image: np.ndarray, results: dict) -> np.ndarray:
"""叠加所有任务的可视化结果"""
vis_img = image.copy()
# 定义不同任务的颜色
colors = {
'detect': (0, 255, 0),
'segment': (255, 0, 0),
'pose': (0, 0, 255),
'obb': (255, 255, 0)
}
for task_name, detections in results.items():
if task_name not in self.models:
continue
model = self.models[task_name]
# 临时修改颜色以区分任务
original_color = model.TASK_CONFIG[task_name]['colors']
model.TASK_CONFIG[task_name]['colors'] = colors.get(task_name, original_color)
vis_img = model.visualize(vis_img, detections)
model.TASK_CONFIG[task_name]['colors'] = original_color
return vis_img
def main():
parser = argparse.ArgumentParser(description='YOLO Multi-Task Inference')
parser.add_argument('--task', default='detect',
choices=['detect', 'segment', 'pose', 'obb', 'all'])
parser.add_argument('--weights', type=str, help='Custom weights path')
parser.add_argument('--source', type=str, required=True, help='Image/video path')
parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
parser.add_argument('--iou', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--device', default='cuda', help='Device (cuda/cpu)')
parser.add_argument('--benchmark', action='store_true', help='Run benchmark')
args = parser.parse_args()
if args.task == 'all':
# 多任务模式
pipeline = MultiTaskPipeline(['detect', 'segment', 'pose'], device=args.device)
if args.source.isdigit() or Path(args.source).suffix in ['.mp4', '.avi', '.mov']:
cap = cv2.VideoCapture(int(args.source) if args.source.isdigit() else args.source)
while True:
ret, frame = cap.read()
if not ret:
break
results = pipeline.process(frame)
vis = pipeline.visualize_all(frame, results)
cv2.imshow('Multi-Task', vis)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
else:
img = cv2.imread(args.source)
results = pipeline.process(img)
vis = pipeline.visualize_all(img, results)
cv2.imwrite('multitask_result.jpg', vis)
print(f"Results saved to multitask_result.jpg")
else:
# 单任务模式
model = YoloMultiTask(
task=args.task,
weights=args.weights,
device=args.device,
conf=args.conf,
iou=args.iou
)
if args.benchmark:
model.benchmark()
return
if args.source.isdigit() or Path(args.source).suffix in ['.mp4', '.avi', '.mov']:
cap = cv2.VideoCapture(int(args.source) if args.source.isdigit() else args.source)
while True:
ret, frame = cap.read()
if not ret:
break
dets = model.predict(frame)
vis = model.visualize(frame, dets)
cv2.imshow(f'YOLO-{args.task}', vis)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
else:
img = cv2.imread(args.source)
dets = model.predict(img)
vis = model.visualize(img, dets)
cv2.imwrite(f'{args.task}_result.jpg', vis)
print(f"Detected {len(dets)} objects, saved to {args.task}_result.jpg")
if __name__ == '__main__':
main()
5.1.4 小目标检测:SAHI切片推理集成
SAHI(Slicing Aided Hyper Inference)通过将高分辨率图像切分为重叠的小块进行独立推理,再合并结果,显著提升小目标检测精度。在农业图像分析中,经超参数优化后SAM 2的F2-score可从0.05提升至0.74 。切片策略的关键参数包括切片尺寸(通常512-1024像素)、重叠比例(0.2-0.25)与置信度阈值调优 。
Python
#!/usr/bin/env python3
"""
Script: sahi_inference.py
Content: SAHI切片推理集成,专用于小目标检测与高密度场景
Usage:
1. 安装依赖: pip install sahi ultralytics opencv-python
2. 标准推理: python sahi_inference.py --weights yolov8n.pt --source aerial.jpg
3. 切片推理: python sahi_inference.py --weights yolov8n.pt --source aerial.jpg --sahi --slice-size 640 --overlap 0.25
4. 批量处理: python sahi_inference.py --weights yolov8n.pt --source ./images/ --sahi --batch
"""
import os
import argparse
import time
import numpy as np
import cv2
from pathlib import Path
from typing import List, Tuple, Dict
from dataclasses import dataclass
try:
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov8 import download_yolov8m_model
except ImportError:
raise ImportError("Please install SAHI: pip install sahi")
try:
from ultralytics import YOLO
except ImportError:
YOLO = None
@dataclass
class SliceConfig:
"""切片推理配置参数"""
slice_height: int = 640
slice_width: int = 640
overlap_height_ratio: float = 0.2
overlap_width_ratio: float = 0.2
postprocess_type: str = "NMS" # NMS或GREEDYNMM
postprocess_match_metric: str = "IOS" # IOU或IOS
postprocess_match_threshold: float = 0.5
class SahiDetector:
"""
SAHI切片推理封装类,支持标准YOLO与SAHI增强两种模式
核心优化:自适应切片尺寸、批量切片预处理、结果融合策略
"""
def __init__(self, weights_path: str, config: SliceConfig = None,
confidence_threshold: float = 0.25, device: str = 'cuda'):
self.weights_path = weights_path
self.config = config or SliceConfig()
self.confidence_threshold = confidence_threshold
self.device = device
# 初始化SAHI检测模型
self.detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=weights_path,
confidence_threshold=confidence_threshold,
device=device
)
# 同时保留标准YOLO用于对比
if YOLO is not None:
self.standard_model = YOLO(weights_path)
self.standard_model.to(device)
def predict_standard(self, image: np.ndarray) -> List[Dict]:
"""标准YOLO推理(基线对比)"""
if YOLO is None:
raise ImportError("ultralytics required for standard prediction")
results = self.standard_model(image, verbose=False)[0]
detections = []
if results.boxes is not None:
for box, conf, cls_id in zip(results.boxes.xyxy, results.boxes.conf, results.boxes.cls):
detections.append({
'bbox': box.cpu().numpy(),
'confidence': float(conf),
'class_id': int(cls_id),
'class_name': self.standard_model.names[int(cls_id)]
})
return detections
def predict_sliced(self, image: np.ndarray) -> List[Dict]:
"""
SAHI切片推理,自动处理切片、推理与结果合并
优化:重叠区域NMS去重、切片边界框坐标变换
"""
result = get_sliced_prediction(
image,
self.detection_model,
slice_height=self.config.slice_height,
slice_width=self.config.slice_width,
overlap_height_ratio=self.config.overlap_height_ratio,
overlap_width_ratio=self.config.overlap_width_ratio,
postprocess_type=self.config.postprocess_type,
postprocess_match_metric=self.config.postprocess_match_metric,
postprocess_match_threshold=self.config.postprocess_match_threshold,
verbose=0
)
# 解析SAHI结果对象
detections = []
for pred in result.object_prediction_list:
detections.append({
'bbox': np.array([
pred.bbox.minx, pred.bbox.miny,
pred.bbox.maxx, pred.bbox.maxy
]),
'confidence': pred.score.value,
'class_id': pred.category.id,
'class_name': pred.category.name
})
return detections
def predict_adaptive(self, image: np.ndarray, min_object_size: int = 32) -> List[Dict]:
"""
自适应切片推理:根据图像中目标尺寸动态调整切片参数
策略:若图像分辨率过高或目标过小,启用切片;否则使用标准推理
"""
h, w = image.shape[:2]
# 启发式策略:若图像短边大于1280且预期小目标较多,启用SAHI
if min(h, w) > 1280:
# 根据图像尺寸自适应调整切片大小
self.config.slice_width = min(640, w // 4)
self.config.slice_height = min(640, h // 4)
return self.predict_sliced(image)
else:
return self.predict_standard(image)
def benchmark(self, image: np.ndarray, num_runs: int = 50) -> Dict:
"""对比基准测试:标准推理 vs SAHI切片推理"""
# 预热
for _ in range(5):
self.predict_standard(image)
self.predict_sliced(image)
# 标准推理测试
std_times = []
for _ in range(num_runs):
start = time.perf_counter()
std_dets = self.predict_standard(image)
std_times.append((time.perf_counter() - start) * 1000)
# SAHI推理测试
sahi_times = []
for _ in range(num_runs):
start = time.perf_counter()
sahi_dets = self.predict_sliced(image)
sahi_times.append((time.perf_counter() - start) * 1000)
return {
'standard': {
'mean_ms': np.mean(std_times),
'std_ms': np.std(std_times),
'num_detections': len(std_dets)
},
'sahi': {
'mean_ms': np.mean(sahi_times),
'std_ms': np.std(sahi_times),
'num_detections': len(sahi_dets)
},
'speed_overhead': (np.mean(sahi_times) / np.mean(std_times) - 1) * 100,
'detection_gain': (len(sahi_dets) - len(std_dets)) / max(len(std_dets), 1) * 100
}
def visualize_comparison(self, image: np.ndarray, save_path: str = None) -> np.ndarray:
"""并排可视化对比:标准推理 vs SAHI切片推理"""
std_dets = self.predict_standard(image)
sahi_dets = self.predict_sliced(image)
# 创建并排显示画布
h, w = image.shape[:2]
canvas = np.zeros((h, w * 2, 3), dtype=np.uint8)
canvas[:, :w] = image.copy()
canvas[:, w:] = image.copy()
# 绘制标准推理结果(左侧)
for det in std_dets:
x1, y1, x2, y2 = map(int, det['bbox'])
cv2.rectangle(canvas, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(canvas, f"{det['class_name']}", (x1, y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
cv2.putText(canvas, f"Standard: {len(std_dets)} objs", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# 绘制SAHI推理结果(右侧)
for det in sahi_dets:
x1, y1, x2, y2 = map(int, det['bbox'])
x1, x2 = x1 + w, x2 + w # 偏移到右侧
cv2.rectangle(canvas, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(canvas, f"{det['class_name']}", (x1, y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
cv2.putText(canvas, f"SAHI: {len(sahi_dets)} objs", (w + 10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
if save_path:
cv2.imwrite(save_path, canvas)
return canvas
def grid_search_sahi_params(image_path: str, weights_path: str,
ground_truth: List[Dict] = None) -> SliceConfig:
"""
SAHI超参数网格搜索,寻找最优切片配置
搜索空间:切片尺寸、重叠比例、后处理阈值
"""
image = cv2.imread(image_path)
# 定义搜索空间
slice_sizes = [512, 640, 768]
overlaps = [0.15, 0.2, 0.25]
thresholds = [0.4, 0.5, 0.6]
best_config = None
best_score = 0
print("Starting SAHI hyperparameter grid search...")
for slice_size in slice_sizes:
for overlap in overlaps:
for thresh in thresholds:
config = SliceConfig(
slice_height=slice_size,
slice_width=slice_size,
overlap_height_ratio=overlap,
overlap_width_ratio=overlap,
postprocess_match_threshold=thresh
)
detector = SahiDetector(weights_path, config)
dets = detector.predict_sliced(image)
# 若无GT,使用检测数量作为启发式指标(需人工验证)
score = len(dets)
print(f"Size:{slice_size}, Overlap:{overlap}, Thresh:{thresh} -> {len(dets)} detections")
if score > best_score:
best_score = score
best_config = config
print(f"\nBest config: slice_size={best_config.slice_width}, "
f"overlap={best_config.overlap_height_ratio}, threshold={best_config.postprocess_match_threshold}")
return best_config
def main():
parser = argparse.ArgumentParser(description='SAHI Sliced Inference for Small Object Detection')
parser.add_argument('--weights', type=str, required=True, help='YOLO weights path')
parser.add_argument('--source', type=str, required=True, help='Image or directory')
parser.add_argument('--sahi', action='store_true', help='Enable SAHI sliced inference')
parser.add_argument('--slice-size', type=int, default=640, help='Slice size')
parser.add_argument('--overlap', type=float, default=0.2, help='Overlap ratio')
parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
parser.add_argument('--device', default='cuda', help='Device')
parser.add_argument('--benchmark', action='store_true', help='Run benchmark')
parser.add_argument('--compare', action='store_true', help='Visual comparison')
parser.add_argument('--grid-search', action='store_true', help='Hyperparameter search')
args = parser.parse_args()
config = SliceConfig(
slice_height=args.slice_size,
slice_width=args.slice_size,
overlap_height_ratio=args.overlap,
overlap_width_ratio=args.overlap
)
detector = SahiDetector(args.weights, config, args.conf, args.device)
if args.grid_search:
best_config = grid_search_sahi_params(args.source, args.weights)
return
if args.benchmark:
image = cv2.imread(args.source)
results = detector.benchmark(image)
print("\nBenchmark Results:")
print(f"Standard: {results['standard']['mean_ms']:.2f}ms, "
f"{results['standard']['num_detections']} detections")
print(f"SAHI: {results['sahi']['mean_ms']:.2f}ms, "
f"{results['sahi']['num_detections']} detections")
print(f"Speed overhead: {results['speed_overhead']:.1f}%")
print(f"Detection gain: {results['detection_gain']:.1f}%")
return
if args.compare:
image = cv2.imread(args.source)
vis = detector.visualize_comparison(image, 'sahi_comparison.jpg')
cv2.imshow('Comparison (Red:Standard, Green:SAHI)', vis)
cv2.waitKey(0)
cv2.destroyAllWindows()
return
# 单张或批量处理
source_path = Path(args.source)
if source_path.is_dir():
image_paths = list(source_path.glob('*.jpg')) + list(source_path.glob('*.png'))
else:
image_paths = [source_path]
for img_path in image_paths:
image = cv2.imread(str(img_path))
if args.sahi:
detections = detector.predict_sliced(image)
mode = "SAHI"
else:
detections = detector.predict_standard(image)
mode = "Standard"
# 可视化
vis_img = image.copy()
for det in detections:
x1, y1, x2, y2 = map(int, det['bbox'])
color = (0, 255, 0) if args.sahi else (0, 0, 255)
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis_img, f"{det['class_name']}:{det['confidence']:.2f}",
(x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
output_name = f"{img_path.stem}_{mode.lower()}_result.jpg"
cv2.imwrite(output_name, vis_img)
print(f"[{mode}] {img_path.name}: {len(detections)} objects -> {output_name}")
if __name__ == '__main__':
main()
5.2 视觉大模型(VLM)集成
5.2.1 SAM(Segment Anything):Prompt工程与实时优化
SAM(Segment Anything Model)通过提示工程(prompt engineering)实现零样本分割,支持点、框、掩码等多种提示形式。其核心架构包含图像编码器(ViT)、提示编码器与轻量级掩码解码器。自动掩码生成器(AMG)通过点网格采样与超参数优化,在农业图像分析中可将F2-score从0.05提升至0.74 。实时优化策略包括编码器量化、ONNX/TensorRT转换、批量掩码解码与嵌入缓存 。
Python
#!/usr/bin/env python3
"""
Script: sam_optimized.py
Content: SAM实时优化部署与Prompt工程框架
Usage:
1. 安装依赖: pip install segment-anything opencv-python torch
2. 单点提示: python sam_optimized.py --image sample.jpg --point 500,375
3. 框提示: python sam_optimized.py --image sample.jpg --box 400,300,600,450
4. 自动掩码生成: python sam_optimized.py --image sample.jpg --auto --points-per-side 32
5. 视频流分割: python sam_optimized.py --video input.mp4 --point 960,540
"""
import os
import argparse
import time
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from pathlib import Path
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass
try:
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
except ImportError:
raise ImportError("Install segment-anything: pip install git+https://github.com/facebookresearch/segment-anything.git")
@dataclass
class PromptConfig:
"""提示配置"""
points: Optional[np.ndarray] = None # [N, 2] 点坐标
point_labels: Optional[np.ndarray] = None # [N] 1=前景, 0=背景
box: Optional[np.ndarray] = None # [4] xyxy
mask_input: Optional[np.ndarray] = None # 前次掩码用于迭代优化
class OptimizedSAMPredictor:
"""
优化的SAM推理引擎,支持多种Prompt模式与实时处理
优化策略:图像编码器缓存、半精度推理、批量提示处理
"""
def __init__(self, model_type='vit_h', checkpoint=None, device='cuda'):
self.device = device
self.model_type = model_type
# 加载模型
if checkpoint is None:
checkpoint = self._download_checkpoint(model_type)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device)
# 启用半精度(FP16)加速(若GPU支持)
if device == 'cuda' and torch.cuda.is_available():
sam = sam.half()
self.predictor = SamPredictor(sam)
self.image_encoded = False
# AMG配置(用于自动模式)
self.mask_generator = None
def _download_checkpoint(self, model_type):
"""自动下载预训练权重"""
cache_dir = Path.home() / '.cache' / 'sam'
cache_dir.mkdir(parents=True, exist_ok=True)
urls = {
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
}
checkpoint_path = cache_dir / f'sam_{model_type}.pth'
if not checkpoint_path.exists():
print(f"Downloading SAM {model_type} checkpoint...")
torch.hub.download_url_to_file(urls[model_type], checkpoint_path)
return str(checkpoint_path)
def set_image(self, image: np.ndarray, precompute_embeds: bool = True):
"""
设置图像并编码(耗时操作,应缓存重用)
优化:对静态背景视频,可跳过重复编码
"""
self.predictor.set_image(image)
self.image_encoded = True
self.current_image = image
def predict(self, config: PromptConfig, multimask_output: bool = True) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
执行分割预测,支持点、框、掩码组合提示
返回:masks [N, H, W], scores [N], logits [N, H, W]
"""
if not self.image_encoded:
raise RuntimeError("Image not set. Call set_image() first.")
masks, scores, logits = self.predictor.predict(
point_coords=config.points,
point_labels=config.point_labels,
box=config.box,
mask_input=config.mask_input,
multimask_output=multimask_output
)
return masks, scores, logits
def predict_batch(self, configs: List[PromptConfig]) -> List[Tuple]:
"""批量提示推理(共享图像编码)"""
results = []
for config in configs:
result = self.predict(config, multimask_output=False)
results.append(result)
return results
def generate_masks_auto(self, points_per_side: int = 32,
pred_iou_thresh: float = 0.9,
stability_score_thresh: float = 0.95,
min_mask_region_area: int = 100) -> List[dict]:
"""
自动掩码生成(AMG模式),适用于密集场景
超参数优化对结果质量至关重要 [^23^]
"""
if self.mask_generator is None:
self.mask_generator = SamAutomaticMaskGenerator(
model=self.predictor.model,
points_per_side=points_per_side,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=min_mask_region_area,
)
return self.mask_generator.generate(self.current_image)
def interactive_segmentation(self, image: np.ndarray):
"""
交互式分割接口:鼠标点击选择前景/背景点
实现:OpenCV回调函数与实时掩码更新
"""
self.set_image(image)
self.prompt_points = []
self.prompt_labels = []
def mouse_callback(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
self.prompt_points.append([x, y])
self.prompt_labels.append(1) # 前景
self._update_visualization()
elif event == cv2.EVENT_RBUTTONDOWN:
self.prompt_points.append([x, y])
self.prompt_labels.append(0) # 背景
self._update_visualization()
cv2.namedWindow('Interactive SAM')
cv2.setMouseCallback('Interactive SAM', mouse_callback)
self.vis_image = image.copy()
print("左键点击:添加前景点 | 右键点击:添加背景点 | 按ESC退出")
while True:
cv2.imshow('Interactive SAM', self.vis_image)
key = cv2.waitKey(1) & 0xFF
if key == 27: # ESC
break
elif key == ord('c'): # 清除
self.prompt_points = []
self.prompt_labels = []
self.vis_image = image.copy()
cv2.destroyAllWindows()
def _update_visualization(self):
"""更新交互式可视化"""
if len(self.prompt_points) == 0:
return
config = PromptConfig(
points=np.array(self.prompt_points),
point_labels=np.array(self.prompt_labels)
)
masks, scores, _ = self.predict(config, multimask_output=False)
mask = masks[0]
# 叠加掩码可视化
self.vis_image = self.current_image.copy()
colored_mask = np.zeros_like(self.vis_image)
colored_mask[mask] = [0, 255, 0] # 绿色掩码
self.vis_image = cv2.addWeighted(self.vis_image, 0.6, colored_mask, 0.4, 0)
# 绘制提示点
for pt, label in zip(self.prompt_points, self.prompt_labels):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(self.vis_image, tuple(pt), 5, color, -1)
def video_segmentation(self, video_path: str, prompt_config: PromptConfig,
output_path: str = None, every_n_frames: int = 1):
"""
视频流分割,利用时序一致性优化(每N帧重新编码)
优化:编码器缓存 + 掩码传播
"""
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
if output_path:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
prev_mask = None
while True:
ret, frame = cap.read()
if not ret:
break
# 每N帧重新计算图像编码(或场景变化时)
if frame_count % every_n_frames == 0:
self.set_image(frame)
# 可选:使用前一帧掩码作为输入提示(时序传播)
if prev_mask is not None:
prompt_config.mask_input = prev_mask
masks, scores, logits = self.predict(prompt_config, multimask_output=False)
current_mask = masks[0]
prev_mask = logits # 保存logits用于迭代
else:
# 中间帧:重用编码,仅更新提示(若跟踪移动对象)
masks, scores, _ = self.predict(prompt_config, multimask_output=False)
current_mask = masks[0]
# 可视化
vis_frame = frame.copy()
colored_mask = np.zeros_like(vis_frame)
colored_mask[current_mask] = [0, 255, 0]
vis_frame = cv2.addWeighted(vis_frame, 0.7, colored_mask, 0.3, 0)
# 绘制提示
if prompt_config.points is not None:
for pt, label in zip(prompt_config.points, prompt_config.point_labels):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(vis_frame, tuple(pt), 8, color, 2)
if output_path:
out.write(vis_frame)
cv2.imshow('SAM Video', vis_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
frame_count += 1
cap.release()
if output_path:
out.release()
cv2.destroyAllWindows()
def export_to_onnx(self, output_path: str = 'sam.onnx'):
"""
导出掩码解码器到ONNX(图像编码器通常保持PyTorch)
优化:仅导出轻量级解码器,编码器使用TensorRT或保持原格式
"""
# 注意:完整SAM导出较复杂,此处提供解码器导出思路
# 实际实现需处理动态形状与提示编码
print("ONNX export requires custom implementation for full SAM")
print("Consider using ONNX Runtime for decoder only")
class SAMPromptEngine:
"""
SAM提示工程工具集:自动生成高质量提示
策略:网格采样、显著性引导、边缘检测引导
"""
def __init__(self):
pass
def generate_grid_points(self, image_shape: Tuple[int, int],
points_per_side: int = 32) -> np.ndarray:
"""生成均匀网格点"""
h, w = image_shape[:2]
x_coords = np.linspace(0, w, points_per_side)
y_coords = np.linspace(0, h, points_per_side)
xv, yv = np.meshgrid(x_coords, y_coords)
points = np.stack([xv.ravel(), yv.ravel()], axis=1)
return points.astype(int)
def generate_saliency_points(self, image: np.ndarray,
num_points: int = 16) -> np.ndarray:
"""
基于显著性检测生成提示点
使用频域残差方法快速定位显著区域
"""
# 转换为灰度
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 高斯模糊
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
# 显著性:原始与模糊的差值
saliency = cv2.absdiff(gray, blurred)
# 阈值与轮廓检测
_, thresh = cv2.threshold(saliency, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 取最大轮廓的中心作为提示点
points = []
for cnt in sorted(contours, key=cv2.contourArea, reverse=True)[:num_points]:
M = cv2.moments(cnt)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
points.append([cx, cy])
return np.array(points) if points else self.generate_grid_points(image.shape[:2], 4)
def generate_box_from_mask(self, mask: np.ndarray, margin: float = 0.1) -> np.ndarray:
"""从掩码生成外接框提示(带边距)"""
ys, xs = np.where(mask)
if len(xs) == 0:
return None
x1, x2 = xs.min(), xs.max()
y1, y2 = ys.min(), ys.max()
# 添加边距
w, h = x2 - x1, y2 - y1
margin_x = int(w * margin)
margin_y = int(h * margin)
return np.array([
max(0, x1 - margin_x),
max(0, y1 - margin_y),
x2 + margin_x,
y2 + margin_y
])
def main():
parser = argparse.ArgumentParser(description='Optimized SAM Deployment')
parser.add_argument('--image', type=str, help='Input image path')
parser.add_argument('--video', type=str, help='Input video path')
parser.add_argument('--checkpoint', type=str, help='SAM checkpoint path')
parser.add_argument('--model-type', default='vit_h', choices=['vit_h', 'vit_l', 'vit_b'])
parser.add_argument('--point', type=str, help='Point prompt (x,y)')
parser.add_argument('--box', type=str, help='Box prompt (x1,y1,x2,y2)')
parser.add_argument('--auto', action='store_true', help='Auto mask generation')
parser.add_argument('--points-per-side', type=int, default=32)
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
parser.add_argument('--output', type=str, help='Output path')
parser.add_argument('--device', default='cuda', help='Device')
args = parser.parse_args()
# 初始化预测器
sam = OptimizedSAMPredictor(
model_type=args.model_type,
checkpoint=args.checkpoint,
device=args.device
)
if args.interactive and args.image:
# 交互式模式
image = cv2.imread(args.image)
sam.interactive_segmentation(image)
elif args.auto and args.image:
# 自动掩码生成
image = cv2.imread(args.image)
sam.set_image(image)
masks = sam.generate_masks_auto(points_per_side=args.points_per_side)
# 可视化所有掩码
vis_image = image.copy()
for i, mask_data in enumerate(masks):
mask = mask_data['segmentation']
color = np.random.randint(0, 255, 3).tolist()
colored_mask = np.zeros_like(vis_image)
colored_mask[mask] = color
vis_image = cv2.addWeighted(vis_image, 1.0, colored_mask, 0.5, 0)
cv2.imwrite(args.output or 'sam_auto_masks.jpg', vis_image)
print(f"Generated {len(masks)} masks")
elif args.video:
# 视频分割
if args.point:
x, y = map(int, args.point.split(','))
config = PromptConfig(points=np.array([[x, y]]), point_labels=np.array([1]))
else:
config = PromptConfig()
sam.video_segmentation(args.video, config, args.output)
elif args.image:
# 单张图片推理
image = cv2.imread(args.image)
sam.set_image(image)
# 构建提示配置
config = PromptConfig()
if args.point:
x, y = map(int, args.point.split(','))
config.points = np.array([[x, y]])
config.point_labels = np.array([1])
if args.box:
x1, y1, x2, y2 = map(int, args.box.split(','))
config.box = np.array([x1, y1, x2, y2])
masks, scores, logits = sam.predict(config)
# 可视化最佳掩码
best_mask = masks[scores.argmax()]
vis_image = image.copy()
colored_mask = np.zeros_like(vis_image)
colored_mask[best_mask] = [0, 255, 0]
vis_image = cv2.addWeighted(vis_image, 0.6, colored_mask, 0.4, 0)
if config.points is not None:
for pt, label in zip(config.points, config.point_labels):
color = (0, 255, 0) if label == 1 else (0, 0, 255)
cv2.circle(vis_image, tuple(pt), 8, color, -1)
cv2.imwrite(args.output or 'sam_result.jpg', vis_image)
print(f"Segmentation saved, score: {scores.max():.3f}")
if __name__ == '__main__':
main()
5.2.2 CLIP/ALIGN:零样本分类与图文检索
CLIP通过对比学习在4亿图文对上训练,将图像与文本映射到共享嵌入空间,实现零样本分类与跨模态检索。ALIGN采用噪声对比估计(NCE)与双重编码器架构处理更大规模(18亿对)但噪声更多的数据。部署时需关注文本编码缓存、批量图像编码与近似最近邻(ANN)索引加速检索 。
Python
#!/usr/bin/env python3
"""
Script: clip_zero_shot.py
Content: CLIP零样本分类与图文检索系统部署
Usage:
1. 安装依赖: pip install transformers torch opencv-python faiss-cpu
2. 零样本分类: python clip_zero_shot.py --image cat.jpg --labels "cat,dog,bird"
3. 图像检索: python clip_zero_shot.py --index ./images/ --query "a red car"
4. 视频检索: python clip_zero_shot.py --video video.mp4 --query "person walking" --top-k 5
"""
import os
import argparse
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from pathlib import Path
from typing import List, Tuple, Dict
from PIL import Image
import pickle
try:
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
except ImportError:
raise ImportError("Install transformers: pip install transformers")
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
print("FAISS not available, using numpy for similarity search")
class CLIPZeroShotClassifier:
"""
CLIP零样本分类器,支持动态标签与批量推理
优化:文本嵌入缓存、半精度推理、图像预处理流水线
"""
def __init__(self, model_name='openai/clip-vit-base-patch32', device='cuda'):
self.device = device
self.model = CLIPModel.from_pretrained(model_name).to(device)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
# 半精度加速
if device == 'cuda':
self.model = self.model.half()
self.model.eval()
self.text_embeddings = {} # 缓存文本嵌入
def encode_text(self, texts: List[str], cache_key: str = None) -> torch.Tensor:
"""
编码文本描述,支持缓存机制
优化:对固定标签集(如ImageNet类别)预先计算并缓存
"""
if cache_key and cache_key in self.text_embeddings:
return self.text_embeddings[cache_key]
inputs = self.tokenizer(
texts,
padding=True,
return_tensors="pt",
truncation=True
).to(self.device)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
text_features = F.normalize(text_features, dim=-1)
if cache_key:
self.text_embeddings[cache_key] = text_features
return text_features
def encode_image(self, images: List[np.ndarray]) -> torch.Tensor:
"""
批量编码图像
输入:BGR格式OpenCV图像列表
"""
# 转换为PIL RGB格式
pil_images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images]
inputs = self.processor(
images=pil_images,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
image_features = F.normalize(image_features, dim=-1)
return image_features
def classify(self, image: np.ndarray, labels: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
"""
零样本图像分类
返回:[(label, probability), ...]
"""
# 编码图像与标签
image_features = self.encode_image([image])
text_features = self.encode_text([f"a photo of a {label}" for label in labels])
# 计算相似度
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
probs = similarity[0].cpu().numpy()
# 排序返回Top-K
indices = np.argsort(probs)[::-1][:top_k]
return [(labels[i], float(probs[i])) for i in indices]
def classify_batch(self, images: List[np.ndarray], labels: List[str]) -> np.ndarray:
"""
批量分类,返回概率矩阵 [num_images, num_labels]
优化:共享文本编码,批量图像处理
"""
image_features = self.encode_image(images)
text_features = self.encode_text([f"a photo of a {label}" for label in labels])
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
return similarity.cpu().numpy()
class CLIPImageRetrieval:
"""
CLIP图像检索系统,支持大规模图像库索引与语义搜索
实现:FAISS向量索引、增量更新、多模态查询
"""
def __init__(self, model_name='openai/clip-vit-base-patch32', device='cuda',
index_path: str = None, dim: int = 512):
self.device = device
self.dim = dim
self.classifier = CLIPZeroShotClassifier(model_name, device)
# 初始化FAISS索引(若可用)
if FAISS_AVAILABLE:
# 使用内积(余弦相似度)的扁平索引,精确但较慢
# 大规模数据可切换为IVF或HNSW索引
self.index = faiss.IndexFlatIP(dim)
else:
self.index = None
self.vectors = []
self.image_paths = []
self.index_path = index_path
# 加载已有索引
if index_path and os.path.exists(index_path):
self.load_index(index_path)
def build_index(self, image_dir: str, batch_size: int = 32):
"""
从图像目录构建索引
优化:批量编码、进度保存、断点续传
"""
image_paths = list(Path(image_dir).rglob('*.jpg')) + \
list(Path(image_dir).rglob('*.png'))
print(f"Indexing {len(image_paths)} images...")
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i+batch_size]
batch_images = []
for path in batch_paths:
img = cv2.imread(str(path))
if img is not None:
batch_images.append(img)
self.image_paths.append(str(path))
if not batch_images:
continue
# 批量编码
features = self.classifier.encode_image(batch_images)
features = features.cpu().numpy().astype('float32')
# 添加到索引
if FAISS_AVAILABLE:
self.index.add(features)
else:
self.vectors.extend(features)
if (i // batch_size) % 10 == 0:
print(f"Processed {min(i+batch_size, len(image_paths))}/{len(image_paths)}")
# 构建numpy索引(若未使用FAISS)
if not FAISS_AVAILABLE and self.vectors:
self.vectors = np.array(self.vectors)
print("Indexing complete")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""
文本查询图像检索
返回:匹配图像路径与相似度分数
"""
# 编码查询文本
text_features = self.classifier.encode_text([query])
text_features = text_features.cpu().numpy().astype('float32')
# 相似度搜索
if FAISS_AVAILABLE:
scores, indices = self.index.search(text_features, top_k)
scores = scores[0]
indices = indices[0]
else:
# NumPy实现
similarities = np.dot(self.vectors, text_features.T).squeeze()
top_indices = np.argsort(similarities)[::-1][:top_k]
scores = similarities[top_indices]
indices = top_indices
results = []
for score, idx in zip(scores, indices):
if idx < len(self.image_paths):
results.append({
'path': self.image_paths[idx],
'score': float(score),
'index': int(idx)
})
return results
def search_by_image(self, query_image: np.ndarray, top_k: int = 5) -> List[Dict]:
"""以图搜图"""
image_features = self.classifier.encode_image([query_image])
image_features = image_features.cpu().numpy().astype('float32')
if FAISS_AVAILABLE:
scores, indices = self.index.search(image_features, top_k)
scores = scores[0]
indices = indices[0]
else:
similarities = np.dot(self.vectors, image_features.T).squeeze()
top_indices = np.argsort(similarities)[::-1][:top_k]
scores = similarities[top_indices]
indices = top_indices
return [{'path': self.image_paths[idx], 'score': float(score), 'index': int(idx)}
for score, idx in zip(scores, indices) if idx < len(self.image_paths)]
def save_index(self, path: str):
"""保存索引与元数据"""
data = {
'image_paths': self.image_paths,
'vectors': self.vectors if not FAISS_AVAILABLE else None
}
if FAISS_AVAILABLE:
faiss.write_index(self.index, path + '.faiss')
with open(path + '.meta', 'wb') as f:
pickle.dump(data, f)
print(f"Index saved to {path}")
def load_index(self, path: str):
"""加载索引"""
if FAISS_AVAILABLE and os.path.exists(path + '.faiss'):
self.index = faiss.read_index(path + '.faiss')
if os.path.exists(path + '.meta'):
with open(path + '.meta', 'rb') as f:
data = pickle.load(f)
self.image_paths = data['image_paths']
if not FAISS_AVAILABLE:
self.vectors = data['vectors']
class VideoMomentRetrieval:
"""
视频时刻检索:基于CLIP的语义片段定位
应用:在视频中查找特定事件的时间戳
"""
def __init__(self, clip_classifier: CLIPZeroShotClassifier, fps: int = 1):
self.classifier = clip_classifier
self.fps = fps # 采样帧率
def index_video(self, video_path: str) -> Tuple[np.ndarray, List[float]]:
"""
为视频建立帧级特征索引
返回:特征矩阵与对应时间戳
"""
cap = cv2.VideoCapture(video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS)
frame_interval = int(video_fps / self.fps)
features = []
timestamps = []
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
feat = self.classifier.encode_image([frame])
features.append(feat.cpu().numpy())
timestamps.append(frame_count / video_fps)
frame_count += 1
cap.release()
return np.vstack(features), timestamps
def query_moments(self, video_path: str, query: str,
top_k: int = 5, window_sec: float = 2.0) -> List[Dict]:
"""
检索视频中与查询最相关的时刻
window_sec: 合并相邻匹配的时间窗口
"""
features, timestamps = self.index_video(video_path)
text_feat = self.classifier.encode_text([query]).cpu().numpy()
# 计算帧级相似度
similarities = np.dot(features, text_feat.T).squeeze()
# 找到Top-K帧
top_indices = np.argsort(similarities)[::-1][:top_k*3] # 多取一些用于合并
# 时间窗口合并
moments = []
for idx in sorted(top_indices):
time_sec = timestamps[idx]
score = similarities[idx]
# 检查是否与已有时刻重叠
merged = False
for moment in moments:
if abs(moment['time'] - time_sec) < window_sec:
if score > moment['score']:
moment['time'] = time_sec
moment['score'] = float(score)
moment['frame_idx'] = idx
merged = True
break
if not merged and len(moments) < top_k:
moments.append({
'time': time_sec,
'score': float(score),
'frame_idx': idx
})
return sorted(moments, key=lambda x: x['score'], reverse=True)
def main():
parser = argparse.ArgumentParser(description='CLIP Zero-Shot & Retrieval')
parser.add_argument('--image', type=str, help='Input image')
parser.add_argument('--labels', type=str, help='Comma-separated labels')
parser.add_argument('--index', type=str, help='Image directory to index')
parser.add_argument('--query', type=str, help='Text query')
parser.add_argument('--video', type=str, help='Video path')
parser.add_argument('--top-k', type=int, default=5)
parser.add_argument('--model', default='openai/clip-vit-base-patch32')
parser.add_argument('--device', default='cuda')
parser.add_argument('--save-index', type=str, help='Save index path')
parser.add_argument('--load-index', type=str, help='Load index path')
args = parser.parse_args()
if args.image and args.labels:
# 零样本分类
classifier = CLIPZeroShotClassifier(args.model, args.device)
image = cv2.imread(args.image)
labels = [l.strip() for l in args.labels.split(',')]
results = classifier.classify(image, labels, args.top_k)
print("Classification Results:")
for label, prob in results:
print(f" {label}: {prob:.4f}")
elif args.index:
# 构建索引并检索
retrieval = CLIPImageRetrieval(args.model, args.device, args.load_index)
if not args.load_index:
retrieval.build_index(args.index)
if args.save_index:
retrieval.save_index(args.save_index)
if args.query:
results = retrieval.search(args.query, args.top_k)
print(f"\nQuery: '{args.query}'")
for i, res in enumerate(results):
print(f"{i+1}. {res['path']} (score: {res['score']:.4f})"
elif args.video and args.query:
# 视频时刻检索
classifier = CLIPZeroShotClassifier(args.model, args.device)
vmr = VideoMomentRetrieval(classifier)
moments = vmr.query_moments(args.video, args.query, args.top_k)
print(f"Query: '{args.query}'")
print("Top moments:")
for m in moments:
print(f" {m['time']:.1f}s (score: {m['score']:.4f})")
if __name__ == '__main__':
main()
5.2.3 视觉Prompt工程:基于分割的开放词汇检测
开放词汇检测(Open-Vocabulary Detection, OVD)结合目标检测与视觉-语言模型,实现对训练时未见过的类别的检测。通过将检测框区域特征与文本嵌入对齐,OVD可动态扩展至新类别。Grounding DINO与GLIP等模型将短语定位与检测统一,支持基于任意文本描述的检测 。
Python
#!/usr/bin/env python3
"""
Script: open_vocabulary_detection.py
Content: 开放词汇检测与视觉Prompt工程
Usage:
1. 安装依赖: pip install transformers torch opencv-python groundingdino-py
2. 文本提示检测: python open_vocabulary_detection.py --image scene.jpg --text "red car . person with hat"
3. 交互式Prompt: python open_vocabulary_detection.py --image scene.jpg --interactive
"""
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import List, Dict, Tuple
try:
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
except ImportError:
raise ImportError("Install transformers: pip install transformers")
class OpenVocabularyDetector:
"""
开放词汇检测器,支持任意文本描述的目标检测
基于Grounding DINO或OWL-ViT架构
"""
def __init__(self, model_name='google/owlvit-base-patch32', device='cuda'):
self.device = device
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name).to(device)
self.model.eval()
def detect(self, image: np.ndarray, text_queries: List[str],
threshold: float = 0.3, nms_threshold: float = 0.5) -> List[Dict]:
"""
执行开放词汇检测
text_queries: 如 ["a red car", "a person wearing a hat"]
"""
# 准备输入
inputs = self.processor(
text=text_queries,
images=image,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# 后处理获取检测结果
target_sizes = torch.Tensor([image.shape[:2]]).to(self.device)
results = self.processor.post_process_grounded_object_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=threshold,
text_labels=text_queries
)[0]
# 应用NMS
boxes = results["boxes"].cpu().numpy()
scores = results["scores"].cpu().numpy()
labels = results["text_labels"]
# OpenCV NMS
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
scores.tolist(),
threshold,
nms_threshold
)
detections = []
for idx in indices.flatten():
detections.append({
'bbox': boxes[idx].tolist(),
'score': float(scores[idx]),
'label': labels[idx]
})
return detections
def visualize(self, image: np.ndarray, detections: List[Dict]) -> np.ndarray:
"""可视化检测结果"""
vis_img = image.copy()
for i, det in enumerate(detections):
x1, y1, x2, y2 = map(int, det['bbox'])
color = tuple(np.random.randint(0, 255, 3).tolist())
cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
label = f"{det['label']}: {det['score']:.2f}"
# 绘制标签背景
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(vis_img, (x1, y1-th-10), (x1+tw, y1), color, -1)
cv2.putText(vis_img, label, (x1, y1-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
return vis_img
class VisualPromptEngine:
"""
视觉Prompt工程工具:生成优化的文本描述以提高检测精度
"""
def __init__(self):
self.prompt_templates = [
"a photo of a {}",
"a {}",
"a close-up of a {}",
"a {} in the scene",
"the {}"
]
def generate_prompts(self, base_concept: str, num_variants: int = 5) -> List[str]:
"""生成多样化的Prompt变体"""
prompts = []
for template in self.prompt_templates[:num_variants]:
prompts.append(template.format(base_concept))
return prompts
def optimize_prompt(self, detector, image: np.ndarray,
base_concept: str, iterations: int = 3) -> str:
"""
通过迭代测试选择最佳Prompt
策略:测试不同模板,选择置信度最高的
"""
best_prompt = base_concept
best_score = 0
prompts = self.generate_prompts(base_concept)
for prompt in prompts:
results = detector.detect(image, [prompt], threshold=0.1)
if results:
avg_score = np.mean([r['score'] for r in results])
if avg_score > best_score:
best_score = avg_score
best_prompt = prompt
return best_prompt
def main():
parser = argparse.ArgumentParser(description='Open Vocabulary Detection')
parser.add_argument('--image', type=str, required=True)
parser.add_argument('--text', type=str, help='Text queries (comma separated)')
parser.add_argument('--threshold', type=float, default=0.3)
parser.add_argument('--model', default='google/owlvit-base-patch32')
parser.add_argument('--interactive', action='store_true')
parser.add_argument('--output', type=str, default='ovd_result.jpg')
args = parser.parse_args()
detector = OpenVocabularyDetector(args.model)
image = cv2.imread(args.image)
if args.interactive:
# 交互式模式:输入文本,实时检测
print("Interactive mode. Enter text queries (empty to quit):")
while True:
query = input("Query: ").strip()
if not query:
break
queries = [q.strip() for q in query.split(',')]
results = detector.detect(image, queries, args.threshold)
print(f"Found {len(results)} objects")
for r in results:
print(f" - {r['label']}: {r['score']:.3f}")
vis = detector.visualize(image, results)
cv2.imshow('Detection', vis)
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
queries = [q.strip() for q in args.text.split('.') if q.strip()]
results = detector.detect(image, queries, args.threshold)
vis = detector.visualize(image, results)
cv2.imwrite(args.output, vis)
print(f"Saved to {args.output}, detected {len(results)} objects")
if __name__ == '__main__':
main()
5.2.4 EfficientSAM/MobileSAM:边缘端大模型压缩
MobileSAM通过解耦蒸馏将SAM的ViT-H编码器压缩为轻量级编码器,在保持性能的同时实现单CPU线程上的实时推理。EfficientSAM进一步采用掩码图像预训练(SAMI)与SAM预训练权重初始化,在COCO数据集上实现零样本实例分割的SOTA性能。部署优化包括INT8量化、CoreML/ONNX转换与注意力机制剪枝 。
Python
#!/usr/bin/env python3
"""
Script: mobile_sam_deploy.py
Content: MobileSAM/EfficientSAM边缘端部署与优化
Usage:
1. 安装依赖: pip install mobile-sam opencv-python torch
2. 标准推理: python mobile_sam_deploy.py --image input.jpg --point 500,375
3. ONNX导出: python mobile_sam_deploy.py --export-onnx --output mobile_sam.onnx
4. INT8量化: python mobile_sam_deploy.py --export-onnx --quantize --output mobile_sam_int8.onnx
"""
import argparse
import numpy as np
import cv2
import torch
import torch.nn as nn
from pathlib import Path
try:
from mobile_sam import sam_model_registry, SamPredictor
except ImportError:
raise ImportError("Install MobileSAM: pip install mobile-sam")
class MobileSAMOptimizer:
"""
MobileSAM优化部署类,支持量化与格式转换
"""
def __init__(self, model_type='vit_t', checkpoint=None, device='cpu'):
self.device = device
if checkpoint is None:
checkpoint = self._download_checkpoint()
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device)
self.predictor = SamPredictor(sam)
def _download_checkpoint(self):
"""下载MobileSAM轻量级权重"""
import urllib.request
cache_dir = Path.home() / '.cache' / 'mobile_sam'
cache_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = cache_dir / 'mobile_sam.pth'
if not checkpoint_path.exists():
url = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
print("Downloading MobileSAM checkpoint...")
urllib.request.urlretrieve(url, checkpoint_path)
return str(checkpoint_path)
def export_onnx(self, output_path: str, quantize: bool = False):
"""
导出为ONNX格式(仅解码器,编码器保持原格式)
优化:动态形状支持、操作符融合
"""
# 创建示例输入
dummy_image_embedding = torch.randn(1, 256, 64, 64).to(self.device)
dummy_point_coords = torch.randint(0, 1024, (1, 1, 2)).float().to(self.device)
dummy_point_labels = torch.ones(1, 1).long().to(self.device)
# 导出掩码解码器
torch.onnx.export(
self.predictor.model.mask_decoder,
(dummy_image_embedding, dummy_point_coords, dummy_point_labels),
output_path,
input_names=['image_embeddings', 'point_coords', 'point_labels'],
output_names=['masks', 'iou_predictions'],
dynamic_axes={
'point_coords': {0: 'batch_size', 1: 'num_points'},
'point_labels': {0: 'batch_size', 1: 'num_points'},
'masks': {0: 'batch_size'},
'iou_predictions': {0: 'batch_size'}
},
opset_version=13
)
print(f"Exported to {output_path}")
if quantize:
self._quantize_onnx(output_path)
def _quantize_onnx(self, model_path: str):
"""INT8量化(需onnxruntime-tools)"""
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input=model_path,
model_output=model_path.replace('.onnx', '_int8.onnx'),
weight_type=QuantType.QInt8
)
print("INT8 quantization complete")
except ImportError:
print("onnxruntime quantization not available")
def benchmark(self, image: np.ndarray, num_runs: int = 100):
"""性能基准测试"""
self.predictor.set_image(image)
point = np.array([[image.shape[1]//2, image.shape[0]//2]])
label = np.array([1])
# 预热
for _ in range(10):
self.predictor.predict(point_coords=point, point_labels=label)
# 测试
import time
times = []
for _ in range(num_runs):
start = time.perf_counter()
self.predictor.predict(point_coords=point, point_labels=label)
times.append((time.perf_counter() - start) * 1000)
print(f"Mean latency: {np.mean(times):.2f}ms")
print(f"FPS: {1000/np.mean(times):.1f}")
def main():
parser = argparse.ArgumentParser(description='MobileSAM Deployment')
parser.add_argument('--image', type=str, help='Input image')
parser.add_argument('--point', type=str, help='Point prompt (x,y)')
parser.add_argument('--export-onnx', action='store_true')
parser.add_argument('--quantize', action='store_true')
parser.add_argument('--output', type=str, default='mobile_sam.onnx')
parser.add_argument('--benchmark', action='store_true')
args = parser.parse_args()
optimizer = MobileSAMOptimizer()
if args.export_onnx:
optimizer.export_onnx(args.output, args.quantize)
elif args.benchmark and args.image:
image = cv2.imread(args.image)
optimizer.benchmark(image)
elif args.image:
image = cv2.imread(args.image)
optimizer.predictor.set_image(image)
if args.point:
x, y = map(int, args.point.split(','))
masks, scores, _ = optimizer.predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1])
)
# 可视化
mask = masks[0]
vis = image.copy()
vis[mask] = vis[mask] * 0.5 + np.array([0, 255, 0]) * 0.5
cv2.imwrite('mobile_sam_result.jpg', vis)
if __name__ == '__main__':
main()
5.3 生成式模型加速
5.3.1 Stable Diffusion:UNet优化与Latent空间处理
Stable Diffusion在压缩的潜在空间(Latent Space)中执行扩散过程,相比像素空间显著降低计算复杂度。UNet架构通过交叉注意力机制注入文本条件,其计算瓶颈主要存在于注意力层与ResNet块。TensorRT优化可实现4倍加速,通过层融合、精度降低(FP16/INT8)与内核自动调优 。动态形状支持允许在单次推理中生成不同尺寸的图像,但需权衡内存开销与灵活性 。
Python
#!/usr/bin/env python3
"""
Script: stable_diffusion_optimized.py
Content: Stable Diffusion TensorRT优化与Latent空间处理
Usage:
1. 安装依赖: pip install diffusers torch tensorrt
2. 导出引擎: python stable_diffusion_optimized.py --export --model stabilityai/stable-diffusion-2-1
3. 文生图: python stable_diffusion_optimized.py --prompt "a beautiful sunset" --steps 30
4. 图生图: python stable_diffusion_optimized.py --prompt "oil painting style" --init-image input.jpg --strength 0.7
"""
import argparse
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional, Union
try:
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler
except ImportError:
raise ImportError("Install diffusers: pip install diffusers")
class OptimizedStableDiffusion:
"""
优化的Stable Diffusion推理引擎
优化策略:UNet TensorRT转换、VAE切片解码、Latent缓存
"""
def __init__(self, model_id='stabilityai/stable-diffusion-2-1', device='cuda'):
self.device = device
self.model_id = model_id
# 加载管道
self.pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
safety_checker=None, # 禁用安全检查器以加速
requires_safety_checker=False
).to(device)
# 使用更快的调度器
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
# 启用内存优化
self.pipe.enable_attention_slicing(1) # 注意力切片
self.pipe.enable_vae_slicing() # VAE切片解码
if device == 'cuda':
self.pipe.enable_xformers_memory_efficient_attention() # 内存高效注意力
def generate(self, prompt: str, negative_prompt: str = "",
num_inference_steps: int = 30, guidance_scale: float = 7.5,
height: int = 512, width: int = 512,
seed: Optional[int] = None) -> np.ndarray:
"""
文生图推理
优化:批量提示编码、Latent缓存、确定性采样
"""
generator = torch.Generator(self.device).manual_seed(seed) if seed else None
with torch.inference_mode():
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=generator
).images[0]
return np.array(image)
def img2img(self, init_image: np.ndarray, prompt: str,
strength: float = 0.7, num_inference_steps: int = 30) -> np.ndarray:
"""
图生图推理
strength: 0-1之间,越大变形越多
"""
from PIL import Image
# 转换输入图像
init_pil = Image.fromarray(init_image).convert('RGB')
# 使用图生图管道
pipe_img2img = StableDiffusionImg2ImgPipeline(
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
unet=self.pipe.unet,
scheduler=self.pipe.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
).to(self.device)
with torch.inference_mode():
result = pipe_img2img(
prompt=prompt,
image=init_pil,
strength=strength,
num_inference_steps=num_inference_steps
).images[0]
return np.array(result)
def export_tensorrt(self, output_dir: str = './trt_engines'):
"""
导出UNet与VAE到TensorRT(需TensorRT 8.6+)
注意:此过程耗时且需要大量显存
"""
from torch2trt import torch2trt
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
# 导出UNet(主要计算瓶颈)
print("Exporting UNet to TensorRT...")
# 创建示例输入
sample = torch.randn(1, 4, 64, 64).half().to(self.device)
timestep = torch.tensor([999]).half().to(self.device)
encoder_hidden_states = torch.randn(1, 77, 1024).half().to(self.device)
# 转换
model_trt = torch2trt(
self.pipe.unet,
[sample, timestep, encoder_hidden_states],
fp16_mode=True,
max_workspace_size=1 << 30
)
# 保存
torch.save(model_trt.state_dict(), output_dir / 'unet_trt.pth')
print(f"UNet TensorRT engine saved to {output_dir}")
def latent_interpolation(self, prompt1: str, prompt2: str,
num_steps: int = 10) -> list:
"""
Latent空间插值:在两个提示之间平滑过渡
返回:插值图像列表
"""
# 编码两个提示
text_input1 = self.pipe.tokenizer(
prompt1, padding="max_length", max_length=77, return_tensors="pt"
).input_ids.to(self.device)
text_input2 = self.pipe.tokenizer(
prompt2, padding="max_length", max_length=77, return_tensors="pt"
).input_ids.to(self.device)
with torch.no_grad():
text_embeds1 = self.pipe.text_encoder(text_input1)[0]
text_embeds2 = self.pipe.text_encoder(text_input2)[0]
images = []
for i in range(num_steps):
# 插值文本嵌入
alpha = i / (num_steps - 1)
interp_embeds = (1 - alpha) * text_embeds1 + alpha * text_embeds2
# 使用插值嵌入生成图像
latents = torch.randn(1, 4, 64, 64, device=self.device, dtype=torch.float16)
with torch.inference_mode():
image = self.pipe(
prompt_embeds=interp_embeds,
latents=latents,
num_inference_steps=30
).images[0]
images.append(np.array(image))
return images
def main():
parser = argparse.ArgumentParser(description='Optimized Stable Diffusion')
parser.add_argument('--export', action='store_true', help='Export TensorRT engines')
parser.add_argument('--prompt', type=str, help='Text prompt')
parser.add_argument('--negative-prompt', type=str, default="", help='Negative prompt')
parser.add_argument('--init-image', type=str, help='Initial image for img2img')
parser.add_argument('--strength', type=float, default=0.7, help='Img2img strength')
parser.add_argument('--steps', type=int, default=30, help='Inference steps')
parser.add_argument('--seed', type=int, help='Random seed')
parser.add_argument('--output', type=str, default='output.png', help='Output path')
parser.add_argument('--interpolate', action='store_true', help='Latent interpolation')
args = parser.parse_args()
sd = OptimizedStableDiffusion()
if args.export:
sd.export_tensorrt()
elif args.interpolate and args.prompt:
prompts = args.prompt.split('|')
if len(prompts) == 2:
images = sd.latent_interpolation(prompts[0], prompts[1])
# 保存为网格
grid = np.hstack(images)
cv2.imwrite(args.output, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
elif args.init_image:
import cv2
init_img = cv2.imread(args.init_image)
init_img = cv2.cvtColor(init_img, cv2.COLOR_BGR2RGB)
result = sd.img2img(init_img, args.prompt, args.strength, args.steps)
cv2.imwrite(args.output, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
elif args.prompt:
result = sd.generate(
args.prompt,
args.negative_prompt,
args.steps,
seed=args.seed
)
import cv2
cv2.imwrite(args.output, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
print(f"Saved to {args.output}")
if __name__ == '__main__':
main()
5.3.2 ControlNet/OpenPose:条件生成管线搭建
ControlNet通过锁定预训练扩散模型参数并训练可学习的副本,实现对生成过程的细粒度控制。OpenPose作为条件输入提供人体姿态骨架,引导生成与姿态一致的人物图像。管线优化包括条件编码器缓存、多ControlNet组合与实时预览生成 。
Python
#!/usr/bin/env python3
"""
Script: controlnet_pipeline.py
Content: ControlNet/OpenPose条件生成管线
Usage:
1. 安装依赖: pip install diffusers controlnet-aux opencv-python
2. Canny边缘: python controlnet_pipeline.py --image input.jpg --mode canny --prompt "modern building"
3. OpenPose: python controlnet_pipeline.py --image pose.jpg --mode openpose --prompt "dancing person"
4. 深度控制: python controlnet_pipeline.py --image room.jpg --mode depth --prompt "futuristic interior"
"""
import argparse
import numpy as np
import cv2
import torch
from PIL import Image
from pathlib import Path
try:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from controlnet_aux import OpenposeDetector, CannyDetector, MidasDetector
except ImportError:
raise ImportError("Install: pip install diffusers controlnet-aux")
class ControlNetPipeline:
"""
ControlNet多条件生成管线,支持Canny/Depth/Pose/OpenPose
优化:条件检测器缓存、多ControlNet组合、批量条件处理
"""
CONTROLNET_MODELS = {
'canny': 'lllyasviel/sd-controlnet-canny',
'depth': 'lllyasviel/sd-controlnet-depth',
'pose': 'lllyasviel/sd-controlnet-openpose',
'openpose': 'lllyasviel/sd-controlnet-openpose'
}
def __init__(self, base_model='runwayml/stable-diffusion-v1-5', device='cuda'):
self.device = device
self.base_model = base_model
# 初始化条件检测器(延迟加载)
self.detectors = {}
def get_detector(self, mode: str):
"""延迟初始化条件检测器"""
if mode not in self.detectors:
if mode == 'openpose' or mode == 'pose':
self.detectors[mode] = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
elif mode == 'canny':
self.detectors[mode] = CannyDetector()
elif mode == 'depth':
self.detectors[mode] = MidasDetector.from_pretrained('lllyasviel/ControlNet')
return self.detectors[mode]
def prepare_condition(self, image: np.ndarray, mode: str,
low_threshold: int = 100,
high_threshold: int = 200) -> np.ndarray:
"""
准备条件图像
Canny: 边缘检测
OpenPose: 人体姿态骨架
Depth: 深度图
"""
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
detector = self.get_detector(mode)
if mode == 'canny':
# Canny边缘检测
condition = detector(pil_image, low_threshold, high_threshold)
elif mode in ['openpose', 'pose']:
# OpenPose姿态检测
condition = detector(pil_image, hand_and_face=True)
elif mode == 'depth':
# MiDaS深度估计
condition = detector(pil_image)
else:
raise ValueError(f"Unknown mode: {mode}")
return np.array(condition)
def generate(self, condition_image: np.ndarray, prompt: str,
mode: str = 'canny', negative_prompt: str = "",
num_steps: int = 30, guidance_scale: float = 9.0,
controlnet_conditioning_scale: float = 1.0) -> np.ndarray:
"""
条件生成图像
"""
# 加载ControlNet模型(若未加载)
if not hasattr(self, 'pipe') or self.current_mode != mode:
controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_MODELS[mode],
torch_dtype=torch.float16
).to(self.device)
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.base_model,
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None
).to(self.device)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_xformers_memory_efficient_attention()
self.current_mode = mode
# 转换条件图像
condition_pil = Image.fromarray(condition_image)
# 生成
with torch.inference_mode():
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=condition_pil,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale
).images[0]
return np.array(result)
def multi_controlnet_generate(self, conditions: dict, prompt: str,
scales: dict = None) -> np.ndarray:
"""
多ControlNet组合生成(如Canny+OpenPose)
conditions: {'canny': canny_img, 'openpose': pose_img}
"""
# 加载多个ControlNet
controlnets = []
images = []
for mode, img in conditions.items():
cn = ControlNetModel.from_pretrained(
self.CONTROLNET_MODELS[mode],
torch_dtype=torch.float16
).to(self.device)
controlnets.append(cn)
images.append(Image.fromarray(img))
pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.base_model,
controlnet=controlnets,
torch_dtype=torch.float16,
safety_checker=None
).to(self.device)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# 默认等权重
if scales is None:
scales = [1.0] * len(controlnets)
with torch.inference_mode():
result = pipe(
prompt=prompt,
image=images,
controlnet_conditioning_scale=scales,
num_inference_steps=30
).images[0]
return np.array(result)
def main():
parser = argparse.ArgumentParser(description='ControlNet Pipeline')
parser.add_argument('--image', type=str, required=True, help='Input image')
parser.add_argument('--mode', type=str, default='canny',
choices=['canny', 'depth', 'openpose', 'pose'])
parser.add_argument('--prompt', type=str, required=True, help='Generation prompt')
parser.add_argument('--negative', type=str, default='low quality, blurry')
parser.add_argument('--steps', type=int, default=30)
parser.add_argument('--scale', type=float, default=1.0, help='Control strength')
parser.add_argument('--output', type=str, default='controlnet_output.png')
args = parser.parse_args()
pipeline = ControlNetPipeline()
# 读取图像
image = cv2.imread(args.image)
# 准备条件
print(f"Preparing {args.mode} condition...")
condition = pipeline.prepare_condition(image, args.mode)
# 保存条件图像供参考
cv2.imwrite(f'condition_{args.mode}.png', condition)
# 生成
print("Generating image...")
result = pipeline.generate(
condition,
args.prompt,
args.mode,
args.negative,
args.steps,
controlnet_conditioning_scale=args.scale
)
# 保存结果
result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
cv2.imwrite(args.output, result_bgr)
print(f"Saved to {args.output}")
if __name__ == '__main__':
main()
5.3.3 图像超分:Real-ESRGAN/BSRGAN轻量级部署
Real-ESRGAN通过高阶退化模型模拟真实世界的复杂退化过程,相比ESRGAN在真实图像上表现更优。BSRGAN针对盲超分辨率设计更广泛的退化空间。部署优化包括NCNN转换(移动端)、TensorRT加速与瓦片化(tile-based)推理处理高分辨率图像 。
Python
#!/usr/bin/env python3
"""
Script: super_resolution_deploy.py
Content: Real-ESRGAN/BSRGAN轻量级部署与优化
Usage:
1. 安装依赖: pip install realesrgan basicsr ncnn
2. 单张超分: python super_resolution_deploy.py --input lowres.jpg --model RealESRGAN_x4plus
3. 瓦片推理: python super_resolution_deploy.py --input large.jpg --tile 512 --tile-pad 32
4. NCNN转换: python super_resolution_deploy.py --export-ncnn --model RealESRGAN_x4plus
"""
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import Tuple
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
except ImportError:
raise ImportError("Install: pip install realesrgan basicsr")
class SuperResolutionEngine:
"""
超分辨率推理引擎,支持Real-ESRGAN与BSRGAN
优化:瓦片化推理、半精度、NCNN后端
"""
MODEL_CONFIGS = {
'RealESRGAN_x4plus': {
'num_block': 23,
'num_feat': 64,
'num_grow_ch': 32,
'scale': 4
},
'RealESRGAN_x2plus': {
'num_block': 23,
'num_feat': 64,
'num_grow_ch': 32,
'scale': 2
},
'RealESRGAN_x4plus_anime_6B': {
'num_block': 6,
'num_feat': 64,
'num_grow_ch': 32,
'scale': 4
}
}
def __init__(self, model_name='RealESRGAN_x4plus', device='cuda', tile=0, tile_pad=10):
self.model_name = model_name
self.device = device
self.tile = tile
self.tile_pad = tile_pad
# 初始化模型
self._init_model()
def _init_model(self):
"""初始化Real-ESRGAN模型"""
config = self.MODEL_CONFIGS.get(self.model_name, self.MODEL_CONFIGS['RealESRGAN_x4plus'])
# 构建网络
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=config['num_feat'],
num_block=config['num_block'],
num_grow_ch=config['num_grow_ch'],
scale=config['scale']
)
# 加载权重
model_path = self._download_model()
self.upsampler = RealESRGANer(
scale=config['scale'],
model_path=model_path,
model=model,
tile=self.tile,
tile_pad=self.tile_pad,
pre_pad=0,
half=True if self.device == 'cuda' else False,
device=self.device
)
def _download_model(self):
"""下载预训练模型"""
import urllib.request
model_urls = {
'RealESRGAN_x4plus': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
'RealESRGAN_x2plus': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
'RealESRGAN_x4plus_anime_6B': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
}
cache_dir = Path.home() / '.cache' / 'realesrgan'
cache_dir.mkdir(parents=True, exist_ok=True)
url = model_urls.get(self.model_name, model_urls['RealESRGAN_x4plus'])
model_path = cache_dir / f"{self.model_name}.pth"
if not model_path.exists():
print(f"Downloading {self.model_name}...")
urllib.request.urlretrieve(url, model_path)
return str(model_path)
def enhance(self, image: np.ndarray, outscale: float = 4.0) -> Tuple[np.ndarray, str]:
"""
超分辨率增强
返回:增强后的图像与RGB空间信息
"""
# RealESRGANer期望BGR输入
output, img_mode = self.upsampler.enhance(image, outscale=outscale)
return output, img_mode
def enhance_tiled(self, image: np.ndarray, tile_size: int = 512) -> np.ndarray:
"""
瓦片化超分(用于超大图像)
避免显存溢出,同时保持全局一致性
"""
h, w = image.shape[:2]
# 计算瓦片数量
tiles_x = (w + tile_size - 1) // tile_size
tiles_y = (h + tile_size - 1) // tile_size
# 创建输出画布
output_scale = 4 # Real-ESRGAN默认4x
output = np.zeros((h * output_scale, w * output_scale, 3), dtype=np.uint8)
for i in range(tiles_y):
for j in range(tiles_x):
# 计算瓦片坐标(含重叠)
y1 = i * tile_size
x1 = j * tile_size
y2 = min(y1 + tile_size + self.tile_pad, h)
x2 = min(x1 + tile_size + self.tile_pad, w)
tile = image[y1:y2, x1:x2]
# 超分瓦片
tile_output, _ = self.enhance(tile)
# 放置到输出(去除填充)
out_y1 = y1 * output_scale
out_x1 = x1 * output_scale
out_y2 = out_y1 + tile_output.shape[0]
out_x2 = out_x1 + tile_output.shape[1]
output[out_y1:out_y2, out_x1:out_x2] = tile_output
return output
def export_ncnn(self, output_dir: str):
"""
导出为NCNN格式(移动端部署)
需安装ncnn-python
"""
try:
import ncnn
# 导出ONNX中间格式
onnx_path = f"{self.model_name}.onnx"
dummy_input = torch.randn(1, 3, 64, 64).to(self.device)
torch.onnx.export(
self.upsampler.model,
dummy_input,
onnx_path,
input_names=['input'],
output_names=['output'],
opset_version=11
)
# 转换为NCNN
import onnx
from onnxsim import simplify
onnx_model = onnx.load(onnx_path)
model_simp, check = simplify(onnx_model)
# 使用ncnn转换工具(需系统安装ncnn)
print(f"ONNX exported to {onnx_path}")
print("Use ncnn tools to convert: onnx2ncnn model.onnx model.param model.bin")
except ImportError:
print("NCNN export requires ncnn-python")
def benchmark(self, image_size: Tuple[int, int] = (256, 256), num_runs: int = 50):
"""性能基准测试"""
dummy = np.random.randint(0, 255, (*image_size, 3), dtype=np.uint8)
# 预热
for _ in range(5):
self.enhance(dummy)
import time
times = []
for _ in range(num_runs):
start = time.perf_counter()
self.enhance(dummy)
times.append((time.perf_counter() - start) * 1000)
print(f"Mean latency: {np.mean(times):.2f}ms")
print(f"Throughput: {1000/np.mean(times):.2f} images/sec")
def main():
parser = argparse.ArgumentParser(description='Super Resolution Deployment')
parser.add_argument('--input', type=str, help='Input image')
parser.add_argument('--model', default='RealESRGAN_x4plus',
choices=['RealESRGAN_x4plus', 'RealESRGAN_x2plus', 'RealESRGAN_x4plus_anime_6B'])
parser.add_argument('--output', type=str, default='sr_output.png')
parser.add_argument('--outscale', type=float, default=4.0)
parser.add_argument('--tile', type=int, default=0, help='Tile size (0=disabled)')
parser.add_argument('--tile-pad', type=int, default=10)
parser.add_argument('--device', default='cuda')
parser.add_argument('--export-ncnn', action='store_true')
parser.add_argument('--benchmark', action='store_true')
args = parser.parse_args()
engine = SuperResolutionEngine(
args.model,
args.device,
args.tile,
args.tile_pad
)
if args.export_ncnn:
engine.export_ncnn('./ncnn_models')
elif args.benchmark:
engine.benchmark()
elif args.input:
image = cv2.imread(args.input)
if args.tile > 0:
result = engine.enhance_tiled(image, args.tile)
else:
result, _ = engine.enhance(image, args.outscale)
cv2.imwrite(args.output, result)
print(f"Saved to {args.output}")
if __name__ == '__main__':
main()
5.3.4 视频增强:BasicVSR++与BasicVSR-IconVSR
BasicVSR++通过二阶网格传播与流引导变形对齐改进长期时间建模,BasicVSR-IconVSR则引入信息重填充机制减少误差传播。部署优化包括帧间特征缓存、光流估计轻量化与时空瓦片处理 。
#!/usr/bin/env python3
"""
Script: video_enhancement.py
Content: BasicVSR++/IconVSR视频超分辨率部署
Usage:
1. 安装依赖: pip install mmcv-full mmedit
2. 视频超分: python video_enhancement.py --input lowres_video.mp4 --output hires_video.mp4 --model basicvsr_pp
3. 滑动窗口: python video_enhancement.py --input video.mp4 --window 30 --stride 15
"""
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import List
try:
from mmedit.apis import init_model, restoration_video_inference
except ImportError:
raise ImportError("Install: pip install mmcv-full mmedit")
class VideoEnhancementEngine:
"""
视频超分辨率引擎,支持BasicVSR++/BasicVSR/IconVSR
优化:滑动窗口推理、帧缓存、内存管理
"""
MODEL_CONFIGS = {
'basicvsr_pp': {
'config': 'configs/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4.py',
'checkpoint': 'https://download.openmmlab.com/mmediting/restorers/basicvsr_plusplus/basicvsr_plusplus_c64n7_8x1_600k_reds4_20210217-db622b2f.pth'
},
'basicvsr': {
'config': 'configs/restorers/basicvsr/basicvsr_reds4.py',
'checkpoint': 'https://download.openmmlab.com/mmediting/restorers/basicvsr/basicvsr_reds4_20120409-0e599677.pth'
},
'iconvsr': {
'config': 'configs/restorers/iconvsr/iconvsr_reds.py',
'checkpoint': 'https://download.openmmlab.com/mmediting/restorers/iconvsr/iconvsr_reds_20210413-92ba1d2a.pth'
}
}
def __init__(self, model_name='basicvsr_pp', device='cuda'):
self.model_name = model_name
self.device = device
# 初始化模型
config = self.MODEL_CONFIGS[model_name]
self.model = init_model(config['config'], config['checkpoint'], device=device)
def enhance_video(self, input_path: str, output_path: str,
window_size: int = None, max_seq_len: int = 100):
"""
视频超分辨率增强
window_size: 滑动窗口大小(None=整段处理,需大量显存)
max_seq_len: 最大序列长度限制
"""
# 读取视频
cap = cv2.VideoCapture(input_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 准备输出视频
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out_width, out_height = width * 4, height * 4 # 4x超分
out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
if window_size is None:
window_size = min(total_frames, max_seq_len)
# 滑动窗口处理
buffer = []
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
buffer.append(frame)
# 当缓冲区满或视频结束时处理
if len(buffer) >= window_size or (not ret and len(buffer) > 0):
enhanced = self._process_window(buffer)
for frame in enhanced:
out.write(frame)
# 保留重叠帧用于连续性(stride = window_size // 2)
overlap = window_size // 2
buffer = buffer[overlap:] if len(buffer) >= overlap else []
frame_idx += 1
if frame_idx % 10 == 0:
print(f"Processed {frame_idx}/{total_frames} frames")
cap.release()
out.release()
print(f"Enhanced video saved to {output_path}")
def _process_window(self, frames: List[np.ndarray]) -> List[np.ndarray]:
"""
处理帧窗口
输入:BGR帧列表
输出:增强后的BGR帧列表
"""
# 转换为Tensor [T, C, H, W]
frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
tensor = torch.from_numpy(np.stack(frames_rgb)).float() / 255.0
tensor = tensor.permute(0, 3, 1, 2).unsqueeze(0).to(self.device) # [1, T, C, H, W]
# 推理
with torch.no_grad():
output = restoration_video_inference(self.model, tensor)
# 转换回BGR
output = output.squeeze(0).permute(0, 2, 3, 1).cpu().numpy()
output = (output * 255).clip(0, 255).astype(np.uint8)
output_bgr = [cv2.cvtColor(f, cv2.COLOR_RGB2BGR) for f in output]
return output_bgr
def enhance_frame_at_index(self, video_path: str, frame_idx: int,
temporal_radius: int = 3) -> np.ndarray:
"""
增强单帧(利用时序上下文)
用于关键帧增强或预览
"""
cap = cv2.VideoCapture(video_path)
# 读取时序窗口
start_idx = max(0, frame_idx - temporal_radius)
end_idx = frame_idx + temporal_radius + 1
frames = []
cap.set(cv2.CAP_PROP_POS_FRAMES, start_idx)
for _ in range(end_idx - start_idx):
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
if len(frames) < temporal_radius * 2 + 1:
# 边界处理:镜像填充
while len(frames) < temporal_radius * 2 + 1:
frames.append(frames[-1])
enhanced = self._process_window(frames)
return enhanced[temporal_radius] # 返回中心帧
def main():
parser = argparse.ArgumentParser(description='Video Enhancement')
parser.add_argument('--input', type=str, required=True, help='Input video')
parser.add_argument('--output', type=str, default='enhanced_video.mp4')
parser.add_argument('--model', default='basicvsr_pp',
choices=['basicvsr_pp', 'basicvsr', 'iconvsr'])
parser.add_argument('--window', type=int, help='Sliding window size')
parser.add_argument('--device', default='cuda')
args = parser.parse_args()
engine = VideoEnhancementEngine(args.model, args.device)
engine.enhance_video(args.input, args.output, args.window)
if __name__ == '__main__':
main()
5.4 MLOps与模型管理
5.4.1 模型版本控制:DVC与MLflow集成
DVC(Data Version Control)扩展Git以管理大型模型文件与数据集,MLflow提供实验追踪与模型注册功能。集成方案通过MLflow记录超参数与指标,DVC管理模型权重版本,实现可复现的机器学习工作流 。
Python
#!/usr/bin/env python3
"""
Script: mlops_model_management.py
Content: DVC与MLflow集成的模型版本控制系统
Usage:
1. 安装依赖: pip install mlflow dvc
2. 初始化: python mlops_model_management.py --init
3. 训练并记录: python mlops_model_management.py --train --experiment yolov8_experiment
4. 模型注册: python mlops_model_management.py --register --model-name yolov8_production --run-id <id>
5. 版本切换: python mlops_model_management.py --checkout --version v1.0
"""
import os
import argparse
import json
import shutil
from pathlib import Path
from datetime import datetime
try:
import mlflow
import mlflow.pytorch
except ImportError:
raise ImportError("Install MLflow: pip install mlflow")
try:
import dvc.api
import dvc.repo
except ImportError:
raise ImportError("Install DVC: pip install dvc")
class ModelVersionManager:
"""
模型版本管理器,集成MLflow实验追踪与DVC数据版本控制
"""
def __init__(self, tracking_uri='http://localhost:5000', experiment_name='default'):
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
self.experiment_name = experiment_name
def init_repository(self):
"""初始化DVC仓库"""
if not Path('.dvc').exists():
os.system('dvc init')
os.system('git add .dvc .dvcignore')
print("DVC repository initialized")
else:
print("DVC already initialized")
def track_dataset(self, data_path: str, remote_name: str = 'storage'):
"""追踪数据集版本"""
# 添加远程存储(若不存在)
try:
os.system(f'dvc remote add -d {remote_name} /tmp/dvc-storage')
except:
pass
# 追踪数据
os.system(f'dvc add {data_path}')
os.system(f'git add {data_path}.dvc .gitignore')
print(f"Dataset {data_path} tracked by DVC")
def log_experiment(self, model, metrics: dict, params: dict,
artifacts: list = None, model_name: str = None):
"""
记录实验到MLflow,模型权重由DVC管理
"""
with mlflow.start_run() as run:
# 记录参数与指标
mlflow.log_params(params)
mlflow.log_metrics(metrics)
# 记录代码版本(Git commit)
import subprocess
try:
git_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()
mlflow.set_tag('git_commit', git_commit)
except:
pass
# 记录DVC追踪的模型路径
if model_name:
model_path = f'models/{model_name}'
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
# 保存模型
if hasattr(model, 'save'):
model.save(model_path)
else:
torch.save(model.state_dict(), model_path)
# DVC追踪
os.system(f'dvc add {model_path}')
# 记录DVC文件作为artifact
mlflow.log_artifact(f'{model_path}.dvc')
# 记录其他artifacts
if artifacts:
for artifact in artifacts:
mlflow.log_artifact(artifact)
print(f"Experiment logged. Run ID: {run.info.run_id}")
return run.info.run_id
def register_model(self, model_name: str, run_id: str,
model_uri: str = None, stage: str = 'Staging'):
"""
注册模型到MLflow Model Registry
"""
if model_uri is None:
model_uri = f'runs:/{run_id}/model'
# 注册模型
mv = mlflow.register_model(model_uri, model_name)
# 设置阶段(Staging/Production/Archived)
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name=model_name,
version=mv.version,
stage=stage
)
print(f"Model {model_name} version {mv.version} registered as {stage}")
return mv.version
def checkout_version(self, version_tag: str):
"""
检出特定版本的模型与数据
结合Git tag与DVC checkout
"""
# Git检出
os.system(f'git checkout {version_tag}')
# DVC检出对应数据
os.system('dvc checkout')
print(f"Checked out version {version_tag}")
def get_production_model(self, model_name: str):
"""获取生产环境模型"""
client = mlflow.tracking.MlflowClient()
# 查找Production阶段的最新版本
versions = client.get_latest_versions(model_name, stages=['Production'])
if not versions:
raise ValueError(f"No production version found for {model_name}")
version = versions[0]
# 下载模型
model_uri = f'models:/{model_name}/Production'
local_path = mlflow.artifacts.download_artifacts(model_uri)
print(f"Loaded production model {model_name} version {version.version}")
return local_path
def compare_runs(self, run_ids: list):
"""对比多个实验运行"""
client = mlflow.tracking.MlflowClient()
comparison = []
for run_id in run_ids:
run = client.get_run(run_id)
comparison.append({
'run_id': run_id,
'params': run.data.params,
'metrics': run.data.metrics,
'start_time': run.info.start_time
})
return comparison
def main():
parser = argparse.ArgumentParser(description='MLOps Model Management')
parser.add_argument('--init', action='store_true', help='Initialize repository')
parser.add_argument('--train', action='store_true', help='Run training and log')
parser.add_argument('--register', action='store_true', help='Register model')
parser.add_argument('--checkout', action='store_true', help='Checkout version')
parser.add_argument('--experiment', type=str, default='default')
parser.add_argument('--model-name', type=str)
parser.add_argument('--run-id', type=str)
parser.add_argument('--version', type=str)
parser.add_argument('--tracking-uri', default='http://localhost:5000')
args = parser.parse_args()
manager = ModelVersionManager(args.tracking_uri, args.experiment)
if args.init:
manager.init_repository()
elif args.train:
# 示例训练流程
import torch
import torch.nn as nn
# 模拟模型
model = nn.Linear(10, 2)
# 模拟指标
metrics = {'accuracy': 0.95, 'loss': 0.1, 'f1_score': 0.94}
params = {'learning_rate': 0.001, 'batch_size': 32, 'epochs': 100}
run_id = manager.log_experiment(
model, metrics, params,
model_name=args.model_name or 'demo_model'
)
print(f"Training completed. Run ID: {run_id}")
elif args.register:
if not args.run_id or not args.model_name:
raise ValueError("--run-id and --model-name required")
manager.register_model(args.model_name, args.run_id)
elif args.checkout:
if not args.version:
raise ValueError("--version required")
manager.checkout_version(args.version)
if __name__ == '__main__':
main()
5.4.2 动态批量推理:Batching策略与吞吐优化
动态批处理通过聚合多个独立请求形成批次,提高GPU利用率与吞吐量。关键参数包括最大批次大小、最大等待延迟与动态填充策略。Triton Inference Server提供内置的动态批处理调度器,支持优先级排队与模型流水线 。
Python
#!/usr/bin/env python3
"""
Script: dynamic_batching.py
Content: 动态批量推理引擎与吞吐优化
Usage:
1. 标准推理: python dynamic_batching.py --mode static --input ./images/
2. 动态批处理: python dynamic_batching.py --mode dynamic --max-batch-size 8 --max-latency-ms 50
3. 基准测试: python dynamic_batching.py --benchmark --batch-sizes 1,2,4,8,16
"""
import os
import time
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import List, Callable, Optional
from dataclasses import dataclass
from collections import deque
import threading
from queue import Queue
@dataclass
class InferenceRequest:
"""推理请求数据结构"""
id: str
image: np.ndarray
callback: Callable
timestamp: float
priority: int = 0 # 0=normal, 1=high
class DynamicBatcher:
"""
动态批量推理调度器
策略:延迟约束下的最大批次填充、优先级抢占、填充(padding)优化
"""
def __init__(self,
inference_fn: Callable,
max_batch_size: int = 8,
max_latency_ms: float = 50.0,
pad_to_multiple_of: int = 1):
self.inference_fn = inference_fn
self.max_batch_size = max_batch_size
self.max_latency_ms = max_latency_ms
self.pad_to_multiple_of = pad_to_multiple_of
# 请求队列(按优先级排序)
self.request_queue = deque()
self.queue_lock = threading.Lock()
# 结果回调队列
self.result_queue = Queue()
# 控制标志
self.running = False
self.worker_thread = None
def start(self):
"""启动批处理工作线程"""
self.running = True
self.worker_thread = threading.Thread(target=self._batch_worker)
self.worker_thread.start()
def stop(self):
"""停止批处理"""
self.running = False
if self.worker_thread:
self.worker_thread.join()
def submit(self, request: InferenceRequest):
"""提交推理请求"""
with self.queue_lock:
# 按优先级插入(高优先级在前)
if request.priority > 0:
self.request_queue.appendleft(request)
else:
self.request_queue.append(request)
def _batch_worker(self):
"""批处理工作线程主循环"""
while self.running:
batch = self._collect_batch()
if batch:
# 执行批量推理
start_time = time.perf_counter()
results = self._process_batch(batch)
latency = (time.perf_counter() - start_time) * 1000
# 回调结果
for req, result in zip(batch, results):
req.callback(result, latency / len(batch))
else:
time.sleep(0.001) # 避免忙等待
def _collect_batch(self) -> List[InferenceRequest]:
"""收集批次,考虑延迟约束与批次大小"""
batch = []
earliest_time = None
while len(batch) < self.max_batch_size:
with self.queue_lock:
if not self.request_queue:
break
req = self.request_queue[0]
# 检查延迟约束
wait_time = (time.perf_counter() - req.timestamp) * 1000
if earliest_time is None:
earliest_time = req.timestamp
# 若达到最大延迟或批次已满,开始处理
total_wait = (time.perf_counter() - earliest_time) * 1000
if total_wait >= self.max_latency_ms and len(batch) > 0:
break
batch.append(self.request_queue.popleft())
return batch
def _process_batch(self, requests: List[InferenceRequest]):
"""处理批次并分发结果"""
# 收集图像
images = [req.image for req in requests]
# 动态填充到固定尺寸(优化GPU利用率)
processed = self._preprocess_batch(images)
# 执行推理
with torch.no_grad():
outputs = self.inference_fn(processed)
# 拆分结果
results = self._postprocess_split(outputs, len(requests))
return results
def _preprocess_batch(self, images: List[np.ndarray]) -> torch.Tensor:
"""批量预处理与填充"""
# 统一尺寸(假设模型输入640x640)
target_size = (640, 640)
batch_tensors = []
for img in images:
# LetterBox处理
h, w = img.shape[:2]
scale = min(target_size[0] / w, target_size[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(img, (new_w, new_h))
canvas = np.full((target_size[1], target_size[0], 3), 114, dtype=np.uint8)
pad_x = (target_size[0] - new_w) // 2
pad_y = (target_size[1] - new_h) // 2
canvas[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
# 归一化
tensor = torch.from_numpy(canvas).permute(2, 0, 1).float() / 255.0
batch_tensors.append(tensor)
# 堆叠批次
batch = torch.stack(batch_tensors)
# 填充到指定倍数(某些模型要求)
if self.pad_to_multiple_of > 1:
pad_size = (self.pad_to_multiple_of - len(batch) % self.pad_to_multiple_of) % self.pad_to_multiple_of
if pad_size > 0:
dummy = torch.zeros(pad_size, *batch.shape[1:])
batch = torch.cat([batch, dummy], dim=0)
return batch.cuda()
def _postprocess_split(self, outputs, num_requests: int):
"""拆分批量输出为单个结果"""
# 假设输出是列表或张量,按批次维度拆分
if isinstance(outputs, torch.Tensor):
return [outputs[i:i+1] for i in range(num_requests)]
elif isinstance(outputs, (list, tuple)):
return [tuple(o[i] for o in outputs) for i in range(num_requests)]
else:
return [outputs] * num_requests
class ThroughputBenchmark:
"""
吞吐量基准测试工具
对比静态批处理 vs 动态批处理
"""
def __init__(self, model):
self.model = model
self.results = []
def benchmark_static(self, images: List[np.ndarray], batch_size: int):
"""测试静态批处理"""
timings = []
for i in range(0, len(images), batch_size):
batch = images[i:i+batch_size]
# 填充到固定批次大小
while len(batch) < batch_size:
batch.append(images[0])
start = time.perf_counter()
with torch.no_grad():
_ = self.model(batch)
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000)
throughput = len(images) / (sum(timings) / 1000)
return {
'mode': 'static',
'batch_size': batch_size,
'throughput': throughput,
'mean_latency': np.mean(timings)
}
def benchmark_dynamic(self, images: List[np.ndarray],
max_batch: int, max_latency_ms: float):
"""测试动态批处理"""
batcher = DynamicBatcher(
self.model,
max_batch_size=max_batch,
max_latency_ms=max_latency_ms
)
batcher.start()
results = []
completed = threading.Event()
completed_count = [0]
def callback(result, latency):
results.append(latency)
completed_count[0] += 1
if completed_count[0] >= len(images):
completed.set()
# 提交所有请求(模拟并发)
start_time = time.perf_counter()
for i, img in enumerate(images):
req = InferenceRequest(
id=f'req_{i}',
image=img,
callback=callback,
timestamp=time.perf_counter()
)
batcher.submit(req)
# 等待完成
completed.wait(timeout=60)
total_time = time.perf_counter() - start_time
batcher.stop()
throughput = len(images) / total_time
return {
'mode': 'dynamic',
'max_batch': max_batch,
'max_latency_ms': max_latency_ms,
'throughput': throughput,
'mean_latency': np.mean(results) if results else 0
}
def main():
parser = argparse.ArgumentParser(description='Dynamic Batching')
parser.add_argument('--mode', choices=['static', 'dynamic'], default='dynamic')
parser.add_argument('--max-batch-size', type=int, default=8)
parser.add_argument('--max-latency-ms', type=float, default=50.0)
parser.add_argument('--input', type=str, help='Input image directory')
parser.add_argument('--benchmark', action='store_true')
parser.add_argument('--batch-sizes', type=str, default='1,2,4,8')
args = parser.parse_args()
# 模拟模型(实际使用时替换为真实模型)
class DummyModel(torch.nn.Module):
def forward(self, x):
# 模拟推理延迟
torch.cuda.synchronize()
time.sleep(0.01) # 10ms模拟计算
return x
model = DummyModel().cuda()
if args.benchmark:
# 生成测试数据
dummy_images = [np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
for _ in range(100)]
benchmark = ThroughputBenchmark(model)
print("Benchmarking...")
for bs in [int(x) for x in args.batch_sizes.split(',')]:
result = benchmark.benchmark_static(dummy_images, bs)
print(f"Static BS={bs}: {result['throughput']:.1f} imgs/sec")
for max_lat in [10, 25, 50]:
result = benchmark.benchmark_dynamic(
dummy_images, args.max_batch_size, max_lat
)
print(f"Dynamic (max_lat={max_lat}ms): {result['throughput']:.1f} imgs/sec")
elif args.input:
# 实际推理模式
image_paths = list(Path(args.input).glob('*.jpg'))
images = [cv2.imread(str(p)) for p in image_paths]
if args.mode == 'dynamic':
batcher = DynamicBatcher(model, args.max_batch_size, args.max_latency_ms)
batcher.start()
def print_callback(result, latency):
print(f"Result received, latency: {latency:.2f}ms")
for i, img in enumerate(images):
req = InferenceRequest(
id=f'img_{i}',
image=img,
callback=print_callback,
timestamp=time.perf_counter()
)
batcher.submit(req)
time.sleep(5) # 等待处理完成
batcher.stop()
else:
# 静态批处理
for i in range(0, len(images), args.max_batch_size):
batch = images[i:i+args.max_batch_size]
result = model(batch)
if __name__ == '__main__':
main()
5.4.3 A/B测试框架:影子模式与灰度发布
影子模式(Shadow Mode)将生产流量复制到候选模型而不影响用户,用于无风险验证。灰度发布(Canary Release)逐步将流量从旧模型切换到新模型,支持快速回滚。实现需关注流量复制、延迟监控与一致性校验 。
Python
#!/usr/bin/env python3
"""
Script: ab_testing_framework.py
Content: A/B测试框架,支持影子模式与灰度发布
Usage:
1. 影子模式: python ab_testing_framework.py --mode shadow --model-a prod.pt --model-b candidate.pt --traffic-ratio 0.1
2. 灰度发布: python ab_testing_framework.py --mode canary --model-a prod.pt --model-b candidate.pt --canary-percent 10
3. 对比分析: python ab_testing_framework.py --analyze --logs shadow_logs.jsonl
"""
import os
import json
import time
import hashlib
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import Dict, List, Callable, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import threading
from collections import defaultdict
@dataclass
class InferenceResult:
"""推理结果数据结构"""
model_name: str
request_id: str
predictions: List[Dict]
latency_ms: float
timestamp: str
metadata: Dict
class ShadowModeTester:
"""
影子模式测试器:复制生产流量到候选模型
特性:异步执行、结果对比、延迟监控
"""
def __init__(self,
model_a: Callable, # 生产模型
model_b: Callable, # 候选模型
traffic_ratio: float = 0.1, # 复制流量比例
log_path: str = 'shadow_logs.jsonl'):
self.model_a = model_a
self.model_b = model_b
self.traffic_ratio = traffic_ratio
self.log_path = log_path
self.logger = open(log_path, 'a')
self.lock = threading.Lock()
def should_shadow(self, request_id: str) -> bool:
"""基于请求ID决定是否复制到影子模型"""
# 一致性哈希确保相同请求ID总是进入影子模式
hash_val = int(hashlib.md5(request_id.encode()).hexdigest(), 16)
return (hash_val % 100) < (self.traffic_ratio * 100)
def process(self, request_id: str, image: np.ndarray) -> Dict:
"""
处理请求:生产模型同步响应,影子模型异步执行
"""
# 生产模型推理(同步)
start = time.perf_counter()
result_a = self.model_a(image)
latency_a = (time.perf_counter() - start) * 1000
response = {
'request_id': request_id,
'result': result_a,
'latency_ms': latency_a,
'shadow': False
}
# 影子模式(异步)
if self.should_shadow(request_id):
threading.Thread(
target=self._shadow_inference,
args=(request_id, image, result_a)
).start()
return response
def _shadow_inference(self, request_id: str, image: np.ndarray, result_a):
"""影子模型推理与对比"""
start = time.perf_counter()
result_b = self.model_b(image)
latency_b = (time.perf_counter() - start) * 1000
# 计算一致性指标
consistency = self._compute_consistency(result_a, result_b)
log_entry = {
'timestamp': datetime.now().isoformat(),
'request_id': request_id,
'model_a_latency': latency_a,
'model_b_latency': latency_b,
'consistency': consistency,
'result_a': result_a,
'result_b': result_b
}
with self.lock:
self.logger.write(json.dumps(log_entry) + '\n')
self.logger.flush()
def _compute_consistency(self, result_a, result_b) -> float:
"""计算两个模型结果的一致性(示例:检测框IoU)"""
# 简化示例:假设结果为检测框列表
if not result_a or not result_b:
return 0.0
# 计算平均IoU
ious = []
for det_a in result_a:
best_iou = 0
for det_b in result_b:
iou = self._compute_iou(det_a.get('bbox', []), det_b.get('bbox', []))
best_iou = max(best_iou, iou)
ious.append(best_iou)
return np.mean(ious) if ious else 0.0
def _compute_iou(self, box_a, box_b):
"""计算两个框的IoU"""
if len(box_a) < 4 or len(box_b) < 4:
return 0.0
x1 = max(box_a[0], box_b[0])
y1 = max(box_a[1], box_b[1])
x2 = min(box_a[2], box_b[2])
y2 = min(box_a[3], box_b[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
return inter / (area_a + area_b - inter + 1e-6)
def close(self):
self.logger.close()
class CanaryReleaser:
"""
灰度发布控制器:逐步切换流量到新模型
特性:渐进式放量、自动回滚、健康检查
"""
def __init__(self,
model_a: Callable, # 旧版本
model_b: Callable, # 新版本
initial_percent: float = 5.0,
max_percent: float = 100.0,
step_percent: float = 5.0,
error_threshold: float = 0.05):
self.model_a = model_a
self.model_b = model_b
self.current_percent = initial_percent
self.max_percent = max_percent
self.step_percent = step_percent
self.error_threshold = error_threshold
self.stats = {
'a': {'requests': 0, 'errors': 0, 'latency': []},
'b': {'requests': 0, 'errors': 0, 'latency': []}
}
self.lock = threading.Lock()
def route(self, request_id: str, image: np.ndarray) -> Dict:
"""
路由请求到模型A或模型B
基于当前灰度比例
"""
# 一致性哈希路由
hash_val = int(hashlib.md5(request_id.encode()).hexdigest(), 16)
use_b = (hash_val % 100) < self.current_percent
model = self.model_b if use_b else self.model_a
model_key = 'b' if use_b else 'a'
start = time.perf_counter()
try:
result = model(image)
error = False
except Exception as e:
result = None
error = True
print(f"Error in model {model_key}: {e}")
latency = (time.perf_counter() - start) * 1000
# 更新统计
with self.lock:
self.stats[model_key]['requests'] += 1
if error:
self.stats[model_key]['errors'] += 1
self.stats[model_key]['latency'].append(latency)
return {
'result': result,
'model': model_key,
'latency_ms': latency,
'error': error
}
def promote(self):
"""提升灰度比例"""
new_percent = min(self.current_percent + self.step_percent, self.max_percent)
# 健康检查:新版本错误率是否可接受
with self.lock:
if self.stats['b']['requests'] > 100:
error_rate = self.stats['b']['errors'] / self.stats['b']['requests']
if error_rate > self.error_threshold:
print(f"Health check failed: error rate {error_rate:.2%}")
return False
self.current_percent = new_percent
print(f"Promoted to {self.current_percent}%")
return True
def rollback(self):
"""回滚到旧版本"""
self.current_percent = 0
print("Rolled back to 0%")
def get_stats(self) -> Dict:
"""获取当前统计信息"""
with self.lock:
stats = {}
for key in ['a', 'b']:
data = self.stats[key]
stats[key] = {
'requests': data['requests'],
'error_rate': data['errors'] / max(data['requests'], 1),
'mean_latency': np.mean(data['latency']) if data['latency'] else 0,
'p99_latency': np.percentile(data['latency'], 99) if data['latency'] else 0
}
return stats
class ABTestAnalyzer:
"""
A/B测试结果分析器
计算统计显著性、效果量、置信区间
"""
def __init__(self, log_path: str):
self.log_path = log_path
self.data = self._load_logs()
def _load_logs(self) -> List[Dict]:
"""加载日志文件"""
logs = []
with open(self.log_path, 'r') as f:
for line in f:
logs.append(json.loads(line))
return logs
def analyze(self) -> Dict:
"""分析A/B测试结果"""
# 提取指标
latencies_a = [log['model_a_latency'] for log in self.data]
latencies_b = [log['model_b_latency'] for log in self.data]
consistencies = [log.get('consistency', 0) for log in self.data]
# 统计检验
from scipy import stats
t_stat, p_value = stats.ttest_ind(latencies_a, latencies_b)
return {
'sample_size': len(self.data),
'latency_a_mean': np.mean(latencies_a),
'latency_b_mean': np.mean(latencies_b),
'latency_diff': np.mean(latencies_b) - np.mean(latencies_a),
'latency_diff_percent': (np.mean(latencies_b) - np.mean(latencies_a)) / np.mean(latencies_a) * 100,
'consistency_mean': np.mean(consistencies),
't_statistic': t_stat,
'p_value': p_value,
'significant': p_value < 0.05
}
def main():
parser = argparse.ArgumentParser(description='A/B Testing Framework')
parser.add_argument('--mode', choices=['shadow', 'canary', 'analyze'])
parser.add_argument('--model-a', type=str, help='Model A path (production)')
parser.add_argument('--model-b', type=str, help='Model B path (candidate)')
parser.add_argument('--traffic-ratio', type=float, default=0.1)
parser.add_argument('--canary-percent', type=float, default=10.0)
parser.add_argument('--logs', type=str, default='shadow_logs.jsonl')
parser.add_argument('--input', type=str, help='Input image for testing')
args = parser.parse_args()
if args.mode == 'analyze':
analyzer = ABTestAnalyzer(args.logs)
results = analyzer.analyze()
print(json.dumps(results, indent=2))
return
# 加载模型(示例使用YOLO)
from ultralytics import YOLO
model_a = YOLO(args.model_a)
model_b = YOLO(args.model_b)
if args.mode == 'shadow':
tester = ShadowModeTester(
lambda img: model_a(img)[0].boxes.data.tolist(),
lambda img: model_b(img)[0].boxes.data.tolist(),
args.traffic_ratio,
args.logs
)
# 模拟生产流量
if args.input:
image = cv2.imread(args.input)
for i in range(100):
request_id = f'req_{i}_{time.time()}'
result = tester.process(request_id, image)
print(f"Request {request_id}: latency={result['latency_ms']:.2f}ms")
time.sleep(0.1)
tester.close()
elif args.mode == 'canary':
releaser = CanaryReleaser(
lambda img: model_a(img)[0].boxes.data.tolist(),
lambda img: model_b(img)[0].boxes.data.tolist(),
initial_percent=args.canary_percent
)
# 模拟流量
if args.input:
image = cv2.imread(args.input)
for i in range(200):
request_id = f'req_{i}'
result = releaser.route(request_id, image)
if i > 0 and i % 50 == 0:
stats = releaser.get_stats()
print(f"\nStats at iteration {i}:")
print(f" Model A: {stats['a']['requests']} req, "
f"{stats['a']['error_rate']:.2%} error, "
f"{stats['a']['mean_latency']:.2f}ms")
print(f" Model B: {stats['b']['requests']} req, "
f"{stats['b']['error_rate']:.2%} error, "
f"{stats['b']['mean_latency']:.2f}ms")
# 自动提升
if stats['b']['error_rate'] < 0.01:
releaser.promote()
else:
releaser.rollback()
break
if __name__ == '__main__':
main()
5.4.4 模型热更新:无停机部署与权重切换
模型热更新(Hot Swapping)允许在不重启服务的情况下切换模型权重,实现零停机部署。关键技术包括双缓冲模型实例、原子指针交换与请求排空(draining)。实现需确保新模型加载完成后再切换流量,并优雅处理进行中的请求 。
Python
#!/usr/bin/env python3
"""
Script: model_hot_swap.py
Content: 模型热更新系统,支持无停机部署与权重切换
Usage:
1. 启动服务: python model_hot_swap.py --serve --model models/yolov8n.pt --port 8000
2. 热更新: python model_hot_swap.py --update --new-model models/yolov8s.pt --endpoint http://localhost:8000
3. 状态检查: curl http://localhost:8000/health
"""
import os
import time
import json
import argparse
import threading
import numpy as np
import cv2
import torch
from pathlib import Path
from typing import Optional, Dict
from http.server import HTTPServer, BaseHTTPRequestHandler
from queue import Queue, Empty
import io
import base64
class HotSwappableModel:
"""
热可切换模型包装器
实现双缓冲与原子切换
"""
def __init__(self, initial_model_path: str, model_loader: callable):
self.model_loader = model_loader
self.active_model = model_loader(initial_model_path)
self.standby_model = None
self.switch_lock = threading.RLock()
self.request_count = 0
self.active_version = initial_model_path
def predict(self, image: np.ndarray) -> Dict:
"""使用当前激活模型进行推理"""
with self.switch_lock:
model = self.active_model
version = self.active_version
self.request_count += 1
start = time.perf_counter()
result = model(image)
latency = (time.perf_counter() - start) * 1000
return {
'result': result,
'version': version,
'latency_ms': latency,
'request_count': self.request_count
}
def load_standby(self, new_model_path: str) -> bool:
"""
在后台加载新模型到待机槽位
不中断当前服务
"""
try:
print(f"Loading standby model: {new_model_path}")
self.standby_model = self.model_loader(new_model_path)
self.standby_version = new_model_path
print("Standby model loaded successfully")
return True
except Exception as e:
print(f"Failed to load standby model: {e}")
return False
def swap(self) -> bool:
"""
原子切换:将待机模型切换为激活模型
等待进行中的请求完成
"""
if self.standby_model is None:
print("No standby model loaded")
return False
with self.switch_lock:
# 原子交换
old_model = self.active_model
self.active_model = self.standby_model
self.active_version = self.standby_version
# 清空待机槽位
self.standby_model = None
# 异步清理旧模型(避免阻塞)
threading.Thread(target=self._cleanup_model, args=(old_model,)).start()
print(f"Model swapped to: {self.active_version}")
return True
def _cleanup_model(self, model):
"""延迟清理旧模型(确保无引用)"""
time.sleep(5) # 给GC时间
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Old model cleaned up")
def get_status(self) -> Dict:
"""获取当前状态"""
return {
'active_version': self.active_version,
'standby_loaded': self.standby_model is not None,
'total_requests': self.request_count,
'memory_allocated': torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
}
class ModelUpdateHandler(BaseHTTPRequestHandler):
"""HTTP请求处理器"""
def do_GET(self):
if self.path == '/health':
self._send_json(200, self.server.model_wrapper.get_status())
else:
self._send_error(404, "Not found")
def do_POST(self):
if self.path == '/predict':
# 读取图像数据
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
# 解析图像(假设为base64编码)
try:
data = json.loads(post_data)
img_bytes = base64.b64decode(data['image'])
img_array = np.frombuffer(img_bytes, dtype=np.uint8)
image = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
# 推理
result = self.server.model_wrapper.predict(image)
self._send_json(200, result)
except Exception as e:
self._send_error(500, str(e))
elif self.path == '/update':
# 热更新请求
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
data = json.loads(post_data)
new_model_path = data.get('model_path')
if not new_model_path:
self._send_error(400, "model_path required")
return
# 异步加载新模型
success = self.server.model_wrapper.load_standby(new_model_path)
if success:
# 执行切换
swapped = self.server.model_wrapper.swap()
self._send_json(200, {'success': swapped, 'new_version': new_model_path})
else:
self._send_error(500, "Failed to load model")
else:
self._send_error(404, "Not found")
def _send_json(self, code: int, data: dict):
self.send_response(code)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(data).encode())
def _send_error(self, code: int, message: str):
self.send_response(code)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps({'error': message}).encode())
def log_message(self, format, *args):
# 简化日志
pass
class HotSwapServer:
"""
模型热更新服务
支持HTTP接口与信号触发更新
"""
def __init__(self, model_path: str, host='0.0.0.0', port=8000):
self.host = host
self.port = port
# 初始化可切换模型
def load_yolo(path):
from ultralytics import YOLO
return YOLO(path)
self.model_wrapper = HotSwappableModel(model_path, load_yolo)
# 设置HTTP服务器
self.server = HTTPServer((host, port), ModelUpdateHandler)
self.server.model_wrapper = self.model_wrapper
def start(self):
"""启动服务"""
print(f"Server started at http://{self.host}:{self.port}")
print(f"Initial model: {self.model_wrapper.active_version}")
# 启动监控线程
self.monitor_thread = threading.Thread(target=self._monitor)
self.monitor_thread.daemon = True
self.monitor_thread.start()
try:
self.server.serve_forever()
except KeyboardInterrupt:
print("\nShutting down...")
def _monitor(self):
"""监控线程:定期检查模型目录更新"""
watch_dir = Path('models')
last_mtime = 0
while True:
time.sleep(10) # 每10秒检查一次
if not watch_dir.exists():
continue
# 检查最新模型文件
model_files = list(watch_dir.glob('*.pt'))
if not model_files:
continue
latest = max(model_files, key=lambda p: p.stat().st_mtime)
mtime = latest.stat().st_mtime
if mtime > last_mtime and str(latest) != self.model_wrapper.active_version:
print(f"Detected new model: {latest}")
# 自动热更新
if self.model_wrapper.load_standby(str(latest)):
self.model_wrapper.swap()
last_mtime = mtime
def main():
parser = argparse.ArgumentParser(description='Model Hot Swap')
parser.add_argument('--serve', action='store_true', help='Start server')
parser.add_argument('--model', type=str, required=True, help='Initial model path')
parser.add_argument('--host', default='0.0.0.0')
parser.add_argument('--port', type=int, default=8000)
parser.add_argument('--update', action='store_true', help='Trigger update')
parser.add_argument('--new-model', type=str, help='New model path for update')
parser.add_argument('--endpoint', type=str, default='http://localhost:8000')
args = parser.parse_args()
if args.serve:
server = HotSwapServer(args.model, args.host, args.port)
server.start()
elif args.update:
import urllib.request
import urllib.error
data = json.dumps({'model_path': args.new_model}).encode()
req = urllib.request.Request(
f"{args.endpoint}/update",
data=data,
headers={'Content-Type': 'application/json'},
method='POST'
)
try:
with urllib.request.urlopen(req) as response:
print(response.read().decode())
except urllib.error.HTTPError as e:
print(f"Error: {e.code} - {e.read().decode()}")
if __name__ == '__main__':
main()
以上代码构成了完整的现代深度学习工程实践指南,涵盖从模型部署优化到生产环境运维的全链路技术。每个脚本均包含详细的使用说明与实现细节,可直接应用于实际项目开发。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)