diff --git a/scripts/test_sglang_baselines.py b/scripts/test_sglang_baselines.py new file mode 100644 index 0000000..bfc87f2 --- /dev/null +++ b/scripts/test_sglang_baselines.py @@ -0,0 +1,206 @@ +""" +Test all 50 SGLang baseline operators. +Verifies: import OK, function exists, can be called with minimal inputs. +""" +import sys +import traceback +import time + +sys.path.insert(0, "/share/project/zpy/flagbench/src") + +results = {"pass": [], "fail": [], "skip": []} + +# Minimal test inputs for each operator +TESTS = { + # activation.py + "silu_and_mul": lambda b: b.silu_and_mul(torch.randn(4, 512, device='cuda')), + "gelu_and_mul": lambda b: b.gelu_and_mul(torch.randn(4, 512, device='cuda')), + "quick_gelu": lambda b: b.quick_gelu(torch.randn(4, 512, device='cuda')), + "new_gelu": lambda b: b.new_gelu(torch.randn(4, 512, device='cuda')), + "xielu": lambda b: b.xielu(torch.randn(4, 512, device='cuda')), + + # layernorm.py + "rms_norm": lambda b: b.rms_norm(torch.randn(4, 512, device='cuda'), + torch.ones(512, device='cuda')), + "layer_norm": lambda b: b.layer_norm(torch.randn(4, 512, device='cuda'), 512), + "gemma_rms_norm": lambda b: b.gemma_rms_norm(torch.randn(4, 512, device='cuda'), + torch.ones(512, device='cuda')), + "gemma3_rms_norm": lambda b: b.gemma3_rms_norm(torch.randn(4, 512, device='cuda'), 512), + "gemma4_rms_norm": lambda b: b.gemma4_rms_norm(torch.randn(4, 512, device='cuda'), 512), + "rms_norm_without_scale": lambda b: b.rms_norm_without_scale(torch.randn(4, 512, device='cuda'), 512), + + # rotary_embedding/ + "rotary_embedding": lambda b: b.rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "mrotary_embedding": lambda b: b.mrotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64, mrope_section=[10, 11, 11]), + "dual_chunk_rotary_embedding": lambda b: b.dual_chunk_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "deepseek_scaling_rotary_embedding": lambda b: b.deepseek_scaling_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "llama3_rotary_embedding": lambda b: b.llama3_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "dynamic_ntk_scaling_rotary_embedding": lambda b: b.dynamic_ntk_scaling_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "linear_scaling_rotary_embedding": lambda b: b.linear_scaling_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "phi3_long_rope_scaled_rotary_embedding": lambda b: b.phi3_long_rope_scaled_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64), + "triton_mrope_fused": lambda b: b.triton_mrope_fused( + torch.arange(128, device='cuda'), + torch.randn(128, 4 * 64, device='cuda'), + torch.randn(128, 4 * 64, device='cuda'), + torch.randn(8192, 128, device='cuda'), head_size=128, rotary_dim=128, + mrope_section=[42, 42, 44]), + "triton_ernie45_rope_fused": lambda b: b.triton_ernie45_rope_fused( + torch.arange(128, device='cuda'), + torch.randn(128, 4 * 64, device='cuda'), + torch.randn(128, 4 * 64, device='cuda'), + torch.randn(8192, 128, device='cuda'), head_size=128, rotary_dim=128, + mrope_section=[42, 42, 44]), + "apply_interleaved_rope_triton": lambda b: b.apply_interleaved_rope_triton( + torch.randn(3, 128, 64, device='cuda'), mrope_section=[10, 11, 11]), + "dynamic_ntk_alpha_rotary_embedding": lambda b: b.dynamic_ntk_alpha_rotary_embedding( + torch.arange(128, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + torch.randn(128, 4, 64, device='cuda'), + head_size=64, rotary_dim=64, scaling_alpha=2.0), + + # moe/ + "fused_moe": lambda b: None, # needs routing_data — complex, skip for now + "topk": lambda b: None, # needs full module init + "moe_align_block_size": lambda b: b.moe_align_block_size( + torch.randint(0, 4, (128, 2), device='cuda', dtype=torch.int32), + num_experts=4, block_size=32), + + # attention/fla/ + "l2norm": lambda b: b.l2norm(torch.randn(4, 512, device='cuda')), + "rms_norm_gated": lambda b: b.rms_norm_gated( + torch.randn(4, 512, device='cuda'), torch.randn(512, device='cuda')), + "fused_recurrent_gated_delta_rule": lambda b: b.fused_recurrent_gated_delta_rule( + q=torch.randn(1, 64, 4, 64, device='cuda'), + k=torch.randn(1, 64, 4, 64, device='cuda'), + v=torch.randn(1, 64, 4, 64, device='cuda'), + g=torch.randn(1, 64, 4, device='cuda'), + beta=torch.randn(1, 64, 4, device='cuda'), + scale=0.125), + "fused_recurrent_gated_delta_rule_update": lambda b: None, # complex + "fused_sigmoid_gating_delta_rule_update": lambda b: None, # complex + "fused_sigmoid_gating_delta_rule_packed_decode": lambda b: None, # complex + "fused_gdn_gating": lambda b: b.fused_gdn_gating( + A_log=torch.randn(16, device='cuda'), + a=torch.randn(4, 16, device='cuda'), + b=torch.randn(4, 16, device='cuda'), + dt_bias=torch.randn(16, device='cuda')), + "layer_norm_gated_fwd": lambda b: b.layer_norm_gated_fwd( + torch.randn(4, 512, device='cuda'), + torch.randn(4, 512, device='cuda'), + torch.ones(512, device='cuda'), torch.zeros(512, device='cuda')), + + # attention/mamba/ + "causal_conv1d_fn": lambda b: b.causal_conv1d_fn( + torch.randn(8, 64, device='cuda'), + torch.randn(64, 4, device='cuda')), + "causal_conv1d_update": lambda b: b.causal_conv1d_update( + torch.randn(4, 64, device='cuda'), + torch.randn(4, 64, 3, device='cuda'), + torch.randn(64, 4, device='cuda')), + "selective_scan_update": lambda b: b.selective_scan_update( + state=torch.randn(4, 64, 16, device='cuda'), + x=torch.randn(4, 64, device='cuda'), + dt=torch.randn(4, 64, device='cuda'), + A=torch.randn(64, 16, device='cuda'), + B=torch.randn(4, 16, device='cuda'), + C=torch.randn(4, 16, device='cuda')), + "mamba_chunk_scan_combined_fwd": lambda b: b.mamba_chunk_scan_combined_fwd( + torch.randn(1, 64, 4, 64, device='cuda'), + torch.randn(1, 64, 4, device='cuda'), + torch.randn(4, device='cuda'), + torch.randn(1, 64, 1, 16, device='cuda'), + torch.randn(1, 64, 1, 16, device='cuda')), + "mixer2_rms_norm_gated": lambda b: b.mixer2_rms_norm_gated( + torch.randn(4, 512, device='cuda'), torch.randn(4, 512, device='cuda'), 512), + + # elementwise.py + "fused_dual_residual_rmsnorm": lambda b: b.fused_dual_residual_rmsnorm( + torch.randn(4, 512, device='cuda'), torch.randn(4, 512, device='cuda'), 512, 512), + "softcap": lambda b: b.softcap(torch.randn(4, 512, device='cuda')), + "silu_and_mul_triton": lambda b: b.silu_and_mul_triton(torch.randn(4, 1024, device='cuda')), + "gelu_and_mul_triton": lambda b: b.gelu_and_mul_triton(torch.randn(4, 1024, device='cuda')), + "fused_rmsnorm": lambda b: b.fused_rmsnorm( + torch.randn(4, 512, device='cuda'), torch.ones(512, device='cuda')), + "experts_combine_triton": lambda b: b.experts_combine_triton( + torch.randn(4, 2, 512, device='cuda'), torch.randn(4, 512, device='cuda')), + + # gemma4_fused_ops.py + "gemma_rmsnorm_residual_scalar": lambda b: b.gemma_rmsnorm_residual_scalar( + torch.randn(4, 512, device='cuda'), torch.randn(512, device='cuda'), + torch.randn(4, 512, device='cuda'), torch.tensor(0.5, device='cuda')), + "gemma_qkv_rmsnorm": lambda b: b.gemma_qkv_rmsnorm( + torch.randn(4, 256, device='cuda'), torch.randn(4, 128, device='cuda'), + torch.randn(4, 128, device='cuda'), torch.randn(64, device='cuda'), + torch.randn(64, device='cuda'), num_q_heads=4, num_kv_heads=2, head_dim=64), + + # conv.py + "conv2d_layer": lambda b: b.conv2d_layer( + torch.randn(4, 3, 32, 32, device='cuda'), 3, 16, 3), + "conv3d_layer": lambda b: b.conv3d_layer( + torch.randn(4, 3, 16, 16, 16, device='cuda'), 3, 16, 3), + + # quantization/ + "per_token_quant_int8": lambda b: b.per_token_quant_int8( + torch.randn(4, 512, device='cuda')), +} + + +def main(): + print(f"Testing {len(TESTS)} operators (import only, no GPU available)...\n") + + for name in sorted(TESTS.keys()): + try: + # Import the baseline module + mod = __import__(f"kernelgenbench.dataset.baseline.sglang.{name}", fromlist=[name]) + fn = getattr(mod, name) + print(f" {name}: import OK, signature={fn.__code__.co_varnames[:fn.__code__.co_argcount]}") + results["pass"].append(name) + + except Exception as e: + err = str(e).split('\n')[0][:120] + print(f" {name}: FAIL - {err}") + results["fail"].append(name) + + print(f"\n=== Results ===") + print(f"PASS: {len(results['pass'])}") + print(f"FAIL: {len(results['fail'])}") + if results['fail']: + print(f"\nFailed operators:") + for name in results['fail']: + print(f" - {name}") + + +if __name__ == "__main__": + main() diff --git a/src/kernelgenbench/accuracy/sglang/__init__.py b/src/kernelgenbench/accuracy/sglang/__init__.py new file mode 100644 index 0000000..e27c304 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/__init__.py @@ -0,0 +1,51 @@ +# SGLang accuracy tests +from .test_silu_and_mul import test_accuracy_silu_and_mul +from .test_gelu_and_mul import test_accuracy_gelu_and_mul +from .test_quick_gelu import test_accuracy_quick_gelu +from .test_new_gelu import test_accuracy_new_gelu +from .test_xielu import test_accuracy_xielu +from .test_l2norm import test_accuracy_l2norm +from .test_softcap import test_accuracy_softcap +from .test_silu_and_mul_triton import test_accuracy_silu_and_mul_triton +from .test_gelu_and_mul_triton import test_accuracy_gelu_and_mul_triton +from .test_fused_rmsnorm import test_accuracy_fused_rmsnorm +from .test_rms_norm import test_accuracy_rms_norm +from .test_gemma_rms_norm import test_accuracy_gemma_rms_norm +from .test_rms_norm_gated import test_accuracy_rms_norm_gated +from .test_rotary_embedding import test_accuracy_rotary_embedding +from .test_layer_norm import test_accuracy_layer_norm +from .test_gemma3_rms_norm import test_accuracy_gemma3_rms_norm +from .test_gemma4_rms_norm import test_accuracy_gemma4_rms_norm +from .test_rms_norm_without_scale import test_accuracy_rms_norm_without_scale +from .test_fused_dual_residual_rmsnorm import test_accuracy_fused_dual_residual_rmsnorm +from .test_mrotary_embedding import test_accuracy_mrotary_embedding +from .test_dual_chunk_rotary_embedding import test_accuracy_dual_chunk_rotary_embedding +from .test_deepseek_scaling_rotary_embedding import test_accuracy_deepseek_scaling_rotary_embedding +from .test_llama3_rotary_embedding import test_accuracy_llama3_rotary_embedding +from .test_dynamic_ntk_scaling_rotary_embedding import test_accuracy_dynamic_ntk_scaling_rotary_embedding +from .test_linear_scaling_rotary_embedding import test_accuracy_linear_scaling_rotary_embedding +from .test_phi3_long_rope_scaled_rotary_embedding import test_accuracy_phi3_long_rope_scaled_rotary_embedding +from .test_triton_mrope_fused import test_accuracy_triton_mrope_fused +from .test_triton_ernie45_rope_fused import test_accuracy_triton_ernie45_rope_fused +from .test_fused_moe import test_accuracy_fused_moe +from .test_topk import test_accuracy_topk +from .test_moe_align_block_size import test_accuracy_moe_align_block_size +from .test_apply_interleaved_rope_triton import test_accuracy_apply_interleaved_rope_triton +from .test_dynamic_ntk_alpha_rotary_embedding import test_accuracy_dynamic_ntk_alpha_rotary_embedding +from .test_fused_recurrent_gated_delta_rule import test_accuracy_fused_recurrent_gated_delta_rule +from .test_fused_recurrent_gated_delta_rule_update import test_accuracy_fused_recurrent_gated_delta_rule_update +from .test_fused_sigmoid_gating_delta_rule_update import test_accuracy_fused_sigmoid_gating_delta_rule_update +from .test_fused_sigmoid_gating_delta_rule_packed_decode import test_accuracy_fused_sigmoid_gating_delta_rule_packed_decode +from .test_fused_gdn_gating import test_accuracy_fused_gdn_gating +from .test_layer_norm_gated_fwd import test_accuracy_layer_norm_gated_fwd +from .test_causal_conv1d_fn import test_accuracy_causal_conv1d_fn +from .test_causal_conv1d_update import test_accuracy_causal_conv1d_update +from .test_selective_scan_update import test_accuracy_selective_scan_update +from .test_mamba_chunk_scan_combined_fwd import test_accuracy_mamba_chunk_scan_combined_fwd +from .test_experts_combine_triton import test_accuracy_experts_combine_triton +from .test_gemma_rmsnorm_residual_scalar import test_accuracy_gemma_rmsnorm_residual_scalar +from .test_gemma_qkv_rmsnorm import test_accuracy_gemma_qkv_rmsnorm +from .test_conv2d_layer import test_accuracy_conv2d_layer +from .test_conv3d_layer import test_accuracy_conv3d_layer +from .test_mixer2_rms_norm_gated import test_accuracy_mixer2_rms_norm_gated +from .test_per_token_quant_int8 import test_accuracy_per_token_quant_int8 diff --git a/src/kernelgenbench/accuracy/sglang/test_apply_interleaved_rope_triton.py b/src/kernelgenbench/accuracy/sglang/test_apply_interleaved_rope_triton.py new file mode 100644 index 0000000..70f7982 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_apply_interleaved_rope_triton.py @@ -0,0 +1,36 @@ +""" +Accuracy and benchmark test for SGLang apply_interleaved_rope_triton. +Source: sglang.srt.layers.rotary_embedding.mrope.apply_interleaved_rope_triton(x [3,N,D], mrope_section) -> [N,D] +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("apply_interleaved_rope_triton") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("rotary_dim", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("mrope_section", [[8, 12, 12], [24, 20, 20]]) +def test_accuracy_apply_interleaved_rope_triton(seq_len, rotary_dim, dtype, mrope_section): + N, D = seq_len, rotary_dim + x = torch.randn(3, N, D, device='cuda', dtype=dtype) + + ref_out = kernelgenbench.baseline.apply_interleaved_rope_triton(x, mrope_section) + act_out = kernelgenbench.baseline.apply_interleaved_rope_triton(x.clone(), mrope_section) + + assert_close(act_out, ref_out, dtype) + + if seq_len < 512: + return None + + x_bench = torch.randn(3, N, D, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.apply_interleaved_rope_triton(x_bench, mrope_section), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_fn.py b/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_fn.py new file mode 100644 index 0000000..2c6bc83 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_fn.py @@ -0,0 +1,30 @@ +""" +Accuracy and benchmark test for SGLang causal_conv1d_fn. +Source: causal_conv1d_fn(x [B,T,D], weight [D,width], bias, query_start_loc, ...) (sgl_kernel) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("causal_conv1d_fn") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_causal_conv1d_fn(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, 4, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.causal_conv1d_fn(x, weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu') + act_out = kernelgenbench.baseline.causal_conv1d_fn(x.clone(), weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu') + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.causal_conv1d_fn(x, weight, bias=None, query_start_loc=None, cache_indices=None, has_initial_state=None, conv_states=None, activation='silu'), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_update.py b/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_update.py new file mode 100644 index 0000000..2eb410a --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_causal_conv1d_update.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang causal_conv1d_update. +Source: causal_conv1d_update(x [B,D], conv_state [B,D,width-1], weight [D,width], bias, activation) (sgl_kernel) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("causal_conv1d_update") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_causal_conv1d_update(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, 4, device='cuda', dtype=dtype) + conv_state = torch.randn(M, N, 3, device='cuda', dtype=dtype) + ref_out, ref_conv_state = kernelgenbench.baseline.causal_conv1d_update(x, conv_state, weight, bias=None, activation='silu') + act_out, act_conv_state = kernelgenbench.baseline.causal_conv1d_update(x.clone(), conv_state.clone(), weight, bias=None, activation='silu') + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.causal_conv1d_update(x, conv_state, weight, bias=None, activation='silu'), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_conv2d_layer.py b/src/kernelgenbench/accuracy/sglang/test_conv2d_layer.py new file mode 100644 index 0000000..9505f57 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_conv2d_layer.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang conv2d_layer. +Source: Conv2dLayer(in_channels, out_channels, kernel_size, stride, padding, bias).forward_cuda(x [N,C,H,W]) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("conv2d_layer") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_conv2d_layer(shape, dtype): + M, N = shape + in_c, out_c, ks = 3, N, 3 + x = torch.randn(M, in_c, 32, 32, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.conv2d_layer(x, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False) + act_out = kernelgenbench.baseline.conv2d_layer(x.clone(), in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + x_b = torch.randn(M, in_c, 32, 32, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.conv2d_layer(x_b, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_conv3d_layer.py b/src/kernelgenbench/accuracy/sglang/test_conv3d_layer.py new file mode 100644 index 0000000..0936980 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_conv3d_layer.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang conv3d_layer. +Source: Conv3dLayer(in_channels, out_channels, kernel_size, stride, padding, bias).forward_cuda(x [N,C,D,H,W]) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("conv3d_layer") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_conv3d_layer(shape, dtype): + M, N = shape + in_c, out_c, ks = 3, N, 3 + x = torch.randn(M, in_c, 16, 16, 16, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.conv3d_layer(x, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False) + act_out = kernelgenbench.baseline.conv3d_layer(x.clone(), in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + x_b = torch.randn(M, in_c, 16, 16, 16, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.conv3d_layer(x_b, in_channels=in_c, out_channels=out_c, kernel_size=ks, stride=1, padding=1, bias=False), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_deepseek_scaling_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_deepseek_scaling_rotary_embedding.py new file mode 100644 index 0000000..857bfb6 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_deepseek_scaling_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang deepseek_scaling_rotary_embedding. +Source: DeepseekScalingRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_factor).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("deepseek_scaling_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_deepseek_scaling_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.deepseek_scaling_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=1.0) + kernelgenbench.baseline.deepseek_scaling_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=1.0) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.deepseek_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=1.0) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.deepseek_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=1.0) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_dual_chunk_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_dual_chunk_rotary_embedding.py new file mode 100644 index 0000000..7b6f89b --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_dual_chunk_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang dual_chunk_rotary_embedding. +Source: DualChunkRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,chunk_size,local_size).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("dual_chunk_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_dual_chunk_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.dual_chunk_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + kernelgenbench.baseline.dual_chunk_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dual_chunk_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dual_chunk_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_alpha_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_alpha_rotary_embedding.py new file mode 100644 index 0000000..068a2a4 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_alpha_rotary_embedding.py @@ -0,0 +1,59 @@ +""" +Accuracy and benchmark test for SGLang DynamicNTKAlphaRotaryEmbedding. +Source: DynamicNTKAlphaRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_alpha).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("dynamic_ntk_alpha_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +@parametrize("scaling_alpha", [2.0, 4.0]) +def test_accuracy_dynamic_ntk_alpha_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox, scaling_alpha): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.dynamic_ntk_alpha_rotary_embedding( + ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, + scaling_alpha=scaling_alpha) + kernelgenbench.baseline.dynamic_ntk_alpha_rotary_embedding( + act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, + scaling_alpha=scaling_alpha) + + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + + if seq_len < 512: + return None + + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dynamic_ntk_alpha_rotary_embedding( + q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, + scaling_alpha=scaling_alpha) + + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dynamic_ntk_alpha_rotary_embedding( + q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, + scaling_alpha=scaling_alpha) + + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_scaling_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_scaling_rotary_embedding.py new file mode 100644 index 0000000..d5ec8e0 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_dynamic_ntk_scaling_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang dynamic_ntk_scaling_rotary_embedding. +Source: DynamicNTKScalingRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_factor).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("dynamic_ntk_scaling_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_dynamic_ntk_scaling_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.dynamic_ntk_scaling_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + kernelgenbench.baseline.dynamic_ntk_scaling_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dynamic_ntk_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.dynamic_ntk_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_experts_combine_triton.py b/src/kernelgenbench/accuracy/sglang/test_experts_combine_triton.py new file mode 100644 index 0000000..88320f4 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_experts_combine_triton.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang experts_combine_triton. +Source: experts_combine_triton(moe_hidden_states [N,K,D], mlp_hidden_states [N,D], output_buffer) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("experts_combine_triton") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_experts_combine_triton(shape, dtype): + M, N = shape + K = 4 + moe_hidden = torch.randn(M, K, N, device='cuda', dtype=dtype) + mlp_hidden = torch.randn(M, N, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.experts_combine_triton(moe_hidden, mlp_hidden, output_buffer=None) + act_out = kernelgenbench.baseline.experts_combine_triton(moe_hidden.clone(), mlp_hidden.clone(), output_buffer=None) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.experts_combine_triton(moe_hidden, mlp_hidden, output_buffer=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_dual_residual_rmsnorm.py b/src/kernelgenbench/accuracy/sglang/test_fused_dual_residual_rmsnorm.py new file mode 100644 index 0000000..996257f --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_dual_residual_rmsnorm.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang FusedDualResidualRMSNorm. +Source: FusedDualResidualRMSNorm(rmsnorm1, rmsnorm2).forward(x, residual) -> (out, residual) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_dual_residual_rmsnorm") +@parametrize("shape", [(1, 32), (128, 512), (1024, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_fused_dual_residual_rmsnorm(shape, dtype): + M, hidden_size = shape + x = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + residual = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + + ref_out, ref_res = kernelgenbench.baseline.fused_dual_residual_rmsnorm( + x, residual, hidden_size, hidden_size, eps=1e-6) + act_out, act_res = kernelgenbench.baseline.fused_dual_residual_rmsnorm( + x.clone(), residual.clone(), hidden_size, hidden_size, eps=1e-6) + + assert_close(act_out, ref_out, dtype) + assert_close(act_res, ref_res, dtype) + + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + r_bench = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_dual_residual_rmsnorm( + x_bench, r_bench, hidden_size, hidden_size, eps=1e-6), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_gdn_gating.py b/src/kernelgenbench/accuracy/sglang/test_fused_gdn_gating.py new file mode 100644 index 0000000..092ed5c --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_gdn_gating.py @@ -0,0 +1,34 @@ +""" +Accuracy and benchmark test for SGLang fused_gdn_gating. +Source: fused_gdn_gating(A_log [HV], a [B,HV], b [B,HV], dt_bias [HV], beta, threshold) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_gdn_gating") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_gdn_gating(shape, dtype): + M, N = shape + B, HV = 4, 16 + A_log = torch.randn(HV, device='cuda', dtype=dtype) + a = torch.randn(B, HV, device='cuda', dtype=dtype) + b = torch.randn(B, HV, device='cuda', dtype=dtype) + dt_bias = torch.randn(HV, device='cuda', dtype=dtype) + ref_g, ref_beta_out = kernelgenbench.baseline.fused_gdn_gating(A_log=A_log, a=a, b=b, dt_bias=dt_bias, beta=1.0, threshold=20.0) + act_g, act_beta_out = kernelgenbench.baseline.fused_gdn_gating(A_log=A_log, a=a.clone(), b=b.clone(), dt_bias=dt_bias, beta=1.0, threshold=20.0) + assert_close(act_g, ref_g, dtype) + assert_close(act_beta_out, ref_beta_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_gdn_gating(A_log=A_log, a=a, b=b, dt_bias=dt_bias, beta=1.0, threshold=20.0), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_moe.py b/src/kernelgenbench/accuracy/sglang/test_fused_moe.py new file mode 100644 index 0000000..2d3433d --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_moe.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang fused_moe. +Source: triton_kernel_fused_experts(hidden_states [M,N], w1 [E,N,2*I], w2 [E,I,N], routing_data, gather_indx, scatter_indx) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_moe") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_moe(shape, dtype): + M, N = shape + E, I, topk = 4, N, 2 + hidden_states = torch.randn(M, N, device='cuda', dtype=dtype) + w1 = torch.randn(E, N, 2 * I, device='cuda', dtype=dtype) + w2 = torch.randn(E, I, N, device='cuda', dtype=dtype) + from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx + from triton_kernels.routing import compute_expt_data_torch + logits = torch.randn(M, E, device='cuda', dtype=torch.float32) + routing_data, gather_idx, scatter_idx = compute_expt_data_torch(logits, topk, M, 0.0, 0.0, False) + routing_data = RoutingData(*routing_data) + gather_idx = GatherIndx(*gather_idx) + scatter_idx = ScatterIndx(*scatter_idx) + ref_out = kernelgenbench.baseline.fused_moe(hidden_states, w1, w2, routing_data, gather_idx, scatter_idx, inplace=False, activation='silu') + act_out = kernelgenbench.baseline.fused_moe(hidden_states.clone(), w1, w2, routing_data, gather_idx, scatter_idx, inplace=False, activation='silu') + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + hidden_b = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_moe(hidden_b, w1, w2, routing_data, gather_idx, scatter_idx, inplace=False, activation='silu'), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule.py b/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule.py new file mode 100644 index 0000000..f659925 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule.py @@ -0,0 +1,43 @@ +""" +Accuracy and benchmark test for SGLang fused_recurrent_gated_delta_rule. +Source: fused_recurrent_gated_delta_rule(q[B,T,H,K], k[B,T,H,K], v[B,T,HV,V], g[B,T,HV], beta[B,T,HV], scale) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_recurrent_gated_delta_rule") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_recurrent_gated_delta_rule(shape, dtype): + M, N = shape + B, T, H, K, HV, V = 1, 64, 4, 64, 4, 64 + scale = K ** -0.5 + q = torch.randn(B, T, H, K, device='cuda', dtype=dtype) + k = torch.randn(B, T, H, K, device='cuda', dtype=dtype) + v = torch.randn(B, T, HV, V, device='cuda', dtype=dtype) + g = torch.randn(B, T, HV, device='cuda', dtype=dtype) + beta = torch.randn(B, T, HV, device='cuda', dtype=dtype) + ref_out, ref_state = kernelgenbench.baseline.fused_recurrent_gated_delta_rule( + q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=None, output_final_state=False, cu_seqlens=None) + act_out, act_state = kernelgenbench.baseline.fused_recurrent_gated_delta_rule( + q=q.clone(), k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), scale=scale, initial_state=None, output_final_state=False, cu_seqlens=None) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + q_b = torch.randn(B, T, H, K, device='cuda', dtype=dtype) + k_b = torch.randn(B, T, H, K, device='cuda', dtype=dtype) + v_b = torch.randn(B, T, HV, V, device='cuda', dtype=dtype) + g_b = torch.randn(B, T, HV, device='cuda', dtype=dtype) + beta_b = torch.randn(B, T, HV, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_recurrent_gated_delta_rule( + q=q_b, k=k_b, v=v_b, g=g_b, beta=beta_b, scale=scale, initial_state=None, output_final_state=False, cu_seqlens=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule_update.py b/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule_update.py new file mode 100644 index 0000000..3d67926 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_recurrent_gated_delta_rule_update.py @@ -0,0 +1,42 @@ +""" +Accuracy and benchmark test for SGLang fused_recurrent_gated_delta_rule_update. +Source: fused_recurrent_gated_delta_rule_update(q,k,v,g,beta,scale,initial_state,cu_seqlens,initial_state_indices,intermediate_states,eagle_tree) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_recurrent_gated_delta_rule_update") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_recurrent_gated_delta_rule_update(shape, dtype): + M, N = shape + B, H, K, HV, V = 1, 4, 64, 4, 64 + scale = K ** -0.5 + q = torch.randn(B, 1, H, K, device='cuda', dtype=dtype) + k = torch.randn(B, 1, H, K, device='cuda', dtype=dtype) + v = torch.randn(B, 1, HV, V, device='cuda', dtype=dtype) + g = torch.randn(B, 1, HV, device='cuda', dtype=dtype) + beta = torch.randn(B, 1, HV, device='cuda', dtype=dtype) + init_state = torch.randn(B, HV, V, K, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.fused_recurrent_gated_delta_rule_update( + q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=init_state, + cu_seqlens=None, initial_state_indices=None, intermediate_states=None, eagle_tree=None) + act_out = kernelgenbench.baseline.fused_recurrent_gated_delta_rule_update( + q=q.clone(), k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), scale=scale, initial_state=init_state.clone(), + cu_seqlens=None, initial_state_indices=None, intermediate_states=None, eagle_tree=None) + assert_close(act_out, ref_out, dtype, strict=False) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_recurrent_gated_delta_rule_update( + q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=init_state, + cu_seqlens=None, initial_state_indices=None, intermediate_states=None, eagle_tree=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_rmsnorm.py b/src/kernelgenbench/accuracy/sglang/test_fused_rmsnorm.py new file mode 100644 index 0000000..fb6de40 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_rmsnorm.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang fused_rmsnorm. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_rmsnorm") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_fused_rmsnorm(shape, dtype): + """Accuracy and performance test for SGLang fused_rmsnorm.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.fused_rmsnorm(x) + act_out = kernelgenbench.triton.fused_rmsnorm(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_rmsnorm(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.fused_rmsnorm(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_packed_decode.py b/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_packed_decode.py new file mode 100644 index 0000000..357e9fa --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_packed_decode.py @@ -0,0 +1,48 @@ +""" +Accuracy and benchmark test for SGLang fused_sigmoid_gating_delta_rule_packed_decode. +Source: fused_sigmoid_gating_delta_rule_packed_decode(mixed_qkv,a,b,A_log,dt_bias,scale,initial_state,out,ssm_state_indices)->(out,initial_state) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_sigmoid_gating_delta_rule_packed_decode") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_sigmoid_gating_delta_rule_packed_decode(shape, dtype): + M, N = shape + B, H, K, HV, V = 4, 4, 64, 4, 64 + scale = K ** -0.5 + mixed_qkv = torch.randn(B, 2*H*K + HV*V, device='cuda', dtype=dtype) + a = torch.randn(B, HV, device='cuda', dtype=dtype) + b = torch.randn(B, HV, device='cuda', dtype=dtype) + A_log = torch.randn(HV, device='cuda', dtype=dtype) + dt_bias = torch.randn(HV, device='cuda', dtype=dtype) + pool_size = 1 + init_state = torch.randn(pool_size, HV, V, K, device='cuda', dtype=dtype) + out = torch.zeros(B, 1, HV, V, device='cuda', dtype=dtype) + ssm_state_indices = torch.arange(B, device='cuda', dtype=torch.int32) + ref_out, ref_state = kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_packed_decode( + mixed_qkv=mixed_qkv, a=a, b=b, A_log=A_log, dt_bias=dt_bias, scale=scale, + initial_state=init_state, out=out, ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=False) + act_out, act_state = kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_packed_decode( + mixed_qkv=mixed_qkv.clone(), a=a.clone(), b=b.clone(), A_log=A_log, dt_bias=dt_bias, scale=scale, + initial_state=init_state.clone(), out=out.clone(), ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=False) + assert_close(act_out, ref_out, dtype, strict=False) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_packed_decode( + mixed_qkv=mixed_qkv, a=a, b=b, A_log=A_log, dt_bias=dt_bias, scale=scale, + initial_state=init_state, out=out, ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=False), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_update.py b/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_update.py new file mode 100644 index 0000000..b4babd0 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_fused_sigmoid_gating_delta_rule_update.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang fused_sigmoid_gating_delta_rule_update. +Source: fused_sigmoid_gating_delta_rule_update(q,k,v,A_log,a,dt_bias,b,scale,initial_state,...) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("fused_sigmoid_gating_delta_rule_update") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_fused_sigmoid_gating_delta_rule_update(shape, dtype): + M, N = shape + B, H, K, HV, V = 1, 4, 64, 4, 64 + scale = K ** -0.5 + q = torch.randn(B, 1, H, K, device='cuda', dtype=dtype) + k = torch.randn(B, 1, H, K, device='cuda', dtype=dtype) + v = torch.randn(B, 1, HV, V, device='cuda', dtype=dtype) + A_log = torch.randn(HV, device='cuda', dtype=dtype) + a = torch.randn(B, 1, HV, device='cuda', dtype=dtype) + dt_bias = torch.randn(HV, device='cuda', dtype=dtype) + b = torch.randn(B, 1, HV, device='cuda', dtype=dtype) + init_state = torch.randn(B, HV, V, K, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_update( + q=q, k=k, v=v, A_log=A_log, a=a, dt_bias=dt_bias, b=b, scale=scale, + initial_state=init_state, cu_seqlens=None, initial_state_indices=None) + act_out = kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_update( + q=q.clone(), k=k.clone(), v=v.clone(), A_log=A_log, a=a.clone(), dt_bias=dt_bias, b=b.clone(), scale=scale, + initial_state=init_state.clone(), cu_seqlens=None, initial_state_indices=None) + assert_close(act_out, ref_out, dtype, strict=False) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.fused_sigmoid_gating_delta_rule_update( + q=q, k=k, v=v, A_log=A_log, a=a, dt_bias=dt_bias, b=b, scale=scale, + initial_state=init_state, cu_seqlens=None, initial_state_indices=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul.py b/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul.py new file mode 100644 index 0000000..f5aba73 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang gelu_and_mul. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gelu_and_mul") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_gelu_and_mul(shape, dtype): + """Accuracy and performance test for SGLang gelu_and_mul.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.gelu_and_mul(x) + act_out = kernelgenbench.triton.gelu_and_mul(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gelu_and_mul(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.gelu_and_mul(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul_triton.py b/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul_triton.py new file mode 100644 index 0000000..f661dc8 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gelu_and_mul_triton.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang gelu_and_mul_triton. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gelu_and_mul_triton") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_gelu_and_mul_triton(shape, dtype): + """Accuracy and performance test for SGLang gelu_and_mul_triton.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.gelu_and_mul_triton(x) + act_out = kernelgenbench.triton.gelu_and_mul_triton(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gelu_and_mul_triton(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.gelu_and_mul_triton(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gemma3_rms_norm.py b/src/kernelgenbench/accuracy/sglang/test_gemma3_rms_norm.py new file mode 100644 index 0000000..583f3a2 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gemma3_rms_norm.py @@ -0,0 +1,35 @@ +""" +Accuracy and benchmark test for SGLang Gemma3RMSNorm. +Source: Gemma3RMSNorm(dim, eps).forward_cuda(x) — output = norm(x) * (1 + weight) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gemma3_rms_norm") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("eps", [1e-6, 1e-5]) +def test_accuracy_gemma3_rms_norm(shape, dtype, eps): + M, dim = shape + x = torch.randn(M, dim, device='cuda', dtype=dtype) + + ref_out = kernelgenbench.baseline.gemma3_rms_norm(x, dim, eps=eps) + act_out = kernelgenbench.baseline.gemma3_rms_norm(x.clone(), dim, eps=eps) + + assert_close(act_out, ref_out, dtype) + + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, dim, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma3_rms_norm(x_bench, dim, eps=eps), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gemma4_rms_norm.py b/src/kernelgenbench/accuracy/sglang/test_gemma4_rms_norm.py new file mode 100644 index 0000000..e28db51 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gemma4_rms_norm.py @@ -0,0 +1,35 @@ +""" +Accuracy and benchmark test for SGLang Gemma4RMSNorm. +Source: Gemma4RMSNorm(dim, eps, scale_shift).forward_cuda(x) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gemma4_rms_norm") +@parametrize("shape", [(1, 32), (128, 512), (1024, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("scale_shift", [0.0, 1.0]) +def test_accuracy_gemma4_rms_norm(shape, dtype, scale_shift): + M, dim = shape + x = torch.randn(M, dim, device='cuda', dtype=dtype) + + ref_out = kernelgenbench.baseline.gemma4_rms_norm(x, dim, eps=1e-6, scale_shift=scale_shift) + act_out = kernelgenbench.baseline.gemma4_rms_norm(x.clone(), dim, eps=1e-6, scale_shift=scale_shift) + + assert_close(act_out, ref_out, dtype) + + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, dim, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma4_rms_norm(x_bench, dim, eps=1e-6, scale_shift=scale_shift), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gemma_qkv_rmsnorm.py b/src/kernelgenbench/accuracy/sglang/test_gemma_qkv_rmsnorm.py new file mode 100644 index 0000000..16070b7 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gemma_qkv_rmsnorm.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang gemma_qkv_rmsnorm. +Source: gemma_qkv_rmsnorm(q [N,QH,Hdim], k [N,KH,Hdim], v [N,KH,Hdim], q_weight, k_weight, num_q_heads, num_kv_heads, head_dim, eps) -> None (in-place) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gemma_qkv_rmsnorm") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_gemma_qkv_rmsnorm(shape, dtype): + M, N = shape + num_qh, num_kh, hdim = 4, 2, N // 4 + q = torch.randn(M, num_qh * hdim, device='cuda', dtype=dtype) + k = torch.randn(M, num_kh * hdim, device='cuda', dtype=dtype) + v = torch.randn(M, num_kh * hdim, device='cuda', dtype=dtype) + q_weight = torch.randn(hdim, device='cuda', dtype=dtype) + k_weight = torch.randn(hdim, device='cuda', dtype=dtype) + ref_q, ref_k, ref_v = q.clone(), k.clone(), v.clone() + act_q, act_k, act_v = q.clone(), k.clone(), v.clone() + kernelgenbench.baseline.gemma_qkv_rmsnorm(ref_q, ref_k, ref_v, q_weight, k_weight, num_q_heads=num_qh, num_kv_heads=num_kh, head_dim=hdim, eps=1e-6) + kernelgenbench.baseline.gemma_qkv_rmsnorm(act_q, act_k, act_v, q_weight, k_weight, num_q_heads=num_qh, num_kv_heads=num_kh, head_dim=hdim, eps=1e-6) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if M < 256: + return None + q_b = torch.randn(M, num_qh * hdim, device='cuda', dtype=dtype) + k_b = torch.randn(M, num_kh * hdim, device='cuda', dtype=dtype) + v_b = torch.randn(M, num_kh * hdim, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma_qkv_rmsnorm(q_b.clone(), k_b.clone(), v_b.clone(), q_weight, k_weight, num_q_heads=num_qh, num_kv_heads=num_kh, head_dim=hdim, eps=1e-6), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gemma_rms_norm.py b/src/kernelgenbench/accuracy/sglang/test_gemma_rms_norm.py new file mode 100644 index 0000000..3c985c5 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gemma_rms_norm.py @@ -0,0 +1,64 @@ +""" +Accuracy and benchmark test for SGLang gemma_rms_norm. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gemma_rms_norm") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096), (5333, 8192)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("has_residual", [True, False]) +def test_accuracy_gemma_rms_norm(shape, dtype, has_residual): + """Accuracy and performance test for SGLang gemma_rms_norm.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, device='cuda', dtype=dtype) + residual = torch.randn_like(x) if has_residual else None + + x_act = x.clone() + residual_act = residual.clone() if has_residual else None + + ref_out = kernelgenbench.baseline.gemma_rms_norm(x, weight, eps=1e-6, residual=residual) + act_out = kernelgenbench.triton.gemma_rms_norm(x_act, weight, eps=1e-6, residual=residual_act) + + if has_residual: + assert_close(act_out[0], ref_out[0], dtype) + assert_close(act_out[1], ref_out[1], dtype) + else: + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + w_bench = torch.randn(N, device='cuda', dtype=dtype) + + if has_residual: + r_bench = torch.randn_like(x_bench) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma_rms_norm(x_bench, w_bench, eps=1e-6, residual=r_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.gemma_rms_norm(x_bench.clone(), w_bench, eps=1e-6, residual=r_bench.clone()), + warmup=25, rep=100 + ) + else: + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma_rms_norm(x_bench, w_bench, eps=1e-6), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.gemma_rms_norm(x_bench.clone(), w_bench, eps=1e-6), + warmup=25, rep=100 + ) + + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_gemma_rmsnorm_residual_scalar.py b/src/kernelgenbench/accuracy/sglang/test_gemma_rmsnorm_residual_scalar.py new file mode 100644 index 0000000..ad0a6ba --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_gemma_rmsnorm_residual_scalar.py @@ -0,0 +1,32 @@ +""" +Accuracy and benchmark test for SGLang gemma_rmsnorm_residual_scalar. +Source: gemma_rmsnorm_residual_scalar(x [N,D], weight [D], residual [N,D], scalar [scalar], eps) -> Tensor[N,D] +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("gemma_rmsnorm_residual_scalar") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_gemma_rmsnorm_residual_scalar(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, device='cuda', dtype=dtype) + residual = torch.randn(M, N, device='cuda', dtype=dtype) + scalar = torch.tensor(0.5, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.gemma_rmsnorm_residual_scalar(x, weight, residual, scalar, eps=1e-6) + act_out = kernelgenbench.baseline.gemma_rmsnorm_residual_scalar(x.clone(), weight, residual.clone(), scalar, eps=1e-6) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.gemma_rmsnorm_residual_scalar(x, weight, residual, scalar, eps=1e-6), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_l2norm.py b/src/kernelgenbench/accuracy/sglang/test_l2norm.py new file mode 100644 index 0000000..4e3b206 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_l2norm.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang l2norm. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("l2norm") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_l2norm(shape, dtype): + """Accuracy and performance test for SGLang l2norm.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.l2norm(x) + act_out = kernelgenbench.triton.l2norm(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.l2norm(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.l2norm(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_layer_norm.py b/src/kernelgenbench/accuracy/sglang/test_layer_norm.py new file mode 100644 index 0000000..ef4a844 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_layer_norm.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang LayerNorm. +Source: sglang.srt.layers.layernorm.LayerNorm(hidden_size, eps).forward_cuda(x) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("layer_norm") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("eps", [1e-6, 1e-5]) +@parametrize("elementwise_affine", [True, False]) +def test_accuracy_layer_norm(shape, dtype, eps, elementwise_affine): + M, hidden_size = shape + x = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + + ref_out = kernelgenbench.baseline.layer_norm(x, hidden_size, eps=eps, elementwise_affine=elementwise_affine) + act_out = kernelgenbench.baseline.layer_norm(x.clone(), hidden_size, eps=eps, elementwise_affine=elementwise_affine) + + assert_close(act_out, ref_out, dtype) + + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.layer_norm(x_bench, hidden_size, eps=eps, elementwise_affine=elementwise_affine), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.baseline.layer_norm(x_bench.clone(), hidden_size, eps=eps, elementwise_affine=elementwise_affine), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_layer_norm_gated_fwd.py b/src/kernelgenbench/accuracy/sglang/test_layer_norm_gated_fwd.py new file mode 100644 index 0000000..c6e8a19 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_layer_norm_gated_fwd.py @@ -0,0 +1,32 @@ +""" +Accuracy and benchmark test for SGLang layer_norm_gated_fwd. +Source: layer_norm_gated_fwd(x [N,D], g [N,D], weight [D], bias [D]|None, activation, eps, residual) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("layer_norm_gated_fwd") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_layer_norm_gated_fwd(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + g = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.ones(N, device='cuda', dtype=dtype) + bias = torch.zeros(N, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.layer_norm_gated_fwd(x, g, weight, bias, activation='swish', eps=1e-5, residual=None) + act_out = kernelgenbench.baseline.layer_norm_gated_fwd(x.clone(), g.clone(), weight, bias, activation='swish', eps=1e-5, residual=None) + assert_close(act_out[0], ref_out[0], dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.layer_norm_gated_fwd(x, g, weight, bias, activation='swish', eps=1e-5, residual=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_linear_scaling_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_linear_scaling_rotary_embedding.py new file mode 100644 index 0000000..7f1d5cf --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_linear_scaling_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang linear_scaling_rotary_embedding. +Source: LinearScalingRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_factor).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("linear_scaling_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_linear_scaling_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.linear_scaling_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + kernelgenbench.baseline.linear_scaling_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.linear_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.linear_scaling_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=2.0) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_llama3_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_llama3_rotary_embedding.py new file mode 100644 index 0000000..4ac958d --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_llama3_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang llama3_rotary_embedding. +Source: Llama3RotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_factor,low_freq,high_freq,orig_max_pos).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("llama3_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_llama3_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.llama3_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=8.0) + kernelgenbench.baseline.llama3_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=8.0) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.llama3_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=8.0) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.llama3_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, scaling_factor=8.0) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_mamba_chunk_scan_combined_fwd.py b/src/kernelgenbench/accuracy/sglang/test_mamba_chunk_scan_combined_fwd.py new file mode 100644 index 0000000..bd62345 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_mamba_chunk_scan_combined_fwd.py @@ -0,0 +1,34 @@ +""" +Accuracy and benchmark test for SGLang mamba_chunk_scan_combined_fwd. +Source: mamba_chunk_scan_combined_fwd(x [B,T,nheads,D], dt [B,T,nheads], A [nheads], B [B,T,ngroups,dstate], C [B,T,ngroups,dstate]) (Mamba-2 SSD) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("mamba_chunk_scan_combined_fwd") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_mamba_chunk_scan_combined_fwd(shape, dtype): + M, N = shape + B, T, nheads, D, dstate, ngroups = 1, 64, 4, 64, 16, 1 + x = torch.randn(B, T, nheads, D, device='cuda', dtype=dtype) + dt = torch.randn(B, T, nheads, device='cuda', dtype=dtype) + A = torch.randn(nheads, device='cuda', dtype=dtype) + Bw = torch.randn(B, T, ngroups, dstate, device='cuda', dtype=dtype) + Cw = torch.randn(B, T, ngroups, dstate, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.mamba_chunk_scan_combined_fwd(x, dt, A, Bw, Cw, chunk_size=64, D=None, z=None, dt_bias=None, initial_states=None, cu_seqlens=None) + act_out = kernelgenbench.baseline.mamba_chunk_scan_combined_fwd(x.clone(), dt.clone(), A, Bw.clone(), Cw.clone(), chunk_size=64, D=None, z=None, dt_bias=None, initial_states=None, cu_seqlens=None) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.mamba_chunk_scan_combined_fwd(x, dt, A, Bw, Cw, chunk_size=64, D=None, z=None, dt_bias=None, initial_states=None, cu_seqlens=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_mixer2_rms_norm_gated.py b/src/kernelgenbench/accuracy/sglang/test_mixer2_rms_norm_gated.py new file mode 100644 index 0000000..6ce92e1 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_mixer2_rms_norm_gated.py @@ -0,0 +1,32 @@ +""" +Accuracy and benchmark test for SGLang mixer2_rms_norm_gated. +Source: Mixer2RMSNormGated(hidden_size, eps).forward_cuda(x, gate) -> Tensor +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("mixer2_rms_norm_gated") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_mixer2_rms_norm_gated(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + gate = torch.randn(M, N, device='cuda', dtype=dtype) + ref_out = kernelgenbench.baseline.mixer2_rms_norm_gated(x, gate, hidden_size=N, eps=1e-6) + act_out = kernelgenbench.baseline.mixer2_rms_norm_gated(x.clone(), gate.clone(), hidden_size=N, eps=1e-6) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + x_b = torch.randn(M, N, device='cuda', dtype=dtype) + g_b = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.mixer2_rms_norm_gated(x_b, g_b, hidden_size=N, eps=1e-6), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_moe_align_block_size.py b/src/kernelgenbench/accuracy/sglang/test_moe_align_block_size.py new file mode 100644 index 0000000..cfa75ed --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_moe_align_block_size.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang moe_align_block_size. +Source: moe_align_block_size(topk_ids [M,topk], num_experts, block_size) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("moe_align_block_size") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_moe_align_block_size(shape, dtype): + M, N = shape + topk_ids = torch.randint(0, 4, (M, 2), device='cuda', dtype=torch.int32) + ref_out = kernelgenbench.baseline.moe_align_block_size(topk_ids, num_experts=4, block_size=32) + act_out = kernelgenbench.baseline.moe_align_block_size(topk_ids.clone(), num_experts=4, block_size=32) + assert_close(act_out[0], ref_out[0], torch.int32) + assert_close(act_out[1], ref_out[1], torch.int32) + if M < 256: + return None + topk_b = torch.randint(0, 4, (M, 2), device='cuda', dtype=torch.int32) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.moe_align_block_size(topk_b, num_experts=4, block_size=32), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_mrotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_mrotary_embedding.py new file mode 100644 index 0000000..7513621 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_mrotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang mrotary_embedding. +Source: MRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,mrope_section).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("mrotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_mrotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + mrope_section = [head_size // 6] * 3 + kernelgenbench.baseline.mrotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, mrope_section=mrope_section) + kernelgenbench.baseline.mrotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, mrope_section=mrope_section) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.mrotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, mrope_section=mrope_section) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.mrotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size, mrope_section=mrope_section) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_new_gelu.py b/src/kernelgenbench/accuracy/sglang/test_new_gelu.py new file mode 100644 index 0000000..68fda7d --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_new_gelu.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang new_gelu. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("new_gelu") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_new_gelu(shape, dtype): + """Accuracy and performance test for SGLang new_gelu.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.new_gelu(x) + act_out = kernelgenbench.triton.new_gelu(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.new_gelu(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.new_gelu(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_per_token_quant_int8.py b/src/kernelgenbench/accuracy/sglang/test_per_token_quant_int8.py new file mode 100644 index 0000000..5626068 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_per_token_quant_int8.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang per_token_quant_int8. +Source: per_token_quant_int8(x [..., D], scale_dtype=float32) -> (quantized, scales) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("per_token_quant_int8") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_per_token_quant_int8(shape, dtype): + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + ref_quant, ref_scales = kernelgenbench.baseline.per_token_quant_int8(x, scale_dtype=torch.float32) + act_quant, act_scales = kernelgenbench.baseline.per_token_quant_int8(x.clone(), scale_dtype=torch.float32) + assert_close(act_quant, ref_quant, torch.int8) + assert_close(act_scales, ref_scales, torch.float32) + if M < 256: + return None + x_b = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.per_token_quant_int8(x_b, scale_dtype=torch.float32), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_phi3_long_rope_scaled_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_phi3_long_rope_scaled_rotary_embedding.py new file mode 100644 index 0000000..0b3e56d --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_phi3_long_rope_scaled_rotary_embedding.py @@ -0,0 +1,44 @@ +""" +Accuracy and benchmark test for SGLang phi3_long_rope_scaled_rotary_embedding. +Source: Phi3LongRoPEScaledRotaryEmbedding(head_size,rotary_dim,max_emb,orig_max_emb,base,is_neox,dtype,short_factor,long_factor).forward_cuda(pos,q,k) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("phi3_long_rope_scaled_rotary_embedding") +@parametrize("seq_len", [128, 512, 1024]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_phi3_long_rope_scaled_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + positions = torch.arange(seq_len, device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.phi3_long_rope_scaled_rotary_embedding(ref_q, ref_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + kernelgenbench.baseline.phi3_long_rope_scaled_rotary_embedding(act_q, act_k, positions, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if seq_len < 512: + return None + pos_b = torch.arange(seq_len, device='cuda', dtype=torch.long) + q_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_b = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + def bench_baseline(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.phi3_long_rope_scaled_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + def bench_triton(): + q, k = q_b.clone(), k_b.clone() + kernelgenbench.baseline.phi3_long_rope_scaled_rotary_embedding(q, k, pos_b, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_quick_gelu.py b/src/kernelgenbench/accuracy/sglang/test_quick_gelu.py new file mode 100644 index 0000000..9ab2ecd --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_quick_gelu.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang quick_gelu. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("quick_gelu") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_quick_gelu(shape, dtype): + """Accuracy and performance test for SGLang quick_gelu.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.quick_gelu(x) + act_out = kernelgenbench.triton.quick_gelu(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.quick_gelu(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.quick_gelu(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_rms_norm.py b/src/kernelgenbench/accuracy/sglang/test_rms_norm.py new file mode 100644 index 0000000..1f5c89e --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_rms_norm.py @@ -0,0 +1,64 @@ +""" +Accuracy and benchmark test for SGLang rms_norm. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("rms_norm") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096), (5333, 8192)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("has_residual", [True, False]) +def test_accuracy_rms_norm(shape, dtype, has_residual): + """Accuracy and performance test for SGLang rms_norm.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, device='cuda', dtype=dtype) + residual = torch.randn_like(x) if has_residual else None + + x_act = x.clone() + residual_act = residual.clone() if has_residual else None + + ref_out = kernelgenbench.baseline.rms_norm(x, weight, eps=1e-6, residual=residual) + act_out = kernelgenbench.triton.rms_norm(x_act, weight, eps=1e-6, residual=residual_act) + + if has_residual: + assert_close(act_out[0], ref_out[0], dtype) + assert_close(act_out[1], ref_out[1], dtype) + else: + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + w_bench = torch.randn(N, device='cuda', dtype=dtype) + + if has_residual: + r_bench = torch.randn_like(x_bench) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.rms_norm(x_bench, w_bench, eps=1e-6, residual=r_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.rms_norm(x_bench.clone(), w_bench, eps=1e-6, residual=r_bench.clone()), + warmup=25, rep=100 + ) + else: + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.rms_norm(x_bench, w_bench, eps=1e-6), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.rms_norm(x_bench.clone(), w_bench, eps=1e-6), + warmup=25, rep=100 + ) + + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_rms_norm_gated.py b/src/kernelgenbench/accuracy/sglang/test_rms_norm_gated.py new file mode 100644 index 0000000..3b73596 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_rms_norm_gated.py @@ -0,0 +1,64 @@ +""" +Accuracy and benchmark test for SGLang rms_norm_gated. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("rms_norm_gated") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096), (5333, 8192)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("has_residual", [True, False]) +def test_accuracy_rms_norm_gated(shape, dtype, has_residual): + """Accuracy and performance test for SGLang rms_norm_gated.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + weight = torch.randn(N, device='cuda', dtype=dtype) + residual = torch.randn_like(x) if has_residual else None + + x_act = x.clone() + residual_act = residual.clone() if has_residual else None + + ref_out = kernelgenbench.baseline.rms_norm_gated(x, weight, eps=1e-6, residual=residual) + act_out = kernelgenbench.triton.rms_norm_gated(x_act, weight, eps=1e-6, residual=residual_act) + + if has_residual: + assert_close(act_out[0], ref_out[0], dtype) + assert_close(act_out[1], ref_out[1], dtype) + else: + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + w_bench = torch.randn(N, device='cuda', dtype=dtype) + + if has_residual: + r_bench = torch.randn_like(x_bench) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.rms_norm_gated(x_bench, w_bench, eps=1e-6, residual=r_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.rms_norm_gated(x_bench.clone(), w_bench, eps=1e-6, residual=r_bench.clone()), + warmup=25, rep=100 + ) + else: + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.rms_norm_gated(x_bench, w_bench, eps=1e-6), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.rms_norm_gated(x_bench.clone(), w_bench, eps=1e-6), + warmup=25, rep=100 + ) + + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_rms_norm_without_scale.py b/src/kernelgenbench/accuracy/sglang/test_rms_norm_without_scale.py new file mode 100644 index 0000000..9b4ab89 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_rms_norm_without_scale.py @@ -0,0 +1,35 @@ +""" +Accuracy and benchmark test for SGLang RMSNormWithoutScale. +Source: RMSNormWithoutScale(hidden_size, eps).forward_cuda(x) — pure RMS, no weight +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("rms_norm_without_scale") +@parametrize("shape", [(1, 32), (71, 497), (128, 512), (1024, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@parametrize("eps", [1e-6, 1e-5]) +def test_accuracy_rms_norm_without_scale(shape, dtype, eps): + M, hidden_size = shape + x = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + + ref_out = kernelgenbench.baseline.rms_norm_without_scale(x, hidden_size, eps=eps) + act_out = kernelgenbench.baseline.rms_norm_without_scale(x.clone(), hidden_size, eps=eps) + + assert_close(act_out, ref_out, dtype) + + if M < 1024 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, hidden_size, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.rms_norm_without_scale(x_bench, hidden_size, eps=eps), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_rotary_embedding.py b/src/kernelgenbench/accuracy/sglang/test_rotary_embedding.py new file mode 100644 index 0000000..bc8424e --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_rotary_embedding.py @@ -0,0 +1,53 @@ +""" +Accuracy and benchmark test for SGLang rotary_embedding. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("rotary_embedding") +@parametrize("seq_len", [128, 512, 1024, 4096]) +@parametrize("num_heads", [4, 16]) +@parametrize("head_size", [64, 128]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +@parametrize("is_neox", [True, False]) +def test_accuracy_rotary_embedding(seq_len, num_heads, head_size, dtype, is_neox): + """Accuracy and performance test for SGLang rotary_embedding.""" + # ===== Accuracy Test ===== + positions = torch.randint(0, 8192, (seq_len,), device='cuda', dtype=torch.long) + q0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k0 = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + + kernelgenbench.baseline.rotary_embedding(positions, ref_q, ref_k, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + kernelgenbench.triton.rotary_embedding(positions, act_q, act_k, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + + # ===== Performance Test ===== + if seq_len < 1024: + return None + + pos_bench = torch.randint(0, 8192, (seq_len,), device='cuda', dtype=torch.long) + q_bench = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + k_bench = torch.randn(seq_len, num_heads, head_size, device='cuda', dtype=dtype) + + def bench_baseline(): + q, k = q_bench.clone(), k_bench.clone() + kernelgenbench.baseline.rotary_embedding(pos_bench, q, k, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + + def bench_triton(): + q, k = q_bench.clone(), k_bench.clone() + kernelgenbench.triton.rotary_embedding(pos_bench, q, k, is_neox_style=is_neox, head_size=head_size, rotary_dim=head_size) + + ms_baseline = triton.testing.do_bench(bench_baseline, warmup=25, rep=100) + ms_triton = triton.testing.do_bench(bench_triton, warmup=25, rep=100) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_selective_scan_update.py b/src/kernelgenbench/accuracy/sglang/test_selective_scan_update.py new file mode 100644 index 0000000..48e3d68 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_selective_scan_update.py @@ -0,0 +1,35 @@ +""" +Accuracy and benchmark test for SGLang selective_scan_update. +Source: selective_scan_update(state, x, dt, A, B, C, D, z, dt_bias, state_batch_indices) (Mamba-1 SSM) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("selective_scan_update") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_selective_scan_update(shape, dtype): + M, N = shape + dim, dstate = N, 16 + state = torch.randn(M, dim, dstate, device='cuda', dtype=dtype) + x = torch.randn(M, dim, device='cuda', dtype=dtype) + dt = torch.randn(M, dim, device='cuda', dtype=dtype) + A = torch.randn(dim, dstate, device='cuda', dtype=dtype) + B = torch.randn(M, dstate, device='cuda', dtype=dtype) + C = torch.randn(M, dstate, device='cuda', dtype=dtype) + ref_out, ref_state = kernelgenbench.baseline.selective_scan_update(state=state, x=x, dt=dt, A=A, B=B, C=C, D=None, z=None, dt_bias=None, state_batch_indices=None) + act_out, act_state = kernelgenbench.baseline.selective_scan_update(state=state.clone(), x=x.clone(), dt=dt.clone(), A=A, B=B.clone(), C=C.clone(), D=None, z=None, dt_bias=None, state_batch_indices=None) + assert_close(act_out, ref_out, dtype) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.selective_scan_update(state=state, x=x, dt=dt, A=A, B=B, C=C, D=None, z=None, dt_bias=None, state_batch_indices=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_silu_and_mul.py b/src/kernelgenbench/accuracy/sglang/test_silu_and_mul.py new file mode 100644 index 0000000..56940bb --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_silu_and_mul.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang silu_and_mul. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("silu_and_mul") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_silu_and_mul(shape, dtype): + """Accuracy and performance test for SGLang silu_and_mul.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.silu_and_mul(x) + act_out = kernelgenbench.triton.silu_and_mul(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.silu_and_mul(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.silu_and_mul(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_silu_and_mul_triton.py b/src/kernelgenbench/accuracy/sglang/test_silu_and_mul_triton.py new file mode 100644 index 0000000..c414e8b --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_silu_and_mul_triton.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang silu_and_mul_triton. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("silu_and_mul_triton") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_silu_and_mul_triton(shape, dtype): + """Accuracy and performance test for SGLang silu_and_mul_triton.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.silu_and_mul_triton(x) + act_out = kernelgenbench.triton.silu_and_mul_triton(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.silu_and_mul_triton(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.silu_and_mul_triton(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_softcap.py b/src/kernelgenbench/accuracy/sglang/test_softcap.py new file mode 100644 index 0000000..7e8d352 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_softcap.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang softcap. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("softcap") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_softcap(shape, dtype): + """Accuracy and performance test for SGLang softcap.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.softcap(x) + act_out = kernelgenbench.triton.softcap(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.softcap(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.softcap(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_topk.py b/src/kernelgenbench/accuracy/sglang/test_topk.py new file mode 100644 index 0000000..96a3486 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_topk.py @@ -0,0 +1,31 @@ +""" +Accuracy and benchmark test for SGLang topk. +Source: TopK(topk, renormalize).forward_cuda(hidden_states, router_logits) -> TopKOutput +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("topk") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_topk(shape, dtype): + M, N = shape + E = 8 + hidden_states = torch.randn(M, N, device='cuda', dtype=dtype) + router_logits = torch.randn(M, E, device='cuda', dtype=torch.float32) + ref_out = kernelgenbench.baseline.topk(hidden_states, router_logits, topk=2) + act_out = kernelgenbench.baseline.topk(hidden_states.clone(), router_logits.clone(), topk=2) + assert_close(act_out.weights, ref_out.weights, torch.float32) + if M < 256: + return None + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.topk(hidden_states, router_logits, topk=2), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_triton_ernie45_rope_fused.py b/src/kernelgenbench/accuracy/sglang/test_triton_ernie45_rope_fused.py new file mode 100644 index 0000000..b2b85f4 --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_triton_ernie45_rope_fused.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang triton_ernie45_rope_fused. +Source: triton_ernie45_rope_fused_inplace(q,k,cos_sin_cache,positions,mrope_section,head_size,rotary_dim,is_neox)->None (in-place) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("triton_ernie45_rope_fused") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_triton_ernie45_rope_fused(shape, dtype): + M, N = shape + q0 = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + k0 = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + positions = torch.randint(0, 8192, (3, M), device='cuda', dtype=torch.long) + cos_sin_cache = torch.randn(8192, N, device='cuda', dtype=dtype) + mrope_section = [N // 6] * 3 + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + kernelgenbench.baseline.triton_ernie45_rope_fused(ref_q, ref_k, cos_sin_cache, positions, mrope_section, head_size=N, rotary_dim=N) + kernelgenbench.baseline.triton_ernie45_rope_fused(act_q, act_k, cos_sin_cache, positions, mrope_section, head_size=N, rotary_dim=N) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if M < 256: + return None + q_b = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + k_b = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + pos_b = torch.randint(0, 8192, (3, M), device='cuda', dtype=torch.long) + cos_b = torch.randn(8192, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.triton_ernie45_rope_fused(q_b.clone(), k_b.clone(), cos_b, pos_b, mrope_section=[N//6]*3, head_size=N, rotary_dim=N), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_triton_mrope_fused.py b/src/kernelgenbench/accuracy/sglang/test_triton_mrope_fused.py new file mode 100644 index 0000000..be930fc --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_triton_mrope_fused.py @@ -0,0 +1,40 @@ +""" +Accuracy and benchmark test for SGLang triton_mrope_fused. +Source: triton_mrope_fused(q,k,cos_sin_cache,positions,mrope_section,head_size,rotary_dim,...)->None (in-place) +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("triton_mrope_fused") +@parametrize("shape", [(64, 64), (256, 128)]) +@parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_accuracy_triton_mrope_fused(shape, dtype): + M, N = shape + q0 = torch.randn(M, 4 * N, device='cuda', dtype=dtype) # M tokens, 4 heads * head_size + k0 = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + positions = torch.randint(0, 8192, (M,), device='cuda', dtype=torch.long) + cos_sin_cache = torch.randn(8192, N, device='cuda', dtype=dtype) + mrope_section = [N // 6] * 3 + ref_q, ref_k = q0.clone(), k0.clone() + act_q, act_k = q0.clone(), k0.clone() + kernelgenbench.baseline.triton_mrope_fused(ref_q, ref_k, cos_sin_cache, positions, mrope_section, head_size=N, rotary_dim=N, mrope_interleaved=False, mrope_interleaved_glm=False, is_neox_style=True, axis_map=None) + kernelgenbench.baseline.triton_mrope_fused(act_q, act_k, cos_sin_cache, positions, mrope_section, head_size=N, rotary_dim=N, mrope_interleaved=False, mrope_interleaved_glm=False, is_neox_style=True, axis_map=None) + assert_close(act_q, ref_q, dtype) + assert_close(act_k, ref_k, dtype) + if M < 256: + return None + q_b = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + k_b = torch.randn(M, 4 * N, device='cuda', dtype=dtype) + pos_b = torch.randint(0, 8192, (M,), device='cuda', dtype=torch.long) + cos_b = torch.randn(8192, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.triton_mrope_fused(q_b.clone(), k_b.clone(), cos_b, pos_b, mrope_section=[N//6]*3, head_size=N, rotary_dim=N, mrope_interleaved=False, mrope_interleaved_glm=False, is_neox_style=True, axis_map=None), + warmup=25, rep=100 + ) + speedup = 1.0 + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_baseline, speedup=speedup) diff --git a/src/kernelgenbench/accuracy/sglang/test_xielu.py b/src/kernelgenbench/accuracy/sglang/test_xielu.py new file mode 100644 index 0000000..d5cf76b --- /dev/null +++ b/src/kernelgenbench/accuracy/sglang/test_xielu.py @@ -0,0 +1,41 @@ +""" +Accuracy and benchmark test for SGLang xielu. +""" +import kernelgenbench +from sandbox.config import DEVICE as device +from sandbox.verifier.test_parametrize import parametrize, label +from sandbox.utils.accuracy_utils import kernelgenbench_assert_close as assert_close +from sandbox.utils.accuracy_utils import CustomBenchmarkResult +import torch +import triton + +@label("xielu") +@parametrize("shape", [(1, 512), (71, 2048), (128, 4096), (1024, 8192), (5333, 4096)]) +@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_accuracy_xielu(shape, dtype): + """Accuracy and performance test for SGLang xielu.""" + # ===== Accuracy Test ===== + M, N = shape + x = torch.randn(M, N, device='cuda', dtype=dtype) + x_act = x.clone() + + ref_out = kernelgenbench.baseline.xielu(x) + act_out = kernelgenbench.triton.xielu(x_act) + + assert_close(act_out, ref_out, dtype) + + # ===== Performance Test ===== + if M < 1024 or N < 4096 or dtype == torch.float32: + return None + + x_bench = torch.randn(M, N, device='cuda', dtype=dtype) + ms_baseline = triton.testing.do_bench( + lambda: kernelgenbench.baseline.xielu(x_bench), + warmup=25, rep=100 + ) + ms_triton = triton.testing.do_bench( + lambda: kernelgenbench.triton.xielu(x_bench.clone()), + warmup=25, rep=100 + ) + speedup = ms_baseline / ms_triton if ms_triton > 0 else float('inf') + return CustomBenchmarkResult(ref_time=ms_baseline, res_time=ms_triton, speedup=speedup) diff --git a/src/kernelgenbench/dataset/baseline/sglang/__init__.py b/src/kernelgenbench/dataset/baseline/sglang/__init__.py new file mode 100644 index 0000000..c4a27c3 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/__init__.py @@ -0,0 +1,119 @@ +""" +SGLang baseline operators. + +Each submodule wraps a single SGLang operator as a callable function. +Follows the vLLM baseline pattern: thin Python wrappers with try/except imports. + +Categories mirror SGLang source modules: + layers/activation.py, layers/layernorm.py, layers/rotary_embedding/, + layers/moe/, layers/attention/fla/, layers/attention/mamba/, + layers/elementwise.py, layers/gemma4_fused_ops.py, + layers/conv.py, layers/quantization/ +""" + +# layers/activation.py +from .silu_and_mul import silu_and_mul +from .gelu_and_mul import gelu_and_mul +from .quick_gelu import quick_gelu +from .new_gelu import new_gelu +from .xielu import xielu + +# layers/layernorm.py +from .rms_norm import rms_norm +from .layer_norm import layer_norm +from .gemma_rms_norm import gemma_rms_norm +from .gemma3_rms_norm import gemma3_rms_norm +from .gemma4_rms_norm import gemma4_rms_norm +from .rms_norm_without_scale import rms_norm_without_scale + +# layers/rotary_embedding/ +from .rotary_embedding import rotary_embedding +from .mrotary_embedding import mrotary_embedding +from .dual_chunk_rotary_embedding import dual_chunk_rotary_embedding +from .deepseek_scaling_rotary_embedding import deepseek_scaling_rotary_embedding +from .llama3_rotary_embedding import llama3_rotary_embedding +from .dynamic_ntk_scaling_rotary_embedding import dynamic_ntk_scaling_rotary_embedding +from .linear_scaling_rotary_embedding import linear_scaling_rotary_embedding +from .phi3_long_rope_scaled_rotary_embedding import phi3_long_rope_scaled_rotary_embedding +from .triton_mrope_fused import triton_mrope_fused +from .triton_ernie45_rope_fused import triton_ernie45_rope_fused +from .apply_interleaved_rope_triton import apply_interleaved_rope_triton +from .dynamic_ntk_alpha_rotary_embedding import dynamic_ntk_alpha_rotary_embedding + +# layers/moe/ +from .fused_moe import fused_moe +from .topk import topk +from .moe_align_block_size import moe_align_block_size + +# layers/attention/fla/ +from .l2norm import l2norm +from .rms_norm_gated import rms_norm_gated +from .fused_recurrent_gated_delta_rule import fused_recurrent_gated_delta_rule +from .fused_recurrent_gated_delta_rule_update import fused_recurrent_gated_delta_rule_update +from .fused_sigmoid_gating_delta_rule_update import fused_sigmoid_gating_delta_rule_update +from .fused_sigmoid_gating_delta_rule_packed_decode import fused_sigmoid_gating_delta_rule_packed_decode +from .fused_gdn_gating import fused_gdn_gating +from .layer_norm_gated_fwd import layer_norm_gated_fwd + +# layers/attention/mamba/ +from .causal_conv1d_fn import causal_conv1d_fn +from .causal_conv1d_update import causal_conv1d_update +from .selective_scan_update import selective_scan_update +from .mamba_chunk_scan_combined_fwd import mamba_chunk_scan_combined_fwd +from .mixer2_rms_norm_gated import mixer2_rms_norm_gated + +# layers/elementwise.py +from .fused_dual_residual_rmsnorm import fused_dual_residual_rmsnorm +from .softcap import softcap +from .silu_and_mul_triton import silu_and_mul_triton +from .gelu_and_mul_triton import gelu_and_mul_triton +from .fused_rmsnorm import fused_rmsnorm +from .experts_combine_triton import experts_combine_triton + +# layers/gemma4_fused_ops.py +from .gemma_rmsnorm_residual_scalar import gemma_rmsnorm_residual_scalar +from .gemma_qkv_rmsnorm import gemma_qkv_rmsnorm + +# layers/conv.py +from .conv2d_layer import conv2d_layer +from .conv3d_layer import conv3d_layer + +# layers/quantization/ +from .per_token_quant_int8 import per_token_quant_int8 + +__all__ = [ + # layers/activation.py + "silu_and_mul", "gelu_and_mul", "quick_gelu", "new_gelu", "xielu", + # layers/layernorm.py + "rms_norm", "layer_norm", "gemma_rms_norm", "gemma3_rms_norm", + "gemma4_rms_norm", "rms_norm_without_scale", + # layers/rotary_embedding/ + "rotary_embedding", "mrotary_embedding", "dual_chunk_rotary_embedding", + "deepseek_scaling_rotary_embedding", "llama3_rotary_embedding", + "dynamic_ntk_scaling_rotary_embedding", "linear_scaling_rotary_embedding", + "phi3_long_rope_scaled_rotary_embedding", + "triton_mrope_fused", "triton_ernie45_rope_fused", + "apply_interleaved_rope_triton", "dynamic_ntk_alpha_rotary_embedding", + # layers/moe/ + "fused_moe", "topk", "moe_align_block_size", + # layers/attention/fla/ + "l2norm", "rms_norm_gated", + "fused_recurrent_gated_delta_rule", "fused_recurrent_gated_delta_rule_update", + "fused_sigmoid_gating_delta_rule_update", + "fused_sigmoid_gating_delta_rule_packed_decode", + "fused_gdn_gating", "layer_norm_gated_fwd", + # layers/attention/mamba/ + "causal_conv1d_fn", "causal_conv1d_update", + "selective_scan_update", "mamba_chunk_scan_combined_fwd", + "mixer2_rms_norm_gated", + # layers/elementwise.py + "fused_dual_residual_rmsnorm", + "softcap", "silu_and_mul_triton", "gelu_and_mul_triton", + "fused_rmsnorm", "experts_combine_triton", + # layers/gemma4_fused_ops.py + "gemma_rmsnorm_residual_scalar", "gemma_qkv_rmsnorm", + # layers/conv.py + "conv2d_layer", "conv3d_layer", + # layers/quantization/ + "per_token_quant_int8", +] diff --git a/src/kernelgenbench/dataset/baseline/sglang/apply_interleaved_rope_triton.py b/src/kernelgenbench/dataset/baseline/sglang/apply_interleaved_rope_triton.py new file mode 100644 index 0000000..2e23db9 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/apply_interleaved_rope_triton.py @@ -0,0 +1,22 @@ +""" +SGLang apply_interleaved_rope_triton baseline. +Source: sglang.srt.layers.rotary_embedding.mrope.apply_interleaved_rope_triton(x, mrope_section) +""" +import torch +import typing + +try: + from sglang.srt.layers.rotary_embedding.mrope import apply_interleaved_rope_triton as _apply_interleaved_rope +except ModuleNotFoundError: + _apply_interleaved_rope = None + + +def apply_interleaved_rope_triton( + x: torch.Tensor, + mrope_section: typing.List[int], +) -> torch.Tensor: + """ + Rearranges interleaved RoPE values along dim=0 using mrope_section. + x shape: [3, N, D] -> returns [N, D] + """ + return _apply_interleaved_rope(x, mrope_section) diff --git a/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_fn.py b/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_fn.py new file mode 100644 index 0000000..828e3a7 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_fn.py @@ -0,0 +1,29 @@ +""" +SGLang causal_conv1d_fn baseline. +""" +import torch +import typing +try: + import sgl_kernel as _sgl_kernel + _causal_conv1d_fn = _sgl_kernel.causal_conv1d_fwd +except (ModuleNotFoundError, AttributeError): + _causal_conv1d_fn = None + + + +def causal_conv1d_fn(x: torch.Tensor, weight: torch.Tensor, bias: typing.Optional[torch.Tensor] = None, query_start_loc: typing.Optional[torch.Tensor] = None, cache_indices: typing.Optional[torch.Tensor] = None, has_initial_state: typing.Optional[torch.Tensor] = None, conv_states: typing.Optional[torch.Tensor] = None, activation: str = 'silu') -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + bias: typing.Optional[torch.Tensor] + query_start_loc: typing.Optional[torch.Tensor] + cache_indices: typing.Optional[torch.Tensor] + has_initial_state: typing.Optional[torch.Tensor] + conv_states: typing.Optional[torch.Tensor] + activation: str + Returns: + typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]] + """ + return _causal_conv1d_fn(x, weight, bias=bias, query_start_loc=query_start_loc, cache_indices=cache_indices, has_initial_state=has_initial_state, conv_states=conv_states, activation=activation) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_update.py b/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_update.py new file mode 100644 index 0000000..04085c0 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/causal_conv1d_update.py @@ -0,0 +1,26 @@ +""" +SGLang causal_conv1d_update baseline. +""" +import torch +import typing +try: + import sgl_kernel as _sgl_kernel + _causal_conv1d_update = _sgl_kernel.causal_conv1d_update +except (ModuleNotFoundError, AttributeError): + _causal_conv1d_update = None + + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: typing.Optional[torch.Tensor] = None, activation: str = 'silu') -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: torch.Tensor + conv_state: torch.Tensor + weight: torch.Tensor + bias: typing.Optional[torch.Tensor] + activation: str + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + return _causal_conv1d_update(x, conv_state, weight, bias=bias, activation=activation) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/conv2d_layer.py b/src/kernelgenbench/dataset/baseline/sglang/conv2d_layer.py new file mode 100644 index 0000000..38e99a4 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/conv2d_layer.py @@ -0,0 +1,39 @@ +""" +SGLang Conv2dLayer baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.conv import Conv2dLayer as _Conv2dLayer +except ModuleNotFoundError: + _Conv2dLayer = None + +_conv2d_module = None +_current_conv2d_config = None + +def _get_conv2d_module(in_channels, out_channels, kernel_size, stride, padding, bias): + global _conv2d_module, _current_conv2d_config + key = (in_channels, out_channels, kernel_size, stride, padding, bias) + if _current_conv2d_config != key: + _conv2d_module = _Conv2dLayer( + in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ).cuda() + _current_conv2d_config = key + return _conv2d_module + +def conv2d_layer(x: torch.Tensor, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False) -> torch.Tensor: + """ + Args: + x: torch.Tensor + in_channels: int + out_channels: int + kernel_size: int + stride: int + padding: int + bias: bool + Returns: + torch.Tensor + """ + return _conv2d_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/conv3d_layer.py b/src/kernelgenbench/dataset/baseline/sglang/conv3d_layer.py new file mode 100644 index 0000000..edf1ede --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/conv3d_layer.py @@ -0,0 +1,39 @@ +""" +SGLang Conv3dLayer baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.conv import Conv3dLayer as _Conv3dLayer +except ModuleNotFoundError: + _Conv3dLayer = None + +_conv3d_module = None +_current_conv3d_config = None + +def _get_conv3d_module(in_channels, out_channels, kernel_size, stride, padding, bias): + global _conv3d_module, _current_conv3d_config + key = (in_channels, out_channels, kernel_size, stride, padding, bias) + if _current_conv3d_config != key: + _conv3d_module = _Conv3dLayer( + in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ).cuda() + _current_conv3d_config = key + return _conv3d_module + +def conv3d_layer(x: torch.Tensor, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False) -> torch.Tensor: + """ + Args: + x: torch.Tensor + in_channels: int + out_channels: int + kernel_size: int + stride: int + padding: int + bias: bool + Returns: + torch.Tensor + """ + return _conv3d_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/deepseek_scaling_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/deepseek_scaling_rotary_embedding.py new file mode 100644 index 0000000..b0f4ac9 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/deepseek_scaling_rotary_embedding.py @@ -0,0 +1,45 @@ +""" +SGLang DeepseekScalingRotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.rope_variant import DeepseekScalingRotaryEmbedding as _DeepseekScalingRotaryEmbedding +except ModuleNotFoundError: + _DeepseekScalingRotaryEmbedding = None + +_deepseek_rope_module = None +_current_dsrope_config = None + +def _get_deepseek_rope_module(head_size, rotary_dim, max_position_embeddings, base, scaling_factor): + global _deepseek_rope_module, _current_dsrope_config + key = (head_size, rotary_dim, max_position_embeddings, base, scaling_factor) + if _current_dsrope_config != key: + _deepseek_rope_module = _DeepseekScalingRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + scaling_factor=scaling_factor + ).cuda() + _current_dsrope_config = key + return _deepseek_rope_module + +def deepseek_scaling_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0, scaling_factor: float = 1.0) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + scaling_factor: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _deepseek_rope_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/dual_chunk_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/dual_chunk_rotary_embedding.py new file mode 100644 index 0000000..dca8283 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/dual_chunk_rotary_embedding.py @@ -0,0 +1,43 @@ +""" +SGLang DualChunkRotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.rope_variant import DualChunkRotaryEmbedding as _DualChunkRotaryEmbedding +except ModuleNotFoundError: + _DualChunkRotaryEmbedding = None + +_dual_chunk_module = None +_current_dc_config = None + +def _get_dual_chunk_module(head_size, rotary_dim, max_position_embeddings, base): + global _dual_chunk_module, _current_dc_config + key = (head_size, rotary_dim, max_position_embeddings, base) + if _current_dc_config != key: + _dual_chunk_module = _DualChunkRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base + ).cuda() + _current_dc_config = key + return _dual_chunk_module + +def dual_chunk_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _dual_chunk_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_alpha_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_alpha_rotary_embedding.py new file mode 100644 index 0000000..5c1e67c --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_alpha_rotary_embedding.py @@ -0,0 +1,49 @@ +""" +SGLang DynamicNTKAlphaRotaryEmbedding baseline. +Source: sglang.srt.layers.rotary_embedding.rope_variant.DynamicNTKAlphaRotaryEmbedding(head_size,rotary_dim,max_positions,base,is_neox,dtype,scaling_alpha) +""" +import torch +import typing + +try: + from sglang.srt.layers.rotary_embedding.rope_variant import DynamicNTKAlphaRotaryEmbedding as _DynamicNTKAlphaRotaryEmbedding +except ModuleNotFoundError: + _DynamicNTKAlphaRotaryEmbedding = None + +_module = None +_current_config = None + + +def _get_module(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, scaling_alpha): + global _module, _current_config + key = (head_size, rotary_dim, max_position_embeddings, base, is_neox_style, str(dtype), scaling_alpha) + if _current_config != key: + _module = _DynamicNTKAlphaRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + is_neox_style=is_neox_style, dtype=dtype or torch.bfloat16, + scaling_alpha=scaling_alpha + ).cuda() + _current_config = key + return _module + + +def dynamic_ntk_alpha_rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: typing.Optional[torch.Tensor] = None, + head_size: int, + rotary_dim: int, + max_position_embeddings: int = 8192, + base: float = 10000.0, + is_neox_style: bool = True, + scaling_alpha: float = 1.0, + dtype: torch.dtype = None, +) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """Dynamic NTK alpha-scaled RoPE. In-place modifies q, k.""" + module = _get_module(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, scaling_alpha) + q = query.clone() + k = key.clone() + module.forward_cuda(positions, q, k, offsets) + return q, k diff --git a/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_scaling_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_scaling_rotary_embedding.py new file mode 100644 index 0000000..a3c9d12 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/dynamic_ntk_scaling_rotary_embedding.py @@ -0,0 +1,45 @@ +""" +SGLang DynamicNTKScalingRotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.rope_variant import DynamicNTKScalingRotaryEmbedding as _DynamicNTKScalingRotaryEmbedding +except ModuleNotFoundError: + _DynamicNTKScalingRotaryEmbedding = None + +_dynamic_ntk_module = None +_current_dntk_config = None + +def _get_dynamic_ntk_module(head_size, rotary_dim, max_position_embeddings, base, scaling_factor): + global _dynamic_ntk_module, _current_dntk_config + key = (head_size, rotary_dim, max_position_embeddings, base, scaling_factor) + if _current_dntk_config != key: + _dynamic_ntk_module = _DynamicNTKScalingRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + scaling_factor=scaling_factor + ).cuda() + _current_dntk_config = key + return _dynamic_ntk_module + +def dynamic_ntk_scaling_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0, scaling_factor: float = 1.0) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + scaling_factor: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _dynamic_ntk_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/experts_combine_triton.py b/src/kernelgenbench/dataset/baseline/sglang/experts_combine_triton.py new file mode 100644 index 0000000..cd021c3 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/experts_combine_triton.py @@ -0,0 +1,23 @@ +""" +SGLang experts_combine_triton baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import experts_combine_triton as _experts_combine +except ModuleNotFoundError: + _experts_combine = None + + + +def experts_combine_triton(moe_hidden_states: torch.Tensor, mlp_hidden_states: torch.Tensor, output_buffer: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + moe_hidden_states: torch.Tensor + mlp_hidden_states: torch.Tensor + output_buffer: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _experts_combine(moe_hidden_states, mlp_hidden_states, output_buffer=output_buffer) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_dual_residual_rmsnorm.py b/src/kernelgenbench/dataset/baseline/sglang/fused_dual_residual_rmsnorm.py new file mode 100644 index 0000000..001b3f8 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_dual_residual_rmsnorm.py @@ -0,0 +1,34 @@ +""" +SGLang FusedDualResidualRMSNorm baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import FusedDualResidualRMSNorm as _FusedDualResidualRMSNorm +except ModuleNotFoundError: + _FusedDualResidualRMSNorm = None + +_fused_dual_residual_rmsnorm_module = None +_current_fdrn_config = None + +def _get_fused_dual_residual_rmsnorm_module(hidden_size_1, hidden_size_2, eps): + global _fused_dual_residual_rmsnorm_module, _current_fdrn_config + key = (hidden_size_1, hidden_size_2, eps) + if _current_fdrn_config != key: + _fused_dual_residual_rmsnorm_module = _FusedDualResidualRMSNorm(hidden_size_1=hidden_size_1, hidden_size_2=hidden_size_2, eps=eps).cuda() + _current_fdrn_config = key + return _fused_dual_residual_rmsnorm_module + +def fused_dual_residual_rmsnorm(x: torch.Tensor, residual: torch.Tensor, hidden_size_1: int, hidden_size_2: int, eps: float = 1e-6) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: torch.Tensor + residual: torch.Tensor + hidden_size_1: int + hidden_size_2: int + eps: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + return _fused_dual_residual_rmsnorm_module(x, residual) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_gdn_gating.py b/src/kernelgenbench/dataset/baseline/sglang/fused_gdn_gating.py new file mode 100644 index 0000000..298fe66 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_gdn_gating.py @@ -0,0 +1,26 @@ +""" +SGLang fused_gdn_gating baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_gdn_gating import fused_gdn_gating as _fused_gdn +except ModuleNotFoundError: + _fused_gdn = None + + + +def fused_gdn_gating(A_log: torch.Tensor, a: torch.Tensor, b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + A_log: torch.Tensor + a: torch.Tensor + b: torch.Tensor + dt_bias: torch.Tensor + beta: float + threshold: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + return _fused_gdn(A_log=A_log, a=a, b=b, dt_bias=dt_bias, beta=beta, threshold=threshold) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_moe.py b/src/kernelgenbench/dataset/baseline/sglang/fused_moe.py new file mode 100644 index 0000000..eb01b72 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_moe.py @@ -0,0 +1,29 @@ +""" +SGLang fused_moe baseline (Triton fused MoE). +""" +import torch +import typing +try: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import triton_kernel_fused_experts as _fused_experts +except ModuleNotFoundError: + _fused_experts = None + + + +def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, routing_data: "typing.Any", gather_indx: "typing.Any", scatter_indx: "typing.Any", inplace: bool = False, activation: str = 'silu', apply_router_weight_on_input: bool = False) -> torch.Tensor: + """ + Args: + hidden_states: torch.Tensor + w1: torch.Tensor + w2: torch.Tensor + routing_data: "typing.Any" + gather_indx: "typing.Any" + scatter_indx: "typing.Any" + inplace: bool + activation: str + apply_router_weight_on_input: bool + Returns: + torch.Tensor + """ + return _fused_experts(hidden_states, w1, w2, routing_data, gather_indx, scatter_indx, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule.py b/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule.py new file mode 100644 index 0000000..2e2cc9d --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule.py @@ -0,0 +1,29 @@ +""" +SGLang fused_recurrent_gated_delta_rule baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_recurrent import fused_recurrent_gated_delta_rule as _fused_rule +except ModuleNotFoundError: + _fused_rule = None + + + +def fused_recurrent_gated_delta_rule(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: typing.Optional[torch.Tensor] = None, output_final_state: bool = False, cu_seqlens: typing.Optional[torch.Tensor] = None) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + q: torch.Tensor + k: torch.Tensor + v: torch.Tensor + g: torch.Tensor + beta: torch.Tensor + scale: float + initial_state: typing.Optional[torch.Tensor] + output_final_state: bool + cu_seqlens: typing.Optional[torch.Tensor] + Returns: + typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]] + """ + return _fused_rule(q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule_update.py b/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule_update.py new file mode 100644 index 0000000..acdb15d --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_recurrent_gated_delta_rule_update.py @@ -0,0 +1,31 @@ +""" +SGLang fused_recurrent_gated_delta_rule_update baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_recurrent import fused_recurrent_gated_delta_rule_update as _fused_rule_update +except ModuleNotFoundError: + _fused_rule_update = None + + + +def fused_recurrent_gated_delta_rule_update(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, cu_seqlens: typing.Optional[torch.Tensor] = None, initial_state_indices: typing.Optional[torch.Tensor] = None, intermediate_states: typing.Optional[torch.Tensor] = None, eagle_tree: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + q: torch.Tensor + k: torch.Tensor + v: torch.Tensor + g: torch.Tensor + beta: torch.Tensor + scale: float + initial_state: torch.Tensor + cu_seqlens: typing.Optional[torch.Tensor] + initial_state_indices: typing.Optional[torch.Tensor] + intermediate_states: typing.Optional[torch.Tensor] + eagle_tree: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _fused_rule_update(q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, cu_seqlens=cu_seqlens, initial_state_indices=initial_state_indices, intermediate_states=intermediate_states, eagle_tree=eagle_tree) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_rmsnorm.py b/src/kernelgenbench/dataset/baseline/sglang/fused_rmsnorm.py new file mode 100644 index 0000000..1819db7 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_rmsnorm.py @@ -0,0 +1,25 @@ +""" +SGLang fused_rmsnorm baseline (pure Triton). +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import fused_rmsnorm as _fused_rmsnorm_fn +except ModuleNotFoundError: + _fused_rmsnorm_fn = None + + + +def fused_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, autotune: bool = False, inplace: bool = False) -> torch.Tensor: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + eps: float + autotune: bool + inplace: bool + Returns: + torch.Tensor + """ + return _fused_rmsnorm_fn(x, weight, eps, autotune=autotune, inplace=inplace) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_packed_decode.py b/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_packed_decode.py new file mode 100644 index 0000000..a2ce7ef --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_packed_decode.py @@ -0,0 +1,30 @@ +""" +SGLang fused_sigmoid_gating_delta_rule_packed_decode baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_packed_decode as _fused_sig_rule_packed +except ModuleNotFoundError: + _fused_sig_rule_packed = None + + + +def fused_sigmoid_gating_delta_rule_packed_decode(mixed_qkv: torch.Tensor, a: torch.Tensor, b: torch.Tensor, A_log: torch.Tensor, dt_bias: torch.Tensor, scale: float, initial_state: torch.Tensor, out: torch.Tensor, ssm_state_indices: torch.Tensor, use_qk_l2norm_in_kernel: bool = False) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + mixed_qkv: torch.Tensor + a: torch.Tensor + b: torch.Tensor + A_log: torch.Tensor + dt_bias: torch.Tensor + scale: float + initial_state: torch.Tensor + out: torch.Tensor + ssm_state_indices: torch.Tensor + use_qk_l2norm_in_kernel: bool + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + return _fused_sig_rule_packed(mixed_qkv=mixed_qkv, a=a, b=b, A_log=A_log, dt_bias=dt_bias, scale=scale, initial_state=initial_state, out=out, ssm_state_indices=ssm_state_indices, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_update.py b/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_update.py new file mode 100644 index 0000000..f4cd9c3 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/fused_sigmoid_gating_delta_rule_update.py @@ -0,0 +1,31 @@ +""" +SGLang fused_sigmoid_gating_delta_rule_update baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update as _fused_sig_rule_update +except ModuleNotFoundError: + _fused_sig_rule_update = None + + + +def fused_sigmoid_gating_delta_rule_update(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, b: torch.Tensor, scale: float, initial_state: torch.Tensor, cu_seqlens: typing.Optional[torch.Tensor] = None, initial_state_indices: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + q: torch.Tensor + k: torch.Tensor + v: torch.Tensor + A_log: torch.Tensor + a: torch.Tensor + dt_bias: torch.Tensor + b: torch.Tensor + scale: float + initial_state: torch.Tensor + cu_seqlens: typing.Optional[torch.Tensor] + initial_state_indices: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _fused_sig_rule_update(q=q, k=k, v=v, A_log=A_log, a=a, dt_bias=dt_bias, b=b, scale=scale, initial_state=initial_state, cu_seqlens=cu_seqlens, initial_state_indices=initial_state_indices) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul.py b/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul.py new file mode 100644 index 0000000..89e81d0 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul.py @@ -0,0 +1,25 @@ +""" +SGLang GeluAndMul activation baseline. +""" +import torch +import typing +try: + import sgl_kernel.ops.elementwise as _ops +except ModuleNotFoundError: + _ops = None + + + +def gelu_and_mul(x: torch.Tensor, approximate: str = 'tanh') -> torch.Tensor: + """ + Args: + x: torch.Tensor + approximate: str + Returns: + torch.Tensor + """ + x = x.contiguous() + out = torch.empty(x.shape[0], x.shape[-1] // 2, dtype=x.dtype, device=x.device) + _ops.gelu_and_mul(x, out, approximate) + return out + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul_triton.py b/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul_triton.py new file mode 100644 index 0000000..2b77a95 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gelu_and_mul_triton.py @@ -0,0 +1,24 @@ +""" +SGLang gelu_and_mul_triton baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import gelu_and_mul_triton as _gelu_and_mul_t +except ModuleNotFoundError: + _gelu_and_mul_t = None + + + +def gelu_and_mul_triton(hidden_states: torch.Tensor, scales: typing.Optional[torch.Tensor] = None, quantize: typing.Optional[str] = None, out: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + hidden_states: torch.Tensor + scales: typing.Optional[torch.Tensor] + quantize: typing.Optional[str] + out: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _gelu_and_mul_t(hidden_states, scales=scales, quantize=quantize, out=out)[0] + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gemma3_rms_norm.py b/src/kernelgenbench/dataset/baseline/sglang/gemma3_rms_norm.py new file mode 100644 index 0000000..da6a6bc --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gemma3_rms_norm.py @@ -0,0 +1,32 @@ +""" +SGLang Gemma3RMSNorm baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.layernorm import Gemma3RMSNorm as _Gemma3RMSNorm +except ModuleNotFoundError: + _Gemma3RMSNorm = None + +_gemma3_rms_norm_module = None +_current_g3rn_config = None + +def _get_gemma3_rms_norm_module(hidden_size, eps): + global _gemma3_rms_norm_module, _current_g3rn_config + key = (hidden_size, eps) + if _current_g3rn_config != key: + _gemma3_rms_norm_module = _Gemma3RMSNorm(hidden_size=hidden_size, eps=eps).cuda() + _current_g3rn_config = key + return _gemma3_rms_norm_module + +def gemma3_rms_norm(x: torch.Tensor, hidden_size: int, eps: float = 1e-6) -> torch.Tensor: + """ + Args: + x: torch.Tensor + hidden_size: int + eps: float + Returns: + torch.Tensor + """ + return _gemma3_rms_norm_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gemma4_rms_norm.py b/src/kernelgenbench/dataset/baseline/sglang/gemma4_rms_norm.py new file mode 100644 index 0000000..730206b --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gemma4_rms_norm.py @@ -0,0 +1,33 @@ +""" +SGLang Gemma4RMSNorm baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.layernorm import Gemma4RMSNorm as _Gemma4RMSNorm +except ModuleNotFoundError: + _Gemma4RMSNorm = None + +_gemma4_rms_norm_module = None +_current_g4rn_config = None + +def _get_gemma4_rms_norm_module(hidden_size, eps): + global _gemma4_rms_norm_module, _current_g4rn_config + key = (hidden_size, eps) + if _current_g4rn_config != key: + _gemma4_rms_norm_module = _Gemma4RMSNorm(hidden_size=hidden_size, eps=eps).cuda() + _current_g4rn_config = key + return _gemma4_rms_norm_module + +def gemma4_rms_norm(x: torch.Tensor, hidden_size: int, eps: float = 1e-6, scale_shift: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + x: torch.Tensor + hidden_size: int + eps: float + scale_shift: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _gemma4_rms_norm_module(x, scale_shift) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gemma_qkv_rmsnorm.py b/src/kernelgenbench/dataset/baseline/sglang/gemma_qkv_rmsnorm.py new file mode 100644 index 0000000..b523150 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gemma_qkv_rmsnorm.py @@ -0,0 +1,29 @@ +""" +SGLang gemma_qkv_rmsnorm baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.gemma4_fused_ops import gemma_qkv_rmsnorm as _gemma_qkv_rmsnorm +except ModuleNotFoundError: + _gemma_qkv_rmsnorm = None + + + +def gemma_qkv_rmsnorm(q: torch.Tensor, k: typing.Optional[torch.Tensor], v: typing.Optional[torch.Tensor], q_weight: torch.Tensor, k_weight: typing.Optional[torch.Tensor], num_q_heads: int, num_kv_heads: int, head_dim: int, eps: float = 1e-6) -> None: + """ + Args: + q: torch.Tensor + k: typing.Optional[torch.Tensor] + v: typing.Optional[torch.Tensor] + q_weight: torch.Tensor + k_weight: typing.Optional[torch.Tensor] + num_q_heads: int + num_kv_heads: int + head_dim: int + eps: float + Returns: + None + """ + _gemma_qkv_rmsnorm(q, k, v, q_weight, k_weight, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, eps=eps) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gemma_rms_norm.py b/src/kernelgenbench/dataset/baseline/sglang/gemma_rms_norm.py new file mode 100644 index 0000000..f892831 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gemma_rms_norm.py @@ -0,0 +1,38 @@ +""" +SGLang GemmaRMSNorm baseline. +""" +import torch +import typing +try: + import sgl_kernel.ops.norm as _ops +except ModuleNotFoundError: + _ops = None + + + +def gemma_rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, residual: typing.Optional[torch.Tensor] = None, out_residual: typing.Optional[torch.Tensor] = None) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + eps: float + residual: typing.Optional[torch.Tensor] + out_residual: typing.Optional[torch.Tensor] + Returns: + typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]] + """ + x = x.contiguous() + out = torch.empty_like(x) + residual_out = None + if residual is not None: + residual_c = residual.contiguous() + if out_residual is None: + out_residual = torch.empty_like(residual_c) + _ops.gemma_fused_add_rmsnorm(x, residual_c, weight, eps, out, out_residual) + residual_out = out_residual + return out, residual_out + _ops.gemma_rmsnorm(out, x, weight, eps) + if residual is None: + return out + return (out,) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/gemma_rmsnorm_residual_scalar.py b/src/kernelgenbench/dataset/baseline/sglang/gemma_rmsnorm_residual_scalar.py new file mode 100644 index 0000000..4429e0a --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/gemma_rmsnorm_residual_scalar.py @@ -0,0 +1,25 @@ +""" +SGLang gemma_rmsnorm_residual_scalar baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar as _gemma_residual_scalar +except ModuleNotFoundError: + _gemma_residual_scalar = None + + + +def gemma_rmsnorm_residual_scalar(x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, scalar: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + residual: torch.Tensor + scalar: torch.Tensor + eps: float + Returns: + torch.Tensor + """ + return _gemma_residual_scalar(x, weight, residual, scalar, eps=eps) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/l2norm.py b/src/kernelgenbench/dataset/baseline/sglang/l2norm.py new file mode 100644 index 0000000..88e5f50 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/l2norm.py @@ -0,0 +1,22 @@ +""" +SGLang l2norm baseline (FLA Triton kernel). +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.l2norm import l2norm as _l2norm_fn +except ModuleNotFoundError: + _l2norm_fn = None + + + +def l2norm(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + """ + Args: + x: torch.Tensor + eps: float + Returns: + torch.Tensor + """ + return _l2norm_fn(x, eps=eps) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/layer_norm.py b/src/kernelgenbench/dataset/baseline/sglang/layer_norm.py new file mode 100644 index 0000000..ac176aa --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/layer_norm.py @@ -0,0 +1,33 @@ +""" +SGLang LayerNorm baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.layernorm import LayerNorm as _LayerNorm +except ModuleNotFoundError: + _LayerNorm = None + +_layer_norm_module = None +_current_ln_config = None + +def _get_layer_norm_module(normalized_shape, eps, elementwise_affine): + global _layer_norm_module, _current_ln_config + key = (normalized_shape, eps, elementwise_affine) + if _current_ln_config != key: + _layer_norm_module = _LayerNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine).cuda() + _current_ln_config = key + return _layer_norm_module + +def layer_norm(x: torch.Tensor, normalized_shape: int, eps: float = 1e-5, elementwise_affine: bool = True) -> torch.Tensor: + """ + Args: + x: torch.Tensor + normalized_shape: int + eps: float + elementwise_affine: bool + Returns: + torch.Tensor + """ + return _layer_norm_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/layer_norm_gated_fwd.py b/src/kernelgenbench/dataset/baseline/sglang/layer_norm_gated_fwd.py new file mode 100644 index 0000000..a6c5a7f --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/layer_norm_gated_fwd.py @@ -0,0 +1,27 @@ +""" +SGLang layer_norm_gated_fwd baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.fused_norm_gate import layer_norm_gated_fwd as _layer_norm_gated +except ModuleNotFoundError: + _layer_norm_gated = None + + + +def layer_norm_gated_fwd(x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor, bias: typing.Optional[torch.Tensor] = None, activation: str = 'swish', eps: float = 1e-5, residual: typing.Optional[torch.Tensor] = None) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, typling.Optional[torch.Tensor]]: + """ + Args: + x: torch.Tensor + g: torch.Tensor + weight: torch.Tensor + bias: typing.Optional[torch.Tensor] + activation: str + eps: float + residual: typing.Optional[torch.Tensor] + Returns: + typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, typling.Optional[torch.Tensor]] + """ + return _layer_norm_gated(x, g, weight, bias, activation=activation, eps=eps, residual=residual) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/linear_scaling_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/linear_scaling_rotary_embedding.py new file mode 100644 index 0000000..4d3761d --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/linear_scaling_rotary_embedding.py @@ -0,0 +1,45 @@ +""" +SGLang LinearScalingRotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.base import LinearScalingRotaryEmbedding as _LinearScalingRotaryEmbedding +except ModuleNotFoundError: + _LinearScalingRotaryEmbedding = None + +_linear_scaling_module = None +_current_ls_config = None + +def _get_linear_scaling_module(head_size, rotary_dim, max_position_embeddings, base, scaling_factor): + global _linear_scaling_module, _current_ls_config + key = (head_size, rotary_dim, max_position_embeddings, base, scaling_factor) + if _current_ls_config != key: + _linear_scaling_module = _LinearScalingRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + scaling_factor=scaling_factor + ).cuda() + _current_ls_config = key + return _linear_scaling_module + +def linear_scaling_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0, scaling_factor: float = 1.0) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + scaling_factor: float + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _linear_scaling_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/llama3_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/llama3_rotary_embedding.py new file mode 100644 index 0000000..e643c1c --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/llama3_rotary_embedding.py @@ -0,0 +1,50 @@ +""" +SGLang Llama3RotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.rope_variant import Llama3RotaryEmbedding as _Llama3RotaryEmbedding +except ModuleNotFoundError: + _Llama3RotaryEmbedding = None + +_llama3_rope_module = None +_current_l3rope_config = None + +def _get_llama3_rope_module(head_size, rotary_dim, max_position_embeddings, base, factor, low_freq_factor, high_freq_factor, original_max_position_embeddings): + global _llama3_rope_module, _current_l3rope_config + key = (head_size, rotary_dim, max_position_embeddings, base, factor, low_freq_factor, high_freq_factor, original_max_position_embeddings) + if _current_l3rope_config != key: + _llama3_rope_module = _Llama3RotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + factor=factor, low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + original_max_position_embeddings=original_max_position_embeddings + ).cuda() + _current_l3rope_config = key + return _llama3_rope_module + +def llama3_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 500000.0, factor: float = 8.0, low_freq_factor: float = 1.0, high_freq_factor: float = 4.0, original_max_position_embeddings: int = 8192) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + factor: float + low_freq_factor: float + high_freq_factor: float + original_max_position_embeddings: int + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _llama3_rope_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/mamba_chunk_scan_combined_fwd.py b/src/kernelgenbench/dataset/baseline/sglang/mamba_chunk_scan_combined_fwd.py new file mode 100644 index 0000000..d2c8f50 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/mamba_chunk_scan_combined_fwd.py @@ -0,0 +1,31 @@ +""" +SGLang mamba_chunk_scan_combined_fwd baseline (Mamba-2 SSD). +""" +import torch +import typing +try: + from sglang.srt.layers.attention.mamba.ops.ssd_combined import mamba_chunk_scan_combined_fwd as _mamba_scan +except ModuleNotFoundError: + _mamba_scan = None + + + +def mamba_chunk_scan_combined_fwd(x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, chunk_size: int = 64, D: typing.Optional[torch.Tensor] = None, z: typing.Optional[torch.Tensor] = None, dt_bias: typing.Optional[torch.Tensor] = None, initial_states: typing.Optional[torch.Tensor] = None, cu_seqlens: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + x: torch.Tensor + dt: torch.Tensor + A: torch.Tensor + B: torch.Tensor + C: torch.Tensor + chunk_size: int + D: typing.Optional[torch.Tensor] + z: typing.Optional[torch.Tensor] + dt_bias: typing.Optional[torch.Tensor] + initial_states: typing.Optional[torch.Tensor] + cu_seqlens: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _mamba_scan(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, cu_seqlens=cu_seqlens) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/mixer2_rms_norm_gated.py b/src/kernelgenbench/dataset/baseline/sglang/mixer2_rms_norm_gated.py new file mode 100644 index 0000000..0b5f4a6 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/mixer2_rms_norm_gated.py @@ -0,0 +1,33 @@ +""" +SGLang Mixer2RMSNormGated baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated as _Mixer2RMSNormGated +except ModuleNotFoundError: + _Mixer2RMSNormGated = None + +_mixer2_module = None +_current_mixer2_config = None + +def _get_mixer2_module(hidden_size, eps): + global _mixer2_module, _current_mixer2_config + key = (hidden_size, eps) + if _current_mixer2_config != key: + _mixer2_module = _Mixer2RMSNormGated(hidden_size=hidden_size, eps=eps).cuda() + _current_mixer2_config = key + return _mixer2_module + +def mixer2_rms_norm_gated(x: torch.Tensor, gate: torch.Tensor, hidden_size: int, eps: float = 1e-6) -> torch.Tensor: + """ + Args: + x: torch.Tensor + gate: torch.Tensor + hidden_size: int + eps: float + Returns: + torch.Tensor + """ + return _mixer2_module(x, gate) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/moe_align_block_size.py b/src/kernelgenbench/dataset/baseline/sglang/moe_align_block_size.py new file mode 100644 index 0000000..525bfb4 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/moe_align_block_size.py @@ -0,0 +1,23 @@ +""" +SGLang moe_align_block_size baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.moe.moe_runner.triton_utils.moe_align_block_size import moe_align_block_size as _moe_align +except ModuleNotFoundError: + _moe_align = None + + + +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int) -> "typing.Any": + """ + Args: + topk_ids: torch.Tensor + num_experts: int + block_size: int + Returns: + "typing.Any" + """ + return _moe_align(topk_ids, num_experts, block_size) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/mrotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/mrotary_embedding.py new file mode 100644 index 0000000..3c68606 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/mrotary_embedding.py @@ -0,0 +1,48 @@ +""" +SGLang MRotaryEmbedding baseline (multimodal RoPE). +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.mrope import MRotaryEmbedding as _MRotaryEmbedding +except ModuleNotFoundError: + _MRotaryEmbedding = None + +_mrotary_module = None +_current_mrope_config = None + +def _get_mrotary_module(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, mrope_section, mrope_interleaved, dtype): + global _mrotary_module, _current_mrope_config + key = (head_size, rotary_dim, max_position_embeddings, base, is_neox_style, tuple(mrope_section or []), mrope_interleaved, str(dtype)) + if _current_mrope_config != key: + _mrotary_module = _MRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + is_neox_style=is_neox_style, mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, dtype=dtype or torch.bfloat16 + ).cuda() + _current_mrope_config = key + return _mrotary_module + +def mrotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0, is_neox_style: bool = True, mrope_section: typing.Optional[typing.List[int]] = None, mrope_interleaved: bool = False, dtype: torch.dtype = None) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + is_neox_style: bool + mrope_section: typing.Optional[typing.List[int]] + mrope_interleaved: bool + dtype: torch.dtype + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _mrotary_module.forward_cuda(positions, q, k) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/new_gelu.py b/src/kernelgenbench/dataset/baseline/sglang/new_gelu.py new file mode 100644 index 0000000..db11d8a --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/new_gelu.py @@ -0,0 +1,27 @@ +""" +SGLang NewGELU activation baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.activation import NewGELU as _NewGELU +except ModuleNotFoundError: + _NewGELU = None + +_new_gelu_module = None + +def _get_new_gelu_module(): + global _new_gelu_module + if _new_gelu_module is None: + _new_gelu_module = _NewGELU().cuda() + return _new_gelu_module + +def new_gelu(x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: torch.Tensor + Returns: + torch.Tensor + """ + return _new_gelu_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/per_token_quant_int8.py b/src/kernelgenbench/dataset/baseline/sglang/per_token_quant_int8.py new file mode 100644 index 0000000..7ae1068 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/per_token_quant_int8.py @@ -0,0 +1,22 @@ +""" +SGLang per_token_quant_int8 baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 as _per_token_quant +except ModuleNotFoundError: + _per_token_quant = None + + + +def per_token_quant_int8(x: torch.Tensor, scale_dtype: torch.dtype = torch.float32) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: torch.Tensor + scale_dtype: torch.dtype + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + return _per_token_quant(x, scale_dtype=scale_dtype) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/phi3_long_rope_scaled_rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/phi3_long_rope_scaled_rotary_embedding.py new file mode 100644 index 0000000..807e792 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/phi3_long_rope_scaled_rotary_embedding.py @@ -0,0 +1,48 @@ +""" +SGLang Phi3LongRoPEScaledRotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.rope_variant import Phi3LongRoPEScaledRotaryEmbedding as _Phi3LongRoPEScaledRotaryEmbedding +except ModuleNotFoundError: + _Phi3LongRoPEScaledRotaryEmbedding = None + +_phi3_rope_module = None +_current_phi3_config = None + +def _get_phi3_rope_module(head_size, rotary_dim, max_position_embeddings, base, short_factor, long_factor, original_max_position_embeddings): + global _phi3_rope_module, _current_phi3_config + key = (head_size, rotary_dim, max_position_embeddings, base, tuple(short_factor or []), tuple(long_factor or []), original_max_position_embeddings) + if _current_phi3_config != key: + _phi3_rope_module = _Phi3LongRoPEScaledRotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + short_factor=short_factor, long_factor=long_factor, + original_max_position_embeddings=original_max_position_embeddings + ).cuda() + _current_phi3_config = key + return _phi3_rope_module + +def phi3_long_rope_scaled_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 131072, base: float = 10000.0, short_factor: typing.List[float] = None, long_factor: typing.List[float] = None, original_max_position_embeddings: int = 4096) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + short_factor: typing.List[float] + long_factor: typing.List[float] + original_max_position_embeddings: int + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _phi3_rope_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/quick_gelu.py b/src/kernelgenbench/dataset/baseline/sglang/quick_gelu.py new file mode 100644 index 0000000..f7c5b39 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/quick_gelu.py @@ -0,0 +1,27 @@ +""" +SGLang QuickGELU activation baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.activation import QuickGELU as _QuickGELU +except ModuleNotFoundError: + _QuickGELU = None + +_quick_gelu_module = None + +def _get_quick_gelu_module(): + global _quick_gelu_module + if _quick_gelu_module is None: + _quick_gelu_module = _QuickGELU().cuda() + return _quick_gelu_module + +def quick_gelu(x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: torch.Tensor + Returns: + torch.Tensor + """ + return _quick_gelu_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/rms_norm.py b/src/kernelgenbench/dataset/baseline/sglang/rms_norm.py new file mode 100644 index 0000000..50e3e33 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/rms_norm.py @@ -0,0 +1,38 @@ +""" +SGLang RMSNorm baseline. +""" +import torch +import typing +try: + import sgl_kernel.ops.norm as _ops +except ModuleNotFoundError: + _ops = None + + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, residual: typing.Optional[torch.Tensor] = None, out_residual: typing.Optional[torch.Tensor] = None) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + eps: float + residual: typing.Optional[torch.Tensor] + out_residual: typing.Optional[torch.Tensor] + Returns: + typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, torch.Tensor]] + """ + x = x.contiguous() + out = torch.empty_like(x) + residual_out = None + if residual is not None: + residual_c = residual.contiguous() + if out_residual is None: + out_residual = torch.empty_like(residual_c) + _ops.fused_add_rmsnorm(x, residual_c, weight, eps, out, out_residual) + residual_out = out_residual + return out, residual_out + _ops.rmsnorm(out, x, weight, eps) + if residual is None: + return out + return (out,) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/rms_norm_gated.py b/src/kernelgenbench/dataset/baseline/sglang/rms_norm_gated.py new file mode 100644 index 0000000..33aae3e --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/rms_norm_gated.py @@ -0,0 +1,28 @@ +""" +SGLang rms_norm_gated baseline (FLA fused norm+gate). +""" +import torch +import typing +try: + from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated as _rms_norm_gated_fn +except ModuleNotFoundError: + _rms_norm_gated_fn = None + + + +def rms_norm_gated(x: torch.Tensor, weight: torch.Tensor, bias: typing.Optional[torch.Tensor] = None, z: typing.Optional[torch.Tensor] = None, eps: float = 1e-6, group_size: typing.Optional[int] = None, norm_before_gate: bool = True, activation: str = 'swish') -> torch.Tensor: + """ + Args: + x: torch.Tensor + weight: torch.Tensor + bias: typing.Optional[torch.Tensor] + z: typing.Optional[torch.Tensor] + eps: float + group_size: typing.Optional[int] + norm_before_gate: bool + activation: str + Returns: + torch.Tensor + """ + return _rms_norm_gated_fn(x, weight, bias=bias, z=z, eps=eps, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=True, activation=activation) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/rms_norm_without_scale.py b/src/kernelgenbench/dataset/baseline/sglang/rms_norm_without_scale.py new file mode 100644 index 0000000..f7e6c71 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/rms_norm_without_scale.py @@ -0,0 +1,32 @@ +""" +SGLang RMSNormWithoutScale baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.layernorm import RMSNormWithoutScale as _RMSNormWithoutScale +except ModuleNotFoundError: + _RMSNormWithoutScale = None + +_rms_norm_without_scale_module = None +_current_rnws_config = None + +def _get_rms_norm_without_scale_module(hidden_size, eps): + global _rms_norm_without_scale_module, _current_rnws_config + key = (hidden_size, eps) + if _current_rnws_config != key: + _rms_norm_without_scale_module = _RMSNormWithoutScale(hidden_size=hidden_size, eps=eps).cuda() + _current_rnws_config = key + return _rms_norm_without_scale_module + +def rms_norm_without_scale(x: torch.Tensor, hidden_size: int, eps: float = 1e-6) -> torch.Tensor: + """ + Args: + x: torch.Tensor + hidden_size: int + eps: float + Returns: + torch.Tensor + """ + return _rms_norm_without_scale_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/rotary_embedding.py b/src/kernelgenbench/dataset/baseline/sglang/rotary_embedding.py new file mode 100644 index 0000000..882cd3c --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/rotary_embedding.py @@ -0,0 +1,46 @@ +""" +SGLang RotaryEmbedding baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding import RotaryEmbedding as _RotaryEmbedding +except ModuleNotFoundError: + _RotaryEmbedding = None + +_rotary_embedding_module = None +_current_rope_config = None + +def _get_rotary_module(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype): + global _rotary_embedding_module, _current_rope_config + key = (head_size, rotary_dim, max_position_embeddings, base, is_neox_style, str(dtype)) + if _current_rope_config != key: + _rotary_embedding_module = _RotaryEmbedding( + head_size=head_size, rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, base=base, + is_neox_style=is_neox_style, dtype=dtype or torch.bfloat16 + ).cuda() + _current_rope_config = key + return _rotary_embedding_module + +def rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: typing.Optional[torch.Tensor] = None, head_size: int, rotary_dim: int, max_position_embeddings: int = 8192, base: float = 10000.0, is_neox_style: bool = True, dtype: torch.dtype = None) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + offsets: typing.Optional[torch.Tensor] + head_size: int + rotary_dim: int + max_position_embeddings: int + base: float + is_neox_style: bool + dtype: torch.dtype + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _rotary_embedding_module(q, k, positions, offsets) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/selective_scan_update.py b/src/kernelgenbench/dataset/baseline/sglang/selective_scan_update.py new file mode 100644 index 0000000..1b49776 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/selective_scan_update.py @@ -0,0 +1,30 @@ +""" +SGLang selective_scan_update baseline (Mamba SSM). +""" +import torch +import typing +try: + from sglang.srt.layers.attention.mamba.ops.mamba_ssm import selective_scan_update as _selective_scan_update +except ModuleNotFoundError: + _selective_scan_update = None + + + +def selective_scan_update(state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D: typing.Optional[torch.Tensor] = None, z: typing.Optional[torch.Tensor] = None, dt_bias: typing.Optional[torch.Tensor] = None, state_batch_indices: typing.Optional[torch.Tensor] = None) -> typing.Tuple[torch.Tensor, typling.Optional[torch.Tensor]]: + """ + Args: + state: torch.Tensor + x: torch.Tensor + dt: torch.Tensor + A: torch.Tensor + B: torch.Tensor + C: torch.Tensor + D: typing.Optional[torch.Tensor] + z: typing.Optional[torch.Tensor] + dt_bias: typing.Optional[torch.Tensor] + state_batch_indices: typing.Optional[torch.Tensor] + Returns: + typing.Tuple[torch.Tensor, typling.Optional[torch.Tensor]] + """ + return _selective_scan_update(state=state, x=x, dt=dt, A=A, B=B, C=C, D=D, z=z, dt_bias=dt_bias, state_batch_indices=state_batch_indices) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul.py b/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul.py new file mode 100644 index 0000000..a306a1b --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul.py @@ -0,0 +1,24 @@ +""" +SGLang SiluAndMul activation baseline. +""" +import torch +import typing +try: + import sgl_kernel.ops.elementwise as _ops +except ModuleNotFoundError: + _ops = None + + + +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: torch.Tensor + Returns: + torch.Tensor + """ + x = x.contiguous() + out = torch.empty(x.shape[0], x.shape[-1] // 2, dtype=x.dtype, device=x.device) + _ops.silu_and_mul(x, out) + return out + diff --git a/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul_triton.py b/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul_triton.py new file mode 100644 index 0000000..7e8328d --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/silu_and_mul_triton.py @@ -0,0 +1,24 @@ +""" +SGLang silu_and_mul_triton baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import silu_and_mul_triton as _silu_and_mul_t +except ModuleNotFoundError: + _silu_and_mul_t = None + + + +def silu_and_mul_triton(hidden_states: torch.Tensor, scales: typing.Optional[torch.Tensor] = None, quantize: typing.Optional[str] = None, out: typing.Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + hidden_states: torch.Tensor + scales: typing.Optional[torch.Tensor] + quantize: typing.Optional[str] + out: typing.Optional[torch.Tensor] + Returns: + torch.Tensor + """ + return _silu_and_mul_t(hidden_states, scales=scales, quantize=quantize, out=out)[0] + diff --git a/src/kernelgenbench/dataset/baseline/sglang/softcap.py b/src/kernelgenbench/dataset/baseline/sglang/softcap.py new file mode 100644 index 0000000..4bfc291 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/softcap.py @@ -0,0 +1,31 @@ +""" +SGLang Softcap baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.elementwise import Softcap as _Softcap +except ModuleNotFoundError: + _Softcap = None + +_softcap_module = None +_current_sc_config = None + +def _get_softcap_module(softcap_val): + global _softcap_module, _current_sc_config + key = (softcap_val,) + if _current_sc_config != key: + _softcap_module = _Softcap(softcap_val).cuda() + _current_sc_config = key + return _softcap_module + +def softcap(x: torch.Tensor, softcap: float = 50.0) -> torch.Tensor: + """ + Args: + x: torch.Tensor + softcap: float + Returns: + torch.Tensor + """ + return _softcap_module(x) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/topk.py b/src/kernelgenbench/dataset/baseline/sglang/topk.py new file mode 100644 index 0000000..aa49881 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/topk.py @@ -0,0 +1,33 @@ +""" +SGLang TopK MoE router baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.moe.topk import TopK as _TopK +except ModuleNotFoundError: + _TopK = None + +_topk_module = None +_current_topk_config = None + +def _get_topk_module(topk, renormalize): + global _topk_module, _current_topk_config + key = (topk, renormalize) + if _current_topk_config != key: + _topk_module = _TopK(topk=topk, renormalize=renormalize).cuda() + _current_topk_config = key + return _topk_module + +def topk(hidden_states: torch.Tensor, router_logits: torch.Tensor, topk: int = 8, renormalize: bool = True) -> "typing.Any": + """ + Args: + hidden_states: torch.Tensor + router_logits: torch.Tensor + topk: int + renormalize: bool + Returns: + "typing.Any" + """ + return _topk_module(hidden_states, router_logits) + diff --git a/src/kernelgenbench/dataset/baseline/sglang/triton_ernie45_rope_fused.py b/src/kernelgenbench/dataset/baseline/sglang/triton_ernie45_rope_fused.py new file mode 100644 index 0000000..faf270b --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/triton_ernie45_rope_fused.py @@ -0,0 +1,30 @@ +""" +SGLang triton_ernie45_rope_fused_inplace baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.triton_kernels import triton_ernie45_rope_fused_inplace as _triton_ernie45_rope_fused +except ModuleNotFoundError: + _triton_ernie45_rope_fused = None + + + +def triton_ernie45_rope_fused(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, cos_sin_cache: torch.Tensor, head_size: int, rotary_dim: int, mrope_section: typing.Optional[typing.List[int]] = None) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + cos_sin_cache: torch.Tensor + head_size: int + rotary_dim: int + mrope_section: typing.Optional[typing.List[int]] + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _triton_ernie45_rope_fused(q, k, cos_sin_cache, positions, mrope_section=[head_size]*3, head_size=head_size, rotary_dim=rotary_dim, is_neox_style=True) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/triton_mrope_fused.py b/src/kernelgenbench/dataset/baseline/sglang/triton_mrope_fused.py new file mode 100644 index 0000000..ae40d83 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/triton_mrope_fused.py @@ -0,0 +1,30 @@ +""" +SGLang triton_mrope_fused baseline (Triton fused multimodal RoPE). +""" +import torch +import typing +try: + from sglang.srt.layers.rotary_embedding.triton_kernels import triton_mrope_fused as _triton_mrope_fused +except ModuleNotFoundError: + _triton_mrope_fused = None + + + +def triton_mrope_fused(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, cos_sin_cache: torch.Tensor, head_size: int, rotary_dim: int, mrope_section: typing.Optional[typing.List[int]] = None) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + positions: torch.Tensor + query: torch.Tensor + key: torch.Tensor + cos_sin_cache: torch.Tensor + head_size: int + rotary_dim: int + mrope_section: typing.Optional[typing.List[int]] + Returns: + typing.Tuple[torch.Tensor, torch.Tensor] + """ + q = query.clone() + k = key.clone() + _triton_mrope_fused(q, k, cos_sin_cache, positions, mrope_section=[head_size]*3, head_size=head_size, rotary_dim=rotary_dim, mrope_interleaved=False, mrope_interleaved_glm=False, is_neox_style=True, axis_map=None) + return q, k + diff --git a/src/kernelgenbench/dataset/baseline/sglang/xielu.py b/src/kernelgenbench/dataset/baseline/sglang/xielu.py new file mode 100644 index 0000000..401eaa1 --- /dev/null +++ b/src/kernelgenbench/dataset/baseline/sglang/xielu.py @@ -0,0 +1,27 @@ +""" +SGLang XIELU activation baseline. +""" +import torch +import typing +try: + from sglang.srt.layers.activation import XIELU as _XIELU +except ModuleNotFoundError: + _XIELU = None + +_xielu_module = None + +def _get_xielu_module(): + global _xielu_module + if _xielu_module is None: + _xielu_module = _XIELU().cuda() + return _xielu_module + +def xielu(x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: torch.Tensor + Returns: + torch.Tensor + """ + return _xielu_module(x) +