STEM:用“Embedding”替代“up_proj”

背景:端侧推理的 FFN 痛点

在端侧 LLM 推理场景中,由于上下文通常不会拉得特别长,相比于 Attention 模块,FFN 实际上才是性能的隐形杀手

在典型的 Transformer 架构中,FFN 占据了约 2/3 的参数量,产生大量的 FLOPs 以及 Decode 阶段的访存带宽压力。对于内存受限的手机或端侧设备来说,FFN 的大矩阵乘法是最大的瓶颈。我一直在关注是否有更适合端侧的稀疏化方案。

最近看到 CMU 和 Meta AI 的 STEM: Scaling Transformers with Embedding Modules (Arxiv: 2601.10639) 比较有意思。


STEM 的核心思路

STEM(Scaling Transformers with Embedding Modules)不仅是一个新架构,更像是一种“反向操作”。

核心改动: 它直接移除了 FFN 层中的 up_proj 矩阵(升维操作),将其替换为基于 Token ID 的 Embedding 查表。

1. 公式对比

标准的 SwiGLU FFN 计算: \(FFN(x) = (x W_{up}) \odot \text{Sigmoid}(x W_{gate}) W_{down}\)

STEM 的 FFN 计算: \(FFN_{STEM}(x, t) = \text{Embedding}(E, t) \odot \text{Sigmoid}(x W_{gate}) W_{down}\) 其中 $t$ 是原始输入的 Token ID,$E$ 是该层独立的 Embedding 表。

2. 代码实现

看起来非常简单,把计算量最大的升维操作变成了 $O(1)$ 的查表:

class STEM_FFN(nn.Module):
    def __init__(self, vocab_size, hidden_dim, intermediate_dim):
        super().__init__()
        # 1. 把原来的 Linear(hidden_dim, intermediate_dim) 换成了 Embedding
        # 注意:每一层都有独立的 Embedding 表
        self.up_embedding = nn.Embedding(vocab_size, intermediate_dim)

        # 2. Gate 和 Down 依然是 Dense Linear,保留上下文能力
        self.gate_proj = nn.Linear(hidden_dim, intermediate_dim, bias=False)
        self.down_proj = nn.Linear(intermediate_dim, hidden_dim, bias=False)

    def forward(self, x, input_ids):
        # x: [batch, seq, hidden]
        # input_ids: [batch, seq] (原始输入的token id)

        # 不再计算 x @ W_up,而是直接查表
        up_state = self.up_embedding(input_ids)
        gate_state = F.sigmoid(self.gate_proj(x))
        return self.down_proj(up_state * gate_state)
    )

隐忧与亮点

1. 隐忧:大词表怎么办?(The Large Vocabulary Problem)

看到上面的代码,第一反应通常是:参数量会不会爆炸? 像 Qwen 这种 vocab_size 高达 15w 的模型,如果每一层都存一个完整的 [150000, 11008] 的 Embedding 表,显存根本塞不下。

解决方案: 论文提出使用 Low-Rank Factorization(低秩分解)。 不直接存大表,而是将其拆解为两个小矩阵(例如先查出一个 512 维的小向量,再投影回 11008 维)。 虽然这引入了一次额外的小矩阵计算(约占原 FFN 计算量的 3%),但能将参数量压缩 20 倍以上,兼顾了低计算量与可控的存储体积。

2. 亮点

除了效率,STEM 架构还带来了两个有趣的特性:

  • 长上下文能力暴涨(Long Context): 传统模型的 FFN 参数是固定的“大锅饭”,长文中容易出现知识干扰。而 STEM 随着序列变长,被激活的 Embedding 参数变多,模型的“有效记忆容量”是动态增加的。Meta 的实验显示其在长窗口任务(如 Needle In A Haystack)上表现优异。
  • 极致的可解释性与知识注入: 想更新模型关于“Apple”的知识?你不需要微调整个网络,只需要修改该 Token 对应的 Embedding 条目。这种精确的知识编辑(Knowledge Editing) 能力是传统 Dense 模型难以具备的。

核心优势:Offloading 到闪存

STEM 相比原始 FFN,最大的工程价值在于极致的动静分离

目前主流的 MoE 虽然也有稀疏性,但由于 router 的结果依赖 Attention 之后的计算,我们无法提前预知会用到哪个专家,这导致必须把所有 Expert 放在内存中(或者忍受极高的延迟)。

而 STEM 完全不同:

  1. up_embedding 是只读的查表。
  2. 索引是 Token ID:这意味着我们在推理的一开始就知道需要读取哪些参数。

因此,我们可以利用异步预取(Prefetch) 技术,将这部分参数(约占 FFN 总参数的 1/3)直接存储在 闪存(Flash/SSD) 中,仅在计算需要时加载到内存。这对于内存首先得端侧设备是巨大的利好。


思考:为什么它能 work?(对比 Pre-gated MoE)

用“原始 token”做索引的方案,很容易让人联想到 2023 年的 Pre-gated MoE (Arxiv: 2308.12066)。

Pre-gated MoE 试图在 Attention 之前路由,结果普遍效果不好。原因在于 Hidden State 还没融合上下文,Router 处于“盲选”状态。

那为什么 STEM 甚至直接用 Token ID 这种更原始的信息做索引,反而有效?

根本原因在于“硬路由”与“软门控”的区别:

  • Pre-gated MoE (硬路由 Failure): 它是做选择题。Router 没看上下文,把“Apple(公司)”分给了“水果专家”,后面怎么算都是错的。这种信息截断是致命的。

  • STEM (软门控 Success): 它是做过滤题。 STEM 的 Embedding 向量利用高维空间的叠加性,在同一个向量里存储了“Apple”作为水果、公司、甚至唱片的所有特征。 关键在于保留的 Gate。这个 Gate 是作用在经过 Attention 后的 $x$ 上的!它拥有完美的上下文信息。

    流程是:

    1. Embedding: 把“Apple”的所有含义一股脑全拿出来。
    2. Gate: 看着上下文说,“这里讨论的是 iPhone,只要科技属性”。
    3. 相乘: 抑制水果属性,激活科技属性。

结论:STEM 没有在索引阶段做决策,它只是负责检索全量信息,把决策权留给了后面的 Gate。


总结

虽然将矩阵乘法替换为查表是否能在所有任务中保持泛化能力还有待大规模验证。但从端侧推理的角度看,STEM 提供了一种以存储换计算、以 IO 换显存的绝佳思路。如果后续开源模型能验证其 scaling law,确实是一个不错的优化方案。




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • jinja.cpp:为什么我要手写一个 Jinja2 编译器
  • LLM Super Weight 实测:剪枝降智与量化思考
  • MNN支持Eagle3
  • LLM训练实战手册
  • MNN模型支持:Qwen3-VL