fix(gemma4_moe): re-tie lm_head to active embed_tokens on MoE path#2601
fix(gemma4_moe): re-tie lm_head to active embed_tokens on MoE path#2601Achyuthan-S wants to merge 4 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds regression coverage and a small model fix to ensure Gemma4 MoE preserves Hugging Face-style weight tying between lm_head and the active embedding after the MoE backend swaps out the language model.
Changes:
- Re-tie
lm_head.weightto the MoE backend’sembed_tokens.weightwhentie_word_embeddings=True. - Add CPU-only unit tests validating tied and untied behavior, including after
initialize_weights()casting.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/unit_tests/models/gemma4_moe/test_gemma4_moe_tied_weights.py | New unit tests to lock in correct tied/untied lm_head ↔ embedding behavior across initialization. |
| nemo_automodel/components/models/gemma4_moe/model.py | Restores lm_head weight tying after replacing the HF language model with the MoE backend. |
Comments suppressed due to low confidence (1)
nemo_automodel/components/models/gemma4_moe/model.py:859
buffer_deviceis atorch.device, but the code useswith buffer_device:, which is invalid (devices are not context managers). This is currently a runtime error on the MoE path wheninitialize_weights()reaches this block. Replace the context-manager usage with a valid device context (typicallytorch.cuda.device(index)when CUDA is available) or restructure the code to avoid needing a context manager at all and instead move tensors/ops explicitly to the target device.
else getattr(
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if getattr(text_config, "tie_word_embeddings", False): | ||
| self.lm_head.weight = self.model.language_model.embed_tokens.weight |
There was a problem hiding this comment.
This is intentional. HF's tie_weights() already ran in super().__init__() and tied lm_head to the original embed_tokens; the issue is that the MoE swap replaces language_model afterward, so re-running tie_weights() leans on the same get_input_embeddings() indirection this is working around (moe/parallelizer.py notes HF's tie_weights() can be incompatible with these custom models). The direct re-point mirrors llama's tie_weights(). Downstream save/load still goes through ensure_tied_lm_head(), which tries tie_weights() first and falls back to direct assignment, so the conventional path is covered.
There was a problem hiding this comment.
I think the direct re-point is the right primitive here, but it would still be better to expose it through this class’s own tie_weights() override.
The reason is that AutoModel/checkpoint paths can call model.tie_weights() again after construction/load. We should make sure that call re-ties to the active MoE embedding, not whatever HF’s generic behavior would do. So the method can still use your exact direct assignment:
def tie_weights(self):
text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(text_config, "tie_word_embeddings", False):
self.lm_head.weight = self.model.language_model.embed_tokens.weightThen after replacing language_model, call self.tie_weights().
This preserves the intentional direct assignment, but makes the public re-tie hook reliable for downstream load/checkpoint paths.
There was a problem hiding this comment.
Done — pulled the direct re-point into tie_weights() and call it after the language_model swap. Added tests that break the tie and assert tie_weights() re-points to the active MoE embedding (and is a no-op when untied).
| def test_tied_lm_head_survives_initialize_weights(): | ||
| """The tie set in __init__ must survive the bf16 cast in initialize_weights().""" | ||
| model = _build(tie_word_embeddings=True) | ||
| model.initialize_weights(dtype=torch.bfloat16, buffer_device=torch.device("cpu")) |
There was a problem hiding this comment.
torch.device has been a context manager since PyTorch 1.13 — with torch.device(...): sets the default device for factory calls, so with buffer_device: is valid, including on CPU. The test here passes buffer_device=torch.device("cpu") and runs through this exact MoE block; it's green on torch 2.10, so there's no TypeError on this path. initialize_weights() is also pre-existing and unchanged by this PR.
The MoE path replaces language_model after HF __init__, orphaning the lm_head<->embed_tokens tie that HF set up. Re-tie lm_head to the active embed_tokens when tie_word_embeddings is set (Gemma defaults True). Add CPU tied/untied tests. Refs NVIDIA-NeMo#2512 Signed-off-by: Achyuthan-S <as21154@nyu.edu>
46282b9 to
4497260
Compare
Wrap the post-swap lm_head re-point in a tie_weights() override so AutoModel and ensure_tied_lm_head() re-tie to the active MoE embedding. Add hook tests for re-tie and untied no-op. Signed-off-by: Achyuthan Sivasankar <achyuthan.sivasankar@gmail.com>
|
/ok to test b5c96b0 |
|
Hi @yuhezhang-ai , Thank you very much for your support. Please feel free to assign me to interesting issues/features. would love to contribute more! |
Thank you, and thanks again for the quick fix here. The PR looks good to me. |
Thanks @yuhezhang-ai — glad the Gemma4 piece looks good. I'll pick up the broader audit next: confirm HF defaults / checkpoint behavior per family, then land the reject-when-tie=True guard for the separate-head models (as we scoped on #2512). I'll post the per-family findings on the issue before opening a broad PR. |
|
/claude review |
|
/ok to test b5c96b0 |
| dtype=get_dtype(getattr(text_config, "torch_dtype", None), torch.bfloat16), | ||
| ) | ||
|
|
||
| def tie_weights(self, *_args: object, **_kwargs: object) -> None: |
There was a problem hiding this comment.
Nit: every other model that implements tie_weights() also declares _tied_weights_keys (e.g. Llama, Qwen2, Mistral3, Baichuan all set _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}). The checkpoint utilities have fallbacks that cover this (including the hardcoded "model.language_model.embed_tokens.weight" candidate in _candidate_source_names()), so this isn't a correctness bug today — but adding the declaration would be consistent with the rest of the codebase and the onboarding checklist:
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}(as a class attribute on Gemma4ForConditionalGeneration). Pre-existing gap, so fine to address in a follow-up.
There was a problem hiding this comment.
_tied_weights_keys is already inherited from HFGemma4ForConditionalGeneration as {'lm_head.weight': 'model.language_model.embed_tokens.weight'} — the mapping you suggested. Checkpoint utils read it via the dict branch in get_tied_lm_head_source_names, so the tied source resolves through the declaration rather than a fallback.
Llama/Qwen2 declare it explicitly because they're NeMo-native classes with no HF parent. Gemma4 subclasses the HF class that already sets it, so a re-declaration would be redundant. Happy to add an explicit class attribute for visibility if a maintainer prefers, but it's not needed for correctness.
| """ | ||
| text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config | ||
| if getattr(text_config, "tie_word_embeddings", False): | ||
| self.lm_head.weight = self.model.language_model.embed_tokens.weight |
There was a problem hiding this comment.
One small correction after our discussion in the issue and checking Gemma4 HF behavior directly: for Gemma4ForConditionalGeneration, the controlling flag is the top-level Gemma4Config.tie_word_embeddings, not text_config.tie_word_embeddings.
I tested the conflicting cases under transformers 5.8.1:
- top-level
True, textTrue-> tied - top-level
False, textFalse-> untied - top-level
False, textTrue-> untied - top-level
True, textFalse-> tied
So the current Gemma4 MoE fix is correct for the normal/default case because both flags are True, but the override should not check text_config first. Otherwise it can diverge from HF behavior when the two flags disagree.
I'd suggest changing the tie check to top-level-first, e.g.:
text_config = self.config.text_config if hasattr(self.config, "text_config") else self.config
if getattr(self.config, "tie_word_embeddings", getattr(text_config, "tie_word_embeddings", False)):
self.lm_head.weight = self.model.language_model.embed_tokens.weightAnd in the tests, set the top-level flag explicitly too:
config = Gemma4Config(
text_config=_make_text_config(tie_word_embeddings=tie_word_embeddings),
tie_word_embeddings=tie_word_embeddings,
)That keeps the fix aligned with HF's actual tying behavior instead of baking in the text_config-first rule.
There was a problem hiding this comment.
Oh I see , good catch , for the default case both flags are True so the re-tie was correct, but you're right it should follow the top-level Gemma4Config.tie_word_embeddings. Confirmed by construction on transformers 5.8.1 that all four combinations match your table. Switched the override to top-level-first (with a text_config fallback), set the top-level flag explicitly in the tests, and added a case with the flags disagreeing (top True/text False → tied, top False/text True → untied) to pin it. pytest tests/unit_tests/models/gemma4_moe/ -v → 7 passed.
Per review: HF Gemma4 ties on the top-level Gemma4Config.tie_word_embeddings regardless of the nested text_config (verified by construction on transformers 5.8.1). Read top-level first with a text_config fallback, set the top-level flag in tests, and add a case proving top-level wins when the two disagree. Refs NVIDIA-NeMo#2512 Signed-off-by: Achyuthan Sivasankar <achyuthan.sivasankar@gmail.com>
|
/ok to test bed4c65 |
|
/claude review |
|
Hi @yuhezhang-ai — I see you've already approved — thanks again. Just checking whether there's anything else needed from my side before this gets merged, or if we're all set from a review perspective. |
Picking up the Gemma4 MoE part of #2512.
Problem
Gemma4ForConditionalGenerationlets HF'ssuper().__init__()set up the model, which tieslm_head.weightto the textembed_tokens. The MoE path then replacesself.model.language_modelwith a freshGemma4MoETextModelBackendthat has its ownembed_tokens, solm_headis left aliased to the old, orphaned embedding instead of the active one.Gemma defaults to
tie_word_embeddings=True, so this runs with an effectively untied head:lm_headand the live embedding drift apart during training, and the tied-head checkpoint guard from #2511 can't tell they're supposed to be tied (the two tensors don't share storage).Fix
After the
language_modelswap, re-pointlm_head.weightat the now-activeembed_tokens.weightwhentie_word_embeddingsis set. The shared parameter survives the in-place bf16 cast ininitialize_weights(). With the storage genuinely shared again,has_local_tied_lm_head()returnsTrue, so the save path dropslm_headand the load path reconstructs it instead of letting a second copy diverge.I left the state-dict adapter alone on purpose. Once the head is actually tied, loading a tied HF checkpoint (which only stores
embed_tokens.weight) is already covered by the existingensure_tied_lm_head()on the load path.Tests
Added
tests/unit_tests/models/gemma4_moe/test_gemma4_moe_tied_weights.py(CPU, reuses the tiny-config helpers from the existing rope test):lm_head/embed_tokensstorage right after constructioninitialize_weights(bf16)A full save → load → resume checkpoint test would be a reasonable follow-up. The per-family reject-guard work for the rest of the audit is tracked separately under #2512.
cc @yuhezhang-ai