MNN 模型支持:Qwen3.5
引言
Qwen3.5 于除夕夜正式开源,MNN 团队迅速完成全面适配与优化,并同步发布可在端侧高效部署的 Qwen3.5 MNN 模型。

Qwen3.5 是通义千问系列的最新一代大语言模型。从 Qwen 到 Qwen2、Qwen2.5、Qwen3,每一代的升级主要集中在训练数据、模型规模和微调策略上,核心架构始终保持在标准 Transformer Decoder 的框架内。而 Qwen3.5 则是这一系列中架构变化最大的一次升级,几乎在每个关键模块上都引入了颠覆性设计:
- 注意力机制:引入混合注意力(Hybrid Attention),交替使用标准 Attention 与线性注意力 DeltaNet
- 归一化层:引入 Gated RMSNorm,同时应用于线性注意力层和标准 Attention 层
- MoE 架构:新增 Shared Expert(共享专家)机制
- 模型形态:默认以多模态视觉语言模型(VLM)发布
- 位置编码:新增部分旋转(Partial Rotary),沿用 Interleaved M-RoPE
- 词表:从 151936 扩展到 248320
- 推测解码:内置 MTP(Multi-Token Prediction)支持
本文将详细介绍 MNN 为支持 Qwen3.5 所做的核心适配工作与技术实现。
Qwen3.5 架构变化概览
| 特性 | Qwen3 / Qwen3-VL | Qwen3.5 |
|---|---|---|
| 注意力机制 | 全 Transformer Attention | 混合 Attention(3:1 Linear + Full) |
| 归一化 | RMSNorm | RMSNorm + Gated RMSNorm |
| MoE | Expert 路由 | Expert 路由 + Shared Expert |
| 默认模态 | 纯文本 / VL 分开发布 | 默认多模态(VLM) |
| RoPE | 标准 / Interleaved M-RoPE | 部分旋转 + Interleaved M-RoPE |
| 词表大小 | 151,936 | 248,320 |
| 推测解码 | - | MTP(Multi-Token Prediction) |
一、DeltaNet:线性注意力机制
1.1 与标准 Attention 的对比
标准 Transformer Attention 和 DeltaNet 线性注意力本质上都是序列建模机制,但在计算模式和资源消耗上有根本区别:
| 标准 Attention | DeltaNet 线性注意力 | |
|---|---|---|
| 核心操作 | Softmax(QK^T/√d)V | 循环状态矩阵递推 |
| 每步推理复杂度 | O(L),需遍历全部历史 KV | O(1),仅更新固定大小状态 |
| 内存开销 | KV Cache 随序列长度线性增长 | 固定大小状态 [H, d_k, d_v] |
| 位置编码 | 依赖 RoPE | 因果卷积隐式编码 |
| 信息访问 | 可精确访问任意历史位置 | 通过压缩记忆间接访问 |
| 适合场景 | 需要精确长距离依赖 | 信息可被渐进压缩的场景 |
Qwen3.5 以 3:1 的比例交替组合这两种注意力——每 4 层中有 3 层使用 DeltaNet 线性注意力,1 层使用标准 Full Attention。这一设计在大幅降低推理成本的同时,依靠间隔插入的标准 Attention 层保持对关键信息的精确捕获能力。
1.2 算法原理
DeltaNet 的核心是 Gated Delta Rule,它维护一个 Key-Value 记忆矩阵 S ∈ ℝ^{d_k × d_v},在每个时间步通过以下五步更新:
输入: qkv [B, D, L], gate [B, L, H], beta [B, L, H], conv_weight [D, 1, K]
Step 1: Depthwise Conv1D + SiLU
将 qkv 通过带状态的因果卷积和 SiLU 激活函数
Step 2: Split Q/K/V + GQA 扩展
从卷积输出中分离 Q、K、V,并处理 GQA(Grouped-Query Attention)
Step 3: 可选 L2 归一化
对 Q 和 K 进行 per-head L2 normalization
Step 4: Scale Q
Q = Q / √d_k
Step 5: Gated Delta Rule 递推(核心)
对每个时间步 t:
S_t = S_{t-1} × exp(gate_t) // 门控衰减旧记忆
v_pred = S_t^T @ k_t // 用当前 key 预测 value
delta = beta_t × (v_t - v_pred) // 计算预测误差
S_t = S_t + k_t @ delta^T // 外积更新记忆
o_t = S_t^T @ q_t // 用 query 查询记忆
其中,gate(衰减因子)和 beta(学习率)控制记忆的遗忘与更新强度,使模型在保持线性复杂度的同时,具备对上下文信息的精细控制能力。
值得注意的是,gate 的计算经过精心参数化以确保始终为负值(保证记忆衰减):
beta = torch.sigmoid(b) # ∈ (0, 1)
gate = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) # 始终 ≤ 0
其中 A_log 和 dt_bias 均为可学习参数。exp(A_log) 确保正数,softplus(...) 确保非负,前面的负号使 gate 始终 ≤ 0,从而在 exp(gate) 后得到 ∈ (0, 1] 的衰减系数。
1.3 反向 GQA
Qwen3.5 的 Linear Attention 中有一个独特的设计:V head 数量大于 K/Q head 数量(如 7B 模型中 num_v_heads=32, num_k_heads=16)。这与标准 GQA(Grouped-Query Attention)中 K/V heads 少于 Q heads 的做法恰好相反。在计算前,K 和 Q 通过 repeat_interleave 扩展到与 V 相同的 head 数量:
if self.num_v_heads > self.num_k_heads:
factor = self.num_v_heads // self.num_k_heads
query = query.repeat_interleave(factor, dim=2)
key = key.repeat_interleave(factor, dim=2)
这意味着循环状态 S 的维度为 [B, num_v_heads, d_k, d_v],每个 V head 拥有独立的记忆矩阵。更多的 V heads 能够让模型在输出端拥有更丰富的表达能力,同时保持 K/Q 的计算开销较低。
1.4 MNN 实现
为支持 DeltaNet,MNN 新增了 OpType_LinearAttention(ID=305)算子,在 FlatBuffers Schema 中新增 LinearAttentionParam 参数表:
table LinearAttentionParam {
attn_type: string; // "gated_delta_rule"
num_k_heads: int; // K/Q head 数量
num_v_heads: int; // V head 数量
head_k_dim: int; // 每个 K head 的维度
head_v_dim: int; // 每个 V head 的维度
use_qk_l2norm: bool; // 是否对 Q/K 做 L2 归一化
}
该算子在四个后端均有完整实现:
-
CPU 后端(
CPULinearAttention.cpp):利用多线程并行,通道级并行计算 Conv1D+SiLU,head 级并行计算 Gated Delta Rule,并使用 MNN 内置的MNNSiLu、MNNScaleAndAddBiasScalar、MNNComputeMatMulForE_1等优化函数加速核心计算。 -
Metal 后端(
MetalLinearAttention.mm):实现三个 Metal Compute Kernel——linear_attn_conv_silu、linear_attn_conv_state_update、linear_attn_gated_delta_rule,通过R"metal(...)"内联 Shader 编译,并使用 Pipeline 缓存机制避免重复编译。 -
OpenCL 后端(
LinearAttentionBufExecution.cpp):三个 OpenCL Kernel 以.cl文件形式编写,通过 Codegen 脚本自动转换为 C++ 源码嵌入编译。全部使用float精度确保递推数值稳定性。 -
Vulkan 后端(
VulkanLinearAttention.cpp):三个 GLSL Compute Shader 编译为 SPIR-V,通过VulkanShaderMap注册,遵循 Vulkan 后端的 Descriptor Set + Uniform Buffer 标准模式。
四个后端共享相同的算法逻辑,均需管理两个持久状态(Persistent State):
- Conv State
[B, D, K-1]:因果卷积的滑动窗口历史 - Recurrent State
[B, H, d_k, d_v]:Delta Rule 的 Key-Value 记忆矩阵
这些状态使用 Backend::STATIC 分配,在首次推理时零初始化,跨 Decode Step 持久保持。
1.5 混合注意力的 Decoder 结构
Qwen3.5 采用 3 层 Linear Attention + 1 层 Full Attention 交替的混合架构。在 Decoder 中,每一层根据原始模型的配置自动选择注意力类型:
class Decoder:
def __init__(self, ...):
if hasattr(self, 'self_attn'):
self.self_attn = Attention(...) # 标准 KV Cache Attention
self.layer_type = 'full_attention'
if hasattr(self, 'linear_attn'):
self.self_attn = LinearAttention(...) # DeltaNet 线性注意力
self.layer_type = 'linear_attention'
两种层在推理时的差异:
- Full Attention 层:需要 RoPE 位置编码和 Attention Mask,维护 KV Cache。
- Linear Attention 层:不使用 RoPE(位置信息由因果卷积隐式编码),不需要 KV Cache,仅维护固定大小的循环状态。
这意味 Qwen3.5 的 KV Cache 大小仅为纯 Transformer 模型的约 1/4,显著降低了长序列推理的内存开销。
二、Gated RMSNorm
Qwen3.5 引入了 Gated RMSNorm,这是标准 RMSNorm 的扩展,额外接收一个门控信号,在归一化后与 SiLU 激活的门控值相乘:
class RMSNorm:
def forward(self, hidden_states, gate=None):
# 标准 RMSNorm
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = weight * hidden_states
# Gated:将归一化结果与 SiLU(gate) 相乘
if gate is not None:
hidden_states = hidden_states * F.silu(gate)
return hidden_states
Gated RMSNorm 在 Qwen3.5 中同时应用于两种注意力层:
在 Linear Attention 层中,gate 来源于独立的 in_proj_z 投影:
z = self.in_proj_z(hidden_states) # [B, L, value_dim]
attn_out = self.norm(attn_out, z) # Gated RMSNorm: norm(x) * silu(z)
output = self.out_proj(attn_out)
在 Full Attention 层中,gate 来源于 q_proj 的扩展输出。当 q_proj 的输出维度为 2 * num_heads * head_dim 时,输出被拆分为 query 和 gate 两部分:
query_states = self.q_proj(hidden_states)
if self.q_proj.out_features == 2 * self.num_heads * self.head_dim:
query_states, gate = torch.split(reshaped, self.head_dim, dim=-1)
...
# Attention 计算后,输出乘以 sigmoid(gate)
attn_output = attn_output * torch.sigmoid(gate)
需要注意两种层使用了不同的门控激活函数:Linear Attention 层使用 SiLU(norm(x) * silu(z)),Full Attention 层使用 Sigmoid(attn_output * sigmoid(gate))。SiLU 的输出范围为 (-0.28, +∞),允许轻微抑制;Sigmoid 的输出严格在 (0, 1) 之间,纯粹起缩放作用。
三、MoE 架构变化:Shared Expert
3.1 变化内容
Qwen3.5 的 MoE(Mixture of Experts)层新增了 Shared Expert 机制。与传统 MoE 每个 token 仅路由到 Top-k 个专家不同,Shared Expert 是一个所有 token 都会经过的公共专家,其输出与路由专家的输出相加:
class Mlp:
def forward(self, hidden_states):
# Shared Expert: 所有 token 共享的基础能力
shared_expert_output = F.sigmoid(self.shared_expert_gate(x)) * self.shared_expert(x)
# Routed Experts: top-k 路由的稀疏专家
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=-1)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # 归一化
expert_output = self.custom_moe(x, routing_weights, selected_experts)
# 最终输出 = 路由专家输出 + 共享专家输出
final_output = expert_output + shared_expert_output
return final_output
3.2 Shared Expert 的意义
- 保底能力:Shared Expert 确保每个 token 至少获得一组完整的 FFN 计算,避免因路由不当导致信息损失。
- 知识共享:通用知识存储在 Shared Expert 中,专业知识由路由专家处理,形成分层知识结构。
- 训练稳定性:减少了 MoE 训练中常见的路由崩塌问题。
3.3 MNN 适配
在 MNN 的导出流程中,Shared Expert 作为一个独立的线性层导出,其输出通过 shared_expert_gate(Sigmoid 门控)调制后与 MoE 的输出相加。model_mapper.py 中新增了对应的映射关系:
qwen3_5_moe_mlp = {
'num_experts': 'experts.num_experts',
'top_k': 'gate.top_k',
'gate': 'gate',
'experts': 'experts',
'shared_expert_gate': 'shared_expert_gate', # 新增
'shared_expert': 'shared_expert' # 新增
}
同时,MoE 模块的 Expert 权重格式从 ModuleList 改为 PackedMLP 格式——所有专家的权重存储在统一的 3D 张量 [num_experts, ...] 中,按索引切片即可(gate_up_proj.data[i] 直接使用,无需 .transpose(0, 1)),简化了导出流程。此外,Qwen3.5 MoE 的路由权重在 Top-k 选择后会进行 归一化(routing_weights /= routing_weights.sum()),确保各专家贡献之和为 1。
四、默认多模态:Vision 原生支持
4.1 架构特点
与 Qwen3 将纯文本和视觉语言模型分开发布不同,Qwen3.5 默认以多模态视觉语言模型(VLM) 形态发布。这意味着:
- 模型配置使用
text_config嵌套结构(如text_config.hidden_size),区分文本和视觉配置。 - 模型路径以
model.language_model.layers而非model.layers访问语言模型层。 - Embedding 层位于
model.language_model.embed_tokens。 - Visual Encoder 位于
model.visual。
4.2 视觉编码器(Qwen3_5Vision)
Qwen3.5 的视觉编码器与 Qwen3-VL 基本一致,继承了其 ViT + Merger 的架构设计、Interleaved M-RoPE 位置编码方案、可学习位置编码(pos_embed + 双线性插值),以及动态分辨率处理能力。主要区别在于:
- 移除 Deepstack:Qwen3-VL 引入的 Deepstack 特性(
deepstack_merger_list,用于对视觉 token 进行多尺度深层融合)在 Qwen3.5 中被移除,简化了视觉编码器结构。
4.3 Interleaved M-RoPE 在纯文本场景的适配
Qwen3.5 继承了 Qwen3-VL 引入的 Interleaved M-RoPE 位置编码方案。该方案将时间 (T)、高度 (H)、宽度 (W) 三个维度的频率分量交替排列,而非早期 M-RoPE 沿特征维度拼接的方式:
# 传统 M-RoPE: [T_freq | H_freq | W_freq] (拼接)
# Interleaved M-RoPE: [T0, H0, W0, T1, H1, W1, ...] (交错)
freq_idx = torch.arange(0, 3 * half_rotary).reshape(3, 1, half_rotary)
mrope_reindex = apply_interleaved_mrope(freq_idx, mrope_section).flatten()
在视觉处理中,T/H/W 三个维度各自携带真实的时空位置信息。而对于纯文本推理,由于文本没有空间维度,MNN 推理引擎将同一个位置 ID 复制为三份(T=H=W),通过 is_mrope 配置标识自动处理:
// 纯文本推理:三维位置 ID 均设为相同的序列位置
if (mConfig->is_mrope()) {
positionIds = _Input({3, seq_len}, NCHW, halide_type_of<int>());
auto ptr = positionIds->writeMap<int>();
for (int i = 0; i < seq_len; i++) {
ptr[0 * seq_len + i] = i + all_seq_len; // T = pos
ptr[1 * seq_len + i] = i + all_seq_len; // H = pos
ptr[2 * seq_len + i] = i + all_seq_len; // W = pos
}
}
4.4 部分旋转(Partial Rotary)
Qwen3.5 引入了 partial_rotary_factor 参数,仅对 head_dim 中的部分维度应用旋转位置编码,其余维度保持不变直接通过:
if 'partial_rotary_factor' in config.rope_parameters:
self.partial_rotary_factor = config.rope_parameters['partial_rotary_factor']
self.rotary_dim = int(self.rotary_dim * self.partial_rotary_factor)
在 apply_rotary_pos 时,这一机制与 Phi 系列模型采用相同的方式——将特征向量拆分为两部分,仅对前 rotary_dim 个维度应用旋转,剩余维度原样拼接:
def phi_rotary_pos(self, x, cos, sin):
x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:]
x = (x * cos) + (rotate_half(x) * sin)
return torch.cat((x, x_pass), dim=-1)
这与 LLaMA 等模型对全部维度应用旋转的方式不同。部分旋转让模型在保留位置信息编码能力的同时,保持一部分“位置无关”的特征维度。Qwen3.5 的 model_type 在 RoPE 处理中被归入 Phi 的分支:
if self.model_type in ['phi-msft', 'qwen3_5', 'qwen3_5_moe']:
return self.phi_rotary_pos(x, cos, sin)
需要注意的是,部分旋转仅应用于 Full Attention 层。LinearAttention 层不使用 RoPE,位置信息完全由因果卷积提供。
五、Tokenizer 升级
Qwen3.5 对词表进行了大幅扩展,从 Qwen 系列一直沿用的 151,936 词表扩展到 248,320,增幅超过 63%。更大的词表能够更好地覆盖多语言和专业领域的 Token 分布,减少 Tokenization 时的碎片化,从而在相同文本下产生更短的 Token 序列,提升推理效率。
六、MTP (Multi-Token Prediction)
Qwen3.5 的配置中包含 mtp_num_hidden_layers: 1,表明模型内置了 MTP(Multi-Token Prediction,多 Token 预测)能力。MTP 是一种推测解码(Speculative Decoding)技术——在标准 LLM 每步仅预测下一个 Token 的基础上,额外的 MTP 层并行预测后续多个 Token,从而在一次前向传播中生成多个候选 Token,显著提升推理吞吐量。
MTP 模块的结构是一个轻量级的 Decoder 层,接收主模型最后一层的隐状态和当前 Token 的 Embedding 作为输入:
class MTP:
def forward(self, input_embeds, hidden_states, attention_mask, position_ids, past_key_values):
# 对隐状态和 token embedding 分别归一化后融合
input_embeds = self.token_layernorm(input_embeds)
hidden_states = self.hidden_layernorm(hidden_states)
hidden_states = self.input_proj(torch.cat([hidden_states, input_embeds], dim=-1))
# 一个完整的 Decoder 层(Self-Attention + MLP)
hidden_states = self.self_attn(hidden_states, ...)
hidden_states = self.mlp(hidden_states)
hidden_states = self.final_layernorm(hidden_states)
# 共享主模型的 lm_head 产生预测
logits = self.lm_head(hidden_states)
return logits
MNN 的导出框架已为 MTP 搭建了完整的导出通道(export_mtp),支持将 MTP 层导出为独立的 ONNX 模型并转换为 MNN 格式。MTP 层与主模型共享 lm_head 权重,同时拥有独立的 KV Cache。
七、MNN 框架层面的改进
Qwen3.5 的适配过程中,MNN 的 LLM 导出框架也做了重要重构:
- 移除
past_key_values显式传递:KV Cache 管理从模型外部传入改为模型内部自管理(self.past_key_value),简化了导出和推理流程。 - Attention 输出简化:Decoder 的
forward返回值从(hidden_states, present_key_value)简化为hidden_states。 - 统一 Fused Attention 导出:所有模型类型均启用
export_fused_attn,不再区分是否为 MoE 模型。 - FusedAttention 增加
kv_cache属性:区分标准 Attention 层(需要 KV Cache)和 LinearAttention 层(不需要 KV Cache)。
八、快速上手
导出模型
# 导出 Qwen3.5 模型
python llmexport.py --path /path/to/model --export mnn --hqq
推理
# 使用 MNN LLM 引擎推理
./llm_demo /path/to/exported/model/config.json prompt.txt
总结
Qwen3 到 Qwen3.5 是通义千问系列迄今为止架构变化最大的一次升级。MNN 对 Qwen3.5 的支持涵盖了该模型的所有核心创新:
- DeltaNet 线性注意力:新增
LinearAttention算子,在 CPU、Metal、OpenCL、Vulkan 四个后端实现,3:1 的 Linear + Full 混合架构将 KV Cache 降至 1/4。 - Gated RMSNorm:扩展 RMSNorm 支持门控调制,同时应用于线性注意力层和标准 Attention 层。
- MoE Shared Expert:支持共享专家与路由专家的混合计算。
- 原生多模态:适配 Qwen3.5 默认 VLM 架构(基于 Qwen3-VL,移除 Deepstack),支持 Interleaved M-RoPE 和部分旋转位置编码。
- Tokenizer 升级:支持从 151,936 到 248,320 的扩展词表。
- MTP 多 Token 预测:通过自带 MTP 头支持投机采样,加速生成。
这些工作使 MNN 能够在移动端和嵌入式设备上高效部署 Qwen3.5 系列模型。
资源链接
- Github Commit: https://github.com/alibaba/MNN/commit/1fdc7d3a70a26d836ade965b6cefb4812c9b8004
- HuggingFace: https://huggingface.co/collections/taobao-mnn/qwen35-mnn
- ModelScope: https://modelscope.cn/collections/MNN/Qwen35-MNN
欢迎开发者体验、反馈与贡献,共同推动端侧大模型推理技术的发展。
Enjoy Reading This Article?
Here are some more articles you might like to read next: