diff --git a/docs/fern/versions/nightly.yml b/docs/fern/versions/nightly.yml
index 4a24ccc28f..2eb7fec83b 100644
--- a/docs/fern/versions/nightly.yml
+++ b/docs/fern/versions/nightly.yml
@@ -169,6 +169,8 @@ navigation:
path: ../../model-coverage/llm/tencent/hy-mt2.mdx
- page: "MiMo-V2-Flash"
path: ../../model-coverage/llm/xiaomimimo/mimo-v2-flash.mdx
+ - page: "MiMo-V2.5-Pro"
+ path: ../../model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx
- page: "Ling 2.0"
path: ../../model-coverage/llm/inclusionai/ling-2.mdx
- section: "Vision Language Models"
diff --git a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx
new file mode 100644
index 0000000000..867ca25ca9
--- /dev/null
+++ b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx
@@ -0,0 +1,93 @@
+---
+title: "MiMo-V2.5-Pro"
+description: ""
+---
+[MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) is Xiaomi's hybrid attention Mixture-of-Experts language model. It alternates full and sliding-window attention layers, uses a `sigmoid_with_bias` router with group-limited expert routing, and ships as an FP8 HF checkpoint.
+
+
+
+| | |
+|---|---|
+| **Task** | Text Generation (MoE, hybrid attention) |
+| **Architecture** | `MiMoV2ForCausalLM` |
+| **Parameters** | Approx. several hundred B total / much smaller active |
+| **HF Org** | [XiaomiMiMo](https://huggingface.co/XiaomiMiMo) |
+
+
+
+## Available Models
+
+- **MiMo-V2.5-Pro**: hybrid full/sliding-window attention with FP8 weights.
+
+## Architecture
+
+- `MiMoV2ForCausalLM`
+- Sliding-window attention using the `MiMoV2Attention(is_swa=True)` path.
+- MoE blocks use `nemo_automodel.components.moe.layers.MoE` with `score_func="sigmoid_with_bias"` and `gate_precision=fp32`.
+- FP8 round-trip in `MiMoV2StateDictAdapter` covers the bulk of attention/expert weights; layer norms, the gate, `lm_head`, and `embed_tokens` stay in bf16 per `NON_QUANTIZED_KEY_PATTERNS`.
+
+## Example HF Models
+
+| Model | HF ID |
+|---|---|
+| MiMo-V2.5-Pro | [`XiaomiMiMo/MiMo-V2.5-Pro`](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) |
+
+## Example Recipes
+
+| Recipe | Description |
+|---|---|
+| [mimo_v25_pro_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml) | SFT: MiMo-V2.5-Pro on HellaSwag |
+
+## Try with NeMo AutoModel
+
+**1. Install** ([full instructions](/get-started/installation)):
+
+```bash
+pip install nemo-automodel
+```
+
+**2. Clone the repo** to get the example recipes:
+
+```bash
+git clone https://github.com/NVIDIA-NeMo/Automodel.git
+cd Automodel
+```
+
+**3. Run the recipe** from inside the repo:
+
+```bash
+automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml
+```
+
+
+**1. Pull the container** and mount a checkpoint directory:
+
+```bash
+docker run --gpus all -it --rm \
+ --shm-size=8g \
+ -v $(pwd)/checkpoints:/opt/Automodel/checkpoints \
+ nvcr.io/nvidia/nemo-automodel:26.02.00
+```
+
+**2. Navigate to the AutoModel directory**:
+
+```bash
+cd /opt/Automodel
+```
+
+**3. Run the recipe**:
+
+```bash
+automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml
+```
+
+
+See the [Installation Guide](/get-started/installation) and [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft).
+
+## Fine-Tuning
+
+See the [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft).
+
+## Hugging Face Model Cards
+
+- [XiaomiMiMo/MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro)
diff --git a/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml
new file mode 100644
index 0000000000..ac521dcdbe
--- /dev/null
+++ b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml
@@ -0,0 +1,134 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 16 H100 nodes (128 GPUs):
+# torchrun --nproc-per-node 8 -m nemo_automodel.cli.app examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml
+
+recipe: TrainFinetuneRecipeForNextTokenPrediction
+
+seed: 1234
+
+step_scheduler:
+ global_batch_size: 256
+ local_batch_size: 8
+ ckpt_every_steps: 25
+ val_every_steps: 500
+ num_epochs: 1
+ max_steps: 100
+
+distributed:
+ strategy: fsdp2
+ tp_size: 1
+ cp_size: 1
+ pp_size: 4
+ ep_size: 32
+
+ sequence_parallel: false
+ activation_checkpointing: true
+
+ pipeline:
+ pp_schedule: interleaved1f1b
+ pp_microbatch_size: 1
+ layers_per_stage: 2
+ round_virtual_stages_to_pp_multiple: down
+ scale_grads_in_schedule: false
+ patch_inner_model: false
+ patch_causal_lm_model: false
+
+ moe:
+ reshard_after_forward: false
+ wrap_outer_model: false
+
+dist_env:
+ backend: nccl
+ timeout_minutes: 30
+
+model:
+ _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_config
+ config:
+ _target_: nemo_automodel.components.models.mimo_v25.config.MiMoV2Config.from_pretrained
+ pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro
+ name_or_path: XiaomiMiMo/MiMo-V2.5-Pro
+ trust_remote_code: false
+ load_base_model: true
+ backend:
+ _target_: nemo_automodel.components.models.common.BackendConfig
+ attn: sdpa
+ linear: torch
+ rms_norm: torch_fp32
+ rope_fusion: false
+ dispatcher: deepep
+ experts: torch_mm
+ gate_precision: float32
+ enable_hf_state_dict_adapter: true
+ enable_fsdp_optimizations: true
+
+checkpoint:
+ enabled: true
+ checkpoint_dir: checkpoints/mimo_v25_pro
+ model_save_format: safetensors
+ save_consolidated: false
+ dequantize_base_checkpoint: true
+
+loss_fn:
+ _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
+
+dataset:
+ _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
+ path_or_dataset: rowan/hellaswag
+ split: train
+ tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro
+ trust_remote_code: true
+
+packed_sequence:
+ packed_sequence_size: 0
+
+dataloader:
+ _target_: torchdata.stateful_dataloader.StatefulDataLoader
+ collate_fn:
+ _target_: nemo_automodel.components.datasets.utils.default_collater
+ pad_seq_len_divisible: 64
+ shuffle: true
+
+validation_dataset:
+ _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag
+ path_or_dataset: rowan/hellaswag
+ split: validation
+ num_samples_limit: 64
+ tokenizer:
+ _target_: transformers.AutoTokenizer.from_pretrained
+ pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro
+ trust_remote_code: true
+
+validation_dataloader:
+ _target_: torchdata.stateful_dataloader.StatefulDataLoader
+ collate_fn:
+ _target_: nemo_automodel.components.datasets.utils.default_collater
+ pad_seq_len_divisible: 64
+ shuffle: false
+ drop_last: true
+
+optimizer:
+ _target_: torch.optim.AdamW
+ betas: [0.9, 0.95]
+ eps: 1e-8
+ lr: 1e-5
+ weight_decay: 0.1
+
+wandb:
+ project: automodel-mimo-v25-pro
+ name: mimo_v25_pro_hellaswag_16n
+ mode: online
diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py
index 765b81611a..ea129cdae3 100644
--- a/nemo_automodel/_transformers/registry.py
+++ b/nemo_automodel/_transformers/registry.py
@@ -140,6 +140,10 @@
"MiMoV2FlashForCausalLM",
("nemo_automodel.components.models.mimo_v2_flash.model", "MiMoV2FlashForCausalLM"),
),
+ (
+ "MiMoV2ForCausalLM",
+ ("nemo_automodel.components.models.mimo_v25.model", "MiMoV2ForCausalLM"),
+ ),
(
"Ministral3ForCausalLM",
("nemo_automodel.components.models.mistral3.model", "Ministral3ForCausalLM"),
@@ -282,6 +286,7 @@
"kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"),
"llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"),
"mimo_v2_flash": ("nemo_automodel.components.models.mimo_v2_flash.config", "MiMoV2FlashConfig"),
+ "mimo_v2": ("nemo_automodel.components.models.mimo_v25.config", "MiMoV2Config"),
"minimax_m3_vl": ("nemo_automodel.components.models.minimax_m3_vl.config", "MiniMaxM3VLConfig"),
"mistral4": ("nemo_automodel.components.models.mistral4.configuration", "Mistral4Config"),
"step3p5v": ("nemo_automodel.components.models.step3p7.configuration_step3p7", "Step3p5VConfig"),
diff --git a/nemo_automodel/components/models/mimo_v25/__init__.py b/nemo_automodel/components/models/mimo_v25/__init__.py
new file mode 100644
index 0000000000..01fca469cf
--- /dev/null
+++ b/nemo_automodel/components/models/mimo_v25/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config
+from nemo_automodel.components.models.mimo_v25.model import MiMoV2ForCausalLM, MiMoV2Model
+
+__all__ = ["MiMoV2Config", "MiMoV2ForCausalLM", "MiMoV2Model"]
diff --git a/nemo_automodel/components/models/mimo_v25/config.py b/nemo_automodel/components/models/mimo_v25/config.py
new file mode 100644
index 0000000000..a1445216d5
--- /dev/null
+++ b/nemo_automodel/components/models/mimo_v25/config.py
@@ -0,0 +1,174 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from transformers.configuration_utils import PretrainedConfig
+
+_MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"}
+
+
+class MiMoV2Config(PretrainedConfig):
+ """Configuration for XiaomiMiMo/MiMo-V2.5-Pro."""
+
+ model_type = "mimo_v2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ attribute_map = {
+ "num_local_experts": "n_routed_experts",
+ }
+
+ def __init__(
+ self,
+ vocab_size: int = 151936,
+ hidden_size: int = 4096,
+ intermediate_size: int = 22016,
+ num_hidden_layers: int = 32,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 32,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 32768,
+ initializer_range: float = 0.02,
+ layernorm_epsilon: float = 1e-6,
+ rms_norm_eps: float | None = None,
+ use_cache: bool = True,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 10000.0,
+ rope_scaling: dict | None = None,
+ attention_dropout: float = 0.0,
+ attention_bias: bool = False,
+ attention_value_scale: float | None = None,
+ head_dim: int | None = None,
+ v_head_dim: int | None = None,
+ swa_num_attention_heads: int | None = None,
+ swa_num_key_value_heads: int | None = None,
+ swa_head_dim: int | None = None,
+ swa_v_head_dim: int | None = None,
+ swa_rope_theta: float | None = None,
+ sliding_window: int | None = None,
+ sliding_window_size: int | None = None,
+ attention_chunk_size: int | None = None,
+ add_full_attention_sink_bias: bool = False,
+ add_swa_attention_sink_bias: bool = False,
+ hybrid_block_size: int | None = None,
+ hybrid_layer_pattern: list[int] | None = None,
+ partial_rotary_factor: float = 1.0,
+ n_routed_experts: int | None = None,
+ n_shared_experts: int | None = None,
+ moe_intermediate_size: int | None = None,
+ num_experts_per_tok: int | None = None,
+ routed_scaling_factor: float | None = None,
+ scoring_func: str = "sigmoid",
+ topk_method: str = "noaux_tc",
+ n_group: int | None = None,
+ topk_group: int | None = None,
+ norm_topk_prob: bool = True,
+ moe_layer_freq: list[int] | None = None,
+ attention_projection_layout: str = "split",
+ torch_dtype: str = "bfloat16",
+ **kwargs,
+ ):
+ rope_parameters = kwargs.pop("rope_parameters", None)
+ if rope_scaling is None and rope_parameters is not None:
+ rope_scaling = rope_parameters
+
+ if attention_projection_layout is None:
+ attention_projection_layout = "split"
+ if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS:
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}")
+
+ self.attention_projection_layout = attention_projection_layout
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ if num_attention_heads % num_key_value_heads != 0:
+ raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layernorm_epsilon = layernorm_epsilon
+ self.rms_norm_eps = layernorm_epsilon if rms_norm_eps is None else rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_dropout = attention_dropout
+ self.attention_bias = attention_bias
+ self.attention_value_scale = attention_value_scale
+
+ self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
+ self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim
+ self.swa_num_attention_heads = (
+ swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads
+ )
+ self.swa_num_key_value_heads = (
+ swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads
+ )
+ if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0:
+ raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads")
+ self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim
+ self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim
+ self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta
+
+ if sliding_window is None:
+ sliding_window = sliding_window_size
+ self.sliding_window = sliding_window
+ self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window
+ self.attention_chunk_size = attention_chunk_size
+ self.add_full_attention_sink_bias = add_full_attention_sink_bias
+ self.add_swa_attention_sink_bias = add_swa_attention_sink_bias
+
+ if hybrid_block_size is not None and hybrid_layer_pattern is None:
+ hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)]
+ elif hybrid_layer_pattern is None:
+ hybrid_layer_pattern = [0] * num_hidden_layers
+ if len(hybrid_layer_pattern) != num_hidden_layers:
+ raise ValueError("hybrid_layer_pattern length must match num_hidden_layers")
+ self.hybrid_block_size = hybrid_block_size
+ self.hybrid_layer_pattern = hybrid_layer_pattern
+
+ self.partial_rotary_factor = partial_rotary_factor
+
+ self.n_routed_experts = n_routed_experts
+ self.n_shared_experts = n_shared_experts
+ self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
+ self.num_experts_per_tok = num_experts_per_tok
+ self.routed_scaling_factor = routed_scaling_factor
+ self.scoring_func = scoring_func
+ self.topk_method = topk_method
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.norm_topk_prob = norm_topk_prob
+ if isinstance(moe_layer_freq, int):
+ moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)]
+ elif moe_layer_freq is None:
+ moe_layer_freq = [False] * num_hidden_layers
+ if len(moe_layer_freq) != num_hidden_layers:
+ raise ValueError("moe_layer_freq length must match num_hidden_layers")
+ self.moe_layer_freq = moe_layer_freq
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ # Assign after super().__init__() so our string value wins over any
+ # dtype conversion done by PretrainedConfig.
+ self.torch_dtype = torch_dtype
diff --git a/nemo_automodel/components/models/mimo_v25/model.py b/nemo_automodel/components/models/mimo_v25/model.py
new file mode 100644
index 0000000000..3c363c0e94
--- /dev/null
+++ b/nemo_automodel/components/models/mimo_v25/model.py
@@ -0,0 +1,740 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+# Portions copyright 2026 Xiaomi Corporation.
+# Portions copyright 2026 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from copy import copy
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+
+from nemo_automodel.components.models.common import (
+ BackendConfig,
+ initialize_linear_module,
+ initialize_rms_norm_module,
+)
+from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin
+from nemo_automodel.components.models.common.utils import (
+ _has_dtensor_params,
+ cast_model_to_dtype,
+ compute_lm_head_logits,
+)
+from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config
+from nemo_automodel.components.moe.config import MoEConfig
+from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin
+from nemo_automodel.components.moe.layers import MLP, MoE
+from nemo_automodel.shared.utils import dtype_from_str as get_dtype
+
+
+def _convert_bool_4d_mask_to_additive(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ if mask.ndim != 4 or mask.dtype != torch.bool:
+ return mask
+ additive = torch.zeros(mask.shape, dtype=dtype, device=mask.device)
+ return additive.masked_fill(~mask, torch.finfo(dtype).min)
+
+
+def _fallback_additive_mask(
+ batch_size: int,
+ seq_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ attention_mask: torch.Tensor | None = None,
+ sliding_window: int | None = None,
+) -> torch.Tensor:
+ min_val = torch.finfo(dtype).min
+ idx = torch.arange(seq_len, device=device)
+ masked = idx.unsqueeze(0) > idx.unsqueeze(1)
+ if sliding_window is not None and sliding_window > 0:
+ masked = masked | ((idx.unsqueeze(1) - idx.unsqueeze(0)) >= sliding_window)
+ additive = torch.zeros((seq_len, seq_len), dtype=dtype, device=device).masked_fill(masked, min_val)
+ additive = additive.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len).contiguous()
+ if attention_mask is not None and attention_mask.ndim == 2:
+ pad = (1.0 - attention_mask.to(dtype=dtype, device=device)).unsqueeze(1).unsqueeze(2) * min_val
+ additive = additive + pad
+ return additive
+
+
+def _ensure_additive_mask(
+ mask: torch.Tensor | None,
+ *,
+ batch_size: int,
+ seq_len: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ attention_mask: torch.Tensor | None,
+ sliding_window: int | None,
+) -> torch.Tensor:
+ if mask is None or not isinstance(mask, torch.Tensor):
+ return _fallback_additive_mask(batch_size, seq_len, dtype, device, attention_mask, sliding_window)
+ return _convert_bool_4d_mask_to_additive(mask, dtype)
+
+
+def _derive_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor:
+ if attention_mask.ndim == 2:
+ return attention_mask == 0
+ if attention_mask.ndim == 4:
+ diagonal = torch.diagonal(attention_mask[:, 0], dim1=-2, dim2=-1)
+ return diagonal.logical_not() if attention_mask.dtype == torch.bool else diagonal != 0
+ return attention_mask.bool().logical_not()
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ unsqueeze_dim: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ sinks: Optional[torch.Tensor] = None,
+ **kwargs: Any,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
+
+ if sinks is not None:
+ sink_bias = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
+ attn_weights = torch.cat([attn_weights, sink_bias.to(attn_weights.dtype)], dim=-1)
+
+ attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values
+ probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ if sinks is not None:
+ probs = probs[..., :-1]
+
+ probs = F.dropout(probs, p=dropout, training=module.training)
+ attn_output = torch.matmul(probs, value_states)
+ return attn_output.transpose(1, 2).contiguous(), probs
+
+
+class MiMoV2Attention(nn.Module):
+ """MiMoV2 hybrid attention (full or sliding-window)."""
+
+ def __init__(
+ self,
+ config: MiMoV2Config,
+ is_swa: bool,
+ layer_idx: int,
+ projection_layout: str,
+ backend: BackendConfig,
+ dtype: torch.dtype,
+ ):
+ super().__init__()
+ if projection_layout not in {"split", "fused_qkv"}:
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}")
+
+ self.layer_idx = layer_idx
+ self.is_swa = is_swa
+ self.is_causal = True
+ self.projection_layout = projection_layout
+
+ default_head_dim = config.hidden_size // config.num_attention_heads
+ default_v_head_dim = getattr(config, "v_head_dim", default_head_dim)
+
+ if is_swa:
+ self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim))
+ self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim)
+ self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads)
+ self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads)
+ else:
+ self.head_dim = getattr(config, "head_dim", default_head_dim)
+ self.v_head_dim = getattr(config, "v_head_dim", self.head_dim)
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+
+ self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0))
+ if self.rope_dim % 2 != 0:
+ raise ValueError(
+ f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from "
+ f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}"
+ )
+
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
+ self.scaling = self.head_dim**-0.5
+ self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None
+
+ self.q_size = self.num_attention_heads * self.head_dim
+ self.k_size = self.num_key_value_heads * self.head_dim
+ self.v_size = self.num_key_value_heads * self.v_head_dim
+ self.o_hidden_size = self.num_attention_heads * self.v_head_dim
+ self.v_scale = getattr(config, "attention_value_scale", None)
+
+ self.attention_sink_bias = (
+ nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False)
+ if (
+ (getattr(config, "add_full_attention_sink_bias", False) and not is_swa)
+ or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa)
+ )
+ else None
+ )
+
+ attention_bias = getattr(config, "attention_bias", False)
+ if self.projection_layout == "fused_qkv":
+ self.qkv_proj = initialize_linear_module(
+ backend.linear,
+ config.hidden_size,
+ self.q_size + self.k_size + self.v_size,
+ bias=attention_bias,
+ dtype=dtype,
+ )
+ else:
+ self.q_proj = initialize_linear_module(
+ backend.linear, config.hidden_size, self.q_size, bias=attention_bias, dtype=dtype
+ )
+ self.k_proj = initialize_linear_module(
+ backend.linear, config.hidden_size, self.k_size, bias=attention_bias, dtype=dtype
+ )
+ self.v_proj = initialize_linear_module(
+ backend.linear, config.hidden_size, self.v_size, bias=attention_bias, dtype=dtype
+ )
+ self.o_proj = initialize_linear_module(
+ backend.linear, self.o_hidden_size, config.hidden_size, bias=False, dtype=dtype
+ )
+
+ def _forward_attention(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ input_shape: torch.Size,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if self.v_scale is not None:
+ value_states = value_states * self.v_scale
+
+ cos, sin = position_embeddings
+ query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
+ key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
+ query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin)
+ query_states = torch.cat([query_rope, query_nope], dim=-1)
+ key_states = torch.cat([key_rope, key_nope], dim=-1)
+
+ attn_output, _ = eager_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ sinks=self.attention_sink_bias,
+ )
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ return self.o_proj(attn_output)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ del kwargs
+ input_shape = hidden_states.shape[:-1]
+
+ if self.projection_layout == "fused_qkv":
+ qkv = self.qkv_proj(hidden_states)
+ query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
+ return self._forward_attention(
+ query_states, key_states, value_states, input_shape, position_embeddings, attention_mask
+ )
+
+
+class MiMoV2RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor
+
+ def __init__(self, config: MiMoV2Config, is_swa: bool, device: Optional[torch.device] = None):
+ super().__init__()
+ self.rope_type = (
+ config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict)
+ else "default"
+ )
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = copy(config)
+ self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {})
+ if is_swa:
+ self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta)
+ self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None))
+ if self.config.rope_parameters:
+ self.config.rope_parameters["rope_theta"] = self.config.rope_theta
+
+ self.rope_init_fn = (
+ self.compute_default_rope_parameters if self.rope_type == "default" else ROPE_INIT_FUNCTIONS[self.rope_type]
+ )
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: MiMoV2Config,
+ device: Optional[torch.device] = None,
+ seq_len: Optional[int] = None,
+ layer_type: Optional[str] = None,
+ ) -> tuple[torch.Tensor, float]:
+ config.standardize_rope_params()
+ rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
+ base = rope_parameters["rope_theta"]
+ partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+ if dim % 2 != 0:
+ raise ValueError(
+ f"MiMoV2 rotary dimension must be even, got {dim} from "
+ f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}"
+ )
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, 1.0
+
+ @torch.no_grad()
+ @dynamic_rope_update
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class MiMoV2DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config: MiMoV2Config,
+ layer_idx: int,
+ moe_config: MoEConfig,
+ backend: BackendConfig,
+ ):
+ super().__init__()
+ dtype = get_dtype(config.torch_dtype, torch.bfloat16)
+ is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1
+ self.attention_type = "sliding_attention" if is_swa_layer else "full_attention"
+ self.self_attn = MiMoV2Attention(
+ config,
+ is_swa=is_swa_layer,
+ layer_idx=layer_idx,
+ projection_layout=config.attention_projection_layout,
+ backend=backend,
+ dtype=dtype,
+ )
+ is_moe_layer = getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx]
+ self.mlp = (
+ MoE(moe_config, backend)
+ if is_moe_layer
+ else MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype)
+ )
+ self.input_layernorm = initialize_rms_norm_module(
+ backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype
+ )
+ self.post_attention_layernorm = initialize_rms_norm_module(
+ backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ *,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ padding_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class MiMoV2Model(nn.Module):
+ def __init__(self, config: MiMoV2Config, moe_config: MoEConfig, backend: BackendConfig):
+ super().__init__()
+ self.config = config
+ self.backend = backend
+
+ if backend.gate_precision is None:
+ backend.gate_precision = torch.float32
+
+ dtype = get_dtype(config.torch_dtype, torch.bfloat16)
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype)
+ self.layers = nn.ModuleDict(
+ {str(i): MiMoV2DecoderLayer(config, i, moe_config, backend) for i in range(config.num_hidden_layers)}
+ )
+ self.norm = initialize_rms_norm_module(
+ backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype
+ )
+ self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False)
+ self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True)
+ self.has_sliding_layers = any(p == 1 for p in config.hybrid_layer_pattern)
+
+ def _build_causal_mask_mapping(
+ self,
+ inputs_embeds: torch.Tensor,
+ attention_mask: torch.Tensor | dict[str, torch.Tensor] | None,
+ position_ids: torch.Tensor,
+ cache_position: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ batch_size, seq_len = inputs_embeds.shape[:2]
+ dtype = inputs_embeds.dtype
+ device = inputs_embeds.device
+
+ if isinstance(attention_mask, dict):
+ return {
+ "full_attention": _ensure_additive_mask(
+ attention_mask.get("full_attention"),
+ batch_size=batch_size,
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=None,
+ sliding_window=None,
+ ),
+ "sliding_attention": _ensure_additive_mask(
+ attention_mask.get("sliding_attention", attention_mask.get("sliding_window_attention")),
+ batch_size=batch_size,
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=None,
+ sliding_window=self.config.sliding_window,
+ ),
+ }
+
+ mask_kwargs = dict(
+ config=self.config,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=None,
+ position_ids=position_ids,
+ )
+ pad_mask = attention_mask if isinstance(attention_mask, torch.Tensor) else None
+ return {
+ "full_attention": _ensure_additive_mask(
+ create_causal_mask(**mask_kwargs),
+ batch_size=batch_size,
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=pad_mask,
+ sliding_window=None,
+ ),
+ "sliding_attention": _ensure_additive_mask(
+ create_sliding_window_causal_mask(**mask_kwargs) if self.has_sliding_layers else None,
+ batch_size=batch_size,
+ seq_len=seq_len,
+ dtype=dtype,
+ device=device,
+ attention_mask=pad_mask,
+ sliding_window=self.config.sliding_window,
+ ),
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ *,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ del kwargs
+ if inputs_embeds is None:
+ if input_ids is None:
+ raise ValueError("input_ids or inputs_embeds must be provided")
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ if padding_mask is None and isinstance(attention_mask, torch.Tensor):
+ padding_mask = _derive_padding_mask(attention_mask)
+
+ causal_mask_mapping = self._build_causal_mask_mapping(
+ inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ )
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
+
+ for layer in self.layers.values():
+ layer_position_embeddings = (
+ swa_position_embeddings if layer.attention_type == "sliding_attention" else position_embeddings
+ )
+ hidden_states = layer(
+ hidden_states,
+ attention_mask=causal_mask_mapping[layer.attention_type],
+ position_embeddings=layer_position_embeddings,
+ padding_mask=padding_mask,
+ )
+
+ return self.norm(hidden_states)
+
+ @torch.no_grad()
+ def init_weights(self, buffer_device: Optional[torch.device] = None) -> None:
+ buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}")
+ with buffer_device:
+ nn.init.normal_(self.embed_tokens.weight)
+ self.norm.reset_parameters()
+ for layer in self.layers.values():
+ layer.init_weights(buffer_device)
+
+
+class MiMoV2ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin):
+ """NeMo AutoModel causal LM wrapper for MiMo-V2.5-Pro."""
+
+ _keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias", "attention_sink_bias"]
+ _pp_keep_self_forward = True
+ _skip_init_weights_on_load = True
+
+ @dataclass(frozen=True)
+ class ModelCapabilities:
+ """Declared parallelism capabilities for this model class."""
+
+ supports_tp: bool = False
+ supports_cp: bool = False
+ supports_pp: bool = True
+ supports_ep: bool = True
+
+ @classmethod
+ def from_config(
+ cls,
+ config: MiMoV2Config,
+ moe_config: MoEConfig | None = None,
+ backend: BackendConfig | None = None,
+ **kwargs,
+ ) -> MiMoV2ForCausalLM:
+ return cls(config, moe_config, backend, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ *model_args,
+ **kwargs,
+ ) -> MiMoV2ForCausalLM:
+ config = MiMoV2Config.from_pretrained(pretrained_model_name_or_path)
+ return cls.from_config(config, *model_args, **kwargs)
+
+ def __init__(
+ self,
+ config: MiMoV2Config,
+ moe_config: MoEConfig | None = None,
+ backend: BackendConfig | None = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.config = config
+ self.backend = backend or BackendConfig()
+ moe_overrides = kwargs.pop("moe_overrides", None)
+
+ dtype = get_dtype(config.torch_dtype, torch.bfloat16)
+ moe_defaults = dict(
+ dim=config.hidden_size,
+ inter_dim=config.intermediate_size,
+ moe_inter_dim=config.moe_intermediate_size,
+ n_routed_experts=int(config.n_routed_experts or 0),
+ n_shared_experts=int(config.n_shared_experts or 0),
+ n_activated_experts=config.num_experts_per_tok,
+ n_expert_groups=config.n_group,
+ n_limited_groups=config.topk_group,
+ train_gate=True,
+ gate_bias_update_factor=0.0,
+ score_func="sigmoid_with_bias" if config.scoring_func == "sigmoid" else config.scoring_func,
+ route_scale=config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0,
+ aux_loss_coeff=0.0,
+ norm_topk_prob=config.norm_topk_prob,
+ router_bias=False,
+ expert_bias=False,
+ expert_activation="swiglu",
+ softmax_before_topk=False,
+ force_e_score_correction_bias=True,
+ dtype=dtype,
+ )
+ if moe_overrides:
+ moe_defaults.update(moe_overrides)
+ resolved_moe_config = moe_config or MoEConfig(**moe_defaults)
+
+ self.model = MiMoV2Model(config, resolved_moe_config, self.backend)
+ self.lm_head = initialize_linear_module(
+ self.backend.linear,
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ )
+
+ if self.backend.enable_hf_state_dict_adapter:
+ from nemo_automodel.components.models.mimo_v25.state_dict_adapter import MiMoV2StateDictAdapter
+
+ self.state_dict_adapter = MiMoV2StateDictAdapter(
+ self.config,
+ self.model.moe_config if hasattr(self.model, "moe_config") else resolved_moe_config,
+ self.backend,
+ dtype=dtype,
+ )
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self) -> nn.Linear:
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
+ self.lm_head = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ *,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs: Any,
+ ) -> CausalLMOutputWithPast:
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else getattr(self.config, "output_hidden_states", False)
+ )
+ hidden = self.model(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ padding_mask=padding_mask,
+ **kwargs,
+ )
+ return compute_lm_head_logits(self.lm_head, hidden, logits_to_keep, output_hidden_states=output_hidden_states)
+
+ def customize_pipeline_stage_modules(
+ self,
+ module_names_per_stage: list[list[str]],
+ *,
+ layers_prefix: str,
+ text_model: Optional[nn.Module] = None,
+ ) -> list[list[str]]:
+ """Keep the SWA rotary embedding on every PP stage."""
+ text_model = text_model or self.model
+ stage_modules = [list(modules) for modules in module_names_per_stage]
+ if getattr(text_model, "swa_rotary_emb", None) is not None:
+ fqn = f"{layers_prefix}swa_rotary_emb"
+ for modules in stage_modules:
+ if fqn not in modules:
+ modules.append(fqn)
+ return stage_modules
+
+ @torch.no_grad()
+ def initialize_weights(
+ self,
+ buffer_device: Optional[torch.device] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ ) -> None:
+ buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}")
+ with buffer_device:
+ self.model.init_weights(buffer_device)
+ final_out_std = self.config.hidden_size**-0.5
+ cutoff_factor = 3
+ nn.init.trunc_normal_(
+ self.lm_head.weight,
+ mean=0.0,
+ std=final_out_std,
+ a=-cutoff_factor * final_out_std,
+ b=cutoff_factor * final_out_std,
+ )
+ if _has_dtensor_params(self):
+ return
+ cast_model_to_dtype(self, dtype)
+
+
+ModelClass = MiMoV2ForCausalLM
diff --git a/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py
new file mode 100644
index 0000000000..8493404e8a
--- /dev/null
+++ b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+import re
+from typing import Any
+
+import torch
+from torch.distributed.device_mesh import DeviceMesh
+
+from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter
+from nemo_automodel.components.models.common import BackendConfig
+from nemo_automodel.components.models.deepseek_v3.state_dict_adapter import (
+ create_scale_inv_for_weight,
+ dequantize_from_fp8,
+)
+from nemo_automodel.components.moe.config import MoEConfig
+from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin
+
+logger = logging.getLogger(__name__)
+
+NON_QUANTIZED_KEY_PATTERNS = [
+ "input_layernorm.weight",
+ "post_attention_layernorm.weight",
+ "norm.weight",
+ "lm_head.weight",
+ "embed_tokens.weight",
+ "mlp.gate.weight",
+ "self_attn.o_proj.weight",
+ "attention_sink_bias",
+]
+
+
+def _should_quantize_key(key: str) -> bool:
+ if not key.endswith(".weight"):
+ return False
+ return not any(pattern in key for pattern in NON_QUANTIZED_KEY_PATTERNS)
+
+
+class MiMoV2StateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter):
+ """Convert MiMo-V2.5-Pro HF checkpoints to Automodel's grouped MoE layout.
+
+ HF stores routed experts as split per-expert projections:
+ ``mlp.experts.{E}.{gate,up,down}_proj.weight``. Automodel groups those
+ into ``gate_and_up_projs`` and ``down_projs`` so EP can shard experts
+ without materialising every expert on every rank.
+
+ MiMo-V2.5-Pro uses fused QKV (``self_attn.qkv_proj.weight``), so attention
+ projection keys require no renaming — they pass through unchanged.
+ """
+
+ def __init__(
+ self,
+ config: Any,
+ moe_config: MoEConfig,
+ backend: BackendConfig,
+ dtype: torch.dtype = torch.bfloat16,
+ ):
+ self.config = config
+ self.moe_config = moe_config
+ self.backend = backend
+ self.dtype = dtype
+ self._uses_model_prefix = True
+
+ def from_hf(
+ self,
+ hf_state_dict: dict[str, Any],
+ device_mesh: DeviceMesh | None = None,
+ **kwargs,
+ ) -> dict[str, Any]:
+ del kwargs
+ for key in hf_state_dict.keys():
+ if ".mlp.experts." in key and key.endswith(".weight"):
+ self._uses_model_prefix = key.startswith("model.")
+ break
+ hf_state_dict = self._dequantize(hf_state_dict)
+ return self._from_hf_w_merged_experts(hf_state_dict, device_mesh)
+
+ def to_hf(
+ self,
+ state_dict: dict[str, Any],
+ exclude_key_regex: str | None = None,
+ quantization: bool = False,
+ **kwargs,
+ ) -> dict[str, Any]:
+ """Convert Automodel state_dict to the HF MiMo-V2.5-Pro layout.
+
+ Note: The ``quantization`` parameter is accepted for interface
+ compatibility but is **ignored**. MiMo-V2.5-Pro is distributed as an
+ FP8 HF checkpoint, so this adapter always emits FP8 weights plus
+ ``_scale_inv`` companions for keys that match ``_should_quantize_key``,
+ regardless of the caller's preference.
+ """
+ hf_state_dict: dict[str, Any] = {}
+ for fqn, tensor in state_dict.items():
+ for key, value in self.convert_single_tensor_to_hf(
+ fqn,
+ tensor,
+ exclude_key_regex=exclude_key_regex,
+ quantization=quantization,
+ **kwargs,
+ ):
+ hf_state_dict[key] = value
+ return hf_state_dict
+
+ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
+ exclude_key_regex = kwargs.get("exclude_key_regex", None)
+
+ expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs)
+ result = expert_result if expert_result is not None else [(fqn, tensor)]
+
+ if exclude_key_regex:
+ result = [(key, value) for key, value in result if not re.match(exclude_key_regex, key)]
+
+ quantized_result: list[tuple[str, Any]] = []
+ for key, value in result:
+ if _should_quantize_key(key):
+ quantized = value.to(dtype=torch.float8_e4m3fn)
+ quantized_result.append((key, quantized))
+ quantized_result.append((key + "_scale_inv", create_scale_inv_for_weight(quantized)))
+ else:
+ quantized_result.append((key, value))
+ return quantized_result
+
+ def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
+ scale_inv_keys: list[str] = []
+ dequantized_count = 0
+ for key in list(state_dict.keys()):
+ if not key.endswith(".weight"):
+ continue
+ scale_key = key + "_scale_inv"
+ if scale_key not in state_dict:
+ continue
+ state_dict[key] = dequantize_from_fp8(
+ state_dict[key],
+ state_dict[scale_key],
+ dtype=self.dtype,
+ name=key,
+ )
+ scale_inv_keys.append(scale_key)
+ dequantized_count += 1
+
+ for key in scale_inv_keys:
+ state_dict.pop(key, None)
+
+ logger.debug("[MiMo V2.5-Pro FP8 Dequant] Dequantized %s weights", dequantized_count)
+ return state_dict
diff --git a/pyproject.toml b/pyproject.toml
index c7c187d182..bcc73d5785 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -406,6 +406,8 @@ convention = "google"
"nemo_automodel/components/models/llama_bidirectional/export_onnx.py" = ["D103"]
"nemo_automodel/components/models/llava_onevision/rice_vit.py" = ["D101", "D103"]
"nemo_automodel/components/models/llava_onevision/state_dict_adapter.py" = ["D101"]
+"nemo_automodel/components/models/mimo_v25/model.py" = ["D101", "D103"]
+"nemo_automodel/components/models/mimo_v25/state_dict_adapter.py" = ["D101"]
"nemo_automodel/components/models/minimax_m2/model.py" = ["D101"]
"nemo_automodel/components/models/minimax_m2/state_dict_adapter.py" = ["D103"]
"nemo_automodel/components/models/mistral3/model.py" = ["D101", "D103"]