跳过正文
  1. Posts/

脉动阵列:逐 Cycle 拆解 TPU MXU 的矩阵乘法

·7 分钟
目录

TL;DR
#

MXU(Matrix Multiply Unit)是 TPU 做矩阵乘法的核心硬件。它的底层是一个 脉动阵列(systolic array)——权重预先加载到阵列中固定不动,激活值从左侧逐行流入,部分和从上往下逐层累加,流到底部时内积就算完了。

脉动阵列的精妙在于它的"斜切输入"设计——A 矩阵的同一行元素沿对角线错位进入阵列的不同行,而部分和从上往下流动的速度恰好和这个错位完美匹配。不需要任何复杂的调度逻辑,数据自然地在正确的时刻汇合到正确的 cell——这就是"脉动"(systolic)的含义,像心跳一样,数据以固定的节律在阵列中脉动前进。

本文通过模拟一个 4×4 的 MXU,逐 cycle 展示每个计算单元在做什么、数据从哪来到哪去,让你直观理解脉动阵列为什么能高效地做矩阵乘法。

MXU 在 TPU 中的位置
#

每代 TPU 芯片的 MXU 配置不同:

芯片TensorCore 数每个 TC 的 MXU 数MXU 尺寸
v5e 及之前11128×128
v6e12256×256
v7x22256×256

MXU 由脉动阵列中的乘法累加器(MAC)组成(参考)。256×256 的 MXU 意味着每个时钟周期可以完成 256×256×2 = 131072 次浮点运算——这就是 TPU 矩阵乘法吞吐量的来源。

那么问题来了:MXU 为什么能这么快?每一拍里它的每个计算单元在做什么?数据是怎么流过去的?

脉动阵列的核心规则
#

在开始逐 cycle 追踪之前,先理解每个计算单元(cell)的行为。规则非常简单:

每个 cell 存储一个固定的权重 w(来自矩阵 B),每个 cycle 执行:

  1. 从左邻居接收激活值 a
  2. 从上邻居接收部分和 p_in(第 0 行 p_in = 0
  3. 计算 p_out = p_in + a × w
  4. a 原样传给右邻居
  5. p_out 传给下邻居

两个数据流方向:激活值 a 向右流,部分和 p 向下流。就这么简单——每个 cell 只做一次乘加,然后把数据传走。

          p_in
           |
           v
  a --> [ w fixed ] --> a
        [p_in+a*w]
           |
           v
         p_out

4×4 完整模拟
#

输入矩阵
#

A (激活值,从左侧流入)        B (权重,固定在阵列中)

[ 1  2  1  2 ]                [ 1  0  1  0 ]
[ 2  1  2  1 ]                [ 0  1  0  1 ]
[ 1  1  1  1 ]                [ 1  1  0  0 ]
[ 2  2  2  2 ]                [ 0  0  1  1 ]

C = A × B (期望结果):
[ 2  3  3  4 ]
[ 4  3  3  2 ]
[ 2  2  2  2 ]
[ 4  4  4  4 ]

权重布局
#

矩阵 B 被预先加载到 4×4 阵列中,每个 cell 固定存储一个权重值:

         c0     c1     c2     c3
       ┌──────┬──────┬──────┬──────┐
  r0   │ w=1  │ w=0  │ w=1  │ w=0  │
       ├──────┼──────┼──────┼──────┤
  r1   │ w=0  │ w=1  │ w=0  │ w=1  │
       ├──────┼──────┼──────┼──────┤
  r2   │ w=1  │ w=1  │ w=0  │ w=0  │
       ├──────┼──────┼──────┼──────┤
  r3   │ w=0  │ w=0  │ w=1  │ w=1  │
       └──────┴──────┴──────┴──────┘

斜切输入:A 矩阵怎么喂进去
#

A 不是整行一起塞进去的。A 的元素按照对角线交错的方式进入各行——A[m][k] 在 cycle m+k 进入阵列的第 k 行左边界:

cy0cy1cy2cy3cy4cy5cy6cy7cy8cy9
r0A[0][0]=1A[1][0]=2A[2][0]=1A[3][0]=2······
r1·A[0][1]=2A[1][1]=1A[2][1]=1A[3][1]=2·····
r2··A[0][2]=1A[1][2]=2A[2][2]=1A[3][2]=2····
r3···A[0][3]=2A[1][3]=1A[2][3]=1A[3][3]=2···

注意对角线:A 的第 0 行 [1, 2, 1, 2] 分布在 (cy0, r0)(cy1, r1)(cy2, r2)(cy3, r3) 这条对角线上。它们不是同时进入阵列的,但在阵列内部会完美汇合——这就是脉动阵列的精妙之处。

逐 Cycle 追踪
#

下面用方格图展示每个 cycle 的阵列状态。每个格子里 a×w+p=R 的含义:

  • a = 从左方来的激活值
  • w = 格子里固定的权重
  • p = 从上方来的部分和(第 0 行固定为 0)
  • R = p + a×w,传给下方

左侧 >v 表示值 v 从左边界进入该行,· 表示该格子本 cycle 无数据。

核心数据流:上一个 cycle 中 cell[r][c] 算出的 R,会成为下一个 cycle 中 cell[r+1][c] 的 p。这就是部分和逐层向下传递的机制。


Cycle 0 — A[0][0]=1 进入 r0
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
>1 r0│ 1×1+0=1 │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │    ·    │    ·    │    ·    │
     └─────────┴─────────┴─────────┴─────────┘

Cycle 1 — A[1][0]=2 进入 r0, A[0][1]=2 进入 r1
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
>2 r0│ 2×1+0=2 │ 1×0+0=0 │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r1│ 2×0+1=1 │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │    ·    │    ·    │    ·    │
     └─────────┴─────────┴─────────┴─────────┘

r0c1 的 a=1 是上一拍 r0c0 向右传过来的。r1c0 的 +1 是上一拍 r0c0 的结果 1 向下传下来的。


Cycle 2 — A[2][0]=1 进入 r0, A[1][1]=1 进入 r1, A[0][2]=1 进入 r2
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
>1 r0│ 1×1+0=1 │ 2×0+0=0 │ 1×1+0=1 │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r1│ 1×0+2=2 │ 2×1+0=2 │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r2│ 1×1+1=2 │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │    ·    │    ·    │    ·    │
     └─────────┴─────────┴─────────┴─────────┘

活跃区域形成三角形,从左上角向右下角扩展——这就是"对角波前"。


Cycle 3 — 阵列满载,第一个结果出炉!
#

A[3][0]=2 进入 r0, A[2][1]=1 进入 r1, A[1][2]=2 进入 r2, A[0][3]=2 进入 r3

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
>2 r0│ 2×1+0=2 │ 1×0+0=0 │ 2×1+0=2 │ 1×0+0=0 │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r1│ 1×0+1=1 │ 1×1+0=1 │ 2×0+1=1 │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r2│ 2×1+2=4 │ 1×1+2=3 │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r3│ 2×0+2=2 │    ·    │    ·    │    ·    │
     └─────────┴─────────┴─────────┴─────────┘
      ★C[0][0]=2

追踪 C[0][0] 的完整生命线——A 的第 0 行 [1, 2, 1, 2] 的 4 个元素在 4 个 cycle 里分别进入 4 行,部分和逐层向下累加:

CycleCell激活 a权重 wp (来自上方)计算R ↓
0[0][0]A[0][0]=110 (顶部)1×1+01
1[1][0]A[0][1]=2012×0+11
2[2][0]A[0][2]=1111×1+12
3[3][0]A[0][3]=2022×0+22 = C[0][0]

验证:C[0][0] = 1×1 + 2×0 + 1×1 + 2×0 = 2 ✓


Cycle 4 — A[3][1]=2 进入 r1, A[2][2]=1 进入 r2, A[1][3]=1 进入 r3
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │ 2×0+0=0 │ 1×1+0=1 │ 2×0+0=0 │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r1│ 2×0+2=2 │ 1×1+0=1 │ 1×0+2=2 │ 2×1+0=2 │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r2│ 1×1+1=2 │ 2×1+1=3 │ 1×0+1=1 │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r3│ 1×0+4=4 │ 2×0+3=3 │    ·    │    ·    │
     └─────────┴─────────┴─────────┴─────────┘
      ★C[1][0]=4 ★C[0][1]=3

两个结果同时从底部不同列流出。流水线开始产出。


Cycle 5 — A[3][2]=2 进入 r2, A[2][3]=1 进入 r3
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │    ·    │ 2×1+0=2 │ 1×0+0=0 │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │ 2×1+0=2 │ 1×0+1=1 │ 1×1+0=1 │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r2│ 2×1+2=4 │ 1×1+1=2 │ 2×0+2=2 │ 1×0+2=2 │
     ├─────────┼─────────┼─────────┼─────────┤
>1 r3│ 1×0+2=2 │ 1×0+3=3 │ 2×1+1=3 │    ·    │
     └─────────┴─────────┴─────────┴─────────┘
      ★C[2][0]=2 ★C[1][1]=3 ★C[0][2]=3

三个结果沿底部对角线同时涌出。


Cycle 6 — A[3][3]=2 进入 r3(最后一个元素), 四个结果同时流出!
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │    ·    │    ·    │ 2×0+0=0 │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │    ·    │ 2×0+2=2 │ 1×1+0=1 │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │ 2×1+2=4 │ 1×0+1=1 │ 2×0+1=1 │
     ├─────────┼─────────┼─────────┼─────────┤
>2 r3│ 2×0+4=4 │ 1×0+2=2 │ 1×1+2=3 │ 2×1+2=4 │
     └─────────┴─────────┴─────────┴─────────┘
      ★C[3][0]=4 ★C[2][1]=2 ★C[1][2]=3 ★C[0][3]=4

吞吐量最高的一拍——4 个结果同时从底部 4 列流出。


Cycle 7 — 流水线开始排空
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │    ·    │    ·    │ 2×1+0=2 │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │    ·    │ 2×0+2=2 │ 1×0+1=1 │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │ 2×0+4=4 │ 1×1+1=2 │ 1×1+1=2 │
     └─────────┴─────────┴─────────┴─────────┘
               ★C[3][1]=4 ★C[2][2]=2 ★C[1][3]=2

Cycle 8
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │    ·    │    ·    │ 2×0+2=2 │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │    ·    │ 2×1+2=4 │ 1×1+1=2 │
     └─────────┴─────────┴─────────┴─────────┘
                          ★C[3][2]=4 ★C[2][3]=2

Cycle 9 — 最后一个结果
#

         c0        c1        c2        c3
     ┌─────────┬─────────┬─────────┬─────────┐
   r0│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r1│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r2│    ·    │    ·    │    ·    │    ·    │
     ├─────────┼─────────┼─────────┼─────────┤
   r3│    ·    │    ·    │    ·    │ 2×1+2=4 │
     └─────────┴─────────┴─────────┴─────────┘
                                    ★C[3][3]=4

输出时序总表
#

C[m][n] 在 cycle (m + n + 3) 从 column n 底部流出:

          c0       c1       c2       c3
C[0][·]  cy3→2   cy4→3   cy5→3   cy6→4
C[1][·]  cy4→4   cy5→3   cy6→3   cy7→2
C[2][·]  cy5→2   cy6→2   cy7→2   cy8→2
C[3][·]  cy6→4   cy7→4   cy8→4   cy9→4

结果沿反对角线从左上到右下依次流出,每条对角线上的结果在同一个 cycle 同时产出。

流水线三阶段
#

回顾完整的 10 个 cycle(cycle 0–9),可以清晰地看到三个阶段:

Cycle 0–2:  填充期    数据还没流到底部,没有结果输出
Cycle 3–6:  满载期    每个 cycle 都有结果从底部流出
Cycle 6:    峰值     4 个结果同时输出(最高吞吐)
Cycle 7–9:  排空期    右下角的结果最后流出

总共 3N - 2 = 10 个 cycle 完成 4×4 矩阵乘法。N×N 阵列做 N×N matmul,延迟是 3N-2,但如果连续喂入多组数据(pipeline),稳态吞吐量是 每 cycle 产出 N 个结果

大矩阵怎么办:分块计算
#

实际的 MXU 是 128×128 或 256×256,但矩阵通常远大于此。以 [2048×1024] × [1024×4096] 为例:

TILE = 128  # MXU 尺寸
C = zeros(2048, 4096)

# 循环顺序:n → k → m(权重复用最优)
for n in range(0, 4096, TILE):        # 32 块,遍历输出列方向
    for k in range(0, 1024, TILE):    # 8 块,遍历收缩维

        # Step 1: 加载权重 tile(每个 (n,k) 组合只加载一次)
        B_tile = B[k:k+128, n:n+128]
        mxu.load_weights(B_tile)          # vmatpush,~128 cycles

        for m in range(0, 2048, TILE):    # 16 块,多组激活流过同一组权重
            # Step 2: 激活值流过 MXU
            A_tile = A[m:m+128, k:k+128]
            partial = mxu.stream_through(A_tile)

            # Step 3: 累加到对应的 C tile
            C[m:m+128, n:n+128] += partial

循环顺序选择 n → k → m 而不是 m → n → k,是因为 MXU 是 weight-stationary(权重固定)架构——每次 vmatpush 加载权重需要 ~128 cycles,代价很高。把 m 放在最内层,同一组权重被 M/TILE = 16 组 A tile 连续复用,权重加载的开销被摊薄 16 倍。如果 k 在最内层,每次 k 迭代都要换权重,同一组权重只被用 1 次。

每次权重加载还会带来流水线空拍(填充 + 排空各 ~127 cycles)。连续流过的 A tile 越多,空拍占比越低,MXU 利用率越高。

每次 MXU 调用就是我们上面模拟的过程——只不过是 128×128 而不是 4×4。

上一篇文章 的 LLO 分析中,我们看到的 vmatpushvmatmulvpop 指令序列,就是这个过程在真实硬件上的体现:

%v63 = vld [vmem:[%s55] sm:$0xff]        ← 从 VMEM 加载权重 tile
%64 = vmatpush.bf16.msra.mxu0 %v63       ← 把权重推入脉动阵列
...
%v93 = vld [vmem:[#allocation0] sm:$0xff] ← 加载激活值 tile
%94 = vmatmul.bf16.gmra.mxu0 %v93        ← 激活值流过阵列
%v95 = vpop.f32.mrf.mxu0                 ← 从底部取出结果

两个 MXU 怎么协作
#

v6e 和 v7x 的每个 TensorCore 有两个 MXU,它们可以沿不同维度拆分工作:

方式一:沿 K 维切分(内积拆半)

A = [A_left | A_right]    B = [B_top ]
                              [B_bot ]

MXU 0:  A_left  × B_top  = C_partial_0
MXU 1:  A_right × B_bot  = C_partial_1

C = C_partial_0 + C_partial_1   ← 两个部分和相加

方式二:沿 M 或 N 维切分(输出拆半)

MXU 0:  A_top × B = C_top      ← 算输出的上半部分
MXU 1:  A_bot × B = C_bot      ← 算输出的下半部分

C = [C_top]                     ← 直接拼接,无需相加
    [C_bot]

两个 MXU 就像两条并行的生产线,吞吐量直接翻倍。

相关文章