diff --git a/conf.py b/conf.py
index 269618878..4670096b7 100644
--- a/conf.py
+++ b/conf.py
@@ -31,6 +31,7 @@
"sphinx_design", # Needed by themes
"myst_nb", # Support for rendering Jupyter Notebooks: https://myst-nb.readthedocs.io/en/v1.2.0/
"sphinx_copybutton", # Support for copying code snippets: https://sphinx-copybutton.readthedocs.io/
+ "sphinxcontrib.mermaid", # Support for Mermaid diagrams in MyST and GitHub Markdown.
]
autoapi_type = 'python'
@@ -61,6 +62,7 @@
myst_enable_extensions = [
"html_image", # Convert
tags in markdown files; https://myst-parser.readthedocs.io/en/latest/syntax/optional.html#html-images
]
+myst_fence_as_directive = ["mermaid"]
include_patterns = [
"docs/**",
diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst
index e8b2950fa..3b6e91bea 100644
--- a/docs/user_guide/index.rst
+++ b/docs/user_guide/index.rst
@@ -20,6 +20,8 @@ Welcome to the GiGL User Guide. This guide provides detailed documentation to he
overview/what_is_gigl
overview/architecture
+ overview/in_memory_subgraph_sampling
+ overview/graph_store_conversion
.. toctree::
diff --git a/docs/user_guide/overview/graph_store_conversion.md b/docs/user_guide/overview/graph_store_conversion.md
new file mode 100644
index 000000000..0015a7fdb
--- /dev/null
+++ b/docs/user_guide/overview/graph_store_conversion.md
@@ -0,0 +1,341 @@
+# Converting Colocated Loops To Graph Store
+
+This guide is for training and inference loops that already use GiGL's in-memory sampling path with `DistDataset`,
+`DistNeighborLoader`, or `DistABLPLoader`.
+
+The goal is not to rewrite model code. The goal is to move graph ownership out of the compute machines and into a
+storage cluster, then let compute processes sample from that graph over RPC.
+
+The link prediction examples under `examples/link_prediction` are useful references, but treat them as examples of the
+pattern rather than a checklist to diff line by line.
+
+## Mental Model
+
+Colocated mode puts graph storage and model compute on the same machines:
+
+```mermaid
+flowchart LR
+ subgraph colocated["Colocated cluster"]
+ machine0["Machine 0
DistDataset partition
Sampling workers
GPU processes"]
+ machine1["Machine 1
DistDataset partition
Sampling workers
GPU processes"]
+ end
+```
+
+Graph store mode separates those roles:
+
+```mermaid
+flowchart LR
+ subgraph storage["Storage cluster"]
+ storage0["Storage 0
DistDataset partition
Graph-store server"]
+ storage1["Storage 1
DistDataset partition
Graph-store server"]
+ storageN["Storage N
DistDataset partition
Graph-store server"]
+ end
+
+ subgraph compute["Compute cluster"]
+ compute0["Compute 0
GPU processes
RemoteDistDataset clients"]
+ compute1["Compute 1
GPU processes
RemoteDistDataset clients"]
+ end
+
+ compute0 <-->|RPC sampling and feature fetches| storage0
+ compute0 <-->|RPC sampling and feature fetches| storage1
+ compute1 <-->|RPC sampling and feature fetches| storage1
+ compute1 <-->|RPC sampling and feature fetches| storageN
+```
+
+`RemoteDistDataset` is the important name to read carefully. It is not a remote copy of the graph and it is not built
+the same way as `DistDataset`. It is a client handle that lets a compute process talk to storage nodes.
+
+## The Main Ordering Difference
+
+In colocated mode, the compute machine builds a real `DistDataset` before spawning local GPU processes:
+
+```mermaid
+sequenceDiagram
+ participant Parent as Compute parent
+ participant Dataset as DistDataset
+ participant Worker as Per-GPU worker
+
+ Parent->>Parent: Initialize temporary distributed group
+ Parent->>Parent: Discover rank, world size, master IP, and ports
+ Parent->>Dataset: build_dataset_from_task_config_uri(...)
+ Dataset-->>Parent: Local graph partition
+ Parent->>Worker: spawn(...)
+ Worker->>Dataset: Create loaders from local DistDataset
+ Worker->>Worker: Run model code
+```
+
+In graph store mode, the storage cluster builds the real `DistDataset`. Compute workers only create a
+`RemoteDistDataset` after the graph-store cluster and the per-process client setup exist:
+
+```mermaid
+sequenceDiagram
+ participant Storage as Storage process
+ participant Server as Graph-store server
+ participant Parent as Compute parent
+ participant Worker as Compute worker
+ participant Remote as RemoteDistDataset
+
+ Storage->>Storage: Build storage-side DistDataset partition
+ Storage->>Server: Start server sessions
+ Parent->>Parent: Initialize global graph-store group
+ Parent->>Parent: get_graph_store_info()
+ Parent->>Parent: Destroy temporary global group
+ Parent->>Worker: spawn(...)
+ Worker->>Worker: init_compute_process(local_rank, cluster_info)
+ Worker->>Remote: RemoteDistDataset(cluster_info, local_rank)
+ Remote->>Server: Fetch loader input over RPC
+ Server-->>Remote: Server-keyed input
+ Worker->>Remote: Create loaders
+ Worker->>Worker: Run model code
+```
+
+That timing is the core migration point. `DistDataset` is constructed once per graph-owning machine. `RemoteDistDataset`
+is constructed inside each compute process, after `init_compute_process(...)` has connected that process to the storage
+servers and initialized the compute process group.
+
+## Compute-Side Shape
+
+A graph-store compute parent usually only needs the cluster topology before spawning workers:
+
+```python
+torch.distributed.init_process_group(backend="gloo")
+cluster_info = get_graph_store_info()
+torch.distributed.destroy_process_group()
+
+mp.spawn(
+ fn=run_worker,
+ args=(cluster_info, ...),
+ nprocs=local_world_size,
+)
+```
+
+The worker then joins the graph-store client setup and creates the remote dataset handle:
+
+```python
+def run_worker(local_rank: int, cluster_info: GraphStoreInfo, ...) -> None:
+ init_compute_process(local_rank, cluster_info)
+ dataset = RemoteDistDataset(cluster_info, local_rank)
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+```
+
+Use `rank` and `world_size` from `torch.distributed` after `init_compute_process(...)`. At that point they describe the
+compute process group, including all per-GPU processes across compute machines.
+
+## Data Access Changes
+
+In colocated mode, loader input is usually read directly from local dataset fields:
+
+```mermaid
+flowchart LR
+ train["dataset.train_node_ids"] --> ablp["DistABLPLoader"]
+ val["dataset.val_node_ids"] --> ablp
+ test["dataset.test_node_ids"] --> ablp
+ nodes["dataset.node_ids"] --> neighbor["DistNeighborLoader"]
+```
+
+In graph-store mode, compute processes fetch the input from storage:
+
+```mermaid
+flowchart LR
+ fetchTrain["dataset.fetch_ablp_input(split='train', ...)"] --> remoteAblp["DistABLPLoader"]
+ fetchVal["dataset.fetch_ablp_input(split='val', ...)"] --> remoteAblp
+ fetchTest["dataset.fetch_ablp_input(split='test', ...)"] --> remoteAblp
+ fetchNodes["dataset.fetch_node_ids(...)"] --> remoteNeighbor["DistNeighborLoader"]
+```
+
+For anchor-based link prediction, fetch the main batch input from storage:
+
+```python
+ablp_input = dataset.fetch_ablp_input(
+ split="train",
+ rank=torch.distributed.get_rank(),
+ world_size=torch.distributed.get_world_size(),
+ anchor_node_type=anchor_node_type, # heterogeneous only
+ supervision_edge_type=supervision_edge_type, # heterogeneous only
+)
+
+main_loader = DistABLPLoader(
+ dataset=dataset,
+ num_neighbors=num_neighbors,
+ input_nodes=ablp_input,
+ batch_size=main_batch_size,
+ num_workers=sampling_workers_per_process,
+)
+```
+
+For random negatives or inference seed nodes, fetch node IDs from storage:
+
+```python
+node_ids = dataset.fetch_node_ids(
+ rank=torch.distributed.get_rank(),
+ world_size=torch.distributed.get_world_size(),
+ node_type=node_type, # heterogeneous only
+)
+
+loader = DistNeighborLoader(
+ dataset=dataset,
+ num_neighbors=num_neighbors,
+ input_nodes=(node_type, node_ids), # use node_ids directly for homogeneous graphs
+ batch_size=batch_size,
+ num_workers=sampling_workers_per_process,
+)
+```
+
+The loader input shape changes because storage nodes own the graph partitions:
+
+| Loader | Colocated input | Graph-store input |
+| -------------------- | -------------------------------- | ------------------------------------------------------ |
+| `DistABLPLoader` | `Tensor` or `(NodeType, Tensor)` | `dict[int, ABLPInputNodes]` |
+| `DistNeighborLoader` | `Tensor` or `(NodeType, Tensor)` | `dict[int, Tensor]` or `(NodeType, dict[int, Tensor])` |
+
+The `int` keys in graph-store inputs are storage ranks. Empty tensors for a storage rank are normal; they mean that this
+compute rank has no input assigned to that storage server.
+
+## Training Notes
+
+Training needs storage-side splits. The storage process has to build its `DistDataset` with the splitter that creates
+the train, validation, and test node IDs. If it does not, calls such as `fetch_ablp_input(split="train", ...)` have
+nothing to fetch.
+
+The rest of the training loop should look familiar:
+
+```mermaid
+flowchart TD
+ ablpInput["fetch ABLP input"] --> ablpLoader["DistABLPLoader"]
+ randomInput["fetch random-negative input"] --> neighborLoader["DistNeighborLoader"]
+ ablpLoader --> batch["main batch + random negative batch"]
+ neighborLoader --> batch
+ batch --> model["model + loss"]
+```
+
+Model construction, loss calculation, optimizer steps, validation, metric writing, and model saving usually do not need
+graph-store-specific behavior. The main exceptions are places that assumed a colocated machine rank or assumed the
+dataset object already existed before local worker processes were spawned.
+
+For heterogeneous ABLP, be explicit about edge direction. If the storage splitter is configured with inbound sampling,
+the supervision edge type passed to the splitter may need to be the reverse of the user-facing prediction edge type.
+This is a graph direction issue, not a graph-store-specific model issue.
+
+## Inference Notes
+
+Inference usually does not need train, validation, or test splits. The compute worker fetches the target node IDs from
+storage and passes them to `DistNeighborLoader`.
+
+```mermaid
+flowchart LR
+ fetchInference["fetch_node_ids(node_type=...)"] --> inferenceLoader["DistNeighborLoader"]
+ inferenceLoader --> forward["model forward"]
+ forward --> export["embedding or prediction export"]
+```
+
+Model loading, batching, embedding export, and BigQuery loading can usually stay the same. Watch for any code that uses
+the old colocated machine rank to decide which process writes shared outputs. In graph-store compute workers,
+`torch.distributed.get_rank() == 0` is the usual lead-process check after `init_compute_process(...)`.
+
+## Storage And Resource Config
+
+Graph store jobs run two entrypoints: one for compute and one for storage.
+
+```mermaid
+flowchart TD
+ taskConfig["Task config"] --> computeCommand["trainerConfig.command
or inferencerConfig.command"]
+ taskConfig --> storageCommand["graphStoreStorageConfig.command"]
+ computeCommand --> computeLoop["Compute-side training
or inference loop"]
+ storageCommand --> storageLoop["Storage entrypoint
Build DistDataset
Start server sessions"]
+```
+
+For training, the storage args need enough information to create splits:
+
+```yaml
+trainerConfig:
+ command: python -m my_package.training_graph_store
+ graphStoreStorageConfig:
+ command: python -m my_package.storage_main
+ storageArgs:
+ sample_edge_direction: "in"
+ splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter"
+ splitter_kwargs: >-
+ {
+ "sampling_direction": "in",
+ "should_convert_labels_to_edges": True,
+ "num_val": 0.1,
+ "num_test": 0.1
+ }
+ num_server_sessions: "1"
+```
+
+For inference, the storage args are often simpler because inference can query all target nodes:
+
+```yaml
+inferencerConfig:
+ command: python -m my_package.inference_graph_store
+ graphStoreStorageConfig:
+ command: python -m my_package.storage_main
+ storageArgs:
+ sample_edge_direction: "in"
+ num_server_sessions: "1"
+```
+
+The resource config also changes from a single Vertex AI pool to separate graph-store and compute pools:
+
+```yaml
+trainer_resource_config:
+ vertex_ai_graph_store_trainer_config:
+ graph_store_pool:
+ machine_type: n2-highmem-32
+ gpu_type: ACCELERATOR_TYPE_UNSPECIFIED
+ gpu_limit: 0
+ num_replicas: 2
+ compute_pool:
+ machine_type: n1-standard-16
+ gpu_type: NVIDIA_TESLA_T4
+ gpu_limit: 2
+ num_replicas: 2
+```
+
+Use the graph-store pool for memory-heavy graph serving. Use the compute pool for model work, usually with GPUs. If you
+need to override how many compute processes run per compute machine, use `compute_cluster_local_world_size` in the
+graph-store resource config.
+
+`num_server_sessions` should match how many separate compute process groups will connect to storage over the lifetime of
+the job. Training commonly uses one session. Heterogeneous inference may use one session per node type if the inference
+loop runs node types sequentially with separate spawned process groups.
+
+## Shutdown Order
+
+Shut down loaders before shutting down the compute process:
+
+```mermaid
+sequenceDiagram
+ participant Worker as Compute worker
+ participant Loader as Dist loaders
+ participant Storage as Storage servers
+ participant Client as Graph-store client
+
+ Worker->>Loader: loader.shutdown()
+ Loader->>Storage: destroy sampling channels
+ Storage-->>Loader: channel teardown complete
+ Worker->>Client: shutdown_compute_process()
+ Client->>Client: shutdown RPC client
+ Client->>Client: destroy compute process group
+```
+
+Loader shutdown tears down server-side sampling channels. `shutdown_compute_process()` tears down the graph-store client
+and the compute process group. Calling them in the opposite order leaves storage with live sampling state and can make
+shutdown slow or noisy.
+
+## Where To Look Next
+
+The link prediction examples show the pattern in real code:
+
+- Colocated training and inference:
+ [`examples/link_prediction`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction)
+- Graph-store training and inference:
+ [`examples/link_prediction/graph_store`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction/graph_store)
+- Graph-store task and resource configs:
+ [`examples/link_prediction/graph_store/configs`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction/graph_store/configs)
+
+The storage entrypoint in `examples/link_prediction/graph_store/storage_main.py` is also useful if you need to write a
+custom storage command.
diff --git a/docs/user_guide/overview/in_memory_subgraph_sampling.md b/docs/user_guide/overview/in_memory_subgraph_sampling.md
index becc28c3b..e88f8c587 100644
--- a/docs/user_guide/overview/in_memory_subgraph_sampling.md
+++ b/docs/user_guide/overview/in_memory_subgraph_sampling.md
@@ -29,10 +29,15 @@ The main abstractions used by the in-memory path are:
- {py:class}`gigl.distributed.dist_dataset.DistDataset` for colocated sampling, where each machine stores its graph
partition locally.
+- {py:class}`gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset` for graph-store sampling, where storage
+ nodes own graph partitions and compute processes sample from them over RPC.
- {py:class}`gigl.distributed.distributed_neighborloader.DistNeighborLoader` for standard neighborhood sampling.
- {py:class}`gigl.distributed.dist_ablp_neighborloader.DistABLPLoader` for anchor-based link prediction batches with
positives and negatives.
+If you already have a colocated in-memory loop and want to move graph storage onto separate machines, see
+[Converting Colocated Loops To Graph Store](graph_store_conversion.md).
+
## Example Implementations
The link prediction examples are the clearest reference implementations for the current runtime:
diff --git a/examples/link_prediction/README.md b/examples/link_prediction/README.md
index cd730f595..afcd053a8 100644
--- a/examples/link_prediction/README.md
+++ b/examples/link_prediction/README.md
@@ -23,6 +23,9 @@ are example inference and training loops for the DBLP dataset. The DBLP dataset
You can follow along with [dblp.ipynb](./dblp.ipynb) to run an e2e GiGL pipeline on the DBLP dataset. It will guide you
through running each component: `config_populator` -> `data_preprocessor` -> `trainer` -> `inferencer`
+For guidance on adapting an existing colocated in-memory loop to graph-store mode, see the
+[graph-store conversion guide](../../docs/user_guide/overview/graph_store_conversion.md).
+
```{toctree}
:maxdepth: 2
:hidden:
diff --git a/pyproject.toml b/pyproject.toml
index b22e1f0e6..3fe88014c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -136,7 +136,7 @@ docs = [
"sphinx-rtd-theme==2.0.0",
"pydata-sphinx-theme==0.16.1",
"sphinx-tabs==3.4.5",
-
+ "sphinxcontrib-mermaid==2.0.2",
]
typing-stubs = [
"pandas-stubs==2.2.2.240807",
diff --git a/uv.lock b/uv.lock
index 1909a0793..121e14b2f 100644
--- a/uv.lock
+++ b/uv.lock
@@ -794,6 +794,7 @@ dev = [
{ name = "sphinx-hoverxref" },
{ name = "sphinx-rtd-theme" },
{ name = "sphinx-tabs" },
+ { name = "sphinxcontrib-mermaid" },
{ name = "ty" },
{ name = "types-psutil" },
{ name = "types-pyyaml" },
@@ -816,6 +817,7 @@ docs = [
{ name = "sphinx-hoverxref" },
{ name = "sphinx-rtd-theme" },
{ name = "sphinx-tabs" },
+ { name = "sphinxcontrib-mermaid" },
]
gigl-core-build-backend = [
{ name = "pybind11" },
@@ -920,6 +922,7 @@ dev = [
{ name = "sphinx-hoverxref", specifier = "==1.3.0" },
{ name = "sphinx-rtd-theme", specifier = "==2.0.0" },
{ name = "sphinx-tabs", specifier = "==3.4.5" },
+ { name = "sphinxcontrib-mermaid", specifier = "==2.0.2" },
{ name = "ty", specifier = "==0.0.31" },
{ name = "types-psutil", specifier = "==7.0.0.20250401" },
{ name = "types-pyyaml", specifier = "~=6.0.12" },
@@ -942,6 +945,7 @@ docs = [
{ name = "sphinx-hoverxref", specifier = "==1.3.0" },
{ name = "sphinx-rtd-theme", specifier = "==2.0.0" },
{ name = "sphinx-tabs", specifier = "==3.4.5" },
+ { name = "sphinxcontrib-mermaid", specifier = "==2.0.2" },
]
gigl-core-build-backend = [
{ name = "pybind11", specifier = ">=2.12" },
@@ -3810,6 +3814,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" },
]
+[[package]]
+name = "sphinxcontrib-mermaid"
+version = "2.0.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "jinja2" },
+ { name = "pyyaml" },
+ { name = "sphinx" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/19/75/3a1cc926da8c563c58ddc124a7b3fe5ccadcae96c96e3a6f8ac3653a210a/sphinxcontrib_mermaid-2.0.2.tar.gz", hash = "sha256:f09576c78ca93fa0e3034fd9c45aaffa7c44ab449de9c43b8b8d262afe52bc66", size = 19265, upload-time = "2026-05-05T13:59:02.959Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/16/8d/93be7e0f7fa915a576859b3bfac7a7baa3303181c44d7db7eefbd3e8a69f/sphinxcontrib_mermaid-2.0.2-py3-none-any.whl", hash = "sha256:d862e514991279fb4816302c5cfe167d2557bf3ce7125ae0cb47dac80a0f46ce", size = 14094, upload-time = "2026-05-05T13:59:01.585Z" },
+]
+
[[package]]
name = "sphinxcontrib-qthelp"
version = "2.0.0"