Skip to content

feat: MegaMOE adaptation for SM90#323

Open
qiushixiaoyu wants to merge 1 commit into
deepseek-ai:mainfrom
qiushixiaoyu:main
Open

feat: MegaMOE adaptation for SM90#323
qiushixiaoyu wants to merge 1 commit into
deepseek-ai:mainfrom
qiushixiaoyu:main

Conversation

@qiushixiaoyu

@qiushixiaoyu qiushixiaoyu commented Apr 30, 2026

Copy link
Copy Markdown

Summary

Add SM90/Hopper FP8 MegaMoE decode support to DeepGEMM.

This change introduces a fused SM90 MegaMoE kernel path for FP8 weights and FP8 activations, including the C++/CUDA implementation, JIT entry point, Python API wrapper, Hopper-specific weight transform, scheduling heuristics, and a unified Hopper accuracy/performance test script.

Main Changes

  • Added the SM90 FP8 MegaMoE fused kernel implementation:

    • deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh
    • csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp
  • Extended MegaMoE API bindings with fp8_mega_moe for SM90.

  • Added Hopper-specific Python entry points:

    • deep_gemm.fp8_mega_moe
    • deep_gemm.transform_weights_for_mega_moe_sm90
  • Added SM90 MegaMoE scheduling/config heuristics.

  • Updated MegaMoE symmetric buffer handling for SM90 FP32 scale-factor layouts.

  • Added tests/test_mega_moe_hopper.py, covering:

    • Layered accuracy checks
    • Fused-kernel benchmark mode
    • Normal DeepEP + grouped-FP8 baseline comparison
    • Low-latency DeepEP baseline comparison

DeepSeekV4Flash(8 card H20)

batch MegaMoE us normal us normal/MegaMoE LL us LL/MegaMoE
1 188.0 505.0 2.686x 250.1 1.330x
2 239.0 615.8 2.576x 299.9 1.255x
4 361.1 828.1 2.293x 416.0 1.152x
8 432.2 950.4 2.199x 494.8 1.145x
16 482.2 1074.1 2.227x 536.5 1.112x
32 499.0 1087.6 2.180x 540.4 1.083x
64 507.6 1081.6 2.131x 550.9 1.085x
128 511.0 1079.8 2.113x 562.5 1.101x
256 528.4 1178.6 2.231x 625.8 1.184x
512 874.1 1238.4 1.417x - -
1024 1638.1 2101.2 1.283x - -
2048 2856.6 3564.9 1.248x - -
4096 5172.0 6378.5 1.233x - -
8192 9784.6 11959.2 1.222x - -

DeepSeekV4Pro(8 card H20)

batch MegaMoE us normal us normal/MegaMoE LL us LL/MegaMoE
1 410.2 910.2 2.219x 493.9 1.204x
2 560.8 1296.6 2.312x 669.5 1.194x
4 843.0 1861.5 2.208x 986.0 1.170x
8 1257.1 2736.1 2.176x 1383.0 1.100x
16 1507.6 3179.8 2.109x 1628.9 1.080x
32 1550.4 3298.0 2.127x 1708.5 1.102x
64 1554.4 3281.2 2.111x 1711.8 1.101x
128 1584.0 3334.6 2.105x 1736.9 1.097x
256 1613.5 3381.5 2.096x 1805.6 1.119x
512 3014.2 3466.1 1.150x - -
1024 4775.0 5373.1 1.125x - -
2048 7838.6 8723.8 1.113x - -
4096 13666.4 15115.5 1.106x - -
8192 25470.8 28148.6 1.105x - -

Benchmark:DeepSeekV4Flash CP8/EP8

Metric MegaMoE off MegaMoE on
GSM8K accuracy 0.975 0.970
GSM8K invalid 0.000 0.000

SLO-Compliant Total Throughput

Input/Output off on Improvement
3500/1500 4766.11 5143.19 +7.9%
32000/1500 11760.82 15176.28 +29.0%
128000/1500 13714.90 16044.73 +17.0%

Max Throughput

Input/Output off on Improvement
3500/1500 4708.56 5002.54 +6.2%
32000/1500 13177.93 16434.02 +24.7%
128000/1500 15888.99 19490.54 +22.7%

@LyricZhao Could you help review this PR when you have time? If the patch is too large for one PR, I’m happy to split it into smaller parts following your preference.

@qinqinwo

qinqinwo commented May 7, 2026

Copy link
Copy Markdown

Do you have benchmark data?

@qiushixiaoyu

qiushixiaoyu commented May 11, 2026

Copy link
Copy Markdown
Author

Do you have benchmark data?

I’m testing the benefits of DeepSeek V4 Flash on H20, and I’ll share the data soon.

@Stone749990226

Copy link
Copy Markdown

看起来效果并不理想:
CUDA_VISIBLE_DEVICES=0,2,3,4 python tests/test_mega_moe_hopper.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 3072 --num-experts 384 --num-topk 6 --num-bench-tests 30
Config (H200 fused mega-MoE):

Tokens: 8192/8192
Hidden: 7168, Intermediate: 3072
Experts: 6/384 (per-rank: 96)
Activation SF: fused L2 per-64 UE8M0, baseline L2 per-128 UE8M0 (SM90 grouped GEMM constraint)
Buffer: 4.268 GiB

Performance:

[fused] EP 3/4 | 245 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 0/4 | 243 TFLOPS | overlap: 244 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 2/4 | 244 TFLOPS | overlap: 245 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26619 us, reduction: 126.5 us
[fused] EP 1/4 | 244 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26618 us, reduction: 126.5 us
[baseline] EP 2/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9638 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 3/4 | 676 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9640 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 1/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9642 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 0/4 | 669 TFLOPS | HBM 800 GB/s, NVL 110 GB/s | 9653 us | t_baseline/t_fused = 0.36x (baseline 更快)

"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。

结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
  * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
           使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
  * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
              使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
              per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
              同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
              不是 bitwise apples-to-apples correctness oracle。
  * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
                  reduction us / `t_baseline / t_fused` legacy 比。
"""

import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple

import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto


# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"


# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128


# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入  x        : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入  topk_w   : (M,)     fp32,可选
# 输出  y        : (M, H)   fp8_e4m3fn
# 输出  y_sf     : (M, H/BLOCK_K) fp32 行主序
# ============================================================================


@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
    x_ptr,
    topk_w_ptr,
    y_ptr,
    y_sf_ptr,
    M,
    H,  # 运行时形状
    stride_xm,
    stride_xn,  # x: (M, 2H) 的 stride
    stride_ym,
    stride_yn,  # y: (M, H)  的 stride
    stride_sfm,
    stride_sfk,  # y_sf: (M, H/BLOCK_K) 的 stride
    clamp_value,  # 当 HAS_CLAMP=False 时这个参数无意义
    HAS_TOPK: tl.constexpr,
    HAS_CLAMP: tl.constexpr,
    USE_UE8M0_SCALE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,  # = num_per_channels
):
    # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)

    # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # 当前 K-block 内的列索引(在 H 维度,不是 2H)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    mask_m = offs_m < M

    # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
    # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
    gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
    up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
    gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
    up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)

    # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
    if HAS_CLAMP:
        gate = tl.minimum(gate, clamp_value)
        up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)

    # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
    y = gate * tl.sigmoid(gate) * up

    # ---- 4) 可选 MoE 权重缩放(per-token 标量)----
    if HAS_TOPK:
        w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
        y = y * w[:, None]

    # ---- 5) 当前 K-block 内每行 absmax → scale ----
    amax = tl.max(tl.abs(y), axis=1)  # (BLOCK_M,)
    sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
    if USE_UE8M0_SCALE:
        # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
        # scale = 2 ** ceil(log2(amax / 448)).
        sf = tl.exp2(tl.ceil(tl.log2(sf)))

    # ---- 6) 量化为 FP8 e4m3fn ----
    y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)

    # ---- 7) 写回 y 和 sf ----
    y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
    tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])

    sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
    tl.store(sf_ptrs, sf, mask=mask_m)


def swiglu_apply_weight_to_fp8_triton(
    x: torch.Tensor,
    topk_weights: torch.Tensor | None,
    clamp_value: float | None = None,
    num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
    use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """SwiGLU + FP8 量化。语义等价于 PyTorch reference:
    gate, up = x[:, :H], x[:, H:]
    y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
    y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
    if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
    y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
    """
    assert x.is_cuda and x.dtype == torch.bfloat16
    assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
    M, two_H = x.shape
    H = two_H // 2
    assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"

    y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
    y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)

    # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
    BLOCK_M = 16
    grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)

    # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
    topk_ptr = topk_weights if topk_weights is not None else x

    _swiglu_apply_weight_to_fp8_kernel[grid](
        x,
        topk_ptr,
        y,
        y_sf,
        M,
        H,
        x.stride(0),
        x.stride(1),
        y.stride(0),
        y.stride(1),
        y_sf.stride(0),
        y_sf.stride(1),
        float(clamp_value) if clamp_value is not None else 0.0,
        HAS_TOPK=topk_weights is not None,
        HAS_CLAMP=clamp_value is not None,
        USE_UE8M0_SCALE=use_ue8m0_scale,
        BLOCK_M=BLOCK_M,
        BLOCK_K=num_per_channels,
    )
    return y, y_sf


# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
#   每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
#   * 不需要 deep_gemm.transform_sf_into_required_layout
#   * SF 是 FP32,不是 UE8M0 packed
# ============================================================================


def _quantize_grouped_fp8_block_128_128(
    w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
    g, n, k = w.shape
    assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"

    # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
    w_view = w.view(g, n // 128, 128, k // 128, 128).float()

    # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
    amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4)  # (G, N/128, K/128)
    sf = amax / FP8_E4M3_MAX

    # 量化:每个元素除以所属子块的 sf 后转 FP8
    # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
    w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
    return w_fp8.view(g, n, k).contiguous(), sf.contiguous()


# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================


def _import_deep_ep():
    try:
        import deep_ep

        return deep_ep
    except Exception as ex:
        dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
        return None


# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================


def _bench_cuda_events(
    fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
    """返回 fn 的中位数耗时(秒)。"""
    for _ in range(num_warmup):
        fn()
    torch.cuda.synchronize()
    times_ms = []
    for _ in range(num_repeat):
        # L2 flush,避免重复访问命中 cache 让测时偏低
        if l2_flush_gb > 0:
            free_bytes, _ = torch.cuda.mem_get_info()
            flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
            if flush_bytes >= 4:
                torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        fn()
        e.record()
        e.synchronize()
        times_ms.append(s.elapsed_time(e))
    times_ms.sort()
    return times_ms[len(times_ms) // 2] / 1e3


# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================


def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
    rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
    torch.manual_seed(rank_idx)
    random.seed(rank_idx)

    # 形状参数(与 test_mega_moe.py 同名同义)
    num_max_tokens_per_rank = args.num_max_tokens_per_rank
    num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
    hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
    num_experts, num_topk = args.num_experts, args.num_topk
    num_experts_per_rank = num_experts // num_ranks
    assert num_tokens <= num_max_tokens_per_rank
    assert num_experts % num_ranks == 0, (
        f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
    )

    # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
    #   * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
    #   * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
    assert hidden % 128 == 0
    assert intermediate_hidden % 128 == 0
    assert intermediate_hidden // 64 <= 64, (
        f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
    )

    # ---- 创建 BF16 输入:token 与两层 weight ----
    # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
    x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
    # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
    l1_weights_bf16 = torch.randn(
        (num_experts_per_rank, intermediate_hidden * 2, hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )
    # L2 weight: 每个 expert 把 intermediate_hidden → hidden
    l2_weights_bf16 = torch.randn(
        (num_experts_per_rank, hidden, intermediate_hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )

    # 路由:scores → topk_idx (M, K) + topk_weights (M, K)
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
    topk_weights, topk_idx = torch.topk(
        scores, num_topk, dim=-1, largest=True, sorted=False
    )

    # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
    cum_stats_fused = torch.zeros(
        (num_experts_per_rank,), dtype=torch.int, device="cuda"
    )
    cum_stats_baseline = cum_stats_fused.clone()

    # ---- BF16 → FP8 量化 ----
    # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
    # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
    x_fp8 = per_token_cast_to_fp8(
        x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
    )

    # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
    # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
    l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
    l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)

    # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
    transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
        l1_weights, l2_weights
    )

    # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
    clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None

    # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
    alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
    deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)

    # ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
    sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
        group,
        num_experts,
        num_max_tokens_per_rank,
        num_topk,
        hidden,
        intermediate_hidden,
    )
    y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")

    def run_fused():
        # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
        # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
        sym_buffer.x[:num_tokens].copy_(x_fp8[0])
        sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
        sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
        sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)

        deep_gemm.fp8_mega_moe(
            y_fused,
            transformed_l1,
            transformed_l2,
            sym_buffer,
            cumulative_local_expert_recv_stats=cum_stats_fused,
            recipe=(128, 128, 128),
            activation="swiglu",
            activation_clamp=clamp_arg,
            fast_math=bool(args.fast_math),
        )
        return y_fused

    # ---- 分配 DeepEP buffer(baseline 用)----
    deep_ep = _import_deep_ep()
    ep_buffer = None
    if deep_ep is not None:
        ep_buffer = deep_ep.ElasticBuffer(
            group,
            num_max_tokens_per_rank=num_max_tokens_per_rank,
            hidden=hidden,
            num_topk=num_topk,
            use_fp8_dispatch=True,
            explicitly_destroy=True,
            allow_multiple_reduction=False,
        )

    # ----------------------------------------------------------------
    # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
    # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
    # 的版本(baseline grouped GEMM 不需要 gate/up interleave)
    # ----------------------------------------------------------------
    def run_baseline():
        recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
            x_fp8,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            cumulative_local_expert_recv_stats=cum_stats_baseline,
            num_experts=num_experts,
            expert_alignment=alignment,
            do_cpu_sync=False,
            do_handle_copy=False,
            do_expand=True,
            use_tma_aligned_col_major_sf=False,  # SM90: row-major float SF
        )
        n = recv_x[0].size(0)

        # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
        l1_y = torch.empty(
            (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
        )
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            recv_x,
            l1_weights,
            l1_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # Triton SwiGLU + FP8 量化(含 topk 权重乘法)
        # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
        # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
        # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
        l1_y = swiglu_apply_weight_to_fp8_triton(
            x=l1_y,
            topk_weights=recv_topk_weights,
            clamp_value=clamp_arg,
            num_per_channels=BASELINE_L2_ACT_SF_GRAN,
            use_ue8m0_scale=True,
        )

        # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
        l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            l1_y,
            l2_weights,
            l2_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
        return ep_buffer.combine(l2_y, handle=handle)[0]

    # ---- 打印 config ----
    dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
    dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
    dist_print(
        f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
    )
    dist_print(
        f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
        once_in_node=True,
    )
    dist_print(
        f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
        f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
        f"(SM90 grouped GEMM constraint)",
        once_in_node=True,
    )
    dist_print(
        f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
    )
    dist_print(once_in_node=True)

    # ---- 跑一次确保不报错(fused + 可选 baseline)----
    y = run_fused()
    assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
        f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
    )
    if ep_buffer is not None:
        out_b = run_baseline()
        assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
            f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
        )
        if args.check_output_diff:
            diff = (y.float() - out_b.float()).abs()
            denom = out_b.float().abs().mean().clamp_min(1e-12)
            dist_print(
                "Output diff (fused vs legacy-per128 baseline):", once_in_node=True
            )
            dist_print(
                f" > max_abs={diff.max().item():.6e}, "
                f"mean_abs={diff.mean().item():.6e}, "
                f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
                once_in_node=True,
            )
            dist_print(once_in_node=True)

    # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
    # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
    # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
    gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
    gathered_topk_idx[
        (gathered_topk_idx < rank_idx * num_experts_per_rank)
        | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
    ] = -1
    num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
    num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)

    # ---- benchmark ----
    # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
    t_fused = bench_kineto(
        run_fused,
        SM90_KERNEL_NAME,
        num_tests=args.num_bench_tests,
        barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
        if ep_buffer is not None
        else dist.barrier(),
        trace_path=(
            f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
            if args.dump_profile_traces
            else None
        ),
    )
    # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
    t_baseline = (
        _bench_cuda_events(
            run_baseline,
            num_warmup=args.num_warmup,
            num_repeat=args.num_repeat,
            l2_flush_gb=args.l2_flush_gb,
        )
        if ep_buffer is not None
        else 0.0
    )

    def safe_div(a, b):
        return float("nan") if b == 0 else a / b

    # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
    tflops = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
    )

    # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
    l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
    l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
    l1_weight_sf_bytes = (
        num_touched_experts
        * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
        * (hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l2_weight_sf_bytes = (
        num_touched_experts
        * (hidden // WEIGHT_SF_GRAN_MN)
        * (intermediate_hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
    l2_act_sf_bytes = (
        num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
    )
    num_hbm_bytes = (
        l1_weight_bytes
        + l2_weight_bytes  # weights (FP8)
        + l1_weight_sf_bytes
        + l2_weight_sf_bytes  # weight SF (FP32)
        + num_recv_tokens * hidden
        + l1_input_sf_bytes  # L1 输入读 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L1 输出写 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L2 输入读 (FP8 + SF)
        + num_recv_tokens * hidden * 2  # L2 输出写 (BF16)
    )
    hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)

    # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
    num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
    nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)

    # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
    t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12

    # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
    approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)

    # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
    tflops_baseline = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
    )
    hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
    nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)

    dist_print("Performance:", once_in_node=True)
    dist_print(
        f" > [fused]    EP {rank_idx:2}/{num_ranks} | "
        f"{tflops:4.0f} TFLOPS | "
        f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
        f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
        f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
        f"{t_fused * 1e6:6.0f} us, "
        f"reduction: {t_reduction * 1e6:5.1f} us"
    )
    if ep_buffer is not None:
        speedup = safe_div(t_baseline, t_fused)
        dist_print(
            f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
            f"{tflops_baseline:4.0f} TFLOPS | "
            f"               HBM {hbm_gbs_baseline:4.0f} GB/s, "
            f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
            f"{t_baseline * 1e6:6.0f} us | "
            f"t_baseline/t_fused = {speedup:.2f}x "
            f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
        )
    else:
        dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)

    # ---- 清理 ----
    dist.barrier()
    sym_buffer.destroy()
    if ep_buffer is not None:
        ep_buffer.destroy()
    dist.destroy_process_group()


# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
    )

    # 资源
    parser.add_argument(
        "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
    )

    # 模型形状
    # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
    parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
    parser.add_argument(
        "--num-tokens",
        type=int,
        default=0,
        help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
    )
    parser.add_argument("--hidden", type=int, default=7168)
    parser.add_argument(
        "--intermediate-hidden",
        type=int,
        default=3072,
        help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
    )
    parser.add_argument(
        "--activation-clamp",
        type=float,
        default=10.0,
        help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
    )
    parser.add_argument("--num-experts", type=int, default=384)
    parser.add_argument("--num-topk", type=int, default=6)
    parser.add_argument(
        "--fast-math",
        type=int,
        default=1,
        help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
    )

    # 测时
    parser.add_argument(
        "--num-bench-tests",
        type=int,
        default=30,
        help="bench_kineto 抓 fused 时的迭代数",
    )
    parser.add_argument(
        "--num-warmup", type=int, default=5, help="baseline cuda events warmup"
    )
    parser.add_argument(
        "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
    )
    parser.add_argument(
        "--l2-flush-gb",
        type=float,
        default=8.0,
        help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
    )
    parser.add_argument(
        "--check-output-diff",
        type=int,
        default=0,
        help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
    )
    parser.add_argument(
        "--dump-profile-traces",
        type=str,
        default="",
        help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
    )

    args = parser.parse_args()

    if args.dump_profile_traces:
        os.makedirs(args.dump_profile_traces, exist_ok=True)

    # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
    torch.multiprocessing.spawn(
        test, args=(args.num_processes, args), nprocs=args.num_processes
    )

@MikeFang-dev

Copy link
Copy Markdown

看起来效果并不理想: CUDA_VISIBLE_DEVICES=0,2,3,4 python tests/test_mega_moe_hopper.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 3072 --num-experts 384 --num-topk 6 --num-bench-tests 30 Config (H200 fused mega-MoE):

Tokens: 8192/8192
Hidden: 7168, Intermediate: 3072
Experts: 6/384 (per-rank: 96)
Activation SF: fused L2 per-64 UE8M0, baseline L2 per-128 UE8M0 (SM90 grouped GEMM constraint)
Buffer: 4.268 GiB

Performance:

[fused] EP 3/4 | 245 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 0/4 | 243 TFLOPS | overlap: 244 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 2/4 | 244 TFLOPS | overlap: 245 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26619 us, reduction: 126.5 us
[fused] EP 1/4 | 244 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26618 us, reduction: 126.5 us
[baseline] EP 2/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9638 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 3/4 | 676 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9640 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 1/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9642 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 0/4 | 669 TFLOPS | HBM 800 GB/s, NVL 110 GB/s | 9653 us | t_baseline/t_fused = 0.36x (baseline 更快)

"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。

结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
  * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
           使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
  * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
              使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
              per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
              同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
              不是 bitwise apples-to-apples correctness oracle。
  * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
                  reduction us / `t_baseline / t_fused` legacy 比。
"""

import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple

import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto


# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"


# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128


# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入  x        : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入  topk_w   : (M,)     fp32,可选
# 输出  y        : (M, H)   fp8_e4m3fn
# 输出  y_sf     : (M, H/BLOCK_K) fp32 行主序
# ============================================================================


@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
    x_ptr,
    topk_w_ptr,
    y_ptr,
    y_sf_ptr,
    M,
    H,  # 运行时形状
    stride_xm,
    stride_xn,  # x: (M, 2H) 的 stride
    stride_ym,
    stride_yn,  # y: (M, H)  的 stride
    stride_sfm,
    stride_sfk,  # y_sf: (M, H/BLOCK_K) 的 stride
    clamp_value,  # 当 HAS_CLAMP=False 时这个参数无意义
    HAS_TOPK: tl.constexpr,
    HAS_CLAMP: tl.constexpr,
    USE_UE8M0_SCALE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,  # = num_per_channels
):
    # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)

    # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # 当前 K-block 内的列索引(在 H 维度,不是 2H)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    mask_m = offs_m < M

    # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
    # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
    gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
    up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
    gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
    up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)

    # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
    if HAS_CLAMP:
        gate = tl.minimum(gate, clamp_value)
        up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)

    # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
    y = gate * tl.sigmoid(gate) * up

    # ---- 4) 可选 MoE 权重缩放(per-token 标量)----
    if HAS_TOPK:
        w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
        y = y * w[:, None]

    # ---- 5) 当前 K-block 内每行 absmax → scale ----
    amax = tl.max(tl.abs(y), axis=1)  # (BLOCK_M,)
    sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
    if USE_UE8M0_SCALE:
        # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
        # scale = 2 ** ceil(log2(amax / 448)).
        sf = tl.exp2(tl.ceil(tl.log2(sf)))

    # ---- 6) 量化为 FP8 e4m3fn ----
    y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)

    # ---- 7) 写回 y 和 sf ----
    y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
    tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])

    sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
    tl.store(sf_ptrs, sf, mask=mask_m)


def swiglu_apply_weight_to_fp8_triton(
    x: torch.Tensor,
    topk_weights: torch.Tensor | None,
    clamp_value: float | None = None,
    num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
    use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """SwiGLU + FP8 量化。语义等价于 PyTorch reference:
    gate, up = x[:, :H], x[:, H:]
    y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
    y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
    if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
    y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
    """
    assert x.is_cuda and x.dtype == torch.bfloat16
    assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
    M, two_H = x.shape
    H = two_H // 2
    assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"

    y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
    y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)

    # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
    BLOCK_M = 16
    grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)

    # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
    topk_ptr = topk_weights if topk_weights is not None else x

    _swiglu_apply_weight_to_fp8_kernel[grid](
        x,
        topk_ptr,
        y,
        y_sf,
        M,
        H,
        x.stride(0),
        x.stride(1),
        y.stride(0),
        y.stride(1),
        y_sf.stride(0),
        y_sf.stride(1),
        float(clamp_value) if clamp_value is not None else 0.0,
        HAS_TOPK=topk_weights is not None,
        HAS_CLAMP=clamp_value is not None,
        USE_UE8M0_SCALE=use_ue8m0_scale,
        BLOCK_M=BLOCK_M,
        BLOCK_K=num_per_channels,
    )
    return y, y_sf


# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
#   每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
#   * 不需要 deep_gemm.transform_sf_into_required_layout
#   * SF 是 FP32,不是 UE8M0 packed
# ============================================================================


def _quantize_grouped_fp8_block_128_128(
    w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
    g, n, k = w.shape
    assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"

    # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
    w_view = w.view(g, n // 128, 128, k // 128, 128).float()

    # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
    amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4)  # (G, N/128, K/128)
    sf = amax / FP8_E4M3_MAX

    # 量化:每个元素除以所属子块的 sf 后转 FP8
    # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
    w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
    return w_fp8.view(g, n, k).contiguous(), sf.contiguous()


# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================


def _import_deep_ep():
    try:
        import deep_ep

        return deep_ep
    except Exception as ex:
        dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
        return None


# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================


def _bench_cuda_events(
    fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
    """返回 fn 的中位数耗时(秒)。"""
    for _ in range(num_warmup):
        fn()
    torch.cuda.synchronize()
    times_ms = []
    for _ in range(num_repeat):
        # L2 flush,避免重复访问命中 cache 让测时偏低
        if l2_flush_gb > 0:
            free_bytes, _ = torch.cuda.mem_get_info()
            flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
            if flush_bytes >= 4:
                torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        fn()
        e.record()
        e.synchronize()
        times_ms.append(s.elapsed_time(e))
    times_ms.sort()
    return times_ms[len(times_ms) // 2] / 1e3


# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================


def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
    rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
    torch.manual_seed(rank_idx)
    random.seed(rank_idx)

    # 形状参数(与 test_mega_moe.py 同名同义)
    num_max_tokens_per_rank = args.num_max_tokens_per_rank
    num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
    hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
    num_experts, num_topk = args.num_experts, args.num_topk
    num_experts_per_rank = num_experts // num_ranks
    assert num_tokens <= num_max_tokens_per_rank
    assert num_experts % num_ranks == 0, (
        f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
    )

    # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
    #   * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
    #   * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
    assert hidden % 128 == 0
    assert intermediate_hidden % 128 == 0
    assert intermediate_hidden // 64 <= 64, (
        f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
    )

    # ---- 创建 BF16 输入:token 与两层 weight ----
    # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
    x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
    # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
    l1_weights_bf16 = torch.randn(
        (num_experts_per_rank, intermediate_hidden * 2, hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )
    # L2 weight: 每个 expert 把 intermediate_hidden → hidden
    l2_weights_bf16 = torch.randn(
        (num_experts_per_rank, hidden, intermediate_hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )

    # 路由:scores → topk_idx (M, K) + topk_weights (M, K)
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
    topk_weights, topk_idx = torch.topk(
        scores, num_topk, dim=-1, largest=True, sorted=False
    )

    # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
    cum_stats_fused = torch.zeros(
        (num_experts_per_rank,), dtype=torch.int, device="cuda"
    )
    cum_stats_baseline = cum_stats_fused.clone()

    # ---- BF16 → FP8 量化 ----
    # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
    # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
    x_fp8 = per_token_cast_to_fp8(
        x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
    )

    # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
    # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
    l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
    l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)

    # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
    transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
        l1_weights, l2_weights
    )

    # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
    clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None

    # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
    alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
    deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)

    # ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
    sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
        group,
        num_experts,
        num_max_tokens_per_rank,
        num_topk,
        hidden,
        intermediate_hidden,
    )
    y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")

    def run_fused():
        # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
        # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
        sym_buffer.x[:num_tokens].copy_(x_fp8[0])
        sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
        sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
        sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)

        deep_gemm.fp8_mega_moe(
            y_fused,
            transformed_l1,
            transformed_l2,
            sym_buffer,
            cumulative_local_expert_recv_stats=cum_stats_fused,
            recipe=(128, 128, 128),
            activation="swiglu",
            activation_clamp=clamp_arg,
            fast_math=bool(args.fast_math),
        )
        return y_fused

    # ---- 分配 DeepEP buffer(baseline 用)----
    deep_ep = _import_deep_ep()
    ep_buffer = None
    if deep_ep is not None:
        ep_buffer = deep_ep.ElasticBuffer(
            group,
            num_max_tokens_per_rank=num_max_tokens_per_rank,
            hidden=hidden,
            num_topk=num_topk,
            use_fp8_dispatch=True,
            explicitly_destroy=True,
            allow_multiple_reduction=False,
        )

    # ----------------------------------------------------------------
    # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
    # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
    # 的版本(baseline grouped GEMM 不需要 gate/up interleave)
    # ----------------------------------------------------------------
    def run_baseline():
        recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
            x_fp8,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            cumulative_local_expert_recv_stats=cum_stats_baseline,
            num_experts=num_experts,
            expert_alignment=alignment,
            do_cpu_sync=False,
            do_handle_copy=False,
            do_expand=True,
            use_tma_aligned_col_major_sf=False,  # SM90: row-major float SF
        )
        n = recv_x[0].size(0)

        # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
        l1_y = torch.empty(
            (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
        )
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            recv_x,
            l1_weights,
            l1_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # Triton SwiGLU + FP8 量化(含 topk 权重乘法)
        # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
        # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
        # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
        l1_y = swiglu_apply_weight_to_fp8_triton(
            x=l1_y,
            topk_weights=recv_topk_weights,
            clamp_value=clamp_arg,
            num_per_channels=BASELINE_L2_ACT_SF_GRAN,
            use_ue8m0_scale=True,
        )

        # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
        l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            l1_y,
            l2_weights,
            l2_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
        return ep_buffer.combine(l2_y, handle=handle)[0]

    # ---- 打印 config ----
    dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
    dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
    dist_print(
        f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
    )
    dist_print(
        f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
        once_in_node=True,
    )
    dist_print(
        f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
        f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
        f"(SM90 grouped GEMM constraint)",
        once_in_node=True,
    )
    dist_print(
        f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
    )
    dist_print(once_in_node=True)

    # ---- 跑一次确保不报错(fused + 可选 baseline)----
    y = run_fused()
    assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
        f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
    )
    if ep_buffer is not None:
        out_b = run_baseline()
        assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
            f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
        )
        if args.check_output_diff:
            diff = (y.float() - out_b.float()).abs()
            denom = out_b.float().abs().mean().clamp_min(1e-12)
            dist_print(
                "Output diff (fused vs legacy-per128 baseline):", once_in_node=True
            )
            dist_print(
                f" > max_abs={diff.max().item():.6e}, "
                f"mean_abs={diff.mean().item():.6e}, "
                f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
                once_in_node=True,
            )
            dist_print(once_in_node=True)

    # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
    # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
    # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
    gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
    gathered_topk_idx[
        (gathered_topk_idx < rank_idx * num_experts_per_rank)
        | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
    ] = -1
    num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
    num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)

    # ---- benchmark ----
    # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
    t_fused = bench_kineto(
        run_fused,
        SM90_KERNEL_NAME,
        num_tests=args.num_bench_tests,
        barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
        if ep_buffer is not None
        else dist.barrier(),
        trace_path=(
            f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
            if args.dump_profile_traces
            else None
        ),
    )
    # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
    t_baseline = (
        _bench_cuda_events(
            run_baseline,
            num_warmup=args.num_warmup,
            num_repeat=args.num_repeat,
            l2_flush_gb=args.l2_flush_gb,
        )
        if ep_buffer is not None
        else 0.0
    )

    def safe_div(a, b):
        return float("nan") if b == 0 else a / b

    # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
    tflops = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
    )

    # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
    l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
    l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
    l1_weight_sf_bytes = (
        num_touched_experts
        * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
        * (hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l2_weight_sf_bytes = (
        num_touched_experts
        * (hidden // WEIGHT_SF_GRAN_MN)
        * (intermediate_hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
    l2_act_sf_bytes = (
        num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
    )
    num_hbm_bytes = (
        l1_weight_bytes
        + l2_weight_bytes  # weights (FP8)
        + l1_weight_sf_bytes
        + l2_weight_sf_bytes  # weight SF (FP32)
        + num_recv_tokens * hidden
        + l1_input_sf_bytes  # L1 输入读 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L1 输出写 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L2 输入读 (FP8 + SF)
        + num_recv_tokens * hidden * 2  # L2 输出写 (BF16)
    )
    hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)

    # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
    num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
    nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)

    # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
    t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12

    # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
    approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)

    # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
    tflops_baseline = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
    )
    hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
    nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)

    dist_print("Performance:", once_in_node=True)
    dist_print(
        f" > [fused]    EP {rank_idx:2}/{num_ranks} | "
        f"{tflops:4.0f} TFLOPS | "
        f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
        f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
        f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
        f"{t_fused * 1e6:6.0f} us, "
        f"reduction: {t_reduction * 1e6:5.1f} us"
    )
    if ep_buffer is not None:
        speedup = safe_div(t_baseline, t_fused)
        dist_print(
            f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
            f"{tflops_baseline:4.0f} TFLOPS | "
            f"               HBM {hbm_gbs_baseline:4.0f} GB/s, "
            f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
            f"{t_baseline * 1e6:6.0f} us | "
            f"t_baseline/t_fused = {speedup:.2f}x "
            f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
        )
    else:
        dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)

    # ---- 清理 ----
    dist.barrier()
    sym_buffer.destroy()
    if ep_buffer is not None:
        ep_buffer.destroy()
    dist.destroy_process_group()


# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
    )

    # 资源
    parser.add_argument(
        "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
    )

    # 模型形状
    # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
    parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
    parser.add_argument(
        "--num-tokens",
        type=int,
        default=0,
        help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
    )
    parser.add_argument("--hidden", type=int, default=7168)
    parser.add_argument(
        "--intermediate-hidden",
        type=int,
        default=3072,
        help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
    )
    parser.add_argument(
        "--activation-clamp",
        type=float,
        default=10.0,
        help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
    )
    parser.add_argument("--num-experts", type=int, default=384)
    parser.add_argument("--num-topk", type=int, default=6)
    parser.add_argument(
        "--fast-math",
        type=int,
        default=1,
        help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
    )

    # 测时
    parser.add_argument(
        "--num-bench-tests",
        type=int,
        default=30,
        help="bench_kineto 抓 fused 时的迭代数",
    )
    parser.add_argument(
        "--num-warmup", type=int, default=5, help="baseline cuda events warmup"
    )
    parser.add_argument(
        "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
    )
    parser.add_argument(
        "--l2-flush-gb",
        type=float,
        default=8.0,
        help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
    )
    parser.add_argument(
        "--check-output-diff",
        type=int,
        default=0,
        help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
    )
    parser.add_argument(
        "--dump-profile-traces",
        type=str,
        default="",
        help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
    )

    args = parser.parse_args()

    if args.dump_profile_traces:
        os.makedirs(args.dump_profile_traces, exist_ok=True)

    # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
    torch.multiprocessing.spawn(
        test, args=(args.num_processes, args), nprocs=args.num_processes
    )

deepseek-ai/DeepEP#629 没有RDMA 的8卡需要依赖这个PR 才能跑ElasticBuffer接口,另外EP4 是不是太小了?整体应该还是bound 在HBM 读取上面了,看不到megamoe 的收益。

@Stone749990226

Copy link
Copy Markdown

你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336

@Stone749990226

Copy link
Copy Markdown

你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336

就是官方的Mega MoE的kernel的B300的报告,我在8卡上跑的

@qiushixiaoyu

Copy link
Copy Markdown
Author

@Stone749990226 我只有H20环境

@qinqinwo

Copy link
Copy Markdown

export PYTHONPATH=/workspace/DeepGEMM:/workspace/DeepEP:${PYTHONPATH:-} export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/tvm_ffi/lib:${LD_LIBRARY_PATH:-} python3 tests/test_mega_moe_hopper.py --num-processes 8 --num-max-tokens-per-rank --num-tokens --hidden 4096 --intermediate-hidden 2048 --num-experts 256 --num-topk 6 --num-bench-tests 5 --num-warmup 2 --num-repeat 5 --l2-flush-gb 0 --run-baseline

Batch Fused avg us Baseline avg us Baseline / Fused Fused TFLOPS Baseline TFLOPS Fused HBM GB/s Baseline HBM GB/s Status
1 183.4 327.6 1.787 1.6 1.0 755.1 422.8 ok
2 263.0 380.4 1.446 2.1 1.5 1005.5 695.6 ok
4 406.1 497.4 1.225 3.0 2.4 1070.5 873.6 ok
8 497.1 546.1 1.099 4.8 4.5 1293.1 1177.2 ok
16 566.0 641.2 1.133 8.4 7.4 1376.8 1214.6 ok
32 576.0 651.0 1.130 16.8 14.8 1404.6 1242.4 ok
64 592.5 653.2 1.103 32.8 29.6 1371.9 1242.5 ok
128 597.9 680.1 1.138 64.9 56.9 1370.9 1202.9 ok
512 1144.0 1220.9 1.067 135.9 126.6 752.1 702.0 ok
1024 1989.5 2189.1 1.100 156.0 141.1 458.8 415.0 ok
4096 6949.8 6913.9 0.995 179.0 179.0 176.0 176.0 ok
8192 13514.9 13343.6 0.987 184.2 185.4 121.2 122.2 ok

请问,这是用最新代码跑的吗?baseline是和sm100上一样,deepep v2+deepgemm完全没任何overlap的吗?我之前跑出来的结果如下:

Model Tokens/rank Backend Time us TFLOPS HBM GB/s NVL GB/s Speedup PR316 time us Time ratio Status
flash 1 pr323-sm90-fp8 360 1 490 0   56.5 6.372 ok
flash 512 pr323-sm90-fp8 2027 76 423 19   146.5 13.836 ok
flash 8192 pr323-sm90-fp8 22957 107 70 26   1283.1 17.892 ok
flash 32768 pr323-sm90-fp8 88802 111 46 27   4855.5 18.289 ok
pro 1 pr323-sm90-fp8 602 1 659 0   108.1 5.569 ok
pro 512 pr323-sm90-fp8 4203 97 776 16   369.6 11.372 ok
pro 8192 pr323-sm90-fp8 58415 111 78 18   2818.5 20.726 ok
pro 32768 pr323-sm90-fp8 220394 118 39 19   10655.2 20.684 ok

@usernamehaha2022

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap:
image
会比baseline差。
我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果:
--num-processes 4 --num-experts 128 --hidden 3072 on 4*H800
image
看上去性能合理

@usernamehaha2022

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

test_mega_moe_sm90.py
对应的测试在这个文件。如果有人对我们的优化感兴趣可以一起讨论🤔

@leiyin22

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

test_mega_moe_sm90.py 对应的测试在这个文件。如果有人对我们的优化感兴趣可以一起讨论🤔

源码有吗?

@foobar2023xx

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?

Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info    : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info    : 474 bytes gmem
ptxas info    : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info    : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info    : Compile time = 582.487 ms

@usernamehaha2022

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?

Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info    : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info    : 474 bytes gmem
ptxas info    : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info    : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info    : Compile time = 582.487 ms

你的测试脚本和参数是什么

@foobar2023xx

Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?

Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info    : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info    : 474 bytes gmem
ptxas info    : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info    : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info    : Compile time = 582.487 ms

你的测试脚本和参数是什么

测试脚本是
bench_mega_moe_sm90.py

DG_JIT_PTXAS_VERBOSE=1 \
DG_JIT_PRINT_COMPILER_COMMAND=1 \
python tests/bench_mega_moe_sm90.py \
  --num-processes 4 \
  --num-max-tokens-per-rank 8192 \
  --num-tokens 8192 \
  --hidden 4096 \
  --intermediate-hidden 2048 \
  --num-experts 256 \
  --num-topk 6

@usernamehaha2022

usernamehaha2022 commented May 25, 2026 via email

Copy link
Copy Markdown

@engineer1109

Copy link
Copy Markdown

H200 跑下来看起来 收益是负面的。。。

@Stone749990226

Copy link
Copy Markdown

H200 跑下来看起来 收益是负面的。。。

我也是这个结果,只有Deepep+GEMM的2/3

@usernamehaha2022

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下
image

@engineer1109

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

This looks better than PR.

@engineer1109

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

@usernamehaha2022

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

@engineer1109

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。

@usernamehaha2022

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。

我觉得可以试试。另外
image
就这个图来说目前的stall还挺多的 我现在还没有细看是哪里导致流水线停滞

@engineer1109

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。

我觉得可以试试。另外 image 就这个图来说目前的stall还挺多的 我现在还没有细看是哪里导致流水线停滞

DSV4参数 immediate_size = 3072 hidden_size = 7168 experts=384 tokens=448~640之间有一个性能低谷,tokens=648 相比 tokens=640 MFU直接暴涨 50%,感觉schedule有问题。

@usernamehaha2022

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。

我觉得可以试试。另外 image 就这个图来说目前的stall还挺多的 我现在还没有细看是哪里导致流水线停滞

DSV4参数 immediate_size = 3072 hidden_size = 7168 experts=384 tokens=448~640之间有一个性能低谷,tokens=648 相比 tokens=640 MFU直接暴涨 50%,感觉schedule有问题。

有可能,现在token少的时候确实性能一般。这个你有时间看吗?我目前还在看token多的时候tma不是瓶颈,L2 wait更多的问题。我的commit有一个简单的profiler: export DG_PHASE_PROFILING

@engineer1109

Copy link
Copy Markdown

https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 image

但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈?

最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点

H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。

我觉得可以试试。另外 image 就这个图来说目前的stall还挺多的 我现在还没有细看是哪里导致流水线停滞

DSV4参数 immediate_size = 3072 hidden_size = 7168 experts=384 tokens=448~640之间有一个性能低谷,tokens=648 相比 tokens=640 MFU直接暴涨 50%,感觉schedule有问题。

有可能,现在token少的时候确实性能一般。这个你有时间看吗?我目前还在看token多的时候tma不是瓶颈,L2 wait更多的问题。我的commit有一个简单的profiler: export DG_PHASE_PROFILING

tokens_per_expert 分界线改成56 可以解决 PRO参数下 性能低谷问题。 感觉还是block_m 64还是128的调度问题。

@Stone749990226

Copy link
Copy Markdown

cluster_size改为2利用TMA的Cluster多播特性一次搬运是否能加快速度呢?收益会很多吗?我看目前Cluster为1

@qiushixiaoyu

Copy link
Copy Markdown
Author

cluster_size改为2利用TMA的Cluster多播特性一次搬运是否能加快速度呢?收益会很多吗?我看目前Cluster为1

试过,没什么收益

Squash local SM90 MegaMoE decode work on top of deepseek-ai/main. Includes SM90 FP8 MegaMoE kernels, decode split-N tuning, low-latency benchmark support, and bn256 L2 counter enablement.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants