diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index c96de48db5..c8ff05e1e2 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -717,6 +717,8 @@ jobs: runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_PPO runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} + - script: L1_Functional_Tests_SingleController + runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_Eval runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_Other_1 @@ -786,6 +788,8 @@ jobs: runner: ${{ vars.GB200_RUNNER }} - script: L1_Functional_Tests_PPO runner: ${{ vars.GB200_RUNNER }} + - script: L1_Functional_Tests_SingleController + runner: ${{ vars.GB200_RUNNER }} - script: L1_Functional_Tests_Eval runner: ${{ vars.GB200_RUNNER }} - script: L1_Functional_Tests_Other_1 @@ -856,6 +860,8 @@ jobs: runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_PPO runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} + - script: L1_Functional_Tests_SingleController + runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_Eval runner: ${{ needs.org-member-pre-flight.outputs.runner_prefix }} - script: L1_Functional_Tests_Other_1 diff --git a/examples/configs/grpo_math_1B_single_controller.yaml b/examples/configs/grpo_math_1B_single_controller.yaml new file mode 100644 index 0000000000..b20eb9dc9f --- /dev/null +++ b/examples/configs/grpo_math_1B_single_controller.yaml @@ -0,0 +1,358 @@ +# GRPO via SingleController (async-RL) — mirrors grpo_math_1B.yaml with +# data_plane.enabled=true and a top-level async_rl: section holding the +# SC-specific runtime knobs. +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 1 + max_num_epochs: 1 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + val_at_end: false + overlong_filtering: false + advantage_clip_low: null + advantage_clip_high: null + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + stop_properly_penalty_coef: null + + adv_estimator: + name: "grpo" + normalize_rewards: ${grpo.normalize_rewards} + use_leave_one_out_baseline: ${grpo.use_leave_one_out_baseline} + minus_baseline: true + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + seq_logprob_error_threshold: null + invalid_tool_call_advantage: null + malformed_thinking_advantage: null + + async_grpo: + enabled: true + max_trajectory_age_steps: 1 + in_flight_weight_updates: false + recompute_kv_cache_after_weight_updates: false + +loss_fn: + reference_policy_kl_penalty: 0.01 + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + truncated_importance_sampling_type: null + truncated_importance_sampling_ratio: null + truncated_importance_sampling_ratio_min: null + sequence_level_importance_ratios: false + token_level_loss: true + force_on_policy_ratio: false + use_kl_in_reward: false + disable_ppo_ratio: false + positive_example_nll_weight: 0.0 + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-single-controller" + metric_name: "val:accuracy" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + save_optimizer: true + +policy: + model_name: "Qwen/Qwen2.5-1.5B" + tokenizer: + name: ${policy.model_name} + chat_template_kwargs: null + hf_config_overrides: {} + train_global_batch_size: 512 + train_micro_batch_size: 4 + generation_batch_size: 32 + logprob_batch_size: ${policy.train_micro_batch_size} + max_total_sequence_length: 512 + precision: "bfloat16" + logprob_chunk_size: null + offload_optimizer_for_logprob: false + + dtensor_cfg: + _v2: true + enabled: false + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + automodel_kwargs: {} + lora_cfg: + enabled: False + target_modules: [] + exclude_modules: [] + match_all_linear: true + dim: 8 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + use_triton: true + + megatron_cfg: + enabled: true + force_reconvert_from_hf: False + empty_unused_memory_level: 1 + activation_checkpointing: false + recompute_granularity: "full" + recompute_modules: null + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" + moe_router_bias_update_rate: 0.0 + moe_permute_fusion: true + apply_rope_fusion: True + bias_activation_fusion: True + defer_fp32_logits: False + moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "alltoall" + moe_shared_expert_overlap: false + gradient_accumulation_fusion: false + use_fused_weighted_squared_relu: false + peft: + enabled: false + target_modules: [] + exclude_modules: [] + dim: 8 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init_method: "xavier" + lora_B_init_method: "zero" + a2a_experimental: false + lora_dtype: None + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + sgd_momentum: 0.9 + use_distributed_optimizer: true + use_precision_aware_optimizer: true + clip_grad: ${policy.max_grad_norm} + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + env_vars: null + + draft: + enabled: false + model_name: null + loss_weight: 0.1 + num_layers: null + aux_layer_indices: null + + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + port_range_low: 11001 + port_range_high: 15000 + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + mcore_generation_config: + buffer_size_gb: 10 + num_cuda_graphs: 4 + block_size_tokens: 256 + use_cuda_graphs_for_non_decode_steps: true + enable_chunked_prefill: true + unified_memory_level: 0 + max_tokens: 16384 + vllm_cfg: + async_engine: true + precision: ${policy.precision} + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: False + use_tqdm: true + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + enable_vllm_metrics_logger: true + vllm_metrics_logger_interval: 0.5 + vllm_kwargs: {} + colocated: + enabled: false + resources: + gpus_per_node: 1 + num_nodes: 1 + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true + num_workers: 1 + use_multiple_dataloader: false + train: + dataset_name: OpenMathInstruct-2 + split_validation_size: 0.05 + seed: ${grpo.seed} + validation: null + default: + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + processor: "math_hf_data_processor" + env_name: "math" + +env: + math: + num_workers: 8 + math_verify_impl: "hf_math_verify" + +logger: + log_dir: "logs" + num_val_samples_to_print: 0 + wandb_enabled: false + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false + monitor_gpus: true + wandb: + project: "grpo-dev" + name: "grpo-single-controller-dev" + swanlab: + project: "grpo-dev" + name: "grpo-single-controller-dev" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-single-controller-dev" + tracking_uri: "http://localhost:5000" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +# TransferQueue data plane — required by the SingleController path. +data_plane: + enabled: true + impl: transfer_queue + backend: "simple" + storage_capacity: 1000000 + num_storage_units: 2 + claim_meta_poll_interval_s: 0.5 + global_segment_size: 549755813888 + local_buffer_size: 68719476736 + +# SC-specific async-RL runtime knobs. +# One training step consumes grpo.num_prompts_per_step prompt groups. +async_rl: + batch_selection_strategy: "strict_on_policy" # or "staleness_window" + max_weight_staleness_versions: 1 + min_prompt_groups_per_batch: 2 + max_inflight_prompts: 8 + # When over_sampling=false this must equal + # grpo.num_prompts_per_step * (max_weight_staleness_versions + 1). + max_buffered_rollouts: 8 + # True : over-generates and wastes rollouts that age past the staleness window; + # False: enforces per-weight-version dispatch quota. + over_sampling: true + +cluster: + gpus_per_node: 2 + num_nodes: 1 + master_port_range_low: 25000 + master_port_range_high: 28000 diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.yaml new file mode 100644 index 0000000000..ee3f8b97ac --- /dev/null +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.yaml @@ -0,0 +1,33 @@ +# SingleController variant of grpo-llama3.1-8b-instruct-2n8g-async-1off.yaml. +# Same training topology and async-1off strategy; the SC path routes +# everything through the TransferQueue data plane + SingleControllerActor. +# staleness_window + over_sampling=false replicate the per-version dispatch +# quota of the original async-grpo path. +defaults: ./performance/grpo-llama3.1-8b-instruct-2n8g-async-1off.yaml + +logger: + log_dir: logs/grpo-llama3.1-8b-instruct-2n8g-async-1off-sc + wandb: + name: grpo-llama3.1-8b-instruct-2n8g-async-1off-sc + +checkpointing: + checkpoint_dir: results/grpo-llama3.1-8b-instruct-2n8g-async-1off-sc + +# TransferQueue data plane is mandatory for the SingleController path. +data_plane: + enabled: true + +# SC async-RL runtime knobs. +async_rl: + batch_selection_strategy: staleness_window + # Matches grpo.async_grpo.max_trajectory_age_steps=1. + max_weight_staleness_versions: 1 + # One training step consumes grpo.num_prompts_per_step (=64) prompt groups. + min_prompt_groups_per_batch: 64 + max_inflight_prompts: ${grpo.num_prompts_per_step} # match grpo-llama3.1-8b-instruct-2n8g-async-1off + # over_sampling=false requires + # max_buffered_rollouts == grpo.num_prompts_per_step * (max_weight_staleness_versions + 1) + # 64 * (1 + 1) = 128 + max_buffered_rollouts: 128 + over_sampling: false + force_in_order: true diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.yaml new file mode 100644 index 0000000000..949bd31f11 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.yaml @@ -0,0 +1,49 @@ +# SingleController + Megatron variant of grpo-qwen2.5-math-1.5b-instruct-1n8g. +# SC currently only supports non-colocated generation; the 8 GPUs on the node +# are split 4 (train) + 4 (inference). The SC path routes everything through +# the TransferQueue data plane + SingleControllerActor. strict_on_policy +# auto-sets max_weight_staleness_versions=0 and over_sampling=False, enforcing +# a strict per-version dispatch quota. +defaults: ./grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml + +logger: + log_dir: logs/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-sc + wandb: + name: grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-sc + +checkpointing: + checkpoint_dir: results/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-sc + +# TransferQueue data plane is mandatory for the SingleController path. +data_plane: + enabled: true + +# SC async-RL runtime knobs. +async_rl: + batch_selection_strategy: strict_on_policy + # One training step consumes grpo.num_prompts_per_step (=32) prompt groups. + min_prompt_groups_per_batch: 32 + max_inflight_prompts: 64 + # strict_on_policy auto-sets max_weight_staleness_versions=0 and + # over_sampling=False. over_sampling=false requires + # max_buffered_rollouts == grpo.num_prompts_per_step * (max_weight_staleness_versions + 1) + # 32 * (0 + 1) = 32 + max_buffered_rollouts: 32 + over_sampling: false + +policy: + dtensor_cfg: + enabled: false + megatron_cfg: + enabled: true + scheduler: + lr_warmup_iters: 50 + generation: + vllm_cfg: + async_engine: true + colocated: + enabled: false + resources: + # 4 GPUs for inference; remaining 4 GPUs on the node go to training. + gpus_per_node: 4 + num_nodes: 1 diff --git a/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g.yaml b/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g.yaml index d965796558..6ffd52c96a 100644 --- a/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g.yaml @@ -11,6 +11,7 @@ policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: name: meta-llama/Llama-3.1-8B-Instruct + train_global_batch_size: 2048 train_micro_batch_size: 1 logprob_batch_size: 2 max_total_sequence_length: 4096 diff --git a/examples/nemo_gym/run_distillation_nemo_gym.py b/examples/nemo_gym/run_distillation_nemo_gym.py index 7af1e4de6d..1fdcab6c68 100644 --- a/examples/nemo_gym/run_distillation_nemo_gym.py +++ b/examples/nemo_gym/run_distillation_nemo_gym.py @@ -32,9 +32,7 @@ from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.environments.nemo_gym import ( - setup_nemo_gym_config, -) +from nemo_rl.environments.nemo_gym import setup_nemo_gym_config from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import ( load_config, diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index 25a2c18493..2557348a34 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -40,9 +40,7 @@ from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.environments.nemo_gym import ( - setup_nemo_gym_config, -) +from nemo_rl.environments.nemo_gym import setup_nemo_gym_config from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import ( diff --git a/examples/run_grpo_single_controller.py b/examples/run_grpo_single_controller.py new file mode 100644 index 0000000000..bda2fe5499 --- /dev/null +++ b/examples/run_grpo_single_controller.py @@ -0,0 +1,136 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async GRPO launcher driven by the SingleController actor. + +Builds the full SC bundle driver-side via setup_single_controller and hands it +to SingleControllerActor. Mirrors run_grpo.py for config loading so the same YAML +files apply. data_plane.enabled=true is mandatory. +""" + +import argparse +import os +import pprint +import sys + +import ray +from omegaconf import OmegaConf + +from nemo_rl.algorithms.single_controller import SingleControllerActor +from nemo_rl.algorithms.single_controller_utils import ( + MasterConfig, + setup_single_controller, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.nemo_gym import setup_nemo_gym_config +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import ( + load_config, + parse_hydra_overrides, + register_omegaconf_resolvers, +) +from nemo_rl.utils.logger import get_next_experiment_dir + +# Drop examples/ from sys.path so examples/nemo_gym/ (no __init__.py) doesn't +# shadow the real nemo_gym package as a namespace package. +current_dir = os.path.dirname(os.path.abspath(__file__)) +while current_dir in sys.path: + sys.path.remove(current_dir) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run async GRPO training via SingleController" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def main() -> None: + """Main entry point.""" + register_omegaconf_resolvers() + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), + "configs", + "grpo_math_1B_single_controller.yaml", + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config = OmegaConf.to_container(config, resolve=True) + config = MasterConfig(**config) + print("Applied CLI overrides") + + dp_cfg = config.data_plane + if not dp_cfg.get("enabled", False): + raise ValueError( + "run_grpo_single_controller requires data_plane.enabled=true. " + "Use examples/run_grpo.py for the legacy / sync paths." + ) + + print("Final config:") + pprint.pprint(config) + + config.logger["log_dir"] = get_next_experiment_dir(config.logger["log_dir"]) + print(f"📊 Using log directory: {config.logger['log_dir']}") + if config.checkpointing["enabled"]: + print( + f"📊 Using checkpoint directory: {config.checkpointing['checkpoint_dir']}" + ) + + init_ray() + + tokenizer = get_tokenizer(config.policy["tokenizer"]) + assert config.policy["generation"] is not None, ( + "A generation config is required for SC-driven async GRPO" + ) + has_refit_draft_weights = bool(config.policy["draft"]["enabled"]) + config.policy["generation"] = configure_generation_config( + config.policy["generation"], + tokenizer, + has_refit_draft_weights=has_refit_draft_weights, + ) + + # NeMo-Gym specific config setup. + if bool(config.env.get("should_use_nemo_gym")): + setup_nemo_gym_config(config, tokenizer) + + bundle = setup_single_controller(config, tokenizer) + + print("🚀 Launching SingleControllerActor") + sc = SingleControllerActor.remote(master_config=config, bundle=bundle) + result = ray.get(sc.run.remote()) + print(f"SC run complete: {result}") + + # Drain env actors before vLLM shutdown to avoid race-condition 500s on + # in-flight requests. + for handle in bundle.env_handles.values(): + ray.get(handle.shutdown.remote()) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/algorithms/async_utils/replay_buffer.py b/nemo_rl/algorithms/async_utils/replay_buffer.py index 22939bf72b..cab8685de7 100644 --- a/nemo_rl/algorithms/async_utils/replay_buffer.py +++ b/nemo_rl/algorithms/async_utils/replay_buffer.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import threading as _threading +import uuid from collections import Counter +from collections.abc import Mapping from typing import Any, Iterable, Optional import ray from nemo_rl.algorithms.async_utils.interfaces import ReplayBufferProtocol +from nemo_rl.data_plane import KVBatchMeta +from nemo_rl.experience.interfaces import PromptGroupRecord +from nemo_rl.experience.payload import pack_payload, record_to_train_batch # Classes with @ray.remote can't be inherited from, so we split the implementation out. @@ -551,93 +557,164 @@ class ReplayBuffer(ReplayBufferImpl): pass -# WIP: DO NOT USE - This class is WIP and may be changed without notice, please DO NOT USE it. -# Will be replaced by TQReplayBuffer once TQ is ready. -@ray.remote # pragma: no cover -class ReplayBufferNew(ReplayBufferImpl): - """Staleness-window replay buffer. - - -- WIP: DO NOT USE -- - This class is WIP and may be changed without notice, please DO NOT USE it. - - Differences from ReplayBuffer: - - _evict(): Stale rows (trainer_version - weight_version > max_staleness) are evicted - at the start of every sample() call. - - sample(): selects trajectories in freshest-first order (default) or FIFO order, - controlled by the sample_freshest_first flag, from whatever remains in the buffer - after eviction. - - TODO: remove when cleaning up - - max_age_steps won't be used in ReplayBufferNew; - - self.target_weight_versions won't be used in ReplayBufferNew and will be removed - when cleaning up. target_weight_versions gates generation on specific trainer steps, - which causes generation pauses; ReplayBufferNew intentionally avoids this. - - add this class to nemo_rl/algorithms/async_utils/__init__.py +class TQReplayBuffer: + """Meta cache + TQ writer with reserve-then-commit slot semantics. + + meta_list, weight_list, ready_list, _group_ids are parallel; a slot stays + ready=False until commit fills it. """ def __init__( - self, max_size: int, max_staleness: int, sample_freshest_first: bool = True + self, + dp_client: Any, + partition_id: str, + *, + pad_value_dict: Mapping[str, int], ): - super().__init__(max_size) - if max_staleness < 0: - raise ValueError(f"max_staleness must be non-negative, got {max_staleness}") - self.max_staleness = max_staleness - # will move to StalenessSampler when we implement it - self.sample_freshest_first = sample_freshest_first + self._dp_client = dp_client + self._partition_id = partition_id + self._pad_value_dict = dict(pad_value_dict) + self.meta_list: list[Optional[KVBatchMeta]] = [] + self.start_weight_list: list[int] = [] + self.end_weight_list: list[int] = [] + # Per-slot target training step (set when force_in_order=True, else None). + self.target_step_list: list[Optional[int]] = [] + self.ready_list: list[bool] = [] + self._group_ids: list[str] = [] + + def reserve( + self, + *, + weight_version: int, + target_step: Optional[int] = None, + group_id: Optional[str] = None, + ) -> str: + """Append an unready slot tagged with weight_version. - def _evict(self, current_weight_version: int) -> None: - """Evict rows where trainer_version - weight_version > max_staleness. + Args: + weight_version: Weight version stamped on the slot. + target_step: Training step this slot targets; only consulted by StalenessSampler.force_in_order. + group_id: Per-group sample_id prefix; defaults to a fresh uuid4. - Must be called with self._lock held. + Returns: + group_id used by the matching commit. """ - min_valid = current_weight_version - self.max_staleness - stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid] - self._remove_indices(stale) - - def sample( + if group_id is None: + group_id = str(uuid.uuid4()) + self.meta_list.append(None) + self.start_weight_list.append(weight_version) + self.end_weight_list.append(-1) + self.target_step_list.append(target_step) + self.ready_list.append(False) + self._group_ids.append(group_id) + return group_id + + async def commit( self, - num_prompt_groups: int, - current_weight_version: int, - max_age_steps: int, - ) -> Optional[dict[str, Any]]: - """Sample num_prompt_groups trajectories, freshest-first. + group_id: str, + record: PromptGroupRecord, + start_weight_version: int, + end_weight_version: int, + ) -> KVBatchMeta: + """Tensorize record, write N rows to TQ, and mark the slot ready. - Will evict stale rows before sampling, so we will get [current_weight_version - self.max_staleness, current_weight_version] valid trajectories. + Args: + group_id: group_id returned by the matching reserve call. + record: PromptGroupRecord to tensorize. + start_weight_version: Weight version stamped on the slot before rollout. + The same as the one from reserve, passed again to avoid race condition when lookup. + end_weight_version: Weight version stamped on the slot after rollout. Returns: - Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None. + KVBatchMeta for the committed group. + + Raises: + ValueError: group_id has no live slot (removed or never reserved). """ - with self._lock: - self._evict(current_weight_version) + train_batch = record_to_train_batch(record, pad_value_dict=self._pad_value_dict) + sample_ids, fields, tags = pack_payload( + train_batch, weight_version=start_weight_version, group_id=group_id + ) + await self._call_dp( + "put_samples", + sample_ids=sample_ids, + partition_id=self._partition_id, + fields=fields, + tags=tags, + ) - if not self.trajectories: - return None + # mirrors kv_first_write + lengths = train_batch["input_lengths"] + meta = KVBatchMeta( + partition_id=self._partition_id, + task_name="train", + sample_ids=list(sample_ids), + fields=list(fields.keys()), + sequence_lengths=[int(s) for s in lengths.tolist()], + tags=[dict(t) for t in tags], + ) - all_indices = range(len(self.trajectory_versions)) - if self.sample_freshest_first: - all_indices = sorted( - all_indices, - key=lambda i: self.trajectory_versions[i], - reverse=True, - ) + idx = self._group_ids.index(group_id) + self.meta_list[idx] = meta + self.end_weight_list[idx] = end_weight_version + self.ready_list[idx] = True + return meta - if len(all_indices) < num_prompt_groups: - print( - f"Insufficient trajectories: have {len(all_indices)}, " - f"need {num_prompt_groups}. Waiting." - ) - return None + async def remove(self, idxs: list[int], remove_in_dp: bool) -> int: + """Drop entries at the given indices and optionally clear them from DataPlane. - selected = all_indices[:num_prompt_groups] - sampled_weights = [self.trajectory_versions[i] for i in selected] - avg_trajectory_age = current_weight_version - sum(sampled_weights) / len( - sampled_weights + Args: + idxs: Entry indices to drop. Must be within [0, size). + remove_in_dp: If True, also clear the dropped rows from DataPlane. + + Returns: + Number of group entries removed from the buffer. + """ + if len(idxs) == 0: + return 0 + + drop_idxs = sorted(idxs, reverse=True) + if drop_idxs[0] >= len(self.meta_list): + raise IndexError( + f"TQReplayBuffer.remove: indices out of range: {drop_idxs[0]}; " + f"size={len(self.meta_list)}" ) - sampled_items = [self.trajectories[i] for i in selected] - self._remove_indices(selected) + dropped_sample_ids: list[str] = [] + for i in drop_idxs: + meta = self.meta_list[i] + if meta is not None: + dropped_sample_ids.extend(meta.sample_ids) + del self.meta_list[i] + del self.start_weight_list[i] + del self.end_weight_list[i] + del self.target_step_list[i] + del self.ready_list[i] + del self._group_ids[i] + + if remove_in_dp: + await self._call_dp( + "clear_samples", + sample_ids=dropped_sample_ids, + partition_id=self._partition_id, + ) - return { - "trajectories": sampled_items, - "avg_trajectory_age": avg_trajectory_age, - } + return len(drop_idxs) + + def size(self) -> int: + """Return the number of prompt-group entries currently held.""" + return len(self.meta_list) + + def __len__(self) -> int: + return len(self.meta_list) + + async def _call_dp(self, method_name: str, **kwargs: Any) -> Any: + """Call a DataPlaneClient method, awaiting Ray remotes if needed.""" + method = getattr(self._dp_client, method_name) + remote = getattr(method, "remote", None) + if remote is not None: + return await remote(**kwargs) + result = method(**kwargs) + if asyncio.iscoroutine(result): + return await result + return result diff --git a/nemo_rl/algorithms/async_utils/staleness_sampler.py b/nemo_rl/algorithms/async_utils/staleness_sampler.py new file mode 100644 index 0000000000..144dabdd3d --- /dev/null +++ b/nemo_rl/algorithms/async_utils/staleness_sampler.py @@ -0,0 +1,160 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prompt-group selection over a TQReplayBuffer.""" + +from nemo_rl.algorithms.async_utils.replay_buffer import TQReplayBuffer +from nemo_rl.data_plane import KVBatchMeta + + +class StalenessSampler: + """Pick complete prompt groups from a TQReplayBuffer. + + Args: + buffer: Shared TQReplayBuffer holding the candidate slots. + max_staleness_versions: Max weight-version gap a sample may have from the trainer. + sample_freshest_first: Prefer smallest lag when picking from the in-window set. + require_order: Take only from the oldest in-window weight_version and wait for its batch to fill. + force_in_order: Match each slot's target_step against current_train_weight, ignoring the window; mirrors legacy async_grpo target_weight semantics. + """ + + def __init__( + self, + buffer: TQReplayBuffer, + max_staleness_versions: int, + sample_freshest_first: bool = False, + require_order: bool = False, + force_in_order: bool = False, + ) -> None: + if max_staleness_versions < 0: + raise ValueError( + f"max_staleness_versions must be non-negative, got " + f"{max_staleness_versions}" + ) + if require_order and sample_freshest_first: + raise ValueError( + "require_order and sample_freshest_first are mutually exclusive" + ) + self._buffer = buffer + self.max_staleness_versions = max_staleness_versions + self.sample_freshest_first = sample_freshest_first + self.require_order = require_order + self.force_in_order = force_in_order + + async def select( + self, + *, + current_train_weight: int, + min_prompt_groups: int, + max_prompt_groups: int, + ) -> tuple[KVBatchMeta | None, int]: + """Concat up to max_prompt_groups eligible groups and drop them from the buffer. + + Eligibility = ready and weight in + [current_train_weight - max_staleness_versions, current_train_weight]. + DataPlane rows survive the local drop; caller clears them at step boundary. + + Args: + current_train_weight: Current trainer weight version. + min_prompt_groups: Minimum groups required; returns (None, 0) below this. + max_prompt_groups: Cap on groups returned when the threshold is met. + + Returns: + meta: Concatenated KVBatchMeta, or None if not enough groups. + num_groups: Number of prompt groups in meta; 0 when meta is None. + """ + if min_prompt_groups < 1: + raise ValueError(f"min_prompt_groups must be >= 1, got {min_prompt_groups}") + if max_prompt_groups < min_prompt_groups: + raise ValueError( + f"max_prompt_groups ({max_prompt_groups}) must be >= " + f"min_prompt_groups ({min_prompt_groups})" + ) + + if self.force_in_order: + # target_step exact match; staleness window ignored. + valid_idxs = [ + i + for i, target in enumerate(self._buffer.target_step_list) + if target == current_train_weight and self._buffer.ready_list[i] + ] + else: + min_valid_version = max( + 0, current_train_weight - self.max_staleness_versions + ) + if self.require_order: + in_window = [ + weight + for weight in self._buffer.start_weight_list + if min_valid_version <= weight <= current_train_weight + ] + if not in_window: + return None, 0 + target_version = min(in_window) + valid_idxs = [ + i + for i, weight in enumerate(self._buffer.start_weight_list) + if weight == target_version and self._buffer.ready_list[i] + ] + else: + valid_idxs = [ + i + for i, weight in enumerate(self._buffer.start_weight_list) + if min_valid_version <= weight <= current_train_weight + and self._buffer.ready_list[i] + ] + + if len(valid_idxs) < min_prompt_groups: + return None, 0 + + if self.sample_freshest_first: + valid_idxs.sort( + key=lambda i: ( + current_train_weight - self._buffer.start_weight_list[i], + i, + ) + ) + + requested_groups = min(len(valid_idxs), max_prompt_groups) + selected_idxs = valid_idxs[:requested_groups] + selected_metas = [self._buffer.meta_list[i] for i in selected_idxs] + + await self._buffer.remove(selected_idxs, remove_in_dp=False) + + return ( + selected_metas[0].concat(*selected_metas[1:]), # type: ignore + len(selected_idxs), + ) + + async def evict(self, *, current_train_weight: int) -> int: + """Drop groups whose weight falls below the staleness window. + + Future entries (weight > current_train_weight) are left alone. + + Args: + current_train_weight: Current trainer weight version; groups with + weight < current_train_weight - max_staleness_versions are dropped. + + Returns: + Number of group entries removed from the buffer. + """ + min_valid_version = max(0, current_train_weight - self.max_staleness_versions) + stale_idxs = [ + i + for i, weight in enumerate(self._buffer.start_weight_list) + if weight < min_valid_version + ] + if not stale_idxs: + return 0 + return await self._buffer.remove(stale_idxs, remove_in_dp=True) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5435df757a..0d6453a6e7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -23,7 +23,6 @@ import ray import torch from pydantic import BaseModel -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -74,12 +73,7 @@ prepare_segment_topology, ) from nemo_rl.environments.interfaces import EnvironmentInterface -from nemo_rl.environments.nemo_gym import ( - NemoGym, - NemoGymConfig, - get_nemo_gym_uv_cache_dir, - get_nemo_gym_venv_dir, -) +from nemo_rl.environments.nemo_gym import spinup_nemo_gym_actor from nemo_rl.experience.rollouts import ( EffortLevelsConfig, run_async_multi_turn_rollout, @@ -460,60 +454,11 @@ def init_train_dataloader(dataset, suffix: str = ""): # spinup can overlap with vLLM model loading via deferred model load. enable_nemo_gym = _should_use_nemo_gym(master_config) nemo_gym_actor = None - if enable_nemo_gym: - nemo_gym_num_nodes = env_configs.get("nemo_gym", {}).get("num_gpu_nodes", 0) - ray_runtime_ctx = ray.get_runtime_context() - ray_cur_node_id = ray_runtime_ctx.get_node_id() - else: - nemo_gym_num_nodes = 0 - ray_cur_node_id = None def _spinup_nemo_gym(base_urls, model_name): """Spin up the NeMo Gym actor against the given generation server URLs.""" t0 = time.perf_counter() - nemo_gym_py_exec = get_actor_python_env("nemo_rl.environments.nemo_gym.NemoGym") - if nemo_gym_py_exec.startswith("uv"): - nemo_gym_py_exec = create_local_venv_on_each_node( - nemo_gym_py_exec, "nemo_rl.environments.nemo_gym.NemoGym" - ) - nemo_gym_dict = env_configs["nemo_gym"] - # NeMo-RL-side detection knobs are top-level NemoGymConfig fields - # (where the detector reads them), not part of Gym's global config. - invalid_tool_call_patterns = nemo_gym_dict.pop( - "invalid_tool_call_patterns", None - ) - thinking_tags = nemo_gym_dict.pop("thinking_tags", None) - # Pass prebuilt cache + venv dirs through the global config so the gym reuses - # image-baked venvs instead of rebuilding them. - uv_cache_dir = get_nemo_gym_uv_cache_dir() - if uv_cache_dir is not None: - nemo_gym_dict.setdefault("uv_cache_dir", uv_cache_dir) - uv_venv_dir = get_nemo_gym_venv_dir() - if uv_venv_dir is not None: - nemo_gym_dict.setdefault("uv_venv_dir", uv_venv_dir) - nemo_gym_cfg = NemoGymConfig( - model_name=model_name, - base_urls=base_urls, - invalid_tool_call_patterns=invalid_tool_call_patterns, - thinking_tags=thinking_tags, - initial_global_config_dict=nemo_gym_dict, - ) - nemo_gym_opts = {} - if nemo_gym_num_nodes: - nemo_gym_opts["scheduling_strategy"] = NodeAffinitySchedulingStrategy( - node_id=ray_cur_node_id, - soft=True, - ) - nemo_gym_opts["runtime_env"] = { - "py_executable": nemo_gym_py_exec, - "env_vars": { - **os.environ, - "VIRTUAL_ENV": nemo_gym_py_exec, - "UV_PROJECT_ENVIRONMENT": nemo_gym_py_exec, - }, - } - actor = NemoGym.options(**nemo_gym_opts).remote(nemo_gym_cfg) - ray.get(actor._spinup.remote()) + actor = spinup_nemo_gym_actor(env_configs, base_urls, model_name) return actor, time.perf_counter() - t0 total_nodes = cluster_config["num_nodes"] diff --git a/nemo_rl/algorithms/single_controller.py b/nemo_rl/algorithms/single_controller.py new file mode 100644 index 0000000000..c9057fc295 --- /dev/null +++ b/nemo_rl/algorithms/single_controller.py @@ -0,0 +1,665 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SingleController: asyncio orchestrator for the RL training loop. + +CPU-only Ray actor that runs three concurrent pumps and coordinates the +other actors via lightweight RPCs. SC sends control signals and reads +metadata only — model tensors still move through DataPlane or NCCL. + +Data flow: + _rollout_pump → rollout_manager.generate_and_push(prompt) + → TQReplayBuffer.reserve claims a slot at dispatch time; + run_rollout runs; TQReplayBuffer.commit tensorizes the + record, writes N training rows to TQ, and marks ready. + _train_pump → sampler.evict → buffer.remove (stale groups, with DP clear). + → sampler.select → drops chosen groups from buffer, returns + KVBatchMeta of K groups (or None); meta is already trainable. + → _advantage_pump (get → compute → put). + → trainer.train_on_meta. + → dp_client.clear_samples (trained groups; buffer already dropped). + _sync_weights → drain _inflight_rollouts → WeightSynchronizer.sync_weights. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any, Optional, Union + +import ray +import torch + +from nemo_rl.algorithms.async_utils.staleness_sampler import StalenessSampler +from nemo_rl.algorithms.single_controller_utils.config import ( + AdvantageConfig, + MasterConfig, + WeightSyncConfig, +) +from nemo_rl.algorithms.single_controller_utils.setup import SingleControllerBundle +from nemo_rl.algorithms.single_controller_utils.utils import ( + aggregate_step_metrics, + fields_for_put, + reduce_advantage_pump_metrics, + squeeze_trailing_unit_dim, + tensor_field, +) +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data_plane import KVBatchMeta +from nemo_rl.models.generation.sglang.sglang_generation import SGLangGeneration +from nemo_rl.models.generation.vllm import VllmGeneration +from nemo_rl.models.policy.tq_policy import TQPolicy +from nemo_rl.utils.logger import Logger +from nemo_rl.utils.timer import Timer + +Generation = Union[VllmGeneration, SGLangGeneration] + + +@ray.remote(num_cpus=1, num_gpus=0) # pragma: no cover +class SingleControllerActor: + """CPU-only Ray actor that orchestrates the RL training loop. + + Owns three concurrent asyncio tasks: + - _rollout_pump: dispatches prompts via RolloutManager; reserve+commit in + TQReplayBuffer preserves dispatch order + - _train_pump: evicts stale groups, samples a batch, trains, drops it + - _sync_weights: drain gate + weight synchronization + + All other actors are passive — they expose methods and wait to be called. + """ + + def __init__( + self, + master_config: MasterConfig, + bundle: SingleControllerBundle, + ) -> None: + """Initialize the SingleController actor. + + Args: + master_config: SC MasterConfig. + bundle: Pre-built bundle from setup_single_controller. Tests can + construct a bundle by hand (or with fakes) to bypass the real factories. + """ + self._advantage_cfg = AdvantageConfig() + self._weight_sync_cfg = WeightSyncConfig() + self._partition_id: str = bundle.partition_id + self._diagnostics: bool = False + + self._master_config = master_config + self._async_cfg = master_config.async_rl + self._dp_client = bundle.dp_client + self._gen: Generation = bundle.gen_handle + self._trainer: TQPolicy = bundle.trainer_handle + self._dataloader = bundle.dataloader + self._weight_synchronizer = bundle.weight_synchronizer + self._advantage_estimator = bundle.advantage_estimator + self._loss_fn = bundle.loss_fn + self._buffer = bundle.tq_buffer + self._rollout_manager = bundle.rollout_manager + # Rebind so writer and sampler share one buffer instance even + # when Ray deserializes rollout_manager and tq_buffer separately. + self._rollout_manager._tq_buffer = self._buffer + + # Built here, not on the driver: Logger backends (wandb/tb/...) hold + # _thread.lock that Ray can't cloudpickle into the actor. + self._logger = Logger(master_config.logger) # type: ignore + self._timer = Timer() + + # Pin clusters so RayVirtualCluster.__del__ doesn't remove the PGs. + self._train_cluster = bundle.train_cluster + self._inference_cluster = bundle.inference_cluster + + num_prompts_per_step = self._master_config.grpo["num_prompts_per_step"] + if num_prompts_per_step < self._async_cfg.min_prompt_groups_per_batch: + raise ValueError( + f"grpo.num_prompts_per_step ({num_prompts_per_step}) " + f"must be >= async_rl.min_prompt_groups_per_batch " + f"({self._async_cfg.min_prompt_groups_per_batch})" + ) + + if self._async_cfg.batch_selection_strategy == "strict_on_policy": + self._async_cfg.max_weight_staleness_versions = 0 + self._async_cfg.over_sampling = False + print( + "Using strict_on_policy, auto setting max_weight_staleness_versions to 0 and over_sampling to False.", + flush=True, + ) + + if not self._async_cfg.over_sampling: + expected_buffer = num_prompts_per_step * ( + self._async_cfg.max_weight_staleness_versions + 1 + ) + if self._async_cfg.max_buffered_rollouts != expected_buffer: + raise ValueError( + f"over_sampling=False requires max_buffered_rollouts " + f"({self._async_cfg.max_buffered_rollouts}) == " + f"num_prompts_per_step * (max_weight_staleness_versions + 1) " + f"({expected_buffer})" + ) + + if self._async_cfg.force_in_order and self._async_cfg.over_sampling: + raise ValueError( + "force_in_order=True requires over_sampling=False so that each " + "dispatched batch corresponds to exactly one target training step." + ) + + # SC split path does one optimizer.step per RL step. + # TODO: support multi-mini-step (legacy train() does gbs-sized + # mini-steps with shared prev_logprobs). + rl_step_samples = ( + num_prompts_per_step + * self._master_config.grpo["num_generations_per_prompt"] + ) + train_gbs = self._master_config.policy["train_global_batch_size"] + if rl_step_samples != train_gbs: + raise ValueError( + f"num_prompts_per_step * num_generations_per_prompt " + f"({rl_step_samples}) must equal policy.train_global_batch_size " + f"({train_gbs}) so that one RL step maps to exactly one " + f"optimizer.step. Multi-mini-step inside a single RL step is " + f"not supported on the SC split path." + ) + + self._sampler = StalenessSampler( + self._buffer, + max_staleness_versions=self._async_cfg.max_weight_staleness_versions, + require_order=not self._async_cfg.over_sampling, + force_in_order=self._async_cfg.force_in_order, + ) + + # ── asyncio state ────────────────────────────────────────────────── + # Gate: cleared during _sync_weights, set when generation may proceed + self._rollout_permitted: asyncio.Event = asyncio.Event() + self._rollout_permitted.set() + + # Count of in-flight generate_and_push calls + self._inflight_rollouts: int = 0 + + # Cancellation handles for in-flight rollout dispatches. + self._dispatched_rollouts: set[asyncio.Task[None]] = set() + + # over_sampling=False batch gate: farthest trainer_version covered by + # already-dispatched batches. + self._max_rollout_version: int = -1 + + # Backpressure valve: max unconsumed rollout groups allowed in DataPlane. + # Acquired before each rollout dispatch; released when the buffer + # drops a group (sampler.evict or post-train buffer.remove). + self._buffer_capacity: asyncio.Semaphore = asyncio.Semaphore( + self._async_cfg.max_buffered_rollouts + ) + + self._trainer_version: int = 0 + self._train_steps: int = 0 + self._step_consumed_sample_ids: list[str] = [] + self._step_log_dict: dict[str, list] = { + "rewards": [], + "masked_advantages": [], + "sequence_lengths": [], + } + + print( + f"SingleControllerActor: " + f"staleness_cap={self._async_cfg.max_weight_staleness_versions} " + f"buffer={self._async_cfg.max_buffered_rollouts} " + f"inflight={self._async_cfg.max_inflight_prompts} " + f"over_sampling={self._async_cfg.over_sampling} " + f"transport={self._weight_sync_cfg.transport}", + flush=True, + ) + + # ── public API ───────────────────────────────────────────────────────── + + async def run(self) -> dict[str, Any]: + """Main entry point. Runs until max_train_steps is reached.""" + # Synchronize weights before starting the pumps + await self._sync_weights() + + # Start the rollout and train pumps + rollout_task = asyncio.create_task(self._rollout_pump()) + train_task = asyncio.create_task(self._train_pump()) + + # Wait until the train pump is done + await train_task + self._logger.finish() + + # Cancel the rollout pump and any in-flight dispatches so we exit immediately. + rollout_task.cancel() + try: + await rollout_task + except asyncio.CancelledError: + pass + inflight = list(self._dispatched_rollouts) + for task in inflight: + task.cancel() + if inflight: + await asyncio.gather(*inflight, return_exceptions=True) + + return { + "train_steps": self._train_steps, + "trainer_version": self._trainer_version, + } + + async def ping(self) -> dict[str, Any]: + """Liveness check — returns immediately if event loop is running.""" + return { + "alive": True, + "trainer_version": self._trainer_version, + "train_steps": self._train_steps, + "inflight_rollouts": self._inflight_rollouts, + "rollout_permitted": self._rollout_permitted.is_set(), + } + + # ── internal helpers ─────────────────────────────────────────────────── + + async def _ray_get(self, obj_ref: Any) -> Any: + """Await a Ray ObjectRef without blocking the asyncio event loop.""" + return await obj_ref + + async def _call_dp(self, method_name: str, **kwargs) -> Any: + """Call a DataPlaneClient method or a Ray actor exposing that method.""" + method = getattr(self._dp_client, method_name) + remote = getattr(method, "remote", None) + if remote is not None: + return await self._ray_get(remote(**kwargs)) + result = method(**kwargs) + if asyncio.iscoroutine(result): + return await result + return result + + # ── the three pumps + advantage helper ──────────────────────────────── + + async def _rollout_pump(self) -> None: + """Continuously dispatch rollout tasks until cancellation. + + Per batch (over_sampling=False): + 0. Wait while _max_rollout_version >= trainer_version + max_staleness, + then claim the next step by incrementing _max_rollout_version. + + Per prompt: + 1. Acquire _buffer_capacity slot (backpressure) + 2. Acquire sem (cap concurrent in-flight rollouts) + 3. Wait for _rollout_permitted (paused during weight sync) + 4. Call rollout_manager.generate_and_push(prompt) — local async + RolloutManager reserves a slot, runs the rollout, then commits the + group via TQReplayBuffer (→ dp_client.put_samples + mark ready) + 5. Decrement _inflight_rollouts + """ + sem = asyncio.Semaphore(self._async_cfg.max_inflight_prompts) + over_sampling = self._async_cfg.over_sampling + max_staleness = self._async_cfg.max_weight_staleness_versions + force_in_order = self._async_cfg.force_in_order + print("rollout_pump: starting", flush=True) + + async def _dispatch_one_prompt( + prompt: DatumSpec, target_step: Optional[int] + ) -> None: + self._inflight_rollouts += 1 + try: + await self._rollout_manager.generate_and_push( + prompt, target_step=target_step + ) + if self._diagnostics: + content = "" + for i in range(len(prompt["message_log"])): + if prompt["message_log"][i]["role"] == "user": + content = prompt["message_log"][i]["content"] + break + print(f" rollout done for prompt='{content[:20]}...'", flush=True) + finally: + self._inflight_rollouts -= 1 + sem.release() + + max_epochs = self._master_config.grpo["max_num_epochs"] + epoch = 0 + while max_epochs is None or epoch < max_epochs: + for prompt_batch in self._dataloader: + # over_sampling=False: batch-level gate on max_rollout_version. + if not over_sampling: + while ( + self._max_rollout_version + >= self._trainer_version + max_staleness + ): + await asyncio.sleep(0.005) + self._max_rollout_version += 1 + + # target_step = batch dispatch index when force_in_order is on. + target_step = self._max_rollout_version if force_in_order else None + + for prompt_idx in range(prompt_batch.size): + prompt: DatumSpec = { # type: ignore + k: v[prompt_idx] for k, v in prompt_batch.items() + } + + # check if buffer is full + await self._buffer_capacity.acquire() + # check if inflight rollouts is full + await sem.acquire() + # wait for rollout to be permitted + await self._rollout_permitted.wait() + + # dispatch rollout + task = asyncio.create_task( + _dispatch_one_prompt(prompt, target_step) + ) + self._dispatched_rollouts.add(task) + task.add_done_callback(self._dispatched_rollouts.discard) + epoch += 1 + + print(f"rollout_pump: completed {epoch} epoch(s)", flush=True) + + async def _train_pump(self) -> None: + """Drain stale groups, sample, train, drop. + + Per step: + 1. sampler.evict drops stale groups from the buffer and clears their TQ rows. + 2. sampler.select returns K prompt groups (or None) and drops them from the + buffer; DP rows survive so the trainer can read them. Already trainable — + buffer wrote training-shaped rows at rollout time. + 3. _advantage_pump(train_meta). + 4. trainer.train_microbatch_from_meta + finish_train_step. + 5. dp_client.clear_samples on consumed sample_ids; release _buffer_capacity + per dropped group, then sync. + """ + adv_cfg = self._advantage_cfg + grpo_cfg = self._master_config.grpo + + # TODO: fix the compute_prev_logprobs and compute_reference_logprobs logic + compute_prev_logprobs = adv_cfg.policy_logprobs_field is not None + compute_reference_logprobs = adv_cfg.reference_logprobs_field is not None + + while self._train_steps < grpo_cfg["max_num_steps"]: + step_id = f"sc-step-{self._train_steps:06d}" + groups_dispatched = 0 + step_open = False + + with self._timer.time("total_step_time"): + while groups_dispatched < grpo_cfg["num_prompts_per_step"]: + await asyncio.sleep(0) + + # evict stale groups + evicted = await self._sampler.evict( + current_train_weight=self._trainer_version, + ) + if evicted: + print(f" evicted {evicted} stale prompt group(s)", flush=True) + for _ in range(evicted): + self._buffer_capacity.release() + + # Get train data + with self._timer.time("exposed_generation"): + max_prompt_groups = ( + grpo_cfg["num_prompts_per_step"] - groups_dispatched + ) + min_prompt_groups = min( + self._async_cfg.min_prompt_groups_per_batch, + max_prompt_groups, + ) + train_meta, num_groups = await self._sampler.select( + current_train_weight=self._trainer_version, + min_prompt_groups=min_prompt_groups, + max_prompt_groups=max_prompt_groups, + ) + + if train_meta is None: + await asyncio.sleep(0.05) + continue + + for _ in range(num_groups): + self._buffer_capacity.release() + + # Compute prev_logprobs / ref_logprobs + with self._timer.time("logprob_inference_prep"): + await asyncio.to_thread(self._trainer.prepare_for_lp_inference) + with self._timer.time("policy_and_reference_logprobs"): + if compute_prev_logprobs: + await asyncio.to_thread( + self._trainer.get_logprobs_from_meta, train_meta + ) + if compute_reference_logprobs: + await asyncio.to_thread( + self._trainer.get_reference_policy_logprobs_from_meta, + train_meta, + ) + + with self._timer.time("advantage_calculation"): + train_meta = await self._advantage_pump(train_meta) + + # Train + with self._timer.time("training_prep"): + await asyncio.to_thread(self._trainer.prepare_for_training) + with self._timer.time("policy_training"): + if not step_open: + await asyncio.to_thread( + self._trainer.begin_train_step, + step_id, + loss_fn=self._loss_fn, + ) + step_open = True + await asyncio.to_thread( + self._trainer.train_microbatch_from_meta, + step_id, + train_meta, + ) + + groups_dispatched += num_groups + self._step_consumed_sample_ids.extend(train_meta.sample_ids) + if train_meta.sequence_lengths: + self._step_log_dict["sequence_lengths"].extend( + int(s) for s in train_meta.sequence_lengths + ) + + if not step_open: + print( + "train_pump: rollout exhausted before any group ready", + flush=True, + ) + break + + with self._timer.time("policy_training"): + result = await asyncio.to_thread( + self._trainer.finish_train_step, step_id + ) + consumed_ids = list(self._step_consumed_sample_ids) + self._step_consumed_sample_ids = [] + await self._call_dp( + "clear_samples", + sample_ids=list(consumed_ids), + partition_id=self._partition_id, + ) + + step_metrics = aggregate_step_metrics(result) + step_metrics.update( + reduce_advantage_pump_metrics(**self._step_log_dict) + ) + self._step_log_dict = {k: [] for k in self._step_log_dict} + + self._trainer_version += 1 + self._train_steps += 1 + with self._timer.time("weight_sync"): + await self._sync_weights() + + timing_metrics: dict[str, float] = self._timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + + total_time = timing_metrics.get("total_step_time", 0.0) + cluster_cfg = self._master_config.cluster + total_num_gpus = cluster_cfg["num_nodes"] * cluster_cfg["gpus_per_node"] + if total_time > 0 and "global_valid_toks" in step_metrics: + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + step_metrics["global_valid_toks"] / total_time / total_num_gpus + ) + + print("\n⏱️ Timing:") + print(f" • Total step time: {total_time:.2f}s") + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k == "total_step_time": + continue + percent = (v / total_time * 100) if total_time > 0 else 0.0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + # TODO: checkpointing (save_period/top-k metric_name, + # policy.save_checkpoint, dataloader state, TQReplayBuffer state). + # TODO: per-step train_data jsonl dump, vllm metrics logger, + # histogram log, rollout_metrics, seq_logprob_error_metrics, + # pretty-print "Training Results" block, print_performance_metrics. + print(f"step_metrics={step_metrics}", flush=True) + self._logger.log_metrics( + step_metrics, step=self._train_steps, prefix="train" + ) + self._logger.log_metrics( + timing_metrics, step=self._train_steps, prefix="timing/train" + ) + self._timer.reset() + + # min sample version refers to the version each consumed sample was + # generated with; lag = current trainer version - oldest sample version. + min_sample_version = min(t["weight_version"] for t in train_meta.tags) # type: ignore + lag = self._trainer_version - min_sample_version + print( + f"train step {self._train_steps}/{grpo_cfg['max_num_steps']} " + f"trainer_v={self._trainer_version} " + f"lag={lag} " + f"batch_size={len(consumed_ids)}", + flush=True, + ) + + async def _sync_weights(self) -> None: + """Drain in-flight rollouts then synchronize weights. + + SC owns the drain gate (when to sync); WeightSynchronizer owns how. + + Flow: + 1. _rollout_permitted.clear() — no new dispatches + 2. drain _inflight_rollouts → 0 (5ms poll) + 3. weight_synchronizer.sync_weights(trainer_version) + 4. _rollout_permitted.set() — resume + """ + self._rollout_permitted.clear() + + # TODO: currently sync_weights is not implemented, comment out for now + # # Drain: wait for all in-flight rollouts to complete before NCCL + # # Critical: if GenWorker has queued calls when NCCL init is dispatched, + # # the init sits behind them — trainer blocks in rendezvous → deadlock + # drain_start = time.monotonic() + # while self._inflight_rollouts > 0: + # await asyncio.sleep(0.005) + + # drain_elapsed = time.monotonic() - drain_start + # print( + # f" _sync_weights: drained in {drain_elapsed:.3f}s, " + # f"syncing weights v{self._trainer_version}", + # flush=True, + # ) + + t0 = time.monotonic() + await asyncio.to_thread(self._weight_synchronizer.sync_weights) + elapsed = time.monotonic() - t0 + + print(f" _sync_weights: sync done in {elapsed:.3f}s", flush=True) + self._rollout_manager.set_weight_version(self._trainer_version) + self._rollout_permitted.set() + + async def _advantage_pump(self, meta: KVBatchMeta) -> KVBatchMeta: + """Fetch advantage inputs, compute advantages, and write them back. + + SC owns the prompt-group-scoped advantage stage because the selected + ``KVBatchMeta`` still contains complete prompt groups before trainer + DP sharding. Tensor payloads still move through DataPlane: SC fetches + only the configured advantage input columns and writes the computed + ``advantages`` column back under the same ``sample_ids``. + """ + if self._advantage_estimator is None: + return meta + adv_cfg = self._advantage_cfg + + data = await self._call_dp( + "get_samples", + sample_ids=meta.sample_ids, + partition_id=meta.partition_id, + select_fields=self._advantage_input_fields(), + ) + + prompt_ids = tensor_field(data, adv_cfg.prompt_ids_field) + rewards = squeeze_trailing_unit_dim( + tensor_field(data, adv_cfg.reward_field) + ).float() + self._step_log_dict["rewards"].append(rewards.detach()) + token_mask = tensor_field(data, adv_cfg.token_mask_field).float() + sample_mask = squeeze_trailing_unit_dim( + tensor_field(data, adv_cfg.sample_mask_field) + ).float() + mask = token_mask * sample_mask.unsqueeze(-1) + + repeated_batch: dict[str, torch.Tensor] = { + "total_reward": rewards, + } + for field_name in adv_cfg.repeated_batch_fields: + repeated_batch[field_name] = squeeze_trailing_unit_dim( + tensor_field(data, field_name) + ) + + kwargs: dict[str, torch.Tensor] = {} + if adv_cfg.policy_logprobs_field is not None: + kwargs["logprobs_policy"] = tensor_field( + data, + adv_cfg.policy_logprobs_field, + ) + if adv_cfg.reference_logprobs_field is not None: + kwargs["logprobs_reference"] = tensor_field( + data, + adv_cfg.reference_logprobs_field, + ) + + advantages = self._advantage_estimator.compute_advantage( + prompt_ids=prompt_ids, + rewards=rewards, + mask=mask, + repeated_batch=repeated_batch, + **kwargs, + ) + self._step_log_dict["masked_advantages"].append( + torch.masked_select(advantages.detach(), mask.bool()) + ) + + await self._call_dp( + "put_samples", + sample_ids=meta.sample_ids, + partition_id=meta.partition_id, + fields=fields_for_put( + meta, + {adv_cfg.output_field: advantages}, + ), + ) + return meta.with_fields([adv_cfg.output_field]) + + # ── utility helpers ──────────────────────────────────────────────────── + + def _advantage_input_fields(self) -> list[str]: + adv_cfg = self._advantage_cfg + fields = [ + adv_cfg.prompt_ids_field, + adv_cfg.reward_field, + adv_cfg.token_mask_field, + adv_cfg.sample_mask_field, + *adv_cfg.repeated_batch_fields, + ] + if adv_cfg.policy_logprobs_field is not None: + fields.append(adv_cfg.policy_logprobs_field) + if adv_cfg.reference_logprobs_field is not None: + fields.append(adv_cfg.reference_logprobs_field) + return list(dict.fromkeys(fields)) diff --git a/nemo_rl/algorithms/single_controller_utils/__init__.py b/nemo_rl/algorithms/single_controller_utils/__init__.py new file mode 100644 index 0000000000..a2c7001f92 --- /dev/null +++ b/nemo_rl/algorithms/single_controller_utils/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SingleController utilities: config schema + setup factories.""" + +from nemo_rl.algorithms.single_controller_utils.config import ( + AdvantageConfig, + AsyncRLConfig, + MasterConfig, + WeightSyncConfig, +) +from nemo_rl.algorithms.single_controller_utils.setup import ( + SingleControllerBundle, + setup_single_controller, +) + +__all__ = [ + "AdvantageConfig", + "AsyncRLConfig", + "MasterConfig", + "SingleControllerBundle", + "WeightSyncConfig", + "setup_single_controller", +] diff --git a/nemo_rl/algorithms/single_controller_utils/config.py b/nemo_rl/algorithms/single_controller_utils/config.py new file mode 100644 index 0000000000..9ffa122adc --- /dev/null +++ b/nemo_rl/algorithms/single_controller_utils/config.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + +from nemo_rl.algorithms.grpo import GRPOConfig, GRPOLoggerConfig +from nemo_rl.algorithms.loss import ClippedPGLossConfig +from nemo_rl.data import DataConfig +from nemo_rl.data_plane.interfaces import DataPlaneConfig +from nemo_rl.distributed.virtual_cluster import ClusterConfig +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.utils.checkpoint import CheckpointingConfig + +# ── User-facing SingleController configs ──────────────────────────────────── + + +class AsyncRLConfig(BaseModel, extra="allow"): + batch_selection_strategy: Literal[ + "strict_on_policy", + "staleness_window", + ] = "strict_on_policy" + # Sampler / on-policy enforcement. + max_weight_staleness_versions: int = 1 + min_prompt_groups_per_batch: int = 2 + # Pump concurrency caps. + max_inflight_prompts: int = 8 + max_buffered_rollouts: int = 8 + # True : over-generates and wastes rollouts that age past the staleness window; + # False: enforces per-weight-version dispatch quota. + over_sampling: bool = True + # Tag rollouts with their dispatch-time target step and require an exact + # match at sample time (legacy target_weight semantics). Requires + # over_sampling=False. + force_in_order: bool = False + + +class MasterConfig(BaseModel, extra="allow"): + policy: PolicyConfig + loss_fn: ClippedPGLossConfig + env: dict[str, Any] = Field(default_factory=dict) + data: DataConfig + grpo: GRPOConfig + logger: GRPOLoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + data_plane: DataPlaneConfig + async_rl: AsyncRLConfig = Field(default_factory=AsyncRLConfig) + + +# ── Internal SingleController configs ──────────────────────────────────── + + +@dataclass +class AdvantageConfig: + output_field: str = "advantages" + prompt_ids_field: str = "prompt_ids_for_adv" + reward_field: str = "total_reward" + token_mask_field: str = "token_mask" + sample_mask_field: str = "sample_mask" + repeated_batch_fields: list[str] = field(default_factory=list) + policy_logprobs_field: Optional[str] = "prev_logprobs" + reference_logprobs_field: Optional[str] = "reference_policy_logprobs" + + +@dataclass +class WeightSyncConfig: + transport: str = "stub" + nccl_addr: str = "127.0.0.1" + nccl_port: Optional[int] = None diff --git a/nemo_rl/algorithms/single_controller_utils/setup.py b/nemo_rl/algorithms/single_controller_utils/setup.py new file mode 100644 index 0000000000..abae6882c6 --- /dev/null +++ b/nemo_rl/algorithms/single_controller_utils/setup.py @@ -0,0 +1,402 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Driver-side factory for the SingleController (async-RL) training path. + +setup builds the full SingleControllerBundle on the driver and the caller passes it to +SingleControllerActor.remote. Everything lives on the driver because driver-side +TQPolicy owns the worker group directly — running this inside another Ray actor nests +runtime_envs and breaks Ray's resource resolution (see the PR #2692 follow-up). +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Optional + +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoProcessor +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.async_utils.replay_buffer import TQReplayBuffer +from nemo_rl.algorithms.grpo import _create_advantage_estimator, _should_use_nemo_gym +from nemo_rl.algorithms.loss import ClippedPGLossFn +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.single_controller_utils.config import MasterConfig +from nemo_rl.algorithms.utils import set_seed +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.utils import setup_response_data +from nemo_rl.data_plane import build_data_plane_client +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.nemo_gym import spinup_nemo_gym_actor +from nemo_rl.experience.rollout_manager import RolloutManager +from nemo_rl.models.generation.sglang.sglang_generation import SGLangGeneration +from nemo_rl.models.generation.vllm import VllmGeneration +from nemo_rl.models.policy.tq_policy import TQPolicy +from nemo_rl.weight_sync import WeightSynchronizer, create_weight_synchronizer + + +@dataclass +class SingleControllerBundle: + """All inputs SingleControllerActor needs, built driver-side by setup_single_controller(). + + Passed as a single arg to SingleControllerActor.remote so the actor's __init__ does + no construction work — every heavy object is cloudpickled in. + """ + + gen_handle: Any + trainer_handle: Any # driver-side TQPolicy + env_handles: dict[str, EnvironmentInterface] + train_cluster: RayVirtualCluster + inference_cluster: RayVirtualCluster + dp_client: Any + dataloader: StatefulDataLoader + weight_synchronizer: WeightSynchronizer + advantage_estimator: Any + loss_fn: LossFunction + rollout_manager: RolloutManager + tq_buffer: TQReplayBuffer + partition_id: str + + +def _build_clusters( + master_config: MasterConfig, +) -> tuple[RayVirtualCluster, RayVirtualCluster]: + """Allocate train + inference clusters; one shared cluster when colocated.""" + cluster_config = master_config.cluster + generation_config = master_config.policy["generation"] + colocated = generation_config["colocated"]["enabled"] + backend = generation_config["backend"] + num_nodes = cluster_config["num_nodes"] + gpus_per_node = cluster_config["gpus_per_node"] + port_range_low = cluster_config.get("master_port_range_low") + port_range_high = cluster_config.get("master_port_range_high") + + if colocated: + # Policy + generation share GPUs — one cluster. + cluster = RayVirtualCluster( + name="sc_policy_cluster", + bundle_ct_per_node_list=[gpus_per_node] * num_nodes, + use_gpus=True, + num_gpus_per_node=gpus_per_node, + max_colocated_worker_groups=1 if backend == "megatron" else 2, + port_range_low=port_range_low, + port_range_high=port_range_high, + ) + return cluster, cluster + + # Non-colocated: split node into train + inference clusters. + assert backend != "megatron", ( + "Non-colocated inference is not supported for Megatron generation backends." + ) + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources["gpus_per_node"] + inference_nodes = inference_resources["num_nodes"] or 1 + if num_nodes == 1: + train_gpus_per_node = gpus_per_node - inference_gpus_per_node + train_nodes = 1 + assert train_gpus_per_node > 0, ( + f"Not enough GPUs for training: {gpus_per_node} - {inference_gpus_per_node} = {train_gpus_per_node}" + ) + else: + train_gpus_per_node = gpus_per_node + train_nodes = num_nodes - inference_nodes + assert train_nodes > 0, ( + f"train_nodes must be > 0: {num_nodes} - {inference_nodes} = {train_nodes}" + ) + + train_cluster = RayVirtualCluster( + name="sc_train_cluster", + bundle_ct_per_node_list=[train_gpus_per_node] * train_nodes, + use_gpus=True, + num_gpus_per_node=train_gpus_per_node, + max_colocated_worker_groups=1, + port_range_low=port_range_low, + port_range_high=port_range_high, + ) + inference_cluster = RayVirtualCluster( + name="sc_inference_cluster", + bundle_ct_per_node_list=[inference_gpus_per_node] * inference_nodes, + use_gpus=True, + num_gpus_per_node=inference_gpus_per_node, + max_colocated_worker_groups=1, + port_range_low=port_range_low, + port_range_high=port_range_high, + ) + return train_cluster, inference_cluster + + +def _build_generation( + inference_cluster: RayVirtualCluster, + master_config: MasterConfig, +): + """Spin up the generation backend (vLLM or SGLang).""" + generation_config = master_config.policy["generation"] + generation_config["model_name"] = master_config.policy["model_name"] + backend = generation_config["backend"] + if backend == "vllm": + generation_config["vllm_kwargs"]["hf_overrides"] = master_config.policy.get( + "hf_config_overrides", {} + ) + gen = VllmGeneration(cluster=inference_cluster, config=generation_config) + elif backend == "sglang": + generation_config["sglang_cfg"].setdefault( + "model_path", master_config.policy["model_name"] + ) + gen = SGLangGeneration(cluster=inference_cluster, config=generation_config) + else: + raise ValueError( + f"single_controller_utils.setup only supports vllm or sglang generation; got {backend!r}" + ) + gen.finish_generation() + return gen + + +def _build_trainer( + train_cluster: RayVirtualCluster, + master_config: MasterConfig, + tokenizer, + processor, +): + """Build the TQ-mediated trainer (driver-side TQPolicy). + + Driver-side on purpose: instantiating TQPolicy inside another Ray + actor nests runtime_envs and triggers Ray's + get_accelerator_ids_for_accelerator_resource IndexError. Keep this + here until PolicyTrainerActor (PR #2692) lands. + """ + loss_config = master_config.loss_fn + init_reference_model = loss_config.reference_policy_kl_penalty > 0 + return TQPolicy( + cluster=train_cluster, + config=master_config.policy, + tokenizer=tokenizer, + processor=processor, + weights_path=None, + optimizer_path=None, + init_optimizer=True, + init_reference_model=init_reference_model, + dp_cfg=master_config.data_plane, + ) + + +def _generation_max_seq_len(generation_config) -> int: + """Return the per-backend max sequence length. + + vllm uses vllm_cfg.max_model_len; sglang uses sglang_cfg.context_length; + megatron generation has no dedicated field and routes max_new_tokens + through as max_sequence_length on the inference worker. + """ + backend = generation_config["backend"] + if backend == "vllm": + return generation_config["vllm_cfg"]["max_model_len"] + if backend == "sglang": + return generation_config["sglang_cfg"]["context_length"] + if backend == "megatron": + return generation_config["max_new_tokens"] + raise ValueError(f"Unknown generation backend: {backend!r}") + + +def _clamp_max_num_steps( + master_config: MasterConfig, dataloader: StatefulDataLoader +) -> None: + """Clamp grpo.max_num_steps to max_num_epochs * len(dataloader).""" + grpo_config = master_config.grpo + max_num_epochs = grpo_config.get("max_num_epochs") + if max_num_epochs is None: + return + grpo_config["max_num_steps"] = min( + grpo_config["max_num_steps"], + max_num_epochs * len(dataloader), + ) + + +def _maybe_inject_megatron_train_iters( + master_config: MasterConfig, dataloader: StatefulDataLoader +) -> None: + """Set megatron_cfg.train_iters; must run before _build_trainer.""" + policy_config = master_config.policy + if not policy_config.get("megatron_cfg", {}).get("enabled", False): + return + grpo_config = master_config.grpo + policy_config["megatron_cfg"]["train_iters"] = min( + grpo_config["max_num_steps"], + grpo_config["max_num_epochs"] * len(dataloader), + ) + + +def setup_single_controller( + master_config: MasterConfig, + tokenizer: PreTrainedTokenizerBase, + *, + processor: Optional[AutoProcessor] = None, + partition_id: str = "rollout_data", +) -> SingleControllerBundle: + """Build the full SC bundle driver-side. + + Args: + master_config: SC MasterConfig. + tokenizer: Tokenizer used by the policy. + processor: Optional AutoProcessor for VLM paths. + partition_id: TQ partition the rollout writer + sampler share. + + Returns: + SingleControllerBundle ready to be passed to SingleControllerActor. + """ + dp_cfg = master_config.data_plane + if dp_cfg is None or not dp_cfg.get("enabled", False): + raise ValueError( + "single_controller_utils.setup requires " + "master_config.data_plane.enabled=True. The async-RL " + "SingleController path is built on the TransferQueue data plane." + ) + + data_config = master_config.data + grpo_config = master_config.grpo + generation_config = master_config.policy["generation"] + assert generation_config is not None, ( + "single_controller_utils.setup requires policy.generation in master_config" + ) + + if data_config["use_multiple_dataloader"]: + raise NotImplementedError( + "single_controller_utils does not support " + "data.use_multiple_dataloader=True yet." + ) + + set_seed(grpo_config["seed"]) + + # ========================== + # Setup Dataset & Environments + # ========================== + # TODO: add validate dataset wiring. + use_nemo_gym = _should_use_nemo_gym(master_config) + if use_nemo_gym: + # NeMo-Gym creates the env actor outside setup_response_data; we wire + # it in after generation is up (it needs the OpenAI server URLs). + dataset, _val_dataset = setup_response_data( + tokenizer, data_config, env_configs=None + ) + env_handles: dict[str, EnvironmentInterface] = {} + else: + dataset, _val_dataset, env_handles, _val_env_handles = setup_response_data( + tokenizer, data_config, env_configs=master_config.env + ) + dataloader = StatefulDataLoader( + dataset, + batch_size=grpo_config["num_prompts_per_step"], + shuffle=data_config["shuffle"], + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config["num_workers"], + ) + + _clamp_max_num_steps(master_config, dataloader) + _maybe_inject_megatron_train_iters(master_config, dataloader) + + # ========================== + # Setup Clusters & Workers + # ========================== + train_cluster, inference_cluster = _build_clusters(master_config) + colocated = generation_config["colocated"]["enabled"] + if colocated: + # Colocated: vLLM prefers a clean GPU at load time, so generation + # comes up before the policy. + generation = _build_generation(inference_cluster, master_config) + policy = _build_trainer(train_cluster, master_config, tokenizer, processor) + else: + # Non-colocated: generation + policy run on disjoint GPUs, so + # bring them up in parallel. + with ThreadPoolExecutor(max_workers=2) as executor: + gen_future = executor.submit( + _build_generation, inference_cluster, master_config + ) + policy_future = executor.submit( + _build_trainer, train_cluster, master_config, tokenizer, processor + ) + generation = gen_future.result() + policy = policy_future.result() + + # ========================== + # NeMo-Gym actor (after generation is up so OpenAI URLs are available) + # ========================== + if use_nemo_gym: + if generation_config["backend"] != "vllm": + raise NotImplementedError( + "SC NeMo-Gym integration currently supports the vllm backend " + f"only; got {generation_config['backend']!r}" + ) + env_handles["nemo_gym"] = spinup_nemo_gym_actor( + env_configs=master_config.env, + base_urls=generation.dp_openai_server_base_urls, + model_name=generation_config["model_name"], + ) + + # ========================== + # Setup Data Plane Client & Weight Sync + # ========================== + # Connect-only DP client; TQPolicy already bootstrapped the controller. + dp_client = build_data_plane_client(dp_cfg, bootstrap=False) + + backend = generation_config["backend"] + weight_synchronizer = create_weight_synchronizer( + policy=policy, + generation=generation, + generation_backend=backend, + colocated=colocated, + train_cluster=train_cluster, + inference_cluster=inference_cluster, + ) + weight_synchronizer.init_communicator() + + # ========================== + # Setup Algorithm + Rollout Wiring + # ========================== + advantage_estimator = _create_advantage_estimator(master_config) + loss_fn: LossFunction = ClippedPGLossFn(master_config.loss_fn) + + pad_id = int(getattr(tokenizer, "pad_token_id", 0) or 0) + tq_buffer = TQReplayBuffer( + dp_client, + partition_id=partition_id, + pad_value_dict={"token_ids": pad_id, "input_ids": pad_id}, + ) + rollout_manager = RolloutManager( + tokenizer=tokenizer, + env_handles=env_handles, + num_generations_per_prompt=grpo_config["num_generations_per_prompt"], + max_seq_len=_generation_max_seq_len(generation_config), + max_rollout_turns=grpo_config.get("max_rollout_turns"), + policy_generation=generation, + generation_config=generation_config, + use_nemo_gym=use_nemo_gym, + tq_buffer=tq_buffer, + ) + + return SingleControllerBundle( + gen_handle=generation, + trainer_handle=policy, + env_handles=env_handles, + train_cluster=train_cluster, + inference_cluster=inference_cluster, + dp_client=dp_client, + dataloader=dataloader, + weight_synchronizer=weight_synchronizer, + advantage_estimator=advantage_estimator, + loss_fn=loss_fn, + rollout_manager=rollout_manager, + tq_buffer=tq_buffer, + partition_id=partition_id, + ) diff --git a/nemo_rl/algorithms/single_controller_utils/utils.py b/nemo_rl/algorithms/single_controller_utils/utils.py new file mode 100644 index 0000000000..e202d99c4d --- /dev/null +++ b/nemo_rl/algorithms/single_controller_utils/utils.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers used by SingleControllerActor.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict + +from nemo_rl.data_plane import KVBatchMeta + +# Reduction rules for all_mb_metrics. Mirror grpo.py / grpo_sync.py. +_MB_METRIC_MIN: frozenset[str] = frozenset( + {"probs_ratio_min", "probs_ratio_clamped_min"} +) +_MB_METRIC_MAX: frozenset[str] = frozenset( + {"probs_ratio_max", "probs_ratio_clamped_max"} +) +_MB_METRIC_MEAN: frozenset[str] = frozenset( + { + "lr", + "wd", + "reward", + "global_valid_seqs", + "global_valid_toks", + "mean_prompt_length", + } +) + + +def aggregate_step_metrics(train_result: dict[str, Any]) -> dict[str, Any]: + """Reduce per-microbatch metric lists into step-level scalars. + + Args: + train_result: Output of TQPolicy.finish_train_step. + + Returns: + Flat dict of step-level scalars ready for logging. + """ + metrics: dict[str, Any] = {} + loss = train_result.get("loss") + if isinstance(loss, torch.Tensor): + metrics["loss"] = loss.detach().mean().item() + elif loss is not None: + metrics["loss"] = float(loss) + grad_norm = train_result.get("grad_norm") + if isinstance(grad_norm, torch.Tensor): + metrics["grad_norm"] = grad_norm.detach().mean().item() + elif grad_norm is not None: + metrics["grad_norm"] = float(grad_norm) + if "total_flops" in train_result: + metrics["total_flops"] = float(train_result["total_flops"]) + if "num_ranks" in train_result: + metrics["num_ranks"] = int(train_result["num_ranks"]) + + # moe/mtp share the same reduction rules as all_mb_metrics in grpo.py. + mb: dict[str, list[Any]] = {} + if "moe_metrics" in train_result: + mb.update({f"moe/{k}": v for k, v in train_result["moe_metrics"].items()}) + if "mtp_metrics" in train_result: + mb.update({f"mtp/{k}": v for k, v in train_result["mtp_metrics"].items()}) + mb.update(train_result.get("all_mb_metrics", {})) + + for k, v in mb.items(): + if k in _MB_METRIC_MIN: + valid = [x for x in v if not np.isinf(x)] + metrics[k] = float(np.min(valid)) if valid else -1.0 + elif k in _MB_METRIC_MAX: + valid = [x for x in v if not np.isinf(x)] + metrics[k] = float(np.max(valid)) if valid else -1.0 + elif k in _MB_METRIC_MEAN: + metrics[k] = float(np.mean(v)) + else: + metrics[k] = float(np.sum(v)) + return metrics + + +def reduce_advantage_pump_metrics( + rewards: list[torch.Tensor], + masked_advantages: list[torch.Tensor], + sequence_lengths: list[int], +) -> dict[str, float]: + """Reduce per-step accumulators from _advantage_pump into step scalars. + + Args: + rewards: One tensor per advantage_pump call; each row a sample reward. + masked_advantages: Token-masked advantages, one tensor per call. + sequence_lengths: All input_lengths trained on this step. + + Returns: + Dict with reward, advantages/{mean,max,min}, total_num_tokens. + """ + out: dict[str, float] = {} + if rewards: + out["reward"] = float(torch.cat([r.flatten() for r in rewards]).mean()) + if masked_advantages: + cat = torch.cat([a.flatten() for a in masked_advantages]) + if cat.numel() > 0: + out["advantages/mean"] = float(cat.mean()) + out["advantages/max"] = float(cat.max()) + out["advantages/min"] = float(cat.min()) + else: + out["advantages/mean"] = 0.0 + out["advantages/max"] = 0.0 + out["advantages/min"] = 0.0 + if sequence_lengths: + out["total_num_tokens"] = float(sum(sequence_lengths)) + return out + + +def tensor_field(data: TensorDict, field_name: str) -> torch.Tensor: + """Read a tensor column from a TensorDict, depadding if nested. + + Args: + data: TensorDict returned by the data plane. + field_name: Column name to fetch. + + Returns: + Dense tensor (nested columns are padded with zeros). + """ + value = data[field_name] + if not isinstance(value, torch.Tensor): + raise TypeError(f"expected tensor field {field_name!r}; got {type(value)}") + if value.is_nested: + return torch.nested.to_padded_tensor(value, padding=0) + return value + + +def squeeze_trailing_unit_dim(value: torch.Tensor) -> torch.Tensor: + """Drop a trailing dim of size 1 if present. + + Args: + value: Input tensor. + + Returns: + Tensor without the trailing unit dim. + """ + if value.dim() >= 2 and value.shape[-1] == 1: + return value.squeeze(-1) + return value + + +def fields_for_put(meta: KVBatchMeta, fields: dict[str, torch.Tensor]) -> TensorDict: + """Pack tensors for DataPlane put, re-nesting jagged rows when needed. + + Args: + meta: Batch meta whose sequence_lengths drive the nesting. + fields: Field name to dense tensor. + + Returns: + TensorDict shaped for dp_client.put_samples. + """ + packed: dict[str, torch.Tensor] = {} + if meta.sequence_lengths is None: + for field_name, value in fields.items(): + packed[field_name] = value.detach().contiguous() + # pyrefly: ignore[bad-argument-type] + return TensorDict(packed, batch_size=[meta.size]) + + lengths = torch.tensor(meta.sequence_lengths, dtype=torch.long) + for field_name, value in fields.items(): + if value.dim() >= 2 and value.shape[1] == int(lengths.max().item()): + rows = [ + value[i, : int(lengths[i].item())].detach().contiguous() + for i in range(meta.size) + ] + packed[field_name] = torch.nested.as_nested_tensor( + rows, + layout=torch.jagged, + ) + else: + packed[field_name] = value.detach().contiguous() + # pyrefly: ignore[bad-argument-type] + return TensorDict(packed, batch_size=[meta.size]) diff --git a/nemo_rl/data_plane/interfaces.py b/nemo_rl/data_plane/interfaces.py index 6bdc5e940c..41a98f0c0e 100644 --- a/nemo_rl/data_plane/interfaces.py +++ b/nemo_rl/data_plane/interfaces.py @@ -219,6 +219,31 @@ def concat(self, *others: "KVBatchMeta") -> "KVBatchMeta": sample_ids=sample_ids, sequence_lengths=seq_lens, tags=tags ) + def drop(self, indices: "Sequence[int]") -> "KVBatchMeta | None": + """Complement of :meth:`subset`. Returns ``None`` when all rows are dropped.""" + dropped = set(indices) + keep = [i for i in range(self.size) if i not in dropped] + if not keep: + return None + return self.subset(keep) + + def with_fields(self, field_names: "Sequence[str]") -> "KVBatchMeta": + """Return a copy with ``field_names`` merged into ``fields`` (deduped, order-preserving).""" + merged = list(dict.fromkeys([*(self.fields or []), *field_names])) + return KVBatchMeta( + partition_id=self.partition_id, + task_name=self.task_name, + sample_ids=list(self.sample_ids), + fields=merged, + sequence_lengths=( + list(self.sequence_lengths) + if self.sequence_lengths is not None + else None + ), + extra_info=dict(self.extra_info or {}), + tags=[dict(tag) for tag in self.tags] if self.tags is not None else None, + ) + class DataPlaneClient(ABC): """Stable, swappable data-plane boundary. diff --git a/nemo_rl/data_plane/worker_mixin.py b/nemo_rl/data_plane/worker_mixin.py index 89c0cb622d..8546336590 100644 --- a/nemo_rl/data_plane/worker_mixin.py +++ b/nemo_rl/data_plane/worker_mixin.py @@ -516,3 +516,82 @@ def get_reference_policy_logprobs_presharded( tq_field="reference_policy_logprobs", ) del result + + # ── split-API entrypoints (SC async path) ────────────────────────────── + # + # The split path lets SingleController drive forward/backward per + # microbatch (or per pipeline-batch on Megatron) without stepping the + # optimizer until a full logical batch has accumulated. Backend + # methods (``begin_train_step``, ``train_microbatch``, + # ``finish_train_step``, ``abort_train_step``) own the train-step + # state machine; this mixin just gates them on TQ-presharded data. + + @wrap_with_nvtx_name("policy_worker/begin_train_step_presharded") + def begin_train_step_presharded( + self, + step_id: str, + loss_fn: Any, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step. No fetch — pure lifecycle. + + The backend stores ``step_id`` / ``loss_fn`` / ``gbs`` / ``mbs``, + clears gradients, and initialises accumulators for + ``local_valid_seqs`` / ``local_valid_toks`` and any per-microbatch + metrics. Optimizer state is untouched here. + """ + self.begin_train_step( # type: ignore[attr-defined] + step_id=step_id, + loss_fn=loss_fn, + gbs=gbs, + mbs=mbs, + ) + + @wrap_with_nvtx_name("policy_worker/train_microbatch_presharded") + def train_microbatch_presharded( + self, + step_id: str, + meta: "KVBatchMeta", + ) -> dict[str, Any]: + """Per-rank microbatch entrypoint. Fetch → packing prep → forward+backward. + + Gradients accumulate into ``.grad`` across calls; no + ``optimizer.step`` here. Returns per-microbatch metrics (loss, + local_valid_*); the backend folds them into the step accumulator + and the caller may surface them for diagnostics. + """ + data = self._fetch(meta) + data = self._attach_or_repack_pack_metadata(data, meta) + return self.train_microbatch( # type: ignore[attr-defined] + step_id=step_id, + data=data, + ) + + @wrap_with_nvtx_name("policy_worker/finish_train_step_presharded") + def finish_train_step_presharded( + self, + step_id: str, + ) -> dict[str, Any]: + """Close a logical train step. No fetch — pure lifecycle. + + Backend all-reduces accumulated ``local_valid_seqs/toks``, + rescales gradients to the final global normalization, runs grad + clip, steps the optimizer + scheduler, then zeros gradients. + Returns the aggregated step result (``loss``, ``grad_norm``, + ``all_mb_metrics``, …). + """ + return self.finish_train_step(step_id=step_id) # type: ignore[attr-defined] + + @wrap_with_nvtx_name("policy_worker/abort_train_step_presharded") + def abort_train_step_presharded( + self, + step_id: str, + ) -> None: + """Discard partial train-step state without stepping the optimizer. + + Used when SC decides the logical batch will not complete (e.g. + weight-sync triggered mid-step). Backend drops accumulators and + zeros gradients. + """ + self.abort_train_step(step_id=step_id) # type: ignore[attr-defined] diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index 611751af36..15f511da88 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -14,12 +14,14 @@ import os import subprocess from pathlib import Path -from typing import Any, Dict, List, NotRequired, TypedDict +from typing import Any, Dict, List, NotRequired, Optional, TypedDict import ray import torch +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from transformers import PreTrainedTokenizerBase +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env from nemo_rl.distributed.virtual_cluster import ( DEFAULT_GYM_PORT_RANGE_HIGH, DEFAULT_GYM_PORT_RANGE_LOW, @@ -28,6 +30,7 @@ ) from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.utils.timer import Timer +from nemo_rl.utils.venvs import create_local_venv_on_each_node DEFAULT_INVALID_TOOL_CALL_PATTERNS = [ "", @@ -445,3 +448,60 @@ def setup_nemo_gym_config(config, tokenizer) -> None: # Stop strings or token ids are not supported generation_config["stop_strings"] = None generation_config["stop_token_ids"] = None + + +def spinup_nemo_gym_actor( + env_configs: dict[str, Any], + base_urls: list[Optional[str]], + model_name: str, +) -> Any: + """Spin up the NeMo-Gym actor against the given generation server URLs. + + When ``env_configs["nemo_gym"]["num_gpu_nodes"] > 0``, the actor is + scheduled with soft NodeAffinity to the current Ray node so its colocated + GPU resources land where the caller expects. + """ + nemo_gym_py_exec = get_actor_python_env("nemo_rl.environments.nemo_gym.NemoGym") + if nemo_gym_py_exec.startswith("uv"): + nemo_gym_py_exec = create_local_venv_on_each_node( + nemo_gym_py_exec, "nemo_rl.environments.nemo_gym.NemoGym" + ) + + nemo_gym_dict = env_configs["nemo_gym"] + # NeMo-RL-side detection knobs are top-level NemoGymConfig fields + # (where the detector reads them), not part of Gym's global config. + invalid_tool_call_patterns = nemo_gym_dict.pop("invalid_tool_call_patterns", None) + thinking_tags = nemo_gym_dict.pop("thinking_tags", None) + uv_cache_dir = get_nemo_gym_uv_cache_dir() + if uv_cache_dir is not None: + nemo_gym_dict.setdefault("uv_cache_dir", uv_cache_dir) + uv_venv_dir = get_nemo_gym_venv_dir() + if uv_venv_dir is not None: + nemo_gym_dict.setdefault("uv_venv_dir", uv_venv_dir) + + nemo_gym_cfg = NemoGymConfig( + model_name=model_name, + base_urls=base_urls, + invalid_tool_call_patterns=invalid_tool_call_patterns, + thinking_tags=thinking_tags, + initial_global_config_dict=nemo_gym_dict, + ) + + nemo_gym_opts: dict[str, Any] = {} + if nemo_gym_dict.get("num_gpu_nodes", 0): + nemo_gym_opts["scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=True, + ) + nemo_gym_opts["runtime_env"] = { + "py_executable": nemo_gym_py_exec, + "env_vars": { + **os.environ, + "VIRTUAL_ENV": nemo_gym_py_exec, + "UV_PROJECT_ENVIRONMENT": nemo_gym_py_exec, + }, + } + + actor = NemoGym.options(**nemo_gym_opts).remote(nemo_gym_cfg) + ray.get(actor._spinup.remote()) + return actor diff --git a/nemo_rl/experience/payload.py b/nemo_rl/experience/payload.py new file mode 100644 index 0000000000..bdf95ab290 --- /dev/null +++ b/nemo_rl/experience/payload.py @@ -0,0 +1,117 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Producer-side payload helpers for the async-RL TQ path.""" + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict + +from nemo_rl.data_plane.codec import pack_jagged_fields +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.experience.interfaces import PromptGroupRecord + + +def record_to_train_batch( + record: PromptGroupRecord, + *, + pad_value_dict: Mapping[str, int], +) -> BatchedDataDict[Any]: + """Convert one prompt group's record into a packed BatchedDataDict of N rows. + + Args: + record: Rollout's PromptGroupRecord with N completions to flatten into rows. + pad_value_dict: Field-name → pad value used by batched_message_log_to_flat_message. + + Returns: + BatchedDataDict with input_ids, input_lengths, generation_logprobs, token_mask, + sample_mask, prompt_ids_for_adv, and total_reward. + """ + # Lazy imports: grpo and llm_message_utils transitively pull + # experience.rollouts, so importing at module top risks a cycle. + from nemo_rl.algorithms.grpo import ( + add_grpo_token_loss_masks_and_generation_logprobs, + extract_initial_prompt_messages, + ) + from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message + + completions = record.completions + n = len(completions) + assert n > 0, "PromptGroupRecord has no completions" + + message_logs = [c.message_log for c in completions] + prompt_token_count = sum(len(m["token_ids"]) for m in record.prompt) + prompt_lengths = torch.full((n,), prompt_token_count, dtype=torch.long) + + prompt_message_logs = extract_initial_prompt_messages(message_logs, prompt_lengths) + prompt_flat, _ = batched_message_log_to_flat_message( + prompt_message_logs, + pad_value_dict=dict(pad_value_dict), # type: ignore + ) + + add_grpo_token_loss_masks_and_generation_logprobs(message_logs) + flat, input_lengths = batched_message_log_to_flat_message( + message_logs, # type: ignore + pad_value_dict=dict(pad_value_dict), # type: ignore + ) + + total_reward = torch.tensor( + [float(c.reward) for c in completions], dtype=torch.float32 + ) + sample_mask = torch.ones(n, dtype=torch.float32) + + return BatchedDataDict[Any]( + { + "input_ids": flat["token_ids"], + "input_lengths": input_lengths, + "generation_logprobs": flat["generation_logprobs"], + "token_mask": flat["token_loss_mask"], + "sample_mask": sample_mask, + "prompt_ids_for_adv": prompt_flat["token_ids"], + "total_reward": total_reward, + } + ) + + +def pack_payload( + train_batch: Mapping[str, Any], + *, + weight_version: int, + group_id: str, +) -> tuple[list[str], TensorDict, list[dict[str, Any]]]: + """Pack a producer batch into (sample_ids, fields, tags) for put_samples. + + Args: + train_batch: Mapping with at least input_lengths plus the tensor/object fields to send. + weight_version: Trainer weight version stamped on every row's tag. + group_id: Per-group identifier used as the sample_id prefix; the caller owns uniqueness. + + Returns: + sample_ids of the form {group_id}_g{i}, a jagged-packed TensorDict, and per-row tags. + """ + lengths = train_batch["input_lengths"] + n = int(lengths.shape[0]) + tensor_fields: dict[str, torch.Tensor | np.ndarray] = { + k: v + for k, v in train_batch.items() + if isinstance(v, torch.Tensor) + or (isinstance(v, np.ndarray) and v.dtype == object) + } + fields_td = pack_jagged_fields(tensor_fields, lengths=lengths) + sample_ids = [f"{group_id}_g{i}" for i in range(n)] + tags = [{"weight_version": weight_version} for _ in range(n)] + return sample_ids, fields_td, tags diff --git a/nemo_rl/experience/rollout_manager.py b/nemo_rl/experience/rollout_manager.py index 4b6b5b9a85..f382ab7af1 100644 --- a/nemo_rl/experience/rollout_manager.py +++ b/nemo_rl/experience/rollout_manager.py @@ -21,7 +21,8 @@ from transformers import PreTrainedTokenizerBase from wandb import Table -from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.algorithms.async_utils.replay_buffer import TQReplayBuffer +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.experience.interfaces import Completion, PromptGroupRecord @@ -50,7 +51,7 @@ class AsyncRolloutImpl: def __init__( self, tokenizer: TokenizerType, - task_to_env: dict[str, EnvironmentInterface], + env_handles: dict[str, EnvironmentInterface], num_generations_per_prompt: int, max_seq_len: int, policy_generation: GenerationInterface, @@ -58,7 +59,7 @@ def __init__( **kwargs: Any, ) -> None: self._tokenizer = tokenizer - self._task_to_env = task_to_env + self._env_handles = env_handles self._num_generations_per_prompt = num_generations_per_prompt self._max_seq_len = max_seq_len self._max_rollout_turns = max_rollout_turns @@ -191,7 +192,7 @@ async def _run_single_rollout( # step. In this case, need to wrap with asyncio.to_thread to make # this function yieldable. env_output = await asyncio.to_thread( - calculate_rewards, sample_batch, self._task_to_env + calculate_rewards, sample_batch, self._env_handles ) # Update reward and termination statistics @@ -387,15 +388,15 @@ class AsyncNemoGymRolloutImpl: def __init__( self, tokenizer: TokenizerType, - task_to_env: dict[str, EnvironmentInterface], + env_handles: dict[str, EnvironmentInterface], num_generations_per_prompt: int, max_seq_len: int, generation_config: GenerationConfig, - max_rollout_turns: Optional[int] = None, + max_rollout_turns: int, **kwargs: Any, ) -> None: self._tokenizer = tokenizer - self._task_to_env = task_to_env + self._env_handles = env_handles self._num_generations_per_prompt = num_generations_per_prompt self._max_seq_len = max_seq_len self._max_rollout_turns = max_rollout_turns @@ -417,7 +418,7 @@ async def run_rollout(self, input_sample: DatumSpec) -> PromptGroupRecord: timer.start(f"{timer_prefix}/total") rollout_inputs = self._build_inputs(input_sample) - completions, rollout_metrics = await self._run_rollouts( + completions, prompt_message_log, rollout_metrics = await self._run_rollouts( rollout_inputs, timer, timer_prefix ) @@ -426,7 +427,7 @@ async def run_rollout(self, input_sample: DatumSpec) -> PromptGroupRecord: return PromptGroupRecord( prompt_idx=input_sample["idx"], - prompt=input_sample["message_log"], + prompt=prompt_message_log, extra_env_info=input_sample["extra_env_info"], metadata={"task_name": "nemo_gym"}, completions=completions, @@ -442,8 +443,9 @@ def _validate_init_params(self) -> None: ) # Validate max_rollout_turns. - assert self._max_rollout_turns is None, ( - "`max_rollout_turns` is not supported in NeMo-Gym path!" + assert self._max_rollout_turns == 1, ( + "`max_rollout_turns` is not supported in NeMo-Gym path! " + "Please set `max_rollout_turns` to 1." ) def _build_inputs(self, input_sample: DatumSpec) -> list[dict]: @@ -476,15 +478,18 @@ def _build_inputs(self, input_sample: DatumSpec) -> list[dict]: async def _run_rollouts( self, inputs: list[dict], timer: Timer, timer_prefix: str - ) -> tuple[list[Completion], dict[str, Any]]: - """Dispatch rows to NeMo-Gym and return completions + metrics.""" - nemo_gym_env = self._task_to_env["nemo_gym"] + ) -> tuple[list[Completion], LLMMessageLogType, dict[str, Any]]: + """Dispatch rows to NeMo-Gym; return completions, prompt, and metrics.""" + nemo_gym_env = self._env_handles["nemo_gym"] # Run generation. with timer.time(f"{timer_prefix}/run_rollouts"): results, env_timing_metrics = await nemo_gym_env.run_rollouts.remote( inputs, self._tokenizer, timer_prefix ) + # All N rollouts share the same input prompt; tensorize one copy. + prompt_message_log = results[0]["input_message_log"] + _tensorize_by_key(prompt_message_log, "token_ids") # Convert results to completions. completions = [self._result_to_completion(r) for r in results] @@ -496,12 +501,11 @@ async def _run_rollouts( rollout_metrics.update(env_timing_metrics) - return completions, rollout_metrics + return completions, prompt_message_log, rollout_metrics def _result_to_completion(self, result: dict) -> Completion: """Convert one run_rollouts result dict into a Completion.""" # Tensorize token fields. - _tensorize_by_key(result["input_message_log"], "token_ids") _tensorize_by_key(result["message_log"], "token_ids") _tensorize_by_key( [m for m in result["message_log"] if m["role"] == "assistant"], @@ -581,18 +585,19 @@ def _compute_rollout_metrics( class RolloutManager: - """Factory that routes to AsyncRolloutImpl (native async) or AsyncNemoGymRolloutImpl (NeMo-Gym).""" + """Routes to AsyncRolloutImpl (native async) or AsyncNemoGymRolloutImpl (NeMo-Gym), and pushes results to a TQReplayBuffer.""" def __init__( self, tokenizer: TokenizerType, - task_to_env: dict[str, EnvironmentInterface], + env_handles: dict[str, EnvironmentInterface], num_generations_per_prompt: int, max_seq_len: int, max_rollout_turns: Optional[int] = None, policy_generation: Optional[GenerationInterface] = None, generation_config: Optional[GenerationConfig] = None, use_nemo_gym: bool = False, + tq_buffer: Optional[TQReplayBuffer] = None, ) -> None: assert num_generations_per_prompt >= 1, ( "num_generations_per_prompt must be >= 1" @@ -613,13 +618,52 @@ def __init__( self._impl: AsyncRolloutImpl | AsyncNemoGymRolloutImpl = rollout_cls( tokenizer=tokenizer, - task_to_env=task_to_env, + env_handles=env_handles, num_generations_per_prompt=num_generations_per_prompt, max_seq_len=max_seq_len, max_rollout_turns=max_rollout_turns, # type: ignore policy_generation=policy_generation, # type: ignore generation_config=generation_config, ) + self._tokenizer = tokenizer + self._num_generations_per_prompt = num_generations_per_prompt + self._tq_buffer = tq_buffer + self._weight_version: int = 0 + + def set_weight_version(self, version: int) -> None: + """Set the weight_version used for rollout tags. + + Args: + version: Trainer weight version to stamp on future rollout tags. + """ + self._weight_version = int(version) async def run_rollout(self, input_sample: DatumSpec) -> PromptGroupRecord: return await self._impl.run_rollout(input_sample) + + async def generate_and_push( + self, input_sample: DatumSpec, *, target_step: Optional[int] = None + ) -> None: + """Reserve a buffer slot, run one prompt's rollout, then commit the slot. + + Args: + input_sample: A single prompt (one DatumSpec entry). + target_step: Training step this rollout targets; stamped on the buffer slot for StalenessSampler.force_in_order. + """ + assert self._tq_buffer is not None, ( + "generate_and_push requires tq_buffer to be set at __init__" + ) + start_version = self._weight_version + group_id = self._tq_buffer.reserve( + weight_version=start_version, target_step=target_step + ) + + record = await self.run_rollout(input_sample) + end_version = self._weight_version + + await self._tq_buffer.commit( + group_id, + record, + start_weight_version=start_version, + end_weight_version=end_version, + ) diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index d68bf512bc..ce928cd9db 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -905,6 +905,9 @@ def shutdown(self) -> bool: try: # Use the worker group's shutdown method with the worker's cleanup method return self.worker_group.shutdown(cleanup_method="shutdown") + except ray.exceptions.RayActorError: + # Workers already dead (e.g., shut down via another handle to the same actors). + return True except Exception as e: print(f"Error during policy shutdown: {e}") return False diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index f1a5d705c0..6eea7497a3 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -1096,6 +1096,9 @@ def shutdown(self) -> bool: try: # Use the worker group's shutdown method with the worker's cleanup method return self.worker_group.shutdown(cleanup_method="shutdown") + except ray.exceptions.RayActorError: + # Workers already dead (e.g., shut down via another handle to the same actors). + return True except Exception as e: print(f"Error during policy shutdown: {e}") return False @@ -1103,12 +1106,11 @@ def shutdown(self) -> bool: def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. - This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to + This is an extra safety net in case the user forgets to call shutdown() and the pointer to the object is lost due to leaving a function scope. It's always recommended that the - user calls worker_group.shutdown(). + user calls shutdown(). """ - if hasattr(self, "worker_group"): - self.worker_group.shutdown(cleanup_method="shutdown") + self.shutdown() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" diff --git a/nemo_rl/models/policy/tq_policy.py b/nemo_rl/models/policy/tq_policy.py index 1e4bc93ba0..058647ba0a 100644 --- a/nemo_rl/models/policy/tq_policy.py +++ b/nemo_rl/models/policy/tq_policy.py @@ -457,3 +457,128 @@ def train_from_meta( warnings.warn(f"Error getting theoretical flops: {e}") return aggregated_results + + # ── split-API fanout (SC async path) ─────────────────────────────────── + # + # Counterpart to :meth:`train_from_meta`, exposed to ``PolicyTrainerActor`` + # so :class:`SingleControllerActor` can stream microbatches without + # forcing a full-step optimizer.step on every dispatch. + # + # Lifecycle: + # begin_train_step — open step; broadcast loss_fn/gbs/mbs + # train_microbatch_from_meta (N×) — DP-sharded fwd/bwd, grads accumulate + # finish_train_step — all_reduce + opt.step + sched.step + # abort_train_step — drop accumulators, no opt.step + # + # ``train_from_meta`` is unchanged and remains the sync entrypoint. + + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step on every worker.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + if self.flops_tracker is not None: + self.flops_tracker.reset() + futures = self.worker_group.run_all_workers_single_data( + "begin_train_step_presharded", + step_id=step_id, + loss_fn=loss_fn, + gbs=batch_size, + mbs=micro_batch_size, + ) + ray.get(futures) + + def train_microbatch_from_meta( + self, + step_id: str, + meta: KVBatchMeta, + timer: Optional[Timer] = None, + ) -> dict[str, Any]: + """Dispatch one microbatch (DP-sharded) into an open train step. + + Mirrors the sharding logic of :meth:`train_from_meta` but without + a logical-batch sizing constraint: this routes ``meta`` to DP + ranks and runs forward+backward; gradients accumulate in + ``.grad``. The optimizer step happens at :meth:`finish_train_step`. + """ + self._stamp_pad_seqlen(meta) + spa, dba = self._packing_args("train_mb_tokens") + train_meta = replace( + meta, + fields=list(DP_TRAIN_FIELDS), + task_name="train", + ) + with timer.time("policy_training/shard_meta") if timer else nullcontext(): + dp_metas, _ = shard_meta_for_dp( + train_meta, + dp_world=self.sharding_annotations.get_axis_size("data_parallel"), + batch_size=None, + sequence_packing_args=spa, + dynamic_batching_args=dba, + ) + + if self.flops_tracker is not None: + for m in dp_metas: + self.flops_tracker.track_batch(list(m.sequence_lengths or [])) + + with ( + timer.time("policy_training/submit_microbatch_futures") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "train_microbatch_presharded", + meta=dp_metas, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"step_id": step_id}, + ) + results = self.worker_group.get_all_worker_results(futures) + # Per-microbatch metrics: pass through DP-rank-0 by convention, + # backend may aggregate later if needed. Surface as-is for now. + return results[0] if results else {} + + def finish_train_step(self, step_id: str) -> dict[str, Any]: + """Close an open train step: all_reduce, rescale, optimizer.step. + + Aggregates per-rank step results into the same shape as + :meth:`train_from_meta` so callers don't have to special-case + the split path. + """ + futures = self.worker_group.run_all_workers_single_data( + "finish_train_step_presharded", + step_id=step_id, + ) + results = ray.get(futures) + aggregated_results = _aggregate_train_results(results) + + if self.flops_tracker is not None: + aggregated_results["total_flops"] = self.flops_tracker.total_flops + aggregated_results["num_ranks"] = self.worker_group.cluster.world_size() + + return aggregated_results + + def abort_train_step(self, step_id: str) -> None: + """Drop partial step state on every worker. No optimizer.step.""" + futures = self.worker_group.run_all_workers_single_data( + "abort_train_step_presharded", + step_id=step_id, + ) + ray.get(futures) + + if self.flops_tracker is not None: + self.flops_tracker.reset() diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 0eec29f04a..a13e494373 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -143,6 +143,11 @@ class MegatronPolicyWorkerImpl( AbstractPolicyWorker, ColocatablePolicyInterface, ): + # Holds the split-API train-step state between begin/finish or + # begin/abort; None when no step is open. Declared at class level so + # ``self._train_step_state = None`` after finish/abort type-checks. + _train_step_state: Optional[dict[str, Any]] = None + def __repr__(self): """Customizes the actor's prefix in the Ray logs. @@ -835,6 +840,389 @@ def _set_moe_grad_scale_func(self, func): if config is not None: config.moe_grad_scale_func = func + # ── split-API train-step state machine (SingleController async path) ── + # + # Mirrors the v1/v2 implementations, adapted for mcore. Key differences: + # + # 1. mcore DDP accumulates ``param.main_grad`` per backward and dispatches + # a cross-DP reduce when ``is_last_microbatch=True`` (one per + # ``forward_backward_func`` call). Naively chaining multiple + # ``forward_backward_func`` calls between ``optimizer.step()`` would + # over-count: each call's terminal reduce sums an already-reduced + # bucket again. We wrap every call in ``self.model.no_sync()`` so + # hooks accumulate locally only; one explicit ``start_grad_sync`` + + # ``finish_grad_sync`` at finish does the single true reduce. + # 2. PP>1: the pipeline scheduler invokes ``config.grad_sync_func`` + # directly on last-microbatch boundaries — this bypasses the + # ``no_sync`` gate. We null it for the duration of the step and + # restore at finish/abort. + # 3. Grad clip is bundled inside ``MegatronOptimizer.step()``; the 1/N + # rescale via ``self.model.scale_gradients(1/N)`` must run before + # ``optimizer.step()`` so the clip operates on the rescaled grad. + # 4. With ``calculate_per_token_loss=True`` + ``average_in_collective= + # False``, mcore's DDP sums (does not average) grads across DP, so + # no FSDP-style ``loss *= dp_size*cp_size`` cancellation is needed + # per microbatch. + + def _split_step_state_init( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int], + mbs: Optional[int], + ) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + return { + "step_id": step_id, + "loss_fn": loss_fn, + "loss_type": getattr(loss_fn, "loss_type", LossType.TOKEN_LEVEL), + "gbs": gbs or self.cfg["train_global_batch_size"], + "mbs": mbs or self.cfg["train_micro_batch_size"], + "local_valid_seqs": torch.zeros((), dtype=torch.float64, device="cuda"), + "local_valid_toks": torch.zeros((), dtype=torch.float64, device="cuda"), + "all_mb_metrics": [], + "mb_losses": [], + "total_num_microbatches": 0, + # Saved across the step so we can restore at finish/abort. + "saved_grad_sync_func": None, + "saved_no_sync_func": None, + "no_sync_active": False, + } + + def _assert_step_open(self, step_id: str) -> dict[str, Any]: + state = getattr(self, "_train_step_state", None) + if state is None: + raise RuntimeError( + f"no train step open; begin_train_step({step_id!r}) must be called first" + ) + if state["step_id"] != step_id: + raise RuntimeError( + f"step_id mismatch: open step is {state['step_id']!r}, got {step_id!r}" + ) + return state + + @wrap_with_nvtx_name("megatron_policy_worker/begin_train_step") + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + existing = getattr(self, "_train_step_state", None) + if existing is not None: + raise RuntimeError( + f"train step {existing['step_id']!r} is already open; " + f"call finish_train_step or abort_train_step before begin" + ) + # Match sync train() inference-state reset (line 332-340). + if hasattr(self.model, "inference_params"): + self.model.inference_params = None + for module in self.model.modules(): + if hasattr(module, "reset_inference_cache"): + module.reset_inference_cache() + if hasattr(module, "_inference_key_value_memory"): + module._inference_key_value_memory = None + + self.model.train() + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + + state = self._split_step_state_init( + step_id=step_id, loss_fn=loss_fn, gbs=gbs, mbs=mbs + ) + + # Null both mcore hooks that would fire a mid-step DP reduce: + # grad_sync_func — PP scheduler's direct call on last-MB boundaries. + # no_sync_func — forward_backward_no_pipelining wraps inner MBs + # in it and runs the LAST MB OUTSIDE, leaking + # per_param_grad_ready_counts past our outer + # no_sync. Override to nullcontext so only the + # outer no_sync in train_microbatch governs + # is_last_microbatch. + # getattr-by-string keeps "config" out of __code__.co_names so + # cloudpickle doesn't grab torch.distributed.config. + model_config = getattr(self.model, "config", None) + if model_config is not None: + state["saved_grad_sync_func"] = getattr( + model_config, "grad_sync_func", None + ) + state["saved_no_sync_func"] = getattr(model_config, "no_sync_func", None) + model_config.grad_sync_func = None + model_config.no_sync_func = nullcontext + else: + state["saved_grad_sync_func"] = None + state["saved_no_sync_func"] = None + + self._train_step_state = state + + @wrap_with_nvtx_name("megatron_policy_worker/train_microbatch") + def train_microbatch( + self, + step_id: str, + data: BatchedDataDict[Any], + ) -> dict[str, Any]: + """One DP slice of data → one ``forward_backward_func`` invocation. + + Wrapped in ``self.model.no_sync()`` so the mcore DDP hooks + accumulate ``param.main_grad`` locally on each rank without + dispatching a per-call DP reduce. The single true reduce is done + explicitly in ``finish_train_step``. + """ + state = self._assert_step_open(step_id) + loss_fn = state["loss_fn"] + + # Accumulate local mask sums for the finish-time all_reduce. + # Inlined from process_global_batch (data.py:319-332) — we can't + # call process_global_batch directly because it eagerly all_reduces + # the local sums, which is exactly what we're trying to defer. + assert "sample_mask" in data, "sample_mask required on microbatch data" + sample_mask = data["sample_mask"] + call_local_seqs = torch.sum(sample_mask).to(torch.float64) + if "token_mask" in data: + token_mask = data["token_mask"] + call_local_toks = torch.sum( + token_mask[:, 1:] * sample_mask.unsqueeze(-1) + ).to(torch.float64) + else: + call_local_toks = call_local_seqs * data["input_ids"].shape[1] + + state["local_valid_seqs"] = state["local_valid_seqs"] + call_local_seqs + state["local_valid_toks"] = state["local_valid_toks"] + call_local_toks + + # Build the per-call iterator. Each ``train_microbatch_from_meta`` + # call carries one DP slice; the iterator subdivides into pipeline + # microbatches. + ( + data_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator( + data, + self.cfg, + state["mbs"], + straggler_timer=self.mcore_state.straggler_timer, + ) + state["total_num_microbatches"] += int(num_microbatches) + + loss_post_processor = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + num_microbatches=num_microbatches, + sampling_params=self.sampling_params, + draft_model=self.draft_model, + ) + + # Placeholder N=1: loss returns un-normalized sums. ``backward`` + # deposits raw ``d(sum)/dθ`` into ``param.main_grad`` via the DDP + # hooks. The 1/N rescale happens once at finish. + placeholder_n = torch.tensor(1.0, device="cuda") + + draft_enabled = "draft" in self.cfg and self.cfg["draft"]["enabled"] + + # The critical wrap: hooks fire (accumulate main_grad) but the + # per-call reduce dispatch is gated off. + with self.model.no_sync(): + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + losses_reduced = megatron_forward_backward( + model=self.model, + data_iterator=data_iterator, + num_microbatches=num_microbatches, + seq_length=padded_seq_length, + mbs=micro_batch_size, + post_processing_fn=loss_post_processor, + forward_only=False, + defer_fp32_logits=self.defer_fp32_logits, + global_valid_seqs=placeholder_n, + global_valid_toks=placeholder_n, + sampling_params=self.sampling_params, + straggler_timer=self.mcore_state.straggler_timer, + draft_model=self.draft_model, + enable_hidden_capture=draft_enabled, + use_linear_ce_fusion_loss=self.cfg["megatron_cfg"].get( + "use_linear_ce_fusion_loss", False + ), + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + torch.cuda.empty_cache() + + # Collect per-mb metrics from the last PP stage; broadcast to all + # PP ranks so non-last-stage ranks have something to all_reduce + # against at finish. Metrics carry the N=1 placeholder for now — + # ``finish_train_step`` rescales by the true 1/N. + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + mb_metrics_collected = [] + for x in losses_reduced: + mb_metrics_collected.append(dict(x)) + else: + mb_metrics_collected = None + + mb_metrics_collected = broadcast_loss_metrics_from_last_stage( + mb_metrics_collected + ) + + for m in mb_metrics_collected: + state["all_mb_metrics"].append(m) + # ``loss`` key is the un-normalized per-mb scalar; collect for + # the global_loss aggregation at finish. + if "loss" in m: + state["mb_losses"].append(m["loss"]) + + return { + "local_valid_seqs_mb": float(call_local_seqs.item()), + "local_valid_toks_mb": float(call_local_toks.item()), + "num_pipeline_microbatches": int(num_microbatches), + } + + @wrap_with_nvtx_name("megatron_policy_worker/finish_train_step") + def finish_train_step(self, step_id: str) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + state = self._assert_step_open(step_id) + + # All-reduce accumulated mask sums across DP to recover true N. + to_reduce = torch.stack( + [state["local_valid_seqs"], state["local_valid_toks"]] + ).to(torch.float64) + torch.distributed.all_reduce( + to_reduce, group=parallel_state.get_data_parallel_group() + ) + global_valid_seqs = to_reduce[0] + global_valid_toks = to_reduce[1] + + if state["loss_type"] == LossType.TOKEN_LEVEL: + n_true = global_valid_toks + else: + n_true = global_valid_seqs + n_safe = n_true if n_true.item() > 0 else torch.tensor(1.0, device="cuda") + inv_n = float((1.0 / n_safe).item()) + + # Rescale all locally-accumulated gradients by 1/N. The reduce + # below sees the rescaled grads; for all_reduce the result is the + # global mean grad; for reduce_scatter (dist-opt) it's the shard. + # Either way, opt.step sees the right-normalized gradient. + self.model.scale_gradients(inv_n) + + # The ONE true cross-DP reduce for the entire step. + self.model.start_grad_sync() + self.model.finish_grad_sync() + + # opt.step clips internally (clip_grad config); operates on the + # already-rescaled grad. Returns (success, grad_norm, num_zeros). + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() + + pg_collection = get_pg_collection(self.model) + update_successful = logical_and_across_model_parallel_group( + update_successful, mp_group=pg_collection.mp + ) + grad_norm = reduce_max_stat_across_model_parallel_group( + grad_norm, mp_group=pg_collection.mp + ) + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group( + num_zeros_in_grad, mp_group=pg_collection.mp + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 2: + torch.cuda.empty_cache() + + # Restore the mcore hooks we nulled in begin_train_step. + # See begin_train_step for why .config is accessed by string. + finish_model_config = getattr(self.model, "config", None) + if finish_model_config is not None: + finish_model_config.grad_sync_func = state["saved_grad_sync_func"] + finish_model_config.no_sync_func = state["saved_no_sync_func"] + + # Record the LR/WD before self.scheduler.step() is called. + curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) + curr_wd = self.scheduler.get_wd() + + # Scheduler increment matches sync path's ``increment=gbs``. + self.scheduler.step(increment=state["gbs"]) + + # Per-mb metrics were computed with N=1; rescale to match what the + # sync path produces. ``masked_mean`` is linear in 1/N so a single + # scalar multiply per metric recovers the normalized value. + rescaled_metrics: list[dict[str, Any]] = [] + global_valid_seqs_f = float(global_valid_seqs.item()) + global_valid_toks_f = float(global_valid_toks.item()) + + for m in state["all_mb_metrics"]: + out: dict[str, Any] = {} + for k, v in m.items(): + if "_min" in k or "_max" in k: + out[k] = v + elif isinstance(v, torch.Tensor): + out[k] = v.detach() * inv_n + else: + out[k] = v * inv_n + out["lr"] = curr_lr + out["wd"] = curr_wd + out["global_valid_seqs"] = global_valid_seqs_f + out["global_valid_toks"] = global_valid_toks_f + rescaled_metrics.append(out) + + # Scale per-mb losses by 1/N and reduce per-call sums. + scaled_losses = [lv * inv_n for lv in state["mb_losses"]] + losses_to_aggregate = [torch.tensor(scaled_losses).sum().item()] + + mb_metrics, global_loss = aggregate_training_statistics( + all_mb_metrics=rescaled_metrics, + losses=losses_to_aggregate, + data_parallel_group=parallel_state.get_data_parallel_group(), + ) + + metrics = { + "global_loss": global_loss.cpu(), + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": mb_metrics, + "grad_norm": torch.tensor([grad_norm]), + } + + # MoE aux-loss metrics: same convention as sync train() — scale + # by the total pipeline-microbatch count accumulated across all + # train_microbatch calls. + model_config = getattr(self.model, "config", None) + num_moe_experts = getattr(model_config, "num_moe_experts", None) + if num_moe_experts is not None and num_moe_experts > 1: + moe_loss_scale = 1.0 / max(1, state["total_num_microbatches"]) + moe_metrics = get_moe_metrics( + loss_scale=moe_loss_scale, + per_layer_logging=self.cfg["megatron_cfg"]["moe_per_layer_logging"], + ) + if moe_metrics: + metrics["moe_metrics"] = moe_metrics + + self._train_step_state = None + return metrics + + @wrap_with_nvtx_name("megatron_policy_worker/abort_train_step") + def abort_train_step(self, step_id: str) -> None: + state = getattr(self, "_train_step_state", None) + if state is None: + return + if state["step_id"] != step_id: + raise RuntimeError( + f"abort_train_step({step_id!r}) does not match open step " + f"{state['step_id']!r}" + ) + # Restore the mcore hooks we nulled in begin_train_step before + # zero_grad_buffer touches anything. + # See begin_train_step for why .config is accessed by string. + abort_model_config = getattr(self.model, "config", None) + if abort_model_config is not None: + abort_model_config.grad_sync_func = state["saved_grad_sync_func"] + abort_model_config.no_sync_func = state["saved_no_sync_func"] + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + self._train_step_state = None + @wrap_with_nvtx_name("megatron_policy_worker/get_reference_policy_logprobs") def get_reference_policy_logprobs( self, diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index c00fa96b9b..caf973a3c5 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -393,6 +393,16 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: """ self.run.log({name: figure}, step=step) + def finish(self) -> None: + """Flush queued metrics and close the wandb service. + + Required when the run lives inside a Ray actor: Ray tears the worker + down before wandb's atexit hook can drain the IPC queue to the service. + """ + if self.run is not None: + self.run.finish() + self.run = None + def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: """Log histogram metrics to wandb. @@ -1038,6 +1048,13 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None: for logger in self.loggers: logger.log_hyperparams(params) + def finish(self) -> None: + """Flush and close backends that need explicit teardown (e.g. wandb).""" + for logger in self.loggers: + finish = getattr(logger, "finish", None) + if callable(finish): + finish() + def log_batched_dict_as_jsonl( self, to_log: BatchedDataDict[Any] | dict[str, Any], filename: str ) -> None: diff --git a/pyrefly.toml b/pyrefly.toml index bca7713e54..760945fc0c 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -45,12 +45,14 @@ project-includes = [ "nemo_rl/algorithms/async_utils/__init__.py", "nemo_rl/algorithms/async_utils/interfaces.py", "nemo_rl/algorithms/async_utils/replay_buffer.py", + "nemo_rl/algorithms/async_utils/staleness_sampler.py", "nemo_rl/algorithms/logits_sampling_utils.py", "nemo_rl/algorithms/loss/__init__.py", "nemo_rl/algorithms/loss/interfaces.py", "nemo_rl/algorithms/loss/utils.py", "nemo_rl/algorithms/opd.py", "nemo_rl/algorithms/reward_functions.py", + "nemo_rl/algorithms/single_controller.py", "nemo_rl/algorithms/utils.py", "nemo_rl/algorithms/x_token/__init__.py", "nemo_rl/algorithms/x_token/utils.py", @@ -164,6 +166,7 @@ project-includes = [ "nemo_rl/models/generation/vllm/vllm_backend.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", + "nemo_rl/models/megatron/draft/__init__.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", diff --git a/tests/functional/L1_Functional_Tests_SingleController.sh b/tests/functional/L1_Functional_Tests_SingleController.sh new file mode 100755 index 0000000000..2a6fad63ee --- /dev/null +++ b/tests/functional/L1_Functional_Tests_SingleController.sh @@ -0,0 +1,43 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +set -xeuo pipefail # Exit immediately if a command exits with a non-zero status + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_ROOT=$(realpath ${SCRIPT_DIR}/../..) + +cd ${PROJECT_ROOT} + +# run_test [fast] +# - "run_test fast " = always runs (both fast and full modes) +# - "run_test " = only runs in full mode; skipped when FAST=1 +run_test() { + if [[ "$1" == "fast" ]]; then + shift + time "$@" + elif [[ "${FAST:-0}" == "1" ]]; then + echo "FAST: Skipping: $*" + else + time "$@" + fi +} + +run_test fast uv run --no-sync bash ./tests/functional/grpo_dp_single_controller.sh +run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym_single_controller.sh + +cd ${PROJECT_ROOT}/tests +if compgen -G ".coverage*" > /dev/null; then + coverage combine .coverage* +fi diff --git a/tests/functional/grpo_async_gym_single_controller.sh b/tests/functional/grpo_async_gym_single_controller.sh new file mode 100755 index 0000000000..f4f283bfff --- /dev/null +++ b/tests/functional/grpo_async_gym_single_controller.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# SingleController + NeMo-Gym e2e smoke. Mirrors grpo_async_gym.sh but +# routes everything through the SC path (TransferQueue data plane + +# SingleControllerActor) instead of async_grpo_train. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +CHECKPOINT_DIR=$EXP_DIR/checkpoints +DATA_DIR=$EXP_DIR/data +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR $CHECKPOINT_DIR $DATA_DIR + +# clean up checkpoint directory on exit +trap "rm -rf $CHECKPOINT_DIR" EXIT + +cd $PROJECT_ROOT + +# Follow nemo-gym instructions here to get this data: +# https://docs.nvidia.com/nemo/gym/0.1.0/tutorials/nemo-rl-grpo/setup.html#training-nemo-rl-grpo-setup +cd 3rdparty/Gym-workspace/Gym + +# We need HF_TOKEN to download the data from huggingface +if [[ ! -f env.yaml ]]; then + if [[ -z "${HF_TOKEN:-}" ]]; then + echo "[ERROR] HF_TOKEN is not set" + exit 1 + fi + echo "hf_token: $HF_TOKEN" >> env.yaml +fi + +uv run ng_prepare_data "+config_paths=[resources_servers/workplace_assistant/configs/workplace_assistant.yaml]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface +cd - + +# This trimming of the workplace assistant dataset is necessary b/c with all the tools the first prompt is >4000 tokens +# which will cause vllm to return nothing on the first prompt and crash RL. Since we want to keep this test short to +# smoke test, we trim all but the first tool +TRAIN_PATH=$DATA_DIR/workplace_assistant_train.jsonl +VALIDATION_PATH=$DATA_DIR/workplace_assistant_validation.jsonl +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl > $TRAIN_PATH +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl > $VALIDATION_PATH + +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_single_controller.py \ + --config $PROJECT_ROOT/examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml \ + policy.model_name=Qwen/Qwen3-0.6B \ + policy.dtensor_cfg.enabled=false \ + policy.megatron_cfg.enabled=true \ + policy.megatron_cfg.tensor_model_parallel_size=1 \ + policy.megatron_cfg.pipeline_model_parallel_size=1 \ + policy.megatron_cfg.expert_model_parallel_size=1 \ + policy.megatron_cfg.context_parallel_size=1 \ + policy.megatron_cfg.sequence_parallel=false \ + policy.generation.vllm_cfg.tensor_parallel_size=1 \ + policy.generation.vllm_cfg.async_engine=true \ + policy.max_total_sequence_length=512 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.colocated.resources.gpus_per_node=1 \ + grpo.num_prompts_per_step=4 \ + grpo.num_generations_per_prompt=2 \ + grpo.max_num_steps=10 \ + grpo.val_period=-1 \ + policy.train_global_batch_size=8 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + loss_fn.reference_policy_kl_penalty=0.01 \ + loss_fn.use_importance_sampling_correction=true \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + data.train.data_path=$TRAIN_PATH \ + data.validation.data_path=$VALIDATION_PATH \ + ++data_plane.enabled=true \ + ++data_plane.impl=transfer_queue \ + ++data_plane.backend=simple \ + ++data_plane.storage_capacity=1000000 \ + ++data_plane.num_storage_units=2 \ + ++data_plane.claim_meta_poll_interval_s=0.5 \ + ++data_plane.global_segment_size=549755813888 \ + ++data_plane.local_buffer_size=68719476736 \ + ++async_rl.batch_selection_strategy=strict_on_policy \ + ++async_rl.max_weight_staleness_versions=0 \ + ++async_rl.min_prompt_groups_per_batch=4 \ + ++async_rl.max_inflight_prompts=4 \ + ++async_rl.max_buffered_rollouts=4 \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Observed to be between 0.8-1.3 +uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/gen_kl_error"]) < 1.3' \ + 'max(data["train/reward"]) > 0' diff --git a/tests/functional/grpo_dp_single_controller.sh b/tests/functional/grpo_dp_single_controller.sh new file mode 100755 index 0000000000..abd590cc47 --- /dev/null +++ b/tests/functional/grpo_dp_single_controller.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Lightweight e2e for examples/run_grpo_single_controller.py — exercises +# setup_handle + setup_single_controller_component + SingleControllerActor +# end-to-end. Same shape as tests/functional/grpo_dp_simple.sh (Qwen3-0.6B, +# 2 GPUs, a handful of steps); data_plane.enabled=true is mandatory for SC. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_single_controller.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=8 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=false \ + checkpointing.enabled=false \ + data_plane.enabled=true \ + data_plane.impl=transfer_queue \ + data_plane.backend=simple \ + async_rl.batch_selection_strategy=strict_on_policy \ + async_rl.max_weight_staleness_versions=0 \ + async_rl.min_prompt_groups_per_batch=2 \ + async_rl.max_inflight_prompts=2 \ + async_rl.max_buffered_rollouts=2 \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/gen_kl_error"]) < 0.002' \ + 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \ + 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_max"]) < 1.21' diff --git a/tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.sh b/tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.sh new file mode 100755 index 0000000000..e4c254c78f --- /dev/null +++ b/tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_single_controller.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=sc-yukih \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["10"] < 1.1' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.sh index 353804958e..298b6e0e9d 100755 --- a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.sh +++ b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.sh @@ -17,13 +17,14 @@ cd $PROJECT_ROOT uv run examples/run_grpo.py \ --config $CONFIG_PATH \ grpo.max_num_steps=$MAX_STEPS \ + grpo.val_period=-1 \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ + logger.wandb.project=sc-yukih \ logger.wandb.name=$EXP_NAME \ logger.monitor_gpus=True \ logger.tensorboard_enabled=True \ - checkpointing.enabled=True \ + checkpointing.enabled=false \ checkpointing.checkpoint_dir=$CKPT_DIR \ $@ \ 2>&1 | tee $RUN_LOG diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.sh new file mode 100755 index 0000000000..978b8893ee --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.sh @@ -0,0 +1,43 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_single_controller.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=sc-yukih \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["450"] < 1.1' \ + 'mean(data["timing/train/total_step_time"], 2) < 25' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/performance/grpo-llama3.1-8b-instruct-2n8g-async-1off.sh b/tests/test_suites/llm/performance/grpo-llama3.1-8b-instruct-2n8g-async-1off.sh index f9f561e3dc..0555651b0a 100755 --- a/tests/test_suites/llm/performance/grpo-llama3.1-8b-instruct-2n8g-async-1off.sh +++ b/tests/test_suites/llm/performance/grpo-llama3.1-8b-instruct-2n8g-async-1off.sh @@ -4,10 +4,10 @@ source $SCRIPT_DIR/common.env # ===== BEGIN CONFIG ===== NUM_NODES=2 -STEPS_PER_RUN=10 -MAX_STEPS=10 +STEPS_PER_RUN=100 +MAX_STEPS=100 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=100 +NUM_MINUTES=240 # ===== END CONFIG ===== exit_if_max_steps_reached @@ -17,13 +17,14 @@ cd $PROJECT_ROOT uv run examples/run_grpo.py \ --config $CONFIG_PATH \ grpo.max_num_steps=$MAX_STEPS \ + grpo.val_period=-1 \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=True \ - logger.wandb.project=nemo-rl \ + logger.wandb.project=sc-yukih \ logger.wandb.name=$EXP_NAME \ logger.monitor_gpus=True \ logger.tensorboard_enabled=True \ - checkpointing.enabled=True \ + checkpointing.enabled=false \ checkpointing.checkpoint_dir=$CKPT_DIR \ $@ \ 2>&1 | tee $RUN_LOG diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 2c2fa417dd..3653ae1a90 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -145,6 +145,10 @@ tests/test_suites/llm/grpo-nanov3-30BA3B-2n8g-fsdp2-tq_mooncake.sh tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora-tq_mooncake.sh tests/test_suites/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2-tq_mooncake.sh +# Single Contoller (SC) +tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-async-1off-single-controller.sh +tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-megatron-single-controller.sh + ######## # DAPO # ######## diff --git a/tests/unit/algorithms/test_async_utils.py b/tests/unit/algorithms/test_async_utils.py index b7cff7eecd..127fc3dd25 100644 --- a/tests/unit/algorithms/test_async_utils.py +++ b/tests/unit/algorithms/test_async_utils.py @@ -33,10 +33,7 @@ AsyncTrajectoryCollector, ReplayBuffer, ) -from nemo_rl.algorithms.async_utils.replay_buffer import ( - ReplayBufferImpl, - ReplayBufferNew, -) +from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferImpl from nemo_rl.algorithms.grpo import ( MasterConfig, add_grpo_token_loss_masks_and_generation_logprobs, @@ -920,160 +917,6 @@ def test_replay_buffer_checkpoint_with_torch_save(self): ray.kill(buffer2) -class TestReplayBufferNew: - """Tests for ReplayBufferNew: staleness-window sampling via _evict + sample.""" - - def _make_traj(self, label: str) -> dict: - return {"batch": {"data": label}, "rollout_metrics": {}} - - def _add(self, buf, label: str, weight_version: int): - return ray.get( - buf.add.remote( - self._make_traj(label), - weight_version=weight_version, - target_weight_version=0, # unused in ReplayBufferNew - ) - ) - - def _sample(self, buf, num_groups: int, trainer_version: int): - return ray.get( - buf.sample.remote( - num_prompt_groups=num_groups, - current_weight_version=trainer_version, - max_age_steps=0, # unused in ReplayBufferNew - ) - ) - - # ------------------------------------------------------------------ - # Construction - # ------------------------------------------------------------------ - - def test_invalid_max_staleness_raises(self): - with pytest.raises(Exception): - buf = ReplayBufferNew.remote(max_size=10, max_staleness=-1) - ray.get(buf.size.remote()) - - # ------------------------------------------------------------------ - # _evict (via sample) - # ------------------------------------------------------------------ - - def test_stale_rows_evicted_before_sampling(self): - """Rows with age > max_staleness are removed before sample() selects.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=2) - # age at trainer=4: gen_v=1 → 3 > 2 (stale), gen_v=3 → 1 ≤ 2 (valid) - self._add(buf, "stale", weight_version=1) - self._add(buf, "fresh", weight_version=3) - - result = self._sample(buf, num_groups=1, trainer_version=4) - - assert result is not None - assert result["trajectories"][0]["batch"]["data"] == "fresh" - assert ray.get(buf.size.remote()) == 0 # stale row also gone - ray.kill(buf) - - def test_all_stale_returns_none(self): - """sample() returns None when all rows are evicted as stale.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=1) - self._add(buf, "a", weight_version=0) - self._add(buf, "b", weight_version=1) - - # trainer=5: both ages > 1 - result = self._sample(buf, num_groups=1, trainer_version=5) - - assert result is None - assert ray.get(buf.size.remote()) == 0 - ray.kill(buf) - - def test_eviction_frees_capacity(self): - """Evicting a stale row allows a subsequent add() to succeed.""" - buf = ReplayBufferNew.remote(max_size=1, max_staleness=1) - self._add(buf, "x", weight_version=1) - assert self._add(buf, "x", weight_version=1) == "full" - - # sample() at trainer=5 evicts the stale row (age 4 > 1) - self._sample(buf, num_groups=1, trainer_version=5) - - assert self._add(buf, "y", weight_version=4) == "success" - ray.kill(buf) - - def test_within_window_not_evicted(self): - """Rows whose age is within max_staleness are not evicted.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=3) - self._add(buf, "x", weight_version=4) - - # trainer=6: age = 6 - 4 = 2 ≤ 3 → should survive - # should return None since there is only 1 row - result = self._sample(buf, num_groups=2, trainer_version=6) - assert result is None - - # this sample should still be there - assert ray.get(buf.size.remote()) == 1 - ray.kill(buf) - - # ------------------------------------------------------------------ - # sample() - # ------------------------------------------------------------------ - - @pytest.mark.parametrize("sample_freshest_first", [True, False]) - def test_sample_freshest_first(self, sample_freshest_first): - """sample() returns the freshest trajectories first.""" - buf = ReplayBufferNew.remote( - max_size=10, max_staleness=5, sample_freshest_first=sample_freshest_first - ) - for gen_v in [3, 4, 5]: - self._add(buf, f"v{gen_v}", weight_version=gen_v) - - result = self._sample(buf, num_groups=2, trainer_version=6) - - assert result is not None - data = [t["batch"]["data"] for t in result["trajectories"]] - if sample_freshest_first: - assert data == ["v5", "v4"] - else: - assert data == ["v3", "v4"] - ray.kill(buf) - - def test_sample_returns_none_when_insufficient(self): - """sample() returns None when fewer rows than requested remain after eviction.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) - self._add(buf, "only", weight_version=1) - - result = self._sample(buf, num_groups=3, trainer_version=2) - - assert result is None - ray.kill(buf) - - def test_sample_returns_none_on_empty_buffer(self): - buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) - result = self._sample(buf, num_groups=1, trainer_version=1) - assert result is None - ray.kill(buf) - - def test_sample_avg_trajectory_age(self): - """avg_trajectory_age is computed from the sampled generation versions.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) - # freshest first: gen 8 (age 2), gen 6 (age 4) → avg = 3.0 - for gen_v in [6, 8]: - self._add(buf, f"v{gen_v}", weight_version=gen_v) - - result = self._sample(buf, num_groups=2, trainer_version=10) - - assert result is not None - assert abs(result["avg_trajectory_age"] - 3.0) < 1e-6 - ray.kill(buf) - - def test_sample_consumes_selected_rows(self): - """Rows returned by sample() are removed from the buffer.""" - buf = ReplayBufferNew.remote(max_size=10, max_staleness=5) - for gen_v in [1, 2, 3]: - self._add(buf, f"v{gen_v}", weight_version=gen_v) - - self._sample(buf, num_groups=2, trainer_version=4) - - assert ray.get(buf.size.remote()) == 1 - ray.kill(buf) - - class TestAsyncTrajectoryCollector: """Test cases for AsyncTrajectoryCollector.""" diff --git a/tests/unit/data_plane/test_kvbatchmeta.py b/tests/unit/data_plane/test_kvbatchmeta.py index 4774c44f1e..a8dc3bc822 100644 --- a/tests/unit/data_plane/test_kvbatchmeta.py +++ b/tests/unit/data_plane/test_kvbatchmeta.py @@ -147,6 +147,96 @@ def test_tags_travel_with_subset_slice_concat(): assert joined.tags == m.tags +def test_drop_removes_indices_and_keeps_remaining_rows(): + """Rows at ``indices`` disappear; survivors keep their tags/seqlens.""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a", "b", "c", "d"], + sequence_lengths=[1, 2, 3, 4], + tags=[{"i": 0}, {"i": 1}, {"i": 2}, {"i": 3}], + ) + remaining = m.drop([1, 3]) + assert remaining.sample_ids == ["a", "c"] + assert remaining.sequence_lengths == [1, 3] + assert remaining.tags == [{"i": 0}, {"i": 2}] + + +def test_drop_empty_indices_returns_identical_content(): + """``drop([])`` returns a fresh copy with the same rows.""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a", "b"], + sequence_lengths=[1, 2], + ) + out = m.drop([]) + assert out.sample_ids == m.sample_ids + assert out.sequence_lengths == m.sequence_lengths + assert out is not m + + +def test_drop_all_returns_none(): + """Dropping every row returns ``None`` (callers chain into ``is None`` branches).""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a", "b"], + tags=[{"i": 0}, {"i": 1}], + ) + assert m.drop([0, 1]) is None + + +def test_drop_indices_can_be_unordered_or_duplicated(): + """Duplicate / out-of-order indices collapse via the internal set.""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a", "b", "c"], + ) + assert m.drop([2, 0, 0]).sample_ids == ["b"] + + +def test_with_fields_appends_and_dedupes_preserving_order(): + """Merges into ``fields`` with order-preserving dedup.""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a"], + fields=["input_ids", "advantages"], + ) + out = m.with_fields(["advantages", "logprobs"]) + assert out.fields == ["input_ids", "advantages", "logprobs"] + + +def test_with_fields_initializes_from_none(): + """Seeds ``fields`` from the argument when previously ``None``.""" + m = KVBatchMeta(partition_id="p", task_name="t", sample_ids=["a"]) + out = m.with_fields(["x", "y", "x"]) + assert out.fields == ["x", "y"] + + +def test_with_fields_returns_independent_copy(): + """Returned meta does not share mutable references with the source.""" + m = KVBatchMeta( + partition_id="p", + task_name="t", + sample_ids=["a", "b"], + sequence_lengths=[1, 2], + extra_info={"step": 0}, + tags=[{"x": 1}, {"x": 2}], + ) + out = m.with_fields(["z"]) + out.sample_ids.append("c") + out.sequence_lengths.append(3) # type: ignore[union-attr] + out.extra_info["step"] = 99 + out.tags[0]["x"] = 999 # type: ignore[index] + assert m.sample_ids == ["a", "b"] + assert m.sequence_lengths == [1, 2] + assert m.extra_info == {"step": 0} + assert m.tags == [{"x": 1}, {"x": 2}] + + def test_tags_none_when_either_side_missing_in_concat(): """``concat`` drops tags if either side has none — symmetric with the ``sequence_lengths`` behavior.""" diff --git a/tests/unit/experience/test_rollout_manager.py b/tests/unit/experience/test_rollout_manager.py new file mode 100644 index 0000000000..ee1f9921af --- /dev/null +++ b/tests/unit/experience/test_rollout_manager.py @@ -0,0 +1,795 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RolloutManager. + +Two groups: + +* TestGenerateAndPushFlow — lightweight unit tests for the reserve→run→commit + flow in generate_and_push (no Ray/vLLM; fakes for impl + tq_buffer). +* AsyncRollout / AsyncNemoGymRollout tests — vLLM/Ray-backed end-to-end checks + for the underlying run_rollout paths (AsyncRolloutImpl / AsyncNemoGymRolloutImpl). +""" + +from __future__ import annotations + +import asyncio +import json +import tempfile +import uuid +from copy import deepcopy + +import pytest +import torch + +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets.response_datasets import NemoGymDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.processors import nemo_gym_data_processor +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.experience.interfaces import Completion, PromptGroupRecord +from nemo_rl.experience.rollout_manager import RolloutManager +from nemo_rl.experience.rollouts import ( + run_async_multi_turn_rollout, + run_async_nemo_gym_rollout, +) + +# Fixtures shared with the heavyweight rollout tests. +from tests.unit.environments.test_nemo_gym import ( + cluster, # noqa: F401 + nemo_gym, # noqa: F401 + nemo_gym_sanity_test_data, # noqa: F401 + nemo_gym_tokenizer, # noqa: F401 + nemo_gym_vllm_generation, # noqa: F401 +) +from tests.unit.experience.test_rollouts import ( + initial_multi_step_calculator_batch, # noqa: F401 + multi_step_calculator_environment, # noqa: F401 + multi_step_setup_vllm_async, # noqa: F401 + rollout_cluster, # noqa: F401 + rollout_tokenizer, # noqa: F401 +) +from tests.unit.test_envs import MultiStepCalcMetadata + + +def _run(coro): + return asyncio.run(coro) + + +class _FakeBuffer: + """Minimal TQReplayBuffer stand-in that records reserve/commit calls.""" + + def __init__(self) -> None: + self.reserve_calls: list[int] = [] # weight_versions passed to reserve + self.commit_calls: list[tuple[str, object, int, int]] = [] + # reserve(weight_version=X) -> group_id; commit fills the slot. + self._slots: list[str] = [] + + def reserve(self, *, weight_version: int, group_id: str | None = None) -> str: + if group_id is None: + group_id = str(uuid.uuid4()) + self.reserve_calls.append(weight_version) + self._slots.append(group_id) + return group_id + + async def commit( + self, + group_id: str, + record, + start_weight_version: int, + end_weight_version: int, + ): + self.commit_calls.append( + (group_id, record, start_weight_version, end_weight_version) + ) + return record + + +class _FakeImpl: + """Stand-in for AsyncRolloutImpl that returns a sentinel record.""" + + def __init__(self, record="sentinel-record", on_run=None) -> None: + self._record = record + self._on_run = on_run + + async def run_rollout(self, input_sample): + if self._on_run is not None: + await self._on_run(input_sample) + return self._record + + +def _make_manager(buffer: _FakeBuffer, impl: _FakeImpl) -> RolloutManager: + """Build a RolloutManager without firing the real __init__.""" + mgr = object.__new__(RolloutManager) + mgr._impl = impl + mgr._tokenizer = None + mgr._num_generations_per_prompt = 1 + mgr._tq_buffer = buffer + mgr._weight_version = 0 + return mgr + + +class TestGenerateAndPushFlow: + def test_reserves_then_runs_then_commits(self): + events: list[str] = [] + buf = _FakeBuffer() + + async def _track_run(_sample): + events.append("run") + + impl = _FakeImpl(record="r0", on_run=_track_run) + mgr = _make_manager(buf, impl) + + # Wrap reserve/commit to log ordering. + original_reserve = buf.reserve + original_commit = buf.commit + + def _logged_reserve(**kwargs): + events.append("reserve") + return original_reserve(**kwargs) + + async def _logged_commit(*args, **kwargs): + events.append("commit") + return await original_commit(*args, **kwargs) + + buf.reserve = _logged_reserve # type: ignore[method-assign] + buf.commit = _logged_commit # type: ignore[method-assign] + + _run(mgr.generate_and_push({"prompt": "p"})) + + assert events == ["reserve", "run", "commit"] + assert buf.reserve_calls == [0] + assert len(buf.commit_calls) == 1 + gid, record, start_v, end_v = buf.commit_calls[0] + assert gid in buf._slots + assert record == "r0" + assert start_v == 0 + assert end_v == 0 + + def test_start_weight_version_pinned_at_reserve_time(self): + """If set_weight_version is called mid-rollout, start != end.""" + buf = _FakeBuffer() + + async def _bump_weight_mid_rollout(_sample): + # Simulate a sync_weights bump during the rollout. + mgr.set_weight_version(5) + + impl = _FakeImpl(record="r0", on_run=_bump_weight_mid_rollout) + mgr = _make_manager(buf, impl) + mgr.set_weight_version(3) + + _run(mgr.generate_and_push({"prompt": "p"})) + + # reserve happened before run_rollout → captured weight 3. + assert buf.reserve_calls == [3] + # commit's start is the same dispatch-time value; end reflects the post-rollout weight. + _, _, start_v, end_v = buf.commit_calls[0] + assert start_v == 3 + assert end_v == 5 + + def test_no_weight_change_means_start_equals_end(self): + buf = _FakeBuffer() + impl = _FakeImpl(record="r0") + mgr = _make_manager(buf, impl) + mgr.set_weight_version(7) + + _run(mgr.generate_and_push({"prompt": "p"})) + + _, _, start_v, end_v = buf.commit_calls[0] + assert start_v == 7 + assert end_v == 7 + + def test_concurrent_dispatch_preserves_reserve_order(self): + """Two concurrent generate_and_push calls must reserve before either commits. + + The contract: reserve order == dispatch order, even if rollouts finish + out of order. Slot order in the buffer reflects the order reserve was + called (not the order run_rollout completed). + """ + buf = _FakeBuffer() + + # First call's rollout blocks until second call has reserved. + first_reserved = asyncio.Event() + second_reserved = asyncio.Event() + + async def _first_run(_sample): + first_reserved.set() + await second_reserved.wait() + + async def _second_run(_sample): + # Second is dispatched only after first reserves, so by the time + # second's reserve fires, slots[0] == first's gid. + second_reserved.set() + + first_impl = _FakeImpl(record="r0", on_run=_first_run) + second_impl = _FakeImpl(record="r1", on_run=_second_run) + + first_mgr = _make_manager(buf, first_impl) + # Share buffer across two managers (mimics two dispatches from one pump). + second_mgr = object.__new__(RolloutManager) + second_mgr._impl = second_impl + second_mgr._tokenizer = None + second_mgr._num_generations_per_prompt = 1 + second_mgr._tq_buffer = buf + second_mgr._weight_version = 0 + + async def _drive(): + t1 = asyncio.create_task(first_mgr.generate_and_push({"prompt": "p1"})) + # Wait until first has reserved before kicking off second so the + # reserve ordering is deterministic. + await first_reserved.wait() + t2 = asyncio.create_task(second_mgr.generate_and_push({"prompt": "p2"})) + await asyncio.gather(t1, t2) + + _run(_drive()) + + # Slots in buffer == reserve order. + first_gid, second_gid = buf._slots + # Commit recorded both, in either order, but each maps to its own gid. + commit_gids = [c[0] for c in buf.commit_calls] + assert set(commit_gids) == {first_gid, second_gid} + assert buf.reserve_calls == [0, 0] + + def test_requires_tq_buffer(self): + mgr = _make_manager(_FakeBuffer(), _FakeImpl()) + mgr._tq_buffer = None + with pytest.raises(AssertionError, match="tq_buffer"): + _run(mgr.generate_and_push({"prompt": "p"})) + + +# --------------------------------------------------------------------------- +# Tests for RolloutManager +# --------------------------------------------------------------------------- + + +def test_rollout_manager_raises_without_impl_params(): + """RolloutManager raises AssertionError when required params are missing.""" + common = { + "tokenizer": None, + "env_handles": {}, + "num_generations_per_prompt": 1, + "max_seq_len": 1, + } + + with pytest.raises(AssertionError, match="num_generations_per_prompt must be >= 1"): + updated_common = common.copy() + updated_common["num_generations_per_prompt"] = 0 + RolloutManager(**updated_common, use_nemo_gym=False) + + with pytest.raises(AssertionError, match="policy_generation is required"): + RolloutManager(**common, use_nemo_gym=False) + + with pytest.raises(AssertionError, match="generation_config is required"): + RolloutManager(**common, use_nemo_gym=True) + + +# --------------------------------------------------------------------------- +# Tests for AsyncRolloutManager (native async path) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="function") +def single_multi_step_calculator_input_sample(rollout_tokenizer): # noqa: F811 + """Returns a single DatumSpec prompt dict (problem 0) for AsyncRolloutManager tests.""" + problem_text = "(5 + 3) * 2" + expected_answer = 16.0 + max_steps = 5 + + tool_instructions = ( + "You have a calculator tool. To use it, respond with:\n" + "'[operand1, operand2, operation_name]'\n" + "The valid 'operation_name' values are exactly: 'sum', 'diff', 'prod', 'div'.\n" + "Example: [5, 3, sum]\n" + "You will receive the result of your calculation as ...\n" + "Use this result to make the next calculation if needed.\n" + "IMPORTANT: Only perform one calculation step (one tool call) before waiting for a result and making a new tool call.\n" + "IMPORTANT: Do not perform any other calculations or operations aside from the tool call and result. Doing so will result in failure.\n" + "To give the final answer, just output the number. numbers inside of don't count, so output just the final number yourself outside of this.\n" + "Example full output: [2, 4, sum]\n6.0\n[6, 6, diff]\n0.0 0\n(note how you have to output the final 0 outside of the tags)" + "------\n" + f"Solve: {problem_text}" + ) + + initial_prompt_content = rollout_tokenizer.apply_chat_template( + [{"role": "user", "content": tool_instructions}], + tokenize=False, + add_system_prompt=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + tokenized_prompt = rollout_tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = MultiStepCalcMetadata( + problem=problem_text, + expected_final_answer=expected_answer, + max_steps=max_steps, + current_step=0, + ) + return { + "message_log": message_log, + "extra_env_info": metadata, + "task_name": "multi_step_calculator_game", + "stop_strings": [""], + "idx": 0, + } + + +@pytest.mark.vllm +def test_async_rollout_manager( + multi_step_setup_vllm_async, # noqa: F811 + single_multi_step_calculator_input_sample, +): + """Standalone test for AsyncRolloutManager. + + Given 1 prompt with num_generations_per_prompt=N, asserts: + - output is a PromptGroupRecord with N Completion objects + - each Completion has a reward (float) and a non-empty message_log + - rollout_metrics has the expected keys with correct types + - completions hold independent (not aliased) message_log objects + """ + vllm_generation, tokenizer, env_handles, _, _ = multi_step_setup_vllm_async + input_sample = single_multi_step_calculator_input_sample + num_generations = 2 + max_seq_len = 1024 + max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 + + manager = RolloutManager( + use_nemo_gym=False, + tokenizer=tokenizer, + env_handles=env_handles, + num_generations_per_prompt=num_generations, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + policy_generation=vllm_generation, + ) + + vllm_generation.prepare_for_generation() + record = asyncio.run(manager.run_rollout(input_sample)) + vllm_generation.finish_generation() + + assert isinstance(record, PromptGroupRecord) + assert len(record.completions) == num_generations, ( + f"Expected {num_generations} completions, got {len(record.completions)}" + ) + assert record.prompt_idx == input_sample["idx"] + + for i, completion in enumerate(record.completions): + assert isinstance(completion, Completion) + + # 1. message_log length + assert len(completion.message_log) >= 4, ( + f"Completion {i}: expected >= 4 messages, got {len(completion.message_log)}" + ) + + # 2. last assistant content + last_assistant = next( + (m for m in reversed(completion.message_log) if m["role"] == "assistant"), + None, + ) + assert last_assistant is not None, f"Completion {i}: no assistant message found" + assert last_assistant["content"].strip() == "16", ( + f"Completion {i}: last assistant content {last_assistant['content']!r} != '16'" + ) + + # 3. reward + assert completion.reward == 1.0, ( + f"Completion {i}: reward {completion.reward} != 1.0" + ) + + # completions must be independent objects + assert record.completions[0].message_log is not record.completions[1].message_log + + +@pytest.mark.vllm +def test_async_rollout_manager_truncation( + multi_step_setup_vllm_async, # noqa: F811 + single_multi_step_calculator_input_sample, +): + """Small max_seq_len forces truncation and truncation_rate=1.0.""" + vllm_generation, tokenizer, env_handles, _, _ = multi_step_setup_vllm_async + input_sample = single_multi_step_calculator_input_sample + num_generations = 2 + max_seq_len = 290 + max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 + + manager = RolloutManager( + use_nemo_gym=False, + tokenizer=tokenizer, + env_handles=env_handles, + num_generations_per_prompt=num_generations, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + policy_generation=vllm_generation, + ) + vllm_generation.prepare_for_generation() + record = asyncio.run(manager.run_rollout(input_sample)) + vllm_generation.finish_generation() + + assert len(record.completions) == num_generations + assert all(c.truncated for c in record.completions) + assert record.rollout_metrics["truncation_rate"] == 1.0 + assert record.rollout_metrics["natural_termination_rate"] == 0.0 + + +@pytest.mark.vllm +def test_async_rollout_manager_matches_original( + multi_step_setup_vllm_async, # noqa: F811 + single_multi_step_calculator_input_sample, +): + """Comparison test: AsyncRolloutManager output is structurally equivalent to the original. + + Calls run_async_multi_turn_rollout with a batch of N identical prompts, + then calls AsyncRolloutManager with 1 prompt and N generations. + Asserts that both produce N results with matching message-log depth, rewards, + and rollout_metrics numeric values. + + TODO: remove this test together with run_async_multi_turn_rollout when the legacy path is deleted. + """ + vllm_generation, tokenizer, env_handles, _, _ = multi_step_setup_vllm_async + input_sample = single_multi_step_calculator_input_sample + num_generations = 2 + max_seq_len = 1024 + max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 + + # Build a batch of N identical prompts for the original function + batch = BatchedDataDict( + { + "message_log": [ + deepcopy(input_sample["message_log"]) for _ in range(num_generations) + ], + "extra_env_info": [ + deepcopy(input_sample["extra_env_info"]) for _ in range(num_generations) + ], + "task_name": [input_sample["task_name"]] * num_generations, + "stop_strings": [input_sample["stop_strings"]] * num_generations, + "idx": list(range(num_generations)), + "loss_multiplier": [1.0] * num_generations, + } + ) + + vllm_generation.prepare_for_generation() + original_batch, original_metrics = run_async_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=batch, + tokenizer=tokenizer, + task_to_env=env_handles, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + ) + + manager = RolloutManager( + use_nemo_gym=False, + tokenizer=tokenizer, + env_handles=env_handles, + num_generations_per_prompt=num_generations, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + policy_generation=vllm_generation, + ) + record = asyncio.run(manager.run_rollout(input_sample)) + vllm_generation.finish_generation() + + # Both should produce N results + assert len(original_batch["message_log"]) == num_generations + assert len(record.completions) == num_generations + + for i in range(num_generations): + orig_msg_log = original_batch["message_log"][i] + new_msg_log = record.completions[i].message_log + + # 1. message_log length matches + assert len(orig_msg_log) == len(new_msg_log), ( + f"Completion {i}: message_log length {len(new_msg_log)} != original {len(orig_msg_log)}" + ) + + # 2. last assistant content matches + def _last_assistant_content(msg_log): + for m in reversed(msg_log): + if m["role"] == "assistant": + return m.get("content", "") + return "" + + orig_last = _last_assistant_content(orig_msg_log) + new_last = _last_assistant_content(new_msg_log) + assert orig_last == new_last, ( + f"Completion {i}: last assistant content mismatch\n" + f" original: {orig_last!r}\n" + f" manager: {new_last!r}" + ) + + # 3. reward matches + orig_reward = original_batch["total_reward"][i].item() + new_reward = record.completions[i].reward + assert orig_reward == new_reward, ( + f"Completion {i}: reward mismatch — original {orig_reward}, manager {new_reward}" + ) + + # 4. rollout_metrics numeric values match (timing and histogram fields are excluded). + # The new impl emits slash-style keys (X/mean, X/max, X/min) via _calculate_single_metric; + # translate the legacy prefix-style keys before comparing. + def _translate_legacy_key(key: str) -> str: + if key == "avg_turns_per_sample": + return "turns_per_sample/mean" + if key == "max_turns_reached_rate": + return key + for prefix, suffix in (("mean_", "/mean"), ("max_", "/max"), ("min_", "/min")): + if key.startswith(prefix): + return f"{key[len(prefix) :]}{suffix}" + return key + + new_metrics = record.rollout_metrics + for key in original_metrics.keys(): + if key.startswith("timing/") or key.startswith("histogram/"): + continue + + new_key = _translate_legacy_key(key) + assert new_key in new_metrics, ( + f"rollout_metrics[{new_key!r}] missing from manager" + ) + + orig_val = original_metrics[key] + new_val = new_metrics[new_key] + + assert type(orig_val) == type(new_val), ( + f"rollout_metrics[{key!r}] type mismatch: {type(orig_val)} != {type(new_val)}" + ) + if not isinstance(orig_val, (bool, int, float)): + continue + + assert orig_val == pytest.approx(new_val), ( + f"rollout_metrics[{key!r}] mismatch — original {orig_val}, manager {new_val}" + ) + + +# --------------------------------------------------------------------------- +# Tests for AsyncNemoGymRolloutManager +# --------------------------------------------------------------------------- + + +@pytest.mark.nemo_gym +def test_async_nemo_gym_rollout_manager( + nemo_gym, # noqa: F811 + nemo_gym_vllm_generation, # noqa: F811 + nemo_gym_sanity_test_data, # noqa: F811 + nemo_gym_tokenizer, # noqa: F811 +): + """Standalone test for AsyncNemoGymRolloutManager. + + Given 1 prompt with num_generations_per_prompt=N, asserts: + - output is a PromptGroupRecord with N Completion objects + - each Completion has a reward (float) and a non-empty message_log + - completions hold independent message_log objects + + If the result here does not match, please check the following: + 1. Test data changed: re-run test_nemo_gym_sanity (tests/unit/environments/test_nemo_gym.py) + and use _write_actual_test_data output to refresh test_nemo_gym_sanity.json. + 2. Logic changed: inspect recent changes to AsyncNemoGymRolloutManager or the gym env. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for data in nemo_gym_sanity_test_data["input"]: + f.write(json.dumps(data) + "\n") + data_path = f.name + + dataset = NemoGymDataset(data_path) + examples = [ + nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) + for idx in range(len(dataset.dataset)) + ] + input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) + + # Use only the first prompt + single_prompt = { + "message_log": input_batch["message_log"][0], + "extra_env_info": input_batch["extra_env_info"][0], + "task_name": "nemo_gym", + "idx": 0, + "loss_multiplier": float(input_batch["loss_multiplier"][0]), + } + num_generations = 2 + + manager = RolloutManager( + use_nemo_gym=True, + tokenizer=nemo_gym_tokenizer, + env_handles={"nemo_gym": nemo_gym}, + num_generations_per_prompt=num_generations, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + generation_config=nemo_gym_vllm_generation.cfg, + ) + record = asyncio.run(manager.run_rollout(single_prompt)) + + assert isinstance(record, PromptGroupRecord) + assert len(record.completions) == num_generations, ( + f"Expected {num_generations} completions, got {len(record.completions)}" + ) + assert record.prompt_idx == 0 + + for i, completion in enumerate(record.completions): + assert isinstance(completion, Completion) + + # 1. message_log length + assert len(completion.message_log) == 2, ( + f"Completion {i}: expected 2 messages, got {len(completion.message_log)}" + ) + + # 2. last assistant token_ids + last_assistant = next( + (m for m in reversed(completion.message_log) if m["role"] == "assistant"), + None, + ) + assert last_assistant is not None, f"Completion {i}: no assistant message found" + assert torch.equal( + last_assistant["token_ids"], + torch.tensor([151667, 198, 32313, 11, 1077]), + ), ( + f"Completion {i}: last assistant token_ids {last_assistant['token_ids'].tolist()} " + f"!= [151667, 198, 32313, 11, 1077]" + ) + + # 3. reward + assert completion.reward == 0.0, ( + f"Completion {i}: reward {completion.reward} != 0.0" + ) + + # completions must be independent objects + assert record.completions[0].message_log is not record.completions[1].message_log + + +@pytest.mark.nemo_gym +def test_async_nemo_gym_rollout_manager_matches_original( + nemo_gym, # noqa: F811 + nemo_gym_vllm_generation, # noqa: F811 + nemo_gym_sanity_test_data, # noqa: F811 + nemo_gym_tokenizer, # noqa: F811 +): + """Comparison test: AsyncNemoGymRolloutManager output is structurally equivalent to the original. + + Calls run_async_nemo_gym_rollout with a batch of N identical rows, + then calls AsyncNemoGymRolloutManager with 1 prompt, N generations. + Asserts that both produce N results and rewards are in the same numeric domain. + + TODO: remove this test together with run_async_nemo_gym_rollout when the legacy path is deleted. + """ + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for data in nemo_gym_sanity_test_data["input"]: + f.write(json.dumps(data) + "\n") + data_path = f.name + + dataset = NemoGymDataset(data_path) + examples = [ + nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) + for idx in range(len(dataset.dataset)) + ] + input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) + + num_generations = 2 + single_prompt = { + "message_log": input_batch["message_log"][0], + "extra_env_info": input_batch["extra_env_info"][0], + "task_name": "nemo_gym", + "idx": 0, + "loss_multiplier": float(input_batch["loss_multiplier"][0]), + } + + # Build a batch of N identical rows for the original function + repeated_batch = BatchedDataDict( + { + "message_log": [ + deepcopy(input_batch["message_log"][0]) for _ in range(num_generations) + ], + "extra_env_info": [ + deepcopy(input_batch["extra_env_info"][0]) + for _ in range(num_generations) + ], + "loss_multiplier": input_batch["loss_multiplier"][0:1].repeat( + num_generations + ), + "idx": list(range(num_generations)), + "task_name": ["nemo_gym"] * num_generations, + } + ) + + original_result = run_async_nemo_gym_rollout( + policy_generation=nemo_gym_vllm_generation, + input_batch=repeated_batch, + tokenizer=nemo_gym_tokenizer, + task_to_env={"nemo_gym": nemo_gym}, + generation_config=nemo_gym_vllm_generation.cfg, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + max_rollout_turns=None, + ) + + manager = RolloutManager( + use_nemo_gym=True, + tokenizer=nemo_gym_tokenizer, + env_handles={"nemo_gym": nemo_gym}, + num_generations_per_prompt=num_generations, + max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], + generation_config=nemo_gym_vllm_generation.cfg, + ) + record = asyncio.run(manager.run_rollout(single_prompt)) + + # Both should produce N completions + assert len(original_result.final_batch["message_log"]) == num_generations + assert len(record.completions) == num_generations + + for i in range(num_generations): + orig_msg_log = original_result.final_batch["message_log"][i] + new_msg_log = record.completions[i].message_log + + # 1. message_log length matches + assert len(orig_msg_log) == len(new_msg_log), ( + f"Completion {i}: message_log length {len(new_msg_log)} != original {len(orig_msg_log)}" + ) + + # 2. last assistant token_ids match + def _last_assistant_token_ids(msg_log): + for m in reversed(msg_log): + if m["role"] == "assistant": + return m.get("token_ids") + return None + + orig_token_ids = _last_assistant_token_ids(orig_msg_log) + new_token_ids = _last_assistant_token_ids(new_msg_log) + assert orig_token_ids is not None, ( + f"Completion {i}: no assistant message in original" + ) + assert new_token_ids is not None, ( + f"Completion {i}: no assistant message in manager" + ) + assert torch.equal(orig_token_ids, new_token_ids), ( + f"Completion {i}: last assistant token_ids mismatch\n" + f" original: {orig_token_ids.tolist()}\n" + f" manager: {new_token_ids.tolist()}" + ) + + # 3. reward matches + orig_reward = original_result.final_batch["total_reward"][i].item() + new_reward = record.completions[i].reward + assert orig_reward == new_reward, ( + f"Completion {i}: reward mismatch — original {orig_reward}, manager {new_reward}" + ) + + # 4. rollout_metrics numeric values match (timing and Table fields are excluded) + orig_metrics = original_result.rollout_metrics + new_metrics = record.rollout_metrics + for key in orig_metrics.keys(): + # Skip timing and full_result fields + if key.startswith("timing/") or key.endswith("/full_result"): + continue + + # Check that the key is present in the new metrics + assert key in new_metrics, f"rollout_metrics[{key!r}] missing from manager" + + orig_val = orig_metrics[key] + new_val = new_metrics[key] + + # Skip non-numeric fields + assert type(orig_val) == type(new_val), ( + f"rollout_metrics[{key!r}] type mismatch: {type(orig_val)} != {type(new_val)}" + ) + if not isinstance(orig_val, (bool, int, float)): + continue + + # Check equal + assert orig_val == pytest.approx(new_val), ( + f"rollout_metrics[{key!r}] mismatch — original {orig_val}, manager {new_val}" + ) diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 5bd124d935..91cf854cde 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -37,8 +37,6 @@ SlidingPuzzleGameLogic, SlidingPuzzleMetadata, ) -from nemo_rl.experience.interfaces import Completion, PromptGroupRecord -from nemo_rl.experience.rollout_manager import RolloutManager from nemo_rl.experience.rollouts import ( _calculate_single_metric, generate_responses_async, @@ -1033,550 +1031,3 @@ def _standardize(d: dict) -> dict: 1. In nemo_rl/experience/rollouts.py::run_async_nemo_gym_rollout, the sampling params are passed appropriately 2. In nemo_rl/models/generation/vllm/vllm_worker_async.py::VllmAsyncGenerationWorker::_setup_vllm_server::create_chat_completion, the sampling params (like top_k) are set as appropriate """ - - -# --------------------------------------------------------------------------- -# Tests for RolloutManager -# --------------------------------------------------------------------------- - - -def test_rollout_manager_raises_without_impl_params(): - """RolloutManager raises AssertionError when required params are missing.""" - common = { - "tokenizer": None, - "task_to_env": {}, - "num_generations_per_prompt": 1, - "max_seq_len": 1, - } - - with pytest.raises(AssertionError, match="num_generations_per_prompt must be >= 1"): - updated_common = common.copy() - updated_common["num_generations_per_prompt"] = 0 - RolloutManager(**updated_common, use_nemo_gym=False) - - with pytest.raises(AssertionError, match="policy_generation is required"): - RolloutManager(**common, use_nemo_gym=False) - - with pytest.raises(AssertionError, match="generation_config is required"): - RolloutManager(**common, use_nemo_gym=True) - - -# --------------------------------------------------------------------------- -# Tests for AsyncRolloutManager (native async path) -# --------------------------------------------------------------------------- - - -@pytest.fixture(scope="function") -def single_multi_step_calculator_input_sample(rollout_tokenizer): - """Returns a single DatumSpec prompt dict (problem 0) for AsyncRolloutManager tests.""" - problem_text = "(5 + 3) * 2" - expected_answer = 16.0 - max_steps = 5 - - tool_instructions = ( - "You have a calculator tool. To use it, respond with:\n" - "'[operand1, operand2, operation_name]'\n" - "The valid 'operation_name' values are exactly: 'sum', 'diff', 'prod', 'div'.\n" - "Example: [5, 3, sum]\n" - "You will receive the result of your calculation as ...\n" - "Use this result to make the next calculation if needed.\n" - "IMPORTANT: Only perform one calculation step (one tool call) before waiting for a result and making a new tool call.\n" - "IMPORTANT: Do not perform any other calculations or operations aside from the tool call and result. Doing so will result in failure.\n" - "To give the final answer, just output the number. numbers inside of don't count, so output just the final number yourself outside of this.\n" - "Example full output: [2, 4, sum]\n6.0\n[6, 6, diff]\n0.0 0\n(note how you have to output the final 0 outside of the tags)" - "------\n" - f"Solve: {problem_text}" - ) - - initial_prompt_content = rollout_tokenizer.apply_chat_template( - [{"role": "user", "content": tool_instructions}], - tokenize=False, - add_system_prompt=False, - add_generation_prompt=True, - add_special_tokens=False, - ) - tokenized_prompt = rollout_tokenizer( - initial_prompt_content, return_tensors="pt", add_special_tokens=False - )["input_ids"][0] - message_log = [ - { - "role": "user", - "content": initial_prompt_content, - "token_ids": tokenized_prompt, - } - ] - metadata = MultiStepCalcMetadata( - problem=problem_text, - expected_final_answer=expected_answer, - max_steps=max_steps, - current_step=0, - ) - return { - "message_log": message_log, - "extra_env_info": metadata, - "task_name": "multi_step_calculator_game", - "stop_strings": [""], - "idx": 0, - } - - -@pytest.mark.vllm -def test_async_rollout_manager( - multi_step_setup_vllm_async, - single_multi_step_calculator_input_sample, -): - """Standalone test for AsyncRolloutManager. - - Given 1 prompt with num_generations_per_prompt=N, asserts: - - output is a PromptGroupRecord with N Completion objects - - each Completion has a reward (float) and a non-empty message_log - - rollout_metrics has the expected keys with correct types - - completions hold independent (not aliased) message_log objects - """ - vllm_generation, rollout_tokenizer, task_to_env, _, _ = multi_step_setup_vllm_async - input_sample = single_multi_step_calculator_input_sample - num_generations = 2 - max_seq_len = 1024 - max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 - - manager = RolloutManager( - use_nemo_gym=False, - tokenizer=rollout_tokenizer, - task_to_env=task_to_env, - num_generations_per_prompt=num_generations, - max_seq_len=max_seq_len, - max_rollout_turns=max_rollout_turns, - policy_generation=vllm_generation, - ) - - vllm_generation.prepare_for_generation() - record = asyncio.run(manager.run_rollout(input_sample)) - vllm_generation.finish_generation() - - assert isinstance(record, PromptGroupRecord) - assert len(record.completions) == num_generations, ( - f"Expected {num_generations} completions, got {len(record.completions)}" - ) - assert record.prompt_idx == input_sample["idx"] - - for i, completion in enumerate(record.completions): - assert isinstance(completion, Completion) - - # 1. message_log length - assert len(completion.message_log) >= 4, ( - f"Completion {i}: expected >= 4 messages, got {len(completion.message_log)}" - ) - - # 2. last assistant content - last_assistant = next( - (m for m in reversed(completion.message_log) if m["role"] == "assistant"), - None, - ) - assert last_assistant is not None, f"Completion {i}: no assistant message found" - assert last_assistant["content"].strip() == "16", ( - f"Completion {i}: last assistant content {last_assistant['content']!r} != '16'" - ) - - # 3. reward - assert completion.reward == 1.0, ( - f"Completion {i}: reward {completion.reward} != 1.0" - ) - - # completions must be independent objects - assert record.completions[0].message_log is not record.completions[1].message_log - - -@pytest.mark.vllm -def test_async_rollout_manager_truncation( - multi_step_setup_vllm_async, - single_multi_step_calculator_input_sample, -): - """Small max_seq_len forces truncation and truncation_rate=1.0.""" - vllm_generation, rollout_tokenizer, task_to_env, _, _ = multi_step_setup_vllm_async - input_sample = single_multi_step_calculator_input_sample - num_generations = 2 - max_seq_len = 290 - max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 - - manager = RolloutManager( - use_nemo_gym=False, - tokenizer=rollout_tokenizer, - task_to_env=task_to_env, - num_generations_per_prompt=num_generations, - max_seq_len=max_seq_len, - max_rollout_turns=max_rollout_turns, - policy_generation=vllm_generation, - ) - vllm_generation.prepare_for_generation() - record = asyncio.run(manager.run_rollout(input_sample)) - vllm_generation.finish_generation() - - assert len(record.completions) == num_generations - assert all(c.truncated for c in record.completions) - assert record.rollout_metrics["truncation_rate"] == 1.0 - assert record.rollout_metrics["natural_termination_rate"] == 0.0 - - -@pytest.mark.vllm -def test_async_rollout_manager_matches_original( - multi_step_setup_vllm_async, - single_multi_step_calculator_input_sample, -): - """Comparison test: AsyncRolloutManager output is structurally equivalent to the original. - - Calls run_async_multi_turn_rollout with a batch of N identical prompts, - then calls AsyncRolloutManager with 1 prompt and N generations. - Asserts that both produce N results with matching message-log depth, rewards, - and rollout_metrics numeric values. - - TODO: remove this test together with run_async_multi_turn_rollout when the legacy path is deleted. - """ - vllm_generation, rollout_tokenizer, task_to_env, _, _ = multi_step_setup_vllm_async - input_sample = single_multi_step_calculator_input_sample - num_generations = 2 - max_seq_len = 1024 - max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 - - # Build a batch of N identical prompts for the original function - batch = BatchedDataDict( - { - "message_log": [ - deepcopy(input_sample["message_log"]) for _ in range(num_generations) - ], - "extra_env_info": [ - deepcopy(input_sample["extra_env_info"]) for _ in range(num_generations) - ], - "task_name": [input_sample["task_name"]] * num_generations, - "stop_strings": [input_sample["stop_strings"]] * num_generations, - "idx": list(range(num_generations)), - "loss_multiplier": [1.0] * num_generations, - } - ) - - vllm_generation.prepare_for_generation() - original_batch, original_metrics = run_async_multi_turn_rollout( - policy_generation=vllm_generation, - input_batch=batch, - tokenizer=rollout_tokenizer, - task_to_env=task_to_env, - max_seq_len=max_seq_len, - max_rollout_turns=max_rollout_turns, - ) - - manager = RolloutManager( - use_nemo_gym=False, - tokenizer=rollout_tokenizer, - task_to_env=task_to_env, - num_generations_per_prompt=num_generations, - max_seq_len=max_seq_len, - max_rollout_turns=max_rollout_turns, - policy_generation=vllm_generation, - ) - record = asyncio.run(manager.run_rollout(input_sample)) - vllm_generation.finish_generation() - - # Both should produce N results - assert len(original_batch["message_log"]) == num_generations - assert len(record.completions) == num_generations - - for i in range(num_generations): - orig_msg_log = original_batch["message_log"][i] - new_msg_log = record.completions[i].message_log - - # 1. message_log length matches - assert len(orig_msg_log) == len(new_msg_log), ( - f"Completion {i}: message_log length {len(new_msg_log)} != original {len(orig_msg_log)}" - ) - - # 2. last assistant content matches - def _last_assistant_content(msg_log): - for m in reversed(msg_log): - if m["role"] == "assistant": - return m.get("content", "") - return "" - - orig_last = _last_assistant_content(orig_msg_log) - new_last = _last_assistant_content(new_msg_log) - assert orig_last == new_last, ( - f"Completion {i}: last assistant content mismatch\n" - f" original: {orig_last!r}\n" - f" manager: {new_last!r}" - ) - - # 3. reward matches - orig_reward = original_batch["total_reward"][i].item() - new_reward = record.completions[i].reward - assert orig_reward == new_reward, ( - f"Completion {i}: reward mismatch — original {orig_reward}, manager {new_reward}" - ) - - # 4. rollout_metrics numeric values match (timing and histogram fields are excluded). - # The new impl emits slash-style keys (X/mean, X/max, X/min) via _calculate_single_metric; - # translate the legacy prefix-style keys before comparing. - def _translate_legacy_key(key: str) -> str: - if key == "avg_turns_per_sample": - return "turns_per_sample/mean" - if key == "max_turns_reached_rate": - return key - for prefix, suffix in (("mean_", "/mean"), ("max_", "/max"), ("min_", "/min")): - if key.startswith(prefix): - return f"{key[len(prefix) :]}{suffix}" - return key - - new_metrics = record.rollout_metrics - for key in original_metrics.keys(): - if key.startswith("timing/") or key.startswith("histogram/"): - continue - - new_key = _translate_legacy_key(key) - assert new_key in new_metrics, ( - f"rollout_metrics[{new_key!r}] missing from manager" - ) - - orig_val = original_metrics[key] - new_val = new_metrics[new_key] - - assert type(orig_val) == type(new_val), ( - f"rollout_metrics[{key!r}] type mismatch: {type(orig_val)} != {type(new_val)}" - ) - if not isinstance(orig_val, (bool, int, float)): - continue - - assert orig_val == pytest.approx(new_val), ( - f"rollout_metrics[{key!r}] mismatch — original {orig_val}, manager {new_val}" - ) - - -# --------------------------------------------------------------------------- -# Tests for AsyncNemoGymRolloutManager -# --------------------------------------------------------------------------- - - -@pytest.mark.nemo_gym -def test_async_nemo_gym_rollout_manager( - nemo_gym, # noqa: F811 - nemo_gym_vllm_generation, # noqa: F811 - nemo_gym_sanity_test_data, # noqa: F811 - nemo_gym_tokenizer, # noqa: F811 -): - """Standalone test for AsyncNemoGymRolloutManager. - - Given 1 prompt with num_generations_per_prompt=N, asserts: - - output is a PromptGroupRecord with N Completion objects - - each Completion has a reward (float) and a non-empty message_log - - completions hold independent message_log objects - - If the result here does not match, please check the following: - 1. Test data changed: re-run test_nemo_gym_sanity (tests/unit/environments/test_nemo_gym.py) - and use _write_actual_test_data output to refresh test_nemo_gym_sanity.json. - 2. Logic changed: inspect recent changes to AsyncNemoGymRolloutManager or the gym env. - """ - with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: - for data in nemo_gym_sanity_test_data["input"]: - f.write(json.dumps(data) + "\n") - data_path = f.name - - dataset = NemoGymDataset(data_path) - examples = [ - nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) - for idx in range(len(dataset.dataset)) - ] - input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) - - # Use only the first prompt - single_prompt = { - "message_log": input_batch["message_log"][0], - "extra_env_info": input_batch["extra_env_info"][0], - "task_name": "nemo_gym", - "idx": 0, - "loss_multiplier": float(input_batch["loss_multiplier"][0]), - } - num_generations = 2 - - manager = RolloutManager( - use_nemo_gym=True, - tokenizer=nemo_gym_tokenizer, - task_to_env={"nemo_gym": nemo_gym}, - num_generations_per_prompt=num_generations, - max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], - generation_config=nemo_gym_vllm_generation.cfg, - ) - record = asyncio.run(manager.run_rollout(single_prompt)) - - assert isinstance(record, PromptGroupRecord) - assert len(record.completions) == num_generations, ( - f"Expected {num_generations} completions, got {len(record.completions)}" - ) - assert record.prompt_idx == 0 - - for i, completion in enumerate(record.completions): - assert isinstance(completion, Completion) - - # 1. message_log length - assert len(completion.message_log) == 2, ( - f"Completion {i}: expected 2 messages, got {len(completion.message_log)}" - ) - - # 2. last assistant token_ids - last_assistant = next( - (m for m in reversed(completion.message_log) if m["role"] == "assistant"), - None, - ) - assert last_assistant is not None, f"Completion {i}: no assistant message found" - assert torch.equal( - last_assistant["token_ids"], - torch.tensor([151667, 198, 32313, 11, 1077]), - ), ( - f"Completion {i}: last assistant token_ids {last_assistant['token_ids'].tolist()} " - f"!= [151667, 198, 32313, 11, 1077]" - ) - - # 3. reward - assert completion.reward == 0.0, ( - f"Completion {i}: reward {completion.reward} != 0.0" - ) - - # completions must be independent objects - assert record.completions[0].message_log is not record.completions[1].message_log - - -@pytest.mark.nemo_gym -def test_async_nemo_gym_rollout_manager_matches_original( - nemo_gym, # noqa: F811 - nemo_gym_vllm_generation, # noqa: F811 - nemo_gym_sanity_test_data, # noqa: F811 - nemo_gym_tokenizer, # noqa: F811 -): - """Comparison test: AsyncNemoGymRolloutManager output is structurally equivalent to the original. - - Calls run_async_nemo_gym_rollout with a batch of N identical rows, - then calls AsyncNemoGymRolloutManager with 1 prompt, N generations. - Asserts that both produce N results and rewards are in the same numeric domain. - - TODO: remove this test together with run_async_nemo_gym_rollout when the legacy path is deleted. - """ - with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: - for data in nemo_gym_sanity_test_data["input"]: - f.write(json.dumps(data) + "\n") - data_path = f.name - - dataset = NemoGymDataset(data_path) - examples = [ - nemo_gym_data_processor(dataset.dataset[idx], None, None, None, idx) - for idx in range(len(dataset.dataset)) - ] - input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(examples) - - num_generations = 2 - single_prompt = { - "message_log": input_batch["message_log"][0], - "extra_env_info": input_batch["extra_env_info"][0], - "task_name": "nemo_gym", - "idx": 0, - "loss_multiplier": float(input_batch["loss_multiplier"][0]), - } - - # Build a batch of N identical rows for the original function - repeated_batch = BatchedDataDict( - { - "message_log": [ - deepcopy(input_batch["message_log"][0]) for _ in range(num_generations) - ], - "extra_env_info": [ - deepcopy(input_batch["extra_env_info"][0]) - for _ in range(num_generations) - ], - "loss_multiplier": input_batch["loss_multiplier"][0:1].repeat( - num_generations - ), - "idx": list(range(num_generations)), - "task_name": ["nemo_gym"] * num_generations, - } - ) - - original_result = run_async_nemo_gym_rollout( - policy_generation=nemo_gym_vllm_generation, - input_batch=repeated_batch, - tokenizer=nemo_gym_tokenizer, - task_to_env={"nemo_gym": nemo_gym}, - generation_config=nemo_gym_vllm_generation.cfg, - max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], - max_rollout_turns=None, - ) - - manager = RolloutManager( - use_nemo_gym=True, - tokenizer=nemo_gym_tokenizer, - task_to_env={"nemo_gym": nemo_gym}, - num_generations_per_prompt=num_generations, - max_seq_len=nemo_gym_vllm_generation.cfg["vllm_cfg"]["max_model_len"], - generation_config=nemo_gym_vllm_generation.cfg, - ) - record = asyncio.run(manager.run_rollout(single_prompt)) - - # Both should produce N completions - assert len(original_result.final_batch["message_log"]) == num_generations - assert len(record.completions) == num_generations - - for i in range(num_generations): - orig_msg_log = original_result.final_batch["message_log"][i] - new_msg_log = record.completions[i].message_log - - # 1. message_log length matches - assert len(orig_msg_log) == len(new_msg_log), ( - f"Completion {i}: message_log length {len(new_msg_log)} != original {len(orig_msg_log)}" - ) - - # 2. last assistant token_ids match - def _last_assistant_token_ids(msg_log): - for m in reversed(msg_log): - if m["role"] == "assistant": - return m.get("token_ids") - return None - - orig_token_ids = _last_assistant_token_ids(orig_msg_log) - new_token_ids = _last_assistant_token_ids(new_msg_log) - assert orig_token_ids is not None, ( - f"Completion {i}: no assistant message in original" - ) - assert new_token_ids is not None, ( - f"Completion {i}: no assistant message in manager" - ) - assert torch.equal(orig_token_ids, new_token_ids), ( - f"Completion {i}: last assistant token_ids mismatch\n" - f" original: {orig_token_ids.tolist()}\n" - f" manager: {new_token_ids.tolist()}" - ) - - # 3. reward matches - orig_reward = original_result.final_batch["total_reward"][i].item() - new_reward = record.completions[i].reward - assert orig_reward == new_reward, ( - f"Completion {i}: reward mismatch — original {orig_reward}, manager {new_reward}" - ) - - # 4. rollout_metrics numeric values match (timing and Table fields are excluded) - orig_metrics = original_result.rollout_metrics - new_metrics = record.rollout_metrics - for key in orig_metrics.keys(): - # Skip timing and full_result fields - if key.startswith("timing/") or key.endswith("/full_result"): - continue - - # Check that the key is present in the new metrics - assert key in new_metrics, f"rollout_metrics[{key!r}] missing from manager" - - orig_val = orig_metrics[key] - new_val = new_metrics[key] - - # Skip non-numeric fields - assert type(orig_val) == type(new_val), ( - f"rollout_metrics[{key!r}] type mismatch: {type(orig_val)} != {type(new_val)}" - ) - if not isinstance(orig_val, (bool, int, float)): - continue - - # Check equal - assert orig_val == pytest.approx(new_val), ( - f"rollout_metrics[{key!r}] mismatch — original {orig_val}, manager {new_val}" - ) diff --git a/tests/unit/models/policy/test_megatron_split_state.py b/tests/unit/models/policy/test_megatron_split_state.py new file mode 100644 index 0000000000..fb50c839ef --- /dev/null +++ b/tests/unit/models/policy/test_megatron_split_state.py @@ -0,0 +1,638 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU state-machine tests for MegatronPolicyWorkerImpl's split-API. + +These tests cover the lifecycle and call-order invariants — they do NOT +exercise real distributed comms, the mcore scheduler, or the optimizer. +Numerical equivalence vs sync ``train()`` lives in the GPU parity tests. + +The bugs these catch: + - silent gradient over-counting if ``model.no_sync()`` is not wrapped + around ``megatron_forward_backward`` (the mcore DDP hooks would + dispatch a per-call reduce, ADDING to an already-reduced bucket). + - PP>1 pipeline-schedule bypass if ``model.config.grad_sync_func`` is + not nulled for the step's duration. + - ``trainer_version`` advancing on abort. + - ``zero_grad_buffer`` not called at begin (mcore's contiguous grad + buffer leaks stale grads otherwise). + - off-by-one in ``total_num_microbatches`` (used to scale MoE aux-loss). +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# megatron.bridge is only available with the mcore extras. Without it the +# eager import of megatron_policy_worker (transitively imports megatron.bridge) +# fails at COLLECTION time on non-mcore shards, which then breaks every other +# test in that shard. importorskip stops collection cleanly here. +pytest.importorskip("megatron.bridge") + +# Eagerly import the worker module so ``unittest.mock.patch`` can resolve +# attributes on it via ``getattr``. Without this the patch path +# ``nemo_rl.models.policy.workers.megatron_policy_worker.`` fails +# at ``getattr(workers, "megatron_policy_worker")``. +import nemo_rl.models.policy.workers.megatron_policy_worker # noqa: E402,F401 + +pytestmark = pytest.mark.mcore + +# Module path of the worker under test +WORKER_MOD = "nemo_rl.models.policy.workers.megatron_policy_worker" + + +# ── Mock fabric ────────────────────────────────────────────────────────── + + +def _make_mock_model(): + """A mcore-DDP-shaped mock: exposes the methods + attributes the + split-API touches, plus an ``inference_params`` attribute and a + ``modules()`` that yields nothing (so the inference-cache reset loop + is a no-op).""" + model = MagicMock() + model.config = MagicMock() + model.config.grad_sync_func = "ORIGINAL_GRAD_SYNC_FUNC" # sentinel + model.config.num_moe_experts = None # disable MoE branch + # no_sync() is a context manager — return a MagicMock that supports + # __enter__/__exit__ so the `with self.model.no_sync():` block works. + model.no_sync = MagicMock( + return_value=MagicMock( + __enter__=MagicMock(return_value=None), + __exit__=MagicMock(return_value=False), + ) + ) + model.modules = MagicMock(return_value=iter([])) + model.inference_params = None + model.parameters = MagicMock( + return_value=iter([]) + ) # no params for the rescale loop + return model + + +def _make_worker(loss_type): + """Construct a MegatronPolicyWorkerImpl instance with all heavy + attributes mocked. Bypasses __init__ via ``object.__new__``.""" + # Lazy import so the module-level mcore imports happen inside the + # mcore-marked test process. + from nemo_rl.models.policy.workers.megatron_policy_worker import ( + MegatronPolicyWorkerImpl, + ) + + w = object.__new__(MegatronPolicyWorkerImpl) + w.model = _make_mock_model() + w.optimizer = MagicMock() + # MegatronOptimizer.step returns (success, grad_norm, num_zeros) + w.optimizer.step.return_value = (True, 0.5, 0) + w.optimizer.param_groups = [{"lr": 1e-4, "weight_decay": 0.01}] + w.scheduler = MagicMock() + w.scheduler.get_lr.return_value = 1e-4 + w.scheduler.get_wd.return_value = 0.01 + w.mcore_state = MagicMock() + w.mcore_state.straggler_timer = None + w.cfg = { + "train_global_batch_size": 32, + "train_micro_batch_size": 4, + "megatron_cfg": { + "empty_unused_memory_level": 0, + "moe_per_layer_logging": False, + "use_linear_ce_fusion_loss": False, + }, + } + w.dp_size = 2 + w.cp_size = 1 + w.sampling_params = None + w.draft_model = None + w.defer_fp32_logits = False + w.dtype = torch.float32 + w._is_reward_model = False + + # Stash a loss_fn with the requested loss_type for tests that need one. + w._test_loss_fn = MagicMock(loss_type=loss_type) + return w + + +@pytest.fixture +def mock_module_symbols(): + """Patch every module-level symbol that the split-API methods call + into. Yields a dict of name → mock for assertions.""" + # Make `aggregate_training_statistics` return ({}, scalar) — what the + # finish path expects. + agg_ret = ({"loss": [0.0]}, torch.tensor(0.5)) + + patches = { + "megatron_forward_backward": [ + {"loss": 0.5, "global_valid_seqs": 8.0, "global_valid_toks": 256.0} + ], + "get_microbatch_iterator": (iter([]), 2, 4, 16, 16), # 2 pipeline mbs per call + "LossPostProcessor": MagicMock(), + "broadcast_loss_metrics_from_last_stage": lambda m: m, + "get_pg_collection": MagicMock(mp=MagicMock()), + "logical_and_across_model_parallel_group": lambda v, mp_group: v, + "reduce_max_stat_across_model_parallel_group": lambda v, mp_group: v, + "aggregate_training_statistics": agg_ret, + "get_moe_metrics": MagicMock(return_value={}), + } + + with ( + patch( + f"{WORKER_MOD}.megatron_forward_backward", + return_value=patches["megatron_forward_backward"], + ) as mfb, + patch( + f"{WORKER_MOD}.get_microbatch_iterator", + return_value=patches["get_microbatch_iterator"], + ) as gmi, + patch( + f"{WORKER_MOD}.LossPostProcessor", return_value=patches["LossPostProcessor"] + ) as lpp, + patch( + f"{WORKER_MOD}.broadcast_loss_metrics_from_last_stage", + side_effect=patches["broadcast_loss_metrics_from_last_stage"], + ) as bcast, + patch( + f"{WORKER_MOD}.get_pg_collection", return_value=patches["get_pg_collection"] + ) as gpgc, + patch( + f"{WORKER_MOD}.logical_and_across_model_parallel_group", + side_effect=patches["logical_and_across_model_parallel_group"], + ) as land, + patch( + f"{WORKER_MOD}.reduce_max_stat_across_model_parallel_group", + side_effect=patches["reduce_max_stat_across_model_parallel_group"], + ) as rmax, + patch( + f"{WORKER_MOD}.aggregate_training_statistics", + return_value=patches["aggregate_training_statistics"], + ) as agg, + patch(f"{WORKER_MOD}.get_moe_metrics", return_value={}) as moe, + patch(f"{WORKER_MOD}.get_rerun_state_machine") as grsm, + patch(f"{WORKER_MOD}.parallel_state") as pstate, + patch("torch.distributed.all_reduce") as ar, + patch("torch.cuda.empty_cache") as cec, + patch("torch.cuda.get_device_name", return_value="H100"), + patch("torch.distributed.get_rank", return_value=0), + ): + # rerun state machine: fire forward+backward once per train_microbatch + rsm = MagicMock() + rsm.should_run_forward_backward.side_effect = [True, False] * 100 + grsm.return_value = rsm + + # parallel_state mocks + pstate.is_pipeline_last_stage.return_value = True + pstate.get_data_parallel_group.return_value = MagicMock() + + yield { + "mfb": mfb, + "gmi": gmi, + "lpp": lpp, + "bcast": bcast, + "gpgc": gpgc, + "land": land, + "rmax": rmax, + "agg": agg, + "moe": moe, + "grsm": grsm, + "pstate": pstate, + "all_reduce": ar, + "empty_cache": cec, + } + + +def _fake_batch(): + """A minimal BatchedDataDict-ish object the mask-sum block can read. + train_microbatch reads ``data["sample_mask"]``, ``data["token_mask"]``, + and (only as a fallback for the no-token-mask path) ``data["input_ids"]``.""" + # 8 samples, all valid (mask=1); 256 valid tokens each + sample_mask = torch.ones(8, dtype=torch.float32) + token_mask = torch.ones(8, 257, dtype=torch.float32) # token_mask[:, 1:] → 256 toks + input_ids = torch.zeros(8, 257, dtype=torch.long) + return { + "sample_mask": sample_mask, + "token_mask": token_mask, + "input_ids": input_ids, + } + + +# ── BEGIN ──────────────────────────────────────────────────────────────── + + +class TestBegin: + def test_opens_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn, gbs=16, mbs=4) + assert w._train_step_state is not None + assert w._train_step_state["step_id"] == "step-0" + assert w._train_step_state["loss_type"] == LossType.TOKEN_LEVEL + assert w._train_step_state["gbs"] == 16 + assert w._train_step_state["mbs"] == 4 + assert w._train_step_state["total_num_microbatches"] == 0 + + def test_calls_zero_grad_and_zero_grad_buffer(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + w.model.train.assert_called_once() + + def test_saves_and_nulls_grad_sync_func(self, mock_module_symbols): + """The PP scheduler's direct reduce dispatch must be suppressed + for the duration of the step. Otherwise PP>1 silently corrupts + grads even when ``no_sync`` is set on the bucket groups.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + assert w._train_step_state["saved_grad_sync_func"] == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_double_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="already open"): + w.begin_train_step("step-1", loss_fn=w._test_loss_fn) + + def test_uses_cfg_defaults_when_gbs_mbs_omitted(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w._train_step_state["gbs"] == w.cfg["train_global_batch_size"] + assert w._train_step_state["mbs"] == w.cfg["train_micro_batch_size"] + + +# ── _assert_step_open ──────────────────────────────────────────────────── + + +class TestAssertStepOpen: + def test_raises_when_no_step_open(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w._assert_step_open("step-0") + + def test_raises_on_step_id_mismatch(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-correct", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="step_id mismatch"): + w._assert_step_open("step-WRONG") + + def test_train_microbatch_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.train_microbatch("step-0", _fake_batch()) + + def test_finish_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.finish_train_step("step-0") + + +# ── train_microbatch ───────────────────────────────────────────────────── + + +class TestTrainMicrobatch: + def test_wraps_forward_backward_in_no_sync(self, mock_module_symbols): + """The single most important assertion in this file. Without the + no_sync wrap, mcore DDP dispatches a per-call cross-DP reduce on + the partially-accumulated buffer — silently corrupting grads.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + # no_sync() must have been ENTERED (called as a context manager). + # MagicMock with __enter__/__exit__ records the __enter__ call. + ctx = w.model.no_sync.return_value + ctx.__enter__.assert_called() + ctx.__exit__.assert_called() + + def test_invokes_megatron_forward_backward_once(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + assert mock_module_symbols["mfb"].call_count == 1 + + def test_passes_placeholder_n_one_to_loss(self, mock_module_symbols): + """The N=1 trick: loss must be called with global_valid_*=1 so it + returns un-normalized sums; finish does the 1/N rescale.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + kwargs = mock_module_symbols["mfb"].call_args.kwargs + # placeholder_n is a tensor(1.0) + assert "global_valid_seqs" in kwargs + assert "global_valid_toks" in kwargs + assert float(kwargs["global_valid_seqs"].item()) == pytest.approx(1.0) + assert float(kwargs["global_valid_toks"].item()) == pytest.approx(1.0) + + def test_accumulates_mask_sums_across_calls(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # _fake_batch has sample_mask sum = 8, token_mask*sample_mask sum = 8*256 = 2048 + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 8.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 2048.0 + ) + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 16.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 4096.0 + ) + + def test_total_num_microbatches_accumulates(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # get_microbatch_iterator mock returns num_microbatches=2 per call + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + assert w._train_step_state["total_num_microbatches"] == 6 + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + """trainer_version semantics: optimizer.step() must NOT fire + per train_microbatch — only at finish.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.optimizer.step.assert_not_called() + + +# ── finish_train_step ──────────────────────────────────────────────────── + + +class TestFinish: + def _setup_open_step(self, mock_module_symbols, loss_type): + w = _make_worker(loss_type) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + return w + + def test_rescales_grads_with_inv_n(self, mock_module_symbols): + """The 1/N rescale must happen ON the local main_grad BEFORE the + cross-DP reduce — otherwise the reduce sees un-rescaled sums.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # scale_gradients should have been called with some 1/N scalar < 1 + w.model.scale_gradients.assert_called_once() + arg = w.model.scale_gradients.call_args.args[0] + assert 0 < arg <= 1.0 + + def test_start_then_finish_grad_sync_called_after_rescale( + self, mock_module_symbols + ): + """Call order matters: scale_gradients -> start_grad_sync -> + finish_grad_sync -> optimizer.step.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + # Record call order via a shared list + order: list[str] = [] + w.model.scale_gradients.side_effect = lambda s: order.append("scale") + w.model.start_grad_sync.side_effect = lambda: order.append("start_sync") + w.model.finish_grad_sync.side_effect = lambda: order.append("finish_sync") + w.optimizer.step.side_effect = lambda: ( + order.append("opt_step") or (True, 0.5, 0) + ) + w.finish_train_step("s0") + assert order == ["scale", "start_sync", "finish_sync", "opt_step"] + + def test_picks_global_valid_toks_for_token_level_loss(self, mock_module_symbols): + """N selection: TOKEN_LEVEL → N = global_valid_toks (not seqs).""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # local_valid_toks accumulated = 2048; with mocked all_reduce as no-op, + # global_valid_toks == 2048 → inv_n = 1/2048 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 2048.0, rel=1e-4) + + def test_picks_global_valid_seqs_for_sequence_level_loss(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.SEQUENCE_LEVEL) + w.finish_train_step("s0") + # local_valid_seqs = 8 → inv_n = 1/8 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 8.0, rel=1e-4) + + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w._train_step_state is None + + def test_calls_scheduler_step_with_increment_gbs(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w._train_step_state["gbs"] = 64 + w.finish_train_step("s0") + w.scheduler.step.assert_called_once_with(increment=64) + + def test_returns_metrics_dict(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + metrics = w.finish_train_step("s0") + for key in ( + "global_loss", + "rank", + "gpu_name", + "model_dtype", + "all_mb_metrics", + "grad_norm", + ): + assert key in metrics, f"missing {key!r}" + + def test_moe_branch_skipped_when_num_experts_is_none(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = None + metrics = w.finish_train_step("s0") + assert "moe_metrics" not in metrics + + def test_moe_branch_uses_total_num_microbatches_for_scale( + self, mock_module_symbols + ): + """MoE aux-loss scale must use the accumulated total, not the + per-call num_microbatches.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = 4 + # Have get_moe_metrics return non-empty so the branch fires + mock_module_symbols["moe"].return_value = {"aux_loss": 0.1} + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # 3 train_microbatch calls × 2 pipeline mbs each = 6 + for _ in range(3): + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + # get_moe_metrics receives loss_scale=1/6 + kwargs = mock_module_symbols["moe"].call_args.kwargs + assert kwargs["loss_scale"] == pytest.approx(1.0 / 6.0, rel=1e-6) + + +# ── abort_train_step ───────────────────────────────────────────────────── + + +class TestAbort: + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_zero_grad_buffer_and_zero_grad_called(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.reset_mock() + w.optimizer.zero_grad.reset_mock() + w.abort_train_step("s0") + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + w.optimizer.step.assert_not_called() + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w._train_step_state is None + + def test_idempotent_with_no_open_step(self, mock_module_symbols): + """abort is a no-op when nothing is open.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + # Should not raise + w.abort_train_step("s0") + assert getattr(w, "_train_step_state", None) is None + + def test_mismatched_step_id_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="does not match open step"): + w.abort_train_step("s-WRONG") + + def test_can_begin_new_step_after_abort(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + # New step opens cleanly + w.begin_train_step("s1", loss_fn=w._test_loss_fn) + assert w._train_step_state["step_id"] == "s1" + assert float(w._train_step_state["local_valid_seqs"].item()) == 0.0 + + +# ── grad_sync_func full lifecycle (integration of begin → finish/abort) ─ + + +class TestGradSyncFuncLifecycle: + def test_begin_finish_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_begin_abort_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_handles_originally_none_grad_sync_func(self, mock_module_symbols): + """When PP=1 (or align_grad_reduce=False), grad_sync_func is None + to begin with. begin → finish must leave it as None.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.grad_sync_func = None + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func is None diff --git a/tests/unit/single_controller/__init__.py b/tests/unit/single_controller/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/unit/single_controller/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/single_controller/test_rollout_pump.py b/tests/unit/single_controller/test_rollout_pump.py new file mode 100644 index 0000000000..b4d5f291b3 --- /dev/null +++ b/tests/unit/single_controller/test_rollout_pump.py @@ -0,0 +1,317 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end test: SC._rollout_pump writes the expected rows to TQ.""" + +from __future__ import annotations + +import time +from typing import Any + +import ray +import torch +from tensordict import TensorDict + +from nemo_rl.algorithms.async_utils.replay_buffer import TQReplayBuffer +from nemo_rl.algorithms.single_controller import SingleControllerActor +from nemo_rl.algorithms.single_controller_utils import ( + AsyncRLConfig, + MasterConfig, + SingleControllerBundle, +) +from nemo_rl.data_plane.adapters.noop import NoOpDataPlaneClient +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.experience.rollout_manager import RolloutManager + +# Reuse fixtures from the experience tests; same shape as test_async_rollout_manager. +from tests.unit.experience.test_rollout_manager import ( + single_multi_step_calculator_input_sample, # noqa: F401 +) +from tests.unit.experience.test_rollouts import ( + initial_multi_step_calculator_batch, # noqa: F401 + multi_step_calculator_environment, # noqa: F401 + multi_step_setup_vllm_async, # noqa: F401 + rollout_cluster, # noqa: F401 + rollout_tokenizer, # noqa: F401 +) + +_PARTITION_ID = "rollout_data" +# TQReplayBuffer.add tensorizes each PromptGroupRecord and writes +# ``generations_per_prompt`` training rows directly to TQ. +_BULK_FIELDS = [ + "input_ids", + "input_lengths", + "generation_logprobs", + "token_mask", + "sample_mask", + "prompt_ids_for_adv", + "total_reward", +] + + +@ray.remote(num_cpus=0) +class _TQActor: + """Ray-wrapped NoOpDataPlaneClient for cross-process TQ inspection.""" + + def __init__( + self, + partition_id: str, + fields: list[str], + num_samples: int, + consumer_tasks: list[str], + ) -> None: + self._client = NoOpDataPlaneClient() + self._client.register_partition( + partition_id=partition_id, + fields=list(fields), + num_samples=int(num_samples), + consumer_tasks=list(consumer_tasks), + ) + + def put_samples( + self, + sample_ids: list[str], + partition_id: str, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + ) -> Any: + return self._client.put_samples( + sample_ids=sample_ids, + partition_id=partition_id, + fields=fields, + tags=tags, + ) + + def claim_meta(self, **kwargs: Any) -> Any: + return self._client.claim_meta(**kwargs) + + def get_samples( + self, + sample_ids: list[str], + partition_id: str, + select_fields: list[str], + ) -> TensorDict: + return self._client.get_samples( + sample_ids=sample_ids, + partition_id=partition_id, + select_fields=list(select_fields), + ) + + def get_tags( + self, partition_id: str, sample_ids: list[str] + ) -> list[dict[str, Any]]: + rec = self._client._partitions[partition_id] + return [dict(rec.tags.get(sid, {})) for sid in sample_ids] + + def peek_count(self, partition_id: str) -> int: + return len(self._client._partitions[partition_id].rows) + + +class _SyncDPAdapter: + """Sync DataPlaneClient over a Ray actor handle. Pads nested tensors before transport.""" + + def __init__(self, handle: Any) -> None: + self._handle = handle + + def put_samples( + self, + sample_ids: list[str], + partition_id: str, + fields: TensorDict | None = None, + tags: list[dict[str, Any]] | None = None, + ) -> Any: + if fields is not None: + fields = self._padded(fields) + return ray.get( + self._handle.put_samples.remote( + sample_ids=sample_ids, + partition_id=partition_id, + fields=fields, + tags=tags, + ) + ) + + @staticmethod + def _padded(td: TensorDict) -> TensorDict: + out: dict[str, torch.Tensor] = {} + for k in td.keys(): + v = td.get(k) + if isinstance(v, torch.Tensor) and v.is_nested: + v = torch.nested.to_padded_tensor(v, padding=0) + out[k] = v + return TensorDict(out, batch_size=td.batch_size) + + +def test_rollout_pump_writes_expected_tq_data( + multi_step_setup_vllm_async, # noqa: F811 + single_multi_step_calculator_input_sample, # noqa: F811 +): + """SC._rollout_pump writes max_rollout_prompts * num_generations rows to TQ with the expected fields and tags.""" + vllm_generation, tokenizer, env_handles, _, _ = multi_step_setup_vllm_async + input_sample = single_multi_step_calculator_input_sample + + num_generations = 2 + max_rollout_prompts = 2 + # TQReplayBuffer.add writes ``num_generations`` training rows per prompt. + expected_samples = max_rollout_prompts * num_generations + max_seq_len = 1024 + max_rollout_turns = input_sample["extra_env_info"]["max_steps"] + 1 + + tq_actor = _TQActor.remote( + partition_id=_PARTITION_ID, + fields=_BULK_FIELDS, + num_samples=expected_samples * 4, + consumer_tasks=["train"], + ) + dp_adapter = _SyncDPAdapter(tq_actor) + + mc = MasterConfig.model_construct( + grpo={ + "max_num_steps": 1, + "max_num_epochs": None, + "num_generations_per_prompt": num_generations, + }, + async_rl=AsyncRLConfig( + batch_selection_strategy="strict_on_policy", + max_weight_staleness_versions=0, + min_prompt_groups_per_batch=1, + max_inflight_prompts=max_rollout_prompts, + max_buffered_rollouts=max_rollout_prompts, + ), + ) + # Wrap each value in a single-element list so size==1 and v[0] returns the original field. + batched_sample = BatchedDataDict({k: [v] for k, v in input_sample.items()}) + dataloader = [batched_sample] * max_rollout_prompts + + tq_buffer = TQReplayBuffer( + dp_adapter, + partition_id=_PARTITION_ID, + pad_value_dict={"token_ids": int(tokenizer.pad_token_id or 0)}, + ) + rollout_manager = RolloutManager( + tokenizer=tokenizer, + env_handles=env_handles, + num_generations_per_prompt=num_generations, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + policy_generation=vllm_generation, + use_nemo_gym=False, + tq_buffer=tq_buffer, + ) + bundle = SingleControllerBundle( + gen_handle=vllm_generation, + trainer_handle=object(), + env_handles=env_handles, + train_cluster=None, + inference_cluster=None, + dp_client=dp_adapter, + dataloader=dataloader, + weight_synchronizer=object(), + advantage_estimator=None, + loss_fn=None, + rollout_manager=rollout_manager, + tq_buffer=tq_buffer, + partition_id=_PARTITION_ID, + ) + ctrl = SingleControllerActor.remote(master_config=mc, bundle=bundle) + + vllm_generation.prepare_for_generation() + + # _rollout_pump runs until cancelled, so poll TQ then cancel. + pump_ref = ctrl._rollout_pump.remote() + deadline = time.monotonic() + 120.0 + while time.monotonic() < deadline: + if ray.get(tq_actor.peek_count.remote(_PARTITION_ID)) >= expected_samples: + break + time.sleep(0.5) + assert ray.get(tq_actor.peek_count.remote(_PARTITION_ID)) >= expected_samples, ( + "rollout_pump did not push expected_samples within timeout" + ) + ray.cancel(pump_ref) + try: + ray.get(pump_ref) + except (ray.exceptions.RayTaskError, ray.exceptions.TaskCancelledError): + pass + + vllm_generation.finish_generation() + + meta = ray.get( + tq_actor.claim_meta.remote( + partition_id=_PARTITION_ID, + task_name="train", + required_fields=_BULK_FIELDS, + batch_size=expected_samples * 4, + blocking=False, + timeout_s=0.0, + ) + ) + assert meta.size == expected_samples + + # pack_payload stamps sample_ids as ``{group_uuid}_g{i}``. + group_ids: set[str] = set() + for sid in meta.sample_ids: + head, _, tail = sid.rpartition("_g") + assert head and tail.isdigit(), f"unexpected sample_id: {sid}" + group_ids.add(head) + assert len(group_ids) == max_rollout_prompts + + bulk = ray.get( + tq_actor.get_samples.remote( + sample_ids=meta.sample_ids, + partition_id=_PARTITION_ID, + select_fields=_BULK_FIELDS, + ) + ) + assert set(bulk.keys()) >= set(_BULK_FIELDS), ( + f"missing bulk fields: {set(_BULK_FIELDS) - set(bulk.keys())}" + ) + + input_lengths = bulk["input_lengths"].long() + assert input_lengths.shape[0] == expected_samples + assert torch.all(input_lengths > 0) + assert torch.allclose( + bulk["sample_mask"].float(), + torch.ones(expected_samples, dtype=torch.float32), + ) + + # Same deterministic prompt as test_async_rollout_manager: the model + # solves the calculator task every time -> reward == 1.0 and decoded + # tail contains " 16". + rewards = bulk["total_reward"].float().flatten() + assert rewards.shape == (expected_samples,) + assert torch.allclose(rewards, torch.ones(expected_samples)), ( + f"expected all rewards == 1.0, got {rewards.tolist()}" + ) + + input_ids = bulk["input_ids"] + token_mask = bulk["token_mask"] + for i in range(expected_samples): + length = int(input_lengths[i]) + decoded = tokenizer.decode( + input_ids[i, :length].tolist(), skip_special_tokens=False + ) + assert " 16" in decoded[-64:], ( + f"sample {i}: decoded tail {decoded[-64:]!r} missing ' 16'" + ) + assert int(token_mask[i, :length].sum().item()) > 0, ( + f"sample {i}: token_mask has no assistant tokens" + ) + + tags = ray.get( + tq_actor.get_tags.remote(partition_id=_PARTITION_ID, sample_ids=meta.sample_ids) + ) + for tag in tags: + assert tag["weight_version"] == 0 + # Slim tag schema: weight_version is the only field producers stamp. + assert set(tag) == {"weight_version"} diff --git a/tests/unit/single_controller/test_single_controller_setup.py b/tests/unit/single_controller/test_single_controller_setup.py new file mode 100644 index 0000000000..1f526bd055 --- /dev/null +++ b/tests/unit/single_controller/test_single_controller_setup.py @@ -0,0 +1,294 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for setup_single_controller (factories monkey-patched).""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +import nemo_rl.algorithms.single_controller_utils.setup as sc_setup_mod +from nemo_rl.algorithms.loss import ClippedPGLossConfig +from nemo_rl.algorithms.single_controller_utils import ( + MasterConfig, + SingleControllerBundle, + setup_single_controller, +) + + +def _make_master_config( + *, + dp_enabled: bool = True, + use_multiple_dataloader: bool = False, + colocated: bool = True, + backend: str = "vllm", + megatron_enabled: bool = False, + env: dict | None = None, + max_num_steps: int = 100, + max_num_epochs: int = 1, + num_prompts_per_step: int = 4, +) -> MasterConfig: + """Build a partially-populated MasterConfig for unit tests. + + Cross-cutting components (cluster/checkpointing/...) are required by pydantic for + normal load but unused here — model_construct skips validation, and we hand-fill + only the dict-shaped fields setup reads. + """ + return MasterConfig.model_construct( + data_plane={"enabled": dp_enabled, "impl": "transfer_queue"}, + data={ + "use_multiple_dataloader": use_multiple_dataloader, + "shuffle": False, + "num_workers": 0, + "train": [{"env_name": "math"}], + }, + grpo={ + "max_num_steps": max_num_steps, + "max_num_epochs": max_num_epochs, + "num_prompts_per_step": num_prompts_per_step, + "num_generations_per_prompt": 2, + "max_rollout_turns": 1, + }, + policy={ + "max_total_sequence_length": 32, + "megatron_cfg": {"enabled": megatron_enabled}, + "generation": { + "backend": backend, + "colocated": {"enabled": colocated, "resources": {}}, + }, + }, + loss_fn=ClippedPGLossConfig(), + env=env if env is not None else {}, + ) + + +@pytest.fixture +def patched_factories(): + """Patch every external factory setup calls. + + Returns a dict of mocks keyed by name so individual tests can assert on call args + without re-importing the patch handles. + """ + fake_dataset = list(range(8)) + fake_dataloader = MagicMock(name="dataloader") + # len(dataloader) used by the Megatron train_iters injection. + fake_dataloader.__len__ = MagicMock(return_value=4) + fake_env_handles = {"math": MagicMock(name="math_env")} + + with ( + patch.object( + sc_setup_mod, + "setup_response_data", + return_value=(fake_dataset, None, fake_env_handles, {}), + ) as mock_setup_response, + patch.object( + sc_setup_mod, + "StatefulDataLoader", + return_value=fake_dataloader, + ) as mock_dataloader, + patch.object( + sc_setup_mod, + "_build_clusters", + return_value=( + MagicMock(name="train_cluster"), + MagicMock(name="inference_cluster"), + ), + ) as mock_clusters, + patch.object( + sc_setup_mod, "_build_generation", return_value=MagicMock(name="gen") + ) as mock_gen, + patch.object( + sc_setup_mod, "_build_trainer", return_value=MagicMock(name="policy") + ) as mock_trainer, + patch.object( + sc_setup_mod, + "build_data_plane_client", + return_value=MagicMock(name="dp_client"), + ) as mock_dp_client, + patch.object( + sc_setup_mod, + "create_weight_synchronizer", + return_value=MagicMock(name="weight_sync"), + ) as mock_weight_sync, + patch.object( + sc_setup_mod, + "_create_advantage_estimator", + return_value=MagicMock(name="adv"), + ) as mock_adv, + patch.object( + sc_setup_mod, "ClippedPGLossFn", return_value=MagicMock(name="loss_fn") + ) as mock_loss, + patch.object( + sc_setup_mod, + "_generation_max_seq_len", + return_value=32, + ), + ): + yield { + "setup_response_data": mock_setup_response, + "StatefulDataLoader": mock_dataloader, + "_build_clusters": mock_clusters, + "_build_generation": mock_gen, + "_build_trainer": mock_trainer, + "build_data_plane_client": mock_dp_client, + "create_weight_synchronizer": mock_weight_sync, + "_create_advantage_estimator": mock_adv, + "ClippedPGLossFn": mock_loss, + "dataloader": fake_dataloader, + "env_handles": fake_env_handles, + } + + +class TestSetup: + """setup arg validation + bundle assembly.""" + + def test_raises_when_data_plane_disabled(self): + mc = _make_master_config(dp_enabled=False) + with pytest.raises(ValueError, match="data_plane.enabled=True"): + setup_single_controller(mc, MagicMock()) + + def test_multiple_dataloader_not_supported(self): + mc = _make_master_config(use_multiple_dataloader=True) + with pytest.raises(NotImplementedError, match="use_multiple_dataloader"): + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + def test_returns_bundle(self, patched_factories): + mc = _make_master_config(colocated=True) + tokenizer = MagicMock(pad_token_id=0) + + bundle = setup_single_controller(mc, tokenizer) + + assert isinstance(bundle, SingleControllerBundle) + assert bundle.gen_handle is patched_factories["_build_generation"].return_value + assert bundle.trainer_handle is patched_factories["_build_trainer"].return_value + assert bundle.env_handles is patched_factories["env_handles"] + assert ( + bundle.dp_client + is patched_factories["build_data_plane_client"].return_value + ) + assert bundle.dataloader is patched_factories["dataloader"] + assert bundle.weight_synchronizer is ( + patched_factories["create_weight_synchronizer"].return_value + ) + assert bundle.advantage_estimator is ( + patched_factories["_create_advantage_estimator"].return_value + ) + assert bundle.loss_fn is patched_factories["ClippedPGLossFn"].return_value + # tq_buffer + rollout_manager are constructed inline (not mocked). + assert bundle.tq_buffer is not None + assert bundle.rollout_manager is not None + # rollout_manager binds the same tq_buffer for the writer + sampler. + assert bundle.rollout_manager._tq_buffer is bundle.tq_buffer + # tq_buffer wires the dp_client + default partition. + assert bundle.tq_buffer._dp_client is bundle.dp_client + assert bundle.partition_id == "rollout_data" + assert bundle.tq_buffer._partition_id == "rollout_data" + + def test_env_handles_sourced_from_setup_response_data(self, patched_factories): + """setup_response_data receives master_config.env and supplies env handles.""" + math_env_cfg = {"some": "value"} + mc = _make_master_config(env={"math": math_env_cfg}) + + bundle = setup_single_controller(mc, MagicMock(pad_token_id=0)) + + _, call_kwargs = patched_factories["setup_response_data"].call_args + assert call_kwargs["env_configs"] == {"math": math_env_cfg} + assert bundle.env_handles is patched_factories["env_handles"] + + def test_weight_sync_factory_args(self, patched_factories): + """create_weight_synchronizer receives policy / generation / topology.""" + mc = _make_master_config(colocated=False, backend="vllm") + tokenizer = MagicMock(pad_token_id=0) + + setup_single_controller(mc, tokenizer) + + _, factory_kwargs = patched_factories["create_weight_synchronizer"].call_args + assert ( + factory_kwargs["policy"] is patched_factories["_build_trainer"].return_value + ) + assert ( + factory_kwargs["generation"] + is patched_factories["_build_generation"].return_value + ) + assert factory_kwargs["generation_backend"] == "vllm" + assert factory_kwargs["colocated"] is False + + def test_custom_partition_id(self, patched_factories): + mc = _make_master_config() + tokenizer = MagicMock(pad_token_id=7) + + bundle = setup_single_controller(mc, tokenizer, partition_id="custom_partition") + + assert bundle.partition_id == "custom_partition" + assert bundle.tq_buffer._partition_id == "custom_partition" + assert bundle.tq_buffer._pad_value_dict == { + "token_ids": 7, + "input_ids": 7, + } + + def test_max_num_steps_capped_by_self(self, patched_factories): + """grpo.max_num_steps stays put when smaller than max_num_epochs * len(dl).""" + mc = _make_master_config( + megatron_enabled=False, + max_num_steps=2, + max_num_epochs=1, + ) + # patched dataloader has len() == 4, so the min picks max_num_steps. + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + assert mc.grpo["max_num_steps"] == 2 + + def test_max_num_steps_capped_by_dataloader_epochs(self, patched_factories): + """grpo.max_num_steps drops to max_num_epochs * len(dataloader) when smaller.""" + mc = _make_master_config( + megatron_enabled=False, + max_num_steps=1000, + max_num_epochs=2, + ) + # patched dataloader has len() == 4 → 2 * 4 = 8 < 1000. + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + assert mc.grpo["max_num_steps"] == 8 + + def test_megatron_train_iters_capped_by_max_num_steps(self, patched_factories): + """train_iters = min(max_num_steps, max_num_epochs * len(dataloader)).""" + mc = _make_master_config( + megatron_enabled=True, + max_num_steps=2, + max_num_epochs=1, + ) + # patched dataloader has len() == 4, so the min picks max_num_steps. + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + assert mc.policy["megatron_cfg"]["train_iters"] == 2 + + def test_megatron_train_iters_capped_by_dataloader_epochs(self, patched_factories): + """train_iters drops to max_num_epochs * len(dataloader) when smaller.""" + mc = _make_master_config( + megatron_enabled=True, + max_num_steps=1000, + max_num_epochs=2, + ) + # patched dataloader has len() == 4 → 2 * 4 = 8 < 1000. + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + assert mc.policy["megatron_cfg"]["train_iters"] == 8 + + def test_megatron_train_iters_not_set_when_disabled(self, patched_factories): + mc = _make_master_config(megatron_enabled=False) + setup_single_controller(mc, MagicMock(pad_token_id=0)) + + assert "train_iters" not in mc.policy.get("megatron_cfg", {}) diff --git a/tests/unit/single_controller/test_staleness_sampler.py b/tests/unit/single_controller/test_staleness_sampler.py new file mode 100644 index 0000000000..ce15f34acf --- /dev/null +++ b/tests/unit/single_controller/test_staleness_sampler.py @@ -0,0 +1,466 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for StalenessSampler (pure filter over TQReplayBuffer state).""" + +from __future__ import annotations + +import asyncio + +import pytest + +from nemo_rl.algorithms.async_utils.staleness_sampler import StalenessSampler +from nemo_rl.data_plane import KVBatchMeta + + +class FakeBuffer: + """Minimal TQReplayBuffer surface used by StalenessSampler tests.""" + + def __init__(self, partition_id: str = "rollout_data") -> None: + self._partition_id = partition_id + self.meta_list: list[KVBatchMeta | None] = [] + self.start_weight_list: list[int] = [] + self.end_weight_list: list[int] = [] + self.ready_list: list[bool] = [] + self.remove_calls: list[tuple[list[int], bool]] = [] + + def add( + self, + group_id: str, + weight: int, + group_size: int = 1, + ready: bool = True, + end_weight: int | None = None, + ) -> KVBatchMeta: + sample_ids = [f"{group_id}_g{i}" for i in range(group_size)] + meta = KVBatchMeta( + partition_id=self._partition_id, + task_name=None, + sample_ids=sample_ids, + tags=[{"weight_version": weight, "group_id": group_id}] * group_size, + ) + self.meta_list.append(meta if ready else None) + self.start_weight_list.append(weight) + self.end_weight_list.append(weight if end_weight is None else end_weight) + self.ready_list.append(ready) + return meta + + async def remove(self, idxs: list[int], remove_in_dp: bool) -> int: + self.remove_calls.append((list(idxs), remove_in_dp)) + for i in sorted(idxs, reverse=True): + del self.meta_list[i] + del self.start_weight_list[i] + del self.end_weight_list[i] + del self.ready_list[i] + return len(idxs) + + +def _run(coro): + return asyncio.run(coro) + + +class TestStalenessSamplerSelect: + def test_select_returns_none_when_insufficient(self): + buf = FakeBuffer() + buf.add("g0", weight=5) + sampler = StalenessSampler(buf, max_staleness_versions=2) + + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert result == (None, 0) + + def test_select_returns_none_on_empty_buffer(self): + buf = FakeBuffer() + sampler = StalenessSampler(buf, max_staleness_versions=2) + + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=1, max_prompt_groups=1 + ) + ) + assert result == (None, 0) + + def test_select_filters_by_staleness_window(self): + buf = FakeBuffer() + # Weights 3, 4, 5, 2, 6 against trainer=5, max_staleness=2: + # lags = 2, 1, 0, 3 (stale), -1 (future) + for i, w in enumerate([3, 4, 5, 2, 6]): + buf.add(f"g{i}", weight=w) + sampler = StalenessSampler( + buf, max_staleness_versions=2, sample_freshest_first=True + ) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + + assert selected is not None + # Freshest first → g2 (lag 0), g1 (lag 1) + assert selected.sample_ids == ["g2_g0", "g1_g0"] + assert num_groups == 2 + + def test_select_freshest_first_orders_by_lag(self): + buf = FakeBuffer() + for w in [3, 4, 5]: + buf.add(f"v{w}", weight=w) + sampler = StalenessSampler( + buf, max_staleness_versions=5, sample_freshest_first=True + ) + + selected, num_groups = _run( + sampler.select( + current_train_weight=6, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert selected is not None + assert selected.sample_ids == ["v5_g0", "v4_g0"] + assert num_groups == 2 + + def test_select_fifo_orders_by_insertion(self): + buf = FakeBuffer() + for w in [3, 4, 5]: + buf.add(f"v{w}", weight=w) + sampler = StalenessSampler( + buf, max_staleness_versions=5, sample_freshest_first=False + ) + + selected, num_groups = _run( + sampler.select( + current_train_weight=6, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert selected is not None + assert selected.sample_ids == ["v3_g0", "v4_g0"] + assert num_groups == 2 + + def test_select_skips_future_weight(self): + buf = FakeBuffer() + buf.add("now", weight=5) + buf.add("future", weight=7) + sampler = StalenessSampler(buf, max_staleness_versions=10) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=1, max_prompt_groups=1 + ) + ) + + assert selected is not None + assert selected.sample_ids == ["now_g0"] + assert num_groups == 1 + + def test_select_concats_groups(self): + buf = FakeBuffer() + buf.add("g0", weight=5, group_size=2) + buf.add("g1", weight=5, group_size=2) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + + assert selected is not None + assert selected.sample_ids == [ + "g0_g0", + "g0_g1", + "g1_g0", + "g1_g1", + ] + # Two groups concatenated, each of size 2 → 4 sample_ids total. + assert num_groups == 2 + + def test_select_strict_on_policy_requires_exact_version(self): + buf = FakeBuffer() + for i, w in enumerate([4, 5, 5, 6]): + buf.add(f"g{i}", weight=w) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + # 3 eligible (need weight=5), only have 2 + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=3, max_prompt_groups=3 + ) + ) + assert result == (None, 0) + + # Buffer still intact: select with min=3 returned None without dropping anything. + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert selected is not None + assert selected.sample_ids == ["g1_g0", "g2_g0"] + assert num_groups == 2 + + def test_select_drops_returned_entries_from_buffer(self): + buf = FakeBuffer() + for i, w in enumerate([5, 5, 5]): + buf.add(f"g{i}", weight=w) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + first_meta, first_num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=1, max_prompt_groups=1 + ) + ) + assert first_meta is not None + assert first_meta.sample_ids == ["g0_g0"] + assert first_num_groups == 1 + assert buf.start_weight_list == [5, 5] + # remove_in_dp=False; DP rows kept for trainer. + assert buf.remove_calls[-1][1] is False + + second_meta, second_num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=1, max_prompt_groups=1 + ) + ) + assert second_meta is not None + assert second_meta.sample_ids == ["g1_g0"] + assert second_num_groups == 1 + + def test_select_rejects_zero_min_prompt_groups(self): + buf = FakeBuffer() + sampler = StalenessSampler(buf, max_staleness_versions=0) + with pytest.raises(ValueError): + _run( + sampler.select( + current_train_weight=0, min_prompt_groups=0, max_prompt_groups=0 + ) + ) + + def test_select_rejects_max_less_than_min(self): + buf = FakeBuffer() + for i in range(3): + buf.add(f"g{i}", weight=5) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + with pytest.raises(ValueError): + _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=1 + ) + ) + + def test_select_caps_at_max_prompt_groups(self): + buf = FakeBuffer() + for i in range(5): + buf.add(f"g{i}", weight=5) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=3 + ) + ) + + assert selected is not None + # FIFO order; capped at max=3 even though 5 are eligible. + assert selected.sample_ids == ["g0_g0", "g1_g0", "g2_g0"] + assert num_groups == 3 + # The remaining two stay in the buffer. + assert buf.start_weight_list == [5, 5] + + def test_select_takes_all_available_when_between_min_and_max(self): + buf = FakeBuffer() + for i in range(3): + buf.add(f"g{i}", weight=5) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=8 + ) + ) + + assert selected is not None + assert selected.sample_ids == ["g0_g0", "g1_g0", "g2_g0"] + assert num_groups == 3 + assert buf.start_weight_list == [] + + +class TestStalenessSamplerEvict: + def test_evict_removes_stale_groups(self): + buf = FakeBuffer() + # trainer=5, max_staleness=1 → lag >1 means stale (weights 0, 1, 2 stale; 4, 5 fresh) + for i, w in enumerate([0, 1, 4, 5, 2]): + buf.add(f"g{i}", weight=w) + sampler = StalenessSampler(buf, max_staleness_versions=1) + + dropped = _run(sampler.evict(current_train_weight=5)) + + assert dropped == 3 + assert buf.start_weight_list == [4, 5] + # Survivors' sample_ids + assert [m.sample_ids[0] for m in buf.meta_list] == ["g2_g0", "g3_g0"] + + def test_evict_returns_zero_when_nothing_stale(self): + buf = FakeBuffer() + for w in [4, 5]: + buf.add(f"v{w}", weight=w) + sampler = StalenessSampler(buf, max_staleness_versions=1) + + assert _run(sampler.evict(current_train_weight=5)) == 0 + assert buf.remove_calls == [] + + def test_evict_keeps_future_groups(self): + buf = FakeBuffer() + buf.add("future", weight=7) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + assert _run(sampler.evict(current_train_weight=5)) == 0 + assert buf.start_weight_list == [7] + + def test_evict_drops_whole_group(self): + buf = FakeBuffer() + buf.add("stale", weight=1, group_size=4) + buf.add("fresh", weight=5, group_size=4) + sampler = StalenessSampler(buf, max_staleness_versions=1) + + dropped = _run(sampler.evict(current_train_weight=5)) + + assert dropped == 1 + assert buf.remove_calls == [([0], True)] + assert buf.start_weight_list == [5] + assert [m.sample_ids[0] for m in buf.meta_list] == ["fresh_g0"] + + +class TestStalenessSamplerInit: + def test_rejects_negative_max_staleness(self): + buf = FakeBuffer() + with pytest.raises(ValueError): + StalenessSampler(buf, max_staleness_versions=-1) + + def test_rejects_require_order_with_freshest_first(self): + buf = FakeBuffer() + with pytest.raises(ValueError): + StalenessSampler( + buf, + max_staleness_versions=0, + sample_freshest_first=True, + require_order=True, + ) + + +class TestStalenessSamplerReady: + def test_default_mode_skips_unready_slots(self): + buf = FakeBuffer() + buf.add("g0", weight=5, ready=False) + buf.add("g1", weight=5, ready=True) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=1, max_prompt_groups=1 + ) + ) + + assert selected is not None + assert selected.sample_ids == ["g1_g0"] + assert num_groups == 1 + + def test_default_mode_waits_when_too_few_ready(self): + buf = FakeBuffer() + buf.add("g0", weight=5, ready=False) + buf.add("g1", weight=5, ready=True) + sampler = StalenessSampler(buf, max_staleness_versions=0) + + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert result == (None, 0) + + +class TestStalenessSamplerRequireOrder: + def test_consumes_oldest_batch_first(self): + buf = FakeBuffer() + # Two complete batches: v=4 then v=5; require_order must take v=4 first. + for i, w in enumerate((4, 4, 5, 5)): + buf.add(f"v{w}_{i}", weight=w) + sampler = StalenessSampler(buf, max_staleness_versions=1, require_order=True) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + + assert selected is not None + # Insertion-order FIFO inside the oldest batch. + assert selected.sample_ids == ["v4_0_g0", "v4_1_g0"] + assert num_groups == 2 + assert buf.start_weight_list == [5, 5] + + def test_waits_when_oldest_batch_partially_ready(self): + buf = FakeBuffer() + # Oldest batch v=4 has 1 ready + 1 unready; v=5 batch is fully ready. + # require_order must NOT skip ahead to v=5. + buf.add("v4_a", weight=4, ready=True) + buf.add("v4_b", weight=4, ready=False) + buf.add("v5_a", weight=5, ready=True) + buf.add("v5_b", weight=5, ready=True) + sampler = StalenessSampler(buf, max_staleness_versions=1, require_order=True) + + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert result == (None, 0) + # Buffer untouched: nothing removed. + assert buf.start_weight_list == [4, 4, 5, 5] + assert buf.ready_list == [True, False, True, True] + + def test_returns_none_when_oldest_batch_not_filled(self): + buf = FakeBuffer() + buf.add("v4_a", weight=4, ready=True) + # Only 1 ready in oldest batch; need 2. + sampler = StalenessSampler(buf, max_staleness_versions=1, require_order=True) + + result = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + assert result == (None, 0) + + def test_ignores_future_versions_when_picking_target(self): + buf = FakeBuffer() + # Trainer at 5, staleness 1: window is [4, 5]; v=7 (future) must not + # become the oldest target. + buf.add("v7", weight=7, ready=True) + buf.add("v5_a", weight=5, ready=True) + buf.add("v5_b", weight=5, ready=True) + sampler = StalenessSampler(buf, max_staleness_versions=1, require_order=True) + + selected, num_groups = _run( + sampler.select( + current_train_weight=5, min_prompt_groups=2, max_prompt_groups=2 + ) + ) + + assert selected is not None + assert selected.sample_ids == ["v5_a_g0", "v5_b_g0"] + assert num_groups == 2 + assert buf.start_weight_list == [7] diff --git a/tests/unit/single_controller/test_tq_replay_buffer.py b/tests/unit/single_controller/test_tq_replay_buffer.py new file mode 100644 index 0000000000..55b45c9ece --- /dev/null +++ b/tests/unit/single_controller/test_tq_replay_buffer.py @@ -0,0 +1,324 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for TQReplayBuffer (plain SC-process buffer + TQ proxy).""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +import torch + +import nemo_rl.algorithms.async_utils.replay_buffer as _replay_buffer_module +from nemo_rl.algorithms.async_utils.replay_buffer import TQReplayBuffer +from nemo_rl.data_plane import KVBatchMeta +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.experience.interfaces import PromptGroupRecord + +# Each record yields _N_GENS training rows. +_N_GENS = 2 + + +def _stub_record_to_train_batch( + record: PromptGroupRecord, *, pad_value_dict: Any +) -> BatchedDataDict[Any]: + del record, pad_value_dict + return BatchedDataDict[Any]( + { + "input_ids": torch.ones((_N_GENS, 3), dtype=torch.long), + "input_lengths": torch.full((_N_GENS,), 3, dtype=torch.long), + "total_reward": torch.zeros(_N_GENS, dtype=torch.float32), + } + ) + + +@pytest.fixture(autouse=True) +def _patch_converter(monkeypatch): + """Bypass the real ``record_to_train_batch`` so tests can use empty records.""" + monkeypatch.setattr( + _replay_buffer_module, + "record_to_train_batch", + _stub_record_to_train_batch, + ) + + +class FakeDataPlaneClient: + """Sync in-memory DataPlaneClient stub used by TQReplayBuffer tests.""" + + def __init__(self, partition_id: str = "rollout_data") -> None: + self._partition_id = partition_id + self._rows: dict[str, dict[str, Any]] = {} + self.put_calls: list[dict[str, Any]] = [] + self.clear_calls: list[list[str]] = [] + + def put_samples( + self, + sample_ids: list[str], + partition_id: str, + fields: Any = None, + tags: list[dict[str, Any]] | None = None, + ) -> KVBatchMeta: + assert partition_id == self._partition_id + self.put_calls.append( + { + "sample_ids": list(sample_ids), + "fields": fields, + "tags": [dict(t) for t in tags] if tags is not None else None, + } + ) + for i, sid in enumerate(sample_ids): + self._rows[sid] = { + "tag": dict(tags[i]) if tags is not None else {}, + } + return KVBatchMeta( + partition_id=partition_id, + task_name=None, + sample_ids=list(sample_ids), + fields=None, + tags=[dict(t) for t in tags] if tags is not None else None, + ) + + def clear_samples(self, sample_ids: list[str] | None, partition_id: str) -> None: + assert partition_id == self._partition_id + ids = list(sample_ids) if sample_ids is not None else list(self._rows) + self.clear_calls.append(list(ids)) + for sid in ids: + self._rows.pop(sid, None) + + def depth(self) -> int: + return len(self._rows) + + +def _run(coro): + return asyncio.run(coro) + + +def _make_record() -> PromptGroupRecord: + """Opaque PromptGroupRecord — converter is stubbed, so contents are unused.""" + return PromptGroupRecord( + prompt_idx=0, + prompt=[], + extra_env_info=None, + metadata={}, + completions=[], + rollout_metrics={}, + ) + + +def _make_buffer(dp: FakeDataPlaneClient) -> TQReplayBuffer: + return TQReplayBuffer( + dp, partition_id="rollout_data", pad_value_dict={"token_ids": 0} + ) + + +def _add_group( + buf: TQReplayBuffer, weight: int, end_weight: int | None = None +) -> KVBatchMeta: + if end_weight is None: + end_weight = weight + group_id = buf.reserve(weight_version=weight) + return _run( + buf.commit( + group_id, + _make_record(), + start_weight_version=weight, + end_weight_version=end_weight, + ) + ) + + +class TestTQReplayBufferReserveCommit: + def test_reserve_appends_placeholder_unready(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + + group_id = buf.reserve(weight_version=3) + + assert isinstance(group_id, str) and group_id + assert buf.size() == 1 + assert buf.start_weight_list == [3] + assert buf.end_weight_list == [-1] + assert buf.ready_list == [False] + assert buf.meta_list == [None] + assert dp.depth() == 0 + assert dp.put_calls == [] + + def test_commit_writes_tq_then_fills_meta(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + + group_id = buf.reserve(weight_version=3) + meta = _run( + buf.commit( + group_id, + _make_record(), + start_weight_version=3, + end_weight_version=4, + ) + ) + + # pack_payload stamps sample_ids as ``{group_uuid}_g{i}``. + assert len(meta.sample_ids) == _N_GENS + head, _, idx = meta.sample_ids[0].rpartition("_g") + assert head == group_id and idx == "0" + assert all(sid.startswith(group_id + "_g") for sid in meta.sample_ids) + assert dp.depth() == _N_GENS + assert buf.size() == 1 + assert buf.start_weight_list == [3] + assert buf.end_weight_list == [4] + assert buf.ready_list == [True] + assert buf.meta_list[0].sample_ids == meta.sample_ids + # TQ tag uses start_weight_version (dispatch time). + assert meta.tags == [{"weight_version": 3}] * _N_GENS + assert len(dp.put_calls) == 1 + + def test_commit_raises_for_unknown_group_id(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + buf.reserve(weight_version=3) + + with pytest.raises(ValueError): + _run( + buf.commit( + "not-a-real-id", + _make_record(), + start_weight_version=3, + end_weight_version=3, + ) + ) + + def test_reserve_then_commit_preserves_dispatch_order(self): + """Reserve in dispatch order, commit out of order; insertion order holds.""" + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + + weights = (1, 2, 3) + gids = [buf.reserve(weight_version=w) for w in weights] + # Commit out of order: 2, 0, 1 — buffer order must still match reserve order. + for i in (2, 0, 1): + _run( + buf.commit( + gids[i], + _make_record(), + start_weight_version=weights[i], + end_weight_version=weights[i], + ) + ) + + assert buf.size() == 3 + assert buf.start_weight_list == [1, 2, 3] + assert buf.end_weight_list == [1, 2, 3] + assert buf.ready_list == [True, True, True] + # sample_id head equals reserved group_id at each slot. + for i, gid in enumerate(gids): + assert buf.meta_list[i] is not None + assert buf.meta_list[i].sample_ids[0].startswith(gid + "_g") + + def test_commit_appends_multiple_records_in_order(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + + metas = [_add_group(buf, weight=w) for w in (1, 2, 3)] + + assert buf.size() == 3 + assert buf.start_weight_list == [1, 2, 3] + assert buf.end_weight_list == [1, 2, 3] + assert [m.sample_ids for m in buf.meta_list] == [ + list(metas[0].sample_ids), + list(metas[1].sample_ids), + list(metas[2].sample_ids), + ] + + +class TestTQReplayBufferRemove: + def test_remove_drops_indices_and_clears_dp_when_requested(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + metas = [_add_group(buf, weight=g) for g in range(3)] + + n = _run(buf.remove([0, 2], remove_in_dp=True)) + + assert n == 2 + assert buf.size() == 1 + assert buf.start_weight_list == [1] + assert buf.end_weight_list == [1] + assert buf.meta_list[0].sample_ids == list(metas[1].sample_ids) + assert dp.depth() == _N_GENS + assert set(dp._rows) == set(metas[1].sample_ids) + + def test_remove_without_dp_keeps_rows(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + metas = [_add_group(buf, weight=g) for g in range(2)] + + n = _run(buf.remove([0], remove_in_dp=False)) + + assert n == 1 + assert buf.size() == 1 + assert buf.start_weight_list == [1] + assert buf.end_weight_list == [1] + assert buf.meta_list[0].sample_ids == list(metas[1].sample_ids) + assert dp.clear_calls == [] + assert dp.depth() == 2 * _N_GENS + + def test_remove_rejects_out_of_range_before_mutating(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + metas = [_add_group(buf, weight=g) for g in range(2)] + + with pytest.raises(IndexError, match=r"out of range: 5; size=2"): + _run(buf.remove([0, 5], remove_in_dp=True)) + + assert buf.size() == 2 + assert [m.sample_ids for m in buf.meta_list] == [ + list(metas[0].sample_ids), + list(metas[1].sample_ids), + ] + assert dp.depth() == 2 * _N_GENS + assert dp.clear_calls == [] + + def test_remove_empty_is_noop(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + _add_group(buf, weight=0) + _add_group(buf, weight=0) + + n = _run(buf.remove([], remove_in_dp=True)) + + assert n == 0 + assert buf.size() == 2 + assert dp.depth() == 2 * _N_GENS + assert dp.clear_calls == [] + + +class TestTQReplayBufferSize: + def test_size_and_len(self): + dp = FakeDataPlaneClient() + buf = _make_buffer(dp) + assert buf.size() == 0 + assert len(buf) == 0 + + _add_group(buf, weight=0) + assert buf.size() == 1 + assert len(buf) == 1 + + _add_group(buf, weight=0) + assert buf.size() == 2 + assert len(buf) == 2 + + _run(buf.remove([0], remove_in_dp=True)) + assert buf.size() == 1 + assert len(buf) == 1