Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions scripts/test_sglang_baselines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""
Test all 50 SGLang baseline operators.
Verifies: import OK, function exists, can be called with minimal inputs.
"""
import sys
import traceback
import time

sys.path.insert(0, "/share/project/zpy/flagbench/src")

results = {"pass": [], "fail": [], "skip": []}

# Minimal test inputs for each operator
TESTS = {
# activation.py
"silu_and_mul": lambda b: b.silu_and_mul(torch.randn(4, 512, device='cuda')),
"gelu_and_mul": lambda b: b.gelu_and_mul(torch.randn(4, 512, device='cuda')),
"quick_gelu": lambda b: b.quick_gelu(torch.randn(4, 512, device='cuda')),
"new_gelu": lambda b: b.new_gelu(torch.randn(4, 512, device='cuda')),
"xielu": lambda b: b.xielu(torch.randn(4, 512, device='cuda')),

# layernorm.py
"rms_norm": lambda b: b.rms_norm(torch.randn(4, 512, device='cuda'),
torch.ones(512, device='cuda')),
"layer_norm": lambda b: b.layer_norm(torch.randn(4, 512, device='cuda'), 512),
"gemma_rms_norm": lambda b: b.gemma_rms_norm(torch.randn(4, 512, device='cuda'),
torch.ones(512, device='cuda')),
"gemma3_rms_norm": lambda b: b.gemma3_rms_norm(torch.randn(4, 512, device='cuda'), 512),
"gemma4_rms_norm": lambda b: b.gemma4_rms_norm(torch.randn(4, 512, device='cuda'), 512),
"rms_norm_without_scale": lambda b: b.rms_norm_without_scale(torch.randn(4, 512, device='cuda'), 512),

# rotary_embedding/
"rotary_embedding": lambda b: b.rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"mrotary_embedding": lambda b: b.mrotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64, mrope_section=[10, 11, 11]),
"dual_chunk_rotary_embedding": lambda b: b.dual_chunk_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"deepseek_scaling_rotary_embedding": lambda b: b.deepseek_scaling_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"llama3_rotary_embedding": lambda b: b.llama3_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"dynamic_ntk_scaling_rotary_embedding": lambda b: b.dynamic_ntk_scaling_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"linear_scaling_rotary_embedding": lambda b: b.linear_scaling_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"phi3_long_rope_scaled_rotary_embedding": lambda b: b.phi3_long_rope_scaled_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64),
"triton_mrope_fused": lambda b: b.triton_mrope_fused(
torch.arange(128, device='cuda'),
torch.randn(128, 4 * 64, device='cuda'),
torch.randn(128, 4 * 64, device='cuda'),
torch.randn(8192, 128, device='cuda'), head_size=128, rotary_dim=128,
mrope_section=[42, 42, 44]),
"triton_ernie45_rope_fused": lambda b: b.triton_ernie45_rope_fused(
torch.arange(128, device='cuda'),
torch.randn(128, 4 * 64, device='cuda'),
torch.randn(128, 4 * 64, device='cuda'),
torch.randn(8192, 128, device='cuda'), head_size=128, rotary_dim=128,
mrope_section=[42, 42, 44]),
"apply_interleaved_rope_triton": lambda b: b.apply_interleaved_rope_triton(
torch.randn(3, 128, 64, device='cuda'), mrope_section=[10, 11, 11]),
"dynamic_ntk_alpha_rotary_embedding": lambda b: b.dynamic_ntk_alpha_rotary_embedding(
torch.arange(128, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
torch.randn(128, 4, 64, device='cuda'),
head_size=64, rotary_dim=64, scaling_alpha=2.0),

# moe/
"fused_moe": lambda b: None, # needs routing_data — complex, skip for now
"topk": lambda b: None, # needs full module init
"moe_align_block_size": lambda b: b.moe_align_block_size(
torch.randint(0, 4, (128, 2), device='cuda', dtype=torch.int32),
num_experts=4, block_size=32),

# attention/fla/
"l2norm": lambda b: b.l2norm(torch.randn(4, 512, device='cuda')),
"rms_norm_gated": lambda b: b.rms_norm_gated(
torch.randn(4, 512, device='cuda'), torch.randn(512, device='cuda')),
"fused_recurrent_gated_delta_rule": lambda b: b.fused_recurrent_gated_delta_rule(
q=torch.randn(1, 64, 4, 64, device='cuda'),
k=torch.randn(1, 64, 4, 64, device='cuda'),
v=torch.randn(1, 64, 4, 64, device='cuda'),
g=torch.randn(1, 64, 4, device='cuda'),
beta=torch.randn(1, 64, 4, device='cuda'),
scale=0.125),
"fused_recurrent_gated_delta_rule_update": lambda b: None, # complex
"fused_sigmoid_gating_delta_rule_update": lambda b: None, # complex
"fused_sigmoid_gating_delta_rule_packed_decode": lambda b: None, # complex
"fused_gdn_gating": lambda b: b.fused_gdn_gating(
A_log=torch.randn(16, device='cuda'),
a=torch.randn(4, 16, device='cuda'),
b=torch.randn(4, 16, device='cuda'),
dt_bias=torch.randn(16, device='cuda')),
"layer_norm_gated_fwd": lambda b: b.layer_norm_gated_fwd(
torch.randn(4, 512, device='cuda'),
torch.randn(4, 512, device='cuda'),
torch.ones(512, device='cuda'), torch.zeros(512, device='cuda')),

# attention/mamba/
"causal_conv1d_fn": lambda b: b.causal_conv1d_fn(
torch.randn(8, 64, device='cuda'),
torch.randn(64, 4, device='cuda')),
"causal_conv1d_update": lambda b: b.causal_conv1d_update(
torch.randn(4, 64, device='cuda'),
torch.randn(4, 64, 3, device='cuda'),
torch.randn(64, 4, device='cuda')),
"selective_scan_update": lambda b: b.selective_scan_update(
state=torch.randn(4, 64, 16, device='cuda'),
x=torch.randn(4, 64, device='cuda'),
dt=torch.randn(4, 64, device='cuda'),
A=torch.randn(64, 16, device='cuda'),
B=torch.randn(4, 16, device='cuda'),
C=torch.randn(4, 16, device='cuda')),
"mamba_chunk_scan_combined_fwd": lambda b: b.mamba_chunk_scan_combined_fwd(
torch.randn(1, 64, 4, 64, device='cuda'),
torch.randn(1, 64, 4, device='cuda'),
torch.randn(4, device='cuda'),
torch.randn(1, 64, 1, 16, device='cuda'),
torch.randn(1, 64, 1, 16, device='cuda')),
"mixer2_rms_norm_gated": lambda b: b.mixer2_rms_norm_gated(
torch.randn(4, 512, device='cuda'), torch.randn(4, 512, device='cuda'), 512),

# elementwise.py
"fused_dual_residual_rmsnorm": lambda b: b.fused_dual_residual_rmsnorm(
torch.randn(4, 512, device='cuda'), torch.randn(4, 512, device='cuda'), 512, 512),
"softcap": lambda b: b.softcap(torch.randn(4, 512, device='cuda')),
"silu_and_mul_triton": lambda b: b.silu_and_mul_triton(torch.randn(4, 1024, device='cuda')),
"gelu_and_mul_triton": lambda b: b.gelu_and_mul_triton(torch.randn(4, 1024, device='cuda')),
"fused_rmsnorm": lambda b: b.fused_rmsnorm(
torch.randn(4, 512, device='cuda'), torch.ones(512, device='cuda')),
"experts_combine_triton": lambda b: b.experts_combine_triton(
torch.randn(4, 2, 512, device='cuda'), torch.randn(4, 512, device='cuda')),

# gemma4_fused_ops.py
"gemma_rmsnorm_residual_scalar": lambda b: b.gemma_rmsnorm_residual_scalar(
torch.randn(4, 512, device='cuda'), torch.randn(512, device='cuda'),
torch.randn(4, 512, device='cuda'), torch.tensor(0.5, device='cuda')),
"gemma_qkv_rmsnorm": lambda b: b.gemma_qkv_rmsnorm(
torch.randn(4, 256, device='cuda'), torch.randn(4, 128, device='cuda'),
torch.randn(4, 128, device='cuda'), torch.randn(64, device='cuda'),
torch.randn(64, device='cuda'), num_q_heads=4, num_kv_heads=2, head_dim=64),

# conv.py
"conv2d_layer": lambda b: b.conv2d_layer(
torch.randn(4, 3, 32, 32, device='cuda'), 3, 16, 3),
"conv3d_layer": lambda b: b.conv3d_layer(
torch.randn(4, 3, 16, 16, 16, device='cuda'), 3, 16, 3),

# quantization/
"per_token_quant_int8": lambda b: b.per_token_quant_int8(
torch.randn(4, 512, device='cuda')),
}


def main():
print(f"Testing {len(TESTS)} operators (import only, no GPU available)...\n")

for name in sorted(TESTS.keys()):
try:
# Import the baseline module
mod = __import__(f"kernelgenbench.dataset.baseline.sglang.{name}", fromlist=[name])
fn = getattr(mod, name)
print(f" {name}: import OK, signature={fn.__code__.co_varnames[:fn.__code__.co_argcount]}")
results["pass"].append(name)

except Exception as e:
err = str(e).split('\n')[0][:120]
print(f" {name}: FAIL - {err}")
results["fail"].append(name)

print(f"\n=== Results ===")
print(f"PASS: {len(results['pass'])}")
print(f"FAIL: {len(results['fail'])}")
if results['fail']:
print(f"\nFailed operators:")
for name in results['fail']:
print(f" - {name}")


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions src/kernelgenbench/accuracy/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SGLang accuracy tests
from .test_silu_and_mul import test_accuracy_silu_and_mul
from .test_gelu_and_mul import test_accuracy_gelu_and_mul
from .test_quick_gelu import test_accuracy_quick_gelu
from .test_new_gelu import test_accuracy_new_gelu
from .test_xielu import test_accuracy_xielu
from .test_l2norm import test_accuracy_l2norm
from .test_softcap import test_accuracy_softcap
from .test_silu_and_mul_triton import test_accuracy_silu_and_mul_triton
from .test_gelu_and_mul_triton import test_accuracy_gelu_and_mul_triton
from .test_fused_rmsnorm import test_accuracy_fused_rmsnorm
from .test_rms_norm import test_accuracy_rms_norm
from .test_gemma_rms_norm import test_accuracy_gemma_rms_norm
from .test_rms_norm_gated import test_accuracy_rms_norm_gated
from .test_rotary_embedding import test_accuracy_rotary_embedding
from .test_layer_norm import test_accuracy_layer_norm
from .test_gemma3_rms_norm import test_accuracy_gemma3_rms_norm
from .test_gemma4_rms_norm import test_accuracy_gemma4_rms_norm
from .test_rms_norm_without_scale import test_accuracy_rms_norm_without_scale
from .test_fused_dual_residual_rmsnorm import test_accuracy_fused_dual_residual_rmsnorm
from .test_mrotary_embedding import test_accuracy_mrotary_embedding
from .test_dual_chunk_rotary_embedding import test_accuracy_dual_chunk_rotary_embedding
from .test_deepseek_scaling_rotary_embedding import test_accuracy_deepseek_scaling_rotary_embedding
from .test_llama3_rotary_embedding import test_accuracy_llama3_rotary_embedding
from .test_dynamic_ntk_scaling_rotary_embedding import test_accuracy_dynamic_ntk_scaling_rotary_embedding
from .test_linear_scaling_rotary_embedding import test_accuracy_linear_scaling_rotary_embedding
from .test_phi3_long_rope_scaled_rotary_embedding import test_accuracy_phi3_long_rope_scaled_rotary_embedding
from .test_triton_mrope_fused import test_accuracy_triton_mrope_fused
from .test_triton_ernie45_rope_fused import test_accuracy_triton_ernie45_rope_fused
from .test_fused_moe import test_accuracy_fused_moe
from .test_topk import test_accuracy_topk
from .test_moe_align_block_size import test_accuracy_moe_align_block_size
from .test_apply_interleaved_rope_triton import test_accuracy_apply_interleaved_rope_triton
from .test_dynamic_ntk_alpha_rotary_embedding import test_accuracy_dynamic_ntk_alpha_rotary_embedding
from .test_fused_recurrent_gated_delta_rule import test_accuracy_fused_recurrent_gated_delta_rule
from .test_fused_recurrent_gated_delta_rule_update import test_accuracy_fused_recurrent_gated_delta_rule_update
from .test_fused_sigmoid_gating_delta_rule_update import test_accuracy_fused_sigmoid_gating_delta_rule_update
from .test_fused_sigmoid_gating_delta_rule_packed_decode import test_accuracy_fused_sigmoid_gating_delta_rule_packed_decode
from .test_fused_gdn_gating import test_accuracy_fused_gdn_gating
from .test_layer_norm_gated_fwd import test_accuracy_layer_norm_gated_fwd
from .test_causal_conv1d_fn import test_accuracy_causal_conv1d_fn
from .test_causal_conv1d_update import test_accuracy_causal_conv1d_update
from .test_selective_scan_update import test_accuracy_selective_scan_update
from .test_mamba_chunk_scan_combined_fwd import test_accuracy_mamba_chunk_scan_combined_fwd
from .test_experts_combine_triton import test_accuracy_experts_combine_triton
from .test_gemma_rmsnorm_residual_scalar import test_accuracy_gemma_rmsnorm_residual_scalar
from .test_gemma_qkv_rmsnorm import test_accuracy_gemma_qkv_rmsnorm
from .test_conv2d_layer import test_accuracy_conv2d_layer
from .test_conv3d_layer import test_accuracy_conv3d_layer
from .test_mixer2_rms_norm_gated import test_accuracy_mixer2_rms_norm_gated
from .test_per_token_quant_int8 import test_accuracy_per_token_quant_int8
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Accuracy and benchmark test for SGLang apply_interleaved_rope_triton.
Source: sglang.srt.layers.rotary_embedding.mrope.apply_interleaved_rope_triton(x [3,N,D], mrope_section) -> [N,D]
"""
import kernelgenbench
from sandbox.config import DEVICE as device
from sandbox.verifier.test_parametrize import parametrize, label
from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close
from sandbox.utils.accuracy_utils import CustomBenchmarkResult
import torch
import triton

@label("apply_interleaved_rope_triton")
@parametrize("seq_len", [128, 512, 1024])
@parametrize("rotary_dim", [64, 128])
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("mrope_section", [[8, 12, 12], [24, 20, 20]])
def test_accuracy_apply_interleaved_rope_triton(seq_len, rotary_dim, dtype, mrope_section):
N, D = seq_len, rotary_dim
x = torch.randn(3, N, D, device='cuda', dtype=dtype)

ref_out = kernelgenbench.baseline.apply_interleaved_rope_triton(x, mrope_section)
act_out = kernelgenbench.baseline.apply_interleaved_rope_triton(x.clone(), mrope_section)

assert_close(act_out, ref_out, dtype)

if seq_len < 512:
return None

x_bench = torch.randn(3, N, D, device='cuda', dtype=dtype)
ms_baseline = triton.testing.do_bench(
lambda: kernelgenbench.baseline.apply_interleaved_rope_triton(x_bench, mrope_section),
warmup=25, rep=100
)
speedup = 1.0
return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup)
30 changes: 30 additions & 0 deletions src/kernelgenbench/accuracy/sglang/test_causal_conv1d_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Accuracy and benchmark test for SGLang causal_conv1d_fn.
Source: causal_conv1d_fn(x [B,T,D], weight [D,width], bias, query_start_loc, ...) (sgl_kernel)
"""
import kernelgenbench
from sandbox.config import DEVICE as device
from sandbox.verifier.test_parametrize import parametrize, label
from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close
from sandbox.utils.accuracy_utils import CustomBenchmarkResult
import torch
import triton

@label("causal_conv1d_fn")
@parametrize("shape", [(64, 64), (256, 128)])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_accuracy_causal_conv1d_fn(shape, dtype):
M, N = shape
x = torch.randn(M, N, device='cuda', dtype=dtype)
weight = torch.randn(N, 4, device='cuda', dtype=dtype)
ref_out = kernelgenbench.baseline.causal_conv1d_fn(x, weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu')
act_out = kernelgenbench.baseline.causal_conv1d_fn(x.clone(), weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu')
assert_close(act_out, ref_out, dtype)
if M < 256:
return None
ms_baseline = triton.testing.do_bench(
lambda: kernelgenbench.baseline.causal_conv1d_fn(x, weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu'),
warmup=25, rep=100
)
speedup = 1.0
return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup)
31 changes: 31 additions & 0 deletions src/kernelgenbench/accuracy/sglang/test_causal_conv1d_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Accuracy and benchmark test for SGLang causal_conv1d_update.
Source: causal_conv1d_update(x [B,D], conv_state [B,D,width-1], weight [D,width], bias, activation) (sgl_kernel)
"""
import kernelgenbench
from sandbox.config import DEVICE as device
from sandbox.verifier.test_parametrize import parametrize, label
from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close
from sandbox.utils.accuracy_utils import CustomBenchmarkResult
import torch
import triton

@label("causal_conv1d_update")
@parametrize("shape", [(64, 64), (256, 128)])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_accuracy_causal_conv1d_update(shape, dtype):
M, N = shape
x = torch.randn(M, N, device='cuda', dtype=dtype)
weight = torch.randn(N, 4, device='cuda', dtype=dtype)
conv_state = torch.randn(M, N, 3, device='cuda', dtype=dtype)
ref_out, ref_conv_state = kernelgenbench.baseline.causal_conv1d_update(x, conv_state, weight, bias=None, activation='silu')
act_out, act_conv_state = kernelgenbench.baseline.causal_conv1d_update(x.clone(), conv_state.clone(), weight, bias=None, activation='silu')
assert_close(act_out, ref_out, dtype)
if M < 256:
return None
ms_baseline = triton.testing.do_bench(
lambda: kernelgenbench.baseline.causal_conv1d_update(x, conv_state, weight, bias=None, activation='silu'),
warmup=25, rep=100
)
speedup = 1.0
return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup)
31 changes: 31 additions & 0 deletions src/kernelgenbench/accuracy/sglang/test_conv2d_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Accuracy and benchmark test for SGLang conv2d_layer.
Source: Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, bias).forward_cuda(x [N,C,H,W])
"""
import kernelgenbench
from sandbox.config import DEVICE as device
from sandbox.verifier.test_parametrize import parametrize, label
from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close
from sandbox.utils.accuracy_utils import CustomBenchmarkResult
import torch
import triton

@label("conv2d_layer")
@parametrize("shape", [(64, 64), (256, 128)])
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_accuracy_conv2d_layer(shape, dtype):
M, N = shape
in_c, out_c, ks = 3, N, 3
x = torch.randn(M, in_c, 32, 32, device='cuda', dtype=dtype)
ref_out = kernelgenbench.baseline.conv2d_layer(x, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False)
act_out = kernelgenbench.baseline.conv2d_layer(x.clone(), in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False)
assert_close(act_out, ref_out, dtype)
if M < 256:
return None
x_b = torch.randn(M, in_c, 32, 32, device='cuda', dtype=dtype)
ms_baseline = triton.testing.do_bench(
lambda: kernelgenbench.baseline.conv2d_layer(x_b, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False),
warmup=25, rep=100
)
speedup = 1.0
return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup)
Loading
Loading