feat: MegaMOE adaptation for SM90#323
Conversation
|
Do you have benchmark data? |
I’m testing the benefits of DeepSeek V4 Flash on H20, and I’ll share the data soon. |
|
看起来效果并不理想:
Performance:
"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。
结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
* fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
* baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
不是 bitwise apples-to-apples correctness oracle。
* 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
reduction us / `t_baseline / t_fused` legacy 比。
"""
import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple
import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto
# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"
# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128
# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入 topk_w : (M,) fp32,可选
# 输出 y : (M, H) fp8_e4m3fn
# 输出 y_sf : (M, H/BLOCK_K) fp32 行主序
# ============================================================================
@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
x_ptr,
topk_w_ptr,
y_ptr,
y_sf_ptr,
M,
H, # 运行时形状
stride_xm,
stride_xn, # x: (M, 2H) 的 stride
stride_ym,
stride_yn, # y: (M, H) 的 stride
stride_sfm,
stride_sfk, # y_sf: (M, H/BLOCK_K) 的 stride
clamp_value, # 当 HAS_CLAMP=False 时这个参数无意义
HAS_TOPK: tl.constexpr,
HAS_CLAMP: tl.constexpr,
USE_UE8M0_SCALE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr, # = num_per_channels
):
# 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
# 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
# 当前 K-block 内的列索引(在 H 维度,不是 2H)
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
mask_m = offs_m < M
# ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
# 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
# ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
if HAS_CLAMP:
gate = tl.minimum(gate, clamp_value)
up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)
# ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
y = gate * tl.sigmoid(gate) * up
# ---- 4) 可选 MoE 权重缩放(per-token 标量)----
if HAS_TOPK:
w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
y = y * w[:, None]
# ---- 5) 当前 K-block 内每行 absmax → scale ----
amax = tl.max(tl.abs(y), axis=1) # (BLOCK_M,)
sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
if USE_UE8M0_SCALE:
# 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
# scale = 2 ** ceil(log2(amax / 448)).
sf = tl.exp2(tl.ceil(tl.log2(sf)))
# ---- 6) 量化为 FP8 e4m3fn ----
y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)
# ---- 7) 写回 y 和 sf ----
y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])
sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
tl.store(sf_ptrs, sf, mask=mask_m)
def swiglu_apply_weight_to_fp8_triton(
x: torch.Tensor,
topk_weights: torch.Tensor | None,
clamp_value: float | None = None,
num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""SwiGLU + FP8 量化。语义等价于 PyTorch reference:
gate, up = x[:, :H], x[:, H:]
y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
"""
assert x.is_cuda and x.dtype == torch.bfloat16
assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
M, two_H = x.shape
H = two_H // 2
assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"
y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)
# BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
BLOCK_M = 16
grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)
# HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
topk_ptr = topk_weights if topk_weights is not None else x
_swiglu_apply_weight_to_fp8_kernel[grid](
x,
topk_ptr,
y,
y_sf,
M,
H,
x.stride(0),
x.stride(1),
y.stride(0),
y.stride(1),
y_sf.stride(0),
y_sf.stride(1),
float(clamp_value) if clamp_value is not None else 0.0,
HAS_TOPK=topk_weights is not None,
HAS_CLAMP=clamp_value is not None,
USE_UE8M0_SCALE=use_ue8m0_scale,
BLOCK_M=BLOCK_M,
BLOCK_K=num_per_channels,
)
return y, y_sf
# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
# 每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
# * 不需要 deep_gemm.transform_sf_into_required_layout
# * SF 是 FP32,不是 UE8M0 packed
# ============================================================================
def _quantize_grouped_fp8_block_128_128(
w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
g, n, k = w.shape
assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"
# 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
w_view = w.view(g, n // 128, 128, k // 128, 128).float()
# 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128)
sf = amax / FP8_E4M3_MAX
# 量化:每个元素除以所属子块的 sf 后转 FP8
# sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
return w_fp8.view(g, n, k).contiguous(), sf.contiguous()
# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================
def _import_deep_ep():
try:
import deep_ep
return deep_ep
except Exception as ex:
dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
return None
# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================
def _bench_cuda_events(
fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
"""返回 fn 的中位数耗时(秒)。"""
for _ in range(num_warmup):
fn()
torch.cuda.synchronize()
times_ms = []
for _ in range(num_repeat):
# L2 flush,避免重复访问命中 cache 让测时偏低
if l2_flush_gb > 0:
free_bytes, _ = torch.cuda.mem_get_info()
flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
if flush_bytes >= 4:
torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn()
e.record()
e.synchronize()
times_ms.append(s.elapsed_time(e))
times_ms.sort()
return times_ms[len(times_ms) // 2] / 1e3
# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================
def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
# 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
torch.manual_seed(rank_idx)
random.seed(rank_idx)
# 形状参数(与 test_mega_moe.py 同名同义)
num_max_tokens_per_rank = args.num_max_tokens_per_rank
num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
num_experts, num_topk = args.num_experts, args.num_topk
num_experts_per_rank = num_experts // num_ranks
assert num_tokens <= num_max_tokens_per_rank
assert num_experts % num_ranks == 0, (
f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
)
# SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
# * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
# * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
assert hidden % 128 == 0
assert intermediate_hidden % 128 == 0
assert intermediate_hidden // 64 <= 64, (
f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
)
# ---- 创建 BF16 输入:token 与两层 weight ----
# x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
# L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
l1_weights_bf16 = torch.randn(
(num_experts_per_rank, intermediate_hidden * 2, hidden),
dtype=torch.bfloat16,
device="cuda",
)
# L2 weight: 每个 expert 把 intermediate_hidden → hidden
l2_weights_bf16 = torch.randn(
(num_experts_per_rank, hidden, intermediate_hidden),
dtype=torch.bfloat16,
device="cuda",
)
# 路由:scores → topk_idx (M, K) + topk_weights (M, K)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
topk_weights, topk_idx = torch.topk(
scores, num_topk, dim=-1, largest=True, sorted=False
)
# 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
cum_stats_fused = torch.zeros(
(num_experts_per_rank,), dtype=torch.int, device="cuda"
)
cum_stats_baseline = cum_stats_fused.clone()
# ---- BF16 → FP8 量化 ----
# x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
# 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
x_fp8 = per_token_cast_to_fp8(
x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
)
# weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
# baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)
# fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
l1_weights, l2_weights
)
# SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None
# ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
# ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group,
num_experts,
num_max_tokens_per_rank,
num_topk,
hidden,
intermediate_hidden,
)
y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
def run_fused():
# NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
# kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
sym_buffer.x[:num_tokens].copy_(x_fp8[0])
sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)
deep_gemm.fp8_mega_moe(
y_fused,
transformed_l1,
transformed_l2,
sym_buffer,
cumulative_local_expert_recv_stats=cum_stats_fused,
recipe=(128, 128, 128),
activation="swiglu",
activation_clamp=clamp_arg,
fast_math=bool(args.fast_math),
)
return y_fused
# ---- 分配 DeepEP buffer(baseline 用)----
deep_ep = _import_deep_ep()
ep_buffer = None
if deep_ep is not None:
ep_buffer = deep_ep.ElasticBuffer(
group,
num_max_tokens_per_rank=num_max_tokens_per_rank,
hidden=hidden,
num_topk=num_topk,
use_fp8_dispatch=True,
explicitly_destroy=True,
allow_multiple_reduction=False,
)
# ----------------------------------------------------------------
# baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
# 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
# 的版本(baseline grouped GEMM 不需要 gate/up interleave)
# ----------------------------------------------------------------
def run_baseline():
recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
x_fp8,
topk_idx=topk_idx,
topk_weights=topk_weights,
cumulative_local_expert_recv_stats=cum_stats_baseline,
num_experts=num_experts,
expert_alignment=alignment,
do_cpu_sync=False,
do_handle_copy=False,
do_expand=True,
use_tma_aligned_col_major_sf=False, # SM90: row-major float SF
)
n = recv_x[0].size(0)
# L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
l1_y = torch.empty(
(n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
recv_x,
l1_weights,
l1_y,
handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True,
disable_ue8m0_cast=True,
)
# Triton SwiGLU + FP8 量化(含 topk 权重乘法)
# 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
# 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
# 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
l1_y = swiglu_apply_weight_to_fp8_triton(
x=l1_y,
topk_weights=recv_topk_weights,
clamp_value=clamp_arg,
num_per_channels=BASELINE_L2_ACT_SF_GRAN,
use_ue8m0_scale=True,
)
# L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
l1_y,
l2_weights,
l2_y,
handle.psum_num_recv_tokens_per_expert,
use_psum_layout=True,
disable_ue8m0_cast=True,
)
# DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
return ep_buffer.combine(l2_y, handle=handle)[0]
# ---- 打印 config ----
dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
dist_print(
f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
)
dist_print(
f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
once_in_node=True,
)
dist_print(
f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
f"(SM90 grouped GEMM constraint)",
once_in_node=True,
)
dist_print(
f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
)
dist_print(once_in_node=True)
# ---- 跑一次确保不报错(fused + 可选 baseline)----
y = run_fused()
assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
)
if ep_buffer is not None:
out_b = run_baseline()
assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
)
if args.check_output_diff:
diff = (y.float() - out_b.float()).abs()
denom = out_b.float().abs().mean().clamp_min(1e-12)
dist_print(
"Output diff (fused vs legacy-per128 baseline):", once_in_node=True
)
dist_print(
f" > max_abs={diff.max().item():.6e}, "
f"mean_abs={diff.mean().item():.6e}, "
f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
once_in_node=True,
)
dist_print(once_in_node=True)
# ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
# 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
# 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
gathered_topk_idx[
(gathered_topk_idx < rank_idx * num_experts_per_rank)
| (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
] = -1
num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)
# ---- benchmark ----
# fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
t_fused = bench_kineto(
run_fused,
SM90_KERNEL_NAME,
num_tests=args.num_bench_tests,
barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
if ep_buffer is not None
else dist.barrier(),
trace_path=(
f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
if args.dump_profile_traces
else None
),
)
# baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
t_baseline = (
_bench_cuda_events(
run_baseline,
num_warmup=args.num_warmup,
num_repeat=args.num_repeat,
l2_flush_gb=args.l2_flush_gb,
)
if ep_buffer is not None
else 0.0
)
def safe_div(a, b):
return float("nan") if b == 0 else a / b
# 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
tflops = safe_div(
2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
)
# HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
l1_weight_sf_bytes = (
num_touched_experts
* (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
* (hidden // WEIGHT_SF_GRAN_K)
* 4
)
l2_weight_sf_bytes = (
num_touched_experts
* (hidden // WEIGHT_SF_GRAN_MN)
* (intermediate_hidden // WEIGHT_SF_GRAN_K)
* 4
)
l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
l2_act_sf_bytes = (
num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
)
num_hbm_bytes = (
l1_weight_bytes
+ l2_weight_bytes # weights (FP8)
+ l1_weight_sf_bytes
+ l2_weight_sf_bytes # weight SF (FP32)
+ num_recv_tokens * hidden
+ l1_input_sf_bytes # L1 输入读 (FP8 + SF)
+ num_recv_tokens * intermediate_hidden
+ l2_act_sf_bytes # L1 输出写 (FP8 + SF)
+ num_recv_tokens * intermediate_hidden
+ l2_act_sf_bytes # L2 输入读 (FP8 + SF)
+ num_recv_tokens * hidden * 2 # L2 输出写 (BF16)
)
hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)
# NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)
# combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12
# overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)
# baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
tflops_baseline = safe_div(
2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
)
hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)
dist_print("Performance:", once_in_node=True)
dist_print(
f" > [fused] EP {rank_idx:2}/{num_ranks} | "
f"{tflops:4.0f} TFLOPS | "
f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
f"{t_fused * 1e6:6.0f} us, "
f"reduction: {t_reduction * 1e6:5.1f} us"
)
if ep_buffer is not None:
speedup = safe_div(t_baseline, t_fused)
dist_print(
f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
f"{tflops_baseline:4.0f} TFLOPS | "
f" HBM {hbm_gbs_baseline:4.0f} GB/s, "
f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
f"{t_baseline * 1e6:6.0f} us | "
f"t_baseline/t_fused = {speedup:.2f}x "
f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
)
else:
dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)
# ---- 清理 ----
dist.barrier()
sym_buffer.destroy()
if ep_buffer is not None:
ep_buffer.destroy()
dist.destroy_process_group()
# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
)
# 资源
parser.add_argument(
"--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
)
# 模型形状
# 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
parser.add_argument(
"--num-tokens",
type=int,
default=0,
help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
)
parser.add_argument("--hidden", type=int, default=7168)
parser.add_argument(
"--intermediate-hidden",
type=int,
default=3072,
help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
)
parser.add_argument(
"--activation-clamp",
type=float,
default=10.0,
help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
)
parser.add_argument("--num-experts", type=int, default=384)
parser.add_argument("--num-topk", type=int, default=6)
parser.add_argument(
"--fast-math",
type=int,
default=1,
help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
)
# 测时
parser.add_argument(
"--num-bench-tests",
type=int,
default=30,
help="bench_kineto 抓 fused 时的迭代数",
)
parser.add_argument(
"--num-warmup", type=int, default=5, help="baseline cuda events warmup"
)
parser.add_argument(
"--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
)
parser.add_argument(
"--l2-flush-gb",
type=float,
default=8.0,
help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
)
parser.add_argument(
"--check-output-diff",
type=int,
default=0,
help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
)
parser.add_argument(
"--dump-profile-traces",
type=str,
default="",
help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
)
args = parser.parse_args()
if args.dump_profile_traces:
os.makedirs(args.dump_profile_traces, exist_ok=True)
# 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
torch.multiprocessing.spawn(
test, args=(args.num_processes, args), nprocs=args.num_processes
) |
deepseek-ai/DeepEP#629 没有RDMA 的8卡需要依赖这个PR 才能跑ElasticBuffer接口,另外EP4 是不是太小了?整体应该还是bound 在HBM 读取上面了,看不到megamoe 的收益。 |
|
你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336 |
就是官方的Mega MoE的kernel的B300的报告,我在8卡上跑的 |
|
@Stone749990226 我只有H20环境 |
请问,这是用最新代码跑的吗?baseline是和sm100上一样,deepep v2+deepgemm完全没任何overlap的吗?我之前跑出来的结果如下:
|
test_mega_moe_sm90.py |
源码有吗? |
测试脚本是 |
|
按照最新的测试跑了一下,确实没有spill,不过相比H800
1900+TFLOPS的峰值tflops比较低的现象还是存在的(379 TFLOPS)。这里的C7510我记得是由于vprinf导致的,但是解了之后提升不大。 Config:
Tokens: 8192/8192
Hidden: 4096
Intermediate: 2048
Experts: 6/256
Buffer: 2.507 GiB
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743189-40fbd983-7dc7f586-d638089b/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743189-40fbd983-7dc7f586-d638089b/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743187-8f593979-ce0bfdbd-e1ee5922/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743187-8f593979-ce0bfdbd-e1ee5922/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743188-7615d764-76bb6375-b809a2fe/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743188-7615d764-76bb6375-b809a2fe/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
Running NVCC command: cd /root/.deep_gemm/tmp && /usr/local/cuda/bin/nvcc
/root/.deep_gemm/tmp/3743186-397a9b0a-45b0e0be-7ec9317a/kernel.cu -cubin -o
/root/.deep_gemm/tmp/3743186-397a9b0a-45b0e0be-7ec9317a/kernel.cubin
-std=c++20 --diag-suppress=39,161,174,177,186,940
--ptxas-options=--register-usage-level=10
--ptxas-options=--verbose,--warn-on-local-memory-usage
-I/workspace/qiushixiaoyu-deepgemm/DeepGEMM/deep_gemm/include
--gpu-architecture=sm_90a
--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi
-O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 315.840 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 307.963 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 308.712 ms
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async
instructions are serialized due to wgmma pipeline crossing function
boundary at a function call in the function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_',
size of stack frame: 56 bytes
ptxas info : 517 bytes gmem
ptxas info : Compiling entry function
'_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
for 'sm_90a'
ptxas info : Function properties for
_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT15_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative
stack size
ptxas info : Compile time = 325.611 ms
Performance:
EP: 0/4 | 375 TFLOPS | overlap: 379 TFLOPS, HBM 370 GB/s, NVL 92
GB/s | 6605 us, reduction: 72.3 us | 0.00x legacy
foobar2023xx ***@***.***> 于2026年5月25日周一 11:12写道:
… *foobar2023xx* left a comment (deepseek-ai/DeepGEMM#323)
<#323 (comment)>
在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: [image:
image]
<https://private-user-images.githubusercontent.com/199557527/595469978-c8551ec8-eac4-49c0-b41b-7ea9d7532f48.png?jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Nzk0Mzk4MjksIm5iZiI6MTc3OTQzOTUyOSwicGF0aCI6Ii8xOTk1NTc1MjcvNTk1NDY5OTc4LWM4NTUxZWM4LWVhYzQtNDljMC1iNDFiLTdlYTlkNzUzMmY0OC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjYwNTIyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI2MDUyMlQwODQ1MjlaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1jMzQ0ZDZkM2MyYWZmMDkzMmJjZTlhYzdhMWExZTgxZTAzZjM5MzlmYjM0ZDk2NzMyOWZiZWNhMDk5ZWY1ODRmJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZyZXNwb25zZS1jb250ZW50LXR5cGU9aW1hZ2UlMkZwbmcifQ.fuwLWBUSUuRpl04rM4JIfpv6YfaydLq7LG04NYmb_Jc>
会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4
--num-experts 128 --hidden 3072 on 4*H800 [image: image]
<https://private-user-images.githubusercontent.com/199557527/595470385-62183cd2-a858-4569-a73f-aca982a0a8bd.png?jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Nzk0Mzk4MjksIm5iZiI6MTc3OTQzOTUyOSwicGF0aCI6Ii8xOTk1NTc1MjcvNTk1NDcwMzg1LTYyMTgzY2QyLWE4NTgtNDU2OS1hNzNmLWFjYTk4MmEwYThiZC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjYwNTIyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI2MDUyMlQwODQ1MjlaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT1kNDQzMmY4OTFkNzNkNjQ3NGYzY2Y3Y2E3ZGMxNWNiM2ZiMzQwMzUwNGU0NDk3NjQyMjI4MzU2NTZiY2I1ZGFjJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZyZXNwb25zZS1jb250ZW50LXR5cGU9aW1hZ2UlMkZwbmcifQ.emAemklJJR0mz482dNzDEAf_9Q_7B57qQ-cmv6gV03w>
看上去性能合理
我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?
Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info : 474 bytes gmem
ptxas info : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info : Compile time = 582.487 ms
你的测试脚本和参数是什么
测试脚本是
bench_mega_moe_sm90.py
<https://github.com/user-attachments/files/28204259/bench_mega_moe_sm90.py>
DG_JIT_PTXAS_VERBOSE=1 \
DG_JIT_PRINT_COMPILER_COMMAND=1 \
python tests/bench_mega_moe_sm90.py \
--num-processes 4 \
--num-max-tokens-per-rank 8192 \
--num-tokens 8192 \
--hidden 4096 \
--intermediate-hidden 2048 \
--num-experts 256 \
--num-topk 6
—
Reply to this email directly, view it on GitHub
<#323 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/BPSQDF2J5JJ7TX5SBDGY4DD44O2ZZAVCNFSM6AAAAACYLY4UHGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DKMZRGE4TIOJVGQ>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you commented.Message ID:
***@***.***>
|
|
H200 跑下来看起来 收益是负面的。。。 |
我也是这个结果,只有Deepep+GEMM的2/3 |
|
https://github.com/usernamehaha2022/DeepGEMM/tree/add_hopper_mega DG_JIT_PTXAS_VERBOSE=1 DG_JIT_PRINT_COMPILER_COMMAND=1 python tests/bench_mega_moe_sm90.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 4096 --num-experts 256 --num-topk 6 基于lz的提交做了一些修改,大家可以测一下 |
This looks better than PR. |
但是MFU 也只有0.37的水准,比B卡的还低很多,还有哪些瓶颈? |
最大的瓶颈就是H卡没有TMEM。所以做一做低延迟还是可以,我理解比sbo会好一点 |
H卡缺寄存器是没办法的事,并且sfa和sfb也没有硬件加速,是CUDA core算的。如果 我砍掉sfa和sfb,直接fp8/fp8计算(总比mxfp8/mxfp4精度高)。是不是还能再腾出一点MFU。 |
|
DSV4参数 immediate_size = 3072 hidden_size = 7168 experts=384 tokens=448~640之间有一个性能低谷,tokens=648 相比 tokens=640 MFU直接暴涨 50%,感觉schedule有问题。 |
有可能,现在token少的时候确实性能一般。这个你有时间看吗?我目前还在看token多的时候tma不是瓶颈,L2 wait更多的问题。我的commit有一个简单的profiler: export DG_PHASE_PROFILING |
tokens_per_expert 分界线改成56 可以解决 PRO参数下 性能低谷问题。 感觉还是block_m 64还是128的调度问题。 |
|
cluster_size改为2利用TMA的Cluster多播特性一次搬运是否能加快速度呢?收益会很多吗?我看目前Cluster为1 |
试过,没什么收益 |
7216e50 to
40b508b
Compare
Squash local SM90 MegaMoE decode work on top of deepseek-ai/main. Includes SM90 FP8 MegaMoE kernels, decode split-N tuning, low-latency benchmark support, and bn256 L2 counter enablement.













Summary
Add SM90/Hopper FP8 MegaMoE decode support to DeepGEMM.
This change introduces a fused SM90 MegaMoE kernel path for FP8 weights and FP8 activations, including the C++/CUDA implementation, JIT entry point, Python API wrapper, Hopper-specific weight transform, scheduling heuristics, and a unified Hopper accuracy/performance test script.
Main Changes
Added the SM90 FP8 MegaMoE fused kernel implementation:
deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuhcsrc/jit_kernels/impls/sm90_fp8_mega_moe.hppExtended MegaMoE API bindings with
fp8_mega_moefor SM90.Added Hopper-specific Python entry points:
deep_gemm.fp8_mega_moedeep_gemm.transform_weights_for_mega_moe_sm90Added SM90 MegaMoE scheduling/config heuristics.
Updated MegaMoE symmetric buffer handling for SM90 FP32 scale-factor layouts.
Added
tests/test_mega_moe_hopper.py, covering:DeepSeekV4Flash(8 card H20)
DeepSeekV4Pro(8 card H20)
Benchmark:DeepSeekV4Flash CP8/EP8
SLO-Compliant Total Throughput
Max Throughput
@LyricZhao Could you help review this PR when you have time? If the patch is too large for one PR, I’m happy to split it into smaller parts following your preference.