Skip to content

feat: Support Chunked Linear Fusion for GRPO#2833

Open
pengdurice wants to merge 9 commits into
NVIDIA-NeMo:mainfrom
pengdurice:peng-linear-ce-loss-grpo-v1-clean
Open

feat: Support Chunked Linear Fusion for GRPO#2833
pengdurice wants to merge 9 commits into
NVIDIA-NeMo:mainfrom
pengdurice:peng-linear-ce-loss-grpo-v1-clean

Conversation

@pengdurice

Copy link
Copy Markdown
Contributor

What does this PR do ?

Support Linear CE Loss Fusion for GRPO.

On top of #2036 (SFT) and #2139 (DPO), this PR extends the chunked linear
cross-entropy fusion loss to GRPO (and the other policy-gradient variants that
share ClippedPGLossFn: PPO, REINFORCE/RLOO, GSPO, Dr.GRPO, etc.).

Optimizations

Chunked Linear Cross-Entropy Fusion Loss

During standard GRPO training the model materializes a full logit tensor of
shape [batch_size, seq_length, vocab_size] for the policy forward-backward
pass and for the previous-/reference-policy logprob computations. This can
cause out-of-memory (OOM) errors for long sequences or large vocabularies. The
chunked linear cross-entropy fusion loss avoids this by computing the
per-token log probabilities directly from the hidden states: it chunks the
sequence dimension, projects each chunk to logits on the fly, gathers the
realized-token log probabilities, and discards the logits before moving to the
next chunk.

Benefits:

  • Extends the maximum trainable sequence length by eliminating the large logit
    tensor from GPU memory.
  • Applies to the training forward-backward pass and the previous/reference
    policy logprob computations (the worker get_logprobs path already honors the
    flag).
  • Produces numerically equivalent loss/reward dynamics to the standard path.

Key changes

  1. ClippedPGLossFn.__init__ now accepts use_linear_ce_fusion. The generic
    prepare_loss_input already treats the model output as precomputed
    next-token logprobs for any LossInputType.LOGPROB loss whose
    use_linear_ce_fusion attribute is set, so no per-loss branching was needed.
  2. nemo_rl/algorithms/grpo.py reads
    policy.megatron_cfg.use_linear_ce_fusion_loss, passes it into
    ClippedPGLossFn, and adds guardrails (see Constraints below).
  3. nemo_rl/models/megatron/setup.py: when fusion is enabled, force
    overlap_param_gather off (with a warning) and skip the
    forward-pre-hook-disable path. The fused forward runs the decoder but reads
    output_layer.weight directly instead of calling output_layer.forward(),
    which breaks Megatron's distributed-optimizer param-gather prefetch chain
    (assert self.param_gather_handle is None in param_and_grad_buffer.py).
    This is most easily triggered on MoE models with expert parallelism.
  4. nemo_rl/distributed/model_utils.py: the patched GPT forward now accepts and
    forwards the newer Megatron-Core is_spec_decode kwarg (only when provided),
    keeping the monkey-patch in sync with the current GPTModel.forward
    signature while staying compatible with older ones.

Constraints

Enforced in nemo_rl/algorithms/grpo.py / setup.py:

  • Megatron backend only (policy.megatron_cfg.enabled: true).
  • Context parallelism must be 1.
  • Sequence packing must be disabled (the fused forward rolls labels over the
    whole sequence and would mix tokens across packed boundaries).
  • Top-k/top-p training-time filtering must be off (top_k: null, top_p: 1.0),
    since the fused path gathers logprobs from unfiltered logits.

Issues

NA

Tests

  1. Unit testtest_megatron_grpo_linear_ce_fusion_agreement in
    tests/unit/models/policy/test_megatron_worker.py compares the standard
    ClippedPGLossFn loss against the fusion-enabled loss on a tiny Qwen2 model
    (rtol/atol = 1e-2), mirroring the SFT/DPO agreement tests.

  2. Nightly functional test — added
    tests/test_suites/llm/grpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh

    • recipe under examples/configs/recipes/llm/ and registered it in
      tests/test_suites/nightly.txt.
  3. Large-scale A/B convergence check — GRPO on Qwen3-30B-A3B (MoE,
    Megatron TP=2 × EP=8 on 2×8 H200, vLLM TP=4 colocated, seq 4096), 50 steps,
    fusion ON vs OFF, two seeds each. Configs are identical except the fusion
    flag.

    run fusion mean reward (50 steps) gen-KL error
    A1 on 0.523 ~0.002
    A2 on 0.524 ~0.002
    B1 off 0.518 ~0.002
    B2 off 0.518 ~0.002
    • Step 1 loss/reward are identical across all runs (generation is
      deterministic from the seed and the optimizer has not yet diverged),
      confirming the fused logprobs/gradients match the standard path.
    • The train-vs-inference logprob error (Generation KL Error) is ~0.002 and
      step-for-step identical between fusion and baseline.
    • Aggregate reward differs by <1% (within run-to-run RL noise — the two
      fusion runs differ by more than fusion-vs-baseline does), and the four
      reward/loss trajectories interleave with no systematic gap. No
      instability/NaN; all runs completed 50/50 steps.

Usage

Enable megatron_cfg and add the two flags below to your GRPO config:

policy:
  megatron_cfg:
    enabled: true
    use_linear_ce_fusion_loss: true
    linear_ce_fusion_chunk_size: 256  # tokens per chunk; smaller = less memory, larger = more throughput
  # required with fusion:
  sequence_packing:
    enabled: false
  generation:
    top_k: null
    top_p: 1.0

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally?
  • Did you add or update any necessary documentation? (docs/guides/grpo.md)

Additional Information

  • Docs: added a "Chunked Linear Cross-Entropy Fusion Loss" subsection to
    docs/guides/grpo.md.
  • The bulk of the fusion machinery is shared infrastructure introduced in feat: Add chunked linear ce loss function from hidden states #2036
    and is algorithm-agnostic; this PR is mostly wiring + guardrails + the
    distributed-optimizer overlap_param_gather fix surfaced by MoE + EP.

@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added Documentation Improvements or additions to documentation community-request labels Jun 16, 2026
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
@pengdurice pengdurice force-pushed the peng-linear-ce-loss-grpo-v1-clean branch from 74be50f to 5cc6a69 Compare June 16, 2026 04:35
@pengdurice pengdurice changed the title Support Linear CE Loss Fusion for GRPO feat: Support Linear CE Loss Fusion for GRPO Jun 16, 2026
@pengdurice pengdurice marked this pull request as ready for review June 16, 2026 04:35
@pengdurice pengdurice requested review from a team as code owners June 16, 2026 04:35

@yuki-97 yuki-97 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice thanks for adding this to GRPO as well, overall LGTM and left some minor comments.

# (0% accuracy / gibberish) for several Qwen3 MoE configs, see vLLM issues
# #34892, #37591, #37758. Force the reference Triton MoE backend instead.
# (Set before the per-recipe env_vars loop so a yaml can still override.)
env_vars["VLLM_USE_FLASHINFER_MOE_FP16"] = "0"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move the two env_vars to examples/configs/recipes/llm/grpo-qwen3-30ba3b-2n8g-megatron_chunked_linear_ce_loss.yaml? so that other configs won't be affect.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good catch, changed!

Comment thread nemo_rl/models/megatron/setup.py Outdated
Comment on lines +907 to +916
if (
config["megatron_cfg"].get("use_linear_ce_fusion_loss", False)
and overlap_param_gather
):
warnings.warn(
"Disabling overlap_param_gather because linear CE fusion loss is enabled: "
"the fused forward bypasses output_layer.forward() and is incompatible "
"with the distributed-optimizer param-gather prefetch chain."
)
overlap_param_gather = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just assert False when using use_linear_ce_fusion_loss + overlap_param_gather?

Comment thread nemo_rl/models/megatron/setup.py Outdated
]
# Linear CE fusion forces overlap_param_gather off (see
# create_megatron_config), so there is no forward pre-hook to disable.
and not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here seems can be removed if this done. https://github.com/NVIDIA-NeMo/RL/pull/2833/changes#r3421595432

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, removed!

Comment thread nemo_rl/algorithms/grpo.py Outdated
Comment on lines +419 to +421
use_linear_ce_fusion = policy_config["megatron_cfg"]["enabled"] and policy_config[
"megatron_cfg"
].get("use_linear_ce_fusion_loss", False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
use_linear_ce_fusion = policy_config["megatron_cfg"]["enabled"] and policy_config[
"megatron_cfg"
].get("use_linear_ce_fusion_loss", False)
use_linear_ce_fusion = (
policy_config["megatron_cfg"]["enabled"]
and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"]
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank for this, since the name needs to be changed, so I didn't sign off and apply this suggestion directly, but thank you for pointing this out;)

Comment thread examples/configs/grpo_math_1B.yaml Outdated
# [batch, seq_len, vocab_size] logit tensor. Reduces peak memory and extends
# the maximum trainable sequence length. Not compatible with context
# parallelism, sequence packing, or top-k/top-p training-time filtering.
use_linear_ce_fusion_loss: false

@yuki-97 yuki-97 Jun 16, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels ce_fusion_loss is a bit confused here since GRPO is not using CE.
is there a better name for it? and could also apply the new name to SFT and DPO to keep the same.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I changed to use_fused_linear_logprobs which is more general, does that make sense? Also for DPO, it is using the logprobs similar to GPRO. so I also changed there. LMK if that makes sense;-) thank you for pointing this out!

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label Jun 16, 2026
Signed-off-by: pengdurice <pengduhit@gmail.com>
@pengdurice pengdurice requested a review from a team as a code owner June 17, 2026 17:27
@pengdurice pengdurice changed the title feat: Support Linear CE Loss Fusion for GRPO feat: Support Chunked Linear Fusion for GRPO Jun 17, 2026
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-maintainers Waiting on maintainers to respond and removed waiting-on-customer Waiting on the original author to respond labels Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request Documentation Improvements or additions to documentation waiting-on-maintainers Waiting on maintainers to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants