From a0fe5cb6129a4cb78f5229364e236eb4d9dc6c46 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:03:21 +0000 Subject: [PATCH 01/18] refactor(ablp): extract per-anchor label loop into _loop_set_labels oracle Factor the inline per-anchor label-remap loop in DistABLPLoader._set_labels into a module-level function _loop_set_labels. The new function is a behavior-preserving extraction: _set_labels delegates to it, producing identical output. _loop_set_labels will serve as the equivalence oracle for the vectorized kernel added in the next task. Also imports PADDING_NODE from gigl.utils.data_splitters (used by the vectorized kernel in the next task) and adds the contract test file tests/unit/distributed/vectorized_set_labels_test.py. Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 121 ++++++++++++------ .../distributed/vectorized_set_labels_test.py | 111 ++++++++++++++++ 2 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 tests/unit/distributed/vectorized_set_labels_test.py diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 50f42f5a9..cafededb5 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -49,12 +49,84 @@ reverse_edge_type, select_label_edge_types, ) -from gigl.utils.data_splitters import get_labels_for_anchor_nodes +from gigl.utils.data_splitters import PADDING_NODE, get_labels_for_anchor_nodes from gigl.utils.sampling import ABLPInputNodes logger = Logger() +def _loop_set_labels( + node_local_to_global_by_type: dict[NodeType, torch.Tensor], + positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], + negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], + supervision_edge_types: list[EdgeType], + to_device: torch.device, +) -> tuple[ + dict[EdgeType, dict[int, torch.Tensor]], + dict[EdgeType, dict[int, torch.Tensor]], +]: + """Per-anchor (loop) label remap from global label ids to local node indices. + + Reference implementation retained as the equivalence oracle for + :func:`vectorized_set_labels`. The production path uses the vectorized + kernel; this loop is exercised only by tests. + + For each label edge type and each anchor row of its ``[N_anchors, M]`` + ``-1``-padded label tensor, emits the ascending local indices into the + supervision node type's ``node`` map whose global id appears in that row, in + :func:`torch.nonzero` multiplicity. + + Args: + node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node + type, a ``[N]`` tensor whose ``i``-th entry is the global id of + local node ``i``. + positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor + of global label ids. + negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, + for negative-label edge types. May be empty. + supervision_edge_types (list[EdgeType]): Supervision edge types (unused + here; accepted for signature parity with the vectorized kernel). + to_device (torch.device): Device for every output tensor. + + Returns: + Tuple ``(y_positive, y_negative)``, each a + ``dict[message_passing_edge_type, dict[anchor_index, local_index_tensor]]`` + with an entry for every anchor index ``0..N_anchors-1``. + """ + del supervision_edge_types # Parity with vectorized_set_labels; not needed. + output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(dict) + output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(dict) + # Supervision edge types are (anchor_node_type, to, supervision_node_type), + # so the supervision node type is at index 2. + edge_index = 2 + for edge_type, label_tensor in positive_labels_by_edge_type.items(): + message_passing_edge_type = label_edge_type_to_message_passing_edge_type( + edge_type + ) + supervision_node_map = node_local_to_global_by_type[edge_type[edge_index]] + for local_anchor_node_id in range(label_tensor.size(0)): + positive_mask = ( + supervision_node_map.unsqueeze(1) == label_tensor[local_anchor_node_id] + ) + output_positive_labels[message_passing_edge_type][local_anchor_node_id] = ( + torch.nonzero(positive_mask)[:, 0].to(to_device) + ) + for edge_type, label_tensor in negative_labels_by_edge_type.items(): + message_passing_edge_type = label_edge_type_to_message_passing_edge_type( + edge_type + ) + supervision_node_map = node_local_to_global_by_type[edge_type[edge_index]] + for local_anchor_node_id in range(label_tensor.size(0)): + negative_mask = ( + supervision_node_map.unsqueeze(1) == label_tensor[local_anchor_node_id] + ) + output_negative_labels[message_passing_edge_type][local_anchor_node_id] = ( + torch.nonzero(negative_mask)[:, 0].to(to_device) + ) + return dict(output_positive_labels), dict(output_negative_labels) + + class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. @@ -775,48 +847,13 @@ def _set_labels( node_type_to_local_node_to_global_node[DEFAULT_HOMOGENEOUS_NODE_TYPE] = ( data.node ) - output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( - dict - ) - output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( - dict + output_positive_labels, output_negative_labels = _loop_set_labels( + node_local_to_global_by_type=node_type_to_local_node_to_global_node, + positive_labels_by_edge_type=positive_labels_by_label_edge_type, + negative_labels_by_edge_type=negative_labels_by_label_edge_type, + supervision_edge_types=self._supervision_edge_types, + to_device=self.to_device, ) - # We always have supervision edge types of the form (anchor_node_type, to, supervision_node_type) - # So we can index into the edge type accordingly. - edge_index = 2 - for edge_type, label_tensor in positive_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - positive_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node - output_positive_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(positive_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node - - for edge_type, label_tensor in negative_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - negative_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node - output_negative_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(negative_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node if not output_positive_labels: raise ValueError("No positive labels were found in the data!") elif len(output_positive_labels) == 1: diff --git a/tests/unit/distributed/vectorized_set_labels_test.py b/tests/unit/distributed/vectorized_set_labels_test.py new file mode 100644 index 000000000..a0a6c5e3e --- /dev/null +++ b/tests/unit/distributed/vectorized_set_labels_test.py @@ -0,0 +1,111 @@ +"""Unit tests for the ABLP label-remap loop oracle and vectorized kernel. + +These exercise the pure-tensor label-remap logic directly (no GLT, no +distributed runtime), so they run in-process without ``mp.spawn``. +""" + +import unittest + +import torch +from parameterized import param, parameterized + +from gigl.distributed.dist_ablp_neighborloader import ( + _loop_set_labels, + vectorized_set_labels, +) +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, + message_passing_to_negative_label, + message_passing_to_positive_label, +) +from tests.test_assets.test_case import TestCase + +_CPU = torch.device("cpu") +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) + +_A = NodeType("a") +_B = NodeType("b") +_C = NodeType("c") +_A_TO_B = EdgeType(_A, Relation("to"), _B) +_A_TO_C = EdgeType(_A, Relation("to"), _C) + + +def _pos(edge_type: EdgeType) -> EdgeType: + return message_passing_to_positive_label(edge_type) + + +def _neg(edge_type: EdgeType) -> EdgeType: + return message_passing_to_negative_label(edge_type) + + +def _assert_label_dicts_equal( + actual: dict[EdgeType, dict[int, torch.Tensor]], + expected: dict[EdgeType, dict[int, torch.Tensor]], +) -> None: + assert set(actual.keys()) == set(expected.keys()), ( + f"{set(actual.keys())} != {set(expected.keys())}" + ) + for edge_type, inner in expected.items(): + actual_inner = actual[edge_type] + assert set(actual_inner.keys()) == set(inner.keys()), ( + f"{edge_type}: {set(actual_inner.keys())} != {set(inner.keys())}" + ) + for anchor, expected_tensor in inner.items(): + got = actual_inner[anchor] + assert got.dtype == torch.long, f"{edge_type}[{anchor}] dtype {got.dtype}" + torch.testing.assert_close(got, expected_tensor) + + +class LoopSetLabelsContractTest(TestCase): + def test_homogeneous_with_empty_and_padded_anchors(self) -> None: + # node holds global ids; index = local id. Supervision node type is + # _STORY (edge_type[2] of the positive-label edge type). + node = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]) + node_map = {_STORY: node} + # anchor 0 -> global 15 (local 5); anchor 1 -> {15,16} (local 5,6); + # anchor 2 -> fully padded (empty); anchor 3 -> global 99 (absent -> empty). + positives = { + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long + ) + } + y_pos, y_neg = _loop_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=_CPU, + ) + expected = { + _USER_TO_STORY: { + 0: torch.tensor([5]), + 1: torch.tensor([5, 6]), + 2: torch.tensor([], dtype=torch.long), + 3: torch.tensor([], dtype=torch.long), + } + } + _assert_label_dicts_equal(y_pos, expected) + self.assertEqual(y_neg, {}) + + def test_duplicate_label_columns_preserve_multiplicity(self) -> None: + # torch.nonzero over [N, M] yields one row index per matching column, + # so a node matching two identical label columns appears twice. + node = torch.tensor([10, 11, 12, 13, 14, 15]) + node_map = {_STORY: node} + positives = {_pos(_USER_TO_STORY): torch.tensor([[15, 15]], dtype=torch.long)} + y_pos, _ = _loop_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=_CPU, + ) + torch.testing.assert_close(y_pos[_USER_TO_STORY][0], torch.tensor([5, 5])) + + +if __name__ == "__main__": + unittest.main() From 00f4cdb2c358548bf26340a654968ae778e40d2f Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:09:27 +0000 Subject: [PATCH 02/18] feat(ablp): add vectorized_set_labels kernel, equivalence-tested vs loop oracle Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 209 ++++++++++++++++++ .../distributed/vectorized_set_labels_test.py | 184 +++++++++++++++ 2 files changed, 393 insertions(+) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index cafededb5..8eeb107d7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -127,6 +127,215 @@ def _loop_set_labels( return dict(output_positive_labels), dict(output_negative_labels) +def _membership_remap( + label_tensor: torch.Tensor, + sorted_node: torch.Tensor, + sort_perm: torch.Tensor, + to_device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Shared per-tensor membership core for the vectorized label kernels. + + Remaps one ``[N_anchors, M]`` ``-1``-padded label tensor of global ids to the + flat ``(anchor_index, local_index)`` pair stream shared by both the dict and + edge-list builders, in :func:`torch.nonzero` order: ascending local index per + anchor, ties broken by ascending label column. The two callers + (:func:`_remap_one_label_tensor`, :func:`_remap_one_label_tensor_edge_list`) + differ only in how they package this pair stream, so the searchsorted + membership logic lives here once. + + Precondition (REQUIRED for correctness): ``sorted_node`` must have UNIQUE + values (the node map is unique local->global). :func:`torch.searchsorted` + returns the left-most equal position, so a duplicate global id would collapse + multiple local indices to one and diverge from the loop oracle. GiGL ``node`` + maps guarantee uniqueness; the check is asserted only under ``__debug__`` to + keep the hot path zero-cost (and is a no-op under ``python -O``). + + Args: + label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global + label ids. + sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. + sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted + positions back to original local indices. + to_device (torch.device): Device for the returned index tensors. + + Returns: + Tuple ``(anchor_index, local_index, num_anchors)``. ``anchor_index`` and + ``local_index`` are equal-length 1-D ``long`` tensors on ``to_device`` in + :func:`torch.nonzero` order (empty when nothing matched); + ``num_anchors == label_tensor.size(0)``. + """ + num_anchors = int(label_tensor.size(0)) + num_nodes = int(sorted_node.size(0)) + empty = torch.empty(0, dtype=torch.long, device=to_device) + if num_anchors == 0: + return empty, empty, num_anchors + + num_labels = int(label_tensor.size(1)) + flat = label_tensor.reshape(-1) + # Create on the label tensor's device: it is indexed below by `valid` + # (derived from `label_tensor`), so on GPU a CPU arange would raise + # "indices should be either on cpu or on the same device as the indexed + # tensor". CPU-only unit tests cannot catch this; see the CUDA-gated test. + anchor_of_entry = torch.arange( + num_anchors, device=label_tensor.device + ).repeat_interleave(num_labels) + + # Mask the padding sentinel BEFORE any search so we never gather with -1. + valid = flat != PADDING_NODE + flat = flat[valid] + anchor_of_entry = anchor_of_entry[valid] + + if num_nodes == 0 or flat.numel() == 0: + return empty, empty, num_anchors + + if __debug__: + assert int(torch.unique(sorted_node).numel()) == num_nodes, ( + "vectorized label remap requires a unique node local->global map; " + "duplicate global ids break the searchsorted membership lookup." + ) + positions = torch.searchsorted(sorted_node, flat) + positions = positions.clamp_(max=num_nodes - 1) + found = sorted_node[positions] == flat + local_idx = sort_perm[positions][found] + anchor_kept = anchor_of_entry[found] + + # Order within each anchor must match torch.nonzero over [N, M]: ascending + # local index, ties broken by ascending label column. searchsorted visits + # entries in (anchor, column) order, so a stable sort on a composite key + # (anchor primary, local index secondary) reproduces it. + composite_key = anchor_kept * (num_nodes + 1) + local_idx + order = torch.argsort(composite_key, stable=True) + return ( + anchor_kept[order].to(to_device).to(torch.long), + local_idx[order].to(to_device).to(torch.long), + num_anchors, + ) + + +def _remap_one_label_tensor( + label_tensor: torch.Tensor, + sorted_node: torch.Tensor, + sort_perm: torch.Tensor, + to_device: torch.device, +) -> dict[int, torch.Tensor]: + """Vectorized remap of one ``[N_anchors, M]`` padded label tensor to a dict. + + Thin wrapper over :func:`_membership_remap`: splits the shared + ``(anchor_index, local_index)`` pair stream into a per-anchor dict. For each + anchor row, the value is the ascending local indices into the original + (pre-sort) node order whose global id appears in that row, in + :func:`torch.nonzero` multiplicity. + + Args: + label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global + label ids. + sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. + sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted + positions back to original local indices. + to_device (torch.device): Device for every output tensor. + + Returns: + Mapping from anchor index ``0..N_anchors-1`` to a 1-D ``long`` tensor of + local indices (empty where the row matched nothing). + """ + anchor_index, local_idx, num_anchors = _membership_remap( + label_tensor, sorted_node, sort_perm, to_device + ) + # Defensive: `vectorized_set_labels` already `continue`s past zero-anchor + # tensors (to match the loop's defaultdict, which never creates the outer + # key), so this branch is unreachable from that caller. Kept so the helper is + # self-consistent for any external caller. + if num_anchors == 0: + return {} + counts = torch.bincount(anchor_index, minlength=num_anchors) + per_anchor = torch.split(local_idx, counts.tolist()) + return {anchor: per_anchor[anchor] for anchor in range(num_anchors)} + + +def vectorized_set_labels( + node_local_to_global_by_type: dict[NodeType, torch.Tensor], + positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], + negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], + supervision_edge_types: list[EdgeType], + to_device: torch.device, +) -> tuple[ + dict[EdgeType, dict[int, torch.Tensor]], + dict[EdgeType, dict[int, torch.Tensor]], +]: + """Vectorized label remap from global label ids to local node indices. + + Drop-in replacement for the per-anchor loop in :func:`_loop_set_labels`, + producing bit-for-bit identical ragged output without a per-anchor Python + loop. + + For each label edge type and each anchor row of its ``[N_anchors, M]`` + ``-1``-padded label tensor, emits the ascending local indices into the + supervision node type's ``node`` map whose global id appears in that row, in + :func:`torch.nonzero` multiplicity. The padding sentinel + (:data:`gigl.utils.data_splitters.PADDING_NODE`) is masked before any search, + so it is never used as a lookup key. Every anchor index ``0..N_anchors-1`` + receives a key; anchors with no in-subgraph labels map to an empty ``long`` + tensor. + + Precondition (REQUIRED for correctness): each ``node`` local->global map in + ``node_local_to_global_by_type`` MUST contain UNIQUE global ids. The + ``torch.searchsorted`` membership lookup returns the LEFT-MOST matching sorted + position; a repeated global id would resolve every match to a single local + index, dropping the duplicate and silently diverging from the loop. GiGL + ``node`` maps satisfy this by construction. + + Args: + node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node + type, a ``[N]`` tensor whose ``i``-th entry is the global id of + local node ``i``. Global ids MUST be unique within each map. + positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor + of global label ids. + negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, + for negative-label edge types. May be empty. + supervision_edge_types (list[EdgeType]): Supervision edge types (unused + here; accepted for signature parity with the loop reference). + to_device (torch.device): Device for every output tensor. + + Returns: + Tuple ``(y_positive, y_negative)``, each a + ``dict[message_passing_edge_type, dict[anchor_index, local_index_tensor]]`` + with an entry for every anchor index ``0..N_anchors-1``. + """ + del supervision_edge_types # Accepted for signature parity; not needed here. + edge_index = 2 # Supervision edge types are (anchor, to, supervision). + sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} + + def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: + if node_type not in sorted_cache: + sorted_cache[node_type] = torch.sort( + node_local_to_global_by_type[node_type] + ) + return sorted_cache[node_type] + + output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = {} + for edge_type, label_tensor in positive_labels_by_edge_type.items(): + # Match the loop's defaultdict: a zero-anchor tensor produces NO outer + # key (the loop's per-anchor body never runs). + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) + output_positive_labels[ + label_edge_type_to_message_passing_edge_type(edge_type) + ] = _remap_one_label_tensor(label_tensor, sorted_node, sort_perm, to_device) + + output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = {} + for edge_type, label_tensor in negative_labels_by_edge_type.items(): + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) + output_negative_labels[ + label_edge_type_to_message_passing_edge_type(edge_type) + ] = _remap_one_label_tensor(label_tensor, sorted_node, sort_perm, to_device) + + return output_positive_labels, output_negative_labels + + class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. diff --git a/tests/unit/distributed/vectorized_set_labels_test.py b/tests/unit/distributed/vectorized_set_labels_test.py index a0a6c5e3e..6bd6c2a0b 100644 --- a/tests/unit/distributed/vectorized_set_labels_test.py +++ b/tests/unit/distributed/vectorized_set_labels_test.py @@ -107,5 +107,189 @@ def test_duplicate_label_columns_preserve_multiplicity(self) -> None: torch.testing.assert_close(y_pos[_USER_TO_STORY][0], torch.tensor([5, 5])) +class VectorizedSetLabelsEquivalenceTest(TestCase): + @parameterized.expand( + [ + param( + "homogeneous_present_empty_and_padded", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "homogeneous_duplicate_labels", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, 15], [11, 11]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + # UNSORTED node map + reversed label columns. With a sorted map + # `sort_perm` is the identity and a broken port that emits sorted (or + # global) order would still pass; here `torch.sort` permutes nontrivially + # (15->local0, 10->local1, 16->local2, 11->local3), and the label row is + # given high-id-first ([16, 15]). The loop oracle emits ascending LOCAL + # index regardless of column order, i.e. local 0 (g15) before local 2 + # (g16) -> [0, 2]; a kernel that forgot to map through `sort_perm`, or + # that preserved column order, would diverge. + param( + "unsorted_node_map_reversed_columns", + node_map={_STORY: torch.tensor([15, 10, 16, 11])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + # Real homogeneous keying: the loader keys the node map by + # DEFAULT_HOMOGENEOUS_NODE_TYPE and uses DEFAULT_HOMOGENEOUS_EDGE_TYPE + # for the supervision edge type (see _set_labels). Exercise that exact + # keying at the kernel level, not just a custom NodeType. + param( + "default_homogeneous_keying", + node_map={ + DEFAULT_HOMOGENEOUS_NODE_TYPE: torch.tensor([20, 10, 30, 11, 15]) + }, + positives={ + message_passing_to_positive_label( + DEFAULT_HOMOGENEOUS_EDGE_TYPE + ): torch.tensor([[30, 10], [-1, -1]], dtype=torch.long) + }, + negatives={}, + supervision_edge_types=[DEFAULT_HOMOGENEOUS_EDGE_TYPE], + ), + param( + "homogeneous_with_negatives", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor([[15], [16]], dtype=torch.long) + }, + negatives={ + _neg(_USER_TO_STORY): torch.tensor( + [[13, 16], [17, -1]], dtype=torch.long + ) + }, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "heterogeneous_multi_edge_type", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 20, 21]), + _C: torch.tensor([20, 21, 22, 23]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={}, + supervision_edge_types=[_A_TO_B, _A_TO_C], + ), + param( + "heterogeneous_multi_edge_type_with_negatives", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 15, 16]), + _C: torch.tensor([20, 21, 22, 23, 24, 25]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={ + _neg(_A_TO_B): torch.tensor([[15, 16]], dtype=torch.long), + _neg(_A_TO_C): torch.tensor([[24, 25]], dtype=torch.long), + }, + supervision_edge_types=[_A_TO_B, _A_TO_C], + ), + param( + "all_anchors_empty", + node_map={_STORY: torch.tensor([10, 11, 12])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[-1, -1], [99, 98]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "zero_anchors", + node_map={_STORY: torch.tensor([10, 11, 12])}, + positives={_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)}, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + ] + ) + def test_matches_loop( + self, + _, + node_map: dict[NodeType, torch.Tensor], + positives: dict[EdgeType, torch.Tensor], + negatives: dict[EdgeType, torch.Tensor], + supervision_edge_types: list[EdgeType], + ) -> None: + loop_pos, loop_neg = _loop_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type=negatives, + supervision_edge_types=supervision_edge_types, + to_device=_CPU, + ) + vec_pos, vec_neg = vectorized_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type=negatives, + supervision_edge_types=supervision_edge_types, + to_device=_CPU, + ) + _assert_label_dicts_equal(vec_pos, loop_pos) + _assert_label_dicts_equal(vec_neg, loop_neg) + + def test_unsorted_node_map_exact_order(self) -> None: + # Belt-and-suspenders on top of the parameterized case: assert the EXACT + # tensor (not just equality-to-loop) so the expected ascending-local order + # is visible and a regression to sorted/column order is unmistakable. + node_map = {_STORY: torch.tensor([15, 10, 16, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long)} + vec_pos, _ = vectorized_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=_CPU, + ) + # g15 is local 0, g16 is local 2 -> ascending local order is [0, 2]. + torch.testing.assert_close( + vec_pos[_USER_TO_STORY][0], torch.tensor([0, 2], dtype=torch.long) + ) + + def test_duplicate_node_map_raises_assertion(self) -> None: + # NOTE: the uniqueness check is gated on `__debug__`, so this assertion is + # a no-op under `python -O` / `PYTHONOPTIMIZE`. GiGL node maps are unique by + # construction (each local index is a distinct subgraph node), so the guard + # exists only to catch misuse; the test asserts the guard fires under the + # default (non-optimized) interpreter used by the test suite. + node_map = {_STORY: torch.tensor([10, 10, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} + with self.assertRaises(AssertionError): + vectorized_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=_CPU, + ) + + if __name__ == "__main__": unittest.main() From fd3fda5a2e58eae40b19b918bfe48969eac7bb50 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:17:08 +0000 Subject: [PATCH 03/18] feat(ablp): add AnchorLabels edge-list container + edge_list_set_labels kernel Add frozen dataclass AnchorLabels (anchor_index, label_index, num_anchors) with to_dict() bridge; thin wrapper _remap_one_label_tensor_edge_list over the shared _membership_remap; and edge_list_set_labels driver. Proves to_dict() reproduces _loop_set_labels bit-for-bit via parametrized equivalence tests (11 new tests, all classes pass). Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 152 +++++++++++ .../distributed/edge_list_set_labels_test.py | 252 ++++++++++++++++++ 2 files changed, 404 insertions(+) create mode 100644 tests/unit/distributed/edge_list_set_labels_test.py diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 8eeb107d7..cc76f46e8 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,4 +1,5 @@ from collections import abc, defaultdict +from dataclasses import dataclass from itertools import count from typing import Optional, Union @@ -55,6 +56,49 @@ logger = Logger() +@dataclass(frozen=True) +class AnchorLabels: + """Dense edge-list ABLP labels for one label edge type. + + Replaces the ragged per-anchor ``dict[int, torch.Tensor]`` with two parallel + 1-D ``long`` tensors. Pair ``k`` asserts that local anchor row + ``anchor_index[k]`` has label node ``label_index[k]``. + + Pairs are ordered ascending by anchor, ties broken ascending by label + column -- identical to ``torch.nonzero`` over the ``[N_anchors, M]`` padded + label tensor, so concatenating the legacy dict's values in anchor order + yields ``label_index`` and the matching anchor repeats yield + ``anchor_index``. + + Anchors with no in-subgraph labels contribute zero pairs; ``num_anchors`` + records the full anchor count so empty anchors remain recoverable. + + Args: + anchor_index (torch.Tensor): ``[E]`` long tensor of local anchor rows. + label_index (torch.Tensor): ``[E]`` long tensor of local label node ids. + num_anchors (int): Total number of anchors ``N`` (rows of the source + padded label tensor). + """ + + anchor_index: torch.Tensor + label_index: torch.Tensor + num_anchors: int + + def to_dict(self) -> dict[int, torch.Tensor]: + """Expand to the legacy ragged ``dict[int, torch.Tensor]`` form. + + Every anchor ``0..num_anchors-1`` receives a key; anchors with no labels + map to an empty ``long`` tensor on the same device as ``label_index``. + + Returns: + Mapping from anchor index to its 1-D ``long`` tensor of local label + node ids. + """ + counts = torch.bincount(self.anchor_index, minlength=self.num_anchors) + per_anchor = torch.split(self.label_index, counts.tolist()) + return {anchor: per_anchor[anchor] for anchor in range(self.num_anchors)} + + def _loop_set_labels( node_local_to_global_by_type: dict[NodeType, torch.Tensor], positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], @@ -336,6 +380,114 @@ def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: return output_positive_labels, output_negative_labels +def _remap_one_label_tensor_edge_list( + label_tensor: torch.Tensor, + sorted_node: torch.Tensor, + sort_perm: torch.Tensor, + to_device: torch.device, +) -> AnchorLabels: + """Vectorized edge-list remap of one ``[N_anchors, M]`` padded label tensor. + + Thin wrapper over the shared :func:`_membership_remap`: wraps the returned + ``(anchor_index, local_index)`` pair stream directly into a dense + :class:`AnchorLabels`. This is strictly less work than the dict builder (no + ``torch.bincount``/``torch.split``, no per-anchor Python comprehension); + ``_membership_remap`` already moved both tensors to ``to_device`` as ``long``. + + Args: + label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global + label ids. + sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. + sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted + positions back to original local indices. + to_device (torch.device): Device for the output tensors. + + Returns: + AnchorLabels with ``anchor_index``/``label_index`` in + :func:`torch.nonzero` order and ``num_anchors == label_tensor.size(0)``. + """ + anchor_index, label_index, num_anchors = _membership_remap( + label_tensor, sorted_node, sort_perm, to_device + ) + return AnchorLabels( + anchor_index=anchor_index, + label_index=label_index, + num_anchors=num_anchors, + ) + + +def edge_list_set_labels( + node_local_to_global_by_type: dict[NodeType, torch.Tensor], + positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], + negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], + supervision_edge_types: list[EdgeType], + to_device: torch.device, +) -> tuple[dict[EdgeType, AnchorLabels], dict[EdgeType, AnchorLabels]]: + """Dense edge-list label remap from global label ids to local node indices. + + Drop-in alternative to :func:`vectorized_set_labels` that emits, per label + edge type, an :class:`AnchorLabels` dense edge-list instead of a ragged + ``dict[int, torch.Tensor]``. Membership semantics are identical; expanding + each result via :meth:`AnchorLabels.to_dict` reproduces + :func:`_loop_set_labels` bit-for-bit. + + Precondition (REQUIRED): each ``node`` local->global map must contain UNIQUE + global ids -- see :func:`vectorized_set_labels` for the rationale. + + Args: + node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node + type, a ``[N]`` tensor whose ``i``-th entry is the global id of + local node ``i``. Global ids MUST be unique within each map. + positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor + of global label ids. + negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, + for negative-label edge types. May be empty. + supervision_edge_types (list[EdgeType]): Accepted for signature parity + with the other kernels; unused here. + to_device (torch.device): Device for every output tensor. + + Returns: + Tuple ``(y_positive, y_negative)``, each a + ``dict[message_passing_edge_type, AnchorLabels]`` with NO entry for a + zero-anchor label tensor (matching the loop's defaultdict). + """ + del supervision_edge_types # Accepted for signature parity; not needed here. + edge_index = 2 # Supervision edge types are (anchor, to, supervision). + sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} + + def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: + if node_type not in sorted_cache: + sorted_cache[node_type] = torch.sort( + node_local_to_global_by_type[node_type] + ) + return sorted_cache[node_type] + + output_positive_labels: dict[EdgeType, AnchorLabels] = {} + for edge_type, label_tensor in positive_labels_by_edge_type.items(): + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) + output_positive_labels[ + label_edge_type_to_message_passing_edge_type(edge_type) + ] = _remap_one_label_tensor_edge_list( + label_tensor, sorted_node, sort_perm, to_device + ) + + output_negative_labels: dict[EdgeType, AnchorLabels] = {} + for edge_type, label_tensor in negative_labels_by_edge_type.items(): + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) + output_negative_labels[ + label_edge_type_to_message_passing_edge_type(edge_type) + ] = _remap_one_label_tensor_edge_list( + label_tensor, sorted_node, sort_perm, to_device + ) + + return output_positive_labels, output_negative_labels + + class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py new file mode 100644 index 000000000..578838505 --- /dev/null +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -0,0 +1,252 @@ +"""Unit tests for the dense edge-list ABLP label container and kernel. + +These exercise the pure-tensor label-remap logic directly (no GLT, no +distributed runtime), so they run in-process without ``mp.spawn``. +""" + +import unittest + +import torch +from parameterized import param, parameterized +from torch_geometric.typing import EdgeType as PyGEdgeType + +from gigl.distributed.dist_ablp_neighborloader import ( + AnchorLabels, + _loop_set_labels, + _remap_one_label_tensor_edge_list, + edge_list_set_labels, +) +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + message_passing_to_negative_label, + message_passing_to_positive_label, +) +from tests.test_assets.test_case import TestCase + +_CPU = torch.device("cpu") +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) +_A = NodeType("a") +_B = NodeType("b") +_C = NodeType("c") +_A_TO_B = EdgeType(_A, Relation("to"), _B) +_A_TO_C = EdgeType(_A, Relation("to"), _C) + + +def _pos(edge_type: EdgeType) -> EdgeType: + return message_passing_to_positive_label(edge_type) + + +def _neg(edge_type: EdgeType) -> EdgeType: + return message_passing_to_negative_label(edge_type) + + +def _dict_from_edge_list( + edge_list_by_type: dict[PyGEdgeType, AnchorLabels], +) -> dict[PyGEdgeType, dict[int, torch.Tensor]]: + return {et: labels.to_dict() for et, labels in edge_list_by_type.items()} + + +def _assert_label_dicts_equal( + actual: dict[PyGEdgeType, dict[int, torch.Tensor]], + expected: dict[PyGEdgeType, dict[int, torch.Tensor]], +) -> None: + assert set(actual.keys()) == set(expected.keys()), ( + f"{set(actual.keys())} != {set(expected.keys())}" + ) + for edge_type, inner in expected.items(): + actual_inner = actual[edge_type] + assert set(actual_inner.keys()) == set(inner.keys()) + for anchor, expected_tensor in inner.items(): + got = actual_inner[anchor] + assert got.dtype == torch.long + torch.testing.assert_close(got, expected_tensor) + + +class AnchorLabelsTest(TestCase): + def test_to_dict_round_trips_empty_and_multi(self) -> None: + # 3 anchors: anchor 0 -> [5], anchor 1 -> [] (empty), anchor 2 -> [7, 8]. + labels = AnchorLabels( + anchor_index=torch.tensor([0, 2, 2], dtype=torch.long), + label_index=torch.tensor([5, 7, 8], dtype=torch.long), + num_anchors=3, + ) + as_dict = labels.to_dict() + self.assertEqual(set(as_dict.keys()), {0, 1, 2}) + torch.testing.assert_close(as_dict[0], torch.tensor([5], dtype=torch.long)) + torch.testing.assert_close(as_dict[1], torch.empty(0, dtype=torch.long)) + torch.testing.assert_close(as_dict[2], torch.tensor([7, 8], dtype=torch.long)) + + +class RemapOneEdgeListTest(TestCase): + def test_matches_nonzero_order_with_padding_and_empty(self) -> None: + node = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]) + sorted_node, sort_perm = torch.sort(node) + # anchor 0: [15, -1] -> local 5 ; anchor 1: [15, 16] -> local 5,6 ; + # anchor 2: [-1, -1] -> empty ; anchor 3: [99, -1] -> empty (99 absent). + label_tensor = torch.tensor( + [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long + ) + result = _remap_one_label_tensor_edge_list( + label_tensor, sorted_node, sort_perm, _CPU + ) + self.assertEqual(result.num_anchors, 4) + torch.testing.assert_close( + result.anchor_index, torch.tensor([0, 1, 1], dtype=torch.long) + ) + torch.testing.assert_close( + result.label_index, torch.tensor([5, 5, 6], dtype=torch.long) + ) + + def test_unsorted_node_map_nontrivial_sort_perm(self) -> None: + # node map is UNSORTED, so torch.sort yields a non-identity sort_perm. + # node[i]=global id of local i: g15->local0, g10->local1, g16->local2, + # g11->local3. The label row is high-id-first ([16, 15]); the edge-list + # must emit ascending LOCAL index (g15=local0 before g16=local2), proving + # the result is mapped through sort_perm and is not in column or sorted + # order. A port that dropped the sort_perm gather would emit [0, 2] mapped + # to the wrong locals (or sorted-position indices 1 and 3) and fail here. + node = torch.tensor([15, 10, 16, 11]) + sorted_node, sort_perm = torch.sort(node) + label_tensor = torch.tensor([[16, 15]], dtype=torch.long) + result = _remap_one_label_tensor_edge_list( + label_tensor, sorted_node, sort_perm, _CPU + ) + self.assertEqual(result.num_anchors, 1) + torch.testing.assert_close( + result.anchor_index, torch.tensor([0, 0], dtype=torch.long) + ) + torch.testing.assert_close( + result.label_index, torch.tensor([0, 2], dtype=torch.long) + ) + + +class EdgeListSetLabelsEquivalenceTest(TestCase): + @parameterized.expand( + [ + param( + "homogeneous_present_empty_and_padded", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "homogeneous_duplicate_labels", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, 15], [11, 11]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + # UNSORTED node map + reversed label columns -- non-identity sort_perm + # (see the vectorized test for the rationale). Guards against a port + # that emits sorted/column order instead of ascending local index. + param( + "unsorted_node_map_reversed_columns", + node_map={_STORY: torch.tensor([15, 10, 16, 11])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "homogeneous_with_negatives", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor([[15], [16]], dtype=torch.long) + }, + negatives={ + _neg(_USER_TO_STORY): torch.tensor( + [[13, 16], [17, -1]], dtype=torch.long + ) + }, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "heterogeneous_multi_edge_type", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 20, 21]), + _C: torch.tensor([20, 21, 22, 23]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={}, + supervision_edge_types=[_A_TO_B, _A_TO_C], + ), + param( + "all_anchors_empty", + node_map={_STORY: torch.tensor([10, 11, 12])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[-1, -1], [99, 98]], dtype=torch.long + ) + }, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + param( + "zero_anchors", + node_map={_STORY: torch.tensor([10, 11, 12])}, + positives={_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)}, + negatives={}, + supervision_edge_types=[_USER_TO_STORY], + ), + ] + ) + def test_edge_list_to_dict_matches_loop( + self, + _, + node_map: dict[NodeType, torch.Tensor], + positives: dict[EdgeType, torch.Tensor], + negatives: dict[EdgeType, torch.Tensor], + supervision_edge_types: list[EdgeType], + ) -> None: + loop_pos, loop_neg = _loop_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type=negatives, + supervision_edge_types=supervision_edge_types, + to_device=_CPU, + ) + el_pos, el_neg = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type=negatives, + supervision_edge_types=supervision_edge_types, + to_device=_CPU, + ) + _assert_label_dicts_equal(_dict_from_edge_list(el_pos), loop_pos) + _assert_label_dicts_equal(_dict_from_edge_list(el_neg), loop_neg) + + def test_duplicate_node_map_raises_assertion(self) -> None: + # NOTE: the uniqueness check is gated on `__debug__`, so this is a no-op + # under `python -O` / `PYTHONOPTIMIZE`. GiGL node maps are unique by + # construction; the guard catches misuse only. The test asserts it fires + # under the default (non-optimized) interpreter used by the test suite. + node_map = {_STORY: torch.tensor([10, 10, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} + with self.assertRaises(AssertionError): + edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=_CPU, + ) + + +if __name__ == "__main__": + unittest.main() From 310150609829809e6acaee40b5cfe3b3f68b4bc1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:21:45 +0000 Subject: [PATCH 04/18] test(ablp): add CUDA device-placement regression for label-remap kernels Co-Authored-By: Claude Opus 4.8 (1M context) --- .../label_remap_cuda_device_test.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 tests/unit/distributed/label_remap_cuda_device_test.py diff --git a/tests/unit/distributed/label_remap_cuda_device_test.py b/tests/unit/distributed/label_remap_cuda_device_test.py new file mode 100644 index 000000000..396243d7e --- /dev/null +++ b/tests/unit/distributed/label_remap_cuda_device_test.py @@ -0,0 +1,115 @@ +"""CUDA device-placement regression test for the ABLP label-remap kernels. + +``vectorized_set_labels`` and ``edge_list_set_labels`` build an internal +``anchor_of_entry`` index and then select it with a mask derived from the input +``label_tensor``. If that index is created on CPU while ``label_tensor`` is on +GPU, the masked select raises ``"indices should be either on cpu or on the same +device as the indexed tensor"``. CPU-only unit tests cannot observe this, so the +bug only surfaces on a real GPU training run. + +These tests run the kernels with all inputs on CUDA and assert the result equals +the CPU result. They are skipped when no GPU is present (e.g. CPU CI); run them +on a CUDA host to guard the device placement. +""" + +import unittest + +import torch + +from gigl.distributed.dist_ablp_neighborloader import ( + edge_list_set_labels, + vectorized_set_labels, +) +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import message_passing_to_positive_label +from tests.test_assets.test_case import TestCase + +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) + + +def _inputs(device: torch.device): + """A small case exercising every on-device code path. + + The node map is UNSORTED (so ``torch.sort`` yields a non-identity + ``sort_perm`` -- the gather through it must run on-device), and anchor 0 has + a DUPLICATE label column ([15, 15]) so the stable-argsort tie-break over the + composite key runs on-device. Local layout: g15->0, g10->1, g16->2, g11->3, + g12->4. Anchor rows: [15, 15] -> local 0 twice; [16, -1] -> local 2; + [-1, -1] -> empty. + """ + node_map = {_STORY: torch.tensor([15, 10, 16, 11, 12], device=device)} + positives = { + message_passing_to_positive_label(_USER_TO_STORY): torch.tensor( + [[15, 15], [16, -1], [-1, -1]], dtype=torch.long, device=device + ) + } + return node_map, positives + + +@unittest.skipUnless(torch.cuda.is_available(), "requires a CUDA device") +class LabelRemapCudaDeviceTest(TestCase): + def test_vectorized_set_labels_cuda_matches_cpu(self) -> None: + cpu_node, cpu_pos = _inputs(torch.device("cpu")) + expected_pos, _ = vectorized_set_labels( + node_local_to_global_by_type=cpu_node, + positive_labels_by_edge_type=cpu_pos, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=torch.device("cpu"), + ) + + cuda = torch.device("cuda") + cuda_node, cuda_pos = _inputs(cuda) + # Must not raise the CPU/GPU index mismatch. + got_pos, _ = vectorized_set_labels( + node_local_to_global_by_type=cuda_node, + positive_labels_by_edge_type=cuda_pos, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=cuda, + ) + + self.assertEqual(set(got_pos.keys()), set(expected_pos.keys())) + for edge_type, inner in expected_pos.items(): + got_inner = got_pos[edge_type] + self.assertEqual(set(got_inner.keys()), set(inner.keys())) + for anchor, expected_tensor in inner.items(): + got = got_inner[anchor] + self.assertEqual(got.device.type, "cuda") + torch.testing.assert_close(got.cpu(), expected_tensor) + + def test_edge_list_set_labels_cuda_matches_cpu(self) -> None: + cpu_node, cpu_pos = _inputs(torch.device("cpu")) + expected_pos, _ = edge_list_set_labels( + node_local_to_global_by_type=cpu_node, + positive_labels_by_edge_type=cpu_pos, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=torch.device("cpu"), + ) + + cuda = torch.device("cuda") + cuda_node, cuda_pos = _inputs(cuda) + got_pos, _ = edge_list_set_labels( + node_local_to_global_by_type=cuda_node, + positive_labels_by_edge_type=cuda_pos, + negative_labels_by_edge_type={}, + supervision_edge_types=[_USER_TO_STORY], + to_device=cuda, + ) + + self.assertEqual(set(got_pos.keys()), set(expected_pos.keys())) + for edge_type, expected_labels in expected_pos.items(): + got_labels = got_pos[edge_type] + self.assertEqual(got_labels.anchor_index.device.type, "cuda") + expected_dict = expected_labels.to_dict() + got_dict = got_labels.to_dict() + self.assertEqual(set(got_dict.keys()), set(expected_dict.keys())) + for anchor, expected_tensor in expected_dict.items(): + torch.testing.assert_close(got_dict[anchor].cpu(), expected_tensor) + + +if __name__ == "__main__": + unittest.main() From d51bd7a01c9b70a2a72eb9167d56b93611874ec6 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:47:18 +0000 Subject: [PATCH 05/18] feat(ablp): vectorized label remap always + use_list_output ctor flag Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 18 +- .../dist_ablp_neighborloader_test.py | 195 +++++++++++++++++- 2 files changed, 211 insertions(+), 2 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index cc76f46e8..c83efbc40 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -523,6 +523,7 @@ def __init__( local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this non_blocking_transfers: bool = True, + use_list_output: bool = False, ): """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. @@ -646,11 +647,19 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. + use_list_output (bool): If False (default), ``y_positive`` and + ``y_negative`` are the legacy ragged ``dict[int, torch.Tensor]`` + per-anchor dicts. If True, they are :class:`AnchorLabels` dense + edge-list objects (single supervision edge type) or + ``dict[EdgeType, AnchorLabels]`` (multiple supervision edge + types), which expose ``anchor_index``, ``label_index``, and + ``num_anchors`` directly and provide a ``.to_dict()`` conversion. """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True + self._use_list_output = use_list_output sampler_options = resolve_sampler_options(num_neighbors, sampler_options) @@ -1208,7 +1217,14 @@ def _set_labels( node_type_to_local_node_to_global_node[DEFAULT_HOMOGENEOUS_NODE_TYPE] = ( data.node ) - output_positive_labels, output_negative_labels = _loop_set_labels( + # Vectorized remap is the production path: bit-for-bit equivalent to the + # _loop_set_labels oracle, GPU-safe, and faster than the per-anchor loop + # (the gap grows with batch size). When use_list_output is set, emit the + # dense AnchorLabels edge-list instead of the ragged per-anchor dict. + label_remap = ( + edge_list_set_labels if self._use_list_output else vectorized_set_labels + ) + output_positive_labels, output_negative_labels = label_remap( node_local_to_global_by_type=node_type_to_local_node_to_global_node, positive_labels_by_edge_type=positive_labels_by_label_edge_type, negative_labels_by_edge_type=negative_labels_by_label_edge_type, diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 31d3d1cbc..6ef3be5b6 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -11,7 +11,7 @@ from torch_geometric.data import Data, HeteroData from gigl.distributed.dataset_factory import build_dataset -from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_ablp_neighborloader import AnchorLabels, DistABLPLoader from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner @@ -416,6 +416,103 @@ def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( shutdown_rpc() +def _ordered_global_pairs( + node: torch.Tensor, label_dict: dict[int, torch.Tensor] +) -> list[tuple[int, int]]: + """Flatten a per-anchor label dict to an ORDERED (global_anchor, global_label) + pair stream. + + Iterates anchors in ascending key order and labels in their stored order + WITHOUT sorting, so a change in pair order (e.g. a kernel that emitted + columns or sorted-position order) changes the returned list. This is the + silent-mistraining mode a per-anchor sorted-set comparison would miss. + """ + pairs: list[tuple[int, int]] = [] + for local_anchor in sorted(label_dict.keys()): + global_anchor = int(node[local_anchor].item()) + for local_label in label_dict[local_anchor].tolist(): + pairs.append((global_anchor, int(node[local_label].item()))) + return pairs + + +def _collect_homogeneous_labels( + _, + return_dict, + use_list_output: bool, + dataset: DistDataset, + input_nodes: torch.Tensor, + batch_size: int, + has_negatives: bool, +): + """Child-side: run the loader, return the ORDERED global-id pair streams. + + Local node indices differ run-to-run, so labels are translated back to + global ids via ``datum.node``. The streams preserve pair ORDER (see + ``_ordered_global_pairs``) so dict-vs-edge-list equality in the parent + catches an order regression, not just a set regression. + + When ``use_list_output`` is True the labels arrive as :class:`AnchorLabels`. + This branch ALSO asserts in-process that the exact tensors the example + training loss reads from the edge-list match the legacy dict read: the + edge-list ``label_index`` must equal the dict's ``torch.cat(values())`` and + ``query_node_idx[anchor_index]`` must equal the legacy + ``repeat_interleave`` over per-anchor lengths. A drift here is exactly the + example-training bug we are guarding against. + """ + create_test_process_group() + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[2, 2], + input_nodes=input_nodes, + batch_size=batch_size, + pin_memory_device=torch.device("cpu"), + use_list_output=use_list_output, + ) + positive_pairs: list[tuple[int, int]] = [] + negative_pairs: list[tuple[int, int]] = [] + for datum in loader: + assert isinstance(datum, Data) + node = datum.node + if use_list_output: + assert isinstance(datum.y_positive, AnchorLabels), ( + f"expected AnchorLabels, got {type(datum.y_positive)}" + ) + positive_dict = datum.y_positive.to_dict() + # Direct check that the example-training read matches the legacy read, + # within this single batch, EXACTLY (order included). + query_node_idx = torch.arange(datum.batch_size) + legacy_positive_idx = torch.cat( + [positive_dict[a] for a in range(datum.batch_size)] + ) + legacy_repeated_query = query_node_idx.repeat_interleave( + torch.tensor( + [len(positive_dict[a]) for a in range(datum.batch_size)] + ) + ) + torch.testing.assert_close( + datum.y_positive.label_index, legacy_positive_idx + ) + torch.testing.assert_close( + query_node_idx[datum.y_positive.anchor_index], legacy_repeated_query + ) + else: + positive_dict = datum.y_positive + positive_pairs.extend(_ordered_global_pairs(node, positive_dict)) + if has_negatives: + if use_list_output: + assert isinstance(datum.y_negative, AnchorLabels) + negative_dict = datum.y_negative.to_dict() + else: + negative_dict = datum.y_negative + negative_pairs.extend(_ordered_global_pairs(node, negative_dict)) + else: + assert not hasattr(datum, "y_negative"), ( + f"expected no negatives, got {getattr(datum, 'y_negative', None)}" + ) + return_dict[use_list_output] = (positive_pairs, negative_pairs) + shutdown_rpc() + + class DistABLPLoaderTest(TestCase): def tearDown(self): if torch.distributed.is_initialized(): @@ -556,6 +653,102 @@ def test_ablp_dataloader( ), ) + @parameterized.expand( + [ + param( + "positive and negative", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + input_nodes=torch.tensor([10, 15]), + batch_size=2, + has_negatives=True, + ), + param( + "positive only", + labeled_edges={_POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]])}, + input_nodes=torch.tensor([10, 15]), + batch_size=2, + has_negatives=False, + ), + # Anchor 11 has message-passing edges (11 -> {13, 17}) but is the + # source of NO positive-label edge, so its positive-label row is + # all-padding and y_positive[11] is a guaranteed-empty tensor. This + # exercises the empty-anchor branch end-to-end for both outputs. + param( + "guaranteed empty positive anchor", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + input_nodes=torch.tensor([10, 11, 15]), + batch_size=3, + has_negatives=True, + ), + ] + ) + def test_use_list_output_matches_dict_output( + self, _, labeled_edges, input_nodes, batch_size, has_negatives + ): + """``use_list_output=True`` yields AnchorLabels whose ``.to_dict()`` matches + the default dict output as an ORDERED (global_anchor, global_label) pair + stream -- not merely a per-anchor set, so a pair-order regression fails. + + Sampling is deterministic here (``shuffle`` defaults to False and the + input is fixed), so the two loader runs emit batches in the same order + and the streams are directly comparable. The child process additionally + asserts the exact tensors the example-training loss reads from the + edge-list equal the legacy dict read (see ``_collect_homogeneous_labels``). + """ + edge_index = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 11, 15, 15, 16, 16], [11, 12, 13, 17, 13, 14, 12, 14]] + ), + } + edge_index.update(labeled_edges) + partition_output = PartitionOutput( + node_partition_book=to_heterogeneous_node(torch.zeros(18)), + edge_partition_book={ + e_type: torch.zeros(int(e_idx.max().item() + 1)) + for e_type, e_idx in edge_index.items() + }, + partitioned_edge_index={ + etype: GraphPartitionData( + edge_index=idx, edge_ids=torch.arange(idx.size(1)) + ) + for etype, idx in edge_index.items() + }, + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + manager = mp.Manager() + return_dict = manager.dict() + for use_list_output in (False, True): + mp.spawn( + fn=_collect_homogeneous_labels, + args=( + return_dict, + use_list_output, + dataset, + input_nodes, + batch_size, + has_negatives, + ), + ) + self.assertEqual(return_dict[False][0], return_dict[True][0]) + self.assertEqual(return_dict[False][1], return_dict[True][1]) + def test_cora_supervised(self): create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ From efa446d45b26b743fc09cb059b89441a2c1f5d28 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:53:05 +0000 Subject: [PATCH 06/18] docs(ablp): document vectorized remap, use_list_output, and AnchorLabels Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 31 +++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index c83efbc40..c2f17aea2 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -572,6 +572,20 @@ def __init__( - `y_positive`: {(a, to, b): {0: torch.tensor([1])}, (a, to, c): {0: torch.tensor([2])}} - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} + Label remapping (the conversion of global label ids to the local + indices stored in `y_positive`/`y_negative`) is vectorized internally; + the output is identical to the historical per-anchor implementation. + + When `use_list_output=True`, `y_positive` and `y_negative` are instead a + dense edge-list `AnchorLabels` (single supervision edge type) or + `dict[EdgeType, AnchorLabels]` (multiple). An `AnchorLabels` holds + `anchor_index` ([E] long), `label_index` ([E] long), and `num_anchors` + (int): pair `k` means local anchor row `anchor_index[k]` has local label + node `label_index[k]`. Pairs are ordered ascending by anchor (ties by + ascending label column). `AnchorLabels.to_dict()` reproduces the ragged + `dict[int, torch.Tensor]` form above. With `use_list_output=False` + (default) the output is the ragged dict, fully backward-compatible. + Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. @@ -647,13 +661,16 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. - use_list_output (bool): If False (default), ``y_positive`` and - ``y_negative`` are the legacy ragged ``dict[int, torch.Tensor]`` - per-anchor dicts. If True, they are :class:`AnchorLabels` dense - edge-list objects (single supervision edge type) or - ``dict[EdgeType, AnchorLabels]`` (multiple supervision edge - types), which expose ``anchor_index``, ``label_index``, and - ``num_anchors`` directly and provide a ``.to_dict()`` conversion. + use_list_output (bool): If True, return labels as a dense + ``AnchorLabels`` edge-list (or ``dict[EdgeType, AnchorLabels]`` + for multiple supervision edge types) instead of the ragged + ``dict[anchor_local_index, torch.Tensor]``. The edge-list form + lets the loss read labels without a per-anchor Python loop + (``y.label_index`` and ``query_idx[y.anchor_index]`` instead of + ``torch.cat(list(y.values()))`` and a ``repeat_interleave`` over + per-anchor lengths). ``AnchorLabels.to_dict()`` recovers the + ragged form. Defaults to ``False`` (ragged dict; fully + backward-compatible). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, From 8c91d6640ac5ee1f9d9f3b55f4eff92435e965b1 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:56:21 +0000 Subject: [PATCH 07/18] docs(examples): consume AnchorLabels edge-list in homogeneous training Co-Authored-By: Claude Opus 4.8 (1M context) --- .../link_prediction/homogeneous_training.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index b95a77489..dc17a119e 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -134,6 +134,9 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Return labels as a dense AnchorLabels edge-list so the loss reads + # anchor/label indices directly without a per-anchor Python loop. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader") @@ -190,18 +193,15 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - # main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor] - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - # We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). + # label_index holds the local label node per (anchor, label) pair; anchor_index + # holds the matching local anchor row. Pairs are ordered ascending by anchor, + # so this is equivalent to the historical dict read + # (torch.cat(list(values())) + repeat_interleave over per-anchor lengths). + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) From 1f1c06916a43a21a90ab00b62ba400f1d14f856c Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:56:43 +0000 Subject: [PATCH 08/18] docs(examples): consume AnchorLabels edge-list in heterogeneous training Co-Authored-By: Claude Opus 4.8 (1M context) --- .../link_prediction/heterogeneous_training.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index 8ed672b7c..175cc5eed 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -144,6 +144,9 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Return labels as a dense AnchorLabels edge-list so the loss reads + # anchor/label indices directly without a per-anchor Python loop. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader") @@ -223,18 +226,15 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor], even in the heterogeneous setting. - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - # We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), + # even in the heterogeneous setting (single supervision edge type per loss + # call). label_index holds the local label node per (anchor, label) pair; + # anchor_index holds the matching local anchor row. Pairs are ordered + # ascending by anchor, so this is equivalent to the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) From f60df78413e13b28a499f9bb9d9f5ba5389f08e4 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:57:01 +0000 Subject: [PATCH 09/18] docs(examples): consume AnchorLabels edge-list in graph-store homogeneous training Co-Authored-By: Claude Opus 4.8 (1M context) --- .../graph_store/homogeneous_training.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 2d4c22788..3f54a9bb1 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -241,6 +241,9 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Return labels as a dense AnchorLabels edge-list so the loss reads + # anchor/label indices directly without a per-anchor Python loop. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -305,16 +308,13 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). + # Pairs are ordered ascending by anchor, so reading label_index/anchor_index + # directly is equivalent to the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) From ee4adeb82fe57c33002fc089a06428e1dd19ab97 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 04:57:26 +0000 Subject: [PATCH 10/18] docs(examples): consume AnchorLabels edge-list in graph-store heterogeneous training Co-Authored-By: Claude Opus 4.8 (1M context) --- .../graph_store/heterogeneous_training.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 2be34e608..876f4f03e 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -211,6 +211,9 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Return labels as a dense AnchorLabels edge-list so the loss reads + # anchor/label indices directly without a per-anchor Python loop. + use_list_output=True, ) print(f"---Rank {rank} finished setting up main loader for split={split}") @@ -299,16 +302,13 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). + # Pairs are ordered ascending by anchor, so reading label_index/anchor_index + # directly is equivalent to the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) From dab9f8fb61733308cec24f28b9f55e33aac63908 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 05:01:03 +0000 Subject: [PATCH 11/18] feat(distributed): export AnchorLabels from gigl.distributed Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gigl/distributed/__init__.py b/gigl/distributed/__init__.py index f88b198cb..c5b351daf 100644 --- a/gigl/distributed/__init__.py +++ b/gigl/distributed/__init__.py @@ -3,6 +3,7 @@ """ __all__ = [ + "AnchorLabels", "DistABLPLoader", "DistNeighborLoader", "DistDataset", @@ -17,7 +18,7 @@ build_dataset, build_dataset_from_task_config_uri, ) -from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_ablp_neighborloader import AnchorLabels, DistABLPLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner From bd97d03b3f403849a9a251732c399422acf1bdb0 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 05:03:58 +0000 Subject: [PATCH 12/18] style(ablp): apply ruff formatting to ABLP label-output examples and test Line-wrapping only (ruff format) for the edge-list label reads in the four link-prediction training examples and the loader equivalence test. No logic change. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../link_prediction/graph_store/heterogeneous_training.py | 4 +++- examples/link_prediction/graph_store/homogeneous_training.py | 4 +++- examples/link_prediction/heterogeneous_training.py | 4 +++- examples/link_prediction/homogeneous_training.py | 4 +++- tests/unit/distributed/dist_ablp_neighborloader_test.py | 4 +--- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 876f4f03e..3670c5aec 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -306,7 +306,9 @@ def _compute_loss( # Pairs are ordered ascending by anchor, so reading label_index/anchor_index # directly is equivalent to the historical dict read. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) - repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) + ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 3f54a9bb1..3f5ef62fd 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -312,7 +312,9 @@ def _compute_loss( # Pairs are ordered ascending by anchor, so reading label_index/anchor_index # directly is equivalent to the historical dict read. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) - repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) + ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index 175cc5eed..e3370a506 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -232,7 +232,9 @@ def _compute_loss( # anchor_index holds the matching local anchor row. Pairs are ordered # ascending by anchor, so this is equivalent to the historical dict read. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) - repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) + ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index dc17a119e..2fc8719f3 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -199,7 +199,9 @@ def _compute_loss( # so this is equivalent to the historical dict read # (torch.cat(list(values())) + repeat_interleave over per-anchor lengths). positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) - repeated_query_node_idx = query_node_idx[main_data.y_positive.anchor_index.to(device)] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) + ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 6ef3be5b6..c4aee9294 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -485,9 +485,7 @@ def _collect_homogeneous_labels( [positive_dict[a] for a in range(datum.batch_size)] ) legacy_repeated_query = query_node_idx.repeat_interleave( - torch.tensor( - [len(positive_dict[a]) for a in range(datum.batch_size)] - ) + torch.tensor([len(positive_dict[a]) for a in range(datum.batch_size)]) ) torch.testing.assert_close( datum.y_positive.label_index, legacy_positive_idx From e6bd5b9eaaaf9b34e87456010a8ae6743813c09b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 16:18:53 +0000 Subject: [PATCH 13/18] refactor(ablp): drop order-reproduction argsort; add readability refactors **Behavioral change:** Drop the composite-key stable argsort from `_membership_remap` (the `composite_key = anchor_kept * (num_nodes + 1) + local_idx` / `torch.argsort(composite_key, stable=True)` block). The pair stream is now emitted in column-visit (row-major masked flatten) order rather than ascending-local-index (torch.nonzero) order. **Why this is safe:** `RetrievalLoss` (`gigl/nn/loss.py`) is `CrossEntropyLoss(reduction="sum")` over a diagonal-targeted score matrix that masks collisions by id VALUE, not position. It is invariant to any joint permutation of `(anchor_index, label_index)` pairs. The only constraint is co-indexing (pair k stays intact) and per-anchor grouping, both of which are preserved. Within-anchor label order was an implementation artifact of the loop oracle -- never a loss requirement. **Why the dict path is unaffected:** `anchor_of_entry` is built as `arange(N).repeat_interleave(M)` (row-major), so the masked flatten is already non-decreasing in anchor_index. `bincount`/`split` requires only contiguous grouping by anchor, which row-major flatten guarantees without any argsort. **Readability refactors in `_membership_remap`:** - Rename terse tensors: `valid` -> `is_present`, `found` -> `is_exact_match`, `positions` -> `sorted_positions`, `local_idx` -> `local_index`, `anchor_kept` -> `anchor_of_matched` - Add numbered step comments explaining the searchsorted membership lookup **Readability refactors in outer kernels:** - Add `_remap_group` helper to collapse the ~40-line duplicated positive/ negative per-edge-type loops in both `vectorized_set_labels` and `edge_list_set_labels` (uses TypeVar for generic return type) - `_sorted_for` closure pattern preserved, inlined per-kernel (memoizes torch.sort across pos/neg edge types of the same node type) **Other improvements:** - Add doctest to `AnchorLabels` showing a 2-anchor case + `to_dict()` round-trip - Update `AnchorLabels` docstring: document column-visit order and loss-permutation-invariance rationale - Update all docstrings to drop "bit-for-bit"/"torch.nonzero order" language **Test contract relaxation (tests remain non-vacuous):** - `vectorized_set_labels_test.py` / `edge_list_set_labels_test.py`: `_assert_label_dicts_equal` -> `_assert_label_dicts_set_equal` (uses `sorted()` per anchor; still catches membership errors + multiplicity) - `test_unsorted_node_map_exact_order` -> `test_unsorted_node_map_correct_membership`: asserts SET {0, 2} instead of sequence [0, 2]; docstring explains column vs ascending-local order difference - `RemapOneEdgeListTest.test_unsorted_node_map_nontrivial_sort_perm` -> `test_unsorted_node_map_correct_membership`: pins column order [2, 0] and explains it is SET-equal to [0, 2]; verifies sort_perm mapping is correct - `dist_ablp_neighborloader_test.py`: `_ordered_global_pairs` -> `_global_pair_set` (sorts within anchor); `_collect_homogeneous_labels` docstring updated; the in-process exact-tensor assertion (`label_index == cat(dict values)`) is preserved -- both paths draw from the same `_membership_remap` pair stream so they remain identical - CUDA device test docstring: remove "stable-argsort tie-break" reference; keep the duplicate [15,15] row (still tests duplicate handling + device placement) Device fix preserved: `anchor_of_entry` is still built on `label_tensor.device`. Co-Authored-By: Claude Sonnet 4.6 --- gigl/distributed/dist_ablp_neighborloader.py | 279 +++++++++++------- .../dist_ablp_neighborloader_test.py | 66 +++-- .../distributed/edge_list_set_labels_test.py | 53 ++-- .../label_remap_cuda_device_test.py | 10 +- .../distributed/vectorized_set_labels_test.py | 54 +++- 5 files changed, 296 insertions(+), 166 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index c2f17aea2..e51b4ccf5 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,7 +1,7 @@ from collections import abc, defaultdict from dataclasses import dataclass from itertools import count -from typing import Optional, Union +from typing import Optional, TypeVar, Union import torch from graphlearn_torch.channel import SampleMessage @@ -64,15 +64,33 @@ class AnchorLabels: 1-D ``long`` tensors. Pair ``k`` asserts that local anchor row ``anchor_index[k]`` has label node ``label_index[k]``. - Pairs are ordered ascending by anchor, ties broken ascending by label - column -- identical to ``torch.nonzero`` over the ``[N_anchors, M]`` padded - label tensor, so concatenating the legacy dict's values in anchor order - yields ``label_index`` and the matching anchor repeats yield - ``anchor_index``. + Pairs are emitted in ``(anchor, column)`` order from the source + ``[N_anchors, M]`` padded label tensor (row-major masked flatten). Within an + anchor, label order is **unspecified** -- the ABLP contrastive loss + (:class:`gigl.nn.loss.RetrievalLoss`) is permutation-invariant over the pair + stream (``CrossEntropyLoss(reduction="sum")`` over a diagonal-targeted score + matrix with value-based collision masks), so order carries no meaning. + ``anchor_index[k]`` and ``label_index[k]`` remain co-indexed. Anchors with no in-subgraph labels contribute zero pairs; ``num_anchors`` records the full anchor count so empty anchors remain recoverable. + Example:: + + >>> import torch + >>> labels = AnchorLabels( + ... anchor_index=torch.tensor([0, 1, 1], dtype=torch.long), + ... label_index=torch.tensor([3, 5, 7], dtype=torch.long), + ... num_anchors=3, + ... ) + >>> d = labels.to_dict() + >>> d[0].tolist() + [3] + >>> d[1].tolist() + [5, 7] + >>> d[2].tolist() + [] + Args: anchor_index (torch.Tensor): ``[E]`` long tensor of local anchor rows. label_index (torch.Tensor): ``[E]`` long tensor of local label node ids. @@ -181,18 +199,28 @@ def _membership_remap( Remaps one ``[N_anchors, M]`` ``-1``-padded label tensor of global ids to the flat ``(anchor_index, local_index)`` pair stream shared by both the dict and - edge-list builders, in :func:`torch.nonzero` order: ascending local index per - anchor, ties broken by ascending label column. The two callers - (:func:`_remap_one_label_tensor`, :func:`_remap_one_label_tensor_edge_list`) - differ only in how they package this pair stream, so the searchsorted - membership logic lives here once. + edge-list builders. Pairs are emitted in ``(anchor, column)`` order (the + natural row-major order of the masked flatten). Within an anchor, label order + is **unspecified** -- the ABLP contrastive loss is permutation-invariant over + the pair stream, so order carries no meaning; see :class:`AnchorLabels`. + + The two callers (:func:`_remap_one_label_tensor`, + :func:`_remap_one_label_tensor_edge_list`) differ only in how they package this + pair stream, so the searchsorted membership logic lives here once. + + The pair stream is already non-decreasing in ``anchor_index`` because + ``anchor_of_entry`` is built as ``arange(N).repeat_interleave(M)`` (row-major) + and the ``is_present`` mask preserves order within each anchor. This means the + dict builder's ``bincount``/``split`` is correct without any additional sorting: + it only requires contiguous grouping by anchor, which row-major flatten + guarantees. Precondition (REQUIRED for correctness): ``sorted_node`` must have UNIQUE values (the node map is unique local->global). :func:`torch.searchsorted` returns the left-most equal position, so a duplicate global id would collapse - multiple local indices to one and diverge from the loop oracle. GiGL ``node`` - maps guarantee uniqueness; the check is asserted only under ``__debug__`` to - keep the hot path zero-cost (and is a no-op under ``python -O``). + multiple local indices to one. GiGL ``node`` maps guarantee uniqueness; the + check is asserted only under ``__debug__`` to keep the hot path zero-cost (and + is a no-op under ``python -O``). Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global @@ -205,7 +233,7 @@ def _membership_remap( Returns: Tuple ``(anchor_index, local_index, num_anchors)``. ``anchor_index`` and ``local_index`` are equal-length 1-D ``long`` tensors on ``to_device`` in - :func:`torch.nonzero` order (empty when nothing matched); + ``(anchor, column)`` order (empty when nothing matched); ``num_anchors == label_tensor.size(0)``. """ num_anchors = int(label_tensor.size(0)) @@ -216,18 +244,18 @@ def _membership_remap( num_labels = int(label_tensor.size(1)) flat = label_tensor.reshape(-1) - # Create on the label tensor's device: it is indexed below by `valid` - # (derived from `label_tensor`), so on GPU a CPU arange would raise - # "indices should be either on cpu or on the same device as the indexed - # tensor". CPU-only unit tests cannot catch this; see the CUDA-gated test. + # Build on the label tensor's device: `anchor_of_entry` is indexed below by + # `is_present` (derived from `label_tensor`). On GPU, a CPU arange would + # raise "indices should be either on cpu or on the same device as the indexed + # tensor". CPU-only unit tests cannot catch this; see the CUDA-gated test. anchor_of_entry = torch.arange( num_anchors, device=label_tensor.device ).repeat_interleave(num_labels) # Mask the padding sentinel BEFORE any search so we never gather with -1. - valid = flat != PADDING_NODE - flat = flat[valid] - anchor_of_entry = anchor_of_entry[valid] + is_present = flat != PADDING_NODE + flat = flat[is_present] + anchor_of_entry = anchor_of_entry[is_present] if num_nodes == 0 or flat.numel() == 0: return empty, empty, num_anchors @@ -237,21 +265,30 @@ def _membership_remap( "vectorized label remap requires a unique node local->global map; " "duplicate global ids break the searchsorted membership lookup." ) - positions = torch.searchsorted(sorted_node, flat) - positions = positions.clamp_(max=num_nodes - 1) - found = sorted_node[positions] == flat - local_idx = sort_perm[positions][found] - anchor_kept = anchor_of_entry[found] - - # Order within each anchor must match torch.nonzero over [N, M]: ascending - # local index, ties broken by ascending label column. searchsorted visits - # entries in (anchor, column) order, so a stable sort on a composite key - # (anchor primary, local index secondary) reproduces it. - composite_key = anchor_kept * (num_nodes + 1) + local_idx - order = torch.argsort(composite_key, stable=True) + + # 1. Locate each label id in the sorted node map: searchsorted returns the + # insertion point, so `sorted_positions[i]` is the candidate index in + # `sorted_node` where `flat[i]` would be inserted to keep order sorted. + sorted_positions = torch.searchsorted(sorted_node, flat) + sorted_positions = sorted_positions.clamp_(max=num_nodes - 1) + + # 2. Keep only exact matches (drop global ids absent from the subgraph). + # `sorted_node[sorted_positions] == flat` is True iff flat[i] is actually + # in the node map (not just a neighboring element in the sorted array). + is_exact_match = sorted_node[sorted_positions] == flat + + # 3. Map sorted position -> original local index via sort_perm: sort_perm[j] + # is the local node index whose global id landed at sorted position j. + local_index = sort_perm[sorted_positions][is_exact_match] + anchor_of_matched = anchor_of_entry[is_exact_match] + + # Pairs are now in (anchor, column) order -- non-decreasing in anchor_index + # because anchor_of_entry is row-major and the masks preserve order. No + # argsort is needed: within-anchor label order is unspecified by contract + # (the ABLP loss is permutation-invariant; see AnchorLabels docstring). return ( - anchor_kept[order].to(to_device).to(torch.long), - local_idx[order].to(to_device).to(torch.long), + anchor_of_matched.to(to_device).to(torch.long), + local_index.to(to_device).to(torch.long), num_anchors, ) @@ -265,10 +302,14 @@ def _remap_one_label_tensor( """Vectorized remap of one ``[N_anchors, M]`` padded label tensor to a dict. Thin wrapper over :func:`_membership_remap`: splits the shared - ``(anchor_index, local_index)`` pair stream into a per-anchor dict. For each - anchor row, the value is the ascending local indices into the original - (pre-sort) node order whose global id appears in that row, in - :func:`torch.nonzero` multiplicity. + ``(anchor_index, local_index)`` pair stream into a per-anchor dict via + ``torch.bincount``/``torch.split``. The pair stream from + :func:`_membership_remap` is non-decreasing in ``anchor_index`` (row-major + masked flatten), so ``bincount``/``split`` is correct without additional + sorting. For each anchor row, the value is the set of local indices into the + original (pre-sort) node order whose global id appears in that row, in + column-visit order (matching multiplicity -- duplicate label columns produce + repeated local entries). Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global @@ -296,6 +337,54 @@ def _remap_one_label_tensor( return {anchor: per_anchor[anchor] for anchor in range(num_anchors)} +_LabelT = TypeVar("_LabelT") + + +def _remap_group( + labels_by_edge_type: dict[EdgeType, torch.Tensor], + sorted_for: abc.Callable[[NodeType], tuple[torch.Tensor, torch.Tensor]], + remap_one: abc.Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.device], + _LabelT, + ], + to_device: torch.device, +) -> dict[EdgeType, _LabelT]: + """Remap one group (positive or negative) of label tensors. + + Iterates ``labels_by_edge_type``, skips zero-anchor tensors (to match the + loop oracle's ``defaultdict`` which never creates an outer key for an empty + anchor set), and delegates each tensor to ``remap_one``. + + Args: + labels_by_edge_type: Per label edge type, a ``[N_anchors, M]`` + ``-1``-padded global label tensor. + sorted_for: Callable returning ``(sorted_values, sort_perm)`` for a + given supervision node type (typically a memoized closure). + remap_one: Per-tensor remap callable with signature + ``(label_tensor, sorted_node, sort_perm, to_device) -> _LabelT``. + Either :func:`_remap_one_label_tensor` (dict output) or + :func:`_remap_one_label_tensor_edge_list` (:class:`AnchorLabels`). + to_device: Device for all output tensors. + + Returns: + ``dict[message_passing_edge_type, _LabelT]`` with no entry for + zero-anchor label tensors. + """ + # Supervision edge types are (anchor_node_type, relation, supervision_node_type). + supervision_node_type_index = 2 + output: dict[EdgeType, _LabelT] = {} + for edge_type, label_tensor in labels_by_edge_type.items(): + # Match the loop oracle's defaultdict: a zero-anchor tensor produces no + # outer key (the loop's per-anchor body never runs for an empty tensor). + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = sorted_for(edge_type[supervision_node_type_index]) + output[label_edge_type_to_message_passing_edge_type(edge_type)] = remap_one( + label_tensor, sorted_node, sort_perm, to_device + ) + return output + + def vectorized_set_labels( node_local_to_global_by_type: dict[NodeType, torch.Tensor], positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], @@ -309,24 +398,26 @@ def vectorized_set_labels( """Vectorized label remap from global label ids to local node indices. Drop-in replacement for the per-anchor loop in :func:`_loop_set_labels`, - producing bit-for-bit identical ragged output without a per-anchor Python - loop. + producing set-equivalent ragged output without a per-anchor Python loop. + Equivalent to the loop oracle up to within-anchor label order; the ABLP + contrastive loss is permutation-invariant over the pair stream, so the + output is interchangeable with the oracle's for all production consumers. For each label edge type and each anchor row of its ``[N_anchors, M]`` - ``-1``-padded label tensor, emits the ascending local indices into the - supervision node type's ``node`` map whose global id appears in that row, in - :func:`torch.nonzero` multiplicity. The padding sentinel - (:data:`gigl.utils.data_splitters.PADDING_NODE`) is masked before any search, - so it is never used as a lookup key. Every anchor index ``0..N_anchors-1`` - receives a key; anchors with no in-subgraph labels map to an empty ``long`` - tensor. + ``-1``-padded label tensor, emits the set of local indices into the + supervision node type's ``node`` map whose global id appears in that row, + in column-visit order with matching multiplicity (duplicate label columns + produce repeated local entries). The padding sentinel + (:data:`gigl.utils.data_splitters.PADDING_NODE`) is masked before any + search, so it is never used as a lookup key. Every anchor index + ``0..N_anchors-1`` receives a key; anchors with no in-subgraph labels map + to an empty ``long`` tensor. Precondition (REQUIRED for correctness): each ``node`` local->global map in ``node_local_to_global_by_type`` MUST contain UNIQUE global ids. The ``torch.searchsorted`` membership lookup returns the LEFT-MOST matching sorted position; a repeated global id would resolve every match to a single local - index, dropping the duplicate and silently diverging from the loop. GiGL - ``node`` maps satisfy this by construction. + index, dropping the duplicate. GiGL ``node`` maps satisfy this by construction. Args: node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node @@ -347,7 +438,6 @@ def vectorized_set_labels( with an entry for every anchor index ``0..N_anchors-1``. """ del supervision_edge_types # Accepted for signature parity; not needed here. - edge_index = 2 # Supervision edge types are (anchor, to, supervision). sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: @@ -357,27 +447,20 @@ def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: ) return sorted_cache[node_type] - output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = {} - for edge_type, label_tensor in positive_labels_by_edge_type.items(): - # Match the loop's defaultdict: a zero-anchor tensor produces NO outer - # key (the loop's per-anchor body never runs). - if label_tensor.size(0) == 0: - continue - sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) - output_positive_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ] = _remap_one_label_tensor(label_tensor, sorted_node, sort_perm, to_device) - - output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = {} - for edge_type, label_tensor in negative_labels_by_edge_type.items(): - if label_tensor.size(0) == 0: - continue - sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) - output_negative_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ] = _remap_one_label_tensor(label_tensor, sorted_node, sort_perm, to_device) - - return output_positive_labels, output_negative_labels + return ( + _remap_group( + positive_labels_by_edge_type, + _sorted_for, + _remap_one_label_tensor, + to_device, + ), + _remap_group( + negative_labels_by_edge_type, + _sorted_for, + _remap_one_label_tensor, + to_device, + ), + ) def _remap_one_label_tensor_edge_list( @@ -394,6 +477,9 @@ def _remap_one_label_tensor_edge_list( ``torch.bincount``/``torch.split``, no per-anchor Python comprehension); ``_membership_remap`` already moved both tensors to ``to_device`` as ``long``. + Pairs are in ``(anchor, column)`` order; within an anchor, label order is + unspecified (see :class:`AnchorLabels`). + Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global label ids. @@ -403,8 +489,8 @@ def _remap_one_label_tensor_edge_list( to_device (torch.device): Device for the output tensors. Returns: - AnchorLabels with ``anchor_index``/``label_index`` in - :func:`torch.nonzero` order and ``num_anchors == label_tensor.size(0)``. + :class:`AnchorLabels` with ``anchor_index``/``label_index`` in + ``(anchor, column)`` order and ``num_anchors == label_tensor.size(0)``. """ anchor_index, label_index, num_anchors = _membership_remap( label_tensor, sorted_node, sort_perm, to_device @@ -428,8 +514,9 @@ def edge_list_set_labels( Drop-in alternative to :func:`vectorized_set_labels` that emits, per label edge type, an :class:`AnchorLabels` dense edge-list instead of a ragged ``dict[int, torch.Tensor]``. Membership semantics are identical; expanding - each result via :meth:`AnchorLabels.to_dict` reproduces - :func:`_loop_set_labels` bit-for-bit. + each result via :meth:`AnchorLabels.to_dict` produces a set-equivalent dict + (same per-anchor label sets as :func:`_loop_set_labels`, in column-visit + order rather than ascending-local order -- permutation-invariant for the loss). Precondition (REQUIRED): each ``node`` local->global map must contain UNIQUE global ids -- see :func:`vectorized_set_labels` for the rationale. @@ -453,7 +540,6 @@ def edge_list_set_labels( zero-anchor label tensor (matching the loop's defaultdict). """ del supervision_edge_types # Accepted for signature parity; not needed here. - edge_index = 2 # Supervision edge types are (anchor, to, supervision). sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: @@ -463,29 +549,20 @@ def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: ) return sorted_cache[node_type] - output_positive_labels: dict[EdgeType, AnchorLabels] = {} - for edge_type, label_tensor in positive_labels_by_edge_type.items(): - if label_tensor.size(0) == 0: - continue - sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) - output_positive_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ] = _remap_one_label_tensor_edge_list( - label_tensor, sorted_node, sort_perm, to_device - ) - - output_negative_labels: dict[EdgeType, AnchorLabels] = {} - for edge_type, label_tensor in negative_labels_by_edge_type.items(): - if label_tensor.size(0) == 0: - continue - sorted_node, sort_perm = _sorted_for(edge_type[edge_index]) - output_negative_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ] = _remap_one_label_tensor_edge_list( - label_tensor, sorted_node, sort_perm, to_device - ) - - return output_positive_labels, output_negative_labels + return ( + _remap_group( + positive_labels_by_edge_type, + _sorted_for, + _remap_one_label_tensor_edge_list, + to_device, + ), + _remap_group( + negative_labels_by_edge_type, + _sorted_for, + _remap_one_label_tensor_edge_list, + to_device, + ), + ) class DistABLPLoader(BaseDistLoader): diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index c4aee9294..ce37c3144 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -416,21 +416,27 @@ def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( shutdown_rpc() -def _ordered_global_pairs( +def _global_pair_set( node: torch.Tensor, label_dict: dict[int, torch.Tensor] ) -> list[tuple[int, int]]: - """Flatten a per-anchor label dict to an ORDERED (global_anchor, global_label) - pair stream. - - Iterates anchors in ascending key order and labels in their stored order - WITHOUT sorting, so a change in pair order (e.g. a kernel that emitted - columns or sorted-position order) changes the returned list. This is the - silent-mistraining mode a per-anchor sorted-set comparison would miss. + """Flatten a per-anchor label dict to a sorted (global_anchor, global_label) + pair list for set-equality comparison. + + Pairs are collected in anchor-ascending order and sorted within each anchor + by global label value, so the result is canonical regardless of the + within-anchor label order emitted by the kernel (column-visit order) vs the + loop oracle (ascending-local order). Multiplicity is preserved: duplicate + label columns produce duplicate entries in the list. + + Use ``collections.Counter`` or ``sorted(...)`` equality rather than + positional ``==`` if pair-level multiplicity without anchor grouping matters. + The current callers use ``list ==`` after sorting within anchor, which is + equivalent to Counter equality for these tests. """ pairs: list[tuple[int, int]] = [] for local_anchor in sorted(label_dict.keys()): global_anchor = int(node[local_anchor].item()) - for local_label in label_dict[local_anchor].tolist(): + for local_label in sorted(label_dict[local_anchor].tolist()): pairs.append((global_anchor, int(node[local_label].item()))) return pairs @@ -444,20 +450,23 @@ def _collect_homogeneous_labels( batch_size: int, has_negatives: bool, ): - """Child-side: run the loader, return the ORDERED global-id pair streams. + """Child-side: run the loader, return sorted global-id pair sets. Local node indices differ run-to-run, so labels are translated back to - global ids via ``datum.node``. The streams preserve pair ORDER (see - ``_ordered_global_pairs``) so dict-vs-edge-list equality in the parent - catches an order regression, not just a set regression. + global ids via ``datum.node``. Pairs are sorted within each anchor (see + ``_global_pair_set``) for canonical set-equality comparison in the parent: + the vectorized kernel emits column-visit order while the loop oracle emits + ascending-local order; both are correct since the ABLP loss is + permutation-invariant over the pair stream. When ``use_list_output`` is True the labels arrive as :class:`AnchorLabels`. This branch ALSO asserts in-process that the exact tensors the example training loss reads from the edge-list match the legacy dict read: the edge-list ``label_index`` must equal the dict's ``torch.cat(values())`` and ``query_node_idx[anchor_index]`` must equal the legacy - ``repeat_interleave`` over per-anchor lengths. A drift here is exactly the - example-training bug we are guarding against. + ``repeat_interleave`` over per-anchor lengths. Both paths draw from the same + :func:`_membership_remap` pair stream (column-visit order), so they are + identical -- a drift here is exactly the example-training bug we guard against. """ create_test_process_group() loader = DistABLPLoader( @@ -495,14 +504,14 @@ def _collect_homogeneous_labels( ) else: positive_dict = datum.y_positive - positive_pairs.extend(_ordered_global_pairs(node, positive_dict)) + positive_pairs.extend(_global_pair_set(node, positive_dict)) if has_negatives: if use_list_output: assert isinstance(datum.y_negative, AnchorLabels) negative_dict = datum.y_negative.to_dict() else: negative_dict = datum.y_negative - negative_pairs.extend(_ordered_global_pairs(node, negative_dict)) + negative_pairs.extend(_global_pair_set(node, negative_dict)) else: assert not hasattr(datum, "y_negative"), ( f"expected no negatives, got {getattr(datum, 'y_negative', None)}" @@ -693,15 +702,20 @@ def test_ablp_dataloader( def test_use_list_output_matches_dict_output( self, _, labeled_edges, input_nodes, batch_size, has_negatives ): - """``use_list_output=True`` yields AnchorLabels whose ``.to_dict()`` matches - the default dict output as an ORDERED (global_anchor, global_label) pair - stream -- not merely a per-anchor set, so a pair-order regression fails. - - Sampling is deterministic here (``shuffle`` defaults to False and the - input is fixed), so the two loader runs emit batches in the same order - and the streams are directly comparable. The child process additionally - asserts the exact tensors the example-training loss reads from the - edge-list equal the legacy dict read (see ``_collect_homogeneous_labels``). + """``use_list_output=True`` yields AnchorLabels whose ``.to_dict()`` produces + a per-anchor SET-equal result to the default dict output. + + Both paths draw from the same :func:`_membership_remap` column-visit pair + stream, so they are in fact identical; the test confirms no membership or + co-indexing regression. Sampling is deterministic (``shuffle`` defaults to + False), so the two loader runs emit batches in the same order and the sorted + pair sets are directly comparable. + + The child process additionally asserts the exact tensors the example-training + loss reads from the edge-list equal the legacy dict read (see + ``_collect_homogeneous_labels``): ``label_index`` equals the dict's + ``torch.cat(values())`` and ``query_node_idx[anchor_index]`` equals the + legacy ``repeat_interleave``. """ edge_index = { DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py index 578838505..432accf12 100644 --- a/tests/unit/distributed/edge_list_set_labels_test.py +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -48,10 +48,18 @@ def _dict_from_edge_list( return {et: labels.to_dict() for et, labels in edge_list_by_type.items()} -def _assert_label_dicts_equal( +def _assert_label_dicts_set_equal( actual: dict[PyGEdgeType, dict[int, torch.Tensor]], expected: dict[PyGEdgeType, dict[int, torch.Tensor]], ) -> None: + """Assert per-anchor SET equality (with multiplicity) between two label dicts. + + The edge-list kernel emits pairs in column-visit order; the loop oracle emits + them in ascending-local order. Both are valid: the ABLP contrastive loss is + permutation-invariant over the pair stream. This helper normalises order via + ``sorted()`` so it catches membership and multiplicity errors while remaining + invariant to within-anchor permutations. + """ assert set(actual.keys()) == set(expected.keys()), ( f"{set(actual.keys())} != {set(expected.keys())}" ) @@ -61,7 +69,10 @@ def _assert_label_dicts_equal( for anchor, expected_tensor in inner.items(): got = actual_inner[anchor] assert got.dtype == torch.long - torch.testing.assert_close(got, expected_tensor) + assert sorted(got.tolist()) == sorted(expected_tensor.tolist()), ( + f"{edge_type}[{anchor}]: got {got.tolist()}, " + f"expected {expected_tensor.tolist()} (as sets)" + ) class AnchorLabelsTest(TestCase): @@ -80,11 +91,14 @@ def test_to_dict_round_trips_empty_and_multi(self) -> None: class RemapOneEdgeListTest(TestCase): - def test_matches_nonzero_order_with_padding_and_empty(self) -> None: + def test_column_order_with_padding_and_empty(self) -> None: + # When the node map is sorted, column order and ascending-local order + # coincide, so this verifies exact output. + # anchor 0: [15, -1] -> local 5 ; anchor 1: [15, 16] -> local 5, 6 + # (column order == ascending-local here because node is already sorted); + # anchor 2: [-1, -1] -> empty ; anchor 3: [99, -1] -> empty (99 absent). node = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]) sorted_node, sort_perm = torch.sort(node) - # anchor 0: [15, -1] -> local 5 ; anchor 1: [15, 16] -> local 5,6 ; - # anchor 2: [-1, -1] -> empty ; anchor 3: [99, -1] -> empty (99 absent). label_tensor = torch.tensor( [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long ) @@ -99,14 +113,14 @@ def test_matches_nonzero_order_with_padding_and_empty(self) -> None: result.label_index, torch.tensor([5, 5, 6], dtype=torch.long) ) - def test_unsorted_node_map_nontrivial_sort_perm(self) -> None: + def test_unsorted_node_map_correct_membership(self) -> None: # node map is UNSORTED, so torch.sort yields a non-identity sort_perm. - # node[i]=global id of local i: g15->local0, g10->local1, g16->local2, - # g11->local3. The label row is high-id-first ([16, 15]); the edge-list - # must emit ascending LOCAL index (g15=local0 before g16=local2), proving - # the result is mapped through sort_perm and is not in column or sorted - # order. A port that dropped the sort_perm gather would emit [0, 2] mapped - # to the wrong locals (or sorted-position indices 1 and 3) and fail here. + # node[i] = global id of local i: g15->local0, g10->local1, g16->local2, + # g11->local3. The label row is high-id-first ([16, 15]); the kernel emits + # pairs in column order (g16=local2 first, g15=local0 second) -> [2, 0]. + # The loop oracle would emit ascending-local order -> [0, 2]. Both are + # SET-equal to {0, 2}. A port that dropped the sort_perm gather would emit + # sorted-position indices (1 and 3) and fail the membership check. node = torch.tensor([15, 10, 16, 11]) sorted_node, sort_perm = torch.sort(node) label_tensor = torch.tensor([[16, 15]], dtype=torch.long) @@ -117,8 +131,9 @@ def test_unsorted_node_map_nontrivial_sort_perm(self) -> None: torch.testing.assert_close( result.anchor_index, torch.tensor([0, 0], dtype=torch.long) ) + # Column order: g16 (local2) first, g15 (local0) second. torch.testing.assert_close( - result.label_index, torch.tensor([0, 2], dtype=torch.long) + result.label_index, torch.tensor([2, 0], dtype=torch.long) ) @@ -147,9 +162,11 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): negatives={}, supervision_edge_types=[_USER_TO_STORY], ), - # UNSORTED node map + reversed label columns -- non-identity sort_perm - # (see the vectorized test for the rationale). Guards against a port - # that emits sorted/column order instead of ascending local index. + # UNSORTED node map + reversed label columns -- non-identity sort_perm. + # The kernel emits column order (g16=local2 first, g15=local0 second); + # the loop oracle emits ascending-local order (local0, local2). + # The per-anchor SET is {0, 2} for both, guarding against a port that + # emits sorted-position indices (1, 3) instead of sort_perm-mapped locals. param( "unsorted_node_map_reversed_columns", node_map={_STORY: torch.tensor([15, 10, 16, 11])}, @@ -228,8 +245,8 @@ def test_edge_list_to_dict_matches_loop( supervision_edge_types=supervision_edge_types, to_device=_CPU, ) - _assert_label_dicts_equal(_dict_from_edge_list(el_pos), loop_pos) - _assert_label_dicts_equal(_dict_from_edge_list(el_neg), loop_neg) + _assert_label_dicts_set_equal(_dict_from_edge_list(el_pos), loop_pos) + _assert_label_dicts_set_equal(_dict_from_edge_list(el_neg), loop_neg) def test_duplicate_node_map_raises_assertion(self) -> None: # NOTE: the uniqueness check is gated on `__debug__`, so this is a no-op diff --git a/tests/unit/distributed/label_remap_cuda_device_test.py b/tests/unit/distributed/label_remap_cuda_device_test.py index 396243d7e..5bdb22fa5 100644 --- a/tests/unit/distributed/label_remap_cuda_device_test.py +++ b/tests/unit/distributed/label_remap_cuda_device_test.py @@ -33,11 +33,11 @@ def _inputs(device: torch.device): """A small case exercising every on-device code path. The node map is UNSORTED (so ``torch.sort`` yields a non-identity - ``sort_perm`` -- the gather through it must run on-device), and anchor 0 has - a DUPLICATE label column ([15, 15]) so the stable-argsort tie-break over the - composite key runs on-device. Local layout: g15->0, g10->1, g16->2, g11->3, - g12->4. Anchor rows: [15, 15] -> local 0 twice; [16, -1] -> local 2; - [-1, -1] -> empty. + ``sort_perm`` -- the gather through it must run on-device). Anchor 0 has a + DUPLICATE label column ([15, 15]) to exercise duplicate handling and + verify multiplicity is preserved (local 0 appears twice). Local layout: + g15->0, g10->1, g16->2, g11->3, g12->4. Anchor rows: [15, 15] -> local 0 + twice; [16, -1] -> local 2; [-1, -1] -> empty. """ node_map = {_STORY: torch.tensor([15, 10, 16, 11, 12], device=device)} positives = { diff --git a/tests/unit/distributed/vectorized_set_labels_test.py b/tests/unit/distributed/vectorized_set_labels_test.py index 6bd6c2a0b..e4b894d9d 100644 --- a/tests/unit/distributed/vectorized_set_labels_test.py +++ b/tests/unit/distributed/vectorized_set_labels_test.py @@ -42,10 +42,23 @@ def _neg(edge_type: EdgeType) -> EdgeType: return message_passing_to_negative_label(edge_type) -def _assert_label_dicts_equal( +def _assert_label_dicts_set_equal( actual: dict[EdgeType, dict[int, torch.Tensor]], expected: dict[EdgeType, dict[int, torch.Tensor]], ) -> None: + """Assert per-anchor SET equality (with multiplicity) between two label dicts. + + The vectorized kernel emits pairs in column-visit order; the loop oracle emits + them in ascending-local order. These differ when a row's local indices are not + monotone in column order (e.g. an unsorted node map with reversed label columns). + Both are valid: the ABLP contrastive loss is permutation-invariant over the pair + stream (``CrossEntropyLoss(reduction="sum")`` with value-based collision masks), + so within-anchor label order carries no meaning. + + This helper uses ``sorted(...)`` to normalise order before comparing, so it + catches membership errors and multiplicity errors (duplicate labels) while + remaining invariant to within-anchor permutations. + """ assert set(actual.keys()) == set(expected.keys()), ( f"{set(actual.keys())} != {set(expected.keys())}" ) @@ -57,7 +70,12 @@ def _assert_label_dicts_equal( for anchor, expected_tensor in inner.items(): got = actual_inner[anchor] assert got.dtype == torch.long, f"{edge_type}[{anchor}] dtype {got.dtype}" - torch.testing.assert_close(got, expected_tensor) + # Sort both tensors so the comparison is order-independent but still + # catches missing/extra entries and multiplicity (duplicate labels). + assert sorted(got.tolist()) == sorted(expected_tensor.tolist()), ( + f"{edge_type}[{anchor}]: got {got.tolist()}, " + f"expected {expected_tensor.tolist()} (as sets)" + ) class LoopSetLabelsContractTest(TestCase): @@ -88,7 +106,7 @@ def test_homogeneous_with_empty_and_padded_anchors(self) -> None: 3: torch.tensor([], dtype=torch.long), } } - _assert_label_dicts_equal(y_pos, expected) + _assert_label_dicts_set_equal(y_pos, expected) self.assertEqual(y_neg, {}) def test_duplicate_label_columns_preserve_multiplicity(self) -> None: @@ -136,10 +154,11 @@ class VectorizedSetLabelsEquivalenceTest(TestCase): # `sort_perm` is the identity and a broken port that emits sorted (or # global) order would still pass; here `torch.sort` permutes nontrivially # (15->local0, 10->local1, 16->local2, 11->local3), and the label row is - # given high-id-first ([16, 15]). The loop oracle emits ascending LOCAL - # index regardless of column order, i.e. local 0 (g15) before local 2 - # (g16) -> [0, 2]; a kernel that forgot to map through `sort_perm`, or - # that preserved column order, would diverge. + # given high-id-first ([16, 15]). The kernel emits pairs in column-visit + # order (g16=local2 first, then g15=local0), while the loop oracle emits + # ascending-local order (local0, local2). The SET is {0, 2} for both -- + # a kernel that forgot to map through `sort_perm` would emit the wrong + # locals (e.g. sorted positions 1 and 3) and fail the set check. param( "unsorted_node_map_reversed_columns", node_map={_STORY: torch.tensor([15, 10, 16, 11])}, @@ -252,13 +271,16 @@ def test_matches_loop( supervision_edge_types=supervision_edge_types, to_device=_CPU, ) - _assert_label_dicts_equal(vec_pos, loop_pos) - _assert_label_dicts_equal(vec_neg, loop_neg) + _assert_label_dicts_set_equal(vec_pos, loop_pos) + _assert_label_dicts_set_equal(vec_neg, loop_neg) - def test_unsorted_node_map_exact_order(self) -> None: + def test_unsorted_node_map_correct_membership(self) -> None: # Belt-and-suspenders on top of the parameterized case: assert the EXACT - # tensor (not just equality-to-loop) so the expected ascending-local order - # is visible and a regression to sorted/column order is unmistakable. + # local-index SET for the unsorted-node-map case so a regression to + # wrong locals (e.g. sorted-position indices instead of sort_perm-mapped + # locals) is unmistakable, even if within-anchor order is not pinned. + # Layout: node = [15, 10, 16, 11] -> g15=local0, g10=local1, g16=local2, + # g11=local3. Labels: [16, 15] -> expected SET {local0, local2} = {0, 2}. node_map = {_STORY: torch.tensor([15, 10, 16, 11])} positives = {_pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long)} vec_pos, _ = vectorized_set_labels( @@ -268,10 +290,10 @@ def test_unsorted_node_map_exact_order(self) -> None: supervision_edge_types=[_USER_TO_STORY], to_device=_CPU, ) - # g15 is local 0, g16 is local 2 -> ascending local order is [0, 2]. - torch.testing.assert_close( - vec_pos[_USER_TO_STORY][0], torch.tensor([0, 2], dtype=torch.long) - ) + # The kernel emits column order (g16 first, g15 second) -> [2, 0]. + # The loop oracle emits ascending-local order -> [0, 2]. + # Both are SET-equal to {0, 2}; only membership and co-indexing matter. + assert sorted(vec_pos[_USER_TO_STORY][0].tolist()) == [0, 2] def test_duplicate_node_map_raises_assertion(self) -> None: # NOTE: the uniqueness check is gated on `__debug__`, so this assertion is From 594ddba3ea76edebc75a93fa8f68de0966955ba5 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 25 Jun 2026 16:33:17 +0000 Subject: [PATCH 14/18] docs(examples): clarify AnchorLabels read is loss-equivalent, not order-identical After dropping the order-reproduction argsort, per-anchor label order is column-visit order, not the loop's nonzero order. The (query, label) pairs are unchanged and the contrastive loss is order-invariant, so the read is equivalent for training; the comments now say so instead of implying byte-identical order. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../link_prediction/graph_store/heterogeneous_training.py | 5 +++-- .../link_prediction/graph_store/homogeneous_training.py | 5 +++-- examples/link_prediction/heterogeneous_training.py | 5 +++-- examples/link_prediction/homogeneous_training.py | 8 +++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 3670c5aec..fb215192f 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -303,8 +303,9 @@ def _compute_loss( random_negative_batch_size = random_negative_data[labeled_node_type].batch_size # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). - # Pairs are ordered ascending by anchor, so reading label_index/anchor_index - # directly is equivalent to the historical dict read. + # Pairs are grouped by anchor, so reading label_index/anchor_index directly + # yields the same (query, label) pairs as the historical dict read; the + # within-anchor order may differ, but the loss is order-invariant. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 3f5ef62fd..710cb8211 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -309,8 +309,9 @@ def _compute_loss( random_negative_batch_size = random_negative_data.batch_size # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). - # Pairs are ordered ascending by anchor, so reading label_index/anchor_index - # directly is equivalent to the historical dict read. + # Pairs are grouped by anchor, so reading label_index/anchor_index directly + # yields the same (query, label) pairs as the historical dict read; the + # within-anchor order may differ, but the loss is order-invariant. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index e3370a506..6e9416dbb 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -229,8 +229,9 @@ def _compute_loss( # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), # even in the heterogeneous setting (single supervision edge type per loss # call). label_index holds the local label node per (anchor, label) pair; - # anchor_index holds the matching local anchor row. Pairs are ordered - # ascending by anchor, so this is equivalent to the historical dict read. + # anchor_index holds the matching local anchor row. Pairs are grouped by + # anchor, so this yields the same (query, label) pairs as the historical dict + # read; the within-anchor order may differ, but the loss is order-invariant. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index 2fc8719f3..4d3af134a 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -195,9 +195,11 @@ def _compute_loss( # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). # label_index holds the local label node per (anchor, label) pair; anchor_index - # holds the matching local anchor row. Pairs are ordered ascending by anchor, - # so this is equivalent to the historical dict read - # (torch.cat(list(values())) + repeat_interleave over per-anchor lengths). + # holds the matching local anchor row. Pairs are grouped by anchor, so reading + # label_index/anchor_index directly yields the same (query, label) pairs as the + # historical dict read (torch.cat(list(values())) + repeat_interleave over + # per-anchor lengths). The within-anchor label order may differ, but the + # contrastive loss is order-invariant, so the result is equivalent. positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) From 641b24e54b365016297b77ccae1aa2bbd2a28dbe Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 26 Jun 2026 17:20:17 +0000 Subject: [PATCH 15/18] refactor(ablp): single AnchorLabels label-remap kernel + dict view GiGL main resolved ABLP labels with a per-anchor Python loop. This replaces it with one vectorized kernel and a single output type, plus a thin dict view for backward compatibility. - Resolve labels with a sorted-membership join (_membership_remap: sort the node map once, searchsorted the label ids, keep exact matches) instead of the O(N_anchors * M * N_nodes) per-anchor loop. Delete _loop_set_labels (no production callers) and the redundant dict-producing kernel. - Collapse the remap to two functions, _membership_remap and edge_list_set_labels; the dict path is just AnchorLabels.to_dict(), selected by use_list_output on DistABLPLoader (default False keeps the ragged dict). Drop the _remap_group callback indirection, the _LabelT TypeVar, and the vestigial supervision_edge_types parity param. - AnchorLabels stores labels as two parallel (anchor_index, label_index) tensors so the loss can index them directly; within-anchor order is unspecified because the ABLP contrastive loss is order-invariant over the pairs. - Make the __debug__ unique-node-map check a cheap adjacent-difference test on the already-sorted map (it ran torch.unique every batch, and GiGL is not launched with -O). - Tests assert against constructed expected values and per-anchor label SETS, not within-anchor order. Examples consume the edge-list directly. Behavior-preserving: per-anchor label sets are unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../graph_store/heterogeneous_training.py | 9 +- .../graph_store/homogeneous_training.py | 9 +- .../link_prediction/heterogeneous_training.py | 11 +- .../link_prediction/homogeneous_training.py | 3 +- gigl/distributed/dist_ablp_neighborloader.py | 533 +++++------------- .../dist_ablp_neighborloader_test.py | 46 +- .../distributed/edge_list_set_labels_test.py | 314 +++++++---- .../label_remap_cuda_device_test.py | 65 +-- .../distributed/vectorized_set_labels_test.py | 317 ----------- 9 files changed, 391 insertions(+), 916 deletions(-) delete mode 100644 tests/unit/distributed/vectorized_set_labels_test.py diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index fb215192f..913de8053 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -211,8 +211,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - # Return labels as a dense AnchorLabels edge-list so the loss reads - # anchor/label indices directly without a per-anchor Python loop. + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. use_list_output=True, ) @@ -302,10 +301,8 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). - # Pairs are grouped by anchor, so reading label_index/anchor_index directly - # yields the same (query, label) pairs as the historical dict read; the - # within-anchor order may differ, but the loss is order-invariant. + # main_data.y_positive is an AnchorLabels edge-list; read label_index and + # query_node_idx[anchor_index] directly (see the AnchorLabels class docstring). positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 710cb8211..706f425c6 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -241,8 +241,7 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - # Return labels as a dense AnchorLabels edge-list so the loss reads - # anchor/label indices directly without a per-anchor Python loop. + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. use_list_output=True, ) @@ -308,10 +307,8 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). - # Pairs are grouped by anchor, so reading label_index/anchor_index directly - # yields the same (query, label) pairs as the historical dict read; the - # within-anchor order may differ, but the loss is order-invariant. + # main_data.y_positive is an AnchorLabels edge-list; read label_index and + # query_node_idx[anchor_index] directly (see the AnchorLabels class docstring). positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index 6e9416dbb..7a1522169 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -144,8 +144,7 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - # Return labels as a dense AnchorLabels edge-list so the loss reads - # anchor/label indices directly without a per-anchor Python loop. + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. use_list_output=True, ) @@ -226,12 +225,8 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), - # even in the heterogeneous setting (single supervision edge type per loss - # call). label_index holds the local label node per (anchor, label) pair; - # anchor_index holds the matching local anchor row. Pairs are grouped by - # anchor, so this yields the same (query, label) pairs as the historical dict - # read; the within-anchor order may differ, but the loss is order-invariant. + # main_data.y_positive is an AnchorLabels edge-list; read label_index and + # query_node_idx[anchor_index] directly (see homogeneous_training._compute_loss). positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) repeated_query_node_idx = query_node_idx[ main_data.y_positive.anchor_index.to(device) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index 4d3af134a..f8de44910 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -134,8 +134,7 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, - # Return labels as a dense AnchorLabels edge-list so the loss reads - # anchor/label indices directly without a per-anchor Python loop. + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. use_list_output=True, ) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index e51b4ccf5..33945f662 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,7 +1,7 @@ -from collections import abc, defaultdict +from collections import abc from dataclasses import dataclass from itertools import count -from typing import Optional, TypeVar, Union +from typing import Optional, Union import torch from graphlearn_torch.channel import SampleMessage @@ -58,22 +58,26 @@ @dataclass(frozen=True) class AnchorLabels: - """Dense edge-list ABLP labels for one label edge type. - - Replaces the ragged per-anchor ``dict[int, torch.Tensor]`` with two parallel - 1-D ``long`` tensors. Pair ``k`` asserts that local anchor row - ``anchor_index[k]`` has label node ``label_index[k]``. - - Pairs are emitted in ``(anchor, column)`` order from the source - ``[N_anchors, M]`` padded label tensor (row-major masked flatten). Within an - anchor, label order is **unspecified** -- the ABLP contrastive loss - (:class:`gigl.nn.loss.RetrievalLoss`) is permutation-invariant over the pair - stream (``CrossEntropyLoss(reduction="sum")`` over a diagonal-targeted score - matrix with value-based collision masks), so order carries no meaning. - ``anchor_index[k]`` and ``label_index[k]`` remain co-indexed. - - Anchors with no in-subgraph labels contribute zero pairs; ``num_anchors`` - records the full anchor count so empty anchors remain recoverable. + """ABLP labels for one edge type, stored as a flat (anchor, label) edge list. + + Each anchor can carry a different number of labels, so the natural shape is + ragged. Rather than a ``dict[int, torch.Tensor]`` -- which needs padding to + batch and a Python loop to read -- we keep two parallel ``long`` tensors: + pair ``k`` says local anchor ``anchor_index[k]`` is labeled by local node + ``label_index[k]``. Downstream the loss can index straight into these tensors + with no per-anchor iteration. + + Order within an anchor is deliberately left unspecified. The ABLP contrastive + loss (:class:`gigl.nn.loss.RetrievalLoss`) scores every (anchor, label) pair + independently and sums, so permuting the labels of an anchor leaves the loss + unchanged. We therefore emit pairs in the order they fall out of the source + ``[N_anchors, M]`` tensor (row-major, padding removed) and never pay to sort. + The two index tensors are always co-indexed, which is the only invariant that + matters. + + Empty anchors contribute no pairs at all. ``num_anchors`` is carried + separately so :meth:`to_dict` can still emit a key for every anchor, even the + ones that matched nothing. Example:: @@ -83,19 +87,14 @@ class AnchorLabels: ... label_index=torch.tensor([3, 5, 7], dtype=torch.long), ... num_anchors=3, ... ) - >>> d = labels.to_dict() - >>> d[0].tolist() - [3] - >>> d[1].tolist() - [5, 7] - >>> d[2].tolist() - [] + >>> labels.to_dict() + {0: tensor([3]), 1: tensor([5, 7]), 2: tensor([], dtype=torch.int64)} Args: anchor_index (torch.Tensor): ``[E]`` long tensor of local anchor rows. label_index (torch.Tensor): ``[E]`` long tensor of local label node ids. num_anchors (int): Total number of anchors ``N`` (rows of the source - padded label tensor). + padded label tensor), including anchors with no labels. """ anchor_index: torch.Tensor @@ -117,110 +116,39 @@ def to_dict(self) -> dict[int, torch.Tensor]: return {anchor: per_anchor[anchor] for anchor in range(self.num_anchors)} -def _loop_set_labels( - node_local_to_global_by_type: dict[NodeType, torch.Tensor], - positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], - negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], - supervision_edge_types: list[EdgeType], - to_device: torch.device, -) -> tuple[ - dict[EdgeType, dict[int, torch.Tensor]], - dict[EdgeType, dict[int, torch.Tensor]], -]: - """Per-anchor (loop) label remap from global label ids to local node indices. - - Reference implementation retained as the equivalence oracle for - :func:`vectorized_set_labels`. The production path uses the vectorized - kernel; this loop is exercised only by tests. - - For each label edge type and each anchor row of its ``[N_anchors, M]`` - ``-1``-padded label tensor, emits the ascending local indices into the - supervision node type's ``node`` map whose global id appears in that row, in - :func:`torch.nonzero` multiplicity. - - Args: - node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node - type, a ``[N]`` tensor whose ``i``-th entry is the global id of - local node ``i``. - positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per - positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor - of global label ids. - negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, - for negative-label edge types. May be empty. - supervision_edge_types (list[EdgeType]): Supervision edge types (unused - here; accepted for signature parity with the vectorized kernel). - to_device (torch.device): Device for every output tensor. - - Returns: - Tuple ``(y_positive, y_negative)``, each a - ``dict[message_passing_edge_type, dict[anchor_index, local_index_tensor]]`` - with an entry for every anchor index ``0..N_anchors-1``. - """ - del supervision_edge_types # Parity with vectorized_set_labels; not needed. - output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(dict) - output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict(dict) - # Supervision edge types are (anchor_node_type, to, supervision_node_type), - # so the supervision node type is at index 2. - edge_index = 2 - for edge_type, label_tensor in positive_labels_by_edge_type.items(): - message_passing_edge_type = label_edge_type_to_message_passing_edge_type( - edge_type - ) - supervision_node_map = node_local_to_global_by_type[edge_type[edge_index]] - for local_anchor_node_id in range(label_tensor.size(0)): - positive_mask = ( - supervision_node_map.unsqueeze(1) == label_tensor[local_anchor_node_id] - ) - output_positive_labels[message_passing_edge_type][local_anchor_node_id] = ( - torch.nonzero(positive_mask)[:, 0].to(to_device) - ) - for edge_type, label_tensor in negative_labels_by_edge_type.items(): - message_passing_edge_type = label_edge_type_to_message_passing_edge_type( - edge_type - ) - supervision_node_map = node_local_to_global_by_type[edge_type[edge_index]] - for local_anchor_node_id in range(label_tensor.size(0)): - negative_mask = ( - supervision_node_map.unsqueeze(1) == label_tensor[local_anchor_node_id] - ) - output_negative_labels[message_passing_edge_type][local_anchor_node_id] = ( - torch.nonzero(negative_mask)[:, 0].to(to_device) - ) - return dict(output_positive_labels), dict(output_negative_labels) - - def _membership_remap( label_tensor: torch.Tensor, sorted_node: torch.Tensor, sort_perm: torch.Tensor, to_device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, int]: - """Shared per-tensor membership core for the vectorized label kernels. - - Remaps one ``[N_anchors, M]`` ``-1``-padded label tensor of global ids to the - flat ``(anchor_index, local_index)`` pair stream shared by both the dict and - edge-list builders. Pairs are emitted in ``(anchor, column)`` order (the - natural row-major order of the masked flatten). Within an anchor, label order - is **unspecified** -- the ABLP contrastive loss is permutation-invariant over - the pair stream, so order carries no meaning; see :class:`AnchorLabels`. - - The two callers (:func:`_remap_one_label_tensor`, - :func:`_remap_one_label_tensor_edge_list`) differ only in how they package this - pair stream, so the searchsorted membership logic lives here once. - - The pair stream is already non-decreasing in ``anchor_index`` because - ``anchor_of_entry`` is built as ``arange(N).repeat_interleave(M)`` (row-major) - and the ``is_present`` mask preserves order within each anchor. This means the - dict builder's ``bincount``/``split`` is correct without any additional sorting: - it only requires contiguous grouping by anchor, which row-major flatten - guarantees. - - Precondition (REQUIRED for correctness): ``sorted_node`` must have UNIQUE - values (the node map is unique local->global). :func:`torch.searchsorted` - returns the left-most equal position, so a duplicate global id would collapse - multiple local indices to one. GiGL ``node`` maps guarantee uniqueness; the - check is asserted only under ``__debug__`` to keep the hot path zero-cost (and - is a no-op under ``python -O``). + """Resolve one padded label tensor to a flat ``(anchor, label)`` pair stream. + + This is where the actual global-id-to-local-index lookup happens; everything + above it just packages the result. Given one ``[N_anchors, M]`` block of + ``-1``-padded global label ids, it returns the matched pairs as two parallel + index tensors plus the anchor count, the raw form :class:`AnchorLabels` wraps. + + The lookup is a sorted-membership join rather than a per-anchor scan: sort the + node map once, ``searchsorted`` every label id into it, and keep only exact + hits. The sort permutation then carries each hit back to its original local + index. This trades an ``O(N_anchors * M * N_nodes)`` broadcast-compare for a + single ``O(E log N_nodes)`` search, which is what lets the loader remap labels + without a Python loop over anchors. + + Pairs come out grouped by anchor for free: ``anchor_of_entry`` is built + row-major (``arange(N).repeat_interleave(M)``) and every mask preserves that + order, so the stream is non-decreasing in ``anchor_index`` without any sort. + Callers can group by anchor with a plain ``bincount``/``split``. Order *within* + an anchor is left as it falls out of the columns -- unspecified by contract, + since the loss does not care (see :class:`AnchorLabels`). + + The lookup is only correct if ``sorted_node`` has unique values: + :func:`torch.searchsorted` returns the left-most equal position, so a repeated + global id would map every match to the same local index and silently drop the + rest. GiGL ``node`` maps are unique by construction (one entry per subgraph + node), so this holds in production; the ``__debug__`` assertion guards against + misuse with a cheap adjacent-difference check on the already-sorted map. Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global @@ -232,8 +160,8 @@ def _membership_remap( Returns: Tuple ``(anchor_index, local_index, num_anchors)``. ``anchor_index`` and - ``local_index`` are equal-length 1-D ``long`` tensors on ``to_device`` in - ``(anchor, column)`` order (empty when nothing matched); + ``local_index`` are equal-length 1-D ``long`` tensors on ``to_device``, + grouped by anchor (empty when nothing matched); ``num_anchors == label_tensor.size(0)``. """ num_anchors = int(label_tensor.size(0)) @@ -261,7 +189,9 @@ def _membership_remap( return empty, empty, num_anchors if __debug__: - assert int(torch.unique(sorted_node).numel()) == num_nodes, ( + # `sorted_node` is already sorted, so uniqueness is equivalent to being + # strictly increasing -- a cheap adjacent-difference check, no re-sort. + assert bool((sorted_node[1:] > sorted_node[:-1]).all()), ( "vectorized label remap requires a unique node local->global map; " "duplicate global ids break the searchsorted membership lookup." ) @@ -293,233 +223,31 @@ def _membership_remap( ) -def _remap_one_label_tensor( - label_tensor: torch.Tensor, - sorted_node: torch.Tensor, - sort_perm: torch.Tensor, - to_device: torch.device, -) -> dict[int, torch.Tensor]: - """Vectorized remap of one ``[N_anchors, M]`` padded label tensor to a dict. - - Thin wrapper over :func:`_membership_remap`: splits the shared - ``(anchor_index, local_index)`` pair stream into a per-anchor dict via - ``torch.bincount``/``torch.split``. The pair stream from - :func:`_membership_remap` is non-decreasing in ``anchor_index`` (row-major - masked flatten), so ``bincount``/``split`` is correct without additional - sorting. For each anchor row, the value is the set of local indices into the - original (pre-sort) node order whose global id appears in that row, in - column-visit order (matching multiplicity -- duplicate label columns produce - repeated local entries). - - Args: - label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global - label ids. - sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. - sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted - positions back to original local indices. - to_device (torch.device): Device for every output tensor. - - Returns: - Mapping from anchor index ``0..N_anchors-1`` to a 1-D ``long`` tensor of - local indices (empty where the row matched nothing). - """ - anchor_index, local_idx, num_anchors = _membership_remap( - label_tensor, sorted_node, sort_perm, to_device - ) - # Defensive: `vectorized_set_labels` already `continue`s past zero-anchor - # tensors (to match the loop's defaultdict, which never creates the outer - # key), so this branch is unreachable from that caller. Kept so the helper is - # self-consistent for any external caller. - if num_anchors == 0: - return {} - counts = torch.bincount(anchor_index, minlength=num_anchors) - per_anchor = torch.split(local_idx, counts.tolist()) - return {anchor: per_anchor[anchor] for anchor in range(num_anchors)} - - -_LabelT = TypeVar("_LabelT") - - -def _remap_group( - labels_by_edge_type: dict[EdgeType, torch.Tensor], - sorted_for: abc.Callable[[NodeType], tuple[torch.Tensor, torch.Tensor]], - remap_one: abc.Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, torch.device], - _LabelT, - ], - to_device: torch.device, -) -> dict[EdgeType, _LabelT]: - """Remap one group (positive or negative) of label tensors. - - Iterates ``labels_by_edge_type``, skips zero-anchor tensors (to match the - loop oracle's ``defaultdict`` which never creates an outer key for an empty - anchor set), and delegates each tensor to ``remap_one``. - - Args: - labels_by_edge_type: Per label edge type, a ``[N_anchors, M]`` - ``-1``-padded global label tensor. - sorted_for: Callable returning ``(sorted_values, sort_perm)`` for a - given supervision node type (typically a memoized closure). - remap_one: Per-tensor remap callable with signature - ``(label_tensor, sorted_node, sort_perm, to_device) -> _LabelT``. - Either :func:`_remap_one_label_tensor` (dict output) or - :func:`_remap_one_label_tensor_edge_list` (:class:`AnchorLabels`). - to_device: Device for all output tensors. - - Returns: - ``dict[message_passing_edge_type, _LabelT]`` with no entry for - zero-anchor label tensors. - """ - # Supervision edge types are (anchor_node_type, relation, supervision_node_type). - supervision_node_type_index = 2 - output: dict[EdgeType, _LabelT] = {} - for edge_type, label_tensor in labels_by_edge_type.items(): - # Match the loop oracle's defaultdict: a zero-anchor tensor produces no - # outer key (the loop's per-anchor body never runs for an empty tensor). - if label_tensor.size(0) == 0: - continue - sorted_node, sort_perm = sorted_for(edge_type[supervision_node_type_index]) - output[label_edge_type_to_message_passing_edge_type(edge_type)] = remap_one( - label_tensor, sorted_node, sort_perm, to_device - ) - return output - - -def vectorized_set_labels( - node_local_to_global_by_type: dict[NodeType, torch.Tensor], - positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], - negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], - supervision_edge_types: list[EdgeType], - to_device: torch.device, -) -> tuple[ - dict[EdgeType, dict[int, torch.Tensor]], - dict[EdgeType, dict[int, torch.Tensor]], -]: - """Vectorized label remap from global label ids to local node indices. - - Drop-in replacement for the per-anchor loop in :func:`_loop_set_labels`, - producing set-equivalent ragged output without a per-anchor Python loop. - Equivalent to the loop oracle up to within-anchor label order; the ABLP - contrastive loss is permutation-invariant over the pair stream, so the - output is interchangeable with the oracle's for all production consumers. - - For each label edge type and each anchor row of its ``[N_anchors, M]`` - ``-1``-padded label tensor, emits the set of local indices into the - supervision node type's ``node`` map whose global id appears in that row, - in column-visit order with matching multiplicity (duplicate label columns - produce repeated local entries). The padding sentinel - (:data:`gigl.utils.data_splitters.PADDING_NODE`) is masked before any - search, so it is never used as a lookup key. Every anchor index - ``0..N_anchors-1`` receives a key; anchors with no in-subgraph labels map - to an empty ``long`` tensor. - - Precondition (REQUIRED for correctness): each ``node`` local->global map in - ``node_local_to_global_by_type`` MUST contain UNIQUE global ids. The - ``torch.searchsorted`` membership lookup returns the LEFT-MOST matching sorted - position; a repeated global id would resolve every match to a single local - index, dropping the duplicate. GiGL ``node`` maps satisfy this by construction. - - Args: - node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node - type, a ``[N]`` tensor whose ``i``-th entry is the global id of - local node ``i``. Global ids MUST be unique within each map. - positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per - positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor - of global label ids. - negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, - for negative-label edge types. May be empty. - supervision_edge_types (list[EdgeType]): Supervision edge types (unused - here; accepted for signature parity with the loop reference). - to_device (torch.device): Device for every output tensor. - - Returns: - Tuple ``(y_positive, y_negative)``, each a - ``dict[message_passing_edge_type, dict[anchor_index, local_index_tensor]]`` - with an entry for every anchor index ``0..N_anchors-1``. - """ - del supervision_edge_types # Accepted for signature parity; not needed here. - sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} - - def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: - if node_type not in sorted_cache: - sorted_cache[node_type] = torch.sort( - node_local_to_global_by_type[node_type] - ) - return sorted_cache[node_type] - - return ( - _remap_group( - positive_labels_by_edge_type, - _sorted_for, - _remap_one_label_tensor, - to_device, - ), - _remap_group( - negative_labels_by_edge_type, - _sorted_for, - _remap_one_label_tensor, - to_device, - ), - ) - - -def _remap_one_label_tensor_edge_list( - label_tensor: torch.Tensor, - sorted_node: torch.Tensor, - sort_perm: torch.Tensor, - to_device: torch.device, -) -> AnchorLabels: - """Vectorized edge-list remap of one ``[N_anchors, M]`` padded label tensor. - - Thin wrapper over the shared :func:`_membership_remap`: wraps the returned - ``(anchor_index, local_index)`` pair stream directly into a dense - :class:`AnchorLabels`. This is strictly less work than the dict builder (no - ``torch.bincount``/``torch.split``, no per-anchor Python comprehension); - ``_membership_remap`` already moved both tensors to ``to_device`` as ``long``. - - Pairs are in ``(anchor, column)`` order; within an anchor, label order is - unspecified (see :class:`AnchorLabels`). - - Args: - label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global - label ids. - sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. - sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted - positions back to original local indices. - to_device (torch.device): Device for the output tensors. - - Returns: - :class:`AnchorLabels` with ``anchor_index``/``label_index`` in - ``(anchor, column)`` order and ``num_anchors == label_tensor.size(0)``. - """ - anchor_index, label_index, num_anchors = _membership_remap( - label_tensor, sorted_node, sort_perm, to_device - ) - return AnchorLabels( - anchor_index=anchor_index, - label_index=label_index, - num_anchors=num_anchors, - ) - - def edge_list_set_labels( node_local_to_global_by_type: dict[NodeType, torch.Tensor], positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], - supervision_edge_types: list[EdgeType], to_device: torch.device, ) -> tuple[dict[EdgeType, AnchorLabels], dict[EdgeType, AnchorLabels]]: - """Dense edge-list label remap from global label ids to local node indices. - - Drop-in alternative to :func:`vectorized_set_labels` that emits, per label - edge type, an :class:`AnchorLabels` dense edge-list instead of a ragged - ``dict[int, torch.Tensor]``. Membership semantics are identical; expanding - each result via :meth:`AnchorLabels.to_dict` produces a set-equivalent dict - (same per-anchor label sets as :func:`_loop_set_labels`, in column-visit - order rather than ascending-local order -- permutation-invariant for the loss). - - Precondition (REQUIRED): each ``node`` local->global map must contain UNIQUE - global ids -- see :func:`vectorized_set_labels` for the rationale. + """Remap ABLP labels from global ids to local indices, as dense edge lists. + + This is the loader's single label-remap entry point. Sampling hands back + labels as global node ids in ``[N_anchors, M]`` padded blocks; training needs + them as local indices into the sampled subgraph. For each edge type this + resolves that mapping and returns an :class:`AnchorLabels` edge list. Callers + wanting the ragged ``dict[int, torch.Tensor]`` instead can expand each result + with :meth:`AnchorLabels.to_dict`; the two are the same labels in a different + container, and the loss treats them identically. + + Positive and negative labels are remapped the same way, against the same + sorted node maps, so the work is shared via a memoized sort per supervision + node type. A zero-anchor tensor is skipped rather than emitted as an empty + entry: there are simply no anchors to label for that edge type in this batch, + so no key is emitted for it. + + Correctness rests on each ``node`` local->global map having unique global ids; + the membership lookup relies on it (see :func:`_membership_remap`). GiGL node + maps satisfy this by construction. Args: node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node @@ -530,16 +258,13 @@ def edge_list_set_labels( of global label ids. negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, for negative-label edge types. May be empty. - supervision_edge_types (list[EdgeType]): Accepted for signature parity - with the other kernels; unused here. to_device (torch.device): Device for every output tensor. Returns: Tuple ``(y_positive, y_negative)``, each a - ``dict[message_passing_edge_type, AnchorLabels]`` with NO entry for a - zero-anchor label tensor (matching the loop's defaultdict). + ``dict[message_passing_edge_type, AnchorLabels]`` with no entry for an + edge type that had no anchors this batch. """ - del supervision_edge_types # Accepted for signature parity; not needed here. sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: @@ -549,20 +274,25 @@ def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: ) return sorted_cache[node_type] - return ( - _remap_group( - positive_labels_by_edge_type, - _sorted_for, - _remap_one_label_tensor_edge_list, - to_device, - ), - _remap_group( - negative_labels_by_edge_type, - _sorted_for, - _remap_one_label_tensor_edge_list, - to_device, - ), - ) + def _remap( + labels_by_edge_type: dict[EdgeType, torch.Tensor], + ) -> dict[EdgeType, AnchorLabels]: + # Supervision edge types are (anchor_type, relation, supervision_type). + supervision_node_type_index = 2 + output: dict[EdgeType, AnchorLabels] = {} + for edge_type, label_tensor in labels_by_edge_type.items(): + # No anchors for this edge type this batch -> no key, not an empty one. + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[supervision_node_type_index]) + output[label_edge_type_to_message_passing_edge_type(edge_type)] = ( + AnchorLabels( + *_membership_remap(label_tensor, sorted_node, sort_perm, to_device) + ) + ) + return output + + return _remap(positive_labels_by_edge_type), _remap(negative_labels_by_edge_type) class DistABLPLoader(BaseDistLoader): @@ -605,13 +335,15 @@ def __init__( """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. - Note that for this class, the dataset must *always* be heterogeneous, - as we need separate edge types for positive and negative labels. + The dataset must *always* be heterogeneous here, since positive and + negative labels are carried as separate edge types. By default, the loader will return {py:class} `torch_geometric.data.HeteroData` (heterogeneous) objects, but will return a {py:class}`torch_geometric.data.Data` (homogeneous) object if the dataset is "labeled homogeneous". - The following fields may also be present: + The following fields may also be present (this describes the default + `use_list_output=False` shape; see `use_list_output` below for the + `AnchorLabels` edge-list alternative): - `y_positive`: `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of positive label node ids. - `y_negative`: (Optional) `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of negative @@ -649,19 +381,13 @@ def __init__( - `y_positive`: {(a, to, b): {0: torch.tensor([1])}, (a, to, c): {0: torch.tensor([2])}} - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} - Label remapping (the conversion of global label ids to the local - indices stored in `y_positive`/`y_negative`) is vectorized internally; - the output is identical to the historical per-anchor implementation. - - When `use_list_output=True`, `y_positive` and `y_negative` are instead a - dense edge-list `AnchorLabels` (single supervision edge type) or - `dict[EdgeType, AnchorLabels]` (multiple). An `AnchorLabels` holds - `anchor_index` ([E] long), `label_index` ([E] long), and `num_anchors` - (int): pair `k` means local anchor row `anchor_index[k]` has local label - node `label_index[k]`. Pairs are ordered ascending by anchor (ties by - ascending label column). `AnchorLabels.to_dict()` reproduces the ragged - `dict[int, torch.Tensor]` form above. With `use_list_output=False` - (default) the output is the ragged dict, fully backward-compatible. + With `use_list_output=True`, the labels arrive instead as an `AnchorLabels` + edge-list (or `dict[EdgeType, AnchorLabels]` for several supervision edge + types); see :class:`AnchorLabels` for its shape. The edge-list keeps the + ragged per-anchor labels as flat tensors the loss can index directly, with + no padding or per-anchor Python loop; within-anchor order is unspecified + but the ABLP loss is order-invariant, and `AnchorLabels.to_dict()` recovers + the dict form. Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. @@ -738,16 +464,13 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. - use_list_output (bool): If True, return labels as a dense - ``AnchorLabels`` edge-list (or ``dict[EdgeType, AnchorLabels]`` - for multiple supervision edge types) instead of the ragged - ``dict[anchor_local_index, torch.Tensor]``. The edge-list form - lets the loss read labels without a per-anchor Python loop - (``y.label_index`` and ``query_idx[y.anchor_index]`` instead of - ``torch.cat(list(y.values()))`` and a ``repeat_interleave`` over - per-anchor lengths). ``AnchorLabels.to_dict()`` recovers the - ragged form. Defaults to ``False`` (ragged dict; fully - backward-compatible). + use_list_output (bool): Return labels as an ``AnchorLabels`` edge-list + (or ``dict[EdgeType, AnchorLabels]`` for multiple supervision edge + types) instead of the ragged ``dict[anchor_local_index, + torch.Tensor]``; see :class:`AnchorLabels` for the shape. The + edge-list lets the loss read ``y.label_index`` and + ``query_idx[y.anchor_index]`` directly. Defaults to ``False`` (the + backward-compatible ragged dict). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -1311,20 +1034,26 @@ def _set_labels( node_type_to_local_node_to_global_node[DEFAULT_HOMOGENEOUS_NODE_TYPE] = ( data.node ) - # Vectorized remap is the production path: bit-for-bit equivalent to the - # _loop_set_labels oracle, GPU-safe, and faster than the per-anchor loop - # (the gap grows with batch size). When use_list_output is set, emit the - # dense AnchorLabels edge-list instead of the ragged per-anchor dict. - label_remap = ( - edge_list_set_labels if self._use_list_output else vectorized_set_labels - ) - output_positive_labels, output_negative_labels = label_remap( + # The edge-list kernel is the single remap path; the ragged dict is just + # one view of it (AnchorLabels.to_dict), so when the caller wants the dict + # we expand here rather than maintaining a second kernel. Both forms feed + # an order-invariant contrastive loss, so the choice is purely about the + # consumer's preferred shape. + output_positive_labels, output_negative_labels = edge_list_set_labels( node_local_to_global_by_type=node_type_to_local_node_to_global_node, positive_labels_by_edge_type=positive_labels_by_label_edge_type, negative_labels_by_edge_type=negative_labels_by_label_edge_type, - supervision_edge_types=self._supervision_edge_types, to_device=self.to_device, ) + if not self._use_list_output: + output_positive_labels = { + et: anchor_labels.to_dict() + for et, anchor_labels in output_positive_labels.items() + } + output_negative_labels = { + et: anchor_labels.to_dict() + for et, anchor_labels in output_negative_labels.items() + } if not output_positive_labels: raise ValueError("No positive labels were found in the data!") elif len(output_positive_labels) == 1: diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index ce37c3144..e35803269 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -423,10 +423,10 @@ def _global_pair_set( pair list for set-equality comparison. Pairs are collected in anchor-ascending order and sorted within each anchor - by global label value, so the result is canonical regardless of the - within-anchor label order emitted by the kernel (column-visit order) vs the - loop oracle (ascending-local order). Multiplicity is preserved: duplicate - label columns produce duplicate entries in the list. + by global label value, so the result is canonical regardless of the order in + which the kernel emits an anchor's labels (which is unspecified by contract). + Multiplicity is preserved: duplicate label columns produce duplicate entries + in the list. Use ``collections.Counter`` or ``sorted(...)`` equality rather than positional ``==`` if pair-level multiplicity without anchor grouping matters. @@ -454,19 +454,17 @@ def _collect_homogeneous_labels( Local node indices differ run-to-run, so labels are translated back to global ids via ``datum.node``. Pairs are sorted within each anchor (see - ``_global_pair_set``) for canonical set-equality comparison in the parent: - the vectorized kernel emits column-visit order while the loop oracle emits - ascending-local order; both are correct since the ABLP loss is - permutation-invariant over the pair stream. + ``_global_pair_set``) for canonical set-equality comparison in the parent, + since the order in which an anchor's labels are emitted is unspecified -- the + ABLP loss is permutation-invariant over the pairs. When ``use_list_output`` is True the labels arrive as :class:`AnchorLabels`. - This branch ALSO asserts in-process that the exact tensors the example - training loss reads from the edge-list match the legacy dict read: the - edge-list ``label_index`` must equal the dict's ``torch.cat(values())`` and - ``query_node_idx[anchor_index]`` must equal the legacy - ``repeat_interleave`` over per-anchor lengths. Both paths draw from the same - :func:`_membership_remap` pair stream (column-visit order), so they are - identical -- a drift here is exactly the example-training bug we guard against. + This branch ALSO asserts in-process that reading labels straight off the + edge-list agrees with reading them via the dict view: ``label_index`` must + equal the dict's ``torch.cat(values())``, and ``query_node_idx[anchor_index]`` + must equal a ``repeat_interleave`` of the anchors over their per-anchor label + counts. Both views come from the same edge-list, so any drift between the two + read patterns is a real bug, not just a reordering. """ create_test_process_group() loader = DistABLPLoader( @@ -487,20 +485,20 @@ def _collect_homogeneous_labels( f"expected AnchorLabels, got {type(datum.y_positive)}" ) positive_dict = datum.y_positive.to_dict() - # Direct check that the example-training read matches the legacy read, - # within this single batch, EXACTLY (order included). + # The direct edge-list read must match the dict-view read within this + # single batch, EXACTLY (order included). query_node_idx = torch.arange(datum.batch_size) - legacy_positive_idx = torch.cat( + dict_view_label_idx = torch.cat( [positive_dict[a] for a in range(datum.batch_size)] ) - legacy_repeated_query = query_node_idx.repeat_interleave( + dict_view_repeated_query = query_node_idx.repeat_interleave( torch.tensor([len(positive_dict[a]) for a in range(datum.batch_size)]) ) torch.testing.assert_close( - datum.y_positive.label_index, legacy_positive_idx + datum.y_positive.label_index, dict_view_label_idx ) torch.testing.assert_close( - query_node_idx[datum.y_positive.anchor_index], legacy_repeated_query + query_node_idx[datum.y_positive.anchor_index], dict_view_repeated_query ) else: positive_dict = datum.y_positive @@ -711,11 +709,11 @@ def test_use_list_output_matches_dict_output( False), so the two loader runs emit batches in the same order and the sorted pair sets are directly comparable. - The child process additionally asserts the exact tensors the example-training - loss reads from the edge-list equal the legacy dict read (see + The child process additionally asserts that reading labels straight off the + edge-list equals reading them through the dict view (see ``_collect_homogeneous_labels``): ``label_index`` equals the dict's ``torch.cat(values())`` and ``query_node_idx[anchor_index]`` equals the - legacy ``repeat_interleave``. + matching ``repeat_interleave`` over per-anchor counts. """ edge_index = { DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py index 432accf12..f57831980 100644 --- a/tests/unit/distributed/edge_list_set_labels_test.py +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -1,7 +1,18 @@ -"""Unit tests for the dense edge-list ABLP label container and kernel. +"""Unit tests for the ABLP label-remap kernel and its edge-list container. These exercise the pure-tensor label-remap logic directly (no GLT, no distributed runtime), so they run in-process without ``mp.spawn``. + +``edge_list_set_labels`` is the loader's single label-remap path; it turns padded +blocks of global label ids into per-edge-type :class:`AnchorLabels` edge lists. +We assert against constructed expected values rather than a second reference +implementation, and we check both the edge-list tensors and their +:meth:`AnchorLabels.to_dict` view, since the loader uses both forms. + +Within an anchor the kernel emits labels in column-visit order, which is left +unspecified by contract (the ABLP loss is order-invariant). We therefore pin +exact tensors only where the node map is sorted -- there column order coincides +with ascending-local order -- and otherwise compare per-anchor label *sets*. """ import unittest @@ -12,12 +23,12 @@ from gigl.distributed.dist_ablp_neighborloader import ( AnchorLabels, - _loop_set_labels, - _remap_one_label_tensor_edge_list, edge_list_set_labels, ) from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, message_passing_to_negative_label, message_passing_to_positive_label, ) @@ -42,33 +53,27 @@ def _neg(edge_type: EdgeType) -> EdgeType: return message_passing_to_negative_label(edge_type) -def _dict_from_edge_list( - edge_list_by_type: dict[PyGEdgeType, AnchorLabels], -) -> dict[PyGEdgeType, dict[int, torch.Tensor]]: - return {et: labels.to_dict() for et, labels in edge_list_by_type.items()} - - -def _assert_label_dicts_set_equal( +def _assert_dict_sets_equal( actual: dict[PyGEdgeType, dict[int, torch.Tensor]], expected: dict[PyGEdgeType, dict[int, torch.Tensor]], ) -> None: """Assert per-anchor SET equality (with multiplicity) between two label dicts. - The edge-list kernel emits pairs in column-visit order; the loop oracle emits - them in ascending-local order. Both are valid: the ABLP contrastive loss is - permutation-invariant over the pair stream. This helper normalises order via - ``sorted()`` so it catches membership and multiplicity errors while remaining - invariant to within-anchor permutations. + Comparison is order-independent within an anchor (the kernel does not pin + within-anchor order) but still catches missing/extra labels and multiplicity + (duplicate label columns), and asserts each value tensor is ``long``. """ assert set(actual.keys()) == set(expected.keys()), ( f"{set(actual.keys())} != {set(expected.keys())}" ) for edge_type, inner in expected.items(): actual_inner = actual[edge_type] - assert set(actual_inner.keys()) == set(inner.keys()) + assert set(actual_inner.keys()) == set(inner.keys()), ( + f"{edge_type}: {set(actual_inner.keys())} != {set(inner.keys())}" + ) for anchor, expected_tensor in inner.items(): got = actual_inner[anchor] - assert got.dtype == torch.long + assert got.dtype == torch.long, f"{edge_type}[{anchor}] dtype {got.dtype}" assert sorted(got.tolist()) == sorted(expected_tensor.tolist()), ( f"{edge_type}[{anchor}]: got {got.tolist()}, " f"expected {expected_tensor.tolist()} (as sets)" @@ -76,8 +81,9 @@ def _assert_label_dicts_set_equal( class AnchorLabelsTest(TestCase): - def test_to_dict_round_trips_empty_and_multi(self) -> None: + def test_to_dict_expands_empty_and_multi_label_anchors(self) -> None: # 3 anchors: anchor 0 -> [5], anchor 1 -> [] (empty), anchor 2 -> [7, 8]. + # Every anchor must get a key, including the empty one in the middle. labels = AnchorLabels( anchor_index=torch.tensor([0, 2, 2], dtype=torch.long), label_index=torch.tensor([5, 7, 8], dtype=torch.long), @@ -89,59 +95,26 @@ def test_to_dict_round_trips_empty_and_multi(self) -> None: torch.testing.assert_close(as_dict[1], torch.empty(0, dtype=torch.long)) torch.testing.assert_close(as_dict[2], torch.tensor([7, 8], dtype=torch.long)) - -class RemapOneEdgeListTest(TestCase): - def test_column_order_with_padding_and_empty(self) -> None: - # When the node map is sorted, column order and ascending-local order - # coincide, so this verifies exact output. - # anchor 0: [15, -1] -> local 5 ; anchor 1: [15, 16] -> local 5, 6 - # (column order == ascending-local here because node is already sorted); - # anchor 2: [-1, -1] -> empty ; anchor 3: [99, -1] -> empty (99 absent). - node = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]) - sorted_node, sort_perm = torch.sort(node) - label_tensor = torch.tensor( - [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long - ) - result = _remap_one_label_tensor_edge_list( - label_tensor, sorted_node, sort_perm, _CPU - ) - self.assertEqual(result.num_anchors, 4) - torch.testing.assert_close( - result.anchor_index, torch.tensor([0, 1, 1], dtype=torch.long) - ) - torch.testing.assert_close( - result.label_index, torch.tensor([5, 5, 6], dtype=torch.long) - ) - - def test_unsorted_node_map_correct_membership(self) -> None: - # node map is UNSORTED, so torch.sort yields a non-identity sort_perm. - # node[i] = global id of local i: g15->local0, g10->local1, g16->local2, - # g11->local3. The label row is high-id-first ([16, 15]); the kernel emits - # pairs in column order (g16=local2 first, g15=local0 second) -> [2, 0]. - # The loop oracle would emit ascending-local order -> [0, 2]. Both are - # SET-equal to {0, 2}. A port that dropped the sort_perm gather would emit - # sorted-position indices (1 and 3) and fail the membership check. - node = torch.tensor([15, 10, 16, 11]) - sorted_node, sort_perm = torch.sort(node) - label_tensor = torch.tensor([[16, 15]], dtype=torch.long) - result = _remap_one_label_tensor_edge_list( - label_tensor, sorted_node, sort_perm, _CPU - ) - self.assertEqual(result.num_anchors, 1) - torch.testing.assert_close( - result.anchor_index, torch.tensor([0, 0], dtype=torch.long) - ) - # Column order: g16 (local2) first, g15 (local0) second. - torch.testing.assert_close( - result.label_index, torch.tensor([2, 0], dtype=torch.long) + def test_to_dict_all_empty(self) -> None: + # No pairs at all: every anchor still gets an empty tensor. + labels = AnchorLabels( + anchor_index=torch.empty(0, dtype=torch.long), + label_index=torch.empty(0, dtype=torch.long), + num_anchors=2, ) + as_dict = labels.to_dict() + self.assertEqual(set(as_dict.keys()), {0, 1}) + torch.testing.assert_close(as_dict[0], torch.empty(0, dtype=torch.long)) + torch.testing.assert_close(as_dict[1], torch.empty(0, dtype=torch.long)) -class EdgeListSetLabelsEquivalenceTest(TestCase): +class EdgeListSetLabelsTest(TestCase): @parameterized.expand( [ + # Sorted node map -> exact tensors are pinned. Covers present labels, + # a fully-padded (empty) anchor, and an absent global id (also empty). param( - "homogeneous_present_empty_and_padded", + "sorted_present_empty_and_padded", node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, positives={ _pos(_USER_TO_STORY): torch.tensor( @@ -149,10 +122,20 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): ) }, negatives={}, - supervision_edge_types=[_USER_TO_STORY], + expected_positive={ + _USER_TO_STORY: { + 0: [5], + 1: [5, 6], + 2: [], + 3: [], + } + }, + expected_negative={}, ), + # Duplicate label columns must keep multiplicity: a node matching two + # identical columns appears twice in its anchor's labels. param( - "homogeneous_duplicate_labels", + "duplicate_label_columns_keep_multiplicity", node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15])}, positives={ _pos(_USER_TO_STORY): torch.tensor( @@ -160,24 +143,32 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): ) }, negatives={}, - supervision_edge_types=[_USER_TO_STORY], + expected_positive={_USER_TO_STORY: {0: [5, 5], 1: [1, 1]}}, + expected_negative={}, ), - # UNSORTED node map + reversed label columns -- non-identity sort_perm. - # The kernel emits column order (g16=local2 first, g15=local0 second); - # the loop oracle emits ascending-local order (local0, local2). - # The per-anchor SET is {0, 2} for both, guarding against a port that - # emits sorted-position indices (1, 3) instead of sort_perm-mapped locals. + # Real homogeneous keying: the loader keys the node map by + # DEFAULT_HOMOGENEOUS_NODE_TYPE and uses DEFAULT_HOMOGENEOUS_EDGE_TYPE. + # Exercise that exact keying, not just a custom NodeType. Node map is + # unsorted here, so we compare per-anchor sets. param( - "unsorted_node_map_reversed_columns", - node_map={_STORY: torch.tensor([15, 10, 16, 11])}, + "default_homogeneous_keying", + node_map={ + DEFAULT_HOMOGENEOUS_NODE_TYPE: torch.tensor([20, 10, 30, 11, 15]) + }, positives={ - _pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long) + message_passing_to_positive_label( + DEFAULT_HOMOGENEOUS_EDGE_TYPE + ): torch.tensor([[30, 10], [-1, -1]], dtype=torch.long) }, negatives={}, - supervision_edge_types=[_USER_TO_STORY], + # g30->local2, g10->local1 ; second anchor fully padded. + expected_positive={DEFAULT_HOMOGENEOUS_EDGE_TYPE: {0: [2, 1], 1: []}}, + expected_negative={}, ), + # Negatives travel the same path and must remap independently of the + # positives. Sorted node map -> exact sets are unambiguous. param( - "homogeneous_with_negatives", + "with_negatives", node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, positives={ _pos(_USER_TO_STORY): torch.tensor([[15], [16]], dtype=torch.long) @@ -187,10 +178,13 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): [[13, 16], [17, -1]], dtype=torch.long ) }, - supervision_edge_types=[_USER_TO_STORY], + expected_positive={_USER_TO_STORY: {0: [5], 1: [6]}}, + expected_negative={_USER_TO_STORY: {0: [3, 6], 1: [7]}}, ), + # Multiple supervision edge types: each gets its own key and remaps + # against its own supervision node type's map. param( - "heterogeneous_multi_edge_type", + "multi_edge_type", node_map={ _A: torch.tensor([10]), _B: torch.tensor([11, 12, 13, 14, 20, 21]), @@ -201,8 +195,39 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), }, negatives={}, - supervision_edge_types=[_A_TO_B, _A_TO_C], + expected_positive={ + _A_TO_B: {0: [2, 3]}, # g13->local2, g14->local3 in _B + _A_TO_C: {0: [2, 3]}, # g22->local2, g23->local3 in _C + }, + expected_negative={}, ), + # Multiple edge types WITH negatives on each. + param( + "multi_edge_type_with_negatives", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 15, 16]), + _C: torch.tensor([20, 21, 22, 23, 24, 25]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={ + _neg(_A_TO_B): torch.tensor([[15, 16]], dtype=torch.long), + _neg(_A_TO_C): torch.tensor([[24, 25]], dtype=torch.long), + }, + expected_positive={ + _A_TO_B: {0: [2, 3]}, + _A_TO_C: {0: [2, 3]}, + }, + expected_negative={ + _A_TO_B: {0: [4, 5]}, # g15->local4, g16->local5 in _B + _A_TO_C: {0: [4, 5]}, # g24->local4, g25->local5 in _C + }, + ), + # Every anchor empty (one fully padded, one with only absent ids): + # the edge type still appears, with an empty tensor per anchor. param( "all_anchors_empty", node_map={_STORY: torch.tensor([10, 11, 12])}, @@ -212,47 +237,131 @@ class EdgeListSetLabelsEquivalenceTest(TestCase): ) }, negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - param( - "zero_anchors", - node_map={_STORY: torch.tensor([10, 11, 12])}, - positives={_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)}, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], + expected_positive={_USER_TO_STORY: {0: [], 1: []}}, + expected_negative={}, ), ] ) - def test_edge_list_to_dict_matches_loop( + def test_to_dict_matches_constructed_expected( self, _, node_map: dict[NodeType, torch.Tensor], positives: dict[EdgeType, torch.Tensor], negatives: dict[EdgeType, torch.Tensor], - supervision_edge_types: list[EdgeType], + expected_positive: dict[PyGEdgeType, dict[int, list[int]]], + expected_negative: dict[PyGEdgeType, dict[int, list[int]]], ) -> None: - loop_pos, loop_neg = _loop_set_labels( + pos, neg = edge_list_set_labels( node_local_to_global_by_type=node_map, positive_labels_by_edge_type=positives, negative_labels_by_edge_type=negatives, - supervision_edge_types=supervision_edge_types, to_device=_CPU, ) - el_pos, el_neg = edge_list_set_labels( + for labels in pos.values(): + self.assertIsInstance(labels, AnchorLabels) + for labels in neg.values(): + self.assertIsInstance(labels, AnchorLabels) + expected_pos_tensors = { + et: {a: torch.tensor(v, dtype=torch.long) for a, v in inner.items()} + for et, inner in expected_positive.items() + } + expected_neg_tensors = { + et: {a: torch.tensor(v, dtype=torch.long) for a, v in inner.items()} + for et, inner in expected_negative.items() + } + _assert_dict_sets_equal( + {et: labels.to_dict() for et, labels in pos.items()}, + expected_pos_tensors, + ) + _assert_dict_sets_equal( + {et: labels.to_dict() for et, labels in neg.items()}, + expected_neg_tensors, + ) + + def test_sorted_node_map_pins_exact_edge_list_tensors(self) -> None: + # With a sorted node map column-visit order coincides with ascending-local + # order, so the kernel's AnchorLabels tensors are fully determined: pin + # them exactly. This locks the (anchor_index, label_index) co-indexing the + # loss relies on. Covers present labels, a multi-label anchor, and a + # fully-padded (empty) anchor. + node_map = {_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])} + positives = { + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1]], dtype=torch.long + ) + } + pos, _ = edge_list_set_labels( node_local_to_global_by_type=node_map, positive_labels_by_edge_type=positives, - negative_labels_by_edge_type=negatives, - supervision_edge_types=supervision_edge_types, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + labels = pos[_USER_TO_STORY] + self.assertEqual(labels.num_anchors, 3) + torch.testing.assert_close( + labels.anchor_index, torch.tensor([0, 1, 1], dtype=torch.long) + ) + torch.testing.assert_close( + labels.label_index, torch.tensor([5, 5, 6], dtype=torch.long) + ) + + def test_unsorted_node_map_correct_membership(self) -> None: + # The node map is UNSORTED, so torch.sort yields a non-identity sort_perm + # and the kernel must map sorted positions back through it to recover real + # local indices. We assert the per-anchor SET (within-anchor order is + # unspecified by contract): a regression that emitted sorted-position + # indices instead of sort_perm-mapped locals would produce {1, 3} here and + # fail, so this proves the local indices are real node indices. + # node = [15, 10, 16, 11] -> g15=local0, g10=local1, g16=local2, g11=local3. + # Labels [16, 15] -> SET {local0, local2} = {0, 2}. + node_map = {_STORY: torch.tensor([15, 10, 16, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long)} + pos, _ = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + self.assertEqual(sorted(pos[_USER_TO_STORY].to_dict()[0].tolist()), [0, 2]) + + def test_zero_anchor_tensor_yields_no_edge_type_key(self) -> None: + # A zero-anchor label tensor means there were no anchors for that edge + # type this batch, so it must NOT appear in the output at all (not as an + # empty entry). + node_map = {_STORY: torch.tensor([10, 11, 12])} + positives = {_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)} + pos, neg = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + self.assertEqual(pos, {}) + self.assertEqual(neg, {}) + + def test_device_placement_cpu(self) -> None: + # The output index tensors must land on the requested device. (CPU here; + # the CUDA counterpart lives in label_remap_cuda_device_test.py.) + node_map = {_STORY: torch.tensor([10, 11, 12, 13, 14, 15])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[15], [11]], dtype=torch.long)} + pos, _ = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, to_device=_CPU, ) - _assert_label_dicts_set_equal(_dict_from_edge_list(el_pos), loop_pos) - _assert_label_dicts_set_equal(_dict_from_edge_list(el_neg), loop_neg) + labels = pos[_USER_TO_STORY] + self.assertEqual(labels.anchor_index.device.type, "cpu") + self.assertEqual(labels.label_index.device.type, "cpu") + self.assertEqual(labels.anchor_index.dtype, torch.long) + self.assertEqual(labels.label_index.dtype, torch.long) def test_duplicate_node_map_raises_assertion(self) -> None: - # NOTE: the uniqueness check is gated on `__debug__`, so this is a no-op - # under `python -O` / `PYTHONOPTIMIZE`. GiGL node maps are unique by - # construction; the guard catches misuse only. The test asserts it fires - # under the default (non-optimized) interpreter used by the test suite. + # The membership lookup requires unique global ids; a duplicate would + # silently drop a local index. The guard is gated on ``__debug__`` (a + # no-op under ``python -O``), and GiGL node maps are unique by + # construction, so this only catches misuse. Assert it fires under the + # default (non-optimized) interpreter the suite runs on. node_map = {_STORY: torch.tensor([10, 10, 11])} positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} with self.assertRaises(AssertionError): @@ -260,7 +369,6 @@ def test_duplicate_node_map_raises_assertion(self) -> None: node_local_to_global_by_type=node_map, positive_labels_by_edge_type=positives, negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], to_device=_CPU, ) diff --git a/tests/unit/distributed/label_remap_cuda_device_test.py b/tests/unit/distributed/label_remap_cuda_device_test.py index 5bdb22fa5..4a98a3d60 100644 --- a/tests/unit/distributed/label_remap_cuda_device_test.py +++ b/tests/unit/distributed/label_remap_cuda_device_test.py @@ -1,25 +1,22 @@ -"""CUDA device-placement regression test for the ABLP label-remap kernels. - -``vectorized_set_labels`` and ``edge_list_set_labels`` build an internal -``anchor_of_entry`` index and then select it with a mask derived from the input -``label_tensor``. If that index is created on CPU while ``label_tensor`` is on -GPU, the masked select raises ``"indices should be either on cpu or on the same -device as the indexed tensor"``. CPU-only unit tests cannot observe this, so the -bug only surfaces on a real GPU training run. - -These tests run the kernels with all inputs on CUDA and assert the result equals -the CPU result. They are skipped when no GPU is present (e.g. CPU CI); run them -on a CUDA host to guard the device placement. +"""CUDA device-placement regression test for the ABLP label-remap kernel. + +``edge_list_set_labels`` builds an internal ``anchor_of_entry`` index and then +selects it with a mask derived from the input ``label_tensor``. If that index is +created on CPU while ``label_tensor`` is on GPU, the masked select raises +``"indices should be either on cpu or on the same device as the indexed +tensor"``. CPU-only unit tests cannot observe this, so the bug only surfaces on a +real GPU training run. + +This test runs the kernel with all inputs on CUDA and asserts the result matches +the CPU result. It is skipped when no GPU is present (e.g. CPU CI); run it on a +CUDA host to guard the device placement. """ import unittest import torch -from gigl.distributed.dist_ablp_neighborloader import ( - edge_list_set_labels, - vectorized_set_labels, -) +from gigl.distributed.dist_ablp_neighborloader import edge_list_set_labels from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.types.graph import message_passing_to_positive_label from tests.test_assets.test_case import TestCase @@ -50,64 +47,36 @@ def _inputs(device: torch.device): @unittest.skipUnless(torch.cuda.is_available(), "requires a CUDA device") class LabelRemapCudaDeviceTest(TestCase): - def test_vectorized_set_labels_cuda_matches_cpu(self) -> None: - cpu_node, cpu_pos = _inputs(torch.device("cpu")) - expected_pos, _ = vectorized_set_labels( - node_local_to_global_by_type=cpu_node, - positive_labels_by_edge_type=cpu_pos, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=torch.device("cpu"), - ) - - cuda = torch.device("cuda") - cuda_node, cuda_pos = _inputs(cuda) - # Must not raise the CPU/GPU index mismatch. - got_pos, _ = vectorized_set_labels( - node_local_to_global_by_type=cuda_node, - positive_labels_by_edge_type=cuda_pos, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=cuda, - ) - - self.assertEqual(set(got_pos.keys()), set(expected_pos.keys())) - for edge_type, inner in expected_pos.items(): - got_inner = got_pos[edge_type] - self.assertEqual(set(got_inner.keys()), set(inner.keys())) - for anchor, expected_tensor in inner.items(): - got = got_inner[anchor] - self.assertEqual(got.device.type, "cuda") - torch.testing.assert_close(got.cpu(), expected_tensor) - def test_edge_list_set_labels_cuda_matches_cpu(self) -> None: cpu_node, cpu_pos = _inputs(torch.device("cpu")) expected_pos, _ = edge_list_set_labels( node_local_to_global_by_type=cpu_node, positive_labels_by_edge_type=cpu_pos, negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], to_device=torch.device("cpu"), ) cuda = torch.device("cuda") cuda_node, cuda_pos = _inputs(cuda) + # Must not raise the CPU/GPU index mismatch. got_pos, _ = edge_list_set_labels( node_local_to_global_by_type=cuda_node, positive_labels_by_edge_type=cuda_pos, negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], to_device=cuda, ) self.assertEqual(set(got_pos.keys()), set(expected_pos.keys())) for edge_type, expected_labels in expected_pos.items(): got_labels = got_pos[edge_type] + # The output tensors must land on the requested CUDA device. self.assertEqual(got_labels.anchor_index.device.type, "cuda") + self.assertEqual(got_labels.label_index.device.type, "cuda") expected_dict = expected_labels.to_dict() got_dict = got_labels.to_dict() self.assertEqual(set(got_dict.keys()), set(expected_dict.keys())) for anchor, expected_tensor in expected_dict.items(): + # Multiplicity (the duplicate [15, 15] column) must survive too. torch.testing.assert_close(got_dict[anchor].cpu(), expected_tensor) diff --git a/tests/unit/distributed/vectorized_set_labels_test.py b/tests/unit/distributed/vectorized_set_labels_test.py deleted file mode 100644 index e4b894d9d..000000000 --- a/tests/unit/distributed/vectorized_set_labels_test.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Unit tests for the ABLP label-remap loop oracle and vectorized kernel. - -These exercise the pure-tensor label-remap logic directly (no GLT, no -distributed runtime), so they run in-process without ``mp.spawn``. -""" - -import unittest - -import torch -from parameterized import param, parameterized - -from gigl.distributed.dist_ablp_neighborloader import ( - _loop_set_labels, - vectorized_set_labels, -) -from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation -from gigl.types.graph import ( - DEFAULT_HOMOGENEOUS_EDGE_TYPE, - DEFAULT_HOMOGENEOUS_NODE_TYPE, - message_passing_to_negative_label, - message_passing_to_positive_label, -) -from tests.test_assets.test_case import TestCase - -_CPU = torch.device("cpu") -_USER = NodeType("user") -_STORY = NodeType("story") -_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) - -_A = NodeType("a") -_B = NodeType("b") -_C = NodeType("c") -_A_TO_B = EdgeType(_A, Relation("to"), _B) -_A_TO_C = EdgeType(_A, Relation("to"), _C) - - -def _pos(edge_type: EdgeType) -> EdgeType: - return message_passing_to_positive_label(edge_type) - - -def _neg(edge_type: EdgeType) -> EdgeType: - return message_passing_to_negative_label(edge_type) - - -def _assert_label_dicts_set_equal( - actual: dict[EdgeType, dict[int, torch.Tensor]], - expected: dict[EdgeType, dict[int, torch.Tensor]], -) -> None: - """Assert per-anchor SET equality (with multiplicity) between two label dicts. - - The vectorized kernel emits pairs in column-visit order; the loop oracle emits - them in ascending-local order. These differ when a row's local indices are not - monotone in column order (e.g. an unsorted node map with reversed label columns). - Both are valid: the ABLP contrastive loss is permutation-invariant over the pair - stream (``CrossEntropyLoss(reduction="sum")`` with value-based collision masks), - so within-anchor label order carries no meaning. - - This helper uses ``sorted(...)`` to normalise order before comparing, so it - catches membership errors and multiplicity errors (duplicate labels) while - remaining invariant to within-anchor permutations. - """ - assert set(actual.keys()) == set(expected.keys()), ( - f"{set(actual.keys())} != {set(expected.keys())}" - ) - for edge_type, inner in expected.items(): - actual_inner = actual[edge_type] - assert set(actual_inner.keys()) == set(inner.keys()), ( - f"{edge_type}: {set(actual_inner.keys())} != {set(inner.keys())}" - ) - for anchor, expected_tensor in inner.items(): - got = actual_inner[anchor] - assert got.dtype == torch.long, f"{edge_type}[{anchor}] dtype {got.dtype}" - # Sort both tensors so the comparison is order-independent but still - # catches missing/extra entries and multiplicity (duplicate labels). - assert sorted(got.tolist()) == sorted(expected_tensor.tolist()), ( - f"{edge_type}[{anchor}]: got {got.tolist()}, " - f"expected {expected_tensor.tolist()} (as sets)" - ) - - -class LoopSetLabelsContractTest(TestCase): - def test_homogeneous_with_empty_and_padded_anchors(self) -> None: - # node holds global ids; index = local id. Supervision node type is - # _STORY (edge_type[2] of the positive-label edge type). - node = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]) - node_map = {_STORY: node} - # anchor 0 -> global 15 (local 5); anchor 1 -> {15,16} (local 5,6); - # anchor 2 -> fully padded (empty); anchor 3 -> global 99 (absent -> empty). - positives = { - _pos(_USER_TO_STORY): torch.tensor( - [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long - ) - } - y_pos, y_neg = _loop_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=_CPU, - ) - expected = { - _USER_TO_STORY: { - 0: torch.tensor([5]), - 1: torch.tensor([5, 6]), - 2: torch.tensor([], dtype=torch.long), - 3: torch.tensor([], dtype=torch.long), - } - } - _assert_label_dicts_set_equal(y_pos, expected) - self.assertEqual(y_neg, {}) - - def test_duplicate_label_columns_preserve_multiplicity(self) -> None: - # torch.nonzero over [N, M] yields one row index per matching column, - # so a node matching two identical label columns appears twice. - node = torch.tensor([10, 11, 12, 13, 14, 15]) - node_map = {_STORY: node} - positives = {_pos(_USER_TO_STORY): torch.tensor([[15, 15]], dtype=torch.long)} - y_pos, _ = _loop_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=_CPU, - ) - torch.testing.assert_close(y_pos[_USER_TO_STORY][0], torch.tensor([5, 5])) - - -class VectorizedSetLabelsEquivalenceTest(TestCase): - @parameterized.expand( - [ - param( - "homogeneous_present_empty_and_padded", - node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, - positives={ - _pos(_USER_TO_STORY): torch.tensor( - [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long - ) - }, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - param( - "homogeneous_duplicate_labels", - node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15])}, - positives={ - _pos(_USER_TO_STORY): torch.tensor( - [[15, 15], [11, 11]], dtype=torch.long - ) - }, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - # UNSORTED node map + reversed label columns. With a sorted map - # `sort_perm` is the identity and a broken port that emits sorted (or - # global) order would still pass; here `torch.sort` permutes nontrivially - # (15->local0, 10->local1, 16->local2, 11->local3), and the label row is - # given high-id-first ([16, 15]). The kernel emits pairs in column-visit - # order (g16=local2 first, then g15=local0), while the loop oracle emits - # ascending-local order (local0, local2). The SET is {0, 2} for both -- - # a kernel that forgot to map through `sort_perm` would emit the wrong - # locals (e.g. sorted positions 1 and 3) and fail the set check. - param( - "unsorted_node_map_reversed_columns", - node_map={_STORY: torch.tensor([15, 10, 16, 11])}, - positives={ - _pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long) - }, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - # Real homogeneous keying: the loader keys the node map by - # DEFAULT_HOMOGENEOUS_NODE_TYPE and uses DEFAULT_HOMOGENEOUS_EDGE_TYPE - # for the supervision edge type (see _set_labels). Exercise that exact - # keying at the kernel level, not just a custom NodeType. - param( - "default_homogeneous_keying", - node_map={ - DEFAULT_HOMOGENEOUS_NODE_TYPE: torch.tensor([20, 10, 30, 11, 15]) - }, - positives={ - message_passing_to_positive_label( - DEFAULT_HOMOGENEOUS_EDGE_TYPE - ): torch.tensor([[30, 10], [-1, -1]], dtype=torch.long) - }, - negatives={}, - supervision_edge_types=[DEFAULT_HOMOGENEOUS_EDGE_TYPE], - ), - param( - "homogeneous_with_negatives", - node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, - positives={ - _pos(_USER_TO_STORY): torch.tensor([[15], [16]], dtype=torch.long) - }, - negatives={ - _neg(_USER_TO_STORY): torch.tensor( - [[13, 16], [17, -1]], dtype=torch.long - ) - }, - supervision_edge_types=[_USER_TO_STORY], - ), - param( - "heterogeneous_multi_edge_type", - node_map={ - _A: torch.tensor([10]), - _B: torch.tensor([11, 12, 13, 14, 20, 21]), - _C: torch.tensor([20, 21, 22, 23]), - }, - positives={ - _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), - _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), - }, - negatives={}, - supervision_edge_types=[_A_TO_B, _A_TO_C], - ), - param( - "heterogeneous_multi_edge_type_with_negatives", - node_map={ - _A: torch.tensor([10]), - _B: torch.tensor([11, 12, 13, 14, 15, 16]), - _C: torch.tensor([20, 21, 22, 23, 24, 25]), - }, - positives={ - _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), - _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), - }, - negatives={ - _neg(_A_TO_B): torch.tensor([[15, 16]], dtype=torch.long), - _neg(_A_TO_C): torch.tensor([[24, 25]], dtype=torch.long), - }, - supervision_edge_types=[_A_TO_B, _A_TO_C], - ), - param( - "all_anchors_empty", - node_map={_STORY: torch.tensor([10, 11, 12])}, - positives={ - _pos(_USER_TO_STORY): torch.tensor( - [[-1, -1], [99, 98]], dtype=torch.long - ) - }, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - param( - "zero_anchors", - node_map={_STORY: torch.tensor([10, 11, 12])}, - positives={_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)}, - negatives={}, - supervision_edge_types=[_USER_TO_STORY], - ), - ] - ) - def test_matches_loop( - self, - _, - node_map: dict[NodeType, torch.Tensor], - positives: dict[EdgeType, torch.Tensor], - negatives: dict[EdgeType, torch.Tensor], - supervision_edge_types: list[EdgeType], - ) -> None: - loop_pos, loop_neg = _loop_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type=negatives, - supervision_edge_types=supervision_edge_types, - to_device=_CPU, - ) - vec_pos, vec_neg = vectorized_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type=negatives, - supervision_edge_types=supervision_edge_types, - to_device=_CPU, - ) - _assert_label_dicts_set_equal(vec_pos, loop_pos) - _assert_label_dicts_set_equal(vec_neg, loop_neg) - - def test_unsorted_node_map_correct_membership(self) -> None: - # Belt-and-suspenders on top of the parameterized case: assert the EXACT - # local-index SET for the unsorted-node-map case so a regression to - # wrong locals (e.g. sorted-position indices instead of sort_perm-mapped - # locals) is unmistakable, even if within-anchor order is not pinned. - # Layout: node = [15, 10, 16, 11] -> g15=local0, g10=local1, g16=local2, - # g11=local3. Labels: [16, 15] -> expected SET {local0, local2} = {0, 2}. - node_map = {_STORY: torch.tensor([15, 10, 16, 11])} - positives = {_pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long)} - vec_pos, _ = vectorized_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=_CPU, - ) - # The kernel emits column order (g16 first, g15 second) -> [2, 0]. - # The loop oracle emits ascending-local order -> [0, 2]. - # Both are SET-equal to {0, 2}; only membership and co-indexing matter. - assert sorted(vec_pos[_USER_TO_STORY][0].tolist()) == [0, 2] - - def test_duplicate_node_map_raises_assertion(self) -> None: - # NOTE: the uniqueness check is gated on `__debug__`, so this assertion is - # a no-op under `python -O` / `PYTHONOPTIMIZE`. GiGL node maps are unique by - # construction (each local index is a distinct subgraph node), so the guard - # exists only to catch misuse; the test asserts the guard fires under the - # default (non-optimized) interpreter used by the test suite. - node_map = {_STORY: torch.tensor([10, 10, 11])} - positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} - with self.assertRaises(AssertionError): - vectorized_set_labels( - node_local_to_global_by_type=node_map, - positive_labels_by_edge_type=positives, - negative_labels_by_edge_type={}, - supervision_edge_types=[_USER_TO_STORY], - to_device=_CPU, - ) - - -if __name__ == "__main__": - unittest.main() From 5a58a53e52959110fe5a157d492cd9daac7d529a Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Fri, 26 Jun 2026 21:47:29 +0000 Subject: [PATCH 16/18] docs(ablp): why-first docstrings, worked examples, tensor-dim annotations Comments and docstrings only -- no behavior change (verified: the AST with docstrings stripped is identical to the prior commit for every touched file). - Docstrings lead with why over what; the membership-remap algorithm is shown as a small worked example (node map + [N_anchors, M] padded label tensor -> pairs) with labeled steps instead of dense prose. - Inline comments reference those docstring steps rather than floating free. - Tensor dimensions annotated throughout (N_anchors / M / N_nodes / K / E), including the K (non-padding candidates) vs E (matched pairs) distinction. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../graph_store/heterogeneous_training.py | 9 +- .../graph_store/homogeneous_training.py | 9 +- .../link_prediction/heterogeneous_training.py | 9 +- .../link_prediction/homogeneous_training.py | 20 +- gigl/distributed/dist_ablp_neighborloader.py | 201 ++++++++++++------ 5 files changed, 157 insertions(+), 91 deletions(-) diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 913de8053..9b32807ca 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -301,11 +301,12 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is an AnchorLabels edge-list; read label_index and - # query_node_idx[anchor_index] directly (see the AnchorLabels class docstring). - positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] repeated_query_node_idx = query_node_idx[ - main_data.y_positive.anchor_index.to(device) + main_data.y_positive.anchor_index.to(device) # [E] ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 706f425c6..8d4de7cc1 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -307,11 +307,12 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - # main_data.y_positive is an AnchorLabels edge-list; read label_index and - # query_node_idx[anchor_index] directly (see the AnchorLabels class docstring). - positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] repeated_query_node_idx = query_node_idx[ - main_data.y_positive.anchor_index.to(device) + main_data.y_positive.anchor_index.to(device) # [E] ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index 7a1522169..a60f72506 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -225,11 +225,12 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is an AnchorLabels edge-list; read label_index and - # query_node_idx[anchor_index] directly (see homogeneous_training._compute_loss). - positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] repeated_query_node_idx = query_node_idx[ - main_data.y_positive.anchor_index.to(device) + main_data.y_positive.anchor_index.to(device) # [E] ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index f8de44910..0b7ad0e4d 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -192,16 +192,18 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True). - # label_index holds the local label node per (anchor, label) pair; anchor_index - # holds the matching local anchor row. Pairs are grouped by anchor, so reading - # label_index/anchor_index directly yields the same (query, label) pairs as the - # historical dict read (torch.cat(list(values())) + repeat_interleave over - # per-anchor lengths). The within-anchor label order may differ, but the - # contrastive loss is order-invariant, so the result is equivalent. - positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), two + # co-indexed [E] tensors: label_index holds the local label node per + # (anchor, label) pair, anchor_index the matching local anchor row. Pairs are + # grouped by anchor, so reading label_index/anchor_index directly yields the same + # (query, label) pairs as the historical dict read (torch.cat(list(values())) + + # repeat_interleave over per-anchor lengths). The within-anchor label order may + # differ, but the contrastive loss is order-invariant, so the result is + # equivalent. (This is the canonical explanation; the other link-prediction + # examples point here rather than restating it.) + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] repeated_query_node_idx = query_node_idx[ - main_data.y_positive.anchor_index.to(device) + main_data.y_positive.anchor_index.to(device) # [E] ] if hasattr(main_data, "y_negative"): hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 33945f662..c3d1c7d54 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -62,10 +62,9 @@ class AnchorLabels: Each anchor can carry a different number of labels, so the natural shape is ragged. Rather than a ``dict[int, torch.Tensor]`` -- which needs padding to - batch and a Python loop to read -- we keep two parallel ``long`` tensors: - pair ``k`` says local anchor ``anchor_index[k]`` is labeled by local node - ``label_index[k]``. Downstream the loss can index straight into these tensors - with no per-anchor iteration. + batch and a Python loop to read -- we keep two co-indexed ``long`` tensors so + the loss can index straight into them with no per-anchor iteration. The data + example below makes the layout concrete. Order within an anchor is deliberately left unspecified. The ABLP contrastive loss (:class:`gigl.nn.loss.RetrievalLoss`) scores every (anchor, label) pair @@ -79,6 +78,24 @@ class AnchorLabels: separately so :meth:`to_dict` can still emit a key for every anchor, even the ones that matched nothing. + Dimension vocabulary (used throughout the label-remap code): + + - ``N_anchors`` -- anchor rows (rows of the source padded label tensor). + - ``M`` -- padded label columns per anchor. + - ``N_nodes`` -- nodes in the supervision local->global map. + - ``K`` -- non-padding candidate labels (after dropping the ``-1`` pad, before + membership filtering; still includes globals absent from the subgraph). + - ``E`` -- surviving ``(anchor, label)`` pairs after membership filtering + (``E <= K``); the length of ``anchor_index`` / ``label_index``. + + The ragged dict and this edge list hold the same labels in different + containers (three anchors, two of which carry labels):: + + dict form {0: [3], 1: [5, 7], 2: []} + edge-list form anchor_index = [0, 1, 1] # [E] = 3 + label_index = [3, 5, 7] # [E] = 3 + num_anchors = 3 # anchor 2 contributes no pair + Example:: >>> import torch @@ -93,8 +110,8 @@ class AnchorLabels: Args: anchor_index (torch.Tensor): ``[E]`` long tensor of local anchor rows. label_index (torch.Tensor): ``[E]`` long tensor of local label node ids. - num_anchors (int): Total number of anchors ``N`` (rows of the source - padded label tensor), including anchors with no labels. + num_anchors (int): Total number of anchors ``N_anchors`` (rows of the + source padded label tensor), including anchors with no labels. """ anchor_index: torch.Tensor @@ -124,24 +141,41 @@ def _membership_remap( ) -> tuple[torch.Tensor, torch.Tensor, int]: """Resolve one padded label tensor to a flat ``(anchor, label)`` pair stream. - This is where the actual global-id-to-local-index lookup happens; everything - above it just packages the result. Given one ``[N_anchors, M]`` block of - ``-1``-padded global label ids, it returns the matched pairs as two parallel - index tensors plus the anchor count, the raw form :class:`AnchorLabels` wraps. - - The lookup is a sorted-membership join rather than a per-anchor scan: sort the - node map once, ``searchsorted`` every label id into it, and keep only exact - hits. The sort permutation then carries each hit back to its original local - index. This trades an ``O(N_anchors * M * N_nodes)`` broadcast-compare for a - single ``O(E log N_nodes)`` search, which is what lets the loader remap labels - without a Python loop over anchors. - - Pairs come out grouped by anchor for free: ``anchor_of_entry`` is built - row-major (``arange(N).repeat_interleave(M)``) and every mask preserves that - order, so the stream is non-decreasing in ``anchor_index`` without any sort. - Callers can group by anchor with a plain ``bincount``/``split``. Order *within* - an anchor is left as it falls out of the columns -- unspecified by contract, - since the loss does not care (see :class:`AnchorLabels`). + This is the actual global-id-to-local-index lookup; :class:`AnchorLabels` and + :func:`edge_list_set_labels` just package what it returns. (See + :class:`AnchorLabels` for the ``N_anchors`` / ``M`` / ``N_nodes`` / ``K`` / ``E`` + dimension vocabulary used below.) + + The lookup is a sorted-membership join rather than a per-anchor scan, which is + what lets the loader remap labels without a Python loop over anchors: it trades + an ``O(N_anchors * M * N_nodes)`` broadcast-compare for a single + ``O(K log N_nodes)`` search. + + Worked example (generic ids):: + + node map (local -> global): [40, 10, 30] # N_nodes = 3 + sorted_node = [10, 30, 40], sort_perm = [1, 2, 0] + label_tensor ([N_anchors=2, M=2], -1 = pad): + [[30, -1], + [40, 10]] + + step 0 flatten row-major, tag each entry with its anchor row, drop pad: + flat = [30, 40, 10] # [K] = 3 candidates + anchor_of_* = [ 0, 1, 1] # [K] + step 1 searchsorted(sorted_node, flat) -> [1, 2, 0] # [K] positions + step 2 keep exact members (sorted_node[pos] == flat): all 3 -> [E] + step 3 sort_perm[pos] -> local index: [2, 0, 1] # [E] + result anchor_index = [0, 1, 1], label_index = [2, 0, 1] # [E] = 3 + + Check by hand: g30 is local 2, g40 is local 0, g10 is local 1 -- matches. Here + every candidate is a member so ``K == E``; the two differ once a global id is + absent from the node map (step 2 drops it). The code names the step-3 output + ``local_index``; :class:`AnchorLabels` stores that same tensor as ``label_index``. + + Because ``anchor_of_*`` is built row-major (step 0) and every mask preserves + order, the result is already grouped by anchor (non-decreasing ``anchor_index``) + with no argsort; order *within* an anchor is unspecified by contract, since the + loss does not care (see :class:`AnchorLabels`). The lookup is only correct if ``sorted_node`` has unique values: :func:`torch.searchsorted` returns the left-most equal position, so a repeated @@ -153,16 +187,17 @@ def _membership_remap( Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global label ids. - sorted_node (torch.Tensor): ``torch.sort`` of the supervision node map. - sort_perm (torch.Tensor): Permutation from ``torch.sort`` mapping sorted - positions back to original local indices. + sorted_node (torch.Tensor): ``[N_nodes]`` sorted values of the supervision + node map (the ``values`` half of ``torch.sort``). + sort_perm (torch.Tensor): ``[N_nodes]`` permutation from ``torch.sort`` + mapping sorted positions back to original local indices. to_device (torch.device): Device for the returned index tensors. Returns: Tuple ``(anchor_index, local_index, num_anchors)``. ``anchor_index`` and - ``local_index`` are equal-length 1-D ``long`` tensors on ``to_device``, - grouped by anchor (empty when nothing matched); - ``num_anchors == label_tensor.size(0)``. + ``local_index`` are co-indexed ``[E]`` tensors grouped by anchor (empty when + nothing matched); ``local_index`` becomes :attr:`AnchorLabels.label_index`. + ``num_anchors == N_anchors == label_tensor.size(0)``. """ num_anchors = int(label_tensor.size(0)) num_nodes = int(sorted_node.size(0)) @@ -171,51 +206,52 @@ def _membership_remap( return empty, empty, num_anchors num_labels = int(label_tensor.size(1)) - flat = label_tensor.reshape(-1) + flat = label_tensor.reshape(-1) # [N_anchors * M] before the pad mask + # step 0 (see docstring example): tag each flattened entry with its anchor row. # Build on the label tensor's device: `anchor_of_entry` is indexed below by # `is_present` (derived from `label_tensor`). On GPU, a CPU arange would # raise "indices should be either on cpu or on the same device as the indexed # tensor". CPU-only unit tests cannot catch this; see the CUDA-gated test. anchor_of_entry = torch.arange( num_anchors, device=label_tensor.device - ).repeat_interleave(num_labels) + ).repeat_interleave(num_labels) # [N_anchors * M] before the pad mask - # Mask the padding sentinel BEFORE any search so we never gather with -1. + # step 0 cont.: drop the -1 pad before any search so we never gather with a + # sentinel. This is the [N_anchors * M] -> [K] (candidate) reduction. is_present = flat != PADDING_NODE - flat = flat[is_present] - anchor_of_entry = anchor_of_entry[is_present] + flat = flat[is_present] # [K] + anchor_of_entry = anchor_of_entry[is_present] # [K] if num_nodes == 0 or flat.numel() == 0: return empty, empty, num_anchors if __debug__: - # `sorted_node` is already sorted, so uniqueness is equivalent to being - # strictly increasing -- a cheap adjacent-difference check, no re-sort. + # Precondition for step 1 (see docstring): `sorted_node` is already sorted, + # so uniqueness is equivalent to being strictly increasing -- a cheap + # adjacent-difference check, no re-sort. assert bool((sorted_node[1:] > sorted_node[:-1]).all()), ( "vectorized label remap requires a unique node local->global map; " "duplicate global ids break the searchsorted membership lookup." ) - # 1. Locate each label id in the sorted node map: searchsorted returns the - # insertion point, so `sorted_positions[i]` is the candidate index in - # `sorted_node` where `flat[i]` would be inserted to keep order sorted. - sorted_positions = torch.searchsorted(sorted_node, flat) + # step 1 (see docstring example): position of each candidate id in sorted_node. + sorted_positions = torch.searchsorted(sorted_node, flat) # [K] + # searchsorted returns N_nodes for an id larger than every entry, which would + # gather out of bounds at step 2; clamp it back into range. sorted_positions = sorted_positions.clamp_(max=num_nodes - 1) - # 2. Keep only exact matches (drop global ids absent from the subgraph). - # `sorted_node[sorted_positions] == flat` is True iff flat[i] is actually - # in the node map (not just a neighboring element in the sorted array). - is_exact_match = sorted_node[sorted_positions] == flat + # step 2 (see docstring example): keep only true members (a neighboring entry + # in the sorted array is not a match). This is the [K] -> [E] filter. + is_exact_match = sorted_node[sorted_positions] == flat # [K] bool - # 3. Map sorted position -> original local index via sort_perm: sort_perm[j] - # is the local node index whose global id landed at sorted position j. - local_index = sort_perm[sorted_positions][is_exact_match] - anchor_of_matched = anchor_of_entry[is_exact_match] + # step 3 (see docstring example): sorted position -> original local node index + # via sort_perm (becomes AnchorLabels.label_index). + local_index = sort_perm[sorted_positions][is_exact_match] # [E] + anchor_of_matched = anchor_of_entry[is_exact_match] # [E] - # Pairs are now in (anchor, column) order -- non-decreasing in anchor_index - # because anchor_of_entry is row-major and the masks preserve order. No - # argsort is needed: within-anchor label order is unspecified by contract - # (the ABLP loss is permutation-invariant; see AnchorLabels docstring). + # Result rows stay grouped by anchor (step 0 tagging was row-major and the masks + # preserve order), so no argsort is needed; within-anchor order is unspecified -- + # the ABLP loss is order-invariant (see AnchorLabels). return ( anchor_of_matched.to(to_device).to(torch.long), local_index.to(to_device).to(torch.long), @@ -251,7 +287,7 @@ def edge_list_set_labels( Args: node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node - type, a ``[N]`` tensor whose ``i``-th entry is the global id of + type, a ``[N_nodes]`` tensor whose ``i``-th entry is the global id of local node ``i``. Global ids MUST be unique within each map. positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor @@ -277,7 +313,8 @@ def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: def _remap( labels_by_edge_type: dict[EdgeType, torch.Tensor], ) -> dict[EdgeType, AnchorLabels]: - # Supervision edge types are (anchor_type, relation, supervision_type). + # Supervision edge types are (anchor_type, relation, supervision_type), so + # the supervision node type is index 2 (used below). supervision_node_type_index = 2 output: dict[EdgeType, AnchorLabels] = {} for edge_type, label_tensor in labels_by_edge_type.items(): @@ -285,6 +322,8 @@ def _remap( if label_tensor.size(0) == 0: continue sorted_node, sort_perm = _sorted_for(edge_type[supervision_node_type_index]) + # Remap globals -> locals via the sorted-membership join (see the + # labeled steps in _membership_remap). output[label_edge_type_to_message_passing_edge_type(edge_type)] = ( AnchorLabels( *_membership_remap(label_tensor, sorted_node, sort_perm, to_device) @@ -467,10 +506,11 @@ def __init__( use_list_output (bool): Return labels as an ``AnchorLabels`` edge-list (or ``dict[EdgeType, AnchorLabels]`` for multiple supervision edge types) instead of the ragged ``dict[anchor_local_index, - torch.Tensor]``; see :class:`AnchorLabels` for the shape. The - edge-list lets the loss read ``y.label_index`` and - ``query_idx[y.anchor_index]`` directly. Defaults to ``False`` (the - backward-compatible ragged dict). + torch.Tensor]``. The edge-list lets the loss read the co-indexed + ``y.label_index`` and ``query_idx[y.anchor_index]`` (both ``[E]``) + directly; see :class:`AnchorLabels` for the shape and the ``[E]`` + vocabulary. Defaults to ``False`` (the backward-compatible ragged + dict). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -1011,20 +1051,41 @@ def _set_labels( positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], ) -> Union[Data, HeteroData]: - """ - Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their - local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be - exposed in the final HeteroData/Data object. + """Attach ABLP labels to the collated graph, remapped to subgraph-local indices. + + This is the collation hook that turns the sampler's global-id labels into the + ``y_positive`` / ``y_negative`` fields downstream training reads. + + The actual remap is delegated to :func:`edge_list_set_labels` (the single + kernel): with ``use_list_output`` the labels are attached as an + :class:`AnchorLabels` edge list, otherwise expanded to the ragged + ``dict[anchor_local_index, torch.Tensor]`` via :meth:`AnchorLabels.to_dict`. + Both are the same labels in a different container. + + The supervision edge type is an internal sampling artifact, so it is stripped + before return and never appears on the output object. + + ``y_positive`` / ``y_negative`` collapse to a single value when there is one + supervision edge type, or a ``dict[EdgeType, ...]`` for several; see + :meth:`DistABLPLoader.__init__` for the full shape contract. + Args: - data (Union[Data, HeteroData]): Graph to provide labels for - positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[positive label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - negative_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[negative label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. + data (Union[Data, HeteroData]): Graph to attach labels to. + positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` tensor whose ``i``-th + row holds the global label ids of the ``i``-th anchor. + negative_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): As + above, for negative-label edge types. + Returns: - Union[Data, HeteroData]: torch_geometric HeteroData/Data object with the filtered edge fields and labels set as properties of the instance + Union[Data, HeteroData]: The same object with the supervision edge fields + stripped and ``y_positive`` (and ``y_negative`` when present) attached. + + Raises: + ValueError: If no positive labels are found in ``data``. """ - # shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i` + # node_type_to_local_node_to_global_node[t][i]: global id of local node i; + # each value tensor is [N_nodes] for its node type. node_type_to_local_node_to_global_node: dict[NodeType, torch.Tensor] = {} if isinstance(data, HeteroData): for e_type in self._supervision_edge_types: From b70e626459a016c870be34d201ee40cb40f1e0a9 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 29 Jun 2026 18:04:05 +0000 Subject: [PATCH 17/18] docs(ablp): worked example for use_list_output=True; harden node-map uniqueness guard - Add a use_list_output=True worked example to DistABLPLoader.__init__ that mirrors the use_list_output=False example (same graph, AnchorLabels values). - Document AnchorLabels.to_dict()'s non-decreasing anchor_index precondition. - Promote the _membership_remap duplicate-node-map guard from a __debug__ assert to an always-on ValueError so it still fires under `python -O`; update the test to expect ValueError. Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 32 ++++++++++++++++--- .../distributed/edge_list_set_labels_test.py | 12 +++---- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index c3d1c7d54..18e497186 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -124,6 +124,12 @@ def to_dict(self) -> dict[int, torch.Tensor]: Every anchor ``0..num_anchors-1`` receives a key; anchors with no labels map to an empty ``long`` tensor on the same device as ``label_index``. + ``anchor_index`` MUST be non-decreasing (grouped by anchor). This holds for + every :class:`AnchorLabels` produced by :func:`edge_list_set_labels`, whose + row-major construction already groups pairs by anchor. The split below + relies on that ordering and is not validated; passing an ungrouped + ``anchor_index`` would silently assign labels to the wrong anchors. + Returns: Mapping from anchor index to its 1-D ``long`` tensor of local label node ids. @@ -225,11 +231,14 @@ def _membership_remap( if num_nodes == 0 or flat.numel() == 0: return empty, empty, num_anchors - if __debug__: - # Precondition for step 1 (see docstring): `sorted_node` is already sorted, - # so uniqueness is equivalent to being strictly increasing -- a cheap - # adjacent-difference check, no re-sort. - assert bool((sorted_node[1:] > sorted_node[:-1]).all()), ( + # Precondition for step 1 (see docstring): `sorted_node` is already sorted, so + # uniqueness is equivalent to being strictly increasing -- a cheap adjacent- + # difference check, no re-sort. Always-on (not `__debug__`) so it still fires + # under `python -O`: a duplicate would silently drop labels and corrupt the loss. + # `sorted_node` has at least one entry here (`num_nodes == 0` returned above), so + # the empty-slice comparison is `True` for a 1-element map. + if not bool((sorted_node[1:] > sorted_node[:-1]).all()): + raise ValueError( "vectorized label remap requires a unique node local->global map; " "duplicate global ids break the searchsorted membership lookup." ) @@ -428,6 +437,19 @@ def __init__( but the ABLP loss is order-invariant, and `AnchorLabels.to_dict()` recovers the dict form. + When `use_list_output=True`, the same example graph above (sampling around + node `0`) produces this format instead: + - `y_positive`: AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([1]), num_anchors=1) + - `y_negative`: AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([2]), num_anchors=1) + + And, as above, the label fields are instead `dict[EdgeType, AnchorLabels]` if multiple supervision edge types are provided. + e.g. for supervision edge types (a, to, b) and (a, to, c): + - `y_positive`: {(a, to, b): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([1]), num_anchors=1), (a, to, c): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([2]), num_anchors=1)} + - `y_negative`: {(a, to, b): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([3]), num_anchors=1), (a, to, c): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([4]), num_anchors=1)} + + These hold the same labels as the `use_list_output=False` example, in a different container: + `AnchorLabels.to_dict()` on each value recovers the corresponding ragged dict (e.g. `{0: torch.tensor([1])}`). + Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py index f57831980..8cea2e807 100644 --- a/tests/unit/distributed/edge_list_set_labels_test.py +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -356,15 +356,15 @@ def test_device_placement_cpu(self) -> None: self.assertEqual(labels.anchor_index.dtype, torch.long) self.assertEqual(labels.label_index.dtype, torch.long) - def test_duplicate_node_map_raises_assertion(self) -> None: + def test_duplicate_node_map_raises(self) -> None: # The membership lookup requires unique global ids; a duplicate would - # silently drop a local index. The guard is gated on ``__debug__`` (a - # no-op under ``python -O``), and GiGL node maps are unique by - # construction, so this only catches misuse. Assert it fires under the - # default (non-optimized) interpreter the suite runs on. + # silently drop a local index and corrupt the loss. The guard is an + # always-on ``ValueError`` (not ``__debug__``), so it fires even under + # ``python -O``. GiGL node maps are unique by construction, so this only + # catches misuse of the now-public kernel. node_map = {_STORY: torch.tensor([10, 10, 11])} positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): edge_list_set_labels( node_local_to_global_by_type=node_map, positive_labels_by_edge_type=positives, From aa978efae0cac647dc4597410d8dd34155dd0eb8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Mon, 29 Jun 2026 18:12:31 +0000 Subject: [PATCH 18/18] docs(ablp): scrub history-relative comments in label-remap code Comments should describe the code as it is, not as a delta from a prior or never-committed version. Reword/remove comments that only parse if the reader knows the code's development history: - Drop the "Always-on (not __debug__) ... python -O" framing on the node-map uniqueness guard; keep the precondition + empty-slice safety rationale. - Fix a stale _membership_remap docstring line still describing a __debug__ assertion (it is now an always-on ValueError). - Remove "with no argsort" / "no argsort needed" (presupposed the dropped order-reproduction argsort); state the grouped-by-anchor property directly. - Reword "a sorted-membership join rather than a per-anchor scan ... trades a broadcast-compare for a search" to describe the algorithm and its complexity as-is. - Drop "rather than maintaining a second kernel" in _set_labels. - Same scrub in edge_list_set_labels_test.py (module docstring + guard test). Co-Authored-By: Claude Opus 4.8 (1M context) --- gigl/distributed/dist_ablp_neighborloader.py | 34 +++++++++---------- .../distributed/edge_list_set_labels_test.py | 16 ++++----- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 18e497186..2b1b3d9d7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -152,10 +152,9 @@ def _membership_remap( :class:`AnchorLabels` for the ``N_anchors`` / ``M`` / ``N_nodes`` / ``K`` / ``E`` dimension vocabulary used below.) - The lookup is a sorted-membership join rather than a per-anchor scan, which is - what lets the loader remap labels without a Python loop over anchors: it trades - an ``O(N_anchors * M * N_nodes)`` broadcast-compare for a single - ``O(K log N_nodes)`` search. + The lookup is a sorted-membership join: a single ``O(K log N_nodes)`` + ``searchsorted`` over the sorted node map resolves every candidate label at + once, so the loader remaps labels with no Python loop over anchors. Worked example (generic ids):: @@ -179,16 +178,16 @@ def _membership_remap( ``local_index``; :class:`AnchorLabels` stores that same tensor as ``label_index``. Because ``anchor_of_*`` is built row-major (step 0) and every mask preserves - order, the result is already grouped by anchor (non-decreasing ``anchor_index``) - with no argsort; order *within* an anchor is unspecified by contract, since the - loss does not care (see :class:`AnchorLabels`). + order, the result is already grouped by anchor (non-decreasing ``anchor_index``); + order *within* an anchor is unspecified by contract, since the loss does not care + (see :class:`AnchorLabels`). The lookup is only correct if ``sorted_node`` has unique values: :func:`torch.searchsorted` returns the left-most equal position, so a repeated global id would map every match to the same local index and silently drop the rest. GiGL ``node`` maps are unique by construction (one entry per subgraph - node), so this holds in production; the ``__debug__`` assertion guards against - misuse with a cheap adjacent-difference check on the already-sorted map. + node), so this holds in production; a cheap adjacent-difference check on the + already-sorted map raises ``ValueError`` if a caller violates it. Args: label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global @@ -233,8 +232,8 @@ def _membership_remap( # Precondition for step 1 (see docstring): `sorted_node` is already sorted, so # uniqueness is equivalent to being strictly increasing -- a cheap adjacent- - # difference check, no re-sort. Always-on (not `__debug__`) so it still fires - # under `python -O`: a duplicate would silently drop labels and corrupt the loss. + # difference check, no re-sort. A duplicate global id would silently drop labels + # and corrupt the loss, so reject it. # `sorted_node` has at least one entry here (`num_nodes == 0` returned above), so # the empty-slice comparison is `True` for a 1-element map. if not bool((sorted_node[1:] > sorted_node[:-1]).all()): @@ -259,8 +258,8 @@ def _membership_remap( anchor_of_matched = anchor_of_entry[is_exact_match] # [E] # Result rows stay grouped by anchor (step 0 tagging was row-major and the masks - # preserve order), so no argsort is needed; within-anchor order is unspecified -- - # the ABLP loss is order-invariant (see AnchorLabels). + # preserve order); within-anchor order is unspecified -- the ABLP loss is + # order-invariant (see AnchorLabels). return ( anchor_of_matched.to(to_device).to(torch.long), local_index.to(to_device).to(torch.long), @@ -1117,11 +1116,10 @@ def _set_labels( node_type_to_local_node_to_global_node[DEFAULT_HOMOGENEOUS_NODE_TYPE] = ( data.node ) - # The edge-list kernel is the single remap path; the ragged dict is just - # one view of it (AnchorLabels.to_dict), so when the caller wants the dict - # we expand here rather than maintaining a second kernel. Both forms feed - # an order-invariant contrastive loss, so the choice is purely about the - # consumer's preferred shape. + # The edge-list kernel is the remap path; the ragged dict is one view of it + # (AnchorLabels.to_dict), expanded here when the caller wants the dict. Both + # forms feed an order-invariant contrastive loss, so the choice is purely + # about the consumer's preferred shape. output_positive_labels, output_negative_labels = edge_list_set_labels( node_local_to_global_by_type=node_type_to_local_node_to_global_node, positive_labels_by_edge_type=positive_labels_by_label_edge_type, diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py index 8cea2e807..32eee2a97 100644 --- a/tests/unit/distributed/edge_list_set_labels_test.py +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -3,11 +3,10 @@ These exercise the pure-tensor label-remap logic directly (no GLT, no distributed runtime), so they run in-process without ``mp.spawn``. -``edge_list_set_labels`` is the loader's single label-remap path; it turns padded -blocks of global label ids into per-edge-type :class:`AnchorLabels` edge lists. -We assert against constructed expected values rather than a second reference -implementation, and we check both the edge-list tensors and their -:meth:`AnchorLabels.to_dict` view, since the loader uses both forms. +``edge_list_set_labels`` is the loader's label-remap path; it turns padded blocks +of global label ids into per-edge-type :class:`AnchorLabels` edge lists. We assert +against constructed expected values, and we check both the edge-list tensors and +their :meth:`AnchorLabels.to_dict` view, since the loader uses both forms. Within an anchor the kernel emits labels in column-visit order, which is left unspecified by contract (the ABLP loss is order-invariant). We therefore pin @@ -358,10 +357,9 @@ def test_device_placement_cpu(self) -> None: def test_duplicate_node_map_raises(self) -> None: # The membership lookup requires unique global ids; a duplicate would - # silently drop a local index and corrupt the loss. The guard is an - # always-on ``ValueError`` (not ``__debug__``), so it fires even under - # ``python -O``. GiGL node maps are unique by construction, so this only - # catches misuse of the now-public kernel. + # silently drop a local index and corrupt the loss, so the kernel raises + # ``ValueError``. GiGL node maps are unique by construction, so this only + # guards misuse of the public kernel. node_map = {_STORY: torch.tensor([10, 10, 11])} positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} with self.assertRaises(ValueError):