Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions tests/unit/data_plane/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared TQ data-plane client fixtures.

Mooncake's C++ engine keeps a process-global mount registry that survives
Python-level ``close()`` (upstream ``transfer_queue.close()`` also leaves
``mooncake_master`` running). Re-initializing the client in the same
pytest worker process leaks stale segment endpoints; the next
``batch_upsert_from`` then routes to a dead endpoint from a prior init
and returns ``TRANSFER_FAIL`` (-800). Production never hits this — the
driver bootstraps once and workers attach via ``bootstrap=False``.

Session-scoping the underlying clients here mirrors production: exactly
one mooncake init per pytest worker, period. Tests must use distinct
``partition_id`` values (seqpack-eq / dynbatch-eq / nopack-eq /
smoke / smoke-backend / smoke-1d / obj-backend / mix-e2e today).
"""

from __future__ import annotations

import pytest

from nemo_rl.data_plane import build_data_plane_client

from ._rollout_shapes import mooncake_available


def _make_tq_cfg(backend: str) -> dict:
return {
"enabled": True,
"impl": "transfer_queue",
"backend": backend,
"storage_capacity": 1024,
"num_storage_units": 1,
"claim_meta_poll_interval_s": 0.5,
"global_segment_size": 8589934592, # 8 GiB — sized for CI host RAM
"local_buffer_size": 1073741824, # 1 GiB
}


# Ray is started by the parent autouse ``init_ray_cluster`` fixture in
# ``tests/unit/conftest.py`` — no explicit init needed here.


@pytest.fixture(scope="session")
def _session_tq_client_simple():
client = build_data_plane_client(_make_tq_cfg("simple"))
yield client
client.close()


@pytest.fixture(scope="session")
def _session_tq_client_mooncake_cpu():
if not mooncake_available():
pytest.skip(
"mooncake not installed — skipping mooncake_cpu "
"(set NEMO_RL_REQUIRE_MOONCAKE=1 to fail loud)"
)
client = build_data_plane_client(_make_tq_cfg("mooncake_cpu"))
yield client
client.close()


@pytest.fixture
def tq_client(_session_tq_client_simple):
"""One simple-backend client shared across the pytest session."""
return _session_tq_client_simple


@pytest.fixture(
params=["simple", "mooncake_cpu"],
ids=["simple", "mooncake_cpu"],
)
def tq_client_backends(request):
"""Parametrized over [simple, mooncake_cpu] backends.

Each variant returns the session-scoped client for that backend, so
the mooncake_cpu client is initialized at most once per pytest worker
(see module docstring).
"""
return request.getfixturevalue(f"_session_tq_client_{request.param}")
84 changes: 13 additions & 71 deletions tests/unit/data_plane/test_seqpack_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,11 @@
pytest.importorskip("ray")
transfer_queue = pytest.importorskip("transfer_queue") # noqa: F841

from nemo_rl.data_plane import build_data_plane_client, materialize # noqa: E402
from nemo_rl.data_plane import materialize # noqa: E402
from nemo_rl.distributed.batched_data_dict import BatchedDataDict # noqa: E402

from ._rollout_shapes import mooncake_available

# Ray is initialized once by the parent autouse fixture
# ``tests/unit/conftest.py::init_ray_cluster`` (mirrors production: NeMo-RL
# inits Ray at startup; the data plane attaches on top). Each test just
# builds a TQ client on the shared Ray and closes it on teardown.
# The parametrized ``tq_client_backends`` fixture (simple + mooncake_cpu)
# is provided by ``tests/unit/data_plane/conftest.py``.


# Mirror of the seed-field set in nemo_rl/algorithms/grpo_sync.py.
Expand All @@ -67,60 +63,6 @@
"sample_mask",
)

# ── loud-skip helpers ─────────────────────────────────────────────────────────

# ── fixtures ──────────────────────────────────────────────────────────────────


def _make_tq_cfg(backend: str) -> dict:
# DataPlaneConfig requires the full schema (see interfaces.py); the
# adapter dereferences ``claim_meta_poll_interval_s`` at construction
# so missing it short-circuits the fixture before any test runs.
# ``global_segment_size`` / ``local_buffer_size`` only matter for
# ``mooncake_cpu`` but are required for schema conformance.
return {
"enabled": True,
"impl": "transfer_queue",
"backend": backend,
"storage_capacity": 1024,
"num_storage_units": 1,
"claim_meta_poll_interval_s": 0.5,
"global_segment_size": 8589934592, # 8 GiB — sized for CI host RAM, not prod
"local_buffer_size": 1073741824, # 1 GiB
}


@pytest.fixture(
scope="module",
params=["simple", "mooncake_cpu"],
ids=["simple", "mooncake_cpu"],
)
def tq_client(request):
"""Parametrized fixture over simple and mooncake_cpu backends.

mooncake_cpu is skipped when the mooncake wheel is not installed.
Set NEMO_RL_REQUIRE_MOONCAKE=1 to promote the skip to a loud failure.

Module-scoped so the mooncake_master + Transfer Engine survive across
the test cases in this file: each test uses its own ``partition_id``
("seqpack-eq" / "dynbatch-eq" / "nopack-eq") so no cross-test data
leak is possible, and reusing one client avoids the close→re-init
race in mooncake's C++ mount registry (upstream ``transfer_queue``
leaks the master process on close; the C++ engine then keeps stale
endpoint references that 404 against the next run's fresh master).

Relies on parent autouse ``init_ray_cluster`` for the Ray runtime.
"""
backend = request.param
if backend == "mooncake_cpu" and not mooncake_available():
pytest.skip(
"mooncake not installed — skipping mooncake_cpu seqpack equivalence "
"(set NEMO_RL_REQUIRE_MOONCAKE=1 to fail loud)"
)
client = build_data_plane_client(_make_tq_cfg(backend))
yield client
client.close()


def _make_fake_train_data(
n_samples: int = 64,
Expand Down Expand Up @@ -156,7 +98,7 @@ def _make_fake_train_data(


def _round_trip_shards_through_tq(
tq_client,
client,
pre_shards: list,
partition_id: str,
) -> list[BatchedDataDict]:
Expand All @@ -167,7 +109,7 @@ def _round_trip_shards_through_tq(
fetches its slice and attaches ``extra_info`` packing metadata.
"""
n_total = sum(int(s["sample_mask"].shape[0]) for s in pre_shards)
tq_client.register_partition(
client.register_partition(
partition_id=partition_id,
fields=list(_DP_SEED_FIELDS),
num_samples=n_total,
Expand All @@ -186,12 +128,12 @@ def _round_trip_shards_through_tq(
{f: shard[f].detach().contiguous() for f in names},
batch_size=[n],
)
tq_client.put_samples(
client.put_samples(
sample_ids=keys,
partition_id=partition_id,
fields=fields,
)
td_back = tq_client.get_samples(
td_back = client.get_samples(
sample_ids=keys,
partition_id=partition_id,
select_fields=list(names),
Expand Down Expand Up @@ -236,7 +178,7 @@ def _assert_shards_byte_equal(legacy, recovered, *, expect_metadata: bool) -> No
)


def test_seqpack_legacy_equals_tq(tq_client):
def test_seqpack_legacy_equals_tq(tq_client_backends):
"""Sequence packing: legacy shards == TQ-roundtripped shards (byte-level)."""
DP_WORLD = 4
GBS = 64
Expand All @@ -260,14 +202,14 @@ def test_seqpack_legacy_equals_tq(tq_client):
sequence_packing_args=spa,
)
recovered = _round_trip_shards_through_tq(
tq_client,
tq_client_backends,
tq_pre_shards,
partition_id="seqpack-eq",
)
_assert_shards_byte_equal(legacy_shards, recovered, expect_metadata=True)


def test_dynbatch_legacy_equals_tq(tq_client):
def test_dynbatch_legacy_equals_tq(tq_client_backends):
"""Dynamic batching: same equivalence claim as seqpack."""
DP_WORLD = 4
GBS = 64
Expand All @@ -290,14 +232,14 @@ def test_dynbatch_legacy_equals_tq(tq_client):
dynamic_batching_args=dba,
)
recovered = _round_trip_shards_through_tq(
tq_client,
tq_client_backends,
tq_pre_shards,
partition_id="dynbatch-eq",
)
_assert_shards_byte_equal(legacy_shards, recovered, expect_metadata=True)


def test_no_packing_legacy_equals_tq(tq_client):
def test_no_packing_legacy_equals_tq(tq_client_backends):
"""Sanity: even without packing/dynbatch the transport should be lossless."""
DP_WORLD = 4
GBS = 64
Expand All @@ -306,7 +248,7 @@ def test_no_packing_legacy_equals_tq(tq_client):
legacy_shards = data.shard_by_batch_size(DP_WORLD, batch_size=GBS)
tq_pre_shards = data.shard_by_batch_size(DP_WORLD, batch_size=GBS)
recovered = _round_trip_shards_through_tq(
tq_client,
tq_client_backends,
tq_pre_shards,
partition_id="nopack-eq",
)
Expand Down
73 changes: 3 additions & 70 deletions tests/unit/data_plane/test_tq_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,11 @@

transfer_queue = pytest.importorskip("transfer_queue") # noqa: F841

from nemo_rl.data_plane import build_data_plane_client
from nemo_rl.data_plane.column_io import kv_first_write, read_columns
from nemo_rl.data_plane.interfaces import KVBatchMeta
from nemo_rl.data_plane.schema import DP_TRAIN_FIELDS
from nemo_rl.distributed.batched_data_dict import BatchedDataDict

from ._rollout_shapes import mooncake_available

# ── loud-skip helpers ─────────────────────────────────────────────────────────

# ── fixtures ──────────────────────────────────────────────────────────────────


def test_register_partition_uses_unique_schema_warmup_key(monkeypatch) -> None:
from nemo_rl.data_plane.adapters import transfer_queue as tq_adapter
Expand Down Expand Up @@ -89,69 +82,9 @@ def fake_clear(**kwargs):
]


@pytest.fixture
def tq_client():
import ray

if not ray.is_initialized():
ray.init(local_mode=False, include_dashboard=False)

client = build_data_plane_client(
{
"enabled": True,
"impl": "transfer_queue",
"backend": "simple",
"storage_capacity": 1024,
"num_storage_units": 1,
"claim_meta_poll_interval_s": 0.5,
"global_segment_size": 8589934592, # 8 GiB (only read by mooncake_cpu)
"local_buffer_size": 1073741824, # 1 GiB (only read by mooncake_cpu)
}
)
yield client
client.close()


@pytest.fixture(
scope="module",
params=["simple", "mooncake_cpu"],
ids=["simple", "mooncake_cpu"],
)
def tq_client_backends(request):
"""Parametrized fixture over simple and mooncake_cpu backends.

mooncake_cpu is skipped when the mooncake wheel is not installed.
Set NEMO_RL_REQUIRE_MOONCAKE=1 to promote the skip to a loud failure.

Module-scoped to dodge mooncake's close→re-init race (stale C++ mount
registry); safe because each test uses a distinct ``partition_id``.
"""
backend = request.param
if backend == "mooncake_cpu" and not mooncake_available():
pytest.skip(
"mooncake not installed — skipping mooncake_cpu backend "
"(set NEMO_RL_REQUIRE_MOONCAKE=1 to fail loud)"
)

import ray

if not ray.is_initialized():
ray.init(local_mode=False, include_dashboard=False)

client = build_data_plane_client(
{
"enabled": True,
"impl": "transfer_queue",
"backend": backend,
"storage_capacity": 1024,
"num_storage_units": 1,
"claim_meta_poll_interval_s": 0.5,
"global_segment_size": 8589934592, # 8 GiB
"local_buffer_size": 1073741824, # 1 GiB
}
)
yield client
client.close()
# ``tq_client`` (simple) and ``tq_client_backends`` (parametrized over
# simple + mooncake_cpu) are session-scoped fixtures provided by
# ``tests/unit/data_plane/conftest.py``. See that file for the rationale.


def test_smoke_round_trip(tq_client) -> None:
Expand Down
Loading