diff --git a/examples/link_prediction/graph_store/heterogeneous_training.py b/examples/link_prediction/graph_store/heterogeneous_training.py index 2be34e608..9b32807ca 100644 --- a/examples/link_prediction/graph_store/heterogeneous_training.py +++ b/examples/link_prediction/graph_store/heterogeneous_training.py @@ -211,6 +211,8 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. + use_list_output=True, ) print(f"---Rank {rank} finished setting up main loader for split={split}") @@ -299,16 +301,15 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) # [E] + ] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) diff --git a/examples/link_prediction/graph_store/homogeneous_training.py b/examples/link_prediction/graph_store/homogeneous_training.py index 2d4c22788..8d4de7cc1 100644 --- a/examples/link_prediction/graph_store/homogeneous_training.py +++ b/examples/link_prediction/graph_store/homogeneous_training.py @@ -241,6 +241,8 @@ def _setup_dataloaders( channel_size=sampling_worker_shared_channel_size, process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader for split={split}") @@ -305,16 +307,15 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) # [E] + ] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index 8ed672b7c..a60f72506 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -144,6 +144,8 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader") @@ -223,18 +225,15 @@ def _compute_loss( ).to(device) random_negative_batch_size = random_negative_data[labeled_node_type].batch_size - # main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor], even in the heterogeneous setting. - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - # We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E] + # label_index and query_node_idx[anchor_index] directly. See + # homogeneous_training._compute_loss for why this equals the historical dict read. + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) # [E] + ] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index b95a77489..0b7ad0e4d 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -134,6 +134,8 @@ def _setup_dataloaders( # This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM process_start_gap_seconds=process_start_gap_seconds, shuffle=shuffle, + # Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring. + use_list_output=True, ) logger.info(f"---Rank {rank} finished setting up main loader") @@ -190,18 +192,21 @@ def _compute_loss( query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device) random_negative_batch_size = random_negative_data.batch_size - # main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor] - positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to( - device - ) - # We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has - repeated_query_node_idx = query_node_idx.repeat_interleave( - torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device) - ) + # main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), two + # co-indexed [E] tensors: label_index holds the local label node per + # (anchor, label) pair, anchor_index the matching local anchor row. Pairs are + # grouped by anchor, so reading label_index/anchor_index directly yields the same + # (query, label) pairs as the historical dict read (torch.cat(list(values())) + + # repeat_interleave over per-anchor lengths). The within-anchor label order may + # differ, but the contrastive loss is order-invariant, so the result is + # equivalent. (This is the canonical explanation; the other link-prediction + # examples point here rather than restating it.) + positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E] + repeated_query_node_idx = query_node_idx[ + main_data.y_positive.anchor_index.to(device) # [E] + ] if hasattr(main_data, "y_negative"): - hard_negative_idx: torch.Tensor = torch.cat( - list(main_data.y_negative.values()) - ).to(device) + hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device) else: hard_negative_idx = torch.empty(0, dtype=torch.long).to(device) diff --git a/gigl/distributed/__init__.py b/gigl/distributed/__init__.py index f88b198cb..c5b351daf 100644 --- a/gigl/distributed/__init__.py +++ b/gigl/distributed/__init__.py @@ -3,6 +3,7 @@ """ __all__ = [ + "AnchorLabels", "DistABLPLoader", "DistNeighborLoader", "DistDataset", @@ -17,7 +18,7 @@ build_dataset, build_dataset_from_task_config_uri, ) -from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_ablp_neighborloader import AnchorLabels, DistABLPLoader from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 50f42f5a9..2b1b3d9d7 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,4 +1,5 @@ -from collections import abc, defaultdict +from collections import abc +from dataclasses import dataclass from itertools import count from typing import Optional, Union @@ -49,12 +50,298 @@ reverse_edge_type, select_label_edge_types, ) -from gigl.utils.data_splitters import get_labels_for_anchor_nodes +from gigl.utils.data_splitters import PADDING_NODE, get_labels_for_anchor_nodes from gigl.utils.sampling import ABLPInputNodes logger = Logger() +@dataclass(frozen=True) +class AnchorLabels: + """ABLP labels for one edge type, stored as a flat (anchor, label) edge list. + + Each anchor can carry a different number of labels, so the natural shape is + ragged. Rather than a ``dict[int, torch.Tensor]`` -- which needs padding to + batch and a Python loop to read -- we keep two co-indexed ``long`` tensors so + the loss can index straight into them with no per-anchor iteration. The data + example below makes the layout concrete. + + Order within an anchor is deliberately left unspecified. The ABLP contrastive + loss (:class:`gigl.nn.loss.RetrievalLoss`) scores every (anchor, label) pair + independently and sums, so permuting the labels of an anchor leaves the loss + unchanged. We therefore emit pairs in the order they fall out of the source + ``[N_anchors, M]`` tensor (row-major, padding removed) and never pay to sort. + The two index tensors are always co-indexed, which is the only invariant that + matters. + + Empty anchors contribute no pairs at all. ``num_anchors`` is carried + separately so :meth:`to_dict` can still emit a key for every anchor, even the + ones that matched nothing. + + Dimension vocabulary (used throughout the label-remap code): + + - ``N_anchors`` -- anchor rows (rows of the source padded label tensor). + - ``M`` -- padded label columns per anchor. + - ``N_nodes`` -- nodes in the supervision local->global map. + - ``K`` -- non-padding candidate labels (after dropping the ``-1`` pad, before + membership filtering; still includes globals absent from the subgraph). + - ``E`` -- surviving ``(anchor, label)`` pairs after membership filtering + (``E <= K``); the length of ``anchor_index`` / ``label_index``. + + The ragged dict and this edge list hold the same labels in different + containers (three anchors, two of which carry labels):: + + dict form {0: [3], 1: [5, 7], 2: []} + edge-list form anchor_index = [0, 1, 1] # [E] = 3 + label_index = [3, 5, 7] # [E] = 3 + num_anchors = 3 # anchor 2 contributes no pair + + Example:: + + >>> import torch + >>> labels = AnchorLabels( + ... anchor_index=torch.tensor([0, 1, 1], dtype=torch.long), + ... label_index=torch.tensor([3, 5, 7], dtype=torch.long), + ... num_anchors=3, + ... ) + >>> labels.to_dict() + {0: tensor([3]), 1: tensor([5, 7]), 2: tensor([], dtype=torch.int64)} + + Args: + anchor_index (torch.Tensor): ``[E]`` long tensor of local anchor rows. + label_index (torch.Tensor): ``[E]`` long tensor of local label node ids. + num_anchors (int): Total number of anchors ``N_anchors`` (rows of the + source padded label tensor), including anchors with no labels. + """ + + anchor_index: torch.Tensor + label_index: torch.Tensor + num_anchors: int + + def to_dict(self) -> dict[int, torch.Tensor]: + """Expand to the legacy ragged ``dict[int, torch.Tensor]`` form. + + Every anchor ``0..num_anchors-1`` receives a key; anchors with no labels + map to an empty ``long`` tensor on the same device as ``label_index``. + + ``anchor_index`` MUST be non-decreasing (grouped by anchor). This holds for + every :class:`AnchorLabels` produced by :func:`edge_list_set_labels`, whose + row-major construction already groups pairs by anchor. The split below + relies on that ordering and is not validated; passing an ungrouped + ``anchor_index`` would silently assign labels to the wrong anchors. + + Returns: + Mapping from anchor index to its 1-D ``long`` tensor of local label + node ids. + """ + counts = torch.bincount(self.anchor_index, minlength=self.num_anchors) + per_anchor = torch.split(self.label_index, counts.tolist()) + return {anchor: per_anchor[anchor] for anchor in range(self.num_anchors)} + + +def _membership_remap( + label_tensor: torch.Tensor, + sorted_node: torch.Tensor, + sort_perm: torch.Tensor, + to_device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Resolve one padded label tensor to a flat ``(anchor, label)`` pair stream. + + This is the actual global-id-to-local-index lookup; :class:`AnchorLabels` and + :func:`edge_list_set_labels` just package what it returns. (See + :class:`AnchorLabels` for the ``N_anchors`` / ``M`` / ``N_nodes`` / ``K`` / ``E`` + dimension vocabulary used below.) + + The lookup is a sorted-membership join: a single ``O(K log N_nodes)`` + ``searchsorted`` over the sorted node map resolves every candidate label at + once, so the loader remaps labels with no Python loop over anchors. + + Worked example (generic ids):: + + node map (local -> global): [40, 10, 30] # N_nodes = 3 + sorted_node = [10, 30, 40], sort_perm = [1, 2, 0] + label_tensor ([N_anchors=2, M=2], -1 = pad): + [[30, -1], + [40, 10]] + + step 0 flatten row-major, tag each entry with its anchor row, drop pad: + flat = [30, 40, 10] # [K] = 3 candidates + anchor_of_* = [ 0, 1, 1] # [K] + step 1 searchsorted(sorted_node, flat) -> [1, 2, 0] # [K] positions + step 2 keep exact members (sorted_node[pos] == flat): all 3 -> [E] + step 3 sort_perm[pos] -> local index: [2, 0, 1] # [E] + result anchor_index = [0, 1, 1], label_index = [2, 0, 1] # [E] = 3 + + Check by hand: g30 is local 2, g40 is local 0, g10 is local 1 -- matches. Here + every candidate is a member so ``K == E``; the two differ once a global id is + absent from the node map (step 2 drops it). The code names the step-3 output + ``local_index``; :class:`AnchorLabels` stores that same tensor as ``label_index``. + + Because ``anchor_of_*`` is built row-major (step 0) and every mask preserves + order, the result is already grouped by anchor (non-decreasing ``anchor_index``); + order *within* an anchor is unspecified by contract, since the loss does not care + (see :class:`AnchorLabels`). + + The lookup is only correct if ``sorted_node`` has unique values: + :func:`torch.searchsorted` returns the left-most equal position, so a repeated + global id would map every match to the same local index and silently drop the + rest. GiGL ``node`` maps are unique by construction (one entry per subgraph + node), so this holds in production; a cheap adjacent-difference check on the + already-sorted map raises ``ValueError`` if a caller violates it. + + Args: + label_tensor (torch.Tensor): ``[N_anchors, M]`` ``-1``-padded global + label ids. + sorted_node (torch.Tensor): ``[N_nodes]`` sorted values of the supervision + node map (the ``values`` half of ``torch.sort``). + sort_perm (torch.Tensor): ``[N_nodes]`` permutation from ``torch.sort`` + mapping sorted positions back to original local indices. + to_device (torch.device): Device for the returned index tensors. + + Returns: + Tuple ``(anchor_index, local_index, num_anchors)``. ``anchor_index`` and + ``local_index`` are co-indexed ``[E]`` tensors grouped by anchor (empty when + nothing matched); ``local_index`` becomes :attr:`AnchorLabels.label_index`. + ``num_anchors == N_anchors == label_tensor.size(0)``. + """ + num_anchors = int(label_tensor.size(0)) + num_nodes = int(sorted_node.size(0)) + empty = torch.empty(0, dtype=torch.long, device=to_device) + if num_anchors == 0: + return empty, empty, num_anchors + + num_labels = int(label_tensor.size(1)) + flat = label_tensor.reshape(-1) # [N_anchors * M] before the pad mask + # step 0 (see docstring example): tag each flattened entry with its anchor row. + # Build on the label tensor's device: `anchor_of_entry` is indexed below by + # `is_present` (derived from `label_tensor`). On GPU, a CPU arange would + # raise "indices should be either on cpu or on the same device as the indexed + # tensor". CPU-only unit tests cannot catch this; see the CUDA-gated test. + anchor_of_entry = torch.arange( + num_anchors, device=label_tensor.device + ).repeat_interleave(num_labels) # [N_anchors * M] before the pad mask + + # step 0 cont.: drop the -1 pad before any search so we never gather with a + # sentinel. This is the [N_anchors * M] -> [K] (candidate) reduction. + is_present = flat != PADDING_NODE + flat = flat[is_present] # [K] + anchor_of_entry = anchor_of_entry[is_present] # [K] + + if num_nodes == 0 or flat.numel() == 0: + return empty, empty, num_anchors + + # Precondition for step 1 (see docstring): `sorted_node` is already sorted, so + # uniqueness is equivalent to being strictly increasing -- a cheap adjacent- + # difference check, no re-sort. A duplicate global id would silently drop labels + # and corrupt the loss, so reject it. + # `sorted_node` has at least one entry here (`num_nodes == 0` returned above), so + # the empty-slice comparison is `True` for a 1-element map. + if not bool((sorted_node[1:] > sorted_node[:-1]).all()): + raise ValueError( + "vectorized label remap requires a unique node local->global map; " + "duplicate global ids break the searchsorted membership lookup." + ) + + # step 1 (see docstring example): position of each candidate id in sorted_node. + sorted_positions = torch.searchsorted(sorted_node, flat) # [K] + # searchsorted returns N_nodes for an id larger than every entry, which would + # gather out of bounds at step 2; clamp it back into range. + sorted_positions = sorted_positions.clamp_(max=num_nodes - 1) + + # step 2 (see docstring example): keep only true members (a neighboring entry + # in the sorted array is not a match). This is the [K] -> [E] filter. + is_exact_match = sorted_node[sorted_positions] == flat # [K] bool + + # step 3 (see docstring example): sorted position -> original local node index + # via sort_perm (becomes AnchorLabels.label_index). + local_index = sort_perm[sorted_positions][is_exact_match] # [E] + anchor_of_matched = anchor_of_entry[is_exact_match] # [E] + + # Result rows stay grouped by anchor (step 0 tagging was row-major and the masks + # preserve order); within-anchor order is unspecified -- the ABLP loss is + # order-invariant (see AnchorLabels). + return ( + anchor_of_matched.to(to_device).to(torch.long), + local_index.to(to_device).to(torch.long), + num_anchors, + ) + + +def edge_list_set_labels( + node_local_to_global_by_type: dict[NodeType, torch.Tensor], + positive_labels_by_edge_type: dict[EdgeType, torch.Tensor], + negative_labels_by_edge_type: dict[EdgeType, torch.Tensor], + to_device: torch.device, +) -> tuple[dict[EdgeType, AnchorLabels], dict[EdgeType, AnchorLabels]]: + """Remap ABLP labels from global ids to local indices, as dense edge lists. + + This is the loader's single label-remap entry point. Sampling hands back + labels as global node ids in ``[N_anchors, M]`` padded blocks; training needs + them as local indices into the sampled subgraph. For each edge type this + resolves that mapping and returns an :class:`AnchorLabels` edge list. Callers + wanting the ragged ``dict[int, torch.Tensor]`` instead can expand each result + with :meth:`AnchorLabels.to_dict`; the two are the same labels in a different + container, and the loss treats them identically. + + Positive and negative labels are remapped the same way, against the same + sorted node maps, so the work is shared via a memoized sort per supervision + node type. A zero-anchor tensor is skipped rather than emitted as an empty + entry: there are simply no anchors to label for that edge type in this batch, + so no key is emitted for it. + + Correctness rests on each ``node`` local->global map having unique global ids; + the membership lookup relies on it (see :func:`_membership_remap`). GiGL node + maps satisfy this by construction. + + Args: + node_local_to_global_by_type (dict[NodeType, torch.Tensor]): Per node + type, a ``[N_nodes]`` tensor whose ``i``-th entry is the global id of + local node ``i``. Global ids MUST be unique within each map. + positive_labels_by_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` ``-1``-padded tensor + of global label ids. + negative_labels_by_edge_type (dict[EdgeType, torch.Tensor]): As above, + for negative-label edge types. May be empty. + to_device (torch.device): Device for every output tensor. + + Returns: + Tuple ``(y_positive, y_negative)``, each a + ``dict[message_passing_edge_type, AnchorLabels]`` with no entry for an + edge type that had no anchors this batch. + """ + sorted_cache: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {} + + def _sorted_for(node_type: NodeType) -> tuple[torch.Tensor, torch.Tensor]: + if node_type not in sorted_cache: + sorted_cache[node_type] = torch.sort( + node_local_to_global_by_type[node_type] + ) + return sorted_cache[node_type] + + def _remap( + labels_by_edge_type: dict[EdgeType, torch.Tensor], + ) -> dict[EdgeType, AnchorLabels]: + # Supervision edge types are (anchor_type, relation, supervision_type), so + # the supervision node type is index 2 (used below). + supervision_node_type_index = 2 + output: dict[EdgeType, AnchorLabels] = {} + for edge_type, label_tensor in labels_by_edge_type.items(): + # No anchors for this edge type this batch -> no key, not an empty one. + if label_tensor.size(0) == 0: + continue + sorted_node, sort_perm = _sorted_for(edge_type[supervision_node_type_index]) + # Remap globals -> locals via the sorted-membership join (see the + # labeled steps in _membership_remap). + output[label_edge_type_to_message_passing_edge_type(edge_type)] = ( + AnchorLabels( + *_membership_remap(label_tensor, sorted_node, sort_perm, to_device) + ) + ) + return output + + return _remap(positive_labels_by_edge_type), _remap(negative_labels_by_edge_type) + + class DistABLPLoader(BaseDistLoader): # Counts instantiations of this class, per process. # This is needed so we can generate unique worker key for each instance, for graph store mode. @@ -90,17 +377,20 @@ def __init__( local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this non_blocking_transfers: bool = True, + use_list_output: bool = False, ): """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. - Note that for this class, the dataset must *always* be heterogeneous, - as we need separate edge types for positive and negative labels. + The dataset must *always* be heterogeneous here, since positive and + negative labels are carried as separate edge types. By default, the loader will return {py:class} `torch_geometric.data.HeteroData` (heterogeneous) objects, but will return a {py:class}`torch_geometric.data.Data` (homogeneous) object if the dataset is "labeled homogeneous". - The following fields may also be present: + The following fields may also be present (this describes the default + `use_list_output=False` shape; see `use_list_output` below for the + `AnchorLabels` edge-list alternative): - `y_positive`: `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of positive label node ids. - `y_negative`: (Optional) `dict[int, torch.Tensor]` mapping from local anchor node id to a tensor of negative @@ -138,6 +428,27 @@ def __init__( - `y_positive`: {(a, to, b): {0: torch.tensor([1])}, (a, to, c): {0: torch.tensor([2])}} - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} + With `use_list_output=True`, the labels arrive instead as an `AnchorLabels` + edge-list (or `dict[EdgeType, AnchorLabels]` for several supervision edge + types); see :class:`AnchorLabels` for its shape. The edge-list keeps the + ragged per-anchor labels as flat tensors the loss can index directly, with + no padding or per-anchor Python loop; within-anchor order is unspecified + but the ABLP loss is order-invariant, and `AnchorLabels.to_dict()` recovers + the dict form. + + When `use_list_output=True`, the same example graph above (sampling around + node `0`) produces this format instead: + - `y_positive`: AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([1]), num_anchors=1) + - `y_negative`: AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([2]), num_anchors=1) + + And, as above, the label fields are instead `dict[EdgeType, AnchorLabels]` if multiple supervision edge types are provided. + e.g. for supervision edge types (a, to, b) and (a, to, c): + - `y_positive`: {(a, to, b): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([1]), num_anchors=1), (a, to, c): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([2]), num_anchors=1)} + - `y_negative`: {(a, to, b): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([3]), num_anchors=1), (a, to, c): AnchorLabels(anchor_index=torch.tensor([0]), label_index=torch.tensor([4]), num_anchors=1)} + + These hold the same labels as the `use_list_output=False` example, in a different container: + `AnchorLabels.to_dict()` on each value recovers the corresponding ragged dict (e.g. `{0: torch.tensor([1])}`). + Args: dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. @@ -213,11 +524,20 @@ def __init__( is used instead. See https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html for background on pinned memory and non-blocking transfers. + use_list_output (bool): Return labels as an ``AnchorLabels`` edge-list + (or ``dict[EdgeType, AnchorLabels]`` for multiple supervision edge + types) instead of the ragged ``dict[anchor_local_index, + torch.Tensor]``. The edge-list lets the loss read the co-indexed + ``y.label_index`` and ``query_idx[y.anchor_index]`` (both ``[E]``) + directly; see :class:`AnchorLabels` for the shape and the ``[E]`` + vocabulary. Defaults to ``False`` (the backward-compatible ragged + dict). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. self._shutdowned = True + self._use_list_output = use_list_output sampler_options = resolve_sampler_options(num_neighbors, sampler_options) @@ -752,20 +1072,41 @@ def _set_labels( positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor], ) -> Union[Data, HeteroData]: - """ - Sets the labels and relevant fields in the torch_geometric Data object, converting the global node ids for labels to their - local index. Removes inserted supervision edge type from the data variables, since this is an implementation detail and should not be - exposed in the final HeteroData/Data object. + """Attach ABLP labels to the collated graph, remapped to subgraph-local indices. + + This is the collation hook that turns the sampler's global-id labels into the + ``y_positive`` / ``y_negative`` fields downstream training reads. + + The actual remap is delegated to :func:`edge_list_set_labels` (the single + kernel): with ``use_list_output`` the labels are attached as an + :class:`AnchorLabels` edge list, otherwise expanded to the ragged + ``dict[anchor_local_index, torch.Tensor]`` via :meth:`AnchorLabels.to_dict`. + Both are the same labels in a different container. + + The supervision edge type is an internal sampling artifact, so it is stripped + before return and never appears on the output object. + + ``y_positive`` / ``y_negative`` collapse to a single value when there is one + supervision edge type, or a ``dict[EdgeType, ...]`` for several; see + :meth:`DistABLPLoader.__init__` for the full shape contract. + Args: - data (Union[Data, HeteroData]): Graph to provide labels for - positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[positive label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. - negative_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Dict[negative label edge type, label ID tensor], - where the ith row of the tensor corresponds to the ith anchor node ID. + data (Union[Data, HeteroData]): Graph to attach labels to. + positive_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): Per + positive-label edge type, a ``[N_anchors, M]`` tensor whose ``i``-th + row holds the global label ids of the ``i``-th anchor. + negative_labels_by_label_edge_type (dict[EdgeType, torch.Tensor]): As + above, for negative-label edge types. + Returns: - Union[Data, HeteroData]: torch_geometric HeteroData/Data object with the filtered edge fields and labels set as properties of the instance + Union[Data, HeteroData]: The same object with the supervision edge fields + stripped and ``y_positive`` (and ``y_negative`` when present) attached. + + Raises: + ValueError: If no positive labels are found in ``data``. """ - # shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i` + # node_type_to_local_node_to_global_node[t][i]: global id of local node i; + # each value tensor is [N_nodes] for its node type. node_type_to_local_node_to_global_node: dict[NodeType, torch.Tensor] = {} if isinstance(data, HeteroData): for e_type in self._supervision_edge_types: @@ -775,48 +1116,25 @@ def _set_labels( node_type_to_local_node_to_global_node[DEFAULT_HOMOGENEOUS_NODE_TYPE] = ( data.node ) - output_positive_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( - dict + # The edge-list kernel is the remap path; the ragged dict is one view of it + # (AnchorLabels.to_dict), expanded here when the caller wants the dict. Both + # forms feed an order-invariant contrastive loss, so the choice is purely + # about the consumer's preferred shape. + output_positive_labels, output_negative_labels = edge_list_set_labels( + node_local_to_global_by_type=node_type_to_local_node_to_global_node, + positive_labels_by_edge_type=positive_labels_by_label_edge_type, + negative_labels_by_edge_type=negative_labels_by_label_edge_type, + to_device=self.to_device, ) - output_negative_labels: dict[EdgeType, dict[int, torch.Tensor]] = defaultdict( - dict - ) - # We always have supervision edge types of the form (anchor_node_type, to, supervision_node_type) - # So we can index into the edge type accordingly. - edge_index = 2 - for edge_type, label_tensor in positive_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - positive_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, P], where N is the number of nodes and P is the number of positive labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the positive labels for the current anchor node - output_positive_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(positive_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the positive labels for the current anchor node - - for edge_type, label_tensor in negative_labels_by_label_edge_type.items(): - for local_anchor_node_id in range(label_tensor.size(0)): - negative_mask = ( - node_type_to_local_node_to_global_node[ - edge_type[edge_index] - ].unsqueeze(1) - == label_tensor[local_anchor_node_id] - ) # shape [N, M], where N is the number of nodes and M is the number of negative labels for the current anchor node - - # Gets the indexes of the items in local_node_to_global_node which match any of the negative labels for the current anchor node - output_negative_labels[ - label_edge_type_to_message_passing_edge_type(edge_type) - ][local_anchor_node_id] = torch.nonzero(negative_mask)[:, 0].to( - self.to_device - ) - # Shape [X], where X is the number of indexes in the original local_node_to_global_node which match a node in the negative labels for the current anchor node + if not self._use_list_output: + output_positive_labels = { + et: anchor_labels.to_dict() + for et, anchor_labels in output_positive_labels.items() + } + output_negative_labels = { + et: anchor_labels.to_dict() + for et, anchor_labels in output_negative_labels.items() + } if not output_positive_labels: raise ValueError("No positive labels were found in the data!") elif len(output_positive_labels) == 1: diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 31d3d1cbc..e35803269 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -11,7 +11,7 @@ from torch_geometric.data import Data, HeteroData from gigl.distributed.dataset_factory import build_dataset -from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_ablp_neighborloader import AnchorLabels, DistABLPLoader from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_partitioner import DistPartitioner from gigl.distributed.dist_range_partitioner import DistRangePartitioner @@ -416,6 +416,108 @@ def _run_distributed_ablp_neighbor_loader_multiple_supervision_edge_types( shutdown_rpc() +def _global_pair_set( + node: torch.Tensor, label_dict: dict[int, torch.Tensor] +) -> list[tuple[int, int]]: + """Flatten a per-anchor label dict to a sorted (global_anchor, global_label) + pair list for set-equality comparison. + + Pairs are collected in anchor-ascending order and sorted within each anchor + by global label value, so the result is canonical regardless of the order in + which the kernel emits an anchor's labels (which is unspecified by contract). + Multiplicity is preserved: duplicate label columns produce duplicate entries + in the list. + + Use ``collections.Counter`` or ``sorted(...)`` equality rather than + positional ``==`` if pair-level multiplicity without anchor grouping matters. + The current callers use ``list ==`` after sorting within anchor, which is + equivalent to Counter equality for these tests. + """ + pairs: list[tuple[int, int]] = [] + for local_anchor in sorted(label_dict.keys()): + global_anchor = int(node[local_anchor].item()) + for local_label in sorted(label_dict[local_anchor].tolist()): + pairs.append((global_anchor, int(node[local_label].item()))) + return pairs + + +def _collect_homogeneous_labels( + _, + return_dict, + use_list_output: bool, + dataset: DistDataset, + input_nodes: torch.Tensor, + batch_size: int, + has_negatives: bool, +): + """Child-side: run the loader, return sorted global-id pair sets. + + Local node indices differ run-to-run, so labels are translated back to + global ids via ``datum.node``. Pairs are sorted within each anchor (see + ``_global_pair_set``) for canonical set-equality comparison in the parent, + since the order in which an anchor's labels are emitted is unspecified -- the + ABLP loss is permutation-invariant over the pairs. + + When ``use_list_output`` is True the labels arrive as :class:`AnchorLabels`. + This branch ALSO asserts in-process that reading labels straight off the + edge-list agrees with reading them via the dict view: ``label_index`` must + equal the dict's ``torch.cat(values())``, and ``query_node_idx[anchor_index]`` + must equal a ``repeat_interleave`` of the anchors over their per-anchor label + counts. Both views come from the same edge-list, so any drift between the two + read patterns is a real bug, not just a reordering. + """ + create_test_process_group() + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[2, 2], + input_nodes=input_nodes, + batch_size=batch_size, + pin_memory_device=torch.device("cpu"), + use_list_output=use_list_output, + ) + positive_pairs: list[tuple[int, int]] = [] + negative_pairs: list[tuple[int, int]] = [] + for datum in loader: + assert isinstance(datum, Data) + node = datum.node + if use_list_output: + assert isinstance(datum.y_positive, AnchorLabels), ( + f"expected AnchorLabels, got {type(datum.y_positive)}" + ) + positive_dict = datum.y_positive.to_dict() + # The direct edge-list read must match the dict-view read within this + # single batch, EXACTLY (order included). + query_node_idx = torch.arange(datum.batch_size) + dict_view_label_idx = torch.cat( + [positive_dict[a] for a in range(datum.batch_size)] + ) + dict_view_repeated_query = query_node_idx.repeat_interleave( + torch.tensor([len(positive_dict[a]) for a in range(datum.batch_size)]) + ) + torch.testing.assert_close( + datum.y_positive.label_index, dict_view_label_idx + ) + torch.testing.assert_close( + query_node_idx[datum.y_positive.anchor_index], dict_view_repeated_query + ) + else: + positive_dict = datum.y_positive + positive_pairs.extend(_global_pair_set(node, positive_dict)) + if has_negatives: + if use_list_output: + assert isinstance(datum.y_negative, AnchorLabels) + negative_dict = datum.y_negative.to_dict() + else: + negative_dict = datum.y_negative + negative_pairs.extend(_global_pair_set(node, negative_dict)) + else: + assert not hasattr(datum, "y_negative"), ( + f"expected no negatives, got {getattr(datum, 'y_negative', None)}" + ) + return_dict[use_list_output] = (positive_pairs, negative_pairs) + shutdown_rpc() + + class DistABLPLoaderTest(TestCase): def tearDown(self): if torch.distributed.is_initialized(): @@ -556,6 +658,107 @@ def test_ablp_dataloader( ), ) + @parameterized.expand( + [ + param( + "positive and negative", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + input_nodes=torch.tensor([10, 15]), + batch_size=2, + has_negatives=True, + ), + param( + "positive only", + labeled_edges={_POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]])}, + input_nodes=torch.tensor([10, 15]), + batch_size=2, + has_negatives=False, + ), + # Anchor 11 has message-passing edges (11 -> {13, 17}) but is the + # source of NO positive-label edge, so its positive-label row is + # all-padding and y_positive[11] is a guaranteed-empty tensor. This + # exercises the empty-anchor branch end-to-end for both outputs. + param( + "guaranteed empty positive anchor", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + input_nodes=torch.tensor([10, 11, 15]), + batch_size=3, + has_negatives=True, + ), + ] + ) + def test_use_list_output_matches_dict_output( + self, _, labeled_edges, input_nodes, batch_size, has_negatives + ): + """``use_list_output=True`` yields AnchorLabels whose ``.to_dict()`` produces + a per-anchor SET-equal result to the default dict output. + + Both paths draw from the same :func:`_membership_remap` column-visit pair + stream, so they are in fact identical; the test confirms no membership or + co-indexing regression. Sampling is deterministic (``shuffle`` defaults to + False), so the two loader runs emit batches in the same order and the sorted + pair sets are directly comparable. + + The child process additionally asserts that reading labels straight off the + edge-list equals reading them through the dict view (see + ``_collect_homogeneous_labels``): ``label_index`` equals the dict's + ``torch.cat(values())`` and ``query_node_idx[anchor_index]`` equals the + matching ``repeat_interleave`` over per-anchor counts. + """ + edge_index = { + DEFAULT_HOMOGENEOUS_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 11, 15, 15, 16, 16], [11, 12, 13, 17, 13, 14, 12, 14]] + ), + } + edge_index.update(labeled_edges) + partition_output = PartitionOutput( + node_partition_book=to_heterogeneous_node(torch.zeros(18)), + edge_partition_book={ + e_type: torch.zeros(int(e_idx.max().item() + 1)) + for e_type, e_idx in edge_index.items() + }, + partitioned_edge_index={ + etype: GraphPartitionData( + edge_index=idx, edge_ids=torch.arange(idx.size(1)) + ) + for etype, idx in edge_index.items() + }, + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + manager = mp.Manager() + return_dict = manager.dict() + for use_list_output in (False, True): + mp.spawn( + fn=_collect_homogeneous_labels, + args=( + return_dict, + use_list_output, + dataset, + input_nodes, + batch_size, + has_negatives, + ), + ) + self.assertEqual(return_dict[False][0], return_dict[True][0]) + self.assertEqual(return_dict[False][1], return_dict[True][1]) + def test_cora_supervised(self): create_test_process_group() cora_supervised_info = get_mocked_dataset_artifact_metadata()[ diff --git a/tests/unit/distributed/edge_list_set_labels_test.py b/tests/unit/distributed/edge_list_set_labels_test.py new file mode 100644 index 000000000..32eee2a97 --- /dev/null +++ b/tests/unit/distributed/edge_list_set_labels_test.py @@ -0,0 +1,375 @@ +"""Unit tests for the ABLP label-remap kernel and its edge-list container. + +These exercise the pure-tensor label-remap logic directly (no GLT, no +distributed runtime), so they run in-process without ``mp.spawn``. + +``edge_list_set_labels`` is the loader's label-remap path; it turns padded blocks +of global label ids into per-edge-type :class:`AnchorLabels` edge lists. We assert +against constructed expected values, and we check both the edge-list tensors and +their :meth:`AnchorLabels.to_dict` view, since the loader uses both forms. + +Within an anchor the kernel emits labels in column-visit order, which is left +unspecified by contract (the ABLP loss is order-invariant). We therefore pin +exact tensors only where the node map is sorted -- there column order coincides +with ascending-local order -- and otherwise compare per-anchor label *sets*. +""" + +import unittest + +import torch +from parameterized import param, parameterized +from torch_geometric.typing import EdgeType as PyGEdgeType + +from gigl.distributed.dist_ablp_neighborloader import ( + AnchorLabels, + edge_list_set_labels, +) +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, + message_passing_to_negative_label, + message_passing_to_positive_label, +) +from tests.test_assets.test_case import TestCase + +_CPU = torch.device("cpu") +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) +_A = NodeType("a") +_B = NodeType("b") +_C = NodeType("c") +_A_TO_B = EdgeType(_A, Relation("to"), _B) +_A_TO_C = EdgeType(_A, Relation("to"), _C) + + +def _pos(edge_type: EdgeType) -> EdgeType: + return message_passing_to_positive_label(edge_type) + + +def _neg(edge_type: EdgeType) -> EdgeType: + return message_passing_to_negative_label(edge_type) + + +def _assert_dict_sets_equal( + actual: dict[PyGEdgeType, dict[int, torch.Tensor]], + expected: dict[PyGEdgeType, dict[int, torch.Tensor]], +) -> None: + """Assert per-anchor SET equality (with multiplicity) between two label dicts. + + Comparison is order-independent within an anchor (the kernel does not pin + within-anchor order) but still catches missing/extra labels and multiplicity + (duplicate label columns), and asserts each value tensor is ``long``. + """ + assert set(actual.keys()) == set(expected.keys()), ( + f"{set(actual.keys())} != {set(expected.keys())}" + ) + for edge_type, inner in expected.items(): + actual_inner = actual[edge_type] + assert set(actual_inner.keys()) == set(inner.keys()), ( + f"{edge_type}: {set(actual_inner.keys())} != {set(inner.keys())}" + ) + for anchor, expected_tensor in inner.items(): + got = actual_inner[anchor] + assert got.dtype == torch.long, f"{edge_type}[{anchor}] dtype {got.dtype}" + assert sorted(got.tolist()) == sorted(expected_tensor.tolist()), ( + f"{edge_type}[{anchor}]: got {got.tolist()}, " + f"expected {expected_tensor.tolist()} (as sets)" + ) + + +class AnchorLabelsTest(TestCase): + def test_to_dict_expands_empty_and_multi_label_anchors(self) -> None: + # 3 anchors: anchor 0 -> [5], anchor 1 -> [] (empty), anchor 2 -> [7, 8]. + # Every anchor must get a key, including the empty one in the middle. + labels = AnchorLabels( + anchor_index=torch.tensor([0, 2, 2], dtype=torch.long), + label_index=torch.tensor([5, 7, 8], dtype=torch.long), + num_anchors=3, + ) + as_dict = labels.to_dict() + self.assertEqual(set(as_dict.keys()), {0, 1, 2}) + torch.testing.assert_close(as_dict[0], torch.tensor([5], dtype=torch.long)) + torch.testing.assert_close(as_dict[1], torch.empty(0, dtype=torch.long)) + torch.testing.assert_close(as_dict[2], torch.tensor([7, 8], dtype=torch.long)) + + def test_to_dict_all_empty(self) -> None: + # No pairs at all: every anchor still gets an empty tensor. + labels = AnchorLabels( + anchor_index=torch.empty(0, dtype=torch.long), + label_index=torch.empty(0, dtype=torch.long), + num_anchors=2, + ) + as_dict = labels.to_dict() + self.assertEqual(set(as_dict.keys()), {0, 1}) + torch.testing.assert_close(as_dict[0], torch.empty(0, dtype=torch.long)) + torch.testing.assert_close(as_dict[1], torch.empty(0, dtype=torch.long)) + + +class EdgeListSetLabelsTest(TestCase): + @parameterized.expand( + [ + # Sorted node map -> exact tensors are pinned. Covers present labels, + # a fully-padded (empty) anchor, and an absent global id (also empty). + param( + "sorted_present_empty_and_padded", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1], [99, -1]], dtype=torch.long + ) + }, + negatives={}, + expected_positive={ + _USER_TO_STORY: { + 0: [5], + 1: [5, 6], + 2: [], + 3: [], + } + }, + expected_negative={}, + ), + # Duplicate label columns must keep multiplicity: a node matching two + # identical columns appears twice in its anchor's labels. + param( + "duplicate_label_columns_keep_multiplicity", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[15, 15], [11, 11]], dtype=torch.long + ) + }, + negatives={}, + expected_positive={_USER_TO_STORY: {0: [5, 5], 1: [1, 1]}}, + expected_negative={}, + ), + # Real homogeneous keying: the loader keys the node map by + # DEFAULT_HOMOGENEOUS_NODE_TYPE and uses DEFAULT_HOMOGENEOUS_EDGE_TYPE. + # Exercise that exact keying, not just a custom NodeType. Node map is + # unsorted here, so we compare per-anchor sets. + param( + "default_homogeneous_keying", + node_map={ + DEFAULT_HOMOGENEOUS_NODE_TYPE: torch.tensor([20, 10, 30, 11, 15]) + }, + positives={ + message_passing_to_positive_label( + DEFAULT_HOMOGENEOUS_EDGE_TYPE + ): torch.tensor([[30, 10], [-1, -1]], dtype=torch.long) + }, + negatives={}, + # g30->local2, g10->local1 ; second anchor fully padded. + expected_positive={DEFAULT_HOMOGENEOUS_EDGE_TYPE: {0: [2, 1], 1: []}}, + expected_negative={}, + ), + # Negatives travel the same path and must remap independently of the + # positives. Sorted node map -> exact sets are unambiguous. + param( + "with_negatives", + node_map={_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor([[15], [16]], dtype=torch.long) + }, + negatives={ + _neg(_USER_TO_STORY): torch.tensor( + [[13, 16], [17, -1]], dtype=torch.long + ) + }, + expected_positive={_USER_TO_STORY: {0: [5], 1: [6]}}, + expected_negative={_USER_TO_STORY: {0: [3, 6], 1: [7]}}, + ), + # Multiple supervision edge types: each gets its own key and remaps + # against its own supervision node type's map. + param( + "multi_edge_type", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 20, 21]), + _C: torch.tensor([20, 21, 22, 23]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={}, + expected_positive={ + _A_TO_B: {0: [2, 3]}, # g13->local2, g14->local3 in _B + _A_TO_C: {0: [2, 3]}, # g22->local2, g23->local3 in _C + }, + expected_negative={}, + ), + # Multiple edge types WITH negatives on each. + param( + "multi_edge_type_with_negatives", + node_map={ + _A: torch.tensor([10]), + _B: torch.tensor([11, 12, 13, 14, 15, 16]), + _C: torch.tensor([20, 21, 22, 23, 24, 25]), + }, + positives={ + _pos(_A_TO_B): torch.tensor([[13, 14]], dtype=torch.long), + _pos(_A_TO_C): torch.tensor([[22, 23]], dtype=torch.long), + }, + negatives={ + _neg(_A_TO_B): torch.tensor([[15, 16]], dtype=torch.long), + _neg(_A_TO_C): torch.tensor([[24, 25]], dtype=torch.long), + }, + expected_positive={ + _A_TO_B: {0: [2, 3]}, + _A_TO_C: {0: [2, 3]}, + }, + expected_negative={ + _A_TO_B: {0: [4, 5]}, # g15->local4, g16->local5 in _B + _A_TO_C: {0: [4, 5]}, # g24->local4, g25->local5 in _C + }, + ), + # Every anchor empty (one fully padded, one with only absent ids): + # the edge type still appears, with an empty tensor per anchor. + param( + "all_anchors_empty", + node_map={_STORY: torch.tensor([10, 11, 12])}, + positives={ + _pos(_USER_TO_STORY): torch.tensor( + [[-1, -1], [99, 98]], dtype=torch.long + ) + }, + negatives={}, + expected_positive={_USER_TO_STORY: {0: [], 1: []}}, + expected_negative={}, + ), + ] + ) + def test_to_dict_matches_constructed_expected( + self, + _, + node_map: dict[NodeType, torch.Tensor], + positives: dict[EdgeType, torch.Tensor], + negatives: dict[EdgeType, torch.Tensor], + expected_positive: dict[PyGEdgeType, dict[int, list[int]]], + expected_negative: dict[PyGEdgeType, dict[int, list[int]]], + ) -> None: + pos, neg = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type=negatives, + to_device=_CPU, + ) + for labels in pos.values(): + self.assertIsInstance(labels, AnchorLabels) + for labels in neg.values(): + self.assertIsInstance(labels, AnchorLabels) + expected_pos_tensors = { + et: {a: torch.tensor(v, dtype=torch.long) for a, v in inner.items()} + for et, inner in expected_positive.items() + } + expected_neg_tensors = { + et: {a: torch.tensor(v, dtype=torch.long) for a, v in inner.items()} + for et, inner in expected_negative.items() + } + _assert_dict_sets_equal( + {et: labels.to_dict() for et, labels in pos.items()}, + expected_pos_tensors, + ) + _assert_dict_sets_equal( + {et: labels.to_dict() for et, labels in neg.items()}, + expected_neg_tensors, + ) + + def test_sorted_node_map_pins_exact_edge_list_tensors(self) -> None: + # With a sorted node map column-visit order coincides with ascending-local + # order, so the kernel's AnchorLabels tensors are fully determined: pin + # them exactly. This locks the (anchor_index, label_index) co-indexing the + # loss relies on. Covers present labels, a multi-label anchor, and a + # fully-padded (empty) anchor. + node_map = {_STORY: torch.tensor([10, 11, 12, 13, 14, 15, 16, 17])} + positives = { + _pos(_USER_TO_STORY): torch.tensor( + [[15, -1], [15, 16], [-1, -1]], dtype=torch.long + ) + } + pos, _ = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + labels = pos[_USER_TO_STORY] + self.assertEqual(labels.num_anchors, 3) + torch.testing.assert_close( + labels.anchor_index, torch.tensor([0, 1, 1], dtype=torch.long) + ) + torch.testing.assert_close( + labels.label_index, torch.tensor([5, 5, 6], dtype=torch.long) + ) + + def test_unsorted_node_map_correct_membership(self) -> None: + # The node map is UNSORTED, so torch.sort yields a non-identity sort_perm + # and the kernel must map sorted positions back through it to recover real + # local indices. We assert the per-anchor SET (within-anchor order is + # unspecified by contract): a regression that emitted sorted-position + # indices instead of sort_perm-mapped locals would produce {1, 3} here and + # fail, so this proves the local indices are real node indices. + # node = [15, 10, 16, 11] -> g15=local0, g10=local1, g16=local2, g11=local3. + # Labels [16, 15] -> SET {local0, local2} = {0, 2}. + node_map = {_STORY: torch.tensor([15, 10, 16, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[16, 15]], dtype=torch.long)} + pos, _ = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + self.assertEqual(sorted(pos[_USER_TO_STORY].to_dict()[0].tolist()), [0, 2]) + + def test_zero_anchor_tensor_yields_no_edge_type_key(self) -> None: + # A zero-anchor label tensor means there were no anchors for that edge + # type this batch, so it must NOT appear in the output at all (not as an + # empty entry). + node_map = {_STORY: torch.tensor([10, 11, 12])} + positives = {_pos(_USER_TO_STORY): torch.empty((0, 0), dtype=torch.long)} + pos, neg = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + self.assertEqual(pos, {}) + self.assertEqual(neg, {}) + + def test_device_placement_cpu(self) -> None: + # The output index tensors must land on the requested device. (CPU here; + # the CUDA counterpart lives in label_remap_cuda_device_test.py.) + node_map = {_STORY: torch.tensor([10, 11, 12, 13, 14, 15])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[15], [11]], dtype=torch.long)} + pos, _ = edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + labels = pos[_USER_TO_STORY] + self.assertEqual(labels.anchor_index.device.type, "cpu") + self.assertEqual(labels.label_index.device.type, "cpu") + self.assertEqual(labels.anchor_index.dtype, torch.long) + self.assertEqual(labels.label_index.dtype, torch.long) + + def test_duplicate_node_map_raises(self) -> None: + # The membership lookup requires unique global ids; a duplicate would + # silently drop a local index and corrupt the loss, so the kernel raises + # ``ValueError``. GiGL node maps are unique by construction, so this only + # guards misuse of the public kernel. + node_map = {_STORY: torch.tensor([10, 10, 11])} + positives = {_pos(_USER_TO_STORY): torch.tensor([[10, 11]], dtype=torch.long)} + with self.assertRaises(ValueError): + edge_list_set_labels( + node_local_to_global_by_type=node_map, + positive_labels_by_edge_type=positives, + negative_labels_by_edge_type={}, + to_device=_CPU, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/distributed/label_remap_cuda_device_test.py b/tests/unit/distributed/label_remap_cuda_device_test.py new file mode 100644 index 000000000..4a98a3d60 --- /dev/null +++ b/tests/unit/distributed/label_remap_cuda_device_test.py @@ -0,0 +1,84 @@ +"""CUDA device-placement regression test for the ABLP label-remap kernel. + +``edge_list_set_labels`` builds an internal ``anchor_of_entry`` index and then +selects it with a mask derived from the input ``label_tensor``. If that index is +created on CPU while ``label_tensor`` is on GPU, the masked select raises +``"indices should be either on cpu or on the same device as the indexed +tensor"``. CPU-only unit tests cannot observe this, so the bug only surfaces on a +real GPU training run. + +This test runs the kernel with all inputs on CUDA and asserts the result matches +the CPU result. It is skipped when no GPU is present (e.g. CPU CI); run it on a +CUDA host to guard the device placement. +""" + +import unittest + +import torch + +from gigl.distributed.dist_ablp_neighborloader import edge_list_set_labels +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import message_passing_to_positive_label +from tests.test_assets.test_case import TestCase + +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) + + +def _inputs(device: torch.device): + """A small case exercising every on-device code path. + + The node map is UNSORTED (so ``torch.sort`` yields a non-identity + ``sort_perm`` -- the gather through it must run on-device). Anchor 0 has a + DUPLICATE label column ([15, 15]) to exercise duplicate handling and + verify multiplicity is preserved (local 0 appears twice). Local layout: + g15->0, g10->1, g16->2, g11->3, g12->4. Anchor rows: [15, 15] -> local 0 + twice; [16, -1] -> local 2; [-1, -1] -> empty. + """ + node_map = {_STORY: torch.tensor([15, 10, 16, 11, 12], device=device)} + positives = { + message_passing_to_positive_label(_USER_TO_STORY): torch.tensor( + [[15, 15], [16, -1], [-1, -1]], dtype=torch.long, device=device + ) + } + return node_map, positives + + +@unittest.skipUnless(torch.cuda.is_available(), "requires a CUDA device") +class LabelRemapCudaDeviceTest(TestCase): + def test_edge_list_set_labels_cuda_matches_cpu(self) -> None: + cpu_node, cpu_pos = _inputs(torch.device("cpu")) + expected_pos, _ = edge_list_set_labels( + node_local_to_global_by_type=cpu_node, + positive_labels_by_edge_type=cpu_pos, + negative_labels_by_edge_type={}, + to_device=torch.device("cpu"), + ) + + cuda = torch.device("cuda") + cuda_node, cuda_pos = _inputs(cuda) + # Must not raise the CPU/GPU index mismatch. + got_pos, _ = edge_list_set_labels( + node_local_to_global_by_type=cuda_node, + positive_labels_by_edge_type=cuda_pos, + negative_labels_by_edge_type={}, + to_device=cuda, + ) + + self.assertEqual(set(got_pos.keys()), set(expected_pos.keys())) + for edge_type, expected_labels in expected_pos.items(): + got_labels = got_pos[edge_type] + # The output tensors must land on the requested CUDA device. + self.assertEqual(got_labels.anchor_index.device.type, "cuda") + self.assertEqual(got_labels.label_index.device.type, "cuda") + expected_dict = expected_labels.to_dict() + got_dict = got_labels.to_dict() + self.assertEqual(set(got_dict.keys()), set(expected_dict.keys())) + for anchor, expected_tensor in expected_dict.items(): + # Multiplicity (the duplicate [15, 15] column) must survive too. + torch.testing.assert_close(got_dict[anchor].cpu(), expected_tensor) + + +if __name__ == "__main__": + unittest.main()