MNN 支持 Gemma4 全系列:PLE、KV共享与多模态
一、背景
Google 最新发布的 Gemma4 是一个架构创新密度很高的系列模型。从端侧 E2B 到桌面级 31B,四种规格覆盖了从手机到服务器的全部场景。更有意思的是,这四种规格不只是简单的参数缩放——它们使用了不同的架构机制组合,给推理引擎的适配带来了不少挑战。
本文记录 MNN 适配 Gemma4 全系列的过程,重点拆解三个最有特色的机制:PLE(Per-Layer Embeddings)、KV Cache 共享和多模态支持,并提供实际的导出测评数据和使用建议。
二、从 Gemma3 到 Gemma4:架构演进
Gemma4 并非直接从 Gemma3 跳跃而来——中间还有一个重要的过渡版本 Gemma3n。Gemma3n 是 Google 专为端侧设备设计的多模态模型,首次引入了 PLE、KV 共享等关键机制(详见我之前的 Gemma3n 架构分析)。Gemma4 在继承这些机制的基础上,进一步扩展和演进。
三代模型的架构演进:
| 特性 | Gemma3 | Gemma3n | Gemma4 |
|---|---|---|---|
| 注意力机制 | 混合注意力 | 混合注意力 | 混合注意力 |
| KV Cache | 每层独立 | KV 共享(10层) | KV 共享(18-20层) |
| Embedding | 共享 embedding | PLE(查表直注入) | PLE(查表+投影融合) |
| QKV Norm | Q/K Norm | Q/K/V Norm | Q/K/V Norm + attn_scale=1.0 |
| V 投影 | 独立 v_proj | 独立 v_proj | 部分模型 K=V(无 v_proj) |
| KV Head | 统一 | 统一 | 异构 KV Head |
| 视觉编码器 | SigLIP ViT | MobileNetV5 | SigLIP ViT + Pooler |
| 音频支持 | 无 | Conformer | USM Conformer |
| MoE | 无 | 无 | 26B-A4B:128 专家 Top-8 |
| 独有机制 | — | AltUp、Laurel Block、激活稀疏性 | K=V、异构KV Head、双RoPE |
可以看到,Gemma4 的演进路线:
- 继承自 Gemma3n:PLE、KV 共享、混合注意力、QKV Norm、多模态原生支持
- Gemma4 新增:PLE 投影路径、更大规模的 KV 共享、K=V 机制、异构 KV Head、Decoder 级 MoE、双 RoPE(不同 head_dim + theta)
- Gemma3n 独有(Gemma4 未继承):AltUp(多路径并行学习)、Laurel Block(低秩增强残差)、95% 激活稀疏性
2.1 整体架构
2.2 四种规格详览
| 模型 | Hidden | Heads/KV Heads | Layers | Head Dim | 特殊机制 |
|---|---|---|---|---|---|
| E2B | 1536 | 8/1 | 35 | S:256/F:512 | KV共享(20层)、PLE |
| E4B | 2560 | 8/2 | 42 | S:256/F:512 | KV共享(18层)、PLE |
| 31B | 5376 | 32/16(F:4) | 60 | S:256/F:512 | K=V、异构KV Head、PLE |
| 26B-A4B | 2816 | 16/8(F:2) | 30 | S:256/F:512 | K=V、异构KV Head、MoE(128专家)、PLE |
所有模型共有的特征:
- 混合注意力:Sliding Attention + Full Attention 交替排列
- 双 RoPE:Sliding 层使用 head_dim=256 + theta=10000,Full 层使用 head_dim=512 + theta=1000000
- QKV 三 Norm:
q_norm、k_norm、v_norm,注意力缩放固定为 1.0 - PLE:每层接收专属的 token embedding 调制信号
- 原生多模态:内置视觉和音频编码器
值得注意的是,E2B/E4B 和 31B/26B-A4B 的设计思路截然不同:前者通过 KV 共享减少存储、用更多的层去分摊任务;后者通过 K=V + 异构 KV Head 压缩参数量,配合更少但更大的层来建模。
三、PLE:每一层看到不同的 token
3.1 从 Gemma3n 到 Gemma4 的 PLE 演进
PLE(Per-Layer Embeddings)最早在 Gemma3n 中引入,其核心思想与 STEM(Scaling Transformers with Embedding Modules)、RWKV-8 的 DeepEmbed 以及华为的 MoLE 方案一脉相承:通过增加 Embedding 参数量来提升模型效果,而 Embedding 参数可以存放在 Flash 上不占用运行内存——接近”免费的午餐”。
Gemma3n 的 PLE 实现比较直接——只有一条路径:用 token ID 在 embed_tokens_per_layer 表中查找,reshape 后按层切分,通过门控乘法注入每层 Decoder。
Gemma4 在此基础上增加了第二条路径:将主 embedding 通过一个 Linear + Scale + Norm 投影到 PLE 空间,与 PLE 查表结果融合后再注入。这使得 PLE 不仅包含 token 级的静态信息,还融入了主 embedding 经过投影后的上下文相关信号。
3.2 Gemma4 的 PLE 计算流程
Gemma4 的 PLE 计算涉及两条路径的融合(相比 Gemma3n 多了左侧的投影路径):
关键步骤解读:
- 查表:
embed_tokens_per_layer是一个大表,维度为[vocab_size, num_layers × ple_dim],用 token ID 查找后 reshape 为[seq, num_layers, ple_dim],每个 token 在每一层都有独立的 embedding - 投影:用
per_layer_model_projection(Linear)将主 embedding 从 hidden_size 投影到 ple_dim 空间,缩放系数为hidden_size^{-0.5},再做 RMSNorm 归一化 - 融合:两条路径的输出相加,乘以
2^{-0.5}作为最终的 per-layer input
3.3 门控注入:PLE 如何进入 Decoder
每层 Decoder 通过门控乘法将 per_layer_input 注入隐状态,而不是简单的加法。这保证了模型可以学习”在这一层需要多少 PLE 信号”:
# 每层 Decoder 内部
residual = hidden_states
# 门控路径:hidden_states 决定"开门多少"
hidden_states = per_layer_input_gate(hidden_states) # Linear
hidden_states = act_fn(hidden_states) # 激活函数
# 逐元素乘法:门控信号 × PLE 调制
hidden_states = hidden_states * per_layer_input # 门控乘法
# 投影回 hidden_size + 归一化
hidden_states = per_layer_projection(hidden_states) # Linear
hidden_states = post_per_layer_input_norm(hidden_states)
# 残差连接
hidden_states = residual + hidden_states
门控设计的好处是:如果某一层不需要 PLE 信号,门控可以学习为接近 0,不影响原有的信息流;如果需要强调某些 token 的层级特异性,门控就会打开。这比简单的加法更灵活,也更容易训练。
3.4 导出策略:PLE Embedding 的量化
PLE Embedding 表体积不小——vocab_size(262144) × num_layers × ple_dim。以 E4B(42 层)为例,bf16 下约 1.3GB,占整个模型不小的比例。
我们支持用 --embed_bit 参数将 PLE Embedding 量化到 int4/int8,复用主 embedding 已有的量化/反量化路径:
# bf16(默认)
python llmexport.py --path model_path --export mnn
# int4 量化,PLE embedding 体积缩小约 4x
python llmexport.py --path model_path --export mnn --embed_bit 4
# int8 量化
python llmexport.py --path model_path --export mnn --embed_bit 8
量化格式与主 embedding 完全一致——使用分块量化(default block_size=64),存储 [weight, alpha] 二进制格式,配置通过 ple_quant = [offset, weight_size, alpha_size, quant_bit, quant_block] 传递给 C++ 推理端。
C++ 端的 DiskEmbedding 类同时支持 bf16 和量化格式的 PLE 加载,推理时 embedding() 阶段同时完成主 embedding 和 PLE embedding 的查表:
// PLE embedding lookup(C++)
if (mPleEmbedding && (!mPleInput.get() || seq_len == 1)) {
mPleInput = _Input({1, seq_len, ple_dim}, NCHW);
mPleEmbedding->embedding(input_ids, mPleInput->writeMap<float>());
if (ple_scale != 1.0f) {
mPleInput = mPleInput * _Scalar<float>(ple_scale);
}
}
分块 prefill 时,PLE 需要跟 input embedding 一起按 block 切分,确保每个 block 的 PLE 与 token 对齐。
四、KV Cache 共享:用 2 层的 KV 撑 20 层的推理
4.1 机制解析
KV 共享同样继承自 Gemma3n。在 Gemma3n-E2B(30 层)中,最后 10 层(层 20-29)共享层 18 和层 19 的 KV,苹果的端侧 3B 模型也使用了类似方案。
Gemma4 将 KV 共享的规模大幅扩展——E2B 从 10 层共享扩展到 20 层共享,共享比例从 33% 提升到 57%。
以 Gemma4 E2B(35 层)为例:
- 层 0-14:每层独立计算和存储自己的 KV Cache(15 层独立 KV)
- 层 15-34:不再计算和存储自己的 KV,而是直接复用层 13 或层 14 的 KV Cache
具体的共享拓扑:
独立 KV 层(layers 0-14):每层计算并存储自己的 K/V
Source 层(提供 KV 给他人复用):
Layer 13 (sliding_attention) → 被 16 个层复用
Layer 14 (full_attention) → 被 4 个层复用
Shared 层(借用 source 层的 KV):
15→13, 16→13, 17→13, 18→13, 19→14,
20→13, 21→13, 22→13, 23→13, 24→14,
25→13, 26→13, 27→13, 28→13, 29→14,
30→13, 31→13, 32→13, 33→13, 34→14
规律非常清晰:每 5 层中 4 层共享层 13(sliding attention),1 层共享层 14(full attention)。共享层依然有自己的 Q 投影和 O 投影,只是 KV 不再独立计算。
这意味着 35 层模型只需要 15 份独立的 KV Cache(层 0-14),而不是 35 份。KV 内存直降 57%。
E4B 的情况类似——42 层中有 18 层共享 KV,同样只需要 24 份独立 KV Cache,KV 内存降低约 43%。
4.2 为什么选层 13 和 14
这并非随机选择。层 13 是 sliding attention 层,层 14 是 full attention 层。共享的 20 层中:
- Sliding attention 层(窗口注意力)共享层 13 的 KV——因为窗口注意力本身只关注局部上下文,共享一份局部 KV 就够了
- Full attention 层(全局注意力)共享层 14 的 KV——全局注意力需要看到完整的历史 KV
这种”分类型共享”的设计比简单的”所有层共享同一份 KV”更精细,保留了混合注意力的局部/全局分工。
4.3 实现细节
HuggingFace 模型在每个 Attention 上标注了三个属性:
-
is_kv_shared_layer:是否是”借用方” -
kv_shared_layer_index:借用哪一层的 KV -
store_full_length_kv:是否是”提供方”
Python 导出时,这些属性传递给 FusedAttention 自定义算子,写入 ONNX 图:
kv_shared_idx = self.kv_shared_layer_index if self.is_kv_shared_layer else -1
self.fused_attn = FusedAttention(
...,
layer_index=layer_id,
kv_shared_layer_index=kv_shared_idx
)
C++ 推理端,CPUAttention 通过一个层级注册表(kv_registry)管理 KV 共享关系:
// 创建 Attention 时注册到 kv_registry
if (layerIndex >= 0 && meta) {
meta->kv_registry[layerIndex] = attn;
}
// 共享层在 forward 时直接引用 source 层的 KV Cache
if (mKVSharedLayerIndex >= 0
&& mMeta->kv_registry[mKVSharedLayerIndex] != nullptr) {
auto sourceKV = static_cast<CPUAttention*>(
mMeta->kv_registry[mKVSharedLayerIndex])->getKVCacheManager();
// 使用 sourceKV 而非自己的 KV Cache
}
为了支持 KV 共享,我们还把原来散布在 4 个文件中的 KVMeta 结构定义统一到了 source/core/KVMeta.hpp,新增 kv_registry(层级注册表)和 kv_shared_map(名称映射表)两个字段。同时给 FlatBuffers 的 AttentionParam 新增了 layer_index 和 kv_shared_layer_index 两个字段。
4.4 内存节省量化分析
以 E2B 模型为例(head_dim=256, kv_heads=1, fp16 KV Cache):
| 指标 | 无 KV 共享 | KV 共享 | 节省 |
|---|---|---|---|
| 独立 KV 层数 | 35 | 15 | -57% |
| 1K token KV Cache | ~17.5 MB | ~7.5 MB | -57% |
| 4K token KV Cache | ~70 MB | ~30 MB | -57% |
| 16K token KV Cache | ~280 MB | ~120 MB | -57% |
对于 E2B 这样定位手机端侧的模型来说,KV Cache 的 57% 缩减意义重大——直接决定了在 4GB 内存的设备上能处理多长的上下文。
五、多模态:视觉与音频
Gemma3n 就已经是原生多模态模型(文本+视觉+音频),Gemma4 延续了这一设计,但视觉编码器架构发生了重大变化——从 Gemma3n 的 MobileNetV5(轻量级卷积网络)切换到了 SigLIP ViT(Vision Transformer),音频编码器则延续了 Conformer 架构。
5.1 视觉编码器
Gemma3n 使用 MobileNetV5 作为视觉编码器,输出 256 个 soft token(vision_soft_tokens_per_image=256),hidden_size=2048。Gemma4 则切换到 ViT + Pooler + Embedder 三段式架构,引入了 2D position_ids 和 3×3 pooling:
Image → Patchify (16×16) → ViT Encoder → Pooler (3×3) → RMSNorm + Linear → image_embeds
Step 1: Patchify
图像按 16×16 切 patch,但在切之前会做两件事:
- 保持原始宽高比(不做强制 resize 到正方形)
- 宽高对齐到 48 像素(
patch_size(16) × pooling_kernel_size(3)),确保后续 pooling 整除
Step 2: ViT Encoder
标准 Vision Transformer,但使用 2D position_ids(行列坐标)而非 1D 序号。这意味着同一行的 patch 共享行坐标,同一列的共享列坐标,比 1D 序号更好地保留了空间结构信息。
Step 3: Pooler
3×3 pooling 将 patch 数压缩 9 倍。例如一张 672×480 的图片产生 42×30=1260 个 patch,pooling 后变成 14×10=140 个 soft token。默认最大输出为 280 个 soft token。
Step 4: Embedder
RMSNorm + Linear 将视觉特征投影到文本 hidden_size 空间,使图像 token 可以与文本 token 在同一个 Decoder 中处理。
导出时手动展开了编码器 forward,避免 HF 代码中 unfold 和动态 mask 的 ONNX trace 问题。C++ 端的视觉处理同样手动实现了 patchify 和 position_ids 生成,并处理了 padding 对齐:
// C++ 端视觉处理核心参数
int patch_size = 16;
int pooling_kernel_size = 3;
int default_output_length = 280; // 最大 soft token 数
int max_patches = default_output_length * pooling_kernel_size * pooling_kernel_size; // 2520
5.2 音频编码器
Gemma3n 和 Gemma4 的音频编码器都基于 Conformer 架构,Gemma4 使用的是 USM(Universal Speech Model) 变体,核心差异在于 chunked attention 的引入——音频被分成固定长度的 chunk,每个 chunk 只关注自身和前后 context 窗口内的帧。
Audio → Mel Spectrogram → Subsample Conv → Conformer Encoder → RMSNorm + Linear → audio_embeds
Mel 特征提取采用 USM 格式:
- 采样率:16000 Hz
- 帧长:320 点(20ms)
- 帧移:160 点(10ms)
- FFT 点数:512
- Mel 滤波器:128 维
- 频率刻度:HTK
Conformer Encoder 是 USM 的标准结构:
FeedForward1 → Attention (chunked) → LightConv1D → FeedForward2
其中 Attention 使用分块滑窗——音频帧被切成等长 chunk,每个 chunk 只关注自身以及左右各若干帧的 context。这避免了对长音频做全局 attention 导致的 O(n²) 开销。
ONNX 导出中最大的挑战是 HF 实现中使用了大量 unfold 操作来构建 chunked attention 的滑动窗口。unfold 在 ONNX trace 时会产生不可控的动态维度。我们实现了 Gemma4AudioExportModel,用 index-gather 替代 unfold:
# 用 gather 替代 unfold 构建 chunk context
offsets = torch.arange(context_size)
block_starts = torch.arange(num_blocks) * chunk_size
indices = (block_starts.unsqueeze(1) + offsets.unsqueeze(0)).reshape(-1)
result = padded[:, indices].reshape(B, num_blocks, context_size, H, D)
5.3 多模态与 PLE 的交互
多模态输入时,PLE 需要特殊处理:图像/音频位置的 token 不能使用对应的多模态 token ID 查 PLE embedding 表。因为 <image_token> 和 <audio_token> 这些特殊 token 在 PLE 表中的 embedding 没有经过有意义的训练,直接使用会引入噪声。
解决方案是将多模态位置的 token ID 替换为 pad_token_id:
ple_ids = input_ids.clone()
for attr in ['image_token_id', 'audio_token_id', 'video_token_id']:
token_id = getattr(config, attr, None)
if isinstance(token_id, int):
ple_ids[ple_ids == token_id] = pad_token_id
主 embedding 的投影输入也做同样处理——多模态位置使用 pad embedding 而非视觉/音频 embedding 来计算 PLE 投影。这保证了 PLE 在多模态场景下的数值稳定性。
六、其他架构适配
6.1 QKV 三 Norm + attn_scale=1.0
Gemma4 的注意力模块有一个独特设计:不仅对 Q 和 K 做 RMSNorm(这在 Gemma3 中就有),还对 V 也做了 RMSNorm(v_norm)。同时,注意力缩放因子固定为 1.0,不再使用传统的 1/√head_dim。
这背后的逻辑是:三个 Norm 已经将 Q、K、V 的范数控制在稳定范围内,注意力 score 的数值就自然稳定了,不再需要额外的缩放。这也是为什么 Gemma4 的 head_dim 可以做到 512(Full Attention)——传统做法下 head_dim=512 意味着 1/√512 ≈ 0.044 的极小缩放因子,容易导致梯度消失。
6.2 K=V 与异构 KV Head(31B / 26B-A4B)
31B 和 26B-A4B 的 Full Attention 层没有独立的 v_proj——K 和 V 源自同一个线性投影 k_proj,之后 K 走 k_norm + RoPE,V 走 v_norm(不经过 RoPE)。导出时在 k_proj 之后 clone 一份给 Value:
# 自动检测 K=V
if not hasattr(attn, 'v_proj'):
k_eq_v = True
# k_proj output → clone → k_norm + RoPE (for K)
# → v_norm (for V, no RoPE)
这两个模型还有异构 KV Head:Sliding 层和 Full 层使用不同的 KV Head 数。以 31B 为例,Sliding 层 16 个 KV Head,Full 层仅 4 个 KV Head。我们从 k_proj.out_features // head_dim 自动推断,不需要在配置中硬编码。
6.3 Decoder 级 MoE(26B-A4B)
26B-A4B 是 Gemma4 中唯一的 MoE 模型,采用”稠密 MLP 与 128 专家 Top-8 并行“的设计——两个分支独立处理 input,各自做 LayerNorm 后相加:
hidden_states
├─→ dense MLP → post_feedforward_layernorm_1 ─→ +
└─→ MoE(128 experts, top-8) → post_feedforward_layernorm_2 ─→ +
│
residual add
这与 Qwen3.5 的 Shared Expert 设计思路类似——稠密 MLP 提供”保底能力”,MoE 专家提供”专业能力”。但 Gemma4 的 MoE 规模更大:128 个专家、Top-8 路由(Qwen3.5 MoE 是 64 专家 Top-8)。
HuggingFace 原始模型中专家权重是 3D Parameter([num_experts, hidden, intermediate]),导出时拆解为 ModuleList,复用 MNN 已有的 MoE 自定义算子管线。路由权重经过 softmax → topk → normalize → per_expert_scale 的标准流程。
6.4 双 RoPE 与部分旋转
Gemma4 的 Sliding 和 Full Attention 层不仅 head_dim 不同(256 vs 512),连 RoPE 的 theta 也不同:
| 层类型 | head_dim | rope_theta | 含义 |
|---|---|---|---|
| Sliding | 256 | 10,000 | 短距离、高频、局部编码 |
| Full | 512 | 1,000,000 | 长距离、低频、全局编码 |
Full Attention 层还支持 partial_rotary_factor(部分旋转),仅对 head_dim 中的一部分维度应用旋转编码,其余维度保持位置无关。实现上复用了 Phi 系列模型的分支:
if self.model_type in ['phi-msft', 'gemma4']:
return self.phi_rotary_pos(x, cos, sin) # 部分旋转
6.5 CUDA OOM 修复
大模型导出时,量化阶段会将权重搬到 GPU 加速。原来的实现只在 weight.cuda() 外套了 try/except 捕获 OOM,但实际 OOM 往往发生在后续的计算步骤中(如 (q_weight * multipliers).sum(axis=1))。
修复方案是将 GPU 量化拆分为独立的 _quant_on_device() 函数,整个 GPU 计算路径包裹在 try/except 中:
def quant(weight, quant_bit, quant_block, symmetric, awq, hqq):
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
return _quant_on_device(weight.cuda(), ...)
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache() # 释放 GPU 内存,fallback 到 CPU
return _quant_on_device(weight, ...) # CPU fallback
这样大部分权重仍在 GPU 上快速完成量化,只有超大权重(如 PLE Embedding)自动回退到 CPU,兼顾速度和稳定性。
七、导出与使用
# E2B 模型导出(int4 量化,PLE int4 量化)
python llmexport.py --path gemma-4-E2B-it --export mnn --hqq --embed_bit 4
# E4B 模型导出
python llmexport.py --path gemma-4-E4B-it --export mnn --hqq --embed_bit 4
# 推理
./llm_demo /path/to/exported/model/config.json prompt.txt
八、E2B / E4B 端侧性能测评
测试环境:Apple M3 Pro 芯片 macOS 设备,MNN 最新版本,HQQ 4-bit 量化 + PLE int4 量化,CPU 后端,4 线程。
8.1 转换大小
| 模型 | LLM 权重 | PLE Embedding | 视觉编码器 | 音频编码器 | 总大小 |
|---|---|---|---|---|---|
| E2B | 1.3 GB | 1.4 GB | 215 MB | 563 MB | 3.5 GB |
| E4B | 2.7 GB | 1.6 GB | 216 MB | 575 MB | 5.2 GB |
值得注意的是,PLE Embedding 即使在 int4 量化后仍然占据了相当大的比例——E2B 的 PLE 甚至比 LLM 权重本身还大。这是 PLE 架构”用 Flash 存储换运行内存”设计思路的直接体现:这些参数从磁盘按需加载,不常驻运行内存。
音频编码器(USM Conformer)在两个模型间大小几乎相同(~570 MB),视觉编码器(SigLIP ViT)同样如此(~215 MB),说明多模态编码器是共享架构、独立于 LLM 主干缩放的。
8.2 推理速度
注意:Gemma4 对激活精度要求较高,必须使用 Normal(fp32)精度才能保证输出质量(详见 8.5 节精度分析)。以下数据均在 Normal 精度下测得。作为对比,Qwen3.5 在 Low(fp16)精度下即可正常工作。
在 CPU(Apple Silicon M3 Pro)后端、4 线程、Normal 精度、KV Cache 开启的条件下:
| 模型 | 模型加载时间 | Prefill 速度 (提示词处理) | Decode 速度 (生成) |
|---|---|---|---|
| E2B | ~1.3 s | ~186 tok/s | ~35 tok/s |
| E4B | ~2.1 s | ~102 tok/s | ~24 tok/s |
解读:
- E2B 的 prefill 速度达到 186 tok/s,decode 约 35 tok/s,对于端侧模型来说流畅度良好。KV 共享(57% KV Cache 缩减)在这里功不可没——更少的 KV Cache 意味着更少的内存带宽消耗。
- E4B 由于 hidden_size 从 1536 增加到 2560、层数从 35 增加到 42,整体速度约为 E2B 的 60-70%,但 24 tok/s 的 decode 速度仍然可以提供可接受的流式输出体验。
- 由于 Gemma4 必须使用 fp32 激活(而 Qwen3.5 只需 fp16),速度差距会进一步拉大。对比同级别的 Qwen3.5 系列(M3 Pro + 4线程 + fp16),Qwen3.5-2B 的 decode 为 70 tok/s(E2B 的 2 倍),Qwen3.5-4B 为 60 tok/s(E4B 的 2.5 倍)。
8.3 问题测试
为了评估端侧模型的实际能力,我们用与 Qwen3.5 评测相同的题目进行测试。测试使用 Normal(fp32)精度。
逻辑陷阱测试
1. 洗车逻辑陷阱
题目:”距离我 30 米有家洗车店,我是开车去洗好还是走路去好?”
- E2B:从便利性、体力、天气等多角度做了详细的利弊分析,条理清晰、没有重复,但未能识破题目的核心陷阱——洗车必须把车带过去,所以只能开车。
- E4B:同样做了非常详细的分析,列出了开车和走路各自的优缺点,甚至给出了”快速决策树”,但同样未识破陷阱,没有意识到”洗车需要车”这个前提。
2. Strawberry 有几个字母 r?
- E2B:回答”2 个 r”。错误(正确答案是 3 个)。
- E4B:回答”3 个 r”。正确。
3. 树上有 10 只鸟,猎人开枪打死 1 只,树上还剩几只?
- E2B:直接回答”9 只”,用数学 10-1=9 解题,未能理解枪声会吓跑其他鸟的常识逻辑。
- E4B:给出了两种理解——”如果作为数学题是 9 只,如果作为脑筋急转弯是 0 只”,展现了较好的语境理解能力,但最终倾向 9 只。部分正确。
4. 鲁迅和周树人是什么关系?
- E2B:正确回答——”鲁迅就是周树人,周树人是鲁迅的本名”,并补充了生卒年份等信息。
- E4B:正确回答——”鲁迅和周树人是同一个人,周树人是鲁迅本名,鲁迅是他的笔名”,简洁准确。
通用能力测试
1. 自我介绍
- E2B:简洁准确——”我是一个大型语言模型,由 Google DeepMind 开发。”
- E4B:自称”Gemma,由 Google 训练”,并描述了自己的功能,完整流畅。
2. 数学计算:(17 × 23) + (45 ÷ 9) - 18 = ?(正确答案:378)
- E2B:正确。使用分配律展开 17×23=391,45÷9=5,最终 391+5-18=378,过程清晰。
- E4B:正确。同样给出了详细的分步计算,结果 378,并附上了步骤回顾。
3. 代码生成:Python 判断质数
- E2B:正确完整。生成了优化版的
is_prime函数,包含偶数快速排除和只检查奇数的优化,附带测试用例。 - E4B:正确完整。同样生成了带完整注释和测试用例的
is_prime函数,代码质量高。
4. 中英翻译
- E2B:提供了三个翻译版本,核心翻译准确流畅:”Deep learning is changing the way we interact with computers, from speech recognition to autonomous driving; its applications are ubiquitous.” 输出完整无重复。
- E4B:同样给出三个翻译选项和关键词汇解释,输出完整流畅。
测试总结
| 测试项 | E2B | E4B |
|---|---|---|
| 洗车逻辑 | △ 分析详细但未识破陷阱 | △ 分析详细但未识破陷阱 |
| Strawberry | ✗ 回答 2 个 | ✓ 正确(3个) |
| 打鸟常识 | ✗ 回答 9 只 | △ 提到 0 只但倾向 9 只 |
| 鲁迅=周树人 | ✓ 正确 | ✓ 正确 |
| 自我介绍 | ✓ 简洁准确 | ✓ 完整流畅 |
| 数学计算 | ✓ 正确(378) | ✓ 正确(378) |
| 代码生成 | ✓ 正确完整 | ✓ 正确完整 |
| 翻译 | ✓ 质量好 | ✓ 质量好 |
| 通过率 | 5/8 | 6/8 |
关键发现:
- E4B 整体略优于 E2B,尤其在 Strawberry 计数和打鸟常识等需要更强推理能力的题目上,更大的模型容量带来了可见的提升。
- 两个模型在基础任务上表现出色——自我介绍、数学计算、代码生成、翻译均通过,输出完整流畅,没有重复循环问题。
- 逻辑陷阱仍然是短板,洗车题两个模型都未能识破”洗车必须带车”的核心前提,这与端侧小模型的常识推理能力限制有关。
- 翻译是两个模型的共同强项,英文输出质量尤为出色,这与 Gemma 系列以英文为主的训练数据分布一致。
8.4 与 Qwen3.5 同级别模型横向对比
Gemma4 E2B / E4B 分别对标 Qwen3.5-2B / Qwen3.5-4B,同为端侧定位的小参数量多模态模型。以下数据均来自同一台 Apple M3 Pro 设备、CPU 4 线程、HQQ 4-bit 量化的测试环境。注意:Gemma4 使用 Normal(fp32)精度,Qwen3.5 使用 Low(fp16)精度——两者均为各自能正常工作的最低精度配置。
模型规格对比
| 对比项 | Gemma4 E2B | Qwen3.5-2B | Gemma4 E4B | Qwen3.5-4B |
|---|---|---|---|---|
| 层数 | 35 | 24 | 42 | 32 |
| 隐藏维度 | 1536 | 2048 | 2560 | 2560 |
| 注意力机制 | 混合(Sliding+Full) | 混合(Linear+Full) | 混合(Sliding+Full) | 混合(Linear+Full) |
| KV Cache 优化 | KV 共享(57%) | Linear Attn(75%无KV) | KV 共享(43%) | Linear Attn(75%无KV) |
| 所需激活精度 | fp32 | fp16 | fp32 | fp16 |
| Embedding | PLE(每层独立) | 共享 Embedding | PLE(每层独立) | 共享 Embedding |
| 视觉 | SigLIP ViT | 内置视觉 | SigLIP ViT | 内置视觉 |
| 音频 | USM Conformer | 无 | USM Conformer | 无 |
两个系列的端侧 KV 优化思路截然不同:Qwen3.5 通过 75% Linear Attention 直接省掉大部分层的 KV Cache,而 Gemma4 通过层间 KV 共享减少独立 KV 份数。Gemma4 独有的 PLE 和音频模态使其在模型体积上显著更大。此外,Gemma4 对激活精度的要求更高(fp32 vs fp16),这进一步影响了推理速度和内存占用。
导出大小对比
| 模型 | 量化后总大小 | 纯 LLM 权重 | 额外开销 |
|---|---|---|---|
| Gemma4 E2B | 3.5 GB | 1.3 GB | PLE 1.4GB + 视觉 215MB + 音频 563MB |
| Qwen3.5-2B | 1.37 GB | ~1.37 GB | 视觉权重含在总大小内 |
| Gemma4 E4B | 5.2 GB | 2.7 GB | PLE 1.6GB + 视觉 216MB + 音频 575MB |
| Qwen3.5-4B | 2.59 GB | ~2.59 GB | 视觉权重含在总大小内 |
Gemma4 E2B 的总大小是 Qwen3.5-2B 的 2.6 倍,E4B 是 Qwen3.5-4B 的 2.0 倍。这主要由三部分额外开销构成:PLE Embedding(~1.4-1.6 GB)、音频编码器(~570 MB)和独立的视觉编码器权重(~215 MB)。如果不需要音频能力,PLE 仍然是不可忽略的体积增量。
推理速度对比
| 模型 | 激活精度 | Prefill (tok/s) | Decode (tok/s) | 首 Token 延迟 |
|---|---|---|---|---|
| Gemma4 E2B | fp32 | ~186 | ~35 | ~1.3 s |
| Qwen3.5-2B | fp16 | ~300 | ~70 | ~0.9 s |
| Gemma4 E4B | fp32 | ~102 | ~24 | ~2.1 s |
| Qwen3.5-4B | fp16 | ~250 | ~60 | ~1.1 s |
Qwen3.5 在推理速度上全面领先:
- 2B 级别:Qwen3.5-2B 的 decode 速度(70 tok/s)是 Gemma4 E2B(35 tok/s)的 2.0 倍
- 4B 级别:Qwen3.5-4B 的 decode 速度(60 tok/s)是 Gemma4 E4B(24 tok/s)的 2.5 倍
Gemma4 更慢的原因:(1) 需要 fp32 激活而 Qwen3.5 只需 fp16,这是最大的速度差距来源;(2) 层数更多(35/42 层 vs 24/32 层),意味着更多的矩阵运算;(3) 双 RoPE(Sliding head_dim=256 + Full head_dim=512)比 Qwen3.5 的统一 RoPE 计算量更大;(4) PLE 的门控注入每层都有额外的 Linear + 激活 + 乘法开销;(5) QKV 三 Norm 比 Qwen3.5 的 QK Norm 多一次归一化。
功能质量对比
| 测试项 | Gemma4 E2B (fp32) | Qwen3.5-2B (fp16) | Gemma4 E4B (fp32) | Qwen3.5-4B (fp16) |
|---|---|---|---|---|
| 洗车逻辑 | △ 详细分析未识破 | △ 结论对理由错 | △ 详细分析未识破 | ✗ 建议走路 |
| Strawberry | ✗ 回答 2 个 | ✗ 死循环 | ✓ 正确(3个) | ✓ 正确(3个) |
| 打鸟常识 | ✗ 回答 9 只 | ✗ 回答 9 只 | △ 提到 0 只但倾向 9 只 | ✓ 正确(0只) |
| 鲁迅=周树人 | ✓ 正确 | ✗ 完全错误 | ✓ 正确 | ✓ 正确 |
| 自我介绍 | ✓ | ✓ | ✓ | ✓ |
| 数学计算 | ✓ 正确(378) | △ | ✓ 正确(378) | ✓ |
| 代码生成 | ✓ 正确完整 | △ | ✓ 正确完整 | ✓ |
| 翻译 | ✓ 质量好 | △ | ✓ 质量好 | ✓ |
| 通过率 | 5/8 | 1/8 | 6/8 | 7/8 |
对比结论:
-
同为 fp16 时 Qwen3.5 优势明显:Qwen3.5 在 fp16 精度下即可正常工作且表现良好,而 Gemma4 在 fp16 下会出现灾难性的重复循环(详见 8.5 节),必须回退到 fp32 才能获得可用的输出质量。仅从”同精度可用性”角度看,Qwen3.5 的架构对低精度更友好。
-
即使 Gemma4 用 fp32,语言能力仍略逊一筹:4B 级别上 E4B(6/8)接近但未超过 Qwen3.5-4B(7/8);2B 级别上 E2B(5/8)虽大幅领先 Qwen3.5-2B(1/8),但 Qwen3.5-2B 本身定位偏低,这一优势参考价值有限。整体而言,纯语言任务上 Qwen3.5 的小模型系列表现更好。
-
速度和体积差距显著:Gemma4 依赖 fp32 激活,decode 速度仅为 Qwen3.5(fp16)的 40-50%;加上 PLE 和音频编码器,模型体积是 Qwen3.5 的 2-2.6 倍。
-
Gemma4 的差异化优势在多模态:Gemma4 独有原生音频理解能力,这是 Qwen3.5 端侧系列不具备的。如果应用场景涉及语音交互或音频分析,Gemma4 是目前端侧唯一的选择。纯文本场景下,Qwen3.5 在速度、体积和语言质量上综合更优。
8.5 精度敏感性:为什么 Gemma4 必须用 fp32
前文所有测试均在 Normal(fp32)精度下进行。这并非随意选择——Gemma4 在 Low(fp16)精度下会出现严重的 token 级重复循环,几乎所有任务都无法正常完成。
fp16 下的灾难性退化
将 config.json 中的 precision 设为 "low"(fp16)后,两个模型的表现:
| 测试项 | E2B (fp16) | E4B (fp16) |
|---|---|---|
| 洗车逻辑 | ✗ 重复循环 | ✗ 重复循环 |
| Strawberry | ✗ 回答 2 个 | ✗ 回答 1 个 |
| 打鸟常识 | ✗ 识别陷阱后重复循环 | ✗ 无法理解中文输入 |
| 鲁迅=周树人 | ✗ 完全错误+重复循环 | ✗ 重复循环 |
| 自我介绍 | ✓ | △ 简繁混杂不完整 |
| 数学计算 | ✗ 切换到印地语+重复 | ✗ 未算出结果 |
| 代码生成 | △ 有语法错误+重复 | ✗ 未完成+重复 |
| 翻译 | ✓ 但末尾轻微重复 | △ 翻译好但术语部分重复 |
| 通过率 | 2/8 | 0/8 |
对比 fp32 下 E2B 5/8、E4B 6/8 的通过率,差距是毁灭性的。fp16 下 E4B 甚至不如 E2B,多次出现语言混杂(简繁混杂、中英切换、甚至冒出印地语)和完全无法理解中文输入的情况。
fp16 速度更快,但不可用
| 模型 | 精度 | Prefill (tok/s) | Decode (tok/s) |
|---|---|---|---|
| E2B | fp16 | ~229 | ~44 |
| E2B | fp32 | ~186 | ~35 |
| E4B | fp16 | ~114 | ~24 |
| E4B | fp32 | ~102 | ~24 |
fp16 确实快 10-20%,但输出质量的崩塌使这个速度优势毫无意义。
原因分析:为什么 Gemma4 对精度敏感而 Qwen3.5 不敏感
Qwen3.5 在 fp16 下即可正常工作,不存在类似的精度退化。两者的关键架构差异解释了这一现象:
-
attn_scale=1.0:Gemma4 不做传统的
1/√head_dim注意力缩放,完全依赖 QKV 三 Norm 来稳定 attention score。fp16 下 Norm 的精度损失会导致 attention score 范围失控。Qwen3.5 使用标准的1/√head_dim缩放,对 Norm 精度的依赖更低。 -
层数更深:E2B 35 层、E4B 42 层,比 Qwen3.5 同级别多 11-10 层。fp16 的误差在更多层中累积,最终导致注意力分布退化为重复模式。
-
PLE 门控注入:每层额外的 Linear → 激活 → 逐元素乘法路径引入了更多的浮点运算环节,每个环节都会在 fp16 下损失精度。
-
双 RoPE:Full Attention 层使用 head_dim=512 + theta=1000000 的大尺度旋转编码,在 fp16 下容易出现三角函数计算的精度问题。
结论:Gemma4 对激活精度的要求显著高于 Qwen3.5,这是其架构设计的固有特性。部署时必须使用 Normal(fp32)精度,config.json 中设置 "precision": "normal"。
九、各版本选型建议
E2B:极致端侧,手机优先
- 定位:4-8GB 内存设备的首选模型
- 优势:KV 共享带来极低内存占用;35 层架构在端侧参数量下仍保持不错的推理能力
- 劣势:hidden_size 仅 1536,知识容量有限
- 推荐场景:手机助手、离线问答、IoT 设备、嵌入式应用
E4B:桌面端侧,平衡之选
- 定位:8-16GB 内存设备的主力模型
- 优势:相比 E2B 显著提升的模型能力,42 层 + hidden=2560 的深度和宽度组合带来更好的推理质量
- 劣势:需要更多内存,不适合低端手机
- 推荐场景:平板电脑、笔记本、桌面端 AI 助手、本地多模态分析
31B:服务器级,高质量推理
- 定位:16GB+ 内存的高性能部署
- 优势:60 层、hidden=5376 提供接近旗舰模型的推理质量;K=V 机制减少了参数量
- 劣势:资源需求高,端侧不可行
- 推荐场景:服务端推理、企业级应用、专业级多模态理解
26B-A4B:MoE 架构,高效推理
- 定位:26B 总参数、仅 4B 激活参数
- 优势:128 专家 Top-8 MoE 设计实现了”大模型知识、小模型算力”的效果;每次推理只激活约 4B 参数
- 劣势:模型文件体积大(所有专家权重都需要存储);MoE 路由开销
- 推荐场景:服务端需要高质量但低推理延迟的场景、多任务并发
十、总结
| 机制 | E2B | E4B | 31B | 26B-A4B |
|---|---|---|---|---|
| 混合注意力 + 双RoPE | ✓ | ✓ | ✓ | ✓ |
| PLE(每层调制信号) | ✓ | ✓ | ✓ | ✓ |
| QKV 三 Norm | ✓ | ✓ | ✓ | ✓ |
| KV 共享 | ✓(20层) | ✓(18层) | — | — |
| K=V (无 v_proj) | — | — | ✓ | ✓ |
| 异构 KV Head | — | — | ✓ | ✓ |
| MoE (128专家 top-8) | — | — | — | ✓ |
| 视觉 | ✓ | ✓ | ✓ | ✓ |
| 音频 | ✓ | ✓ | ✓ | ✓ |
Gemma4 的适配共涉及 34 个文件、+1816/-401 行代码。设计原则是从权重形状自动推断而非硬编码——k_eq_v 通过检测 v_proj、num_key_value_heads 从 k_proj 推断、MoE 检测 experts 属性——一套代码自动适配四种规格。
三个核心创新的适配总结:
- PLE:两路查表 + 门控注入,支持 bf16/int4/int8 量化导出,C++ DiskEmbedding 统一加载
- KV 共享:层级注册表 + FlatBuffers 扩展,E2B 节省 57% KV 内存,使端侧长上下文成为可能
- 多模态:视觉(ViT + 3×3 Pooler)+ 音频(USM Conformer),unfold→gather 解决 ONNX trace 问题
值得一提的是,本次 Gemma4 全系列的适配代码 100% 由 AI 生成,基于我们内部研发的 s-coder 平台(MNN LLM支持 Agent)完成。人工参与集中在关键节点验证、debug 思路分析和部分实现方案的引导上——从架构分析、需求理解、代码生成、调试修复到最终验证,Agent 全程自主完成了 34 个文件、2000+ 行代码的编写。这也是我们在大型推理引擎项目上验证 AI Coding 落地效果的一次实践。
资源链接
- Github Issue 需求: https://github.com/alibaba/MNN/issues/4340
- Github 代码实现: https://github.com/alibaba/MNN/commit/ba76938ec0935a6bf7f01f348d975f9c8d8b5816
- HuggingFace 模型合集: https://huggingface.co/collections/taobao-mnn/gemma-4-mnn
- ModelScope 模型合集: https://modelscope.cn/collections/MNN/gemma-4-MNN
- Gemma3n 架构分析: https://zhuanlan.zhihu.com/p/1929221256969945147
Enjoy Reading This Article?
Here are some more articles you might like to read next: