CANN 端侧部署实战:模型转换与服务化
CANN 端侧部署实战:模型转换与服务化
如何将训练好的模型快速部署到昇腾端侧设备?本文详解模型格式转换、端侧优化与服务化部署的完整流程。—
一、端侧部署概述
1.1 端侧部署的挑战
与数据中心训练不同,端侧部署面临独特的约束:算力受限(端侧设备如昇腾 310 的算力远低于 910)、内存有限、功耗敏感、延迟要求严苛。在昇腾 910 上能实时推理的模型,迁移到昇腾 310 可能需要进一步优化才能满足实时性要求。
端侧部署的核心挑战在于适配不同算子、减少内存占用、降低推理延迟。CANN 提供了完整的端侧工具链,从模型转换到推理引擎再到服务化部署,帮助开发者将训练模型高效部署到各种昇腾硬件。
1.2 端侧部署架构
┌──────────────────────────────────────────────────┐
│ CANN 端侧部署架构 │
├──────────────────────────────────────────────────┤
│ │
│ 训练侧(数据中心) │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ PyTorch │→│ ONNX │→│ ACL │ │
│ │ / TF │ │ 中间格式 │ │ .om 模型 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │ │
│ └──────────────┴────────────┘ │
│ 模型转换(ATC) │
│ │
│ 端侧(昇腾设备) │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ ACL │→│ 引擎 │→│ 服务化 │ │
│ │ 推理 │ │ 调度 │ │ 部署 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │
└──────────────────────────────────────────────────┘
1.3 端侧设备选型
| 设备 | 算力(FP16) | 显存 | 功耗 | 适用场景 |
|---|---|---|---|---|
| 昇腾 310 | 22 TFLOPS | 8 GB | 7W | 边缘推理、IPC、摄像头 |
| 昇腾 910 | 256 TFLOPS | 32 GB | 310W | 数据中心训练/推理 |
| 昇腾 310P | 70 TFLOPS | 12 GB | 25W | 边缘服务器、自动驾驶 |
二、模型格式转换
2.1 支持的输入格式
CANN 支持多种主流训练框架的模型格式:
| 框架 | 输入格式 | 说明 |
|---|---|---|
| PyTorch | .pt, .pth, .onnx | 需转换为 ONNX |
| TensorFlow | .pb, SavedModel | 冻结图或 SavedModel |
| ONNX | .onnx | 中间格式,推荐 |
| MindSpore | .mindir | 华为自研框架 |
转换路径优先级:
PyTorch → ONNX → .om(推荐路径)
TensorFlow → ONNX → .om
MindSpore → .mindir → .om
2.2 PyTorch 转 ONNX
PyTorch 模型需要先转换为 ONNX 格式,再通过 ATC 转为 .om:
8.1 及之前(简单 export):
import torch
model = MyModel()
model.eval()
# 简单 export(静态 shape)
torch.onnx.export(
model,
args=torch.randn(1, 3, 224, 224),
f="model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes=None # 静态 shape
)
8.2 新增(动态 shape 支持):
import torch
model = MyModel()
model.eval()
# 动态 shape export(支持变长输入)
torch.onnx.export(
model,
args=torch.randn(1, 3, 224, 224),
f="model_dynamic.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"}
},
opset_version=13, # 建议使用 13 以上
keep_initializers_as_inputs=False
)
# 验证 ONNX 模型
import onnx
model_onnx = onnx.load("model_dynamic.onnx")
onnx.checker.check_model(model_onnx)
print(f"ONNX model validated successfully")
2.3 ONNX 模型简化
转换后的 ONNX 模型可能包含冗余节点,需要简化后再转换:
8.1 及之前(手动简化):
# 使用 onnx-simplifier(需手动调用)
python -m onnxsim model.onnx model_simplified.onnx
8.2 新增(ATC 内置简化):
# ATC 自动处理图简化
atc --model=model.onnx \
--framework=5 \
--output=model \
--input_shape="input:1,3,224,224" \
--input_format=NCHW \
--soc_version=Ascend310 \
--graph_optimize_mode=ENABLE # 启用内置图优化
2.4 ATC 模型转换
ATC(Ascend Tensor Compiler)是 CANN 的模型转换核心工具:
基础转换命令:
atc \
--model=/path/to/model.onnx \
--framework=5 \
--output=/path/to/model \
--input_shape="input:1,3,224,224" \
--input_format=NCHW \
--soc_version=Ascend310 \
--log=INFO
进阶转换命令(带优化):
atc \
--model=/path/to/model.onnx \
--framework=5 \
--output=/path/to/model \
--input_shape="input:1,3,-1,-1" \
--dynamic_dims="224,256;224,288;224,320" \
--input_format=NCHW \
--soc_version=Ascend310 \
--precision_mode=allow_fp32_to_fp16 \
--op_select_implmode=high_precision \
--output_type=FP16 \
--insert_batchnorm_mode=0 \
--graph_optimize_mode=GP_ENHANCE \
--enable_single_stream=1 \
--log=INFO
关键参数说明:
| 参数 | 说明 | 推荐值 |
|---|---|---|
| precision_mode | 精度模式 | allow_fp32_to_fp16 |
| op_select_implmode | 算子实现模式 | high_precision |
| graph_optimize_mode | 图优化级别 | GP_ENHANCE |
| enable_single_stream | 单流优化 | 1(推理推荐) |
| dynamic_dims | 动态 shape 候选 | 224,256;224,288 |
三、ACL 推理引擎
3.1 ACL 架构
ACL(Ascend Computing Language)是昇腾的推理运行时,提供了 C++ 和 Python 接口:
┌──────────────────────────────────────────────────┐
│ ACL 推理架构 │
├──────────────────────────────────────────────────┤
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 模型 │→│ 引擎 │→│ 数据 │ │
│ │ 加载 │ │ 执行 │ │ 拷贝 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │ │
│ ┌────┴────────────┴────────────┴────┐ │
│ │ 昇腾硬件抽象层 │ │
│ └────────────────────────────────────┘ │
└──────────────────────────────────────────────────┘
3.2 ACL C++ 推理
基础推理流程:
#include "acl/acl.h"
#include "acl/ops/acl_dvpp.h"
// ACL 初始化
void InitACL() {
aclError ret;
// 初始化 Context
ret = acl.init();
ret = acl.set_device(0); // 选择设备 0
// 创建 Context
aclrtContext context;
ret = acl.rt.create_context(&context);
}
// 模型加载
aclmdlDescPtr modelDesc;
aclmdlDatasetPtr inputDataset;
aclmdlDatasetPtr outputDataset;
void LoadModel(const std::string& modelPath) {
uint32_t modelId;
size_t modelSize;
// 读取模型文件
std::ifstream file(modelPath, std::ios::binary);
file.seekg(0, std::ios::end);
modelSize = file.tellg();
file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> modelData(new char[modelSize]);
file.read(modelData.get(), modelSize);
// 加载模型
aclmdlLoadFromMem(modelData.get(), modelSize, &modelId);
// 获取模型描述
modelDesc = acl.mdl.create_desc();
acl.mdl.get_desc(modelDesc, modelId);
}
// 推理执行
void Inference(void* inputData, void* outputData, size_t dataSize) {
// 准备输入
inputDataset = acl.mdl.create_dataset();
aclmdlDataBufferPtr inputBuffer = acl.mdl.create_data_buffer(inputData, dataSize);
acl.mdl.add_dataset_buffer(inputDataset, inputBuffer);
// 准备输出
outputDataset = acl.mdl.create_dataset();
aclmdlDataBufferPtr outputBuffer = acl.mdl.create_data_buffer(outputData, outputSize);
acl.mdl.add_dataset_buffer(outputDataset, outputBuffer);
// 执行推理
acl.mdl.execute(modelId, inputDataset, outputDataset);
// 释放数据
acl.mdl.destroy_data_buffer(inputBuffer);
acl.mdl.destroy_data_buffer(outputBuffer);
acl.mdl.destroy_dataset(inputDataset);
acl.mdl.destroy_dataset(outputDataset);
}
3.3 ACL Python 推理
Python 接口更易用,适合快速部署:
8.1 及之前(acl弟弟接口):
import acl
def init_acl():
ret = acl.init()
ret = acl.set_device(0)
context, ret = acl.rt.create_context()
stream, ret = acl.rt.create_stream()
return stream
def inference(model_path, input_data):
# 加载模型
model_id, ret = acl.mdl.load_from_file(model_path)
# 准备输入
input_dataset = acl.mdl.create_dataset()
input_buffer = acl.mdl.create_data_buffer(input_data)
acl.mdl.add_dataset_buffer(input_dataset, input_buffer)
# 执行
acl.mdl.execute(model_id, input_dataset, output_dataset)
return output_data
8.2 新增(高层接口,简化调用):
from ascend.acl import AclModel
# 一行代码加载模型
model = AclModel(model_path="/path/to/model.om")
# 自动类型转换
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 推理(自动管理输入输出)
output = model.predict(input_data)
# 输出即为 numpy 数组
print(f"Output shape: {output.shape}")
3.4 模型预热
首次推理会有初始化开销,需要预热:
from ascend.acl import AclModel
model = AclModel(model_path="/path/to/model.om")
# 预热(执行一次空推理)
warmup_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
for _ in range(5):
_ = model.predict(warmup_input)
print("Model warmed up, ready for real inference")
四、服务化部署
4.1 推理服务架构
端侧推理服务化需要考虑:并发处理、延迟优化、资源管理:
┌──────────────────────────────────────────────────┐
│ 推理服务化架构 │
├──────────────────────────────────────────────────┤
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ HTTP │→│ 请求 │→│ 模型 │ │
│ │ 服务 │ │ 队列 │ │ 池 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │ │ │ │
│ ┌────┴────────────┴────────────┴────┐ │
│ │ 多模型并行推理 │ │
│ └────────────────────────────────────┘ │
└──────────────────────────────────────────────────┘
4.2 基于 Flask 的简单服务
基础 Flask 服务:
from flask import Flask, request, jsonify
import numpy as np
from ascend.acl import AclModel
from threading import Lock
app = Flask(__name__)
# 模型实例(线程安全)
model_lock = Lock()
model = None
def load_model():
global model
model = AclModel(model_path="/path/to/model.om")
# 预热
warmup = np.random.randn(1, 3, 224, 224).astype(np.float32)
for _ in range(3):
model.predict(warmup)
@app.route('/predict', methods=['POST'])
def predict():
if model is None:
return jsonify({"error": "Model not loaded"}), 500
# 解析输入
data = request.json
input_array = np.array(data['input'], dtype=np.float32)
# 推理
with model_lock:
output = model.predict(input_array)
# 返回结果
return jsonify({
"output": output.tolist(),
"shape": output.shape
})
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "ok", "model_loaded": model is not None})
if __name__ == "__main__":
load_model()
app.run(host="0.0.0.0", port=8080, threaded=True)
4.3 异步批处理服务
高并发场景下,异步批处理可显著提升吞吐:
8.1 及之前(同步单请求):
# 每个请求单独处理,吞吐低
def predict_sync(data):
return model.predict(data)
8.2 新增(异步批处理):
import asyncio
from threading import Thread
from collections import deque
class AsyncBatchedInference:
def __init__(self, model_path, max_batch_size=32, max_wait_ms=50):
self.model = AclModel(model_path)
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.request_queue = deque()
self.results = {}
self.running = True
# 启动批处理线程
self.process_thread = Thread(target=self._process_loop)
self.process_thread.start()
def _process_loop(self):
"""后台批处理循环"""
while self.running:
batch = []
start_time = time.time()
# 收集请求直到 batch 满或超时
while len(batch) < self.max_batch_size:
if time.time() - start_time > self.max_wait_ms / 1000:
break
# 从队列取请求(最多等 5ms)
try:
request = self.request_queue.popleft()
batch.append(request)
except IndexError:
time.sleep(0.005)
if batch:
# 批量推理
inputs = np.stack([r['input'] for r in batch])
outputs = self.model.predict(inputs)
# 分发结果
for i, request in enumerate(batch):
request['future'].set_result(outputs[i])
async def predict_async(self, input_data):
"""异步预测接口"""
future = asyncio.Future()
self.request_queue.append({
'input': input_data,
'future': future
})
return await future
# 使用
inference_server = AsyncBatchedInference("/path/to/model.om")
async def handle_request(data):
result = await inference_server.predict_async(data)
return result
4.4 模型版本管理
生产环境中需要支持模型热更新:
class VersionedModelServer:
def __init__(self, model_dir="/models"):
self.model_dir = model_dir
self.current_version = None
self.models = {}
self.lock = Lock()
def load_version(self, version):
model_path = f"{self.model_dir}/model_v{version}.om"
self.models[version] = AclModel(model_path)
with self.lock:
if self.current_version is None:
self.current_version = version
print(f"Model version {version} loaded")
def switch_version(self, version):
if version not in self.models:
self.load_version(version)
with self.lock:
old_version = self.current_version
self.current_version = version
# 旧版本在所有请求处理完后可释放
return old_version
def predict(self, input_data):
with self.lock:
model = self.models[self.current_version]
return model.predict(input_data)
# 滚动更新示例
server = VersionedModelServer("/models")
server.load_version("1.0")
# 收到新版本后
old = server.switch_version("1.1") # 切换到新版本
# 旧版本资源可后续释放
五、性能优化
5.1 延迟优化
端侧推理延迟是关键指标,需要从多个层面优化:
内存拷贝优化:避免 CPU 和 NPU 之间的不必要数据拷贝:
# 8.1 及之前(频繁拷贝)
def inference_old(input_data):
# 输入从 CPU 拷到 NPU
input_npu = input_data.npu()
output = model.predict(input_npu)
# 输出从 NPU 拷回 CPU
return output.cpu()
# 8.2 新增(零拷贝)
def inference_new(input_data):
# 使用 pinned memory 和异步拷贝
input_npu = input_data.npu()
output = model.predict(input_npu)
return output # 直接返回 NPU tensor
batch 处理优化:适当增大 batch 可提升吞吐:
# 单请求延迟 vs 批处理吞吐
for batch_size in [1, 4, 8, 16]:
latencies = []
for _ in range(100):
start = time.time()
# 合并 batch 推理
output = model.predict(batch_input)
latencies.append(time.time() - start)
avg_latency = np.mean(latencies)
throughput = batch_size / avg_latency
print(f"Batch {batch_size}: latency={avg_latency:.3f}s, throughput={throughput:.1f} fps")
5.2 显存优化
端侧设备显存有限,需要精细管理:
# 使用 torch inference mode 减少显存占用
with torch.inference_mode():
output = model(input)
# 释放中间张量
del intermediate_tensor
# 强制垃圾回收
import gc
gc.collect()
torch.npu.empty_cache()
动态 shape 优化:合理设置动态维度范围避免预分配过多显存:
# 设置动态 shape 范围而非列举所有可能
atc --input_shape="batch:1,3,-1,-1" \
--dynamic_dims="224,224;320,320"
5.3 Profiling 与瓶颈分析
# 启用 Profiling
export ASCEND_PROFILING_ENABLE=1
export ASCEND_PROFILING_OUTPUT_PATH=/path/to/profiling
# 运行推理
python inference.py
# 生成报告
ascend-sn report --input /path/to/profiling --output /path/to/report
六、部署方案对比
| 方案 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|
| ACL C++ | 延迟敏感 | 最低延迟 | 开发成本高 |
| ACL Python | 快速原型 | 易用 | 性能略低 |
| Flask HTTP | 简单服务 | 通用 | 延迟高 |
| 异步批处理 | 高吞吐 | 吞吐高 | 延迟波动 |
| Triton | 生产级 | 功能全 | 配置复杂 |
七、常见问题排查
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 模型加载失败 | .om 文件损坏 | 重新用 ATC 转换 |
| 推理超时 | 模型太复杂 | 简化模型或降低精度 |
| 显存不足 | batch size 过大 | 减小 batch 或优化显存 |
| 输出为 NaN | 输入范围异常 | 检查输入归一化 |
| 精度下降 | 量化过激 | 使用 higher precision_mode |
相关仓库
- ACL - 推理引擎 C++ 接口 https://gitee.com/ascend/ascend-cl
- atc - 模型转换工具 https://gitee.com/ascend/atc
- Flask - HTTP 服务框架 https://gitee.com/pallets/flask
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)