Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a0fe5cb
refactor(ablp): extract per-anchor label loop into _loop_set_labels o…
Jun 25, 2026
00f4cdb
feat(ablp): add vectorized_set_labels kernel, equivalence-tested vs l…
Jun 25, 2026
fd3fda5
feat(ablp): add AnchorLabels edge-list container + edge_list_set_labe…
Jun 25, 2026
3101506
test(ablp): add CUDA device-placement regression for label-remap kernels
Jun 25, 2026
d51bd7a
feat(ablp): vectorized label remap always + use_list_output ctor flag
Jun 25, 2026
efa446d
docs(ablp): document vectorized remap, use_list_output, and AnchorLabels
Jun 25, 2026
8c91d66
docs(examples): consume AnchorLabels edge-list in homogeneous training
Jun 25, 2026
1f1c069
docs(examples): consume AnchorLabels edge-list in heterogeneous training
Jun 25, 2026
f60df78
docs(examples): consume AnchorLabels edge-list in graph-store homogen…
Jun 25, 2026
ee4adeb
docs(examples): consume AnchorLabels edge-list in graph-store heterog…
Jun 25, 2026
dab9f8f
feat(distributed): export AnchorLabels from gigl.distributed
Jun 25, 2026
bd97d03
style(ablp): apply ruff formatting to ABLP label-output examples and …
Jun 25, 2026
e6bd5b9
refactor(ablp): drop order-reproduction argsort; add readability refa…
Jun 25, 2026
594ddba
docs(examples): clarify AnchorLabels read is loss-equivalent, not ord…
Jun 25, 2026
641b24e
refactor(ablp): single AnchorLabels label-remap kernel + dict view
Jun 26, 2026
5a58a53
docs(ablp): why-first docstrings, worked examples, tensor-dim annotat…
Jun 26, 2026
b70e626
docs(ablp): worked example for use_list_output=True; harden node-map …
kmonte Jun 29, 2026
aa978ef
docs(ablp): scrub history-relative comments in label-remap code
kmonte Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/link_prediction/graph_store/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
# Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring.
use_list_output=True,
)

print(f"---Rank {rank} finished setting up main loader for split={split}")
Expand Down Expand Up @@ -299,16 +301,15 @@ def _compute_loss(
).to(device)
random_negative_batch_size = random_negative_data[labeled_node_type].batch_size

positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to(
device
)
repeated_query_node_idx = query_node_idx.repeat_interleave(
torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device)
)
# main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E]
# label_index and query_node_idx[anchor_index] directly. See
# homogeneous_training._compute_loss for why this equals the historical dict read.
positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E]
repeated_query_node_idx = query_node_idx[
main_data.y_positive.anchor_index.to(device) # [E]
]
if hasattr(main_data, "y_negative"):
hard_negative_idx: torch.Tensor = torch.cat(
list(main_data.y_negative.values())
).to(device)
hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device)
else:
hard_negative_idx = torch.empty(0, dtype=torch.long).to(device)

Expand Down
19 changes: 10 additions & 9 deletions examples/link_prediction/graph_store/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _setup_dataloaders(
channel_size=sampling_worker_shared_channel_size,
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
# Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring.
use_list_output=True,
)

logger.info(f"---Rank {rank} finished setting up main loader for split={split}")
Expand Down Expand Up @@ -305,16 +307,15 @@ def _compute_loss(
query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device)
random_negative_batch_size = random_negative_data.batch_size

positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to(
device
)
repeated_query_node_idx = query_node_idx.repeat_interleave(
torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device)
)
# main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E]
# label_index and query_node_idx[anchor_index] directly. See
# homogeneous_training._compute_loss for why this equals the historical dict read.
positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E]
repeated_query_node_idx = query_node_idx[
main_data.y_positive.anchor_index.to(device) # [E]
]
if hasattr(main_data, "y_negative"):
hard_negative_idx: torch.Tensor = torch.cat(
list(main_data.y_negative.values())
).to(device)
hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device)
else:
hard_negative_idx = torch.empty(0, dtype=torch.long).to(device)

Expand Down
21 changes: 10 additions & 11 deletions examples/link_prediction/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def _setup_dataloaders(
# This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
# Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring.
use_list_output=True,
)

logger.info(f"---Rank {rank} finished setting up main loader")
Expand Down Expand Up @@ -223,18 +225,15 @@ def _compute_loss(
).to(device)
random_negative_batch_size = random_negative_data[labeled_node_type].batch_size

# main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor], even in the heterogeneous setting.
positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to(
device
)
# We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has
repeated_query_node_idx = query_node_idx.repeat_interleave(
torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device)
)
# main_data.y_positive is an AnchorLabels edge-list; read the co-indexed [E]
# label_index and query_node_idx[anchor_index] directly. See
# homogeneous_training._compute_loss for why this equals the historical dict read.
positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E]
repeated_query_node_idx = query_node_idx[
main_data.y_positive.anchor_index.to(device) # [E]
]
if hasattr(main_data, "y_negative"):
hard_negative_idx: torch.Tensor = torch.cat(
list(main_data.y_negative.values())
).to(device)
hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device)
else:
hard_negative_idx = torch.empty(0, dtype=torch.long).to(device)

Expand Down
27 changes: 16 additions & 11 deletions examples/link_prediction/homogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def _setup_dataloaders(
# This is done so that each process on the current machine which initializes a `main_loader` doesn't compete for memory, causing potential OOM
process_start_gap_seconds=process_start_gap_seconds,
shuffle=shuffle,
# Labels as an AnchorLabels edge-list; see the AnchorLabels class docstring.
use_list_output=True,
)

logger.info(f"---Rank {rank} finished setting up main loader")
Expand Down Expand Up @@ -190,18 +192,21 @@ def _compute_loss(
query_node_idx: torch.Tensor = torch.arange(main_data.batch_size).to(device)
random_negative_batch_size = random_negative_data.batch_size

# main_data.y_positive is a dict[query_node_local_index: int, labeled_node_local_indices: torch.Tensor]
positive_idx: torch.Tensor = torch.cat(list(main_data.y_positive.values())).to(
device
)
# We also extract a repeated query node index tensor which upsamples each query node based on the number of positives it has
repeated_query_node_idx = query_node_idx.repeat_interleave(
torch.tensor([len(v) for v in main_data.y_positive.values()]).to(device)
)
# main_data.y_positive is an AnchorLabels edge-list (use_list_output=True), two
# co-indexed [E] tensors: label_index holds the local label node per
# (anchor, label) pair, anchor_index the matching local anchor row. Pairs are
# grouped by anchor, so reading label_index/anchor_index directly yields the same
# (query, label) pairs as the historical dict read (torch.cat(list(values())) +
# repeat_interleave over per-anchor lengths). The within-anchor label order may
# differ, but the contrastive loss is order-invariant, so the result is
# equivalent. (This is the canonical explanation; the other link-prediction
# examples point here rather than restating it.)
positive_idx: torch.Tensor = main_data.y_positive.label_index.to(device) # [E]
repeated_query_node_idx = query_node_idx[
main_data.y_positive.anchor_index.to(device) # [E]
]
if hasattr(main_data, "y_negative"):
hard_negative_idx: torch.Tensor = torch.cat(
list(main_data.y_negative.values())
).to(device)
hard_negative_idx: torch.Tensor = main_data.y_negative.label_index.to(device)
else:
hard_negative_idx = torch.empty(0, dtype=torch.long).to(device)

Expand Down
3 changes: 2 additions & 1 deletion gigl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

__all__ = [
"AnchorLabels",
"DistABLPLoader",
"DistNeighborLoader",
"DistDataset",
Expand All @@ -17,7 +18,7 @@
build_dataset,
build_dataset_from_task_config_uri,
)
from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader
from gigl.distributed.dist_ablp_neighborloader import AnchorLabels, DistABLPLoader
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_partitioner import DistPartitioner
Expand Down
Loading