TL;DR#
MXU(Matrix Multiply Unit)是 TPU 做矩阵乘法的核心硬件。它的底层是一个 脉动阵列(systolic array)——权重预先加载到阵列中固定不动,激活值从左侧逐行流入,部分和从上往下逐层累加,流到底部时内积就算完了。
脉动阵列的精妙在于它的"斜切输入"设计——A 矩阵的同一行元素沿对角线错位进入阵列的不同行,而部分和从上往下流动的速度恰好和这个错位完美匹配。不需要任何复杂的调度逻辑,数据自然地在正确的时刻汇合到正确的 cell——这就是"脉动"(systolic)的含义,像心跳一样,数据以固定的节律在阵列中脉动前进。
本文通过模拟一个 4×4 的 MXU,逐 cycle 展示每个计算单元在做什么、数据从哪来到哪去,让你直观理解脉动阵列为什么能高效地做矩阵乘法。
MXU 在 TPU 中的位置#
每代 TPU 芯片的 MXU 配置不同:
| 芯片 | TensorCore 数 | 每个 TC 的 MXU 数 | MXU 尺寸 |
|---|---|---|---|
| v5e 及之前 | 1 | 1 | 128×128 |
| v6e | 1 | 2 | 256×256 |
| v7x | 2 | 2 | 256×256 |
MXU 由脉动阵列中的乘法累加器(MAC)组成(参考)。256×256 的 MXU 意味着每个时钟周期可以完成 256×256×2 = 131072 次浮点运算——这就是 TPU 矩阵乘法吞吐量的来源。
那么问题来了:MXU 为什么能这么快?每一拍里它的每个计算单元在做什么?数据是怎么流过去的?
脉动阵列的核心规则#
在开始逐 cycle 追踪之前,先理解每个计算单元(cell)的行为。规则非常简单:
每个 cell 存储一个固定的权重 w(来自矩阵 B),每个 cycle 执行:
- 从左邻居接收激活值
a - 从上邻居接收部分和
p_in(第 0 行p_in = 0) - 计算
p_out = p_in + a × w - 把
a原样传给右邻居 - 把
p_out传给下邻居
两个数据流方向:激活值 a 向右流,部分和 p 向下流。就这么简单——每个 cell 只做一次乘加,然后把数据传走。
p_in
|
v
a --> [ w fixed ] --> a
[p_in+a*w]
|
v
p_out4×4 完整模拟#
输入矩阵#
A (激活值,从左侧流入) B (权重,固定在阵列中)
[ 1 2 1 2 ] [ 1 0 1 0 ]
[ 2 1 2 1 ] [ 0 1 0 1 ]
[ 1 1 1 1 ] [ 1 1 0 0 ]
[ 2 2 2 2 ] [ 0 0 1 1 ]
C = A × B (期望结果):
[ 2 3 3 4 ]
[ 4 3 3 2 ]
[ 2 2 2 2 ]
[ 4 4 4 4 ]权重布局#
矩阵 B 被预先加载到 4×4 阵列中,每个 cell 固定存储一个权重值:
c0 c1 c2 c3
┌──────┬──────┬──────┬──────┐
r0 │ w=1 │ w=0 │ w=1 │ w=0 │
├──────┼──────┼──────┼──────┤
r1 │ w=0 │ w=1 │ w=0 │ w=1 │
├──────┼──────┼──────┼──────┤
r2 │ w=1 │ w=1 │ w=0 │ w=0 │
├──────┼──────┼──────┼──────┤
r3 │ w=0 │ w=0 │ w=1 │ w=1 │
└──────┴──────┴──────┴──────┘斜切输入:A 矩阵怎么喂进去#
A 不是整行一起塞进去的。A 的元素按照对角线交错的方式进入各行——A[m][k] 在 cycle m+k 进入阵列的第 k 行左边界:
| cy0 | cy1 | cy2 | cy3 | cy4 | cy5 | cy6 | cy7 | cy8 | cy9 | |
|---|---|---|---|---|---|---|---|---|---|---|
| r0 | A[0][0]=1 | A[1][0]=2 | A[2][0]=1 | A[3][0]=2 | · | · | · | · | · | · |
| r1 | · | A[0][1]=2 | A[1][1]=1 | A[2][1]=1 | A[3][1]=2 | · | · | · | · | · |
| r2 | · | · | A[0][2]=1 | A[1][2]=2 | A[2][2]=1 | A[3][2]=2 | · | · | · | · |
| r3 | · | · | · | A[0][3]=2 | A[1][3]=1 | A[2][3]=1 | A[3][3]=2 | · | · | · |
注意对角线:A 的第 0 行 [1, 2, 1, 2] 分布在 (cy0, r0) → (cy1, r1) → (cy2, r2) → (cy3, r3) 这条对角线上。它们不是同时进入阵列的,但在阵列内部会完美汇合——这就是脉动阵列的精妙之处。
逐 Cycle 追踪#
下面用方格图展示每个 cycle 的阵列状态。每个格子里 a×w+p=R 的含义:
- a = 从左方来的激活值
- w = 格子里固定的权重
- p = 从上方来的部分和(第 0 行固定为 0)
- R = p + a×w,传给下方
左侧 >v 表示值 v 从左边界进入该行,· 表示该格子本 cycle 无数据。
核心数据流:上一个 cycle 中 cell[r][c] 算出的 R,会成为下一个 cycle 中 cell[r+1][c] 的 p。这就是部分和逐层向下传递的机制。
Cycle 0 — A[0][0]=1 进入 r0#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
>1 r0│ 1×1+0=1 │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ · │ · │ · │
└─────────┴─────────┴─────────┴─────────┘Cycle 1 — A[1][0]=2 进入 r0, A[0][1]=2 进入 r1#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
>2 r0│ 2×1+0=2 │ 1×0+0=0 │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
>2 r1│ 2×0+1=1 │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ · │ · │ · │
└─────────┴─────────┴─────────┴─────────┘r0c1 的 a=1 是上一拍 r0c0 向右传过来的。r1c0 的 +1 是上一拍 r0c0 的结果 1 向下传下来的。
Cycle 2 — A[2][0]=1 进入 r0, A[1][1]=1 进入 r1, A[0][2]=1 进入 r2#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
>1 r0│ 1×1+0=1 │ 2×0+0=0 │ 1×1+0=1 │ · │
├─────────┼─────────┼─────────┼─────────┤
>1 r1│ 1×0+2=2 │ 2×1+0=2 │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
>1 r2│ 1×1+1=2 │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ · │ · │ · │
└─────────┴─────────┴─────────┴─────────┘活跃区域形成三角形,从左上角向右下角扩展——这就是"对角波前"。
Cycle 3 — 阵列满载,第一个结果出炉!#
A[3][0]=2 进入 r0, A[2][1]=1 进入 r1, A[1][2]=2 进入 r2, A[0][3]=2 进入 r3
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
>2 r0│ 2×1+0=2 │ 1×0+0=0 │ 2×1+0=2 │ 1×0+0=0 │
├─────────┼─────────┼─────────┼─────────┤
>1 r1│ 1×0+1=1 │ 1×1+0=1 │ 2×0+1=1 │ · │
├─────────┼─────────┼─────────┼─────────┤
>2 r2│ 2×1+2=4 │ 1×1+2=3 │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
>2 r3│ 2×0+2=2 │ · │ · │ · │
└─────────┴─────────┴─────────┴─────────┘
★C[0][0]=2追踪 C[0][0] 的完整生命线——A 的第 0 行 [1, 2, 1, 2] 的 4 个元素在 4 个 cycle 里分别进入 4 行,部分和逐层向下累加:
| Cycle | Cell | 激活 a | 权重 w | p (来自上方) | 计算 | R ↓ |
|---|---|---|---|---|---|---|
| 0 | [0][0] | A[0][0]=1 | 1 | 0 (顶部) | 1×1+0 | 1 |
| 1 | [1][0] | A[0][1]=2 | 0 | 1 | 2×0+1 | 1 |
| 2 | [2][0] | A[0][2]=1 | 1 | 1 | 1×1+1 | 2 |
| 3 | [3][0] | A[0][3]=2 | 0 | 2 | 2×0+2 | 2 = C[0][0] |
验证:C[0][0] = 1×1 + 2×0 + 1×1 + 2×0 = 2 ✓
Cycle 4 — A[3][1]=2 进入 r1, A[2][2]=1 进入 r2, A[1][3]=1 进入 r3#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ 2×0+0=0 │ 1×1+0=1 │ 2×0+0=0 │
├─────────┼─────────┼─────────┼─────────┤
>2 r1│ 2×0+2=2 │ 1×1+0=1 │ 1×0+2=2 │ 2×1+0=2 │
├─────────┼─────────┼─────────┼─────────┤
>1 r2│ 1×1+1=2 │ 2×1+1=3 │ 1×0+1=1 │ · │
├─────────┼─────────┼─────────┼─────────┤
>1 r3│ 1×0+4=4 │ 2×0+3=3 │ · │ · │
└─────────┴─────────┴─────────┴─────────┘
★C[1][0]=4 ★C[0][1]=3两个结果同时从底部不同列流出。流水线开始产出。
Cycle 5 — A[3][2]=2 进入 r2, A[2][3]=1 进入 r3#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ · │ 2×1+0=2 │ 1×0+0=0 │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ 2×1+0=2 │ 1×0+1=1 │ 1×1+0=1 │
├─────────┼─────────┼─────────┼─────────┤
>2 r2│ 2×1+2=4 │ 1×1+1=2 │ 2×0+2=2 │ 1×0+2=2 │
├─────────┼─────────┼─────────┼─────────┤
>1 r3│ 1×0+2=2 │ 1×0+3=3 │ 2×1+1=3 │ · │
└─────────┴─────────┴─────────┴─────────┘
★C[2][0]=2 ★C[1][1]=3 ★C[0][2]=3三个结果沿底部对角线同时涌出。
Cycle 6 — A[3][3]=2 进入 r3(最后一个元素), 四个结果同时流出!#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ · │ · │ 2×0+0=0 │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ · │ 2×0+2=2 │ 1×1+0=1 │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ 2×1+2=4 │ 1×0+1=1 │ 2×0+1=1 │
├─────────┼─────────┼─────────┼─────────┤
>2 r3│ 2×0+4=4 │ 1×0+2=2 │ 1×1+2=3 │ 2×1+2=4 │
└─────────┴─────────┴─────────┴─────────┘
★C[3][0]=4 ★C[2][1]=2 ★C[1][2]=3 ★C[0][3]=4吞吐量最高的一拍——4 个结果同时从底部 4 列流出。
Cycle 7 — 流水线开始排空#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ · │ · │ 2×1+0=2 │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ · │ 2×0+2=2 │ 1×0+1=1 │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ 2×0+4=4 │ 1×1+1=2 │ 1×1+1=2 │
└─────────┴─────────┴─────────┴─────────┘
★C[3][1]=4 ★C[2][2]=2 ★C[1][3]=2Cycle 8#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ · │ · │ 2×0+2=2 │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ · │ 2×1+2=4 │ 1×1+1=2 │
└─────────┴─────────┴─────────┴─────────┘
★C[3][2]=4 ★C[2][3]=2Cycle 9 — 最后一个结果#
c0 c1 c2 c3
┌─────────┬─────────┬─────────┬─────────┐
r0│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r1│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r2│ · │ · │ · │ · │
├─────────┼─────────┼─────────┼─────────┤
r3│ · │ · │ · │ 2×1+2=4 │
└─────────┴─────────┴─────────┴─────────┘
★C[3][3]=4输出时序总表#
C[m][n] 在 cycle (m + n + 3) 从 column n 底部流出:
c0 c1 c2 c3
C[0][·] cy3→2 cy4→3 cy5→3 cy6→4
C[1][·] cy4→4 cy5→3 cy6→3 cy7→2
C[2][·] cy5→2 cy6→2 cy7→2 cy8→2
C[3][·] cy6→4 cy7→4 cy8→4 cy9→4结果沿反对角线从左上到右下依次流出,每条对角线上的结果在同一个 cycle 同时产出。
流水线三阶段#
回顾完整的 10 个 cycle(cycle 0–9),可以清晰地看到三个阶段:
Cycle 0–2: 填充期 数据还没流到底部,没有结果输出
Cycle 3–6: 满载期 每个 cycle 都有结果从底部流出
Cycle 6: 峰值 4 个结果同时输出(最高吞吐)
Cycle 7–9: 排空期 右下角的结果最后流出总共 3N - 2 = 10 个 cycle 完成 4×4 矩阵乘法。N×N 阵列做 N×N matmul,延迟是 3N-2,但如果连续喂入多组数据(pipeline),稳态吞吐量是 每 cycle 产出 N 个结果。
大矩阵怎么办:分块计算#
实际的 MXU 是 128×128 或 256×256,但矩阵通常远大于此。以 [2048×1024] × [1024×4096] 为例:
TILE = 128 # MXU 尺寸
C = zeros(2048, 4096)
# 循环顺序:n → k → m(权重复用最优)
for n in range(0, 4096, TILE): # 32 块,遍历输出列方向
for k in range(0, 1024, TILE): # 8 块,遍历收缩维
# Step 1: 加载权重 tile(每个 (n,k) 组合只加载一次)
B_tile = B[k:k+128, n:n+128]
mxu.load_weights(B_tile) # vmatpush,~128 cycles
for m in range(0, 2048, TILE): # 16 块,多组激活流过同一组权重
# Step 2: 激活值流过 MXU
A_tile = A[m:m+128, k:k+128]
partial = mxu.stream_through(A_tile)
# Step 3: 累加到对应的 C tile
C[m:m+128, n:n+128] += partial循环顺序选择 n → k → m 而不是 m → n → k,是因为 MXU 是 weight-stationary(权重固定)架构——每次 vmatpush 加载权重需要 ~128 cycles,代价很高。把 m 放在最内层,同一组权重被 M/TILE = 16 组 A tile 连续复用,权重加载的开销被摊薄 16 倍。如果 k 在最内层,每次 k 迭代都要换权重,同一组权重只被用 1 次。
每次权重加载还会带来流水线空拍(填充 + 排空各 ~127 cycles)。连续流过的 A tile 越多,空拍占比越低,MXU 利用率越高。
每次 MXU 调用就是我们上面模拟的过程——只不过是 128×128 而不是 4×4。
在 上一篇文章 的 LLO 分析中,我们看到的 vmatpush → vmatmul → vpop 指令序列,就是这个过程在真实硬件上的体现:
%v63 = vld [vmem:[%s55] sm:$0xff] ← 从 VMEM 加载权重 tile
%64 = vmatpush.bf16.msra.mxu0 %v63 ← 把权重推入脉动阵列
...
%v93 = vld [vmem:[#allocation0] sm:$0xff] ← 加载激活值 tile
%94 = vmatmul.bf16.gmra.mxu0 %v93 ← 激活值流过阵列
%v95 = vpop.f32.mrf.mxu0 ← 从底部取出结果两个 MXU 怎么协作#
v6e 和 v7x 的每个 TensorCore 有两个 MXU,它们可以沿不同维度拆分工作:
方式一:沿 K 维切分(内积拆半)
A = [A_left | A_right] B = [B_top ]
[B_bot ]
MXU 0: A_left × B_top = C_partial_0
MXU 1: A_right × B_bot = C_partial_1
C = C_partial_0 + C_partial_1 ← 两个部分和相加方式二:沿 M 或 N 维切分(输出拆半)
MXU 0: A_top × B = C_top ← 算输出的上半部分
MXU 1: A_bot × B = C_bot ← 算输出的下半部分
C = [C_top] ← 直接拼接,无需相加
[C_bot]两个 MXU 就像两条并行的生产线,吞吐量直接翻倍。