Skip to content

einsum: local-slice fast path for batched (Hadamard) contractions#561

Closed
evaleev wants to merge 1 commit into
masterfrom
evaleev/feature/einsum-hadamard-local-fastpath
Closed

einsum: local-slice fast path for batched (Hadamard) contractions#561
evaleev wants to merge 1 commit into
masterfrom
evaleev/feature/einsum-hadamard-local-fastpath

Conversation

@evaleev

@evaleev evaleev commented Jun 11, 2026

Copy link
Copy Markdown
Member

Summary

The generalized batched-contraction einsum path — Hadamard indices coexisting with external/contracted indices, e.g. C(b,i,k) = A(b,i,j) * B(b,j,k) — previously ran one MPI_Comm_split + a fresh sub-World + a sub-World fence per Hadamard tile. That is O(#Hadamard-tiles) collectives, and the per-tile sub-World construction/teardown dominates wall time even at np=1.

This PR adds a local-slice fast path: a Hadamard slice whose (single) input tiles are all owned by one rank is contracted locally on that rank with a direct Tensor::gemm — no comm-split, no sub-World, no make_array/DistEval, no fence. Whether a slice is local is decided purely from global pmap/trange metadata, so every rank reaches the same verdict and the comm-splits for the remaining (genuinely cross-rank) slices stay in lockstep. For batch-blocked data every slice is single-owner, so the whole batched contraction runs communication-free.

Slices that span ranks, multi-tile external/contracted dimensions, and tensor-of-tensor tiles fall back to the unchanged sub-World path.

Why it's correct

  • The "is this slice local?" verdict comes only from pmap/trange metadata (identical on all ranks) → no divergence in the collective sequence for the distributed-fallback slices.
  • The local contraction C(e_A,e_B) = A(e_A,i) * B(e_B,i) is a canonical GemmHelper(NoTranspose, Transpose, ...); e = (a|b)-(a&b) orders A-externals before B-externals, matching gemm's (left_outer, right_outer) output, so the existing permute/harvest fixes up the layout unchanged.

Performance

Speedup over the legacy path for C(b,i,k) = A(b,i,j) * B(b,j,k), 64 batch tiles, flops held constant (examples/tot_bench/batched_contraction_attribution.cpp):

np=1 np=2
1 Hadamard tile 1.36× 1.10×
16 tiles 4.3× 2.2×
64 tiles 12.9× 5.3×

The legacy path's overhead grew from ~2× to ~26× over the raw-BLAS flop floor as the batch was split into more tiles; with the fast path that per-Hadamard-tile machinery cost is gone and the overhead is flat in the number of Hadamard tiles.

Toggle

On by default. Set TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED=1 (or flip detail::einsum_hadamard_local_fastpath_disabled()) to force the legacy sub-World path — kept as a safety valve and differential-correctness hook (mirrors regime_a_strided_disabled()).

Testing

  • einsum + einsum_tot suites pass at np=1 and np=2 with both the new and legacy paths (Release build).
  • Full faithful (TA_ASSERT_THROW) suite passes at np=1.
  • The np=2 ASan debug build hits a pre-existing MADNESS comm-thread teardown SEGV that reproduces identically with the legacy path forced (i.e. master behavior) in unrelated suites (index_list, bipartite_index_list); unrelated to this change.

Notes / follow-ups

Out of scope here, but the benchmark also surfaces remaining headroom: the multi-tile-external case still uses the sub-World fallback (could be a local tiled-gemm loop); the small per-batch GEMMs run at a fraction of peak (batched/strided BLAS would lift the flop floor); and the entry world.gop.fence() hotfix may now be removable on the common path.

The generalized batched-contraction path -- Hadamard indices coexisting
with external/contracted indices, e.g. C(b,i,k) = A(b,i,j) * B(b,j,k) --
ran one MPI_Comm_split + a fresh sub-World + a sub-World fence per
Hadamard tile. That is O(#Hadamard-tiles) collectives, and the per-tile
sub-World construction/teardown dominates wall time even at np=1.

Add a fast path: a Hadamard slice whose (single) input tiles are all
owned by one rank is contracted locally on that rank with a direct
Tensor::gemm -- no comm-split, no sub-World, no make_array/DistEval, no
fence. Whether a slice is local is decided purely from global pmap/trange
metadata, so every rank reaches the same verdict and the comm-splits for
the remaining (genuinely cross-rank) slices stay in lockstep. For
batch-blocked data every slice is single-owner, so the whole batched
contraction runs communication-free.

Slices that span ranks, multi-tile external/contracted dimensions, and
tensor-of-tensor tiles fall back to the unchanged sub-World path.

Measured speedup over the legacy path (C(b,i,k) = A(b,i,j) * B(b,j,k),
64 batch tiles, flops held constant): 12.9x at np=1, 5.3x at np=2;
neutral for a single Hadamard tile. The per-Hadamard-tile machinery cost
-- which previously grew the overhead from 2x to 26x as the batch was
split into more tiles -- is gone; overhead vs the raw-BLAS flop floor is
now flat in the number of Hadamard tiles.

The fast path is on by default; set
TA_EINSUM_HADAMARD_LOCAL_FASTPATH_DISABLED=1 (or flip
detail::einsum_hadamard_local_fastpath_disabled()) to force the legacy
sub-World path as a safety valve / differential-correctness hook.

examples/tot_bench/batched_contraction_attribution.cpp is an attribution
benchmark for this case: legacy vs fast path vs a raw-BLAS flop floor,
with a constant-flops granularity sweep.
@evaleev

evaleev commented Jun 11, 2026

Copy link
Copy Markdown
Member Author

this has no effect on CSV-CC, only affects plain tensor products with batching indices

@evaleev

evaleev commented Jun 11, 2026

Copy link
Copy Markdown
Member Author

this will be superceded by proper general product support in expression layer.

@evaleev evaleev closed this Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant