Skip to content

feat: add sm120 support for DeepGEMM #324

Open
leavelet wants to merge 51 commits into
deepseek-ai:nv_devfrom
leavelet:sm120
Open

feat: add sm120 support for DeepGEMM #324
leavelet wants to merge 51 commits into
deepseek-ai:nv_devfrom
leavelet:sm120

Conversation

@leavelet

@leavelet leavelet commented May 1, 2026

Copy link
Copy Markdown

Added sm120 support for DeepGEMM, with performance on par with cuBLAS and outperforming the CUTLASS version in most cases.

sm120 support will be maintained in the nv_dev branch by NVIDIA DevTech APAC.

Performance metrics:

Dense GEMM

Precision Peak TFLOPS % of MMA Peak Notes
FP8 619 76% block-scaled UE8M0, BF16 output
FP4 1239 81% block-scaled UE8M0, gran_k=32, BF16 output
BF16 374 98.5% No scale factors
  • FP4 vs FP8 speedup: 1.6–1.9× on large shapes.
  • BF16 achieves the highest MMA utilization (98.5%) due to being more compute-bound.

Grouped GEMM (FP8 / FP4)

Kernel FP8 TFLOPS FP4 TFLOPS
K-grouped (EP64, 4 groups) 551 828
M-grouped contiguous (4g) 728 1372
M-grouped masked (6g) 633 1169
  • FP4 vs FP8 speedup: ~1.8× on M-grouped contiguous, ~1.7–1.8× on masked.
  • M-grouped contiguous exceeds dense TFLOPS due to larger effective M (~34K), giving better wave efficiency across SMs.

Einsum (Batched GEMM)

Variant FP8 TFLOPS BF16 TFLOPS
K-major B (bhr,hdr->bhd) 682 384
MN-major B (bhd,hdr->bhr) 664 ~410
Split-S reduction (bmk,bnk->mn) 188 (BW-limited)
  • MN-major B matches K-major performance (99–102%) via ldmatrix.trans.x2 + multi-atom BLOCK_N=128 optimization.

HC Prenorm (Fused GEMM + sqr_sum, TF32)

M K Splits TFLOPS GB/s
8192 28672 16 28.8 1242
4096 28672 16 25.1 1087
  • Inherently memory-bound (N=24). Split-K provides up to 7× speedup for small M.

MQA Logits (Attention, FP8)

Mode Peak TFLOPS Notes
Dense (ragged) 651 (80% peak) Warp-specialized, L2-cached KV
Paged 322 DRAM BW-limited (1.35 TB/s)

Co-authored by: @lucifer1004

@leavelet leavelet mentioned this pull request May 1, 2026
@leavelet leavelet changed the title [WIP] Feat: Add sm120 support for DeepGEMM [WIP] feat: Add sm120 support for DeepGEMM May 1, 2026
@leavelet leavelet marked this pull request as ready for review May 9, 2026 04:37
@linjiapro

Copy link
Copy Markdown

@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch.

@jasl

jasl commented May 10, 2026

Copy link
Copy Markdown
Contributor

I add my benchmark at vllm-project/vllm#41834 as references

jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to
tokens==1) with a port of the SM120 FP8 einsum kernel from upstream
DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file
`deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM
kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs,
with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and
the same block-128 fp32 scale recipe.

Removing the tokens==1 gate: the kernel handles all token counts that
the SM12x dispatch predicate accepts (tokens <= 16 today; larger token
batches will arrive once T1-α expands graph capture).

Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024,
GPU idle):

  tokens=1  cuda 0.026ms  triton 0.020ms  speedup 0.72x (was 0.40x)
  tokens=2  cuda 0.027ms  triton 0.020ms  speedup 0.72x (was 0.73x*)
  tokens=8  cuda 0.075ms  triton 0.020ms  speedup 0.27x (was 0.21x*)

  * Triton-as-default after the previous tokens==1 hotfix.

The kernel's grid is `tokens * groups * (out/128)`, one block per
`(token, group, out_tile=128)` triple. Because each block reads its
weight tile independently, total weight reads scale linearly with
`num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the
production shape and the 0.72x is a real net win against the previous
Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel
loses ~2x to Triton's m=16 cooperative tile; we will revisit with a
multi-token tile design before T1-α exposes that shape to production.

Earlier hand-written attempts (one-cell-per-block, per-thread B=16
accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA,
4-warp m16n128 MMA) are documented in
`docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA
designs hit either occupancy collapse (80 regs/thread) or insufficient
parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping
out at ~0.51x. The DeepGEMM design wins at the production shape by
avoiding tensor cores entirely -- a per-thread GEMV with fp8x4
vectorization and L1/L2-friendly weight access fits the small-M decode
profile better than the m=16 MMA tile.

Attribution: kernel source ported under MIT license from upstream
DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the
tvm-ffi binding, stride/scale validation, and the SM12x dispatch
integration; the dot-product math is unchanged.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to
``DeepseekV4Attention._compute_qr_kv``: on success it replaces the
reference FP8 linear path, on failure it logs a WARNING per layer and
falls back. DeepGEMM does not support SM120/SM121 yet (see PR
``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory),
so on the RTX Pro 6000 workstation every layer fires:

    DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference
    FP8 linear. reason=RuntimeError: Assertion error
    (csrc/apis/layout.hpp:59): Unknown SF transformation

The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag
already catches this for steady-state replay, but it costs one failed
call + one WARNING per layer at boot. Mirror the pattern used by
``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short-
circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x
so the platform never tries the DeepGEMM path. Non-SM12x platforms
keep the new fast path.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 20, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
@Rachmanino

Copy link
Copy Markdown

nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro?

jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
DoradusResearch pushed a commit to DoradusResearch/vllm that referenced this pull request May 23, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
leavelet and others added 12 commits May 26, 2026 05:43
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0).

Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors,
B128 XOR swizzle, persistent scheduling, register-based epilogue.

New files:
- SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils
- CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages)

Modified files:
- Arch detection, compiler flags (-gencode for SM120a)
- API dispatch (arch_major == 12), SF layout transform
- Default recipe for SM120

Correctness: 8/8 shapes pass (diff < 0.001 cosine distance)
Performance: ~73 TFLOPS (baseline, optimization pending)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only

Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100
architecture), merging the warp-specialized implementation into the main
sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM
with the mxf4nvf4 m16n8k64 MMA instruction.

Key changes:
- Consolidate: remove non-spec path, always BM=128/BK=128/384 threads
- FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard
  ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4
- Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4
- API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding
- Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer
  loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4,
  fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment,
  tensor map descriptors at end), fix epilogue bounds for multi-group output.

- MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline
  asm using "+f" read-write constraints for accumulator registers. Eliminates
  CUTLASS header dependency and gives explicit control over MMA operand
  encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA.

- JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA
  descriptor creation (first_k base, FP4-aware stride), SF TMA covering
  concatenated groups, CD TMA with num_groups outer dimension.

- API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous,
  relax recipe assertion to support gran_k=32/128, add SM120 SF layout
  transform with auto-detection of transposed K-major scale factors.

- Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K
  edge case), fix K-major selection for SM120 in generators, fix test
  dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison.

Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support:
- Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store
- Add IndexType::SF_K for batched SF coordinate computation
- Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint)
- BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192)

M-grouped BF16:
- Add contiguous and masked M-grouped BF16 GEMM launchers

HC prenorm TF32:
- New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak)
- BF16 A -> FP32 cast with fused sqr_sum accumulation
- Atom-aware FP32 B fragment loading from K-major SMEM
- Split-K support for large K / small M shapes
- 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass.
Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass.

Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale).
8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction
across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time.
Global stores are fire-and-forget on SM120a, so no epilogue warps needed.

Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages,
3 KV stages, 84KB SMEM (83% of 101KB capacity).

Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline.
Fixed metadata split_kv mismatch and register budget overflow (TMA regs
64→40 to stay within 65536 register limit).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits.
  Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead
  of n; kernel uses stride_d parameter for D row stride and bounds checks;
  TMA store coordinates apply epilogue N-index remapping.

- MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B.
  Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false.

- FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride
  from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with
ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads
column-major 8x8 BF16 matrices — directly producing MMA B fragments.

Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T
(M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to
heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint).

Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS,
MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans
correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5).

MN-major B now achieves 99-102% of K-major performance (was 53% with
scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr

New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output
(188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and
bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride
formula: replace single stride_d with stride_cd_m + stride_cd_batch to
support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor
template parameter to FP8 kernel with verified scalar-load MN-major B path
(correct but 3x slower than K-major ldmatrix, kept for future optimization).
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA
(kKGroupedConstantStride) — per-group only replaces addr+dim, not stride.
TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem.

Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh:
- TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom
- Prefetch advance: use get_atom_advance() instead of hardcoded +1
- Math loop: conditional iteration count via is_paired_atom for unpaired atoms
- KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X).
Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2.
Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale.

Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8.
Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
@brandonmmusic-max

Copy link
Copy Markdown

Your more generalized fix is the right call. You certainly have the holistic POV. If I’m trying to be helpful I. Debugging I try to be targeted instead if taking a flamethrower to other people’s work lol because there are probably structural assumptions in making that I don’t even realize. Like the smem sized per atom I wasn’t thinking from that perspective (and it was also late at night for me lol). I’ll put this through it paced though thanks for responding!

jasl added a commit to jasl/vllm that referenced this pull request Jun 2, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 3, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 4, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
leavelet added 2 commits June 4, 2026 06:18
Brings in 9 upstream nv_dev commits: paged-MQA metadata OOB fix (deepseek-ai#353),
IMA guard in the paged-MQA scheduler (deepseek-ai#342), the FP16-weights FP8 MQA
logits kernel, the mega-MoE refactor (deepseek-ai#347), and assorted updates.

Conflicts resolved (3 files):
- csrc/apis/attention.hpp (dense MQA dispatch): union of both sides — keep
  SM120 FP4 (sm120_fp4_mqa_logits) and SM120 FP8 (arch 12) arms, and add
  nv_dev's FP16-weights arm (sm100_fp8_mqa_logits_f16_weights, SM100-only,
  already guarded by the arch==10 host assert).
- deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh: take nv_dev's
  refactored scheduler. SM120's one-line OOB guard is subsumed by nv_dev's
  reversed-metadata allocation (the deepseek-ai#342/deepseek-ai#353 fix). The refactor replaces
  get_num_kv/get_atom_advance with refresh_num_kv_and_advance + get_last_advance.
- tests/test_attention.py: keep SM120 small-S split-KV MQA shapes and add the
  upstream large-KV/32-head regime; next_n = (1..6) on SM120, (1,2,4,5,6) on
  SM100, (1,2,4) on SM90; take nv_dev's paged test body (memory caps + peak-mem
  reset).

Semantic fixes required by the merge (no textual conflict, would not compile/run
otherwise):
- sm120_fp8/fp4_paged_mqa_logits.cuh: migrate the scheduler call
  get_atom_advance(next_q_idx, batch_size) -> get_last_advance() to match
  nv_dev's refactored scheduler API (same migration nv_dev applied to the
  SM100 paged kernel).
- tests/test_attention.py: gate the FP16-weights path to arch 10. Upstream set
  FP16 weights unconditionally, which dispatches to the SM100-only kernel and
  trips the arch==10 host assert when test_attention runs on SM90/SM120.

Validated on H200 (sm90): clean build; test_bf16, test_fp8_fp4, test_attention
all pass. SM120 (arch 12) kernels are JIT and arch-gated, so they are neither
compiled nor run on H200 — they require an RTX 6000 run before the final PR.
The FP8/FP4 1D1D epilogue has two store paths: a swizzled TMA-store path
(kUseTMAStoreEpilogue, writes via tensor_map_cd) and a strided global-store
path (honors stride_cd_n). The TMA path's tensor map assumes a contiguous-N
output and ignores stride_cd_n, so it cannot express the AB-swap path's
transposed output (stride_cd_n = original-N stride != 1).

kUseTMAStoreEpilogue depends only on swizzle_cd_mode and BLOCK_N, where
swizzle_cd_mode = 128 once BLOCK_N * sizeof(cd) >= 128. The dense GEMM swap
caps M at 16 (BLOCK_N = 16 -> swizzle 0 -> strided store, correct). The BMM
swap (fp8_einsum) caps M at 32: with M in 17..32 and an FP32 output,
BLOCK_N = 32 -> swizzle_cd_mode = 128 -> the TMA-store path runs for a
transposed output and writes out of bounds.

Repro (RTX PRO 6000): fp8_einsum('bhr,hdr->bhd') with batch b = 32 and an FP32
output raises CUDA illegal memory access; b <= 16, or BF16 output (swizzle
stays 0), are unaffected. Not covered by the test suite, which uses BF16
output for the swapping expressions.

Fix: make the invariant explicit. Add GemmDesc::cd_n_contiguous; the SM120
swizzle_cd_mode heuristic requires it, and the dense + BMM swap launchers set
it false, so the swap always uses the strided-store epilogue. The swap path is
small-M and latency-bound, so the TMA-store epilogue is not a loss there.

Validated on RTX PRO 6000 (arch 12): the repro now passes (b=32 FP32
diff 7e-4); test_split_k_swap, test_einsum, test_fp8_fp4 still pass.
jasl added a commit to jasl/vllm that referenced this pull request Jun 4, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
…r changes

The SM120 MQA-logits work modified the shared launchers in two ways that broke
SM100 (and were latent on SM90), surfacing as test_attention failures on B300
(baseline=origin/nv_dev passes, sm120 HEAD fails). No SM100/SM90 device kernel
(.cuh) changed, so this is purely a host launch-config regression.

1. Dense MQA SMEM under-allocation (smxx_fp8_fp4_mqa_logits.hpp). The SMEM size
   dropped the `(num_math_threads / 128) * 2` mbarrier pairs that the SM90/SM100
   kernels still allocate (the SM120 kernel does not). On SM100 this under-sized
   the dynamic SMEM; compute-sanitizer reports an "Invalid __shared__ write" in
   sm100_fp8_mqa_logits, faulting (CUDA_ERROR_ILLEGAL_ADDRESS) at large configs
   such as seq_len_kv=130560. Restore the term for arch != 12.

2. Paged-MQA metadata capacity assert (smxx_fp8_fp4_paged_mqa_logits.hpp). An
   unconditional `smem_size <= SM120ArchSpec::smem_capacity` (99 KB) was added to
   the metadata launcher, which runs on every arch. On SM90/SM100 (228 KB) a
   large (varlen) batch's metadata SMEM legitimately exceeds 99 KB and falsely
   tripped the assert. Gate the capacity check to the running arch.

Verified on B300 (sm100): compute-sanitizer memcheck clean on the previously
out-of-bounds config (seq_len=510, seq_len_kv=130560); full test_attention
(dense + paged, NextN 1-6, FP8/FP4) passes. Fix (1) also corrects a latent
SMEM under-allocation on SM90.
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 6, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
Strip changes that should not ship upstream and align comment style with the
existing codebase (terse, no decorative banners). No change to production
GEMM/attention behavior — only removes a dev-only bench path and trims comments.

- .gitignore: revert to upstream; personal ignores (docs_internal/, benchmarks/,
  ncu_reports/, internal/) moved to local .git/info/exclude.
- Remove dev-only sm120_fp8_gemm_bench API and its override_layout tile-override
  param (gemm.hpp, __init__.py, sm120_fp8_fp4_gemm_1d1d.hpp); config selection
  collapses to the standard get_best_config path.
- Trim changelog-/rationale-style block comments to single lines (shared rationale
  lives in commit history): MQA-logits SMEM, paged-MQA capacity assert,
  cd_n_contiguous, split-K SF alignment, latency model, einsum MN-major notes.
- De-duplicate the AB-swap "transposed output" rationale (kept once in config.hpp)
  and the UE8M0 SF-tile note.
- Remove decorative // ==== banner comments from sm120 .cuh (no other kernel uses them).
- Remove an unused num_waves local in the SM120 latency model (clears a
  -Wunused-variable warning).
- Remove tests/test_split_k_swap.py (standalone dev harness, not pytest-style).
@leavelet

leavelet commented Jun 7, 2026

Copy link
Copy Markdown
Author

@RayWang96 I have cleaned up the code and verified them on sm90, sm100 and sm120, the PR is ready for merge. Thanks!

jasl added a commit to jasl/vllm that referenced this pull request Jun 7, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
AliceChenyy added a commit to AliceChenyy/sglang that referenced this pull request Jun 7, 2026
…pGEMM@sm120)

Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000).
Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged).

Changes:
- configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed
  (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability)
- server_args.py: Auto-select deep_gemm MoE backend on SM120
- kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required
  by DeepGEMM's block-scaled dequantization on SM120
- deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner:
  - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input)
  - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2)
  - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode)
- fp8.py: Add .contiguous() before transform_sf_into_required_layout
- fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported)

Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K):
  TTFT: 130ms (vs 400ms marlin, 3x faster)
  Decode ITL: 47ms (vs 41ms marlin, 15% slower)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
guqiong96 pushed a commit to guqiong96/Lvllmds4 that referenced this pull request Jun 8, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 9, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants