Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b83f760
feat(distributed): add GIGL_COLLATE_IMPL flag resolver for collate di…
Jun 23, 2026
64bb885
refactor(distributed): extract ablp label remap into _loop_set_labels…
Jun 23, 2026
894af30
feat(distributed): vectorize ablp label remap with searchsorted
Jun 23, 2026
b52dc0b
test(distributed): assert python vs vectorized collate label equivalence
Jun 23, 2026
226fe2b
chore(distributed): satisfy type/format gates for vectorized label remap
Jun 23, 2026
3586c08
test: add collate-equivalence comparison helper (homogeneous)
Jun 24, 2026
dd931bd
test: complete collate-equivalence helper (heterogeneous)
Jun 24, 2026
670bb43
test(distributed): add cross-impl batch-capture driver (D3)
Jun 24, 2026
6b712d0
fix(test): D3 driver — use COLLATE_IMPLS default, required test names…
Jun 24, 2026
4ff2135
test: ABLP homogeneous collate equivalence across GIGL_COLLATE_IMPL
Jun 24, 2026
52e03b7
test: ABLP heterogeneous + edge_dir in/out collate equivalence
Jun 24, 2026
be8cc46
Collate core: scaffold gigl_core.collate_core pybind11 extension
Jun 24, 2026
57e1cd3
Collate core: per-hop count padding helpers + C++ test target
Jun 24, 2026
5109d71
Collate core: homogeneous collate component-tensor builder
Jun 24, 2026
a3e77df
Collate core: heterogeneous collate (dict build, edge swap, padding)
Jun 24, 2026
59bc160
Collate core: pybind11 bindings for homogeneous/heterogeneous collate
Jun 24, 2026
00f8204
Collate core: dispatcher flag resolution + PyG assembly shim
Jun 24, 2026
e82b382
Collate core: dispatcher C++-path collate entry functions
Jun 24, 2026
e940c9f
Collate core: route both loaders' GLT body through GIGL_COLLATE_IMPL …
Jun 24, 2026
8d53d8b
Collate core: python-vs-cpp output equivalence tests (CORA/DBLP, edge…
Jun 24, 2026
3614073
feat(distributed): opt-in collate-vs-recv timing accumulator on BaseD…
Jun 24, 2026
826c967
feat(launcher): forward GIGL_COLLATE_IMPL env into Vertex AI worker spec
Jun 24, 2026
b7dd1fc
fix(distributed): lazy-import collate_core to avoid crash when extens…
Jun 24, 2026
dcf7e94
feat(ablp): add GIGL_ABLP_LABEL_FORMAT selector (dict|edge_list)
Jun 24, 2026
1e4d529
feat(ablp): add AnchorLabels dense edge-list container + per-tensor r…
Jun 24, 2026
8f2fd44
feat(ablp): add edge_list_set_labels kernel, parity-tested vs loop or…
Jun 24, 2026
509a1c7
feat(ablp): route GIGL_ABLP_LABEL_FORMAT=edge_list through _set_labels
Jun 24, 2026
6602187
feat(ablp): forward GIGL_ABLP_LABEL_FORMAT to sampling workers
Jun 24, 2026
18b0215
fix(orchestration): propagate loader-selection env vars to component …
Jun 24, 2026
1960cd1
fix(ablp): place anchor_of_entry on the label tensor's device
Jun 25, 2026
e320878
docs(distributed): tidy collate/label comments and docstrings
Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions gigl-core/core/collation/collate_core.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#include "collate_core.h"

#include <torch/nn/functional/padding.h>

namespace {

// Look up a key in the message; return nullopt if absent (mirrors `key in msg`).
std::optional<torch::Tensor> tryGet(const std::unordered_map<std::string, torch::Tensor>& msg, const std::string& key) {
auto it = msg.find(key);
if (it == msg.end()) {
return std::nullopt;
}
return it->second;
}

} // namespace

namespace gigl {
namespace collation {

torch::Tensor padCount(const torch::Tensor& counts, int64_t targetLen) {
TORCH_CHECK(counts.dim() == 1, "per-hop count tensor must be 1-D");
const int64_t current = counts.size(0);
TORCH_CHECK(current <= targetLen, "per-hop count length exceeds target");
if (current == targetLen) {
return counts;
}
namespace F = torch::nn::functional;
return F::pad(counts, F::PadFuncOptions({0, targetLen - current}));
}

torch::Tensor zeroCount(int64_t targetLen, const torch::TensorOptions& options) {
return torch::zeros({targetLen}, options);
}

HomogeneousCollateResult collateHomogeneous(const torch::Tensor& ids,
const torch::Tensor& rows,
const torch::Tensor& cols,
const std::optional<torch::Tensor>& eids,
const std::optional<torch::Tensor>& nfeats,
const std::optional<torch::Tensor>& efeats,
const std::optional<torch::Tensor>& batch,
const std::optional<torch::Tensor>& numSampledNodes,
const std::optional<torch::Tensor>& numSampledEdges) {
HomogeneousCollateResult result;
result.node = ids;
result.edgeIndex = torch::stack({rows, cols});
result.eid = eids;
result.x = nfeats;
result.edgeAttr = efeats;
result.batch = batch;
result.numSampledNodes = numSampledNodes;
result.numSampledEdges = numSampledEdges;
return result;
}

HeterogeneousCollateResult collateHeterogeneous(
const std::unordered_map<std::string, torch::Tensor>& msg,
const std::vector<std::string>& nodeTypes,
const std::vector<std::pair<std::string, EdgeTypeArray>>& edgeTypeStrToRev,
const std::vector<EdgeTypeArray>& reversedEdgeTypes,
const std::string& inputType,
bool hasBatch,
int64_t batchSize) {
HeterogeneousCollateResult result;

// --- nodes (dist_loader.py:356-365) ---
for (const auto& ntype : nodeTypes) {
if (auto ids = tryGet(msg, ntype + ".ids")) {
result.node[ntype] = *ids;
}
if (auto nfeat = tryGet(msg, ntype + ".nfeats")) {
result.x[ntype] = *nfeat;
}
if (auto nsn = tryGet(msg, ntype + ".num_sampled_nodes")) {
result.numSampledNodes[ntype] = *nsn; // padded below
}
}

// --- edges + edge_dir swap (dist_loader.py:367-382) ---
EdgeTypeMap<torch::Tensor> rowDict;
EdgeTypeMap<torch::Tensor> colDict;
for (const auto& [etypeStr, revEtype] : edgeTypeStrToRev) {
auto rows = tryGet(msg, etypeStr + ".rows");
auto cols = tryGet(msg, etypeStr + ".cols");
if (rows && cols) {
// The edge index is reversed: row<-cols, col<-rows.
rowDict[revEtype] = *cols;
colDict[revEtype] = *rows;
}
if (auto eids = tryGet(msg, etypeStr + ".eids")) {
result.edge[revEtype] = *eids;
}
if (auto nse = tryGet(msg, etypeStr + ".num_sampled_edges")) {
result.numSampledEdges[revEtype] = *nse; // padded below
}
if (auto efeat = tryGet(msg, etypeStr + ".efeats")) {
result.edgeAttr[revEtype] = *efeat;
}
}

// --- batch (dist_loader.py:389-405); inputType is the anchor node type ---
// GiGL loaders are NODE sampling; only the {inputType: batch} entry is produced.
// batch_labels (nlabels) are not present for these loaders and are ignored here.
// GLT writes the "{inputType}.batch" key ONLY when output.batch is not None
// (dist_neighbor_sampler.py:781-783); when absent, GLT's NODE branch FALLS BACK to
// node_dict[inputType][:batch_size] (dist_loader.py:397-399). We must reproduce that
// fallback here, so the kernel takes batchSize and slices the anchor node ids.
if (hasBatch) {
if (auto b = tryGet(msg, inputType + ".batch")) {
result.batch[inputType] = *b;
} else {
// Slice the first batchSize anchor node ids (matches node_dict[inputType][:batch_size]).
auto nodeIt = result.node.find(inputType);
TORCH_CHECK(nodeIt != result.node.end(), "batch fallback requires anchor node ids for inputType");
result.batch[inputType] = nodeIt->second.slice(/*dim=*/0, /*start=*/0, /*end=*/batchSize);
}
}

// --- get_edge_index empty-fill (sampler/base.py:294-301) ---
// GLT fills absent edge types with torch.empty((2,0)).to(self.device), where
// self.device == to_device (dist_loader.py:417, sampler/base.py:299). Derive the
// device from any present edge tensor; if NO edges were sampled at all, fall back to
// any present NODE tensor's device (node tensors are on to_device after the dispatcher's
// _move_msg_to_device). A bare CPU default would diverge from GLT on CUDA when a batch
// has zero sampled edges.
torch::TensorOptions edgeOpts = torch::TensorOptions().dtype(torch::kInt64);
bool deviceFound = false;
for (const auto& [et, t] : rowDict) {
edgeOpts = torch::TensorOptions().dtype(torch::kInt64).device(t.device());
deviceFound = true;
break;
}
if (!deviceFound) {
for (const auto& [nt, t] : result.node) {
edgeOpts = torch::TensorOptions().dtype(torch::kInt64).device(t.device());
deviceFound = true;
break;
}
}
for (const auto& revEtype : reversedEdgeTypes) {
auto rIt = rowDict.find(revEtype);
if (rIt != rowDict.end()) {
result.edgeIndex[revEtype] = torch::stack({rIt->second, colDict.at(revEtype)});
} else {
result.edgeIndex[revEtype] = torch::empty({2, 0}, edgeOpts);
}
}

// --- num_sampled_edges padding (transform.py:70-90) ---
int64_t numHops = 0;
for (const auto& [et, t] : result.numSampledEdges) {
numHops = std::max<int64_t>(numHops, t.size(0));
}
for (const auto& revEtype : reversedEdgeTypes) {
auto edgeIndexDevice = result.edgeIndex.at(revEtype).device();
auto countOpts = torch::TensorOptions().dtype(torch::kInt64).device(edgeIndexDevice);
auto it = result.numSampledEdges.find(revEtype);
if (it == result.numSampledEdges.end()) {
result.numSampledEdges[revEtype] = zeroCount(numHops, countOpts);
} else {
result.numSampledEdges[revEtype] = padCount(it->second, numHops);
}
}

// --- num_sampled_nodes padding (transform.py:97-104) ---
// PyG iterates node types present in the sampler output's node dict.
for (const auto& ntype : nodeTypes) {
auto nodeIt = result.node.find(ntype);
if (nodeIt == result.node.end()) {
continue; // node type absent from this batch's node dict
}
auto nodeDevice = nodeIt->second.device();
auto countOpts = torch::TensorOptions().dtype(torch::kInt64).device(nodeDevice);
auto it = result.numSampledNodes.find(ntype);
if (it == result.numSampledNodes.end()) {
result.numSampledNodes[ntype] = zeroCount(numHops + 1, countOpts);
} else {
result.numSampledNodes[ntype] = padCount(it->second, numHops + 1);
}
}

return result;
}

} // namespace collation
} // namespace gigl
111 changes: 111 additions & 0 deletions gigl-core/core/collation/collate_core.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#pragma once

// Generic collation kernel for distributed neighbor loaders.
//
// Reproduces the per-batch tensor wrangling that a distributed neighbor loader
// performs when turning a flat sampler message (per-type id/feature/edge tensors)
// into the component tensors of a graph batch: per-type dict assembly, the
// edge-direction row/col swap, empty-edge filling, and per-hop count padding.
//
// All input tensors are assumed to already reside on the target device; this
// kernel never issues device transfers. Pure C++ lives here and in collate_core.cpp;
// python_collate_core.cpp handles only Python<->C++ type conversion.

#include <array>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include <torch/torch.h>

namespace gigl {
namespace collation {

// Right-pad a 1-D per-hop count tensor with zeros to `targetLen` on its own device.
// Mirrors torch.nn.functional.pad(t, (0, targetLen - t.size(0))).
// Precondition: counts.dim() == 1 and counts.size(0) <= targetLen.
torch::Tensor padCount(const torch::Tensor& counts, int64_t targetLen);

// Build a length-`targetLen` zero count vector with the given options (dtype/device).
// Mirrors torch.tensor([0]*targetLen, device=...).
torch::Tensor zeroCount(int64_t targetLen, const torch::TensorOptions& options);

// Component tensors for a homogeneous graph batch. Optionals are nullopt when the
// corresponding sampler tensor was absent; the Python shim maps these to None.
struct HomogeneousCollateResult {
torch::Tensor node; // local->global ids
torch::Tensor edgeIndex; // [2, E]
std::optional<torch::Tensor> eid;
std::optional<torch::Tensor> x; // node features
std::optional<torch::Tensor> edgeAttr; // edge features
std::optional<torch::Tensor> batch;
std::optional<torch::Tensor> numSampledNodes; // passed through verbatim
std::optional<torch::Tensor> numSampledEdges; // passed through verbatim
};

// Reproduces the GLT homogeneous collate body + to_data assembly (no padding).
// `rows`/`cols` are stacked verbatim into edgeIndex = stack([rows, cols]); the caller
// performs the edge_dir swap by choosing which sampler tensor to pass as each argument
// (mirrors SamplerOutput(ids, cols, rows, ...) at dist_loader.py:446).
// All tensors are assumed already on the target device; no transfers are issued.
HomogeneousCollateResult collateHomogeneous(
const torch::Tensor& ids,
const torch::Tensor& rows,
const torch::Tensor& cols,
const std::optional<torch::Tensor>& eids,
const std::optional<torch::Tensor>& nfeats,
const std::optional<torch::Tensor>& efeats,
const std::optional<torch::Tensor>& batch,
const std::optional<torch::Tensor>& numSampledNodes,
const std::optional<torch::Tensor>& numSampledEdges);

using EdgeTypeArray = std::array<std::string, 3>;

struct EdgeTypeArrayHash {
std::size_t operator()(const EdgeTypeArray& e) const noexcept {
std::size_t h = 1469598103934665603ULL; // FNV-1a basis
for (const auto& s : e) {
for (char c : s) {
h ^= static_cast<std::size_t>(static_cast<unsigned char>(c));
h *= 1099511628211ULL;
}
h ^= 0x9e3779b97f4a7c15ULL; // separator between tuple fields
}
return h;
}
};

template <typename V>
using EdgeTypeMap = std::unordered_map<EdgeTypeArray, V, EdgeTypeArrayHash>;

// Component tensors for a heterogeneous graph batch. Keys mirror GLT/PyG:
// node/x keyed by node-type string; edges keyed by the reversed EdgeTypeArray.
struct HeterogeneousCollateResult {
std::unordered_map<std::string, torch::Tensor> node;
EdgeTypeMap<torch::Tensor> edgeIndex; // [2,E]; absent types filled [2,0]
EdgeTypeMap<torch::Tensor> edge; // eids; only present types
std::unordered_map<std::string, torch::Tensor> x; // nfeats; only present types
EdgeTypeMap<torch::Tensor> edgeAttr; // efeats; only present types
std::unordered_map<std::string, torch::Tensor> batch; // {inputType: tensor} or empty
std::unordered_map<std::string, torch::Tensor> numSampledNodes; // padded
EdgeTypeMap<torch::Tensor> numSampledEdges; // padded
};

// Reproduces the GLT heterogeneous collate body (dist_loader.py:351-420) and the
// tensor-level parts of to_hetero_data (transform.py:60-115): per-type dict build,
// edge_dir row/col swap, get_edge_index empty-fill, and num_sampled_* padding.
// Metadata is assumed already stripped (GiGL strips #META keys upstream), so no
// metadata handling is performed. All tensors are assumed on the target device.
HeterogeneousCollateResult collateHeterogeneous(
const std::unordered_map<std::string, torch::Tensor>& msg,
const std::vector<std::string>& nodeTypes,
const std::vector<std::pair<std::string, EdgeTypeArray>>& edgeTypeStrToRev,
const std::vector<EdgeTypeArray>& reversedEdgeTypes,
const std::string& inputType,
bool hasBatch,
int64_t batchSize); // anchor-id slice length used when the ".batch" key is absent

} // namespace collation
} // namespace gigl
Loading