DistABLPLoader: vectorized label remap + optional AnchorLabels edge-list output#681
Draft
kmontemayor2-sc wants to merge 18 commits into
Draft
DistABLPLoader: vectorized label remap + optional AnchorLabels edge-list output#681kmontemayor2-sc wants to merge 18 commits into
kmontemayor2-sc wants to merge 18 commits into
Conversation
…racle Factor the inline per-anchor label-remap loop in DistABLPLoader._set_labels into a module-level function _loop_set_labels. The new function is a behavior-preserving extraction: _set_labels delegates to it, producing identical output. _loop_set_labels will serve as the equivalence oracle for the vectorized kernel added in the next task. Also imports PADDING_NODE from gigl.utils.data_splitters (used by the vectorized kernel in the next task) and adds the contract test file tests/unit/distributed/vectorized_set_labels_test.py. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…oop oracle Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ls kernel Add frozen dataclass AnchorLabels (anchor_index, label_index, num_anchors) with to_dict() bridge; thin wrapper _remap_one_label_tensor_edge_list over the shared _membership_remap; and edge_list_set_labels driver. Proves to_dict() reproduces _loop_set_labels bit-for-bit via parametrized equivalence tests (11 new tests, all classes pass). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…eous training Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…eneous training Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…test Line-wrapping only (ruff format) for the edge-list label reads in the four link-prediction training examples and the loader equivalence test. No logic change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
292f4d4 to
bd97d03
Compare
…ctors
**Behavioral change:** Drop the composite-key stable argsort from
`_membership_remap` (the `composite_key = anchor_kept * (num_nodes + 1) +
local_idx` / `torch.argsort(composite_key, stable=True)` block). The pair
stream is now emitted in column-visit (row-major masked flatten) order rather
than ascending-local-index (torch.nonzero) order.
**Why this is safe:** `RetrievalLoss` (`gigl/nn/loss.py`) is
`CrossEntropyLoss(reduction="sum")` over a diagonal-targeted score matrix that
masks collisions by id VALUE, not position. It is invariant to any joint
permutation of `(anchor_index, label_index)` pairs. The only constraint is
co-indexing (pair k stays intact) and per-anchor grouping, both of which are
preserved. Within-anchor label order was an implementation artifact of the loop
oracle -- never a loss requirement.
**Why the dict path is unaffected:** `anchor_of_entry` is built as
`arange(N).repeat_interleave(M)` (row-major), so the masked flatten is already
non-decreasing in anchor_index. `bincount`/`split` requires only contiguous
grouping by anchor, which row-major flatten guarantees without any argsort.
**Readability refactors in `_membership_remap`:**
- Rename terse tensors: `valid` -> `is_present`, `found` -> `is_exact_match`,
`positions` -> `sorted_positions`, `local_idx` -> `local_index`,
`anchor_kept` -> `anchor_of_matched`
- Add numbered step comments explaining the searchsorted membership lookup
**Readability refactors in outer kernels:**
- Add `_remap_group` helper to collapse the ~40-line duplicated positive/
negative per-edge-type loops in both `vectorized_set_labels` and
`edge_list_set_labels` (uses TypeVar for generic return type)
- `_sorted_for` closure pattern preserved, inlined per-kernel (memoizes
torch.sort across pos/neg edge types of the same node type)
**Other improvements:**
- Add doctest to `AnchorLabels` showing a 2-anchor case + `to_dict()` round-trip
- Update `AnchorLabels` docstring: document column-visit order and
loss-permutation-invariance rationale
- Update all docstrings to drop "bit-for-bit"/"torch.nonzero order" language
**Test contract relaxation (tests remain non-vacuous):**
- `vectorized_set_labels_test.py` / `edge_list_set_labels_test.py`:
`_assert_label_dicts_equal` -> `_assert_label_dicts_set_equal` (uses
`sorted()` per anchor; still catches membership errors + multiplicity)
- `test_unsorted_node_map_exact_order` -> `test_unsorted_node_map_correct_membership`:
asserts SET {0, 2} instead of sequence [0, 2]; docstring explains column
vs ascending-local order difference
- `RemapOneEdgeListTest.test_unsorted_node_map_nontrivial_sort_perm` ->
`test_unsorted_node_map_correct_membership`: pins column order [2, 0] and
explains it is SET-equal to [0, 2]; verifies sort_perm mapping is correct
- `dist_ablp_neighborloader_test.py`: `_ordered_global_pairs` ->
`_global_pair_set` (sorts within anchor); `_collect_homogeneous_labels`
docstring updated; the in-process exact-tensor assertion
(`label_index == cat(dict values)`) is preserved -- both paths draw from the
same `_membership_remap` pair stream so they remain identical
- CUDA device test docstring: remove "stable-argsort tie-break" reference;
keep the duplicate [15,15] row (still tests duplicate handling + device placement)
Device fix preserved: `anchor_of_entry` is still built on `label_tensor.device`.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…er-identical After dropping the order-reproduction argsort, per-anchor label order is column-visit order, not the loop's nonzero order. The (query, label) pairs are unchanged and the contrastive loss is order-invariant, so the read is equivalent for training; the comments now say so instead of implying byte-identical order. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
486e670 to
d42d0df
Compare
GiGL main resolved ABLP labels with a per-anchor Python loop. This replaces it with one vectorized kernel and a single output type, plus a thin dict view for backward compatibility. - Resolve labels with a sorted-membership join (_membership_remap: sort the node map once, searchsorted the label ids, keep exact matches) instead of the O(N_anchors * M * N_nodes) per-anchor loop. Delete _loop_set_labels (no production callers) and the redundant dict-producing kernel. - Collapse the remap to two functions, _membership_remap and edge_list_set_labels; the dict path is just AnchorLabels.to_dict(), selected by use_list_output on DistABLPLoader (default False keeps the ragged dict). Drop the _remap_group callback indirection, the _LabelT TypeVar, and the vestigial supervision_edge_types parity param. - AnchorLabels stores labels as two parallel (anchor_index, label_index) tensors so the loss can index them directly; within-anchor order is unspecified because the ABLP contrastive loss is order-invariant over the pairs. - Make the __debug__ unique-node-map check a cheap adjacent-difference test on the already-sorted map (it ran torch.unique every batch, and GiGL is not launched with -O). - Tests assert against constructed expected values and per-anchor label SETS, not within-anchor order. Examples consume the edge-list directly. Behavior-preserving: per-anchor label sets are unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
d42d0df to
641b24e
Compare
…ions Comments and docstrings only -- no behavior change (verified: the AST with docstrings stripped is identical to the prior commit for every touched file). - Docstrings lead with why over what; the membership-remap algorithm is shown as a small worked example (node map + [N_anchors, M] padded label tensor -> pairs) with labeled steps instead of dense prose. - Inline comments reference those docstring steps rather than floating free. - Tensor dimensions annotated throughout (N_anchors / M / N_nodes / K / E), including the K (non-padding candidates) vs E (matched pairs) distinction. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uniqueness guard - Add a use_list_output=True worked example to DistABLPLoader.__init__ that mirrors the use_list_output=False example (same graph, AnchorLabels values). - Document AnchorLabels.to_dict()'s non-decreasing anchor_index precondition. - Promote the _membership_remap duplicate-node-map guard from a __debug__ assert to an always-on ValueError so it still fires under `python -O`; update the test to expect ValueError. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Comments should describe the code as it is, not as a delta from a prior or never-committed version. Reword/remove comments that only parse if the reader knows the code's development history: - Drop the "Always-on (not __debug__) ... python -O" framing on the node-map uniqueness guard; keep the precondition + empty-slice safety rationale. - Fix a stale _membership_remap docstring line still describing a __debug__ assertion (it is now an always-on ValueError). - Remove "with no argsort" / "no argsort needed" (presupposed the dropped order-reproduction argsort); state the grouped-by-anchor property directly. - Reword "a sorted-membership join rather than a per-anchor scan ... trades a broadcast-compare for a search" to describe the algorithm and its complexity as-is. - Drop "rather than maintaining a second kernel" in _set_labels. - Same scrub in edge_list_set_labels_test.py (module docstring + guard test). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Reworks how
DistABLPLoaderremaps Anchor-Based Link Prediction (ABLP) labelsfrom global node ids to subgraph-local indices, and adds an optional edge-list
output format for those labels.
Internally, we see small but
Vectorized label remap. Label remapping now runs as a single
sorted-membership join (
searchsortedover the node map) instead of aper-anchor Python loop. The result is loss-equivalent to the previous output
(same
(anchor, label)pairs; within-anchor order is unspecified, and theABLP contrastive loss is order-invariant). We observe meaningful sampling/
collation performance improvements at production scale.
AnchorLabelsedge-list container +use_list_outputflag.DistABLPLoader(..., use_list_output=True)returns labels as anAnchorLabelsedge list (two co-indexed[E]tensors,anchor_index/label_index) that the loss can index directly with nopadding or per-anchor Python loop.
use_list_output=False(default) preservesthe existing ragged
dict[int, torch.Tensor]output, so this is backwardcompatible.
AnchorLabels.to_dict()recovers the dict form.Public API.
AnchorLabelsis exported fromgigl.distributed.Examples. The link-prediction training examples (homogeneous +
heterogeneous, colocated + graph-store) are updated to consume the
AnchorLabelsedge-list output.Tests
covering empty / fully-padded / duplicate-label / multi-edge-type cases.
use_list_output=Truevs=Falseproduce equivalent labels.ValueError.Notes
DistABLPLoaderdocstring andAnchorLabelsshape docs.