diff --git a/conf.py b/conf.py index 269618878..4670096b7 100644 --- a/conf.py +++ b/conf.py @@ -31,6 +31,7 @@ "sphinx_design", # Needed by themes "myst_nb", # Support for rendering Jupyter Notebooks: https://myst-nb.readthedocs.io/en/v1.2.0/ "sphinx_copybutton", # Support for copying code snippets: https://sphinx-copybutton.readthedocs.io/ + "sphinxcontrib.mermaid", # Support for Mermaid diagrams in MyST and GitHub Markdown. ] autoapi_type = 'python' @@ -61,6 +62,7 @@ myst_enable_extensions = [ "html_image", # Convert tags in markdown files; https://myst-parser.readthedocs.io/en/latest/syntax/optional.html#html-images ] +myst_fence_as_directive = ["mermaid"] include_patterns = [ "docs/**", diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst index e8b2950fa..3b6e91bea 100644 --- a/docs/user_guide/index.rst +++ b/docs/user_guide/index.rst @@ -20,6 +20,8 @@ Welcome to the GiGL User Guide. This guide provides detailed documentation to he overview/what_is_gigl overview/architecture + overview/in_memory_subgraph_sampling + overview/graph_store_conversion .. toctree:: diff --git a/docs/user_guide/overview/graph_store_conversion.md b/docs/user_guide/overview/graph_store_conversion.md new file mode 100644 index 000000000..0015a7fdb --- /dev/null +++ b/docs/user_guide/overview/graph_store_conversion.md @@ -0,0 +1,341 @@ +# Converting Colocated Loops To Graph Store + +This guide is for training and inference loops that already use GiGL's in-memory sampling path with `DistDataset`, +`DistNeighborLoader`, or `DistABLPLoader`. + +The goal is not to rewrite model code. The goal is to move graph ownership out of the compute machines and into a +storage cluster, then let compute processes sample from that graph over RPC. + +The link prediction examples under `examples/link_prediction` are useful references, but treat them as examples of the +pattern rather than a checklist to diff line by line. + +## Mental Model + +Colocated mode puts graph storage and model compute on the same machines: + +```mermaid +flowchart LR + subgraph colocated["Colocated cluster"] + machine0["Machine 0
DistDataset partition
Sampling workers
GPU processes"] + machine1["Machine 1
DistDataset partition
Sampling workers
GPU processes"] + end +``` + +Graph store mode separates those roles: + +```mermaid +flowchart LR + subgraph storage["Storage cluster"] + storage0["Storage 0
DistDataset partition
Graph-store server"] + storage1["Storage 1
DistDataset partition
Graph-store server"] + storageN["Storage N
DistDataset partition
Graph-store server"] + end + + subgraph compute["Compute cluster"] + compute0["Compute 0
GPU processes
RemoteDistDataset clients"] + compute1["Compute 1
GPU processes
RemoteDistDataset clients"] + end + + compute0 <-->|RPC sampling and feature fetches| storage0 + compute0 <-->|RPC sampling and feature fetches| storage1 + compute1 <-->|RPC sampling and feature fetches| storage1 + compute1 <-->|RPC sampling and feature fetches| storageN +``` + +`RemoteDistDataset` is the important name to read carefully. It is not a remote copy of the graph and it is not built +the same way as `DistDataset`. It is a client handle that lets a compute process talk to storage nodes. + +## The Main Ordering Difference + +In colocated mode, the compute machine builds a real `DistDataset` before spawning local GPU processes: + +```mermaid +sequenceDiagram + participant Parent as Compute parent + participant Dataset as DistDataset + participant Worker as Per-GPU worker + + Parent->>Parent: Initialize temporary distributed group + Parent->>Parent: Discover rank, world size, master IP, and ports + Parent->>Dataset: build_dataset_from_task_config_uri(...) + Dataset-->>Parent: Local graph partition + Parent->>Worker: spawn(...) + Worker->>Dataset: Create loaders from local DistDataset + Worker->>Worker: Run model code +``` + +In graph store mode, the storage cluster builds the real `DistDataset`. Compute workers only create a +`RemoteDistDataset` after the graph-store cluster and the per-process client setup exist: + +```mermaid +sequenceDiagram + participant Storage as Storage process + participant Server as Graph-store server + participant Parent as Compute parent + participant Worker as Compute worker + participant Remote as RemoteDistDataset + + Storage->>Storage: Build storage-side DistDataset partition + Storage->>Server: Start server sessions + Parent->>Parent: Initialize global graph-store group + Parent->>Parent: get_graph_store_info() + Parent->>Parent: Destroy temporary global group + Parent->>Worker: spawn(...) + Worker->>Worker: init_compute_process(local_rank, cluster_info) + Worker->>Remote: RemoteDistDataset(cluster_info, local_rank) + Remote->>Server: Fetch loader input over RPC + Server-->>Remote: Server-keyed input + Worker->>Remote: Create loaders + Worker->>Worker: Run model code +``` + +That timing is the core migration point. `DistDataset` is constructed once per graph-owning machine. `RemoteDistDataset` +is constructed inside each compute process, after `init_compute_process(...)` has connected that process to the storage +servers and initialized the compute process group. + +## Compute-Side Shape + +A graph-store compute parent usually only needs the cluster topology before spawning workers: + +```python +torch.distributed.init_process_group(backend="gloo") +cluster_info = get_graph_store_info() +torch.distributed.destroy_process_group() + +mp.spawn( + fn=run_worker, + args=(cluster_info, ...), + nprocs=local_world_size, +) +``` + +The worker then joins the graph-store client setup and creates the remote dataset handle: + +```python +def run_worker(local_rank: int, cluster_info: GraphStoreInfo, ...) -> None: + init_compute_process(local_rank, cluster_info) + dataset = RemoteDistDataset(cluster_info, local_rank) + + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() +``` + +Use `rank` and `world_size` from `torch.distributed` after `init_compute_process(...)`. At that point they describe the +compute process group, including all per-GPU processes across compute machines. + +## Data Access Changes + +In colocated mode, loader input is usually read directly from local dataset fields: + +```mermaid +flowchart LR + train["dataset.train_node_ids"] --> ablp["DistABLPLoader"] + val["dataset.val_node_ids"] --> ablp + test["dataset.test_node_ids"] --> ablp + nodes["dataset.node_ids"] --> neighbor["DistNeighborLoader"] +``` + +In graph-store mode, compute processes fetch the input from storage: + +```mermaid +flowchart LR + fetchTrain["dataset.fetch_ablp_input(split='train', ...)"] --> remoteAblp["DistABLPLoader"] + fetchVal["dataset.fetch_ablp_input(split='val', ...)"] --> remoteAblp + fetchTest["dataset.fetch_ablp_input(split='test', ...)"] --> remoteAblp + fetchNodes["dataset.fetch_node_ids(...)"] --> remoteNeighbor["DistNeighborLoader"] +``` + +For anchor-based link prediction, fetch the main batch input from storage: + +```python +ablp_input = dataset.fetch_ablp_input( + split="train", + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + anchor_node_type=anchor_node_type, # heterogeneous only + supervision_edge_type=supervision_edge_type, # heterogeneous only +) + +main_loader = DistABLPLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=ablp_input, + batch_size=main_batch_size, + num_workers=sampling_workers_per_process, +) +``` + +For random negatives or inference seed nodes, fetch node IDs from storage: + +```python +node_ids = dataset.fetch_node_ids( + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + node_type=node_type, # heterogeneous only +) + +loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=(node_type, node_ids), # use node_ids directly for homogeneous graphs + batch_size=batch_size, + num_workers=sampling_workers_per_process, +) +``` + +The loader input shape changes because storage nodes own the graph partitions: + +| Loader | Colocated input | Graph-store input | +| -------------------- | -------------------------------- | ------------------------------------------------------ | +| `DistABLPLoader` | `Tensor` or `(NodeType, Tensor)` | `dict[int, ABLPInputNodes]` | +| `DistNeighborLoader` | `Tensor` or `(NodeType, Tensor)` | `dict[int, Tensor]` or `(NodeType, dict[int, Tensor])` | + +The `int` keys in graph-store inputs are storage ranks. Empty tensors for a storage rank are normal; they mean that this +compute rank has no input assigned to that storage server. + +## Training Notes + +Training needs storage-side splits. The storage process has to build its `DistDataset` with the splitter that creates +the train, validation, and test node IDs. If it does not, calls such as `fetch_ablp_input(split="train", ...)` have +nothing to fetch. + +The rest of the training loop should look familiar: + +```mermaid +flowchart TD + ablpInput["fetch ABLP input"] --> ablpLoader["DistABLPLoader"] + randomInput["fetch random-negative input"] --> neighborLoader["DistNeighborLoader"] + ablpLoader --> batch["main batch + random negative batch"] + neighborLoader --> batch + batch --> model["model + loss"] +``` + +Model construction, loss calculation, optimizer steps, validation, metric writing, and model saving usually do not need +graph-store-specific behavior. The main exceptions are places that assumed a colocated machine rank or assumed the +dataset object already existed before local worker processes were spawned. + +For heterogeneous ABLP, be explicit about edge direction. If the storage splitter is configured with inbound sampling, +the supervision edge type passed to the splitter may need to be the reverse of the user-facing prediction edge type. +This is a graph direction issue, not a graph-store-specific model issue. + +## Inference Notes + +Inference usually does not need train, validation, or test splits. The compute worker fetches the target node IDs from +storage and passes them to `DistNeighborLoader`. + +```mermaid +flowchart LR + fetchInference["fetch_node_ids(node_type=...)"] --> inferenceLoader["DistNeighborLoader"] + inferenceLoader --> forward["model forward"] + forward --> export["embedding or prediction export"] +``` + +Model loading, batching, embedding export, and BigQuery loading can usually stay the same. Watch for any code that uses +the old colocated machine rank to decide which process writes shared outputs. In graph-store compute workers, +`torch.distributed.get_rank() == 0` is the usual lead-process check after `init_compute_process(...)`. + +## Storage And Resource Config + +Graph store jobs run two entrypoints: one for compute and one for storage. + +```mermaid +flowchart TD + taskConfig["Task config"] --> computeCommand["trainerConfig.command
or inferencerConfig.command"] + taskConfig --> storageCommand["graphStoreStorageConfig.command"] + computeCommand --> computeLoop["Compute-side training
or inference loop"] + storageCommand --> storageLoop["Storage entrypoint
Build DistDataset
Start server sessions"] +``` + +For training, the storage args need enough information to create splits: + +```yaml +trainerConfig: + command: python -m my_package.training_graph_store + graphStoreStorageConfig: + command: python -m my_package.storage_main + storageArgs: + sample_edge_direction: "in" + splitter_cls_path: "gigl.utils.data_splitters.DistNodeAnchorLinkSplitter" + splitter_kwargs: >- + { + "sampling_direction": "in", + "should_convert_labels_to_edges": True, + "num_val": 0.1, + "num_test": 0.1 + } + num_server_sessions: "1" +``` + +For inference, the storage args are often simpler because inference can query all target nodes: + +```yaml +inferencerConfig: + command: python -m my_package.inference_graph_store + graphStoreStorageConfig: + command: python -m my_package.storage_main + storageArgs: + sample_edge_direction: "in" + num_server_sessions: "1" +``` + +The resource config also changes from a single Vertex AI pool to separate graph-store and compute pools: + +```yaml +trainer_resource_config: + vertex_ai_graph_store_trainer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 +``` + +Use the graph-store pool for memory-heavy graph serving. Use the compute pool for model work, usually with GPUs. If you +need to override how many compute processes run per compute machine, use `compute_cluster_local_world_size` in the +graph-store resource config. + +`num_server_sessions` should match how many separate compute process groups will connect to storage over the lifetime of +the job. Training commonly uses one session. Heterogeneous inference may use one session per node type if the inference +loop runs node types sequentially with separate spawned process groups. + +## Shutdown Order + +Shut down loaders before shutting down the compute process: + +```mermaid +sequenceDiagram + participant Worker as Compute worker + participant Loader as Dist loaders + participant Storage as Storage servers + participant Client as Graph-store client + + Worker->>Loader: loader.shutdown() + Loader->>Storage: destroy sampling channels + Storage-->>Loader: channel teardown complete + Worker->>Client: shutdown_compute_process() + Client->>Client: shutdown RPC client + Client->>Client: destroy compute process group +``` + +Loader shutdown tears down server-side sampling channels. `shutdown_compute_process()` tears down the graph-store client +and the compute process group. Calling them in the opposite order leaves storage with live sampling state and can make +shutdown slow or noisy. + +## Where To Look Next + +The link prediction examples show the pattern in real code: + +- Colocated training and inference: + [`examples/link_prediction`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction) +- Graph-store training and inference: + [`examples/link_prediction/graph_store`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction/graph_store) +- Graph-store task and resource configs: + [`examples/link_prediction/graph_store/configs`](https://github.com/Snapchat/GiGL/tree/main/examples/link_prediction/graph_store/configs) + +The storage entrypoint in `examples/link_prediction/graph_store/storage_main.py` is also useful if you need to write a +custom storage command. diff --git a/docs/user_guide/overview/in_memory_subgraph_sampling.md b/docs/user_guide/overview/in_memory_subgraph_sampling.md index becc28c3b..e88f8c587 100644 --- a/docs/user_guide/overview/in_memory_subgraph_sampling.md +++ b/docs/user_guide/overview/in_memory_subgraph_sampling.md @@ -29,10 +29,15 @@ The main abstractions used by the in-memory path are: - {py:class}`gigl.distributed.dist_dataset.DistDataset` for colocated sampling, where each machine stores its graph partition locally. +- {py:class}`gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset` for graph-store sampling, where storage + nodes own graph partitions and compute processes sample from them over RPC. - {py:class}`gigl.distributed.distributed_neighborloader.DistNeighborLoader` for standard neighborhood sampling. - {py:class}`gigl.distributed.dist_ablp_neighborloader.DistABLPLoader` for anchor-based link prediction batches with positives and negatives. +If you already have a colocated in-memory loop and want to move graph storage onto separate machines, see +[Converting Colocated Loops To Graph Store](graph_store_conversion.md). + ## Example Implementations The link prediction examples are the clearest reference implementations for the current runtime: diff --git a/examples/link_prediction/README.md b/examples/link_prediction/README.md index cd730f595..afcd053a8 100644 --- a/examples/link_prediction/README.md +++ b/examples/link_prediction/README.md @@ -23,6 +23,9 @@ are example inference and training loops for the DBLP dataset. The DBLP dataset You can follow along with [dblp.ipynb](./dblp.ipynb) to run an e2e GiGL pipeline on the DBLP dataset. It will guide you through running each component: `config_populator` -> `data_preprocessor` -> `trainer` -> `inferencer` +For guidance on adapting an existing colocated in-memory loop to graph-store mode, see the +[graph-store conversion guide](../../docs/user_guide/overview/graph_store_conversion.md). + ```{toctree} :maxdepth: 2 :hidden: diff --git a/pyproject.toml b/pyproject.toml index b22e1f0e6..3fe88014c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,7 @@ docs = [ "sphinx-rtd-theme==2.0.0", "pydata-sphinx-theme==0.16.1", "sphinx-tabs==3.4.5", - + "sphinxcontrib-mermaid==2.0.2", ] typing-stubs = [ "pandas-stubs==2.2.2.240807", diff --git a/uv.lock b/uv.lock index 1909a0793..121e14b2f 100644 --- a/uv.lock +++ b/uv.lock @@ -794,6 +794,7 @@ dev = [ { name = "sphinx-hoverxref" }, { name = "sphinx-rtd-theme" }, { name = "sphinx-tabs" }, + { name = "sphinxcontrib-mermaid" }, { name = "ty" }, { name = "types-psutil" }, { name = "types-pyyaml" }, @@ -816,6 +817,7 @@ docs = [ { name = "sphinx-hoverxref" }, { name = "sphinx-rtd-theme" }, { name = "sphinx-tabs" }, + { name = "sphinxcontrib-mermaid" }, ] gigl-core-build-backend = [ { name = "pybind11" }, @@ -920,6 +922,7 @@ dev = [ { name = "sphinx-hoverxref", specifier = "==1.3.0" }, { name = "sphinx-rtd-theme", specifier = "==2.0.0" }, { name = "sphinx-tabs", specifier = "==3.4.5" }, + { name = "sphinxcontrib-mermaid", specifier = "==2.0.2" }, { name = "ty", specifier = "==0.0.31" }, { name = "types-psutil", specifier = "==7.0.0.20250401" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, @@ -942,6 +945,7 @@ docs = [ { name = "sphinx-hoverxref", specifier = "==1.3.0" }, { name = "sphinx-rtd-theme", specifier = "==2.0.0" }, { name = "sphinx-tabs", specifier = "==3.4.5" }, + { name = "sphinxcontrib-mermaid", specifier = "==2.0.2" }, ] gigl-core-build-backend = [ { name = "pybind11", specifier = ">=2.12" }, @@ -3810,6 +3814,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071, upload-time = "2019-01-21T16:10:14.333Z" }, ] +[[package]] +name = "sphinxcontrib-mermaid" +version = "2.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "pyyaml" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/75/3a1cc926da8c563c58ddc124a7b3fe5ccadcae96c96e3a6f8ac3653a210a/sphinxcontrib_mermaid-2.0.2.tar.gz", hash = "sha256:f09576c78ca93fa0e3034fd9c45aaffa7c44ab449de9c43b8b8d262afe52bc66", size = 19265, upload-time = "2026-05-05T13:59:02.959Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/8d/93be7e0f7fa915a576859b3bfac7a7baa3303181c44d7db7eefbd3e8a69f/sphinxcontrib_mermaid-2.0.2-py3-none-any.whl", hash = "sha256:d862e514991279fb4816302c5cfe167d2557bf3ce7125ae0cb47dac80a0f46ce", size = 14094, upload-time = "2026-05-05T13:59:01.585Z" }, +] + [[package]] name = "sphinxcontrib-qthelp" version = "2.0.0"