[{"content":" 关于我 # 浙江大学（本科）/ 清华大学（硕士），电子与通信工程。先后在美团和 MiniMax 负责 AI Infra，专注大规模模型训练与推理基础设施。\n技术方向 # 大规模分布式训练（Megatron-LM、MaxText） 推理优化与部署（SGLang、vLLM） GPU/TPU Kernel 开发（CUDA、Triton、Pallas） 训练基础设施与 MLOps ","date":"2026年4月4日","externalUrl":null,"permalink":"/","section":"","summary":"","title":"","type":"page"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/categories/","section":"Categories","summary":"","title":"Categories","type":"categories"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/fp8/","section":"Tags","summary":"","title":"FP8","type":"tags"},{"content":"基于 Pallas GMM FP8 blockwise 量化内核开发的实战问答。\nQ1: 优化目标与初始挑战 # 问： GMM FP8 内核优化的出发点是什么？最初 FP8 相比 BF16 性能表现如何？瓶颈在哪里？\n答： 目标是以 blockwise 方式实现 FP8 GMM，对齐 Megatron 的量化方案。初始 FP8 kernel 比 BF16 更慢，根本原因有两个：\nScale 校正开销 — FP8 blockwise 量化需要在 matmul 前后做 scale 乘除（反量化），这些 VPU 运算未能与 MXU 计算有效 overlap Baseline 强制 tk == block_size == 128 — 量化 block 大小是 128，baseline 直接把 K-tile 也限死为 128，导致 MXU 每次计算量太小 因果链：tk 被限死为 128 → 每次 MXU 计算量太小 → MXU 空闲等待 → 吞吐量低\nQ2: Blockwise 量化粒度 # 问： block_size=128 具体是什么含义？LHS 和 RHS 的量化 block 形状有什么区别？\n答：\nLHS（激活） RHS（权重） Block 形状 1 × 128 128 × 128 Scale 粒度 每行每 128 元素一个 scale 每 128×128 块一个 scale 设计原因 激活是动态的，逐行量化适应每个 token 的数值范围 权重是静态的，大 block 减少 scale 存储开销，可预计算 这种非对称设计与 DeepSeek 的实现对齐，是 FP8 训练的标准做法。\nQ3: 子通道循环（Subchannel Loop） # 问： Baseline 的 tk == block_size 约束是怎么被突破的？子通道循环的核心思路是什么？\n答： Baseline 的 tk == block_size 是过于保守的设计选择（每个 tile 恰好对应一个量化 block），子通道循环将 tiling 与量化解耦：\nBaseline: tk == block_size == 128（一个 tile = 一个量化 block） Subchannel: tk = block_size × num_subchannels（一个 tile 包含多个量化 block） 当 tk=2048, block_size=128 时，一个 K-tile 有 16 个 subchannel，内层循环独立处理每个 subchannel：取 scale → 反量化 → matmul → 累加。MXU 一次处理更大的 K 维度，利用率提升。\n延伸：编译器如何利用更大的 tk？ # Pallas kernel 由 Mosaic 编译器（不是 XLA）编译为 LLO 格式的 VLIW 指令：\nPallas kernel → JAX trace → Jaxpr → Mosaic 编译器 → LLO (VLIW) → TPU 执行 更大的 tk 意味着更多的 subchannel 迭代，编译器获得更大的指令窗口，可以：\n循环展开：看到更多指令，发现并行机会 指令打包：在 VLIW bundle 中将 VPU（scale 运算）和 MXU（matmul）并行执行 DMA 预取：提前加载下一个 subchannel 数据，与当前计算 overlap 注意：Pallas kernel 对 XLA 是不透明的 custom_call，XLA 不优化 kernel 内部。Mosaic 才是编译 kernel body 的编译器。\n延伸：VLIW 越短 = 性能越好？ # 不一定。 VLIW bundle 数量不是可靠的性能指标。反例：将 tm 从 256 增至 512，VLIW 指令数增加，但性能更好——因为每条 MXU 指令处理的数据量翻倍。关键不是 bundle 数量，而是每个 bundle 内各 slot（MXU/VPU/DMA/SFU）的并行利用率。\nQ4: 分阶段 Tiling # 问： 为什么 fwd、bwd、tgmm 需要不同的 tiling 配置？\n答： 本质原因：归约维度不同。\nfwd: C[M,N] += A[M,K] × B[K,N] → 归约维度是 K tgmm: dW[K,N] += X^T[K,M] × dY[M,N] → 归约维度是 M 同一个物理维度 M，在 fwd 中是输出维度（并行铺开），在 tgmm 中变成归约维度（逐步加载、累加、转置）。归约维度决定了每步迭代的 VMEM 占用模式。\ntgmm 中 tm=256 为什么导致寄存器压力？ # fwd（归约步长 tk=128）：每步加载 A[tm, 128] + B[128, tn]，tm=256 只影响累加器大小，每步搬运量由 tk 决定，不大。\ntgmm（归约步长 tm）：如果 tm=256，每步需要同时持有：\n数据 大小 (tm=256, tk=128) X 原始 256 × 128 = 32K elements X 转置副本 128 × 256 = 32K elements dY 256 × tn 累加器 128 × tn TPU VMEM 内转置需要同时保留原始和转置副本。类比：fwd 像\u0026quot;每次拿 128 张牌算，结果放 256 个格子\u0026quot;；tgmm 用 tm=256 像\u0026quot;每次抓 256 张牌还要翻转一摞\u0026quot;——手上东西太多。所以 tgmm 需要降低 tm。\n生产最优 tiling 验证了这一点：\nwi (fwd): (256, 2048, 1024), sk=16 wo (fwd): (256, 512, 2048), sk=4 tgmm: 更小的 tm 以降低 VMEM 压力 Q5: 零反向量化 # 问： 为什么反向传播可以跳过 FP8 量化用 bf16 直接算，但前向不行？\n答： 核心是量化开销的收益比不同：\n前向 反向 权重量化 离线预计算，开销被摊销 激活和梯度都需要在线量化，每步都算 scale 量化代价 scale 运算少，可被 overlap 两个输入都要在线量化，VPU 开销翻倍 精度敏感度 前向对量化噪声较鲁棒 梯度本身有噪声，量化误差累积影响训练稳定性 反向跳过量化的本质：在线量化开销大、精度风险高，收益不足以覆盖代价。 实验验证去掉反向量化后精度提升且性能不降反升。\nQ6: 生产形状推导 # 问： 生产环境下 GMM 的 M、K、N、G 分别是多少？怎么推导的？Dev shape 和生产 shape 是什么关系？\n答：\nwi 层（gate/up projection）：\nM = per_device_batch × seq_len × top_k = 2 × 4096 × 8 = 65536 K = EMB_DIM = 2048 N = MOE_MLP_DIM = 512 G = NUM_EXPERTS = 256 wo 层（down projection）： M=65536, K=512, N=2048, G=256（K 和 N 互换）\nDev shape 的来源：\n生产 shape: M=65536, G=256 EP=8 分片后每设备: M=65536/8=8192, G=256/8=32 Dev shape: M=8192, G=32 ← 直接取 EP=8 的单分片 Dev shape 是故意设为 EP=8 分片后的形状，确保开发阶段优化的 tiling 可以直接用于生产。\nQ7: HLO 分析与性能归因 # 问： FP8 多了 scale 运算，为什么 HLO 算子数和 BF16 一样？FP8 超越 BF16 的性能增益来自哪里？\n答：\n为什么 HLO 算子数一样？ # Scale 运算在 Pallas kernel 内部。HLO 层面看到的只是 custom_call 节点——不管 kernel 里做了什么，从 HLO 角度就是一个算子。差异全部被封装在 kernel 内部，HLO 看不到。\nFP8 性能增益来源 # 两个因素共同贡献：\nMXU 硬件加速 — TPU v7x 的 MXU 对 FP8×FP8 有约 1.81x 吞吐加速（微基准实测：1662 TFLOPS vs BF16 的 918 TFLOPS），仅当两个操作数都是 FP8 时生效 HBM 带宽减半 — FP8 数据量是 BF16 的一半，搬运更快 纯 matmul MXU 加速是 1.81x，但 GMM 实际只有 1.07x-1.24x，差距来自 GMM 的额外开销（scale 运算、group 索引、量化/反量化）。\n注意：FP8 不减少 FLOPs。FLOPs 是运算次数，由矩阵维度决定（2×M×K×N），与精度无关。FP8 减少的是完成这些 FLOPs 所需的内存带宽和计算时间。\nQ8: 进化优化策略 # 问： 为什么用进化算法搜索 tiling？增长趋势说明了什么？\n答：\n为什么用进化算法？ # 搜索空间太大，网格搜索不可行。 tiling 参数包括 3 个阶段 × 每阶段 3 个维度 + subchannel 数，组合是天文数字：\n网格搜索: 10^9 = 10亿种组合（假设每维度 10 个候选）→ 不可行 进化搜索: 每轮 ~5 个变体 × 8 轮 = ~40 次评估 → 达到 2.85x 进化算法通过 mutation + selection 高效收敛：基于当前最优解生成有方向的变体（不是随机改数字），保留最快的，淘汰慢的。LLM 驱动进一步提升变异质量——能理解\u0026quot;为什么这个配置快\u0026quot;。\n注意：进化算法不是剪枝搜索。剪枝是排除法（系统地砍掉子空间），进化是生成法（基于好解生成新解），本质是有方向的随机采样，采样分布逐渐收敛到最优区域。\n增长趋势分析 # Round 1-2: 1.015x → 2.294x （跳跃式增长 → baseline 离最优很远） Round 3-5: 2.335x → 2.651x （稳步增长 → 逐步逼近） Round 6-8: 2.769x → 2.847x （收敛放缓 → 边际递减） 纯 tiling 调整就能 2.85x，说明原始 kernel 的瓶颈不在算法逻辑，而在调度和数据搬运。\nQ9: 精度对齐方法论 # 问： 如何系统性地验证 FP8 kernel 的计算精度？\n答：\n9.1 对齐目标与比较对象 # 精度验证分两层：\n对比 目的 预期差异 FP8 kernel vs FP8 CPU reference 验证 kernel 实现正确性（tiling、subchannel 没引入 bug） 应极小，仅浮点累加顺序差异 FP8 kernel vs BF16 kernel 量化本身带来的精度损失 可接受范围内即可 第一层是正确性验证，第二层是量化影响评估。两者不能混为一谈——如果直接拿 FP8 kernel 对比 BF16，分不清误差来自 kernel 实现还是量化本身。\n9.2 精度指标体系 # 绝对/相对误差（atol / rtol） # 最基础的指标，逐元素比较：\n# allclose 判定：|actual - expected| \u0026lt;= atol + rtol * |expected| np.testing.assert_allclose(kernel_out, ref_out, atol=1e-3, rtol=1e-2) 选取依据：\nFP8 kernel vs FP8 CPU ref：atol=1e-3, rtol=1e-2（严格，只允许累加顺序差异） FP8 kernel vs BF16 kernel：atol=1e-1, rtol=5e-2（宽松，量化本身有精度损失） 局限性：atol/rtol 是逐元素的，无法反映整体分布的偏移趋势。\n信噪比（SNR / SQNR） # 衡量信号与量化噪声的比值，反映整体精度质量：\nSQNR = 10 × log10(||signal||² / ||noise||²) = 10 × log10(||ref||² / ||ref - kernel||²) SQNR \u0026gt; 30 dB：精度良好，量化噪声远小于信号 SQNR 20-30 dB：可接受，需关注是否影响训练收敛 SQNR \u0026lt; 20 dB：精度有风险，需排查 SNR 的优势在于它是一个全局标量指标，不受个别 outlier 元素影响，能直观反映\u0026quot;量化后的结果还有多少有效信息\u0026quot;。\nULP（Unit in the Last Place）误差 # 衡量浮点表示层面的精度距离：\nULP 误差 = |actual - expected| / ULP(expected) 其中 ULP(x) 是 x 所在浮点区间的最小精度单位。ULP 误差 = 1 意味着结果与参考值相差恰好 1 个最小精度位。\nULP 的意义：\n与数值大小无关 — 不像 atol 对大值宽松、对小值严格；ULP 在整个数值范围内给出均匀的精度评估 直接对应硬件精度极限 — FP8 E4M3 的 ULP 比 BF16 大得多，ULP 误差能直接反映\u0026quot;在该精度下还能做到多好\u0026quot; 典型阈值：kernel vs CPU ref 应在 1-2 ULP 以内（仅累加顺序差异），kernel vs BF16 可允许更大的 ULP 差异 指标的配合使用 # atol/rtol → 逐元素守门，快速发现 bug（某个元素算错） SNR/SQNR → 全局质量评估，判断整体可用性 ULP → 精度归因，区分\u0026#34;算法误差\u0026#34;vs\u0026#34;精度极限\u0026#34; 三者结合才能完整评估：atol 过了但 SNR 低，说明多数元素偏差不大但整体噪声偏高；atol 没过但 ULP 在 1-2 以内，说明是浮点精度极限导致的，不是 bug。\n9.3 分层验证策略 # 精度验证不能只看单次 matmul 输出，需要分层递进：\nLayer 1: 单算子级\n单次 GMM 调用，固定输入，对比 kernel vs CPU ref 覆盖所有 shape 组合（wi/wo）和所有 tiling 配置 指标：atol + ULP Layer 2: 单层级（MoE Layer）\n完整的 MoE 前向 + 反向，包含 router、expert 选择、combine 验证梯度是否正确回传 指标：atol + SNR Layer 3: 多步训练\n跑 100-1000 步训练，对比 FP8 和 BF16 的 loss 曲线 关注 loss 是否发散、收敛速度是否一致 指标：loss 曲线对比、梯度 SNR 随步数的变化趋势 Layer 4: 端到端验证\n长时间训练（数千步以上），确认 FP8 不影响最终模型质量 这一层不再看逐元素精度，而是看训练指标（loss、eval metric） 9.4 关闭精度检查的问题 # 进化优化后期将 atol 设为 1e10（完全无约束）不合理。进化算法可能朝精度不可用的方向优化，浪费搜索预算。更合理的做法是放宽但不关闭：\n前期: atol=1e-3, SNR\u0026gt;30dB （严格，确保基本正确） 中期: atol=1e-1, SNR\u0026gt;20dB （放宽，允许更多探索） 后期: atol=1e-3, SNR\u0026gt;30dB （收紧，验证最终候选） Q10: 端到端实施路线 # 问： 在新 TPU 平台上为新模型实现 FP8 GMM kernel，从需求到交付的关键步骤？\n答：\nStep 0: 了解新平台硬件规格\nMXU 对 FP8 的吞吐（跑微基准确认硬件加速倍数） VMEM 大小（决定 tiling 上限） HBM 带宽（判断 kernel 是 compute-bound 还是 memory-bound） Step 1: 确定生产 shape\n从训练配置推导：per_device_batch、seq_len、top_k → M 从模型配置推导：EMB_DIM → K，MOE_MLP_DIM → N，NUM_EXPERTS/EP → G 确定 dev shape（取 EP 分片后的单卡 shape） Step 2: 实现 FP8 CPU reference\n精度正确的参考实现，作为 kernel 对齐目标 Step 3: 设计 block_size 与量化方案\n确定 LHS/RHS 的 block 形状和 scale 粒度 决定是否采用零反向量化 Step 4: 实现 BF16 baseline kernel\n性能基准线，没有 baseline 无法量化 FP8 收益 Step 5: 实现 FP8 kernel 并优化\n实现子通道循环（解耦 tk 与 block_size） 分阶段搜索 tiling：fwd/bwd/tgmm 独立优化 借助 profiling + LLO 分析 + 进化算法探索 全程保持精度对齐 CPU reference Step 6: 集成与端到端验证\n集成到训练框架，验证端到端性能提升 对比训练 loss 曲线，确保收敛性不受影响 ","date":"2026年4月4日","externalUrl":null,"permalink":"/posts/gemm-kernel-optimization-10-questions/","section":"博客","summary":"基于 Pallas GMM FP8 blockwise 量化内核开发的实战问答，涵盖量化粒度、子通道循环、分阶段 tiling、编译链路、精度对齐方法论等核心话题。","title":"GEMM Kernel 性能优化十问","type":"posts"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/kernel/","section":"Tags","summary":"","title":"Kernel","type":"tags"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/moe/","section":"Tags","summary":"","title":"MoE","type":"tags"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/pallas/","section":"Tags","summary":"","title":"Pallas","type":"tags"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/","section":"Tags","summary":"","title":"Tags","type":"tags"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/tags/tpu/","section":"Tags","summary":"","title":"TPU","type":"tags"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/posts/","section":"博客","summary":"","title":"博客","type":"posts"},{"content":"","date":"2026年4月4日","externalUrl":null,"permalink":"/categories/%E6%8A%80%E6%9C%AF/","section":"Categories","summary":"","title":"技术","type":"categories"},{"content":"","date":"2026年3月20日","externalUrl":null,"permalink":"/tags/jax/","section":"Tags","summary":"","title":"JAX","type":"tags"},{"content":" 项目简介 # 为 SGLang 实现的 JAX 后端，使其能够在 TPU 上进行高效推理。\n技术栈 # Python / JAX TPU SGLang ","date":"2026年3月20日","externalUrl":"https://github.com/sgl-project/sglang-jax","permalink":"/projects/sglang-jax/","section":"项目","summary":"SGLang 的 JAX 后端实现，支持在 TPU 上运行高效推理。","title":"SGLang-JAX","type":"projects"},{"content":"","date":"2026年3月20日","externalUrl":null,"permalink":"/tags/%E6%8E%A8%E7%90%86/","section":"Tags","summary":"","title":"推理","type":"tags"},{"content":"","date":"2026年3月20日","externalUrl":null,"permalink":"/categories/%E9%A1%B9%E7%9B%AE/","section":"Categories","summary":"","title":"项目","type":"categories"},{"content":"","date":"2026年3月20日","externalUrl":null,"permalink":"/projects/","section":"项目","summary":"","title":"项目","type":"projects"},{"content":"","date":"2025年10月29日","externalUrl":null,"permalink":"/tags/sglang/","section":"Tags","summary":"","title":"SGLang","type":"tags"},{"content":"","date":"2025年10月29日","externalUrl":"https://www.lmsys.org/blog/2025-10-29-sglang-jax/","permalink":"/posts/sglang-jax-tpu-inference/","section":"博客","summary":"SGLang-JAX 是基于 JAX 和 XLA 构建的开源推理引擎，支持 continuous batching 和 speculative decoding，在 TPU 上实现高效原生推理。","title":"SGLang-JAX: An Open-Source Solution for Native TPU Inference","type":"posts"},{"content":"","externalUrl":null,"permalink":"/series/","section":"Series","summary":"","title":"Series","type":"series"}]