CANN 端侧部署实战:模型转换与服务化

如何将训练好的模型快速部署到昇腾端侧设备?本文详解模型格式转换、端侧优化与服务化部署的完整流程。—

一、端侧部署概述

1.1 端侧部署的挑战

与数据中心训练不同,端侧部署面临独特的约束:算力受限(端侧设备如昇腾 310 的算力远低于 910)、内存有限功耗敏感延迟要求严苛。在昇腾 910 上能实时推理的模型,迁移到昇腾 310 可能需要进一步优化才能满足实时性要求。

端侧部署的核心挑战在于适配不同算子、减少内存占用、降低推理延迟。CANN 提供了完整的端侧工具链,从模型转换到推理引擎再到服务化部署,帮助开发者将训练模型高效部署到各种昇腾硬件。

1.2 端侧部署架构

┌──────────────────────────────────────────────────┐
│           CANN 端侧部署架构                       │
├──────────────────────────────────────────────────┤
│                                                  │
│  训练侧(数据中心)                               │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐           │
│  │ PyTorch │→│ ONNX    │→│ ACL     │           │
│  │ / TF    │  │ 中间格式 │  │ .om 模型 │           │
│  └─────────┘  └─────────┘  └─────────┘           │
│       │              │            │              │
│       └──────────────┴────────────┘              │
│              模型转换(ATC)                      │
│                                                  │
│  端侧(昇腾设备)                                 │
│  ┌─────────┐  ┌─────────┐  ┌─────────┐           │
│  │ ACL    │→│ 引擎    │→│ 服务化  │           │
│  │ 推理   │  │ 调度    │  │ 部署    │           │
│  └─────────┘  └─────────┘  └─────────┘           │
│                                                  │
└──────────────────────────────────────────────────┘

1.3 端侧设备选型

设备 算力(FP16) 显存 功耗 适用场景
昇腾 310 22 TFLOPS 8 GB 7W 边缘推理、IPC、摄像头
昇腾 910 256 TFLOPS 32 GB 310W 数据中心训练/推理
昇腾 310P 70 TFLOPS 12 GB 25W 边缘服务器、自动驾驶

二、模型格式转换

2.1 支持的输入格式

CANN 支持多种主流训练框架的模型格式:

框架 输入格式 说明
PyTorch .pt, .pth, .onnx 需转换为 ONNX
TensorFlow .pb, SavedModel 冻结图或 SavedModel
ONNX .onnx 中间格式,推荐
MindSpore .mindir 华为自研框架

转换路径优先级

PyTorch → ONNX → .om(推荐路径)
TensorFlow → ONNX → .om
MindSpore → .mindir → .om

2.2 PyTorch 转 ONNX

PyTorch 模型需要先转换为 ONNX 格式,再通过 ATC 转为 .om:

8.1 及之前(简单 export):

import torch

model = MyModel()
model.eval()

# 简单 export(静态 shape)
torch.onnx.export(
    model,
    args=torch.randn(1, 3, 224, 224),
    f="model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes=None  # 静态 shape
)

8.2 新增(动态 shape 支持):

import torch

model = MyModel()
model.eval()

# 动态 shape export(支持变长输入)
torch.onnx.export(
    model,
    args=torch.randn(1, 3, 224, 224),
    f="model_dynamic.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size"}
    },
    opset_version=13,  # 建议使用 13 以上
    keep_initializers_as_inputs=False
)

# 验证 ONNX 模型
import onnx
model_onnx = onnx.load("model_dynamic.onnx")
onnx.checker.check_model(model_onnx)
print(f"ONNX model validated successfully")

2.3 ONNX 模型简化

转换后的 ONNX 模型可能包含冗余节点,需要简化后再转换:

8.1 及之前(手动简化):

# 使用 onnx-simplifier(需手动调用)
python -m onnxsim model.onnx model_simplified.onnx

8.2 新增(ATC 内置简化):

# ATC 自动处理图简化
atc --model=model.onnx \
    --framework=5 \
    --output=model \
    --input_shape="input:1,3,224,224" \
    --input_format=NCHW \
    --soc_version=Ascend310 \
    --graph_optimize_mode=ENABLE  # 启用内置图优化

2.4 ATC 模型转换

ATC(Ascend Tensor Compiler)是 CANN 的模型转换核心工具:

基础转换命令

atc \
    --model=/path/to/model.onnx \
    --framework=5 \
    --output=/path/to/model \
    --input_shape="input:1,3,224,224" \
    --input_format=NCHW \
    --soc_version=Ascend310 \
    --log=INFO

进阶转换命令(带优化)

atc \
    --model=/path/to/model.onnx \
    --framework=5 \
    --output=/path/to/model \
    --input_shape="input:1,3,-1,-1" \
    --dynamic_dims="224,256;224,288;224,320" \
    --input_format=NCHW \
    --soc_version=Ascend310 \
    --precision_mode=allow_fp32_to_fp16 \
    --op_select_implmode=high_precision \
    --output_type=FP16 \
    --insert_batchnorm_mode=0 \
    --graph_optimize_mode=GP_ENHANCE \
    --enable_single_stream=1 \
    --log=INFO

关键参数说明

参数 说明 推荐值
precision_mode 精度模式 allow_fp32_to_fp16
op_select_implmode 算子实现模式 high_precision
graph_optimize_mode 图优化级别 GP_ENHANCE
enable_single_stream 单流优化 1(推理推荐)
dynamic_dims 动态 shape 候选 224,256;224,288

三、ACL 推理引擎

3.1 ACL 架构

ACL(Ascend Computing Language)是昇腾的推理运行时,提供了 C++ 和 Python 接口:

┌──────────────────────────────────────────────────┐
│              ACL 推理架构                        │
├──────────────────────────────────────────────────┤
│  ┌─────────┐  ┌─────────┐  ┌─────────┐          │
│  │ 模型    │→│ 引擎    │→│ 数据    │          │
│  │ 加载    │  │ 执行    │  │ 拷贝    │          │
│  └─────────┘  └─────────┘  └─────────┘          │
│       │            │            │                │
│  ┌────┴────────────┴────────────┴────┐          │
│  │        昇腾硬件抽象层               │          │
│  └────────────────────────────────────┘          │
└──────────────────────────────────────────────────┘

3.2 ACL C++ 推理

基础推理流程

#include "acl/acl.h"
#include "acl/ops/acl_dvpp.h"

// ACL 初始化
void InitACL() {
    aclError ret;
    
    // 初始化 Context
    ret = acl.init();
    ret = acl.set_device(0);  // 选择设备 0
    
    // 创建 Context
    aclrtContext context;
    ret = acl.rt.create_context(&context);
}

// 模型加载
aclmdlDescPtr modelDesc;
aclmdlDatasetPtr inputDataset;
aclmdlDatasetPtr outputDataset;

void LoadModel(const std::string& modelPath) {
    uint32_t modelId;
    size_t modelSize;
    
    // 读取模型文件
    std::ifstream file(modelPath, std::ios::binary);
    file.seekg(0, std::ios::end);
    modelSize = file.tellg();
    file.seekg(0, std::ios::beg);
    
    std::unique_ptr<char[]> modelData(new char[modelSize]);
    file.read(modelData.get(), modelSize);
    
    // 加载模型
    aclmdlLoadFromMem(modelData.get(), modelSize, &modelId);
    
    // 获取模型描述
    modelDesc = acl.mdl.create_desc();
    acl.mdl.get_desc(modelDesc, modelId);
}

// 推理执行
void Inference(void* inputData, void* outputData, size_t dataSize) {
    // 准备输入
    inputDataset = acl.mdl.create_dataset();
    aclmdlDataBufferPtr inputBuffer = acl.mdl.create_data_buffer(inputData, dataSize);
    acl.mdl.add_dataset_buffer(inputDataset, inputBuffer);
    
    // 准备输出
    outputDataset = acl.mdl.create_dataset();
    aclmdlDataBufferPtr outputBuffer = acl.mdl.create_data_buffer(outputData, outputSize);
    acl.mdl.add_dataset_buffer(outputDataset, outputBuffer);
    
    // 执行推理
    acl.mdl.execute(modelId, inputDataset, outputDataset);
    
    // 释放数据
    acl.mdl.destroy_data_buffer(inputBuffer);
    acl.mdl.destroy_data_buffer(outputBuffer);
    acl.mdl.destroy_dataset(inputDataset);
    acl.mdl.destroy_dataset(outputDataset);
}

3.3 ACL Python 推理

Python 接口更易用,适合快速部署:

8.1 及之前(acl弟弟接口):

import acl

def init_acl():
    ret = acl.init()
    ret = acl.set_device(0)
    context, ret = acl.rt.create_context()
    stream, ret = acl.rt.create_stream()
    return stream

def inference(model_path, input_data):
    # 加载模型
    model_id, ret = acl.mdl.load_from_file(model_path)
    
    # 准备输入
    input_dataset = acl.mdl.create_dataset()
    input_buffer = acl.mdl.create_data_buffer(input_data)
    acl.mdl.add_dataset_buffer(input_dataset, input_buffer)
    
    # 执行
    acl.mdl.execute(model_id, input_dataset, output_dataset)
    
    return output_data

8.2 新增(高层接口,简化调用):

from ascend.acl import AclModel

# 一行代码加载模型
model = AclModel(model_path="/path/to/model.om")

# 自动类型转换
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 推理(自动管理输入输出)
output = model.predict(input_data)

# 输出即为 numpy 数组
print(f"Output shape: {output.shape}")

3.4 模型预热

首次推理会有初始化开销,需要预热:

from ascend.acl import AclModel

model = AclModel(model_path="/path/to/model.om")

# 预热(执行一次空推理)
warmup_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
for _ in range(5):
    _ = model.predict(warmup_input)

print("Model warmed up, ready for real inference")

四、服务化部署

4.1 推理服务架构

端侧推理服务化需要考虑:并发处理、延迟优化、资源管理:

┌──────────────────────────────────────────────────┐
│           推理服务化架构                         │
├──────────────────────────────────────────────────┤
│  ┌─────────┐  ┌─────────┐  ┌─────────┐          │
│  │ HTTP    │→│ 请求    │→│ 模型    │          │
│  │ 服务    │  │ 队列    │  │ 池     │          │
│  └─────────┘  └─────────┘  └─────────┘          │
│       │            │            │                │
│  ┌────┴────────────┴────────────┴────┐          │
│  │        多模型并行推理               │          │
│  └────────────────────────────────────┘          │
└──────────────────────────────────────────────────┘

4.2 基于 Flask 的简单服务

基础 Flask 服务

from flask import Flask, request, jsonify
import numpy as np
from ascend.acl import AclModel
from threading import Lock

app = Flask(__name__)

# 模型实例(线程安全)
model_lock = Lock()
model = None

def load_model():
    global model
    model = AclModel(model_path="/path/to/model.om")
    # 预热
    warmup = np.random.randn(1, 3, 224, 224).astype(np.float32)
    for _ in range(3):
        model.predict(warmup)

@app.route('/predict', methods=['POST'])
def predict():
    if model is None:
        return jsonify({"error": "Model not loaded"}), 500
    
    # 解析输入
    data = request.json
    input_array = np.array(data['input'], dtype=np.float32)
    
    # 推理
    with model_lock:
        output = model.predict(input_array)
    
    # 返回结果
    return jsonify({
        "output": output.tolist(),
        "shape": output.shape
    })

@app.route('/health', methods=['GET'])
def health():
    return jsonify({"status": "ok", "model_loaded": model is not None})

if __name__ == "__main__":
    load_model()
    app.run(host="0.0.0.0", port=8080, threaded=True)

4.3 异步批处理服务

高并发场景下,异步批处理可显著提升吞吐:

8.1 及之前(同步单请求):

# 每个请求单独处理,吞吐低
def predict_sync(data):
    return model.predict(data)

8.2 新增(异步批处理):

import asyncio
from threading import Thread
from collections import deque

class AsyncBatchedInference:
    def __init__(self, model_path, max_batch_size=32, max_wait_ms=50):
        self.model = AclModel(model_path)
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.request_queue = deque()
        self.results = {}
        self.running = True
        
        # 启动批处理线程
        self.process_thread = Thread(target=self._process_loop)
        self.process_thread.start()
    
    def _process_loop(self):
        """后台批处理循环"""
        while self.running:
            batch = []
            start_time = time.time()
            
            # 收集请求直到 batch 满或超时
            while len(batch) < self.max_batch_size:
                if time.time() - start_time > self.max_wait_ms / 1000:
                    break
                
                # 从队列取请求(最多等 5ms)
                try:
                    request = self.request_queue.popleft()
                    batch.append(request)
                except IndexError:
                    time.sleep(0.005)
            
            if batch:
                # 批量推理
                inputs = np.stack([r['input'] for r in batch])
                outputs = self.model.predict(inputs)
                
                # 分发结果
                for i, request in enumerate(batch):
                    request['future'].set_result(outputs[i])
    
    async def predict_async(self, input_data):
        """异步预测接口"""
        future = asyncio.Future()
        
        self.request_queue.append({
            'input': input_data,
            'future': future
        })
        
        return await future

# 使用
inference_server = AsyncBatchedInference("/path/to/model.om")

async def handle_request(data):
    result = await inference_server.predict_async(data)
    return result

4.4 模型版本管理

生产环境中需要支持模型热更新:

class VersionedModelServer:
    def __init__(self, model_dir="/models"):
        self.model_dir = model_dir
        self.current_version = None
        self.models = {}
        self.lock = Lock()
    
    def load_version(self, version):
        model_path = f"{self.model_dir}/model_v{version}.om"
        self.models[version] = AclModel(model_path)
        with self.lock:
            if self.current_version is None:
                self.current_version = version
        print(f"Model version {version} loaded")
    
    def switch_version(self, version):
        if version not in self.models:
            self.load_version(version)
        
        with self.lock:
            old_version = self.current_version
            self.current_version = version
        
        # 旧版本在所有请求处理完后可释放
        return old_version
    
    def predict(self, input_data):
        with self.lock:
            model = self.models[self.current_version]
        return model.predict(input_data)

# 滚动更新示例
server = VersionedModelServer("/models")
server.load_version("1.0")
# 收到新版本后
old = server.switch_version("1.1")  # 切换到新版本
# 旧版本资源可后续释放

五、性能优化

5.1 延迟优化

端侧推理延迟是关键指标,需要从多个层面优化:

内存拷贝优化:避免 CPU 和 NPU 之间的不必要数据拷贝:

# 8.1 及之前(频繁拷贝)
def inference_old(input_data):
    # 输入从 CPU 拷到 NPU
    input_npu = input_data.npu()
    output = model.predict(input_npu)
    # 输出从 NPU 拷回 CPU
    return output.cpu()

# 8.2 新增(零拷贝)
def inference_new(input_data):
    # 使用 pinned memory 和异步拷贝
    input_npu = input_data.npu()
    output = model.predict(input_npu)
    return output  # 直接返回 NPU tensor

batch 处理优化:适当增大 batch 可提升吞吐:

# 单请求延迟 vs 批处理吞吐
for batch_size in [1, 4, 8, 16]:
    latencies = []
    for _ in range(100):
        start = time.time()
        # 合并 batch 推理
        output = model.predict(batch_input)
        latencies.append(time.time() - start)
    
    avg_latency = np.mean(latencies)
    throughput = batch_size / avg_latency
    print(f"Batch {batch_size}: latency={avg_latency:.3f}s, throughput={throughput:.1f} fps")

5.2 显存优化

端侧设备显存有限,需要精细管理:

# 使用 torch inference mode 减少显存占用
with torch.inference_mode():
    output = model(input)

# 释放中间张量
del intermediate_tensor

# 强制垃圾回收
import gc
gc.collect()
torch.npu.empty_cache()

动态 shape 优化:合理设置动态维度范围避免预分配过多显存:

# 设置动态 shape 范围而非列举所有可能
atc --input_shape="batch:1,3,-1,-1" \
    --dynamic_dims="224,224;320,320"

5.3 Profiling 与瓶颈分析

# 启用 Profiling
export ASCEND_PROFILING_ENABLE=1
export ASCEND_PROFILING_OUTPUT_PATH=/path/to/profiling

# 运行推理
python inference.py

# 生成报告
ascend-sn report --input /path/to/profiling --output /path/to/report

六、部署方案对比

方案 适用场景 优点 缺点
ACL C++ 延迟敏感 最低延迟 开发成本高
ACL Python 快速原型 易用 性能略低
Flask HTTP 简单服务 通用 延迟高
异步批处理 高吞吐 吞吐高 延迟波动
Triton 生产级 功能全 配置复杂

七、常见问题排查

问题 原因 解决方案
模型加载失败 .om 文件损坏 重新用 ATC 转换
推理超时 模型太复杂 简化模型或降低精度
显存不足 batch size 过大 减小 batch 或优化显存
输出为 NaN 输入范围异常 检查输入归一化
精度下降 量化过激 使用 higher precision_mode

相关仓库

  • ACL - 推理引擎 C++ 接口 https://gitee.com/ascend/ascend-cl
  • atc - 模型转换工具 https://gitee.com/ascend/atc
  • Flask - HTTP 服务框架 https://gitee.com/pallets/flask
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐