在边缘设备上部署混合注意力大模型:基于 TensorRT Edge-LLM 适配 Qwen3.5-0.8B 全记录(一)
在边缘设备上部署混合注意力大模型:基于 TensorRT Edge-LLM 适配 Qwen3.5-0.8B 全记录(一)
Qwen3.5 系列是通义千问推出的新一代混合架构语言模型,首次在小参数量模型中引入了 Linear Attention + Full Attention 的混合设计。本文记录了将 Qwen3.5-0.8B 适配到 NVIDIA TensorRT Edge-LLM 框架并部署至 DRIVE Orin 边缘平台的完整过程——从架构分析、Python 导出适配、CUDA 内核开发、数据类型排错,到最终在目标设备上跑通推理。
一、背景与动机
1.1 为什么选择 Qwen3.5-0.8B?
在自动驾驶和机器人等边缘 AI 场景中,我们需要一个参数量足够小(适合 Orin 的内存约束)且推理质量够用的语言模型。Qwen3.5-0.8B 以仅 0.8B 参数实现了令人印象深刻的推理能力,其关键在于一个创新的混合架构设计:
- 24 层 Decoder,其中 18 层使用 Linear Attention(GatedDeltaNet),6 层使用传统 Full Attention(GQA)
- 层排列模式:
[Linear, Linear, Linear, Full] × 6 - Linear Attention 层不需要 KV Cache,使用固定大小的递归状态,对内存极其友好
这意味着在一个 0.8B 模型中,75% 的层不需要随序列长度增长而扩展的 KV Cache,这对内存受限的边缘设备来说是一个巨大的优势。
1.2 TensorRT Edge-LLM 简介
TensorRT Edge-LLM 是 NVIDIA 面向嵌入式平台(Jetson / DRIVE)的高性能 C++ LLM 推理运行时。它的工作流分为三个阶段:
HuggingFace Model ──[Python 量化+导出]──▶ ONNX Model
│
(传输到边缘设备)
▼
──[C++ Engine Builder]──▶ TensorRT Engine
──[C++ Runtime]────────▶ 推理结果
项目的核心设计理念是:Python 工具链仅负责离线导出,设备端运行纯 C++,无 Python 依赖。这决定了适配一个新模型需要在 Python 端和 C++ 端同时做工作。
1.3 挑战概览
适配 Qwen3.5-0.8B 不是简单的「添加一个新模型名称到支持列表」。它带来了几个此前框架未覆盖的技术挑战:
| 挑战 | 根因 | 影响范围 |
|---|---|---|
| Linear Attention 机制 | 全新的 GatedDeltaNet 架构,框架中无对应实现 | Python 导出 + CUDA 内核 + C++ 插件 + Runtime 状态管理 |
| Head Dimension = 256 | Full Attention 层的 head_dim=256,框架仅支持 64/128 | FMHA/XQA 预编译 CUBIN + 内核元数据 + canImplement 检查 |
| ONNX 数据类型不匹配 | PyTorch 隐式类型提升在 ONNX 图中丢失 | TensorRT 构建阶段报错 |
| 混合状态管理 | 两种注意力层需要不同的状态(KV Cache vs Conv/Recurrent State) | C++ Runtime 内存分配与绑定 |
下面按实际工作的时间线逐一展开。
二、模型架构深度解析
2.1 Qwen3.5-0.8B 配置一览
从 config.json 中提取关键参数:
{
"num_hidden_layers": 24,
"hidden_size": 1024,
"intermediate_size": 3584,
"layer_types": ["linear_attention", "linear_attention", "linear_attention", "full_attention", ...],
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"partial_rotary_factor": 0.25,
"attn_output_gate": true,
"linear_num_key_heads": 16,
"linear_num_value_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4
}
2.2 两种注意力层对比
| 属性 | Full Attention (GQA) | Linear Attention (GatedDeltaNet) |
|---|---|---|
| 层数 | 6(第 3, 7, 11, 15, 19, 23 层) | 18(其余层) |
| Q Heads | 8 | 16 |
| KV Heads | 2(GQA 4:1) | 16 |
| Head Dim | 256 | K=128, V=128 |
| 位置编码 | RoPE(仅前 64 维,partial_rotary_factor=0.25) | 无 |
| 输出门控 | ✅ sigmoid gate | ✅ gated RMSNorm |
| 状态类型 | KV Cache(随序列增长) | Conv State (固定) + Recurrent State (固定) |
| 状态大小 | 取决于 max_seq_len | Conv: (B, 6144, 4), Recurrent: (B, 16, 128, 128) |
2.3 Linear Attention 计算流程
GatedDeltaNet 的核心是一个门控递归规则(Gated Delta Rule),其计算流程如下:
输入 hidden_states
│
├── in_proj_qkv → mixed_qkv → Split → Q, K, V
├── in_proj_z → z (gate)
├── in_proj_a → a (alpha log)
└── in_proj_b → b (beta)
│
▼
Causal Conv1d (depthwise, kernel=4, 带 SiLU)
│
▼
Split QKV + L2 Norm (Q, K)
│
▼
Compute: g = -exp(A_log) * softplus(a + dt_bias)
beta = sigmoid(b)
│
▼
Recurrent Delta Rule:
state = exp(g) * state + beta * (K^T ⊗ V)
output = Q * state
│
▼
Gated RMSNorm (output × SiLU(z))
│
▼
out_proj → 输出
这个计算图中,Causal Conv1d 和 Recurrent Delta Rule 都涉及跨时间步的状态维护,在 Prefill(处理整个输入序列)和 Decode(逐 token 生成)两个阶段有不同的实现。
三、Python 导出管线适配
3.1 模型架构层注册
TensorRT Edge-LLM 的 Python 端使用统一的 EdgeLLMDecoderLayer 来封装每一层。关键的分发逻辑在 layers.py 中:
if "qwen" in config.model_type:
attention_module = Qwen2Attention(config, index)
self.mlp = Qwen2MLP(config)
由于 Qwen3.5 的 model_type 为 "qwen3_5",包含 "qwen" 子串,大部分已有逻辑可以复用。但需要新增对 layer_types 字段的判断:
- 如果当前层是
"full_attention"→ 使用EdgeLLMAttention(含 output gate 支持) - 如果当前层是
"linear_attention"→ 使用新增的EdgeLLMLinearAttention
3.2 Full Attention 层的 Output Gate
Qwen3.5 的 Full Attention 层有一个特殊设计:Q 投影的输出同时包含了 query 和 gate:
q_out = self.q_proj(hidden_states) # shape: (B, L, num_heads * head_dim * 2)
query_states, gate = torch.chunk(q_out, 2, dim=-1)
# 注意力计算...
attn_output = attention(query_states, key_states, value_states)
# 输出门控
attn_output = attn_output * torch.sigmoid(gate)
在注意力输出之后,gate 经过 sigmoid 激活后逐元素乘以注意力输出(代码中的 attn_output * torch.sigmoid(gate))。这让模型能够自适应地抑制或放大注意力结果中每个维度的信息,相当于一个可学习的"开关"——对当前输入不重要的特征可以被 gate 压低到接近零。
这种设计让模型可以学习「何时关注、何时忽略」注意力层的输出,提高了表达能力,同时在 ONNX 导出时引入了类型匹配问题(后文详述)。
3.3 LinearAttention 的 ONNX 导出策略
Linear Attention 的计算涉及复杂的递归操作,无法用标准 ONNX 算子高效表达。我们采用自定义 TensorRT 插件的方式:
- 在 Python 端实现
LinearAttentionPlugin的 ONNX symbolic 函数 - 在 ONNX 图中生成一个自定义算子节点
LinearAttentionPlugin - C++ 端实现对应的 TensorRT 插件来执行实际计算
# linear_attention_plugin.py
class LinearAttentionPluginOp(torch.autograd.Function):
@staticmethod
def symbolic(g, mixed_qkv, z, a, b, conv_state, recurrent_state,
context_lengths, conv1d_weight, A_log, dt_bias, norm_weight,
num_v_heads, num_k_heads, head_k_dim, head_v_dim, conv_kernel_dim):
return g.op("TRT::LinearAttentionPlugin",
mixed_qkv, z, a, b, conv_state, recurrent_state,
context_lengths, conv1d_weight, A_log, dt_bias, norm_weight,
# 插件属性
num_v_heads_i=num_v_heads,
num_k_heads_i=num_k_heads,
...)
插件的接口定义:
- 输入:11 个 tensor(mixed_qkv, z, a, b, conv_state, recurrent_state, context_lengths, conv1d_weight, A_log, dt_bias, norm_weight)+ 5 个 int 属性
- 输出:3 个 tensor(output, present_conv_state, present_recurrent_state)
3.4 量化与导出
对于 DRIVE Orin (SM87) 平台,推荐使用 INT4 AWQ 量化:
# Step 1: INT4 AWQ 量化
tensorrt-edgellm-quantize-llm \
--model_dir Qwen/Qwen3.5-0.8B \
--quantization int4_awq \
--output_dir qwen3.5-0.8b/quantized
# Step 2: 导出 ONNX
tensorrt-edgellm-export-llm \
--model_dir qwen3.5-0.8b/quantized \
--output_dir qwen3.5-0.8b/onnx
导出产物:
qwen3.5-0.8b/onnx/
├── model.onnx # 计算图(含自定义算子节点)
├── onnx_model.data # 权重数据
├── config.json # 模型配置(含 layer_types 等)
├── embedding.safetensors # 词嵌入表
├── tokenizer.json # 分词器
└── chat_template.jinja # 对话模板
四、ONNX → TensorRT 数据类型不匹配排错
这是整个适配过程中最隐蔽、最消耗时间的一个阶段。模型导出为 ONNX 后,尝试用 TensorRT 构建引擎时连续遇到了多个数据类型不匹配错误。
4.1 问题现象
[ERROR] /model/layers.3/self_attn/Mul_2: ElementWiseOperation PROD
must have same input types. But they are of types Float and Half.
TensorRT 对 ElementWise 操作要求两侧输入类型严格一致,而 PyTorch 在运行时会自动类型提升——这种行为在 ONNX 静态图中丢失了。
4.2 三种根因模式
经过系统性排查,我们总结出三种常见的类型不匹配模式:
模式 A:PyTorch 隐式类型提升
# PyTorch: sigmoid 内部提升到 FP32 计算,返回 FP32
# ONNX: Sigmoid 保留输入类型 (FP16→FP16)
attn_output = attn_output * torch.sigmoid(gate) # 一侧 FP32,一侧 FP16
模式 B:type_as() 在 ONNX 追踪中失效
# Qwen3_5RMSNorm 内部:
output = self._norm(x.float()) # FP32
return output.type_as(x) # ONNX 追踪可能丢失此 Cast
模式 C:自定义 Plugin 的输出类型推导错误
# LinearAttentionPlugin 输入为 HALF,但 ONNX 推导输出为 FLOAT
# 下游的 Int4GroupwiseGemmPlugin 要求 HALF 输入 → 构建失败
4.3 修复策略:显式 Cast
核心原则:不要依赖隐式类型推导,在关键边界处显式指定目标类型。
# 修复 1: Output Gate 乘法
attn_output = attn_output.to(torch.float16) * torch.sigmoid(gate).to(torch.float16)
# 修复 2: RMSNorm 输出后
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states.to(torch.float16) # 确保下游 plugin 收到 HALF
# 修复 3: LinearAttentionPlugin 输出后
core_attn_out, conv_state, recurrent_state = linear_attention_plugin(...)
core_attn_out = core_attn_out.to(torch.float16) # 确保 out_proj 收到 HALF
4.4 经验总结
ONNX 导出 Checklist:
- ✅ ElementWise 运算前:确保两侧操作数类型一致,不依赖 PyTorch 自动广播
- ✅ Norm 层调用后:如果 Norm 内部有 FP32 提升,显式 cast 回目标类型
- ✅ 自定义 Plugin 输出后:显式 cast 为期望类型
- ✅ 避免
type_as():改用tensor.to(torch.float16) - ✅ 验证手段:使用 Netron 可视化 ONNX 图,确认每个节点输入输出的类型
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)