diff --git a/docs/fern/versions/nightly.yml b/docs/fern/versions/nightly.yml index 4a24ccc28f..2eb7fec83b 100644 --- a/docs/fern/versions/nightly.yml +++ b/docs/fern/versions/nightly.yml @@ -169,6 +169,8 @@ navigation: path: ../../model-coverage/llm/tencent/hy-mt2.mdx - page: "MiMo-V2-Flash" path: ../../model-coverage/llm/xiaomimimo/mimo-v2-flash.mdx + - page: "MiMo-V2.5-Pro" + path: ../../model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx - page: "Ling 2.0" path: ../../model-coverage/llm/inclusionai/ling-2.mdx - section: "Vision Language Models" diff --git a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx new file mode 100644 index 0000000000..867ca25ca9 --- /dev/null +++ b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx @@ -0,0 +1,93 @@ +--- +title: "MiMo-V2.5-Pro" +description: "" +--- +[MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) is Xiaomi's hybrid attention Mixture-of-Experts language model. It alternates full and sliding-window attention layers, uses a `sigmoid_with_bias` router with group-limited expert routing, and ships as an FP8 HF checkpoint. + + + +| | | +|---|---| +| **Task** | Text Generation (MoE, hybrid attention) | +| **Architecture** | `MiMoV2ForCausalLM` | +| **Parameters** | Approx. several hundred B total / much smaller active | +| **HF Org** | [XiaomiMiMo](https://huggingface.co/XiaomiMiMo) | + + + +## Available Models + +- **MiMo-V2.5-Pro**: hybrid full/sliding-window attention with FP8 weights. + +## Architecture + +- `MiMoV2ForCausalLM` +- Sliding-window attention using the `MiMoV2Attention(is_swa=True)` path. +- MoE blocks use `nemo_automodel.components.moe.layers.MoE` with `score_func="sigmoid_with_bias"` and `gate_precision=fp32`. +- FP8 round-trip in `MiMoV2StateDictAdapter` covers the bulk of attention/expert weights; layer norms, the gate, `lm_head`, and `embed_tokens` stay in bf16 per `NON_QUANTIZED_KEY_PATTERNS`. + +## Example HF Models + +| Model | HF ID | +|---|---| +| MiMo-V2.5-Pro | [`XiaomiMiMo/MiMo-V2.5-Pro`](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) | + +## Example Recipes + +| Recipe | Description | +|---|---| +| [mimo_v25_pro_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml) | SFT: MiMo-V2.5-Pro on HellaSwag | + +## Try with NeMo AutoModel + +**1. Install** ([full instructions](/get-started/installation)): + +```bash +pip install nemo-automodel +``` + +**2. Clone the repo** to get the example recipes: + +```bash +git clone https://github.com/NVIDIA-NeMo/Automodel.git +cd Automodel +``` + +**3. Run the recipe** from inside the repo: + +```bash +automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml +``` + + +**1. Pull the container** and mount a checkpoint directory: + +```bash +docker run --gpus all -it --rm \ + --shm-size=8g \ + -v $(pwd)/checkpoints:/opt/Automodel/checkpoints \ + nvcr.io/nvidia/nemo-automodel:26.02.00 +``` + +**2. Navigate to the AutoModel directory**: + +```bash +cd /opt/Automodel +``` + +**3. Run the recipe**: + +```bash +automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml +``` + + +See the [Installation Guide](/get-started/installation) and [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft). + +## Fine-Tuning + +See the [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft). + +## Hugging Face Model Cards + +- [XiaomiMiMo/MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) diff --git a/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml new file mode 100644 index 0000000000..ac521dcdbe --- /dev/null +++ b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml @@ -0,0 +1,134 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 16 H100 nodes (128 GPUs): +# torchrun --nproc-per-node 8 -m nemo_automodel.cli.app examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml + +recipe: TrainFinetuneRecipeForNextTokenPrediction + +seed: 1234 + +step_scheduler: + global_batch_size: 256 + local_batch_size: 8 + ckpt_every_steps: 25 + val_every_steps: 500 + num_epochs: 1 + max_steps: 100 + +distributed: + strategy: fsdp2 + tp_size: 1 + cp_size: 1 + pp_size: 4 + ep_size: 32 + + sequence_parallel: false + activation_checkpointing: true + + pipeline: + pp_schedule: interleaved1f1b + pp_microbatch_size: 1 + layers_per_stage: 2 + round_virtual_stages_to_pp_multiple: down + scale_grads_in_schedule: false + patch_inner_model: false + patch_causal_lm_model: false + + moe: + reshard_after_forward: false + wrap_outer_model: false + +dist_env: + backend: nccl + timeout_minutes: 30 + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_config + config: + _target_: nemo_automodel.components.models.mimo_v25.config.MiMoV2Config.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: false + load_base_model: true + backend: + _target_: nemo_automodel.components.models.common.BackendConfig + attn: sdpa + linear: torch + rms_norm: torch_fp32 + rope_fusion: false + dispatcher: deepep + experts: torch_mm + gate_precision: float32 + enable_hf_state_dict_adapter: true + enable_fsdp_optimizations: true + +checkpoint: + enabled: true + checkpoint_dir: checkpoints/mimo_v25_pro + model_save_format: safetensors + save_consolidated: false + dequantize_base_checkpoint: true + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: train + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: true + +packed_sequence: + packed_sequence_size: 0 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: true + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: validation + num_samples_limit: 64 + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: true + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: false + drop_last: true + +optimizer: + _target_: torch.optim.AdamW + betas: [0.9, 0.95] + eps: 1e-8 + lr: 1e-5 + weight_decay: 0.1 + +wandb: + project: automodel-mimo-v25-pro + name: mimo_v25_pro_hellaswag_16n + mode: online diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py index 765b81611a..ea129cdae3 100644 --- a/nemo_automodel/_transformers/registry.py +++ b/nemo_automodel/_transformers/registry.py @@ -140,6 +140,10 @@ "MiMoV2FlashForCausalLM", ("nemo_automodel.components.models.mimo_v2_flash.model", "MiMoV2FlashForCausalLM"), ), + ( + "MiMoV2ForCausalLM", + ("nemo_automodel.components.models.mimo_v25.model", "MiMoV2ForCausalLM"), + ), ( "Ministral3ForCausalLM", ("nemo_automodel.components.models.mistral3.model", "Ministral3ForCausalLM"), @@ -282,6 +286,7 @@ "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), "llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"), "mimo_v2_flash": ("nemo_automodel.components.models.mimo_v2_flash.config", "MiMoV2FlashConfig"), + "mimo_v2": ("nemo_automodel.components.models.mimo_v25.config", "MiMoV2Config"), "minimax_m3_vl": ("nemo_automodel.components.models.minimax_m3_vl.config", "MiniMaxM3VLConfig"), "mistral4": ("nemo_automodel.components.models.mistral4.configuration", "Mistral4Config"), "step3p5v": ("nemo_automodel.components.models.step3p7.configuration_step3p7", "Step3p5VConfig"), diff --git a/nemo_automodel/components/models/mimo_v25/__init__.py b/nemo_automodel/components/models/mimo_v25/__init__.py new file mode 100644 index 0000000000..01fca469cf --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config +from nemo_automodel.components.models.mimo_v25.model import MiMoV2ForCausalLM, MiMoV2Model + +__all__ = ["MiMoV2Config", "MiMoV2ForCausalLM", "MiMoV2Model"] diff --git a/nemo_automodel/components/models/mimo_v25/config.py b/nemo_automodel/components/models/mimo_v25/config.py new file mode 100644 index 0000000000..a1445216d5 --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/config.py @@ -0,0 +1,174 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from transformers.configuration_utils import PretrainedConfig + +_MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"} + + +class MiMoV2Config(PretrainedConfig): + """Configuration for XiaomiMiMo/MiMo-V2.5-Pro.""" + + model_type = "mimo_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int = 151936, + hidden_size: int = 4096, + intermediate_size: int = 22016, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 32, + hidden_act: str = "silu", + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + layernorm_epsilon: float = 1e-6, + rms_norm_eps: float | None = None, + use_cache: bool = True, + tie_word_embeddings: bool = False, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + attention_dropout: float = 0.0, + attention_bias: bool = False, + attention_value_scale: float | None = None, + head_dim: int | None = None, + v_head_dim: int | None = None, + swa_num_attention_heads: int | None = None, + swa_num_key_value_heads: int | None = None, + swa_head_dim: int | None = None, + swa_v_head_dim: int | None = None, + swa_rope_theta: float | None = None, + sliding_window: int | None = None, + sliding_window_size: int | None = None, + attention_chunk_size: int | None = None, + add_full_attention_sink_bias: bool = False, + add_swa_attention_sink_bias: bool = False, + hybrid_block_size: int | None = None, + hybrid_layer_pattern: list[int] | None = None, + partial_rotary_factor: float = 1.0, + n_routed_experts: int | None = None, + n_shared_experts: int | None = None, + moe_intermediate_size: int | None = None, + num_experts_per_tok: int | None = None, + routed_scaling_factor: float | None = None, + scoring_func: str = "sigmoid", + topk_method: str = "noaux_tc", + n_group: int | None = None, + topk_group: int | None = None, + norm_topk_prob: bool = True, + moe_layer_freq: list[int] | None = None, + attention_projection_layout: str = "split", + torch_dtype: str = "bfloat16", + **kwargs, + ): + rope_parameters = kwargs.pop("rope_parameters", None) + if rope_scaling is None and rope_parameters is not None: + rope_scaling = rope_parameters + + if attention_projection_layout is None: + attention_projection_layout = "split" + if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS: + raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}") + + self.attention_projection_layout = attention_projection_layout + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if num_attention_heads % num_key_value_heads != 0: + raise ValueError("num_attention_heads must be divisible by num_key_value_heads") + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layernorm_epsilon = layernorm_epsilon + self.rms_norm_eps = layernorm_epsilon if rms_norm_eps is None else rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.attention_value_scale = attention_value_scale + + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim + self.swa_num_attention_heads = ( + swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads + ) + self.swa_num_key_value_heads = ( + swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads + ) + if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0: + raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads") + self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim + self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim + self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta + + if sliding_window is None: + sliding_window = sliding_window_size + self.sliding_window = sliding_window + self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window + self.attention_chunk_size = attention_chunk_size + self.add_full_attention_sink_bias = add_full_attention_sink_bias + self.add_swa_attention_sink_bias = add_swa_attention_sink_bias + + if hybrid_block_size is not None and hybrid_layer_pattern is None: + hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)] + elif hybrid_layer_pattern is None: + hybrid_layer_pattern = [0] * num_hidden_layers + if len(hybrid_layer_pattern) != num_hidden_layers: + raise ValueError("hybrid_layer_pattern length must match num_hidden_layers") + self.hybrid_block_size = hybrid_block_size + self.hybrid_layer_pattern = hybrid_layer_pattern + + self.partial_rotary_factor = partial_rotary_factor + + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.routed_scaling_factor = routed_scaling_factor + self.scoring_func = scoring_func + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + if isinstance(moe_layer_freq, int): + moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)] + elif moe_layer_freq is None: + moe_layer_freq = [False] * num_hidden_layers + if len(moe_layer_freq) != num_hidden_layers: + raise ValueError("moe_layer_freq length must match num_hidden_layers") + self.moe_layer_freq = moe_layer_freq + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + # Assign after super().__init__() so our string value wins over any + # dtype conversion done by PretrainedConfig. + self.torch_dtype = torch_dtype diff --git a/nemo_automodel/components/models/mimo_v25/model.py b/nemo_automodel/components/models/mimo_v25/model.py new file mode 100644 index 0000000000..3c363c0e94 --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/model.py @@ -0,0 +1,740 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Portions copyright 2026 Xiaomi Corporation. +# Portions copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import copy +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +from nemo_automodel.components.models.common.utils import ( + _has_dtensor_params, + cast_model_to_dtype, + compute_lm_head_logits, +) +from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin +from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +def _convert_bool_4d_mask_to_additive(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + if mask.ndim != 4 or mask.dtype != torch.bool: + return mask + additive = torch.zeros(mask.shape, dtype=dtype, device=mask.device) + return additive.masked_fill(~mask, torch.finfo(dtype).min) + + +def _fallback_additive_mask( + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: torch.Tensor | None = None, + sliding_window: int | None = None, +) -> torch.Tensor: + min_val = torch.finfo(dtype).min + idx = torch.arange(seq_len, device=device) + masked = idx.unsqueeze(0) > idx.unsqueeze(1) + if sliding_window is not None and sliding_window > 0: + masked = masked | ((idx.unsqueeze(1) - idx.unsqueeze(0)) >= sliding_window) + additive = torch.zeros((seq_len, seq_len), dtype=dtype, device=device).masked_fill(masked, min_val) + additive = additive.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len).contiguous() + if attention_mask is not None and attention_mask.ndim == 2: + pad = (1.0 - attention_mask.to(dtype=dtype, device=device)).unsqueeze(1).unsqueeze(2) * min_val + additive = additive + pad + return additive + + +def _ensure_additive_mask( + mask: torch.Tensor | None, + *, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: torch.Tensor | None, + sliding_window: int | None, +) -> torch.Tensor: + if mask is None or not isinstance(mask, torch.Tensor): + return _fallback_additive_mask(batch_size, seq_len, dtype, device, attention_mask, sliding_window) + return _convert_bool_4d_mask_to_additive(mask, dtype) + + +def _derive_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor: + if attention_mask.ndim == 2: + return attention_mask == 0 + if attention_mask.ndim == 4: + diagonal = torch.diagonal(attention_mask[:, 0], dim1=-2, dim2=-1) + return diagonal.logical_not() if attention_mask.dtype == torch.bool else diagonal != 0 + return attention_mask.bool().logical_not() + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + sinks: Optional[torch.Tensor] = None, + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] + + if sinks is not None: + sink_bias = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + attn_weights = torch.cat([attn_weights, sink_bias.to(attn_weights.dtype)], dim=-1) + + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if sinks is not None: + probs = probs[..., :-1] + + probs = F.dropout(probs, p=dropout, training=module.training) + attn_output = torch.matmul(probs, value_states) + return attn_output.transpose(1, 2).contiguous(), probs + + +class MiMoV2Attention(nn.Module): + """MiMoV2 hybrid attention (full or sliding-window).""" + + def __init__( + self, + config: MiMoV2Config, + is_swa: bool, + layer_idx: int, + projection_layout: str, + backend: BackendConfig, + dtype: torch.dtype, + ): + super().__init__() + if projection_layout not in {"split", "fused_qkv"}: + raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}") + + self.layer_idx = layer_idx + self.is_swa = is_swa + self.is_causal = True + self.projection_layout = projection_layout + + default_head_dim = config.hidden_size // config.num_attention_heads + default_v_head_dim = getattr(config, "v_head_dim", default_head_dim) + + if is_swa: + self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim)) + self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim) + self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads) + self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads) + else: + self.head_dim = getattr(config, "head_dim", default_head_dim) + self.v_head_dim = getattr(config, "v_head_dim", self.head_dim) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) + if self.rope_dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from " + f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}" + ) + + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.scaling = self.head_dim**-0.5 + self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None + + self.q_size = self.num_attention_heads * self.head_dim + self.k_size = self.num_key_value_heads * self.head_dim + self.v_size = self.num_key_value_heads * self.v_head_dim + self.o_hidden_size = self.num_attention_heads * self.v_head_dim + self.v_scale = getattr(config, "attention_value_scale", None) + + self.attention_sink_bias = ( + nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False) + if ( + (getattr(config, "add_full_attention_sink_bias", False) and not is_swa) + or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa) + ) + else None + ) + + attention_bias = getattr(config, "attention_bias", False) + if self.projection_layout == "fused_qkv": + self.qkv_proj = initialize_linear_module( + backend.linear, + config.hidden_size, + self.q_size + self.k_size + self.v_size, + bias=attention_bias, + dtype=dtype, + ) + else: + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.q_size, bias=attention_bias, dtype=dtype + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.k_size, bias=attention_bias, dtype=dtype + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.v_size, bias=attention_bias, dtype=dtype + ) + self.o_proj = initialize_linear_module( + backend.linear, self.o_hidden_size, config.hidden_size, bias=False, dtype=dtype + ) + + def _forward_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + input_shape: torch.Size, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.v_scale is not None: + value_states = value_states * self.v_scale + + cos, sin = position_embeddings + query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin) + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + attn_output, _ = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + sinks=self.attention_sink_bias, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + **kwargs: Any, + ) -> torch.Tensor: + del kwargs + input_shape = hidden_states.shape[:-1] + + if self.projection_layout == "fused_qkv": + qkv = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) + return self._forward_attention( + query_states, key_states, value_states, input_shape, position_embeddings, attention_mask + ) + + +class MiMoV2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: MiMoV2Config, is_swa: bool, device: Optional[torch.device] = None): + super().__init__() + self.rope_type = ( + config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict) + else "default" + ) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = copy(config) + self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {}) + if is_swa: + self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta) + self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None)) + if self.config.rope_parameters: + self.config.rope_parameters["rope_theta"] = self.config.rope_theta + + self.rope_init_fn = ( + self.compute_default_rope_parameters if self.rope_type == "default" else ROPE_INIT_FUNCTIONS[self.rope_type] + ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: MiMoV2Config, + device: Optional[torch.device] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple[torch.Tensor, float]: + config.standardize_rope_params() + rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + base = rope_parameters["rope_theta"] + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + if dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {dim} from " + f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}" + ) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MiMoV2DecoderLayer(nn.Module): + def __init__( + self, + config: MiMoV2Config, + layer_idx: int, + moe_config: MoEConfig, + backend: BackendConfig, + ): + super().__init__() + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1 + self.attention_type = "sliding_attention" if is_swa_layer else "full_attention" + self.self_attn = MiMoV2Attention( + config, + is_swa=is_swa_layer, + layer_idx=layer_idx, + projection_layout=config.attention_projection_layout, + backend=backend, + dtype=dtype, + ) + is_moe_layer = getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx] + self.mlp = ( + MoE(moe_config, backend) + if is_moe_layer + else MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) + ) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MiMoV2Model(nn.Module): + def __init__(self, config: MiMoV2Config, moe_config: MoEConfig, backend: BackendConfig): + super().__init__() + self.config = config + self.backend = backend + + if backend.gate_precision is None: + backend.gate_precision = torch.float32 + + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype) + self.layers = nn.ModuleDict( + {str(i): MiMoV2DecoderLayer(config, i, moe_config, backend) for i in range(config.num_hidden_layers)} + ) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False) + self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True) + self.has_sliding_layers = any(p == 1 for p in config.hybrid_layer_pattern) + + def _build_causal_mask_mapping( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | dict[str, torch.Tensor] | None, + position_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, seq_len = inputs_embeds.shape[:2] + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + if isinstance(attention_mask, dict): + return { + "full_attention": _ensure_additive_mask( + attention_mask.get("full_attention"), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=None, + ), + "sliding_attention": _ensure_additive_mask( + attention_mask.get("sliding_attention", attention_mask.get("sliding_window_attention")), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=self.config.sliding_window, + ), + } + + mask_kwargs = dict( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + pad_mask = attention_mask if isinstance(attention_mask, torch.Tensor) else None + return { + "full_attention": _ensure_additive_mask( + create_causal_mask(**mask_kwargs), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=pad_mask, + sliding_window=None, + ), + "sliding_attention": _ensure_additive_mask( + create_sliding_window_causal_mask(**mask_kwargs) if self.has_sliding_layers else None, + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=pad_mask, + sliding_window=self.config.sliding_window, + ), + } + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + *, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + del kwargs + if inputs_embeds is None: + if input_ids is None: + raise ValueError("input_ids or inputs_embeds must be provided") + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if padding_mask is None and isinstance(attention_mask, torch.Tensor): + padding_mask = _derive_padding_mask(attention_mask) + + causal_mask_mapping = self._build_causal_mask_mapping( + inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + + for layer in self.layers.values(): + layer_position_embeddings = ( + swa_position_embeddings if layer.attention_type == "sliding_attention" else position_embeddings + ) + hidden_states = layer( + hidden_states, + attention_mask=causal_mask_mapping[layer.attention_type], + position_embeddings=layer_position_embeddings, + padding_mask=padding_mask, + ) + + return self.norm(hidden_states) + + @torch.no_grad() + def init_weights(self, buffer_device: Optional[torch.device] = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + nn.init.normal_(self.embed_tokens.weight) + self.norm.reset_parameters() + for layer in self.layers.values(): + layer.init_weights(buffer_device) + + +class MiMoV2ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin): + """NeMo AutoModel causal LM wrapper for MiMo-V2.5-Pro.""" + + _keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias", "attention_sink_bias"] + _pp_keep_self_forward = True + _skip_init_weights_on_load = True + + @dataclass(frozen=True) + class ModelCapabilities: + """Declared parallelism capabilities for this model class.""" + + supports_tp: bool = False + supports_cp: bool = False + supports_pp: bool = True + supports_ep: bool = True + + @classmethod + def from_config( + cls, + config: MiMoV2Config, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ) -> MiMoV2ForCausalLM: + return cls(config, moe_config, backend, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + *model_args, + **kwargs, + ) -> MiMoV2ForCausalLM: + config = MiMoV2Config.from_pretrained(pretrained_model_name_or_path) + return cls.from_config(config, *model_args, **kwargs) + + def __init__( + self, + config: MiMoV2Config, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + moe_overrides = kwargs.pop("moe_overrides", None) + + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + moe_defaults = dict( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=int(config.n_routed_experts or 0), + n_shared_experts=int(config.n_shared_experts or 0), + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=config.n_group, + n_limited_groups=config.topk_group, + train_gate=True, + gate_bias_update_factor=0.0, + score_func="sigmoid_with_bias" if config.scoring_func == "sigmoid" else config.scoring_func, + route_scale=config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0, + aux_loss_coeff=0.0, + norm_topk_prob=config.norm_topk_prob, + router_bias=False, + expert_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + force_e_score_correction_bias=True, + dtype=dtype, + ) + if moe_overrides: + moe_defaults.update(moe_overrides) + resolved_moe_config = moe_config or MoEConfig(**moe_defaults) + + self.model = MiMoV2Model(config, resolved_moe_config, self.backend) + self.lm_head = initialize_linear_module( + self.backend.linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + ) + + if self.backend.enable_hf_state_dict_adapter: + from nemo_automodel.components.models.mimo_v25.state_dict_adapter import MiMoV2StateDictAdapter + + self.state_dict_adapter = MiMoV2StateDictAdapter( + self.config, + self.model.moe_config if hasattr(self.model, "moe_config") else resolved_moe_config, + self.backend, + dtype=dtype, + ) + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Linear: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + *, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + output_hidden_states: Optional[bool] = None, + **kwargs: Any, + ) -> CausalLMOutputWithPast: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + hidden = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **kwargs, + ) + return compute_lm_head_logits(self.lm_head, hidden, logits_to_keep, output_hidden_states=output_hidden_states) + + def customize_pipeline_stage_modules( + self, + module_names_per_stage: list[list[str]], + *, + layers_prefix: str, + text_model: Optional[nn.Module] = None, + ) -> list[list[str]]: + """Keep the SWA rotary embedding on every PP stage.""" + text_model = text_model or self.model + stage_modules = [list(modules) for modules in module_names_per_stage] + if getattr(text_model, "swa_rotary_emb", None) is not None: + fqn = f"{layers_prefix}swa_rotary_emb" + for modules in stage_modules: + if fqn not in modules: + modules.append(fqn) + return stage_modules + + @torch.no_grad() + def initialize_weights( + self, + buffer_device: Optional[torch.device] = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if _has_dtensor_params(self): + return + cast_model_to_dtype(self, dtype) + + +ModelClass = MiMoV2ForCausalLM diff --git a/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py new file mode 100644 index 0000000000..8493404e8a --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +from typing import Any + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.models.deepseek_v3.state_dict_adapter import ( + create_scale_inv_for_weight, + dequantize_from_fp8, +) +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin + +logger = logging.getLogger(__name__) + +NON_QUANTIZED_KEY_PATTERNS = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "norm.weight", + "lm_head.weight", + "embed_tokens.weight", + "mlp.gate.weight", + "self_attn.o_proj.weight", + "attention_sink_bias", +] + + +def _should_quantize_key(key: str) -> bool: + if not key.endswith(".weight"): + return False + return not any(pattern in key for pattern in NON_QUANTIZED_KEY_PATTERNS) + + +class MiMoV2StateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter): + """Convert MiMo-V2.5-Pro HF checkpoints to Automodel's grouped MoE layout. + + HF stores routed experts as split per-expert projections: + ``mlp.experts.{E}.{gate,up,down}_proj.weight``. Automodel groups those + into ``gate_and_up_projs`` and ``down_projs`` so EP can shard experts + without materialising every expert on every rank. + + MiMo-V2.5-Pro uses fused QKV (``self_attn.qkv_proj.weight``), so attention + projection keys require no renaming — they pass through unchanged. + """ + + def __init__( + self, + config: Any, + moe_config: MoEConfig, + backend: BackendConfig, + dtype: torch.dtype = torch.bfloat16, + ): + self.config = config + self.moe_config = moe_config + self.backend = backend + self.dtype = dtype + self._uses_model_prefix = True + + def from_hf( + self, + hf_state_dict: dict[str, Any], + device_mesh: DeviceMesh | None = None, + **kwargs, + ) -> dict[str, Any]: + del kwargs + for key in hf_state_dict.keys(): + if ".mlp.experts." in key and key.endswith(".weight"): + self._uses_model_prefix = key.startswith("model.") + break + hf_state_dict = self._dequantize(hf_state_dict) + return self._from_hf_w_merged_experts(hf_state_dict, device_mesh) + + def to_hf( + self, + state_dict: dict[str, Any], + exclude_key_regex: str | None = None, + quantization: bool = False, + **kwargs, + ) -> dict[str, Any]: + """Convert Automodel state_dict to the HF MiMo-V2.5-Pro layout. + + Note: The ``quantization`` parameter is accepted for interface + compatibility but is **ignored**. MiMo-V2.5-Pro is distributed as an + FP8 HF checkpoint, so this adapter always emits FP8 weights plus + ``_scale_inv`` companions for keys that match ``_should_quantize_key``, + regardless of the caller's preference. + """ + hf_state_dict: dict[str, Any] = {} + for fqn, tensor in state_dict.items(): + for key, value in self.convert_single_tensor_to_hf( + fqn, + tensor, + exclude_key_regex=exclude_key_regex, + quantization=quantization, + **kwargs, + ): + hf_state_dict[key] = value + return hf_state_dict + + def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]: + exclude_key_regex = kwargs.get("exclude_key_regex", None) + + expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs) + result = expert_result if expert_result is not None else [(fqn, tensor)] + + if exclude_key_regex: + result = [(key, value) for key, value in result if not re.match(exclude_key_regex, key)] + + quantized_result: list[tuple[str, Any]] = [] + for key, value in result: + if _should_quantize_key(key): + quantized = value.to(dtype=torch.float8_e4m3fn) + quantized_result.append((key, quantized)) + quantized_result.append((key + "_scale_inv", create_scale_inv_for_weight(quantized))) + else: + quantized_result.append((key, value)) + return quantized_result + + def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + scale_inv_keys: list[str] = [] + dequantized_count = 0 + for key in list(state_dict.keys()): + if not key.endswith(".weight"): + continue + scale_key = key + "_scale_inv" + if scale_key not in state_dict: + continue + state_dict[key] = dequantize_from_fp8( + state_dict[key], + state_dict[scale_key], + dtype=self.dtype, + name=key, + ) + scale_inv_keys.append(scale_key) + dequantized_count += 1 + + for key in scale_inv_keys: + state_dict.pop(key, None) + + logger.debug("[MiMo V2.5-Pro FP8 Dequant] Dequantized %s weights", dequantized_count) + return state_dict diff --git a/pyproject.toml b/pyproject.toml index c7c187d182..bcc73d5785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -406,6 +406,8 @@ convention = "google" "nemo_automodel/components/models/llama_bidirectional/export_onnx.py" = ["D103"] "nemo_automodel/components/models/llava_onevision/rice_vit.py" = ["D101", "D103"] "nemo_automodel/components/models/llava_onevision/state_dict_adapter.py" = ["D101"] +"nemo_automodel/components/models/mimo_v25/model.py" = ["D101", "D103"] +"nemo_automodel/components/models/mimo_v25/state_dict_adapter.py" = ["D101"] "nemo_automodel/components/models/minimax_m2/model.py" = ["D101"] "nemo_automodel/components/models/minimax_m2/state_dict_adapter.py" = ["D103"] "nemo_automodel/components/models/mistral3/model.py" = ["D101", "D103"]