feat: Support Chunked Linear Fusion for GRPO#2833
Conversation
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
74be50f to
5cc6a69
Compare
yuki-97
left a comment
There was a problem hiding this comment.
@pengdurice thanks for adding this to GRPO as well, overall LGTM and left some minor comments.
| # (0% accuracy / gibberish) for several Qwen3 MoE configs, see vLLM issues | ||
| # #34892, #37591, #37758. Force the reference Triton MoE backend instead. | ||
| # (Set before the per-recipe env_vars loop so a yaml can still override.) | ||
| env_vars["VLLM_USE_FLASHINFER_MOE_FP16"] = "0" |
There was a problem hiding this comment.
can we move the two env_vars to examples/configs/recipes/llm/grpo-qwen3-30ba3b-2n8g-megatron_chunked_linear_ce_loss.yaml? so that other configs won't be affect.
There was a problem hiding this comment.
yeah, good catch, changed!
| if ( | ||
| config["megatron_cfg"].get("use_linear_ce_fusion_loss", False) | ||
| and overlap_param_gather | ||
| ): | ||
| warnings.warn( | ||
| "Disabling overlap_param_gather because linear CE fusion loss is enabled: " | ||
| "the fused forward bypasses output_layer.forward() and is incompatible " | ||
| "with the distributed-optimizer param-gather prefetch chain." | ||
| ) | ||
| overlap_param_gather = False |
There was a problem hiding this comment.
can we just assert False when using use_linear_ce_fusion_loss + overlap_param_gather?
| ] | ||
| # Linear CE fusion forces overlap_param_gather off (see | ||
| # create_megatron_config), so there is no forward pre-hook to disable. | ||
| and not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False) |
There was a problem hiding this comment.
here seems can be removed if this done. https://github.com/NVIDIA-NeMo/RL/pull/2833/changes#r3421595432
| use_linear_ce_fusion = policy_config["megatron_cfg"]["enabled"] and policy_config[ | ||
| "megatron_cfg" | ||
| ].get("use_linear_ce_fusion_loss", False) |
There was a problem hiding this comment.
nit
| use_linear_ce_fusion = policy_config["megatron_cfg"]["enabled"] and policy_config[ | |
| "megatron_cfg" | |
| ].get("use_linear_ce_fusion_loss", False) | |
| use_linear_ce_fusion = ( | |
| policy_config["megatron_cfg"]["enabled"] | |
| and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"] | |
| ) |
There was a problem hiding this comment.
thank for this, since the name needs to be changed, so I didn't sign off and apply this suggestion directly, but thank you for pointing this out;)
| # [batch, seq_len, vocab_size] logit tensor. Reduces peak memory and extends | ||
| # the maximum trainable sequence length. Not compatible with context | ||
| # parallelism, sequence packing, or top-k/top-p training-time filtering. | ||
| use_linear_ce_fusion_loss: false |
There was a problem hiding this comment.
feels ce_fusion_loss is a bit confused here since GRPO is not using CE.
is there a better name for it? and could also apply the new name to SFT and DPO to keep the same.
There was a problem hiding this comment.
yes, I changed to use_fused_linear_logprobs which is more general, does that make sense? Also for DPO, it is using the logprobs similar to GPRO. so I also changed there. LMK if that makes sense;-) thank you for pointing this out!
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
What does this PR do ?
Support Linear CE Loss Fusion for GRPO.
On top of #2036 (SFT) and #2139 (DPO), this PR extends the chunked linear
cross-entropy fusion loss to GRPO (and the other policy-gradient variants that
share
ClippedPGLossFn: PPO, REINFORCE/RLOO, GSPO, Dr.GRPO, etc.).Optimizations
Chunked Linear Cross-Entropy Fusion Loss
During standard GRPO training the model materializes a full logit tensor of
shape
[batch_size, seq_length, vocab_size]for the policy forward-backwardpass and for the previous-/reference-policy logprob computations. This can
cause out-of-memory (OOM) errors for long sequences or large vocabularies. The
chunked linear cross-entropy fusion loss avoids this by computing the
per-token log probabilities directly from the hidden states: it chunks the
sequence dimension, projects each chunk to logits on the fly, gathers the
realized-token log probabilities, and discards the logits before moving to the
next chunk.
Benefits:
tensor from GPU memory.
policy logprob computations (the worker
get_logprobspath already honors theflag).
Key changes
ClippedPGLossFn.__init__now acceptsuse_linear_ce_fusion. The genericprepare_loss_inputalready treats the model output as precomputednext-token logprobs for any
LossInputType.LOGPROBloss whoseuse_linear_ce_fusionattribute is set, so no per-loss branching was needed.nemo_rl/algorithms/grpo.pyreadspolicy.megatron_cfg.use_linear_ce_fusion_loss, passes it intoClippedPGLossFn, and adds guardrails (see Constraints below).nemo_rl/models/megatron/setup.py: when fusion is enabled, forceoverlap_param_gatheroff (with a warning) and skip theforward-pre-hook-disable path. The fused forward runs the decoder but reads
output_layer.weightdirectly instead of callingoutput_layer.forward(),which breaks Megatron's distributed-optimizer param-gather prefetch chain
(
assert self.param_gather_handle is Noneinparam_and_grad_buffer.py).This is most easily triggered on MoE models with expert parallelism.
nemo_rl/distributed/model_utils.py: the patched GPT forward now accepts andforwards the newer Megatron-Core
is_spec_decodekwarg (only when provided),keeping the monkey-patch in sync with the current
GPTModel.forwardsignature while staying compatible with older ones.
Constraints
Enforced in
nemo_rl/algorithms/grpo.py/setup.py:policy.megatron_cfg.enabled: true).whole sequence and would mix tokens across packed boundaries).
top_k: null,top_p: 1.0),since the fused path gathers logprobs from unfiltered logits.
Issues
NA
Tests
Unit test —
test_megatron_grpo_linear_ce_fusion_agreementintests/unit/models/policy/test_megatron_worker.pycompares the standardClippedPGLossFnloss against the fusion-enabled loss on a tiny Qwen2 model(rtol/atol = 1e-2), mirroring the SFT/DPO agreement tests.
Nightly functional test — added
tests/test_suites/llm/grpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.shexamples/configs/recipes/llm/and registered it intests/test_suites/nightly.txt.Large-scale A/B convergence check — GRPO on Qwen3-30B-A3B (MoE,
Megatron TP=2 × EP=8 on 2×8 H200, vLLM TP=4 colocated, seq 4096), 50 steps,
fusion ON vs OFF, two seeds each. Configs are identical except the fusion
flag.
deterministic from the seed and the optimizer has not yet diverged),
confirming the fused logprobs/gradients match the standard path.
Generation KL Error) is ~0.002 andstep-for-step identical between fusion and baseline.
fusion runs differ by more than fusion-vs-baseline does), and the four
reward/loss trajectories interleave with no systematic gap. No
instability/NaN; all runs completed 50/50 steps.
Usage
Enable
megatron_cfgand add the two flags below to your GRPO config:Before your PR is "Ready for review"
Pre checks:
Additional Information
docs/guides/grpo.md.and is algorithm-agnostic; this PR is mostly wiring + guardrails + the
distributed-optimizer
overlap_param_gatherfix surfaced by MoE + EP.