From bc8ad5fc5c37c2d92f315ea335373fd18e8f42fb Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Wed, 10 Jun 2026 15:13:59 -0400 Subject: [PATCH 1/8] feat(support mimo v2.5)- download the config file from hf and modify to match expected arch Signed-off-by: Simar Malhotra --- .../components/models/mimo_v25/__init__.py | 17 ++ .../components/models/mimo_v25/config.py | 174 ++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 nemo_automodel/components/models/mimo_v25/__init__.py create mode 100644 nemo_automodel/components/models/mimo_v25/config.py diff --git a/nemo_automodel/components/models/mimo_v25/__init__.py b/nemo_automodel/components/models/mimo_v25/__init__.py new file mode 100644 index 0000000000..4608ab1c06 --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config + +__all__ = ["MiMoV2Config"] diff --git a/nemo_automodel/components/models/mimo_v25/config.py b/nemo_automodel/components/models/mimo_v25/config.py new file mode 100644 index 0000000000..a1445216d5 --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/config.py @@ -0,0 +1,174 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from transformers.configuration_utils import PretrainedConfig + +_MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"} + + +class MiMoV2Config(PretrainedConfig): + """Configuration for XiaomiMiMo/MiMo-V2.5-Pro.""" + + model_type = "mimo_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int = 151936, + hidden_size: int = 4096, + intermediate_size: int = 22016, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 32, + hidden_act: str = "silu", + max_position_embeddings: int = 32768, + initializer_range: float = 0.02, + layernorm_epsilon: float = 1e-6, + rms_norm_eps: float | None = None, + use_cache: bool = True, + tie_word_embeddings: bool = False, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + attention_dropout: float = 0.0, + attention_bias: bool = False, + attention_value_scale: float | None = None, + head_dim: int | None = None, + v_head_dim: int | None = None, + swa_num_attention_heads: int | None = None, + swa_num_key_value_heads: int | None = None, + swa_head_dim: int | None = None, + swa_v_head_dim: int | None = None, + swa_rope_theta: float | None = None, + sliding_window: int | None = None, + sliding_window_size: int | None = None, + attention_chunk_size: int | None = None, + add_full_attention_sink_bias: bool = False, + add_swa_attention_sink_bias: bool = False, + hybrid_block_size: int | None = None, + hybrid_layer_pattern: list[int] | None = None, + partial_rotary_factor: float = 1.0, + n_routed_experts: int | None = None, + n_shared_experts: int | None = None, + moe_intermediate_size: int | None = None, + num_experts_per_tok: int | None = None, + routed_scaling_factor: float | None = None, + scoring_func: str = "sigmoid", + topk_method: str = "noaux_tc", + n_group: int | None = None, + topk_group: int | None = None, + norm_topk_prob: bool = True, + moe_layer_freq: list[int] | None = None, + attention_projection_layout: str = "split", + torch_dtype: str = "bfloat16", + **kwargs, + ): + rope_parameters = kwargs.pop("rope_parameters", None) + if rope_scaling is None and rope_parameters is not None: + rope_scaling = rope_parameters + + if attention_projection_layout is None: + attention_projection_layout = "split" + if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS: + raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}") + + self.attention_projection_layout = attention_projection_layout + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + if num_attention_heads % num_key_value_heads != 0: + raise ValueError("num_attention_heads must be divisible by num_key_value_heads") + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layernorm_epsilon = layernorm_epsilon + self.rms_norm_eps = layernorm_epsilon if rms_norm_eps is None else rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.attention_value_scale = attention_value_scale + + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim + self.swa_num_attention_heads = ( + swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads + ) + self.swa_num_key_value_heads = ( + swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads + ) + if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0: + raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads") + self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim + self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim + self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta + + if sliding_window is None: + sliding_window = sliding_window_size + self.sliding_window = sliding_window + self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window + self.attention_chunk_size = attention_chunk_size + self.add_full_attention_sink_bias = add_full_attention_sink_bias + self.add_swa_attention_sink_bias = add_swa_attention_sink_bias + + if hybrid_block_size is not None and hybrid_layer_pattern is None: + hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)] + elif hybrid_layer_pattern is None: + hybrid_layer_pattern = [0] * num_hidden_layers + if len(hybrid_layer_pattern) != num_hidden_layers: + raise ValueError("hybrid_layer_pattern length must match num_hidden_layers") + self.hybrid_block_size = hybrid_block_size + self.hybrid_layer_pattern = hybrid_layer_pattern + + self.partial_rotary_factor = partial_rotary_factor + + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.routed_scaling_factor = routed_scaling_factor + self.scoring_func = scoring_func + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob + if isinstance(moe_layer_freq, int): + moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)] + elif moe_layer_freq is None: + moe_layer_freq = [False] * num_hidden_layers + if len(moe_layer_freq) != num_hidden_layers: + raise ValueError("moe_layer_freq length must match num_hidden_layers") + self.moe_layer_freq = moe_layer_freq + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + # Assign after super().__init__() so our string value wins over any + # dtype conversion done by PretrainedConfig. + self.torch_dtype = torch_dtype From 37825fd6ddeeb2399d42dd6f4a9a52d2c0a4cd80 Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Wed, 10 Jun 2026 15:30:04 -0400 Subject: [PATCH 2/8] feat(mimo v2.5)- remove native imports and use NeMo instead Signed-off-by: Simar Malhotra --- .../components/models/mimo_v25/model.py | 610 ++++++++++++++++++ 1 file changed, 610 insertions(+) create mode 100644 nemo_automodel/components/models/mimo_v25/model.py diff --git a/nemo_automodel/components/models/mimo_v25/model.py b/nemo_automodel/components/models/mimo_v25/model.py new file mode 100644 index 0000000000..df68c19959 --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/model.py @@ -0,0 +1,610 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Portions copyright 2026 Xiaomi Corporation. +# Portions copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import copy +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, can_return_tuple, logging + +from nemo_automodel.components.models.common import BackendConfig, initialize_rms_norm_module +from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +logger = logging.get_logger(__name__) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies rotary position embedding to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + sinks: Optional[torch.Tensor] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + if sinks is not None: + sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + attn_weights = torch.cat([attn_weights, sinks], dim=-1) + + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if sinks is not None: + probs = probs[..., :-1] + + attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class MiMoV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MiMoV2Attention(nn.Module): + """MiMoV2 attention. + + `projection_layout` only controls how checkpoint weights are named and + stored: Flash uses separate q/k/v projections, while Pro uses fused qkv. + The attention computation after projection is shared. + """ + + def __init__(self, config, is_swa: bool, layer_idx: int, projection_layout: str = "split"): + super().__init__() + if projection_layout not in {"split", "fused_qkv"}: + raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}") + + self.config = config + self.layer_idx = layer_idx + self.is_swa = is_swa + self.is_causal = True + self.projection_layout = projection_layout + + default_head_dim = config.hidden_size // config.num_attention_heads + default_v_head_dim = getattr(config, "v_head_dim", default_head_dim) + + if is_swa: + self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim)) + self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim) + self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads) + self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads) + else: + self.head_dim = getattr(config, "head_dim", default_head_dim) + self.v_head_dim = getattr(config, "v_head_dim", self.head_dim) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) + if self.rope_dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from " + f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}" + ) + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.scaling = self.head_dim**-0.5 + self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None + self.q_size = self.num_attention_heads * self.head_dim + self.k_size = self.num_key_value_heads * self.head_dim + self.v_size = self.num_key_value_heads * self.v_head_dim + self.o_hidden_size = self.num_attention_heads * self.v_head_dim + self.v_scale = getattr(config, "attention_value_scale", None) + self.attention_sink_bias = ( + nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False) + if ( + (getattr(config, "add_full_attention_sink_bias", False) and not is_swa) + or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa) + ) + else None + ) + + attention_bias = getattr(config, "attention_bias", False) + if self.projection_layout == "fused_qkv": + self.qkv_proj = nn.Linear( + config.hidden_size, + self.q_size + self.k_size + self.v_size, + bias=attention_bias, + ) + else: + self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=attention_bias) + self.k_proj = nn.Linear(config.hidden_size, self.k_size, bias=attention_bias) + self.v_proj = nn.Linear(config.hidden_size, self.v_size, bias=attention_bias) + self.o_proj = nn.Linear(self.o_hidden_size, config.hidden_size, bias=False) + + def _forward_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + input_shape: torch.Size, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.v_scale is not None: + value_states = value_states * self.v_scale + + cos, sin = position_embeddings + query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin) + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_implementation = self.config._attn_implementation + if attn_implementation is not None and attn_implementation.startswith("paged|"): + raise ValueError( + "MiMoV2 remote code does not support paged attention cache. " + "Please use eager, sdpa, flex_attention, or flash_attention_2." + ) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + attn_implementation, eager_attention_forward + ) + if self.attention_sink_bias is not None and attn_implementation == "sdpa": + logger.warning_once( + "MiMoV2 attention sink bias is not supported by SDPA; falling back to eager attention for correctness." + ) + attention_interface = eager_attention_forward + + attention_kwargs = { + "dropout": 0.0 if not self.training else self.attention_dropout, + "scaling": self.scaling, + "position_ids": position_ids, + "is_causal": self.is_causal, + } + if attention_interface is eager_attention_forward: + attention_kwargs["sinks"] = self.attention_sink_bias + else: + if self.attention_sink_bias is not None: + attention_kwargs["s_aux"] = self.attention_sink_bias + if self.sliding_window is not None: + attention_kwargs["sliding_window"] = self.sliding_window + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + **attention_kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + + if self.projection_layout == "fused_qkv": + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.split([self.q_size, self.k_size, self.v_size], dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) + return self._forward_attention( + query_states, + key_states, + value_states, + input_shape, + position_embeddings, + attention_mask, + past_key_values=past_key_values, + cache_position=cache_position, + position_ids=position_ids, + ) + + +class MiMoV2DecoderLayer(nn.Module): + def __init__( + self, + config: MiMoV2Config, + layer_idx: int, + moe_config: MoEConfig, + backend: BackendConfig, + ): + super().__init__() + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1 + self.attention_type = "sliding_window_attention" if is_swa_layer else "full_attention" + self.self_attn = MiMoV2Attention( + config, is_swa_layer, layer_idx, projection_layout=config.attention_projection_layout + ) + is_moe_layer = getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx] + self.mlp = ( + MoE(moe_config, backend) + if is_moe_layer + else MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) + ) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MiMoV2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config, is_swa: bool, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = copy(config) + self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {}) + if is_swa: + self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta) + self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None)) + if self.config.rope_parameters: + self.config.rope_parameters["rope_theta"] = self.config.rope_theta + self.rope_init_fn = ( + self.compute_default_rope_parameters + if self.rope_type == "default" + else ROPE_INIT_FUNCTIONS[self.rope_type] + ) + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @staticmethod + def compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None): + config.standardize_rope_params() + rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + base = rope_parameters["rope_theta"] + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + if dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {dim} from " + f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}" + ) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MiMoV2Model(PreTrainedModel): + config_class = MiMoV2Config + attention_projection_layout = "split" + + def __init__(self, config): + super().__init__(config) + self.attention_projection_layout = getattr( + config, "attention_projection_layout", self.attention_projection_layout + ) + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [ + MiMoV2DecoderLayer( + config, + layer_idx, + attention_projection_layout=self.attention_projection_layout, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False) + self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True) + self.has_sliding_layers = any(pattern == 1 for pattern in config.hybrid_layer_pattern) + self.config.layer_types = [ + "sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention" + for i in range(config.num_hidden_layers) + ] + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if not isinstance(causal_mask_mapping := attention_mask, dict): + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + if self.has_sliding_layers: + if getattr(self.config, "sliding_window", None) is None: + raise ValueError("MiMoV2 config `sliding_window` must be set when hybrid_layer_pattern uses SWA.") + causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_embeddings=position_embeddings + if decoder_layer.attention_type == "full_attention" + else swa_position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class MiMoV2ForCausalLM(PreTrainedModel, GenerationMixin): + config_class = MiMoV2Config + model_class = MiMoV2Model + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + _keys_to_ignore_on_load_unexpected = [ + r"model\.(swa_)?rotary_emb\.inv_freq", + r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq", + r"model\.layers\.\d+\.self_attn\.rotary_emb\.(cos_cached|sin_cached)", + r"model\.mtp\..*", + ] + + def __init__(self, config): + super().__init__(config) + self.model = self.model_class(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MiMoV2Attention", + "MiMoV2DecoderLayer", + "MiMoV2ForCausalLM", + "MiMoV2MLP", + "MiMoV2MoE", + "MiMoV2MoEGate", + "MiMoV2Model", + "MiMoV2RMSNorm", + "MiMoV2RotaryEmbedding", +] From 7e77c0d79b8d6510fe04c90ac266eb16460c4ac6 Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Wed, 10 Jun 2026 15:49:35 -0400 Subject: [PATCH 3/8] feat(mimo_v25): add NeMo model and config for MiMo-V2.5-Pro Signed-off-by: Simar Malhotra --- .../components/models/mimo_v25/__init__.py | 3 +- .../components/models/mimo_v25/model.py | 1350 +++++++++-------- pyproject.toml | 2 + 3 files changed, 744 insertions(+), 611 deletions(-) diff --git a/nemo_automodel/components/models/mimo_v25/__init__.py b/nemo_automodel/components/models/mimo_v25/__init__.py index 4608ab1c06..01fca469cf 100644 --- a/nemo_automodel/components/models/mimo_v25/__init__.py +++ b/nemo_automodel/components/models/mimo_v25/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config +from nemo_automodel.components.models.mimo_v25.model import MiMoV2ForCausalLM, MiMoV2Model -__all__ = ["MiMoV2Config"] +__all__ = ["MiMoV2Config", "MiMoV2ForCausalLM", "MiMoV2Model"] diff --git a/nemo_automodel/components/models/mimo_v25/model.py b/nemo_automodel/components/models/mimo_v25/model.py index df68c19959..3c363c0e94 100644 --- a/nemo_automodel/components/models/mimo_v25/model.py +++ b/nemo_automodel/components/models/mimo_v25/model.py @@ -1,610 +1,740 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -# Portions copyright 2026 Xiaomi Corporation. -# Portions copyright 2026 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from copy import copy -from typing import Callable, Optional, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from transformers.cache_utils import Cache, DynamicCache -from transformers.generation import GenerationMixin -from transformers.integrations import use_kernel_forward_from_hub -from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from transformers.processing_utils import Unpack -from transformers.utils import TransformersKwargs, can_return_tuple, logging - -from nemo_automodel.components.models.common import BackendConfig, initialize_rms_norm_module -from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config -from nemo_automodel.components.moe.config import MoEConfig -from nemo_automodel.components.moe.layers import MLP, MoE -from nemo_automodel.shared.utils import dtype_from_str as get_dtype - - -logger = logging.get_logger(__name__) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies rotary position embedding to query and key tensors.""" - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - sinks: Optional[torch.Tensor] = None, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - if sinks is not None: - sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - attn_weights = torch.cat([attn_weights, sinks], dim=-1) - - attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values - probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - - if sinks is not None: - probs = probs[..., :-1] - - attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -@use_kernel_forward_from_hub("RMSNorm") -class MiMoV2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class MiMoV2Attention(nn.Module): - """MiMoV2 attention. - - `projection_layout` only controls how checkpoint weights are named and - stored: Flash uses separate q/k/v projections, while Pro uses fused qkv. - The attention computation after projection is shared. - """ - - def __init__(self, config, is_swa: bool, layer_idx: int, projection_layout: str = "split"): - super().__init__() - if projection_layout not in {"split", "fused_qkv"}: - raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}") - - self.config = config - self.layer_idx = layer_idx - self.is_swa = is_swa - self.is_causal = True - self.projection_layout = projection_layout - - default_head_dim = config.hidden_size // config.num_attention_heads - default_v_head_dim = getattr(config, "v_head_dim", default_head_dim) - - if is_swa: - self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim)) - self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim) - self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads) - self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads) - else: - self.head_dim = getattr(config, "head_dim", default_head_dim) - self.v_head_dim = getattr(config, "v_head_dim", self.head_dim) - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - - self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) - if self.rope_dim % 2 != 0: - raise ValueError( - f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from " - f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}" - ) - self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads - self.attention_dropout = getattr(config, "attention_dropout", 0.0) - self.scaling = self.head_dim**-0.5 - self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None - self.q_size = self.num_attention_heads * self.head_dim - self.k_size = self.num_key_value_heads * self.head_dim - self.v_size = self.num_key_value_heads * self.v_head_dim - self.o_hidden_size = self.num_attention_heads * self.v_head_dim - self.v_scale = getattr(config, "attention_value_scale", None) - self.attention_sink_bias = ( - nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False) - if ( - (getattr(config, "add_full_attention_sink_bias", False) and not is_swa) - or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa) - ) - else None - ) - - attention_bias = getattr(config, "attention_bias", False) - if self.projection_layout == "fused_qkv": - self.qkv_proj = nn.Linear( - config.hidden_size, - self.q_size + self.k_size + self.v_size, - bias=attention_bias, - ) - else: - self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=attention_bias) - self.k_proj = nn.Linear(config.hidden_size, self.k_size, bias=attention_bias) - self.v_proj = nn.Linear(config.hidden_size, self.v_size, bias=attention_bias) - self.o_proj = nn.Linear(self.o_hidden_size, config.hidden_size, bias=False) - - def _forward_attention( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - input_shape: torch.Size, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.v_scale is not None: - value_states = value_states * self.v_scale - - cos, sin = position_embeddings - query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) - key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) - query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin) - query_states = torch.cat([query_rope, query_nope], dim=-1) - key_states = torch.cat([key_rope, key_nope], dim=-1) - - if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attn_implementation = self.config._attn_implementation - if attn_implementation is not None and attn_implementation.startswith("paged|"): - raise ValueError( - "MiMoV2 remote code does not support paged attention cache. " - "Please use eager, sdpa, flex_attention, or flash_attention_2." - ) - - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - attn_implementation, eager_attention_forward - ) - if self.attention_sink_bias is not None and attn_implementation == "sdpa": - logger.warning_once( - "MiMoV2 attention sink bias is not supported by SDPA; falling back to eager attention for correctness." - ) - attention_interface = eager_attention_forward - - attention_kwargs = { - "dropout": 0.0 if not self.training else self.attention_dropout, - "scaling": self.scaling, - "position_ids": position_ids, - "is_causal": self.is_causal, - } - if attention_interface is eager_attention_forward: - attention_kwargs["sinks"] = self.attention_sink_bias - else: - if self.attention_sink_bias is not None: - attention_kwargs["s_aux"] = self.attention_sink_bias - if self.sliding_window is not None: - attention_kwargs["sliding_window"] = self.sliding_window - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - **attention_kwargs, - ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] - - if self.projection_layout == "fused_qkv": - qkv_states = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split([self.q_size, self.k_size, self.v_size], dim=-1) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) - return self._forward_attention( - query_states, - key_states, - value_states, - input_shape, - position_embeddings, - attention_mask, - past_key_values=past_key_values, - cache_position=cache_position, - position_ids=position_ids, - ) - - -class MiMoV2DecoderLayer(nn.Module): - def __init__( - self, - config: MiMoV2Config, - layer_idx: int, - moe_config: MoEConfig, - backend: BackendConfig, - ): - super().__init__() - dtype = get_dtype(config.torch_dtype, torch.bfloat16) - is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1 - self.attention_type = "sliding_window_attention" if is_swa_layer else "full_attention" - self.self_attn = MiMoV2Attention( - config, is_swa_layer, layer_idx, projection_layout=config.attention_projection_layout - ) - is_moe_layer = getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx] - self.mlp = ( - MoE(moe_config, backend) - if is_moe_layer - else MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) - ) - self.input_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype - ) - self.post_attention_layernorm = initialize_rms_norm_module( - backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class MiMoV2RotaryEmbedding(nn.Module): - inv_freq: torch.Tensor - - def __init__(self, config, is_swa: bool, device=None): - super().__init__() - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = copy(config) - self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {}) - if is_swa: - self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta) - self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None)) - if self.config.rope_parameters: - self.config.rope_parameters["rope_theta"] = self.config.rope_theta - self.rope_init_fn = ( - self.compute_default_rope_parameters - if self.rope_type == "default" - else ROPE_INIT_FUNCTIONS[self.rope_type] - ) - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @staticmethod - def compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None): - config.standardize_rope_params() - rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - base = rope_parameters["rope_theta"] - partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - if dim % 2 != 0: - raise ValueError( - f"MiMoV2 rotary dimension must be even, got {dim} from " - f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}" - ) - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, 1.0 - - @torch.no_grad() - @dynamic_rope_update - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class MiMoV2Model(PreTrainedModel): - config_class = MiMoV2Config - attention_projection_layout = "split" - - def __init__(self, config): - super().__init__(config) - self.attention_projection_layout = getattr( - config, "attention_projection_layout", self.attention_projection_layout - ) - self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList( - [ - MiMoV2DecoderLayer( - config, - layer_idx, - attention_projection_layout=self.attention_projection_layout, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon) - self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False) - self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True) - self.has_sliding_layers = any(pattern == 1 for pattern in config.hybrid_layer_pattern) - self.config.layer_types = [ - "sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention" - for i in range(config.num_hidden_layers) - ] - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPast: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - if not isinstance(causal_mask_mapping := attention_mask, dict): - mask_kwargs = { - "config": self.config, - "input_embeds": inputs_embeds, - "attention_mask": attention_mask, - "cache_position": cache_position, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - causal_mask_mapping = { - "full_attention": create_causal_mask(**mask_kwargs), - } - if self.has_sliding_layers: - if getattr(self.config, "sliding_window", None) is None: - raise ValueError("MiMoV2 config `sliding_window` must be set when hybrid_layer_pattern uses SWA.") - causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs) - - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) - swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings - if decoder_layer.attention_type == "full_attention" - else swa_position_embeddings, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - ) - - -class MiMoV2ForCausalLM(PreTrainedModel, GenerationMixin): - config_class = MiMoV2Config - model_class = MiMoV2Model - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - _keys_to_ignore_on_load_unexpected = [ - r"model\.(swa_)?rotary_emb\.inv_freq", - r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq", - r"model\.layers\.\d+\.self_attn\.rotary_emb\.(cos_cached|sin_cached)", - r"model\.mtp\..*", - ] - - def __init__(self, config): - super().__init__(config) - self.model = self.model_class(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - @can_return_tuple - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "MiMoV2Attention", - "MiMoV2DecoderLayer", - "MiMoV2ForCausalLM", - "MiMoV2MLP", - "MiMoV2MoE", - "MiMoV2MoEGate", - "MiMoV2Model", - "MiMoV2RMSNorm", - "MiMoV2RotaryEmbedding", -] +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Portions copyright 2026 Xiaomi Corporation. +# Portions copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import copy +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from nemo_automodel.components.models.common import ( + BackendConfig, + initialize_linear_module, + initialize_rms_norm_module, +) +from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +from nemo_automodel.components.models.common.utils import ( + _has_dtensor_params, + cast_model_to_dtype, + compute_lm_head_logits, +) +from nemo_automodel.components.models.mimo_v25.config import MiMoV2Config +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.fsdp_mixin import MoEFSDPSyncMixin +from nemo_automodel.components.moe.layers import MLP, MoE +from nemo_automodel.shared.utils import dtype_from_str as get_dtype + + +def _convert_bool_4d_mask_to_additive(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + if mask.ndim != 4 or mask.dtype != torch.bool: + return mask + additive = torch.zeros(mask.shape, dtype=dtype, device=mask.device) + return additive.masked_fill(~mask, torch.finfo(dtype).min) + + +def _fallback_additive_mask( + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: torch.Tensor | None = None, + sliding_window: int | None = None, +) -> torch.Tensor: + min_val = torch.finfo(dtype).min + idx = torch.arange(seq_len, device=device) + masked = idx.unsqueeze(0) > idx.unsqueeze(1) + if sliding_window is not None and sliding_window > 0: + masked = masked | ((idx.unsqueeze(1) - idx.unsqueeze(0)) >= sliding_window) + additive = torch.zeros((seq_len, seq_len), dtype=dtype, device=device).masked_fill(masked, min_val) + additive = additive.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, seq_len, seq_len).contiguous() + if attention_mask is not None and attention_mask.ndim == 2: + pad = (1.0 - attention_mask.to(dtype=dtype, device=device)).unsqueeze(1).unsqueeze(2) * min_val + additive = additive + pad + return additive + + +def _ensure_additive_mask( + mask: torch.Tensor | None, + *, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + attention_mask: torch.Tensor | None, + sliding_window: int | None, +) -> torch.Tensor: + if mask is None or not isinstance(mask, torch.Tensor): + return _fallback_additive_mask(batch_size, seq_len, dtype, device, attention_mask, sliding_window) + return _convert_bool_4d_mask_to_additive(mask, dtype) + + +def _derive_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor: + if attention_mask.ndim == 2: + return attention_mask == 0 + if attention_mask.ndim == 4: + diagonal = torch.diagonal(attention_mask[:, 0], dim1=-2, dim2=-1) + return diagonal.logical_not() if attention_mask.dtype == torch.bool else diagonal != 0 + return attention_mask.bool().logical_not() + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + sinks: Optional[torch.Tensor] = None, + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] + + if sinks is not None: + sink_bias = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + attn_weights = torch.cat([attn_weights, sink_bias.to(attn_weights.dtype)], dim=-1) + + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if sinks is not None: + probs = probs[..., :-1] + + probs = F.dropout(probs, p=dropout, training=module.training) + attn_output = torch.matmul(probs, value_states) + return attn_output.transpose(1, 2).contiguous(), probs + + +class MiMoV2Attention(nn.Module): + """MiMoV2 hybrid attention (full or sliding-window).""" + + def __init__( + self, + config: MiMoV2Config, + is_swa: bool, + layer_idx: int, + projection_layout: str, + backend: BackendConfig, + dtype: torch.dtype, + ): + super().__init__() + if projection_layout not in {"split", "fused_qkv"}: + raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}") + + self.layer_idx = layer_idx + self.is_swa = is_swa + self.is_causal = True + self.projection_layout = projection_layout + + default_head_dim = config.hidden_size // config.num_attention_heads + default_v_head_dim = getattr(config, "v_head_dim", default_head_dim) + + if is_swa: + self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim)) + self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim) + self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads) + self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads) + else: + self.head_dim = getattr(config, "head_dim", default_head_dim) + self.v_head_dim = getattr(config, "v_head_dim", self.head_dim) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0)) + if self.rope_dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from " + f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}" + ) + + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.scaling = self.head_dim**-0.5 + self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None + + self.q_size = self.num_attention_heads * self.head_dim + self.k_size = self.num_key_value_heads * self.head_dim + self.v_size = self.num_key_value_heads * self.v_head_dim + self.o_hidden_size = self.num_attention_heads * self.v_head_dim + self.v_scale = getattr(config, "attention_value_scale", None) + + self.attention_sink_bias = ( + nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False) + if ( + (getattr(config, "add_full_attention_sink_bias", False) and not is_swa) + or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa) + ) + else None + ) + + attention_bias = getattr(config, "attention_bias", False) + if self.projection_layout == "fused_qkv": + self.qkv_proj = initialize_linear_module( + backend.linear, + config.hidden_size, + self.q_size + self.k_size + self.v_size, + bias=attention_bias, + dtype=dtype, + ) + else: + self.q_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.q_size, bias=attention_bias, dtype=dtype + ) + self.k_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.k_size, bias=attention_bias, dtype=dtype + ) + self.v_proj = initialize_linear_module( + backend.linear, config.hidden_size, self.v_size, bias=attention_bias, dtype=dtype + ) + self.o_proj = initialize_linear_module( + backend.linear, self.o_hidden_size, config.hidden_size, bias=False, dtype=dtype + ) + + def _forward_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + input_shape: torch.Size, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.v_scale is not None: + value_states = value_states * self.v_scale + + cos, sin = position_embeddings + query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin) + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + attn_output, _ = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + sinks=self.attention_sink_bias, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + return self.o_proj(attn_output) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + **kwargs: Any, + ) -> torch.Tensor: + del kwargs + input_shape = hidden_states.shape[:-1] + + if self.projection_layout == "fused_qkv": + qkv = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2) + return self._forward_attention( + query_states, key_states, value_states, input_shape, position_embeddings, attention_mask + ) + + +class MiMoV2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: MiMoV2Config, is_swa: bool, device: Optional[torch.device] = None): + super().__init__() + self.rope_type = ( + config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default")) + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict) + else "default" + ) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = copy(config) + self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {}) + if is_swa: + self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta) + self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None)) + if self.config.rope_parameters: + self.config.rope_parameters["rope_theta"] = self.config.rope_theta + + self.rope_init_fn = ( + self.compute_default_rope_parameters if self.rope_type == "default" else ROPE_INIT_FUNCTIONS[self.rope_type] + ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: MiMoV2Config, + device: Optional[torch.device] = None, + seq_len: Optional[int] = None, + layer_type: Optional[str] = None, + ) -> tuple[torch.Tensor, float]: + config.standardize_rope_params() + rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + base = rope_parameters["rope_theta"] + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + if dim % 2 != 0: + raise ValueError( + f"MiMoV2 rotary dimension must be even, got {dim} from " + f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}" + ) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + + @torch.no_grad() + @dynamic_rope_update + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class MiMoV2DecoderLayer(nn.Module): + def __init__( + self, + config: MiMoV2Config, + layer_idx: int, + moe_config: MoEConfig, + backend: BackendConfig, + ): + super().__init__() + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1 + self.attention_type = "sliding_attention" if is_swa_layer else "full_attention" + self.self_attn = MiMoV2Attention( + config, + is_swa=is_swa_layer, + layer_idx=layer_idx, + projection_layout=config.attention_projection_layout, + backend=backend, + dtype=dtype, + ) + is_moe_layer = getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx] + self.mlp = ( + MoE(moe_config, backend) + if is_moe_layer + else MLP(config.hidden_size, config.intermediate_size, backend.linear, dtype=dtype) + ) + self.input_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + self.post_attention_layernorm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class MiMoV2Model(nn.Module): + def __init__(self, config: MiMoV2Config, moe_config: MoEConfig, backend: BackendConfig): + super().__init__() + self.config = config + self.backend = backend + + if backend.gate_precision is None: + backend.gate_precision = torch.float32 + + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, dtype=dtype) + self.layers = nn.ModuleDict( + {str(i): MiMoV2DecoderLayer(config, i, moe_config, backend) for i in range(config.num_hidden_layers)} + ) + self.norm = initialize_rms_norm_module( + backend.rms_norm, config.hidden_size, eps=config.layernorm_epsilon, dtype=dtype + ) + self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False) + self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True) + self.has_sliding_layers = any(p == 1 for p in config.hybrid_layer_pattern) + + def _build_causal_mask_mapping( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor | dict[str, torch.Tensor] | None, + position_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, seq_len = inputs_embeds.shape[:2] + dtype = inputs_embeds.dtype + device = inputs_embeds.device + + if isinstance(attention_mask, dict): + return { + "full_attention": _ensure_additive_mask( + attention_mask.get("full_attention"), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=None, + ), + "sliding_attention": _ensure_additive_mask( + attention_mask.get("sliding_attention", attention_mask.get("sliding_window_attention")), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=None, + sliding_window=self.config.sliding_window, + ), + } + + mask_kwargs = dict( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + pad_mask = attention_mask if isinstance(attention_mask, torch.Tensor) else None + return { + "full_attention": _ensure_additive_mask( + create_causal_mask(**mask_kwargs), + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=pad_mask, + sliding_window=None, + ), + "sliding_attention": _ensure_additive_mask( + create_sliding_window_causal_mask(**mask_kwargs) if self.has_sliding_layers else None, + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + device=device, + attention_mask=pad_mask, + sliding_window=self.config.sliding_window, + ), + } + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + *, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + del kwargs + if inputs_embeds is None: + if input_ids is None: + raise ValueError("input_ids or inputs_embeds must be provided") + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if padding_mask is None and isinstance(attention_mask, torch.Tensor): + padding_mask = _derive_padding_mask(attention_mask) + + causal_mask_mapping = self._build_causal_mask_mapping( + inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids) + + for layer in self.layers.values(): + layer_position_embeddings = ( + swa_position_embeddings if layer.attention_type == "sliding_attention" else position_embeddings + ) + hidden_states = layer( + hidden_states, + attention_mask=causal_mask_mapping[layer.attention_type], + position_embeddings=layer_position_embeddings, + padding_mask=padding_mask, + ) + + return self.norm(hidden_states) + + @torch.no_grad() + def init_weights(self, buffer_device: Optional[torch.device] = None) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + nn.init.normal_(self.embed_tokens.weight) + self.norm.reset_parameters() + for layer in self.layers.values(): + layer.init_weights(buffer_device) + + +class MiMoV2ForCausalLM(HFCheckpointingMixin, nn.Module, MoEFSDPSyncMixin): + """NeMo AutoModel causal LM wrapper for MiMo-V2.5-Pro.""" + + _keep_in_fp32_modules_strict = ["mlp.gate.e_score_correction_bias", "attention_sink_bias"] + _pp_keep_self_forward = True + _skip_init_weights_on_load = True + + @dataclass(frozen=True) + class ModelCapabilities: + """Declared parallelism capabilities for this model class.""" + + supports_tp: bool = False + supports_cp: bool = False + supports_pp: bool = True + supports_ep: bool = True + + @classmethod + def from_config( + cls, + config: MiMoV2Config, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ) -> MiMoV2ForCausalLM: + return cls(config, moe_config, backend, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + *model_args, + **kwargs, + ) -> MiMoV2ForCausalLM: + config = MiMoV2Config.from_pretrained(pretrained_model_name_or_path) + return cls.from_config(config, *model_args, **kwargs) + + def __init__( + self, + config: MiMoV2Config, + moe_config: MoEConfig | None = None, + backend: BackendConfig | None = None, + **kwargs, + ): + super().__init__() + self.config = config + self.backend = backend or BackendConfig() + moe_overrides = kwargs.pop("moe_overrides", None) + + dtype = get_dtype(config.torch_dtype, torch.bfloat16) + moe_defaults = dict( + dim=config.hidden_size, + inter_dim=config.intermediate_size, + moe_inter_dim=config.moe_intermediate_size, + n_routed_experts=int(config.n_routed_experts or 0), + n_shared_experts=int(config.n_shared_experts or 0), + n_activated_experts=config.num_experts_per_tok, + n_expert_groups=config.n_group, + n_limited_groups=config.topk_group, + train_gate=True, + gate_bias_update_factor=0.0, + score_func="sigmoid_with_bias" if config.scoring_func == "sigmoid" else config.scoring_func, + route_scale=config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0, + aux_loss_coeff=0.0, + norm_topk_prob=config.norm_topk_prob, + router_bias=False, + expert_bias=False, + expert_activation="swiglu", + softmax_before_topk=False, + force_e_score_correction_bias=True, + dtype=dtype, + ) + if moe_overrides: + moe_defaults.update(moe_overrides) + resolved_moe_config = moe_config or MoEConfig(**moe_defaults) + + self.model = MiMoV2Model(config, resolved_moe_config, self.backend) + self.lm_head = initialize_linear_module( + self.backend.linear, + config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + ) + + if self.backend.enable_hf_state_dict_adapter: + from nemo_automodel.components.models.mimo_v25.state_dict_adapter import MiMoV2StateDictAdapter + + self.state_dict_adapter = MiMoV2StateDictAdapter( + self.config, + self.model.moe_config if hasattr(self.model, "moe_config") else resolved_moe_config, + self.backend, + dtype=dtype, + ) + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Linear: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + *, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + output_hidden_states: Optional[bool] = None, + **kwargs: Any, + ) -> CausalLMOutputWithPast: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + hidden = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + padding_mask=padding_mask, + **kwargs, + ) + return compute_lm_head_logits(self.lm_head, hidden, logits_to_keep, output_hidden_states=output_hidden_states) + + def customize_pipeline_stage_modules( + self, + module_names_per_stage: list[list[str]], + *, + layers_prefix: str, + text_model: Optional[nn.Module] = None, + ) -> list[list[str]]: + """Keep the SWA rotary embedding on every PP stage.""" + text_model = text_model or self.model + stage_modules = [list(modules) for modules in module_names_per_stage] + if getattr(text_model, "swa_rotary_emb", None) is not None: + fqn = f"{layers_prefix}swa_rotary_emb" + for modules in stage_modules: + if fqn not in modules: + modules.append(fqn) + return stage_modules + + @torch.no_grad() + def initialize_weights( + self, + buffer_device: Optional[torch.device] = None, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + buffer_device = buffer_device or torch.device(f"cuda:{torch.cuda.current_device()}") + with buffer_device: + self.model.init_weights(buffer_device) + final_out_std = self.config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.lm_head.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if _has_dtensor_params(self): + return + cast_model_to_dtype(self, dtype) + + +ModelClass = MiMoV2ForCausalLM diff --git a/pyproject.toml b/pyproject.toml index c7c187d182..bcc73d5785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -406,6 +406,8 @@ convention = "google" "nemo_automodel/components/models/llama_bidirectional/export_onnx.py" = ["D103"] "nemo_automodel/components/models/llava_onevision/rice_vit.py" = ["D101", "D103"] "nemo_automodel/components/models/llava_onevision/state_dict_adapter.py" = ["D101"] +"nemo_automodel/components/models/mimo_v25/model.py" = ["D101", "D103"] +"nemo_automodel/components/models/mimo_v25/state_dict_adapter.py" = ["D101"] "nemo_automodel/components/models/minimax_m2/model.py" = ["D101"] "nemo_automodel/components/models/minimax_m2/state_dict_adapter.py" = ["D103"] "nemo_automodel/components/models/mistral3/model.py" = ["D101", "D103"] From 2fd7b2b3e5835a8e21a574b770ba21537723b1e0 Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Wed, 10 Jun 2026 16:02:04 -0400 Subject: [PATCH 4/8] feat(mimo_v25): add state dict adapter and registry entries for MiMo-V2.5-Pro Implements MiMoV2StateDictAdapter (MoeSplitExpertsStateDictMixin + StateDictAdapter) for XiaomiMiMo/MiMo-V2.5-Pro: - from_hf: FP8 dequantisation (weight + _scale_inv pairs) followed by per-expert weight merging via _from_hf_w_merged_experts; fused QKV keys (self_attn.qkv_proj.weight) pass through unchanged since HF and NeMo use the same name - to_hf: splits merged expert tensors back to per-expert layout and re-quantises eligible weights to float8_e4m3fn with scale_inv companions; NON_QUANTIZED_KEY_PATTERNS matches the V2-Flash precedent (norms, embeddings, lm_head, router gate, o_proj, attention_sink_bias) - Registers MiMoV2ForCausalLM in MODEL_ARCH_MAPPING and mimo_v2 in _CUSTOM_CONFIG_REGISTRATIONS so NeMoAutoModelForCausalLM can resolve the model from an HF config Smoke-tested end-to-end on CPU with a tiny MiMo-V2.5-Pro config (4 layers, fused QKV, mixed full/SWA attention, MoE layers): imports, registry lookup, model instantiation, adapter attachment, and a forward pass all pass cleanly. Signed-off-by: Simar Signed-off-by: Simar Malhotra --- nemo_automodel/_transformers/registry.py | 5 + .../models/mimo_v25/state_dict_adapter.py | 160 ++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 nemo_automodel/components/models/mimo_v25/state_dict_adapter.py diff --git a/nemo_automodel/_transformers/registry.py b/nemo_automodel/_transformers/registry.py index 765b81611a..ea129cdae3 100644 --- a/nemo_automodel/_transformers/registry.py +++ b/nemo_automodel/_transformers/registry.py @@ -140,6 +140,10 @@ "MiMoV2FlashForCausalLM", ("nemo_automodel.components.models.mimo_v2_flash.model", "MiMoV2FlashForCausalLM"), ), + ( + "MiMoV2ForCausalLM", + ("nemo_automodel.components.models.mimo_v25.model", "MiMoV2ForCausalLM"), + ), ( "Ministral3ForCausalLM", ("nemo_automodel.components.models.mistral3.model", "Ministral3ForCausalLM"), @@ -282,6 +286,7 @@ "kimi_vl": ("nemo_automodel.components.models.kimivl.model", "KimiVLConfig"), "llavaonevision1_5": ("nemo_automodel.components.models.llava_onevision.model", "Llavaonevision1_5Config"), "mimo_v2_flash": ("nemo_automodel.components.models.mimo_v2_flash.config", "MiMoV2FlashConfig"), + "mimo_v2": ("nemo_automodel.components.models.mimo_v25.config", "MiMoV2Config"), "minimax_m3_vl": ("nemo_automodel.components.models.minimax_m3_vl.config", "MiniMaxM3VLConfig"), "mistral4": ("nemo_automodel.components.models.mistral4.configuration", "Mistral4Config"), "step3p5v": ("nemo_automodel.components.models.step3p7.configuration_step3p7", "Step3p5VConfig"), diff --git a/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py new file mode 100644 index 0000000000..8493404e8a --- /dev/null +++ b/nemo_automodel/components/models/mimo_v25/state_dict_adapter.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import re +from typing import Any + +import torch +from torch.distributed.device_mesh import DeviceMesh + +from nemo_automodel.components.checkpoint.state_dict_adapter import StateDictAdapter +from nemo_automodel.components.models.common import BackendConfig +from nemo_automodel.components.models.deepseek_v3.state_dict_adapter import ( + create_scale_inv_for_weight, + dequantize_from_fp8, +) +from nemo_automodel.components.moe.config import MoEConfig +from nemo_automodel.components.moe.state_dict_mixin import MoESplitExpertsStateDictMixin + +logger = logging.getLogger(__name__) + +NON_QUANTIZED_KEY_PATTERNS = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "norm.weight", + "lm_head.weight", + "embed_tokens.weight", + "mlp.gate.weight", + "self_attn.o_proj.weight", + "attention_sink_bias", +] + + +def _should_quantize_key(key: str) -> bool: + if not key.endswith(".weight"): + return False + return not any(pattern in key for pattern in NON_QUANTIZED_KEY_PATTERNS) + + +class MiMoV2StateDictAdapter(MoESplitExpertsStateDictMixin, StateDictAdapter): + """Convert MiMo-V2.5-Pro HF checkpoints to Automodel's grouped MoE layout. + + HF stores routed experts as split per-expert projections: + ``mlp.experts.{E}.{gate,up,down}_proj.weight``. Automodel groups those + into ``gate_and_up_projs`` and ``down_projs`` so EP can shard experts + without materialising every expert on every rank. + + MiMo-V2.5-Pro uses fused QKV (``self_attn.qkv_proj.weight``), so attention + projection keys require no renaming — they pass through unchanged. + """ + + def __init__( + self, + config: Any, + moe_config: MoEConfig, + backend: BackendConfig, + dtype: torch.dtype = torch.bfloat16, + ): + self.config = config + self.moe_config = moe_config + self.backend = backend + self.dtype = dtype + self._uses_model_prefix = True + + def from_hf( + self, + hf_state_dict: dict[str, Any], + device_mesh: DeviceMesh | None = None, + **kwargs, + ) -> dict[str, Any]: + del kwargs + for key in hf_state_dict.keys(): + if ".mlp.experts." in key and key.endswith(".weight"): + self._uses_model_prefix = key.startswith("model.") + break + hf_state_dict = self._dequantize(hf_state_dict) + return self._from_hf_w_merged_experts(hf_state_dict, device_mesh) + + def to_hf( + self, + state_dict: dict[str, Any], + exclude_key_regex: str | None = None, + quantization: bool = False, + **kwargs, + ) -> dict[str, Any]: + """Convert Automodel state_dict to the HF MiMo-V2.5-Pro layout. + + Note: The ``quantization`` parameter is accepted for interface + compatibility but is **ignored**. MiMo-V2.5-Pro is distributed as an + FP8 HF checkpoint, so this adapter always emits FP8 weights plus + ``_scale_inv`` companions for keys that match ``_should_quantize_key``, + regardless of the caller's preference. + """ + hf_state_dict: dict[str, Any] = {} + for fqn, tensor in state_dict.items(): + for key, value in self.convert_single_tensor_to_hf( + fqn, + tensor, + exclude_key_regex=exclude_key_regex, + quantization=quantization, + **kwargs, + ): + hf_state_dict[key] = value + return hf_state_dict + + def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]: + exclude_key_regex = kwargs.get("exclude_key_regex", None) + + expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs) + result = expert_result if expert_result is not None else [(fqn, tensor)] + + if exclude_key_regex: + result = [(key, value) for key, value in result if not re.match(exclude_key_regex, key)] + + quantized_result: list[tuple[str, Any]] = [] + for key, value in result: + if _should_quantize_key(key): + quantized = value.to(dtype=torch.float8_e4m3fn) + quantized_result.append((key, quantized)) + quantized_result.append((key + "_scale_inv", create_scale_inv_for_weight(quantized))) + else: + quantized_result.append((key, value)) + return quantized_result + + def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: + scale_inv_keys: list[str] = [] + dequantized_count = 0 + for key in list(state_dict.keys()): + if not key.endswith(".weight"): + continue + scale_key = key + "_scale_inv" + if scale_key not in state_dict: + continue + state_dict[key] = dequantize_from_fp8( + state_dict[key], + state_dict[scale_key], + dtype=self.dtype, + name=key, + ) + scale_inv_keys.append(scale_key) + dequantized_count += 1 + + for key in scale_inv_keys: + state_dict.pop(key, None) + + logger.debug("[MiMo V2.5-Pro FP8 Dequant] Dequantized %s weights", dequantized_count) + return state_dict From c698570d411876f8d652d0133c7ad4a7eee7b90a Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Wed, 10 Jun 2026 21:18:20 -0400 Subject: [PATCH 5/8] feat(mimo_v25): add hellaswag finetune recipe for MiMo-V2.5-Pro Adds examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml: - 16-node (128 H100) recipe using pp_size=4, ep_size=32 matching the declared ModelCapabilities (supports_pp=True, supports_ep=True) - dequantize_base_checkpoint=true to handle the FP8 base checkpoint via MiMoV2StateDictAdapter before training - Same hyperparameters (lr=1e-5, AdamW, max_steps=100) and dataset splits as the MiMo-V2-Flash hellaswag recipe Signed-off-by: Simar Signed-off-by: Simar Malhotra --- .../mimo_v25/mimo_v25_pro_hellaswag.yaml | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml diff --git a/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml new file mode 100644 index 0000000000..ac521dcdbe --- /dev/null +++ b/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml @@ -0,0 +1,134 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 16 H100 nodes (128 GPUs): +# torchrun --nproc-per-node 8 -m nemo_automodel.cli.app examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml + +recipe: TrainFinetuneRecipeForNextTokenPrediction + +seed: 1234 + +step_scheduler: + global_batch_size: 256 + local_batch_size: 8 + ckpt_every_steps: 25 + val_every_steps: 500 + num_epochs: 1 + max_steps: 100 + +distributed: + strategy: fsdp2 + tp_size: 1 + cp_size: 1 + pp_size: 4 + ep_size: 32 + + sequence_parallel: false + activation_checkpointing: true + + pipeline: + pp_schedule: interleaved1f1b + pp_microbatch_size: 1 + layers_per_stage: 2 + round_virtual_stages_to_pp_multiple: down + scale_grads_in_schedule: false + patch_inner_model: false + patch_causal_lm_model: false + + moe: + reshard_after_forward: false + wrap_outer_model: false + +dist_env: + backend: nccl + timeout_minutes: 30 + +model: + _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_config + config: + _target_: nemo_automodel.components.models.mimo_v25.config.MiMoV2Config.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: false + load_base_model: true + backend: + _target_: nemo_automodel.components.models.common.BackendConfig + attn: sdpa + linear: torch + rms_norm: torch_fp32 + rope_fusion: false + dispatcher: deepep + experts: torch_mm + gate_precision: float32 + enable_hf_state_dict_adapter: true + enable_fsdp_optimizations: true + +checkpoint: + enabled: true + checkpoint_dir: checkpoints/mimo_v25_pro + model_save_format: safetensors + save_consolidated: false + dequantize_base_checkpoint: true + +loss_fn: + _target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy + +dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: train + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: true + +packed_sequence: + packed_sequence_size: 0 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: true + +validation_dataset: + _target_: nemo_automodel.components.datasets.llm.hellaswag.HellaSwag + path_or_dataset: rowan/hellaswag + split: validation + num_samples_limit: 64 + tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: XiaomiMiMo/MiMo-V2.5-Pro + trust_remote_code: true + +validation_dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + collate_fn: + _target_: nemo_automodel.components.datasets.utils.default_collater + pad_seq_len_divisible: 64 + shuffle: false + drop_last: true + +optimizer: + _target_: torch.optim.AdamW + betas: [0.9, 0.95] + eps: 1e-8 + lr: 1e-5 + weight_decay: 0.1 + +wandb: + project: automodel-mimo-v25-pro + name: mimo_v25_pro_hellaswag_16n + mode: online From 85f851fe46b75c7e933fd52714b2846463ab5217 Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Tue, 16 Jun 2026 08:38:23 -0400 Subject: [PATCH 6/8] docs(mimo_v25): add model coverage page for MiMo-V2.5-Pro Adds docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx so that test_every_registered_arch_has_model_coverage_doc passes for the newly registered MiMoV2ForCausalLM architecture. Signed-off-by: Simar Signed-off-by: Simar Malhotra --- .../llm/xiaomimimo/mimo-v2-5-pro.mdx | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx diff --git a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx new file mode 100644 index 0000000000..7008ebc8eb --- /dev/null +++ b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx @@ -0,0 +1,92 @@ +--- +title: "MiMo-V2.5-Pro" +description: "" +--- +[MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) is Xiaomi's hybrid attention Mixture-of-Experts language model. It alternates full and sliding-window attention layers, uses a `sigmoid_with_bias` router with group-limited expert routing, and ships as an FP8 HF checkpoint. + + + +| | | +|---|---| +| **Task** | Text Generation (MoE, hybrid attention) | +| **Architecture** | `MiMoV2ForCausalLM` | +| **HF Org** | [XiaomiMiMo](https://huggingface.co/XiaomiMiMo) | + + + +## Available Models + +- **MiMo-V2.5-Pro**: hybrid full/sliding-window attention with FP8 weights. + +## Architecture + +- `MiMoV2ForCausalLM` +- Sliding-window attention via the `MiMoV2Attention(is_swa=True)` path. +- MoE blocks use `nemo_automodel.components.moe.layers.MoE` with `score_func="sigmoid_with_bias"` and `gate_precision=fp32`. +- FP8 round-trip in `MiMoV2StateDictAdapter` covers the bulk of attention/expert weights; layer norms, the gate, `lm_head`, and `embed_tokens` stay in bf16 per `NON_QUANTIZED_KEY_PATTERNS`. + +## Example HF Models + +| Model | HF ID | +|---|---| +| MiMo-V2.5-Pro | [`XiaomiMiMo/MiMo-V2.5-Pro`](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) | + +## Example Recipes + +| Recipe | Description | +|---|---| +| [mimo_v25_pro_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml) | SFT — MiMo-V2.5-Pro on HellaSwag | + +## Try with NeMo AutoModel + +**1. Install** ([full instructions](/get-started/installation)): + +```bash +pip install nemo-automodel +``` + +**2. Clone the repo** to get the example recipes: + +```bash +git clone https://github.com/NVIDIA-NeMo/Automodel.git +cd Automodel +``` + +**3. Run the recipe** from inside the repo: + +```bash +automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml +``` + + +**1. Pull the container** and mount a checkpoint directory: + +```bash +docker run --gpus all -it --rm \ + --shm-size=8g \ + -v $(pwd)/checkpoints:/opt/Automodel/checkpoints \ + nvcr.io/nvidia/nemo-automodel:26.02.00 +``` + +**2. Navigate to the AutoModel directory**: + +```bash +cd /opt/Automodel +``` + +**3. Run the recipe**: + +```bash +automodel --nproc-per-node=8 examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml +``` + + +See the [Installation Guide](/get-started/installation) and [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft). + +## Fine-Tuning + +See the [LLM Fine-Tuning Guide](/recipes-e2e-examples/sft-peft). + +## Hugging Face Model Cards + +- [XiaomiMiMo/MiMo-V2.5-Pro](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro) From f74ad26f2f9ca05a43ba6055a167b610189f4b42 Mon Sep 17 00:00:00 2001 From: Simar Malhotra Date: Tue, 16 Jun 2026 08:41:33 -0400 Subject: [PATCH 7/8] docs(mimo_v25): add MiMo-V2.5-Pro to sidebar nav and fix info table - Adds MiMo-V2.5-Pro page entry to docs/fern/versions/nightly.yml so the page appears in the sidebar alongside MiMo-V2-Flash - Adds missing Parameters row to the Info table to match the MiMo-V2-Flash page format Signed-off-by: Simar Signed-off-by: Simar Malhotra --- docs/fern/versions/nightly.yml | 2 ++ docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx | 1 + 2 files changed, 3 insertions(+) diff --git a/docs/fern/versions/nightly.yml b/docs/fern/versions/nightly.yml index 4a24ccc28f..2eb7fec83b 100644 --- a/docs/fern/versions/nightly.yml +++ b/docs/fern/versions/nightly.yml @@ -169,6 +169,8 @@ navigation: path: ../../model-coverage/llm/tencent/hy-mt2.mdx - page: "MiMo-V2-Flash" path: ../../model-coverage/llm/xiaomimimo/mimo-v2-flash.mdx + - page: "MiMo-V2.5-Pro" + path: ../../model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx - page: "Ling 2.0" path: ../../model-coverage/llm/inclusionai/ling-2.mdx - section: "Vision Language Models" diff --git a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx index 7008ebc8eb..410c149dc5 100644 --- a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx +++ b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx @@ -10,6 +10,7 @@ description: "" |---|---| | **Task** | Text Generation (MoE, hybrid attention) | | **Architecture** | `MiMoV2ForCausalLM` | +| **Parameters** | Approx. several hundred B total / much smaller active | | **HF Org** | [XiaomiMiMo](https://huggingface.co/XiaomiMiMo) | From abe8433b5f2b35c67af154e7924018fe77642b67 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Tue, 16 Jun 2026 15:46:12 -0700 Subject: [PATCH 8/8] Apply suggestions from code review Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> --- docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx index 410c149dc5..867ca25ca9 100644 --- a/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx +++ b/docs/model-coverage/llm/xiaomimimo/mimo-v2-5-pro.mdx @@ -22,7 +22,7 @@ description: "" ## Architecture - `MiMoV2ForCausalLM` -- Sliding-window attention via the `MiMoV2Attention(is_swa=True)` path. +- Sliding-window attention using the `MiMoV2Attention(is_swa=True)` path. - MoE blocks use `nemo_automodel.components.moe.layers.MoE` with `score_func="sigmoid_with_bias"` and `gate_precision=fp32`. - FP8 round-trip in `MiMoV2StateDictAdapter` covers the bulk of attention/expert weights; layer norms, the gate, `lm_head`, and `embed_tokens` stay in bf16 per `NON_QUANTIZED_KEY_PATTERNS`. @@ -36,7 +36,7 @@ description: "" | Recipe | Description | |---|---| -| [mimo_v25_pro_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml) | SFT — MiMo-V2.5-Pro on HellaSwag | +| [mimo_v25_pro_hellaswag.yaml](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/llm_finetune/mimo_v25/mimo_v25_pro_hellaswag.yaml) | SFT: MiMo-V2.5-Pro on HellaSwag | ## Try with NeMo AutoModel