Skip to content

Add SM90 FP8 MegaMoE split kernel#352

Open
AichenF wants to merge 18 commits into
deepseek-ai:mainfrom
AichenF:megamoe_sm90
Open

Add SM90 FP8 MegaMoE split kernel#352
AichenF wants to merge 18 commits into
deepseek-ai:mainfrom
AichenF:megamoe_sm90

Conversation

@AichenF

@AichenF AichenF commented Jun 3, 2026

Copy link
Copy Markdown

co-author: @jychen21

Summary

  • Adds an SM90 (Hopper) FP8 MegaMoE implementation alongside the existing SM100 fp8_fp4_mega_moe. The kernel is split into two cooperating launches(_l1_impl + _l2_impl), each tuned independently per phase. Activation SF layout matches the DeepEP convention (per-128 K float), so SF tensors can be physically shared between DeepEP and this kernel.
  • Speedup over DeepEP baseline on H20: 1.15× – 2.04× (see table below).
  • This branch is built on top of the initial SM90 MegaMoE port in feat: MegaMOE adaptation for SM90 #323 — thanks to the authors of that PR for the groundwork.

What's new

  • API: deep_gemm.fp8_mega_moe(...) and transform_weights_for_mega_moe_sm90(...). Same SymmBuffer as SM100; slice dtype is arch-aware (float SF on SM90, packed UE8M0 int on SM100).
  • Kernel: deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh — one templated _core instantiated twice via MegaMoEPhasePolicy<Linear1|Linear2>, no cross-phase if-constexpr branching at use-sites.
  • Scheduler: mega_moe.cuh gains kClusterSize / kL{1,2}NMajorSchedule template params for_each_phase_block; SM100 fused path unchanged (cluster_size=2 default preserved).
  • Selector: csrc/jit_kernels/heuristics/mega_moe.hpp — shape-keyed candidate enumeration with H20 / H200 empirical hints.
  • Tests / bench: tests/test_mega_moe_sm90.py (smoke / heuristic / shape / edge / stress layers, calc_diff < 0.01 against BF16/FP32 reference); tests/bench_mega_moe_sm90.py.

Generic touches

  • deep_gemm/include/deep_gemm/common/math.cuh: added mul2() portable shim (__fmul2_rn is sm_100+ only).
  • deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh: kClusterSize template parameter; existing 2-CTA path preserved as the default.
  • deep_gemm/include/deep_gemm/comm/barrier.cuh: gate NVLink-barrier timeout printf behind DG_NVLINK_BARRIER_TIMEOUT_PRINTF; use DG_TRAP_ONLY_DEVICE_ASSERT.
  • csrc/jit/{compiler,handle}.hpp: DG_JIT_WITH_DEVICE_DEBUG (cuda-gdb DWARF) and DG_JIT_FORCE_LEGACY_LOAD (old-driver cuModuleLoad fallback) escape hatches.
  • deep_gemm/testing/bench.py: configurable L2-flush size (DG_BENCH_FLUSH_L2_BYTES); optional split-kernel timing breakdown (DG_SM90_MOE_REPORT_SPLIT_KERNELS).

Benchmark — H20-3e, 8 ranks, hidden=7168, ih=2048, num_experts=256, topk=8

M This branch (μs) DeepEP baseline (μs) Speedup
8 783.5 ~1520 1.94×
16 942.5 ~1624 1.72×
32 848.8 ~1650 1.94×
64 881.3 ~1647 1.87×
128 819.7 ~1670 2.04×
256 1141.6 ~1740 1.52×
260 1384.3 ~1730 1.25×
512 1958.4 ~2545 1.30×
819 2819.4 ~3356 1.19×
1024 3286.0 ~4221 1.29×
2048 6024.0 ~7322 1.22×
4096 11347.0 ~13254 1.17×
8192 22039.0 ~25396 1.15×

M ≤ 128 are 10-run medians; M ≥ 256 are from the full sweep. Numbers for M ≤ 128 are 10-run medians because at that scale per-call time is only 700–950 μs, and the result is dominated by routing-draw noise, idle-SM under-utilisation, and roughly M-independent launch/barrier overhead that doesn't amortise away.

@Rachmanino

Copy link
Copy Markdown

Hi authors, thanks for your great work! I wonder whether you've compared your split version versus the original one in #323. And if so, could you please illustrate the where the performance gain comes from? Thx

@YanyunDuan

Copy link
Copy Markdown

Does MegaMoE currently support SM120, or is support still under development?

@AichenF

AichenF commented Jun 8, 2026

Copy link
Copy Markdown
Author

Does MegaMoE currently support SM120, or is support still under development?

no plan to support sm120

@Mikezhang001

Copy link
Copy Markdown

Here are the benchmark results tested on an 8x H800 setup:

💻 Environment & Configurations

  • Hardware: 8 × NVIDIA H800 (SM90, Hopper)
  • Settings:
    • ranks = 8, hidden = 7168, intermediate_hidden = 2048, num_experts = 256
    • num_topk = 8, masked_ratio = 0, fast_math = 1, num_tests = 20

📊 End-to-End Latency Comparison (μs)

M (tokens/rank) H800 (μs) PR H20-3e (μs) DeepEP baseline (μs) H800 / PR H20 H800 vs DeepEP baseline (Speedup)
8 613.1 783.5 ~1520 0.78× 2.48×
16 730.9 942.5 ~1624 0.78× 2.22×
32 657.8 848.8 ~1650 0.77× 2.51×
64 614.6 881.3 ~1647 0.70× 2.68×
128 678.5 819.7 ~1670 0.83× 2.46×
256 877.3 1141.6 ~1740 0.77× 1.98×
260 963.9 1384.3 ~1730 0.70× 1.79×
512 1212.3 1958.4 ~2545 0.62× 2.10×
819 1729.9 2819.4 ~3356 0.61× 1.94×
1024 1939.8 3286.0 ~4221 0.59× 2.18×
2048 3517.0 6024.0 ~7322 0.58× 2.08×
4096 6458.0 11347.0 ~13254 0.57× 2.05×
8192 12147.0 22039.0 ~25396 0.55× 2.09×

Summary: The H800 environment achieves a significant 1.79× - 2.68× speedup over the DeepEP baseline across various batch sizes.

@mpdfdfl

mpdfdfl commented Jun 8, 2026

Copy link
Copy Markdown

The baseline seems a bit different from the one used on SM100 — looks like it isn't using DeepEP-v2? Here are my test results on H200:

For reference, my run on H200 (8 ranks, hidden=7168, ih=2048, num_experts=256, topk=8):

M (tokens) This branch (μs) TFLOPS DeepEP baseline (μs) Speedup
8 403.9 14.8 456.1 1.13×
16 422.6 27.9 497.8 1.18×
32 514.4 39.2 507.6 0.99×
64 443.8 90.9 526.0 1.19×
128 490.2 184.2 562.4 1.15×
256 615.2 303.4 642.5 1.04×
260 623.0 287.8 645.1 1.04×
512 806.9 452.1 866.2 1.07×
819 1106.0 517.9 1142.3 1.03×
1024 1222.3 594.8 1399.9 1.15×
2048 2088.1 690.6 2492.8 1.19×
4096 3866.0 743.2 4710.2 1.22×
8192 7492.0 768.4 9157.7 1.22×

EP_DISABLE_GIN=1 python3 tests/bench_mega_moe_sm90.py
--num-processes 8
--hidden 7168 --intermediate-hidden 2048
--num-experts 256 --num-topk 8
--batches 8 16 32 64 128 256 260 512 819 1024 2048 4096 8192
--baseline

bench_mega_moe_sm90.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants