feat: add sm120 support for DeepGEMM #324
Conversation
|
@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch. |
|
I add my benchmark at vllm-project/vllm#41834 as references |
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to tokens==1) with a port of the SM120 FP8 einsum kernel from upstream DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file `deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs, with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and the same block-128 fp32 scale recipe. Removing the tokens==1 gate: the kernel handles all token counts that the SM12x dispatch predicate accepts (tokens <= 16 today; larger token batches will arrive once T1-α expands graph capture). Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024, GPU idle): tokens=1 cuda 0.026ms triton 0.020ms speedup 0.72x (was 0.40x) tokens=2 cuda 0.027ms triton 0.020ms speedup 0.72x (was 0.73x*) tokens=8 cuda 0.075ms triton 0.020ms speedup 0.27x (was 0.21x*) * Triton-as-default after the previous tokens==1 hotfix. The kernel's grid is `tokens * groups * (out/128)`, one block per `(token, group, out_tile=128)` triple. Because each block reads its weight tile independently, total weight reads scale linearly with `num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the production shape and the 0.72x is a real net win against the previous Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel loses ~2x to Triton's m=16 cooperative tile; we will revisit with a multi-token tile design before T1-α exposes that shape to production. Earlier hand-written attempts (one-cell-per-block, per-thread B=16 accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA, 4-warp m16n128 MMA) are documented in `docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA designs hit either occupancy collapse (80 regs/thread) or insufficient parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping out at ~0.51x. The DeepGEMM design wins at the production shape by avoiding tensor cores entirely -- a per-thread GEMV with fp8x4 vectorization and L1/L2-friendly weight access fits the small-M decode profile better than the m=16 MMA tile. Attribution: kernel source ported under MIT license from upstream DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the tvm-ffi binding, stride/scale validation, and the SM12x dispatch integration; the dot-product math is unchanged. Signed-off-by: jasl <jasl9187@hotmail.com>
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to ``DeepseekV4Attention._compute_qr_kv``: on success it replaces the reference FP8 linear path, on failure it logs a WARNING per layer and falls back. DeepGEMM does not support SM120/SM121 yet (see PR ``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory), so on the RTX Pro 6000 workstation every layer fires: DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference FP8 linear. reason=RuntimeError: Assertion error (csrc/apis/layout.hpp:59): Unknown SF transformation The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag already catches this for steady-state replay, but it costs one failed call + one WARNING per layer at boot. Mirror the pattern used by ``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short- circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x so the platform never tries the DeepGEMM path. Non-SM12x platforms keep the new fast path. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
|
nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro? |
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0). Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors, B128 XOR swizzle, persistent scheduling, register-based epilogue. New files: - SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils - CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages) Modified files: - Arch detection, compiler flags (-gencode for SM120a) - API dispatch (arch_major == 12), SF layout transform - Default recipe for SM120 Correctness: 8/8 shapes pass (diff < 0.001 cosine distance) Performance: ~73 TFLOPS (baseline, optimization pending) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100 architecture), merging the warp-specialized implementation into the main sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM with the mxf4nvf4 m16n8k64 MMA instruction. Key changes: - Consolidate: remove non-spec path, always BM=128/BK=128/384 threads - FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4 - Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4 - API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding - Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4, fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment, tensor map descriptors at end), fix epilogue bounds for multi-group output. - MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline asm using "+f" read-write constraints for accumulator registers. Eliminates CUTLASS header dependency and gives explicit control over MMA operand encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA. - JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA descriptor creation (first_k base, FP4-aware stride), SF TMA covering concatenated groups, CD TMA with num_groups outer dimension. - API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous, relax recipe assertion to support gran_k=32/128, add SM120 SF layout transform with auto-detection of transposed K-major scale factors. - Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K edge case), fix K-major selection for SM120 in generators, fix test dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison. Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support: - Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store - Add IndexType::SF_K for batched SF coordinate computation - Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint) - BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192) M-grouped BF16: - Add contiguous and masked M-grouped BF16 GEMM launchers HC prenorm TF32: - New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak) - BF16 A -> FP32 cast with fused sqr_sum accumulation - Atom-aware FP32 B fragment loading from K-major SMEM - Split-K support for large K / small M shapes - 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass. Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass. Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale). 8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time. Global stores are fire-and-forget on SM120a, so no epilogue warps needed. Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages, 3 KV stages, 84KB SMEM (83% of 101KB capacity). Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline. Fixed metadata split_kv mismatch and register budget overflow (TMA regs 64→40 to stay within 65536 register limit). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits. Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead of n; kernel uses stride_d parameter for D row stride and bounds checks; TMA store coordinates apply epilogue N-index remapping. - MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B. Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false. - FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads column-major 8x8 BF16 matrices — directly producing MMA B fragments. Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T (M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint). Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS, MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5). MN-major B now achieves 99-102% of K-major performance (was 53% with scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output (188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride formula: replace single stride_d with stride_cd_m + stride_cd_batch to support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor template parameter to FP8 kernel with verified scalar-load MN-major B path (correct but 3x slower than K-major ldmatrix, kept for future optimization).
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA (kKGroupedConstantStride) — per-group only replaces addr+dim, not stride. TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem. Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh: - TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom - Prefetch advance: use get_atom_advance() instead of hardcoded +1 - Math loop: conditional iteration count via is_paired_atom for unpaired atoms - KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X). Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2. Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale. Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8. Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
|
Your more generalized fix is the right call. You certainly have the holistic POV. If I’m trying to be helpful I. Debugging I try to be targeted instead if taking a flamethrower to other people’s work lol because there are probably structural assumptions in making that I don’t even realize. Like the smem sized per atom I wasn’t thinking from that perspective (and it was also late at night for me lol). I’ll put this through it paced though thanks for responding! |
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Brings in 9 upstream nv_dev commits: paged-MQA metadata OOB fix (deepseek-ai#353), IMA guard in the paged-MQA scheduler (deepseek-ai#342), the FP16-weights FP8 MQA logits kernel, the mega-MoE refactor (deepseek-ai#347), and assorted updates. Conflicts resolved (3 files): - csrc/apis/attention.hpp (dense MQA dispatch): union of both sides — keep SM120 FP4 (sm120_fp4_mqa_logits) and SM120 FP8 (arch 12) arms, and add nv_dev's FP16-weights arm (sm100_fp8_mqa_logits_f16_weights, SM100-only, already guarded by the arch==10 host assert). - deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh: take nv_dev's refactored scheduler. SM120's one-line OOB guard is subsumed by nv_dev's reversed-metadata allocation (the deepseek-ai#342/deepseek-ai#353 fix). The refactor replaces get_num_kv/get_atom_advance with refresh_num_kv_and_advance + get_last_advance. - tests/test_attention.py: keep SM120 small-S split-KV MQA shapes and add the upstream large-KV/32-head regime; next_n = (1..6) on SM120, (1,2,4,5,6) on SM100, (1,2,4) on SM90; take nv_dev's paged test body (memory caps + peak-mem reset). Semantic fixes required by the merge (no textual conflict, would not compile/run otherwise): - sm120_fp8/fp4_paged_mqa_logits.cuh: migrate the scheduler call get_atom_advance(next_q_idx, batch_size) -> get_last_advance() to match nv_dev's refactored scheduler API (same migration nv_dev applied to the SM100 paged kernel). - tests/test_attention.py: gate the FP16-weights path to arch 10. Upstream set FP16 weights unconditionally, which dispatches to the SM100-only kernel and trips the arch==10 host assert when test_attention runs on SM90/SM120. Validated on H200 (sm90): clean build; test_bf16, test_fp8_fp4, test_attention all pass. SM120 (arch 12) kernels are JIT and arch-gated, so they are neither compiled nor run on H200 — they require an RTX 6000 run before the final PR.
The FP8/FP4 1D1D epilogue has two store paths: a swizzled TMA-store path
(kUseTMAStoreEpilogue, writes via tensor_map_cd) and a strided global-store
path (honors stride_cd_n). The TMA path's tensor map assumes a contiguous-N
output and ignores stride_cd_n, so it cannot express the AB-swap path's
transposed output (stride_cd_n = original-N stride != 1).
kUseTMAStoreEpilogue depends only on swizzle_cd_mode and BLOCK_N, where
swizzle_cd_mode = 128 once BLOCK_N * sizeof(cd) >= 128. The dense GEMM swap
caps M at 16 (BLOCK_N = 16 -> swizzle 0 -> strided store, correct). The BMM
swap (fp8_einsum) caps M at 32: with M in 17..32 and an FP32 output,
BLOCK_N = 32 -> swizzle_cd_mode = 128 -> the TMA-store path runs for a
transposed output and writes out of bounds.
Repro (RTX PRO 6000): fp8_einsum('bhr,hdr->bhd') with batch b = 32 and an FP32
output raises CUDA illegal memory access; b <= 16, or BF16 output (swizzle
stays 0), are unaffected. Not covered by the test suite, which uses BF16
output for the swapping expressions.
Fix: make the invariant explicit. Add GemmDesc::cd_n_contiguous; the SM120
swizzle_cd_mode heuristic requires it, and the dense + BMM swap launchers set
it false, so the swap always uses the strided-store epilogue. The swap path is
small-M and latency-bound, so the TMA-store epilogue is not a loss there.
Validated on RTX PRO 6000 (arch 12): the repro now passes (b=32 FP32
diff 7e-4); test_split_k_swap, test_einsum, test_fp8_fp4 still pass.
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
…r changes The SM120 MQA-logits work modified the shared launchers in two ways that broke SM100 (and were latent on SM90), surfacing as test_attention failures on B300 (baseline=origin/nv_dev passes, sm120 HEAD fails). No SM100/SM90 device kernel (.cuh) changed, so this is purely a host launch-config regression. 1. Dense MQA SMEM under-allocation (smxx_fp8_fp4_mqa_logits.hpp). The SMEM size dropped the `(num_math_threads / 128) * 2` mbarrier pairs that the SM90/SM100 kernels still allocate (the SM120 kernel does not). On SM100 this under-sized the dynamic SMEM; compute-sanitizer reports an "Invalid __shared__ write" in sm100_fp8_mqa_logits, faulting (CUDA_ERROR_ILLEGAL_ADDRESS) at large configs such as seq_len_kv=130560. Restore the term for arch != 12. 2. Paged-MQA metadata capacity assert (smxx_fp8_fp4_paged_mqa_logits.hpp). An unconditional `smem_size <= SM120ArchSpec::smem_capacity` (99 KB) was added to the metadata launcher, which runs on every arch. On SM90/SM100 (228 KB) a large (varlen) batch's metadata SMEM legitimately exceeds 99 KB and falsely tripped the assert. Gate the capacity check to the running arch. Verified on B300 (sm100): compute-sanitizer memcheck clean on the previously out-of-bounds config (seq_len=510, seq_len_kv=130560); full test_attention (dense + paged, NextN 1-6, FP8/FP4) passes. Fix (1) also corrects a latent SMEM under-allocation on SM90.
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Strip changes that should not ship upstream and align comment style with the existing codebase (terse, no decorative banners). No change to production GEMM/attention behavior — only removes a dev-only bench path and trims comments. - .gitignore: revert to upstream; personal ignores (docs_internal/, benchmarks/, ncu_reports/, internal/) moved to local .git/info/exclude. - Remove dev-only sm120_fp8_gemm_bench API and its override_layout tile-override param (gemm.hpp, __init__.py, sm120_fp8_fp4_gemm_1d1d.hpp); config selection collapses to the standard get_best_config path. - Trim changelog-/rationale-style block comments to single lines (shared rationale lives in commit history): MQA-logits SMEM, paged-MQA capacity assert, cd_n_contiguous, split-K SF alignment, latency model, einsum MN-major notes. - De-duplicate the AB-swap "transposed output" rationale (kept once in config.hpp) and the UE8M0 SF-tile note. - Remove decorative // ==== banner comments from sm120 .cuh (no other kernel uses them). - Remove an unused num_waves local in the SM120 latency model (clears a -Wunused-variable warning). - Remove tests/test_split_k_swap.py (standalone dev harness, not pytest-style).
|
@RayWang96 I have cleaned up the code and verified them on sm90, sm100 and sm120, the PR is ready for merge. Thanks! |
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
…pGEMM@sm120) Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000). Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged). Changes: - configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability) - server_args.py: Auto-select deep_gemm MoE backend on SM120 - kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required by DeepGEMM's block-scaled dequantization on SM120 - deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner: - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input) - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2) - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode) - fp8.py: Add .contiguous() before transform_sf_into_required_layout - fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported) Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K): TTFT: 130ms (vs 400ms marlin, 3x faster) Decode ITL: 47ms (vs 41ms marlin, 15% slower) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Added sm120 support for DeepGEMM, with performance on par with cuBLAS and outperforming the CUTLASS version in most cases.
sm120 support will be maintained in the
nv_devbranch by NVIDIA DevTech APAC.Performance metrics:
Dense GEMM
Grouped GEMM (FP8 / FP4)
Einsum (Batched GEMM)
bhr,hdr->bhd)bhd,hdr->bhr)bmk,bnk->mn)ldmatrix.trans.x2+ multi-atomBLOCK_N=128optimization.HC Prenorm (Fused GEMM + sqr_sum, TF32)
MQA Logits (Attention, FP8)
Co-authored by: @lucifer1004