From c6203932c594a3e466983207e02ede51f8edce7e Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 01:47:23 -0400 Subject: [PATCH 01/25] einsum: runtime-gated attribution profiler (TA_EINSUM_INSTRUMENT) Buckets per einsum call: entry_fence / setup / commsplit+world / retile/make_array / contract+fence / harvest / local_kernel / teardown, keyed by branch (hadamard-reduction-local, generalized-subworld, generalized-inner-perm-recurse) and contraction annotation; dumped to stderr at exit. Zero overhead when disabled. Establishes the baseline attribution for replacing the per-Hadamard-tile sub-World decomposition with first-class general-product (h+e+c) support. PNO-CCSD c6h14/cc-pVDZ baseline (np=1, 3 CC iters): 17.9 s einsum-region, 1412 Hadamard slices = 1412 sub-Worlds; retile/make_array 25.5% + harvest 3.0% + teardown 4.2% non-numeric, contract+fence 30.5%. --- src/TiledArray/einsum/einsum_instrument.h | 227 ++++++++++++++++++++++ src/TiledArray/einsum/tiledarray.h | 94 ++++++--- 2 files changed, 295 insertions(+), 26 deletions(-) create mode 100644 src/TiledArray/einsum/einsum_instrument.h diff --git a/src/TiledArray/einsum/einsum_instrument.h b/src/TiledArray/einsum/einsum_instrument.h new file mode 100644 index 0000000000..215f6fcde4 --- /dev/null +++ b/src/TiledArray/einsum/einsum_instrument.h @@ -0,0 +1,227 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * einsum/einsum_instrument.h + * + * Lightweight, runtime-gated attribution profiler for the generalized + * batched-contraction einsum (Hadamard indices coexisting with + * external/contracted indices, including tensor-of-tensor operands). + * + * Goal: separate time spent in the *machinery* of the per-Hadamard-tile + * sub-World scheme (MPI_Comm_split, sub-World construction/teardown, + * make_array/retile, harvest, entry fence) from time spent in the actual + * numeric contraction. Enabled by setting TA_EINSUM_INSTRUMENT to any + * non-empty value other than "0"; zero (modulo a cached bool load) overhead + * when off. Results are dumped to std::cerr at static teardown. + */ + +#ifndef TILEDARRAY_EINSUM_EINSUM_INSTRUMENT_H__INCLUDED +#define TILEDARRAY_EINSUM_EINSUM_INSTRUMENT_H__INCLUDED + +#include "TiledArray/util/time.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace TiledArray::detail { + +/// Runtime gate for the einsum attribution profiler. Toggled from the +/// environment via TA_EINSUM_INSTRUMENT (any non-empty value other than "0" +/// enables). Mirrors einsum_hadamard_local_fastpath_disabled(). +inline bool einsum_instrument_enabled() { + static const bool flag = [] { + const char *e = std::getenv("TA_EINSUM_INSTRUMENT"); + return e != nullptr && e[0] != '\0' && std::string_view(e) != "0"; + }(); + return flag; +} + +/// Time buckets attributed per einsum call. NUMERICS is the only bucket that +/// is genuine flops; the rest is per-Hadamard-tile / per-call machinery. +enum class EinsumBucket : std::size_t { + EntryFence = 0, ///< the blocking world.gop.fence() at function entry + Setup, ///< index algebra, range maps, trange, inner-op build + CommSplitWorld, ///< MPI_Comm_split + sub-World construction + Retile, ///< make_array of the per-slice input sub-arrays + ContractFence, ///< expr-assign enqueue + per-slice sub-World fence + ///< (numerics for the sub-World path live here) + Harvest, ///< extracting completed result tiles from the sub-array + LocalKernel, ///< direct local contraction (arena / element-op gemm) + Teardown, ///< build_C_array + final sub-World fences + COUNT +}; + +inline const char *einsum_bucket_name(EinsumBucket b) { + switch (b) { + case EinsumBucket::EntryFence: + return "entry_fence"; + case EinsumBucket::Setup: + return "setup"; + case EinsumBucket::CommSplitWorld: + return "commsplit+world"; + case EinsumBucket::Retile: + return "retile/make_array"; + case EinsumBucket::ContractFence: + return "contract+fence"; + case EinsumBucket::Harvest: + return "harvest"; + case EinsumBucket::LocalKernel: + return "local_kernel"; + case EinsumBucket::Teardown: + return "teardown"; + default: + return "?"; + } +} + +struct EinsumProfileEntry { + std::uint64_t calls = 0; + std::uint64_t slices = 0; ///< total Hadamard slices iterated (owned) + std::uint64_t subworlds = 0; ///< total sub-Worlds constructed + std::uint64_t localslices = 0; ///< slices handled by a local kernel + std::array(EinsumBucket::COUNT)> ns{}; + + void merge(const EinsumProfileEntry &o) { + calls += o.calls; + slices += o.slices; + subworlds += o.subworlds; + localslices += o.localslices; + for (std::size_t k = 0; k < ns.size(); ++k) ns[k] += o.ns[k]; + } + std::int64_t total_ns() const { + std::int64_t t = 0; + for (auto v : ns) t += v; + return t; + } +}; + +/// Process-wide accumulator, keyed by " | ". Dumps a +/// sorted attribution table to std::cerr at static teardown when enabled. +class EinsumProfiler { + public: + static EinsumProfiler &instance() { + static EinsumProfiler p; + return p; + } + + void merge(const std::string &key, const EinsumProfileEntry &e) { + std::lock_guard g(mtx_); + by_key_[key].merge(e); + } + + ~EinsumProfiler() { + if (einsum_instrument_enabled()) dump(std::cerr); + } + + void dump(std::ostream &os) { + std::lock_guard g(mtx_); + if (by_key_.empty()) return; + // collect and sort by total time descending + std::vector> rows( + by_key_.begin(), by_key_.end()); + std::sort(rows.begin(), rows.end(), [](auto const &a, auto const &b) { + return a.second.total_ns() > b.second.total_ns(); + }); + std::int64_t grand = 0; + EinsumProfileEntry tot; + for (auto const &[k, e] : rows) { + grand += e.total_ns(); + tot.merge(e); + } + auto s = [](std::int64_t ns) { return ns / 1e9; }; + os << "\n================ TA einsum attribution (TA_EINSUM_INSTRUMENT) " + "================\n"; + os << "total einsum-region time: " << s(grand) << " s over " << tot.calls + << " calls, " << tot.slices << " slices, " << tot.subworlds + << " sub-Worlds, " << tot.localslices << " local slices\n"; + // aggregate bucket breakdown + os << "-- aggregate by bucket --\n"; + for (std::size_t k = 0; k < tot.ns.size(); ++k) { + if (tot.ns[k] == 0) continue; + os << " " << std::setw(18) << std::left + << einsum_bucket_name(static_cast(k)) << " " + << std::setw(9) << std::right << s(tot.ns[k]) << " s (" + << std::setw(5) << std::fixed << std::setprecision(1) + << (grand ? 100.0 * tot.ns[k] / grand : 0.0) << "%)\n"; + } + // per-key rows + os << "-- by contraction (branch | shape), sorted by total time --\n"; + for (auto const &[k, e] : rows) { + std::int64_t t = e.total_ns(); + os << " " << s(t) << " s x" << e.calls << " slices=" << e.slices + << " subW=" << e.subworlds << " local=" << e.localslices << "\n " + << k << "\n "; + for (std::size_t b = 0; b < e.ns.size(); ++b) { + if (e.ns[b] == 0) continue; + os << einsum_bucket_name(static_cast(b)) << "=" + << std::fixed << std::setprecision(1) + << (t ? 100.0 * e.ns[b] / t : 0.0) << "% "; + } + os << "\n"; + } + os << "============================================================" + "====================\n"; + os.flush(); + } + + private: + std::mutex mtx_; + std::map by_key_; +}; + +/// Per-call accumulator. Construct once near the top of an einsum call; its +/// destructor merges into the process-wide profiler. No-op when disabled. +struct EinsumCall { + bool active; + std::string label; ///< shape annotation (a;A * b;B -> c;C) + const char *branch = "?"; ///< which einsum branch handled this call + EinsumProfileEntry e; + + explicit EinsumCall(std::string lbl) + : active(einsum_instrument_enabled()), label(std::move(lbl)) { + if (active) e.calls = 1; + } + ~EinsumCall() { + if (active) + EinsumProfiler::instance().merge(std::string(branch) + " | " + label, e); + } + void add(EinsumBucket b, std::int64_t ns) { + if (active) e.ns[static_cast(b)] += ns; + } + void add_slices(std::uint64_t n) { + if (active) e.slices += n; + } + void add_subworld() { + if (active) ++e.subworlds; + } + void add_localslice() { + if (active) ++e.localslices; + } +}; + +/// RAII region timer; adds elapsed wall time to a bucket of an EinsumCall. +struct EinsumTimer { + EinsumCall *call; + EinsumBucket bucket; + bool on; + time_point t0; + EinsumTimer(EinsumCall &c, EinsumBucket b) + : call(&c), bucket(b), on(c.active), t0(on ? now() : time_point{}) {} + ~EinsumTimer() { + if (on) call->add(bucket, duration_in_ns(t0, now())); + } +}; + +} // namespace TiledArray::detail + +#endif // TILEDARRAY_EINSUM_EINSUM_INSTRUMENT_H__INCLUDED diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 89fd0db4cf..099b19037f 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -3,6 +3,7 @@ #include "TiledArray/conversions/make_array.h" #include "TiledArray/dist_array.h" +#include "TiledArray/einsum/einsum_instrument.h" #include "TiledArray/einsum/index.h" #include "TiledArray/einsum/range.h" #include "TiledArray/expressions/fwd.h" @@ -468,7 +469,12 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // blocking calls // TODO figure out why having free threads left after blocking MPI split // still not enough to ensure progress + const auto _ein_entry_t0 = + detail::einsum_instrument_enabled() ? now() : time_point{}; world.gop.fence(); + const std::int64_t _ein_entry_fence_ns = + detail::einsum_instrument_enabled() ? duration_in_ns(_ein_entry_t0, now()) + : 0; using ArrayA = std::remove_cv_t; using ArrayB = std::remove_cv_t; @@ -518,6 +524,12 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, "General product between inner tensors not supported"); } + // einsum attribution profiler (TA_EINSUM_INSTRUMENT); no-op when disabled + detail::EinsumCall _ein_call{(std::string)a + inner.a + " * " + + (std::string)b + inner.b + " -> " + + (std::string)c + inner.c}; + _ein_call.add(detail::EinsumBucket::EntryFence, _ein_entry_fence_ns); + if constexpr (DeNestFlag == DeNest::True) { static_assert(detail::nested_rank == detail::nested_rank && detail::nested_rank == 2); @@ -730,6 +742,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, if (!e) { // hadamard reduction + _ein_call.branch = "hadamard-reduction-local"; + const auto _ein_he_t0 = _ein_call.active ? now() : time_point{}; + auto &[A, B] = AB; TiledRange trange(range_map[i]); RangeProduct tiles; @@ -866,6 +881,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, C_local_tiles.emplace_back(std::move(c), std::move(tile)); } + _ein_call.add(detail::EinsumBucket::LocalKernel, + _ein_call.active ? duration_in_ns(_ein_he_t0, now()) : 0); + build_C_array(); return C.array; @@ -873,11 +891,15 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // generalized contraction + _ein_call.branch = "generalized-subworld"; + const auto _ein_gen_t0 = _ein_call.active ? now() : time_point{}; + if constexpr (IsArrayToT) { if (inner.C != inner.h + inner.e) { // when inner tensor permutation is non-trivial (could be potentially // elided by extending this function (@c einsum) to take into account // of inner tensor's permutations) + _ein_call.branch = "generalized-inner-perm-recurse"; auto temp_annot = std::string(c) + ";" + std::string(inner.h + inner.e); ArrayC temp = einsum(tnsrExprA, tnsrExprB, Einsum::idx(temp_annot), world); @@ -899,13 +921,21 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, std::invoke(update_tr, std::get<0>(AB)); std::invoke(update_tr, std::get<1>(AB)); + _ein_call.add(detail::EinsumBucket::Setup, + _ein_call.active ? duration_in_ns(_ein_gen_t0, now()) : 0); + // iterates over tiles of hadamard indices for (Index h : H.tiles) { auto &[A, B] = AB; + _ein_call.add_slices(1); + const auto _ein_cs0 = _ein_call.active ? now() : time_point{}; auto own = A.own(h) || B.own(h); auto comm = madness::blocking_invoke(&SafeMPI::Intracomm::Split, world.mpi.comm(), own, world.rank()); worlds.push_back(std::make_unique(comm)); + _ein_call.add_subworld(); + _ein_call.add(detail::EinsumBucket::CommSplitWorld, + _ein_call.active ? duration_in_ns(_ein_cs0, now()) : 0); auto &owners = worlds.back(); if (!own) continue; size_t batch = 1; @@ -933,38 +963,50 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, *owners, term.ei_tiled_range, term.local_tiles.begin(), term.local_tiles.end(), replicated); }; - std::invoke(retile, std::get<0>(AB)); - std::invoke(retile, std::get<1>(AB)); - - C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); - A.ei.defer_deleter_to_next_fence(); - B.ei.defer_deleter_to_next_fence(); - A.ei = ArrayA(); - B.ei = ArrayB(); - // why omitting this fence leads to deadlock? - owners->gop.fence(); - for (Index e : C.tiles) { - if (!C.ei.is_local(e)) continue; - if (C.ei.is_zero(e)) continue; - // TODO no need for immediate evaluation - auto tile = C.ei.find_local(e).get(); - assert(tile.nbatch() == batch); - const Permutation &P = C.permutation; - auto c = apply(P, h + e); - auto shape = C.array.trange().tile(c); - shape = apply_inverse(P, shape); - tile = tile.reshape(shape); - if (P) tile = tile.permute(P); - C_local_tiles.emplace_back(std::move(c), std::move(tile)); + { + detail::EinsumTimer _t(_ein_call, detail::EinsumBucket::Retile); + std::invoke(retile, std::get<0>(AB)); + std::invoke(retile, std::get<1>(AB)); + } + + { + detail::EinsumTimer _t(_ein_call, detail::EinsumBucket::ContractFence); + C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); + A.ei.defer_deleter_to_next_fence(); + B.ei.defer_deleter_to_next_fence(); + A.ei = ArrayA(); + B.ei = ArrayB(); + // why omitting this fence leads to deadlock? + owners->gop.fence(); + } + { + detail::EinsumTimer _t(_ein_call, detail::EinsumBucket::Harvest); + for (Index e : C.tiles) { + if (!C.ei.is_local(e)) continue; + if (C.ei.is_zero(e)) continue; + // TODO no need for immediate evaluation + auto tile = C.ei.find_local(e).get(); + assert(tile.nbatch() == batch); + const Permutation &P = C.permutation; + auto c = apply(P, h + e); + auto shape = C.array.trange().tile(c); + shape = apply_inverse(P, shape); + tile = tile.reshape(shape); + if (P) tile = tile.permute(P); + C_local_tiles.emplace_back(std::move(c), std::move(tile)); + } } // mark for lazy deletion C.ei = ArrayC(); } - build_C_array(); + { + detail::EinsumTimer _t(_ein_call, detail::EinsumBucket::Teardown); + build_C_array(); - for (auto &w : worlds) { - w->gop.fence(); + for (auto &w : worlds) { + w->gop.fence(); + } } return C.array; From 7669213c2f13ed422f765dfb1b9a4fbfba8f6556 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 02:08:40 -0400 Subject: [PATCH 02/25] expressions: classify + lay out general products (TensorProduct::General) Phase A of first-class general-product (fused + contracted + free indices) support in the expression layer (target: PNO-CC batched contractions, replacing einsum's per-Hadamard-tile sub-World decomposition): - compute_product_type(left, right, target) now returns TensorProduct::General when a shared index survives into the target (fused) alongside contracted and/or free indices, incl. the Hadamard-reduction case (args related by permutation, target drops indices). The 2-arg overload is unchanged (bottom-up convention: shared => contracted). - new GeneralPermutationOptimizer: canonical layouts left (h, e_A, c), right (h, c, e_B), result (h, e_A, e_B) -- the GEMM-canonical layout with fused indices prepended so a consuming batched-GEMM op can fold them into the tile batch dimension by reshape; exposes the h/c/e_A/e_B partition for engine consumption. Requires target indices (fused-vs-contracted is undecidable bottom-up); validates against implicit reductions. - BinaryEngine::init_indices_ and MultEngine/ScalMultEngine route General through the new optimizer; ContEngine::product_type() accessor admits General. - evaluation is gated with an informative exception (use TiledArray::einsum() meanwhile) until the batched-Summa DistEval lands (Phase B); previously such expressions misclassified as pure contractions and died in target-permutation resolution. - unit tests: classification, optimizer layouts/partitions/errors, and the end-to-end expression gate (tests/general_product.cpp). --- src/TiledArray/expressions/binary_engine.h | 17 +- src/TiledArray/expressions/cont_engine.h | 175 +++++++++++---------- src/TiledArray/expressions/mult_engine.h | 112 +++++++++++-- src/TiledArray/expressions/permopt.h | 162 +++++++++++++++++++ src/TiledArray/expressions/product.h | 32 +++- tests/CMakeLists.txt | 1 + tests/general_product.cpp | 170 ++++++++++++++++++++ 7 files changed, 568 insertions(+), 101 deletions(-) create mode 100644 tests/general_product.cpp diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index d7c4fda6e2..7d0e3ff0d2 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -100,13 +100,20 @@ class BinaryEngine : public ExprEngine { template void init_indices_(const BipartiteIndexList& target_indices = {}) { static_assert(OuterProductType == TensorProduct::Contraction || - OuterProductType == TensorProduct::Hadamard); + OuterProductType == TensorProduct::Hadamard || + OuterProductType == TensorProduct::General); + // N.B. a General product's layout depends on the target (the role of a + // shared index -- fused vs contracted -- is defined by which indices the + // target keeps), so OuterProductType == General requires nonempty + // target_indices (GeneralPermutationOptimizer throws otherwise). // prefer to permute the arg with fewest leaves to try to minimize the // number of possible permutations - using permopt_type = - std::conditional_t; + using permopt_type = std::conditional_t< + OuterProductType == TensorProduct::Contraction, + GEMMPermutationOptimizer, + std::conditional_t>; std::shared_ptr outer_opt, inner_opt; if (!target_indices) { diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 76be6fc3c1..e723952a34 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -139,28 +139,33 @@ class ContEngine : public BinaryEngine { std::function arena_strided_dgemm_ce_e_tile_op_; ///< whole-tile ce+e strided DGEMM op - ///< (arena inner OUTER-PRODUCT under an - ///< outer contraction); null otherwise + ///< (arena inner OUTER-PRODUCT under + ///< an outer contraction); null + ///< otherwise std::function - arena_strided_dgemm_ce_ce_right_tile_op_; ///< whole-tile ce+ce strided DGEMM op - ///< (arena inner CONTRACTION under an - ///< outer contraction; right-external - ///< rides BLAS M, left-external rides - ///< an outer loop); null otherwise. - ///< Mutually exclusive with - ///< arena_strided_dgemm_ce_e_tile_op_ - ///< (disjoint num_contract_ranks() - ///< gates) + arena_strided_dgemm_ce_ce_right_tile_op_; ///< whole-tile ce+ce strided + ///< DGEMM op (arena inner + ///< CONTRACTION under an outer + ///< contraction; + ///< right-external rides BLAS + ///< M, left-external rides an + ///< outer loop); null + ///< otherwise. Mutually + ///< exclusive with + ///< arena_strided_dgemm_ce_e_tile_op_ + ///< (disjoint + ///< num_contract_ranks() + ///< gates) std::function arena_strided_dgemm_ce_ce_left_tile_op_; ///< whole-tile ce+ce strided - ///< DGEMM op, LEFT-clean mirror: - ///< left-external rides BLAS M, - ///< right-external rides an - ///< outer loop. Mutually - ///< exclusive with the ce_e and - ///< ce_ce_right ops. + ///< DGEMM op, LEFT-clean + ///< mirror: left-external rides + ///< BLAS M, right-external + ///< rides an outer loop. + ///< Mutually exclusive with the + ///< ce_e and ce_ce_right ops. using arena_plan_storage_t = TiledArray::detail::arena_plan_storage_t; @@ -186,9 +191,10 @@ class ContEngine : public BinaryEngine { TensorProduct product_type() const { TA_ASSERT(product_type_ != TensorProduct::Invalid); // init_indices() must initialize this - /// only Hadamard and contraction are supported now + /// only Hadamard, contraction, and general are supported now TA_ASSERT(product_type_ == TensorProduct::Hadamard || - product_type_ == TensorProduct::Contraction); + product_type_ == TensorProduct::Contraction || + product_type_ == TensorProduct::General); return product_type_; } @@ -359,9 +365,11 @@ class ContEngine : public BinaryEngine { if (this->arena_strided_dgemm_ce_e_tile_op_) op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_e_tile_op_); if (this->arena_strided_dgemm_ce_ce_right_tile_op_) - op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_ce_right_tile_op_); + op_.set_strided_oprod_op( + this->arena_strided_dgemm_ce_ce_right_tile_op_); if (this->arena_strided_dgemm_ce_ce_left_tile_op_) - op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_ce_left_tile_op_); + op_.set_strided_oprod_op( + this->arena_strided_dgemm_ce_ce_left_tile_op_); } // Plan ownership transferred to op_; mark carrier slot empty so any // later use of arena_plan_ reads as "no plan" rather than moved-from. @@ -416,9 +424,11 @@ class ContEngine : public BinaryEngine { if (this->arena_strided_dgemm_ce_e_tile_op_) op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_e_tile_op_); if (this->arena_strided_dgemm_ce_ce_right_tile_op_) - op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_ce_right_tile_op_); + op_.set_strided_oprod_op( + this->arena_strided_dgemm_ce_ce_right_tile_op_); if (this->arena_strided_dgemm_ce_ce_left_tile_op_) - op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_ce_left_tile_op_); + op_.set_strided_oprod_op( + this->arena_strided_dgemm_ce_ce_left_tile_op_); } // Plan ownership transferred to op_; mark carrier slot empty so any // later use of arena_plan_ reads as "no plan" rather than moved-from. @@ -791,23 +801,21 @@ class ContEngine : public BinaryEngine { // 3-operand predicate so a mixed-operand contraction (e.g. a // view/double result with a non-view or non-double operand, or // float/complex inner) stays on the generic per-cell path and - // never instantiates the double-view-only kernel (which would be - // a hard compile error rather than a graceful fallback). - if constexpr (TiledArray::is_tensor_view_v< - result_tile_element_type> && - TiledArray::is_tensor_view_v< - left_tile_element_type> && - TiledArray::is_tensor_view_v< - right_tile_element_type> && - std::is_same_v && - std::is_same_v< - typename left_tile_element_type::numeric_type, - double> && - std::is_same_v< - typename right_tile_element_type::numeric_type, - double>) { + // never instantiates the double-view-only kernel (which would + // be a hard compile error rather than a graceful fallback). + if constexpr ( + TiledArray::is_tensor_view_v && + TiledArray::is_tensor_view_v && + TiledArray::is_tensor_view_v && + std::is_same_v< + typename result_tile_element_type::numeric_type, + double> && + std::is_same_v< + typename left_tile_element_type::numeric_type, + double> && + std::is_same_v< + typename right_tile_element_type::numeric_type, + double>) { if (contrreduce_op.gemm_helper().num_contract_ranks() == 0 && !bool(inner(this->perm_))) { const scalar_type factor = this->factor_; @@ -827,38 +835,40 @@ class ContEngine : public BinaryEngine { }; } // ce+ce (hce+ce): inner CONTRACTION (num_contract_ranks() >= - // 1) under outer contraction. One operand inner must be a pure - // contraction vector; that side's outer-external rides BLAS M - // with one strided DGEMM per (batch, other-external, + // 1) under outer contraction. One operand inner must be a + // pure contraction vector; that side's outer-external rides + // BLAS M with one strided DGEMM per (batch, other-external, // outer-contraction) cell. Two orientations (right-clean -> // ce_ce_right, left-clean -> ce_ce_left); see the either-side // rule below. Sibling of the ce+e arm above (disjoint - // num_contract_ranks gate) so at most one strided op installs. + // num_contract_ranks gate) so at most one strided op + // installs. const auto& inner_gh = contrreduce_op.gemm_helper(); const bool inner_contraction = inner_gh.num_contract_ranks() >= 1; // STRIDED-APPLICABILITY RULE (matrix x matrix exclusion). // The ce+ce core assumes the RIGHT inner cell is a pure // contraction vector R[k,μ̃](a4) -- i.e. the right operand - // carries NO inner external. When BOTH operand inners carry an - // external (a genuine inner matrix x matrix, e.g. + // carries NO inner external. When BOTH operand inners carry + // an external (a genuine inner matrix x matrix, e.g. // C(m,n;μ,ν) = A(m,k;μ,κ) * B(k,n;κ,ν)), riding μ̃ into BLAS M // would need a two-level stride the kernel cannot represent: // the per-cell `clean` probe fails and the GEMV fallback then // silently contributes nothing (the result cell volume P*Q no // longer matches the left cell). Refuse the install so such - // shapes take the generic per-cell contraction path. The right - // inner-external rank is right_rank - num_contract_ranks; the - // supported (right-clean) shape has it == 0. - // EITHER-SIDE rule: an inner contraction is strided-castable - // iff at least ONE operand inner is a pure contraction vector - // (no inner external). right-clean -> ce_ce_right (ride the - // right-external into BLAS M); left-clean -> ce_ce_left (ride - // the left-external into BLAS M). When BOTH inners carry an - // external (a genuine inner matrix x matrix, e.g. - // C(m,n;μ,ν) = A(m,k;μ,κ) * B(k,n;κ,ν)) neither fires and the - // generic per-cell path runs. An operand's inner-external rank - // is its rank - num_contract_ranks; clean == 0. + // shapes take the generic per-cell contraction path. The + // right inner-external rank is right_rank - + // num_contract_ranks; the supported (right-clean) shape has + // it == 0. EITHER-SIDE rule: an inner contraction is + // strided-castable iff at least ONE operand inner is a pure + // contraction vector (no inner external). right-clean -> + // ce_ce_right (ride the right-external into BLAS M); + // left-clean -> ce_ce_left (ride the left-external into BLAS + // M). When BOTH inners carry an external (a genuine inner + // matrix x matrix, e.g. C(m,n;μ,ν) = A(m,k;μ,κ) * B(k,n;κ,ν)) + // neither fires and the generic per-cell path runs. An + // operand's inner-external rank is its rank - + // num_contract_ranks; clean == 0. const bool right_inner_clean = inner_gh.right_rank() == inner_gh.num_contract_ranks(); const bool left_inner_clean = @@ -872,26 +882,28 @@ class ContEngine : public BinaryEngine { // the ridden operand must carry an outer external to ride. const bool right_has_ext = outer_size(this->right_indices_) > oc; - const bool left_has_ext = outer_size(this->left_indices_) > oc; + const bool left_has_ext = + outer_size(this->left_indices_) > oc; // canonical inner orientation: identity == "no inner // transpose". right core assumes L=(a1,a4), R=(a4); left core // assumes L=(a4), R=(a4,b1). Either way BOTH inner permtypes - // must be identity and there must be no inner result perm. This - // gate is LOAD-BEARING for correctness. + // must be identity and there must be no inner result perm. + // This gate is LOAD-BEARING for correctness. const bool inner_canonical = this->left_inner_permtype_ == TiledArray::expressions::PermutationType::identity && this->right_inner_permtype_ == TiledArray::expressions::PermutationType::identity && !bool(inner(this->perm_)); - // RELAXED gate. The strided kernel can fold a matrix_transpose - // of the EXTERNAL-carrying operand into the inner GEMM op flag - // (zero-copy), because matrix_transpose is a contiguous - // two-block swap (permopt) so the cell still flattens cleanly. - // The CLEAN (pure contraction vector) side must stay identity, - // the result inner must not be permuted, and a `general` inner - // perm still falls back. right arm: left carries the external - // (may be T), right is the vector (id). left arm: mirror. + // RELAXED gate. The strided kernel can fold a + // matrix_transpose of the EXTERNAL-carrying operand into the + // inner GEMM op flag (zero-copy), because matrix_transpose is + // a contiguous two-block swap (permopt) so the cell still + // flattens cleanly. The CLEAN (pure contraction vector) side + // must stay identity, the result inner must not be permuted, + // and a `general` inner perm still falls back. right arm: + // left carries the external (may be T), right is the vector + // (id). left arm: mirror. auto inner_pt_ok = [](TiledArray::expressions::PermutationType p) { return p == TiledArray::expressions::PermutationType:: @@ -916,12 +928,13 @@ class ContEngine : public BinaryEngine { const scalar_type factor = this->factor_; const bool left_inner_T = this->left_inner_permtype_ == - TiledArray::expressions::PermutationType::matrix_transpose; + TiledArray::expressions::PermutationType:: + matrix_transpose; this->arena_strided_dgemm_ce_ce_right_tile_op_ = - [factor, left_inner_T]( - result_tile_type& Cc, const left_tile_type& Lt, - const right_tile_type& Rt, - const math::GemmHelper& gh) { + [factor, left_inner_T](result_tile_type& Cc, + const left_tile_type& Lt, + const right_tile_type& Rt, + const math::GemmHelper& gh) { math::blas::integer Mo = 0, No = 0, Ko = 0; gh.compute_matrix_sizes(Mo, No, Ko, Lt.range(), Rt.range()); @@ -935,12 +948,13 @@ class ContEngine : public BinaryEngine { const scalar_type factor = this->factor_; const bool right_inner_T = this->right_inner_permtype_ == - TiledArray::expressions::PermutationType::matrix_transpose; + TiledArray::expressions::PermutationType:: + matrix_transpose; this->arena_strided_dgemm_ce_ce_left_tile_op_ = - [factor, right_inner_T]( - result_tile_type& Cc, const left_tile_type& Lt, - const right_tile_type& Rt, - const math::GemmHelper& gh) { + [factor, right_inner_T](result_tile_type& Cc, + const left_tile_type& Lt, + const right_tile_type& Rt, + const math::GemmHelper& gh) { math::blas::integer Mo = 0, No = 0, Ko = 0; gh.compute_matrix_sizes(Mo, No, Ko, Lt.range(), Rt.range()); @@ -975,7 +989,8 @@ class ContEngine : public BinaryEngine { TiledArray::detail::strided_dgemm_log( left_has_ext ? "hce+e REVERTED -> by-cell (inner result perm)" - : "hc+e REVERTED -> by-cell (inner result perm)"); + : "hc+e REVERTED -> by-cell (inner result " + "perm)"); } else if (!inner_canonical) { // ce+ce candidate blocked by a non-canonical inner perm. // Break down WHICH operand/result perm is non-identity diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index f0942ee48d..6be2f6fb1d 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -260,7 +260,19 @@ class MultEngine : public ContEngine> { void perm_indices(const BipartiteIndexList& target_indices) { if (this->product_type() == TensorProduct::Contraction) ContEngine_::perm_indices(target_indices); - else { + else if (this->product_type() == TensorProduct::General) { + // mirror ContEngine_::perm_indices, but lay out via + // GeneralPermutationOptimizer (the target determines which shared + // indices are fused vs contracted) + if (!this->implicit_permute()) { + BinaryEngine_::template init_indices_( + target_indices); + if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices()) + BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_); + if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices()) + BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_); + } + } else { BinaryEngine_::perm_indices(target_indices); } } @@ -275,6 +287,40 @@ class MultEngine : public ContEngine> { // for the left and right, hence do target-neutral initialization BinaryEngine_::left_.init_indices(); BinaryEngine_::right_.init_indices(); + + // Validate that the (bottom-up resolved) child indices are consistent + // with the target: every outer index of each child must appear in the + // other child or in the target (as a free, fused, or contracted index). + // A violation usually means a *general* product (fused + contracted + + // free indices) appears at an INNER node of the expression tree, where + // the role of a shared index cannot be deduced bottom-up; e.g. in the + // THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * ... the + // index r1 is fused in X*X but contracted downstream, while the + // bottom-up convention contracts it in X*X, orphaning the r1 of Z. + // Resolving this requires pushing the needed-index sets down the + // expression tree; until then materialize such inner products into + // explicit intermediates, so that every general product appears as the + // root of its own assignment (where the target determines the index + // roles). + { + auto const& left_outer = outer(BinaryEngine_::left_.indices()); + auto const& right_outer = outer(BinaryEngine_::right_.indices()); + auto const& target_outer = outer(target_indices); + auto validate = [&](const IndexList& a, const IndexList& b) { + for (auto&& idx : a) + if (!b.count(idx) && !target_outer.count(idx)) + TA_EXCEPTION( + "MultEngine: an argument index appears in neither the other " + "argument nor the target. If a general product (fused + " + "contracted + free indices) appears at an inner node of the " + "expression tree, its index roles cannot be deduced " + "bottom-up; materialize it into an explicit intermediate so " + "that it appears as the root of its own assignment"); + }; + validate(left_outer, right_outer); + validate(right_outer, left_outer); + } + this->product_type_ = compute_product_type( outer(BinaryEngine_::left_.indices()), outer(BinaryEngine_::right_.indices()), outer(target_indices)); @@ -282,24 +328,27 @@ class MultEngine : public ContEngine> { inner(BinaryEngine_::left_.indices()), inner(BinaryEngine_::right_.indices()), inner(target_indices)); - // TODO support general products that involve fused, contracted, and free - // indices Example: in ijk * jkl -> ijl indices i and l are free, index k is - // contracted, and index j is fused - // N.B. Currently only 2 types of products are supported: - // - Hadamard product (in which all indices are fused), and, - // - pure contraction (>=1 contracted, 0 fused, >=1 free indices) - // For the ToT arguments only the Hadamard product is supported + if (this->inner_product_type_ == TensorProduct::General) + TA_EXCEPTION( + "MultEngine: general products (fused + contracted + free indices) " + "between the inner (nested) indices of tensors-of-tensors are not " + "supported"); // Check the *outer* indices to determine whether the arguments are - // - contracted, or - // - Hadamard-multiplied - // The latter is indicated by the equality (modulo permutation) of - // the outer left and right arg indices to the target indices. + // - Hadamard-multiplied (outer left, right, and target indices are all + // related by permutations), + // - contracted (no index appears in left, right, AND target), or + // - generally multiplied (fused, contracted, and free indices coexist, + // e.g. "b,i,j" * "b,j,k" -> "b,i,k"); the layout then depends on the + // target, which determines the role (fused vs contracted) of every + // shared index. // Only the outer indices matter here since the inner indices only encode // the tile op; the type of the tile op does not need to match the type of // the operation on the outer indices if (this->product_type() == TensorProduct::Hadamard) { BinaryEngine_::perm_indices(target_indices); + } else if (this->product_type() == TensorProduct::General) { + this->perm_indices(target_indices); } else { auto children_initialized = true; ContEngine_::init_indices(children_initialized); @@ -334,6 +383,16 @@ class MultEngine : public ContEngine> { /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { + // TODO Phase B (batched Summa): evaluate general products natively. + // Until then this engine can classify and lay out a general product + // (see init_indices) but not evaluate it. + if (this->product_type() == TensorProduct::General) + TA_EXCEPTION( + "MultEngine: evaluation of general products (fused + contracted + " + "free indices, e.g. C(\"b,i,k\") = A(\"b,i,j\") * B(\"b,j,k\")) via " + "the expression layer is not yet implemented; use " + "TiledArray::einsum() instead"); + this->init_perm(target_indices); // for ContEngine_::init_struct need to initialize element op first @@ -576,7 +635,16 @@ class ScalMultEngine void perm_indices(const BipartiteIndexList& target_indices) { if (this->product_type() == TensorProduct::Contraction) ContEngine_::perm_indices(target_indices); - else { + else if (this->product_type() == TensorProduct::General) { + if (!this->implicit_permute()) { + BinaryEngine_::template init_indices_( + target_indices); + if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices()) + BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_); + if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices()) + BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_); + } + } else { BinaryEngine_::perm_indices(target_indices); } } @@ -595,6 +663,17 @@ class ScalMultEngine // since already initialized left and right arg indices assign the target // indices BinaryEngine_::perm_indices(target_indices); + } else if (this->product_type() == TensorProduct::General) { + // layout via GeneralPermutationOptimizer (the target determines which + // shared indices are fused vs contracted), then propagate to children + if (!this->implicit_permute()) { + BinaryEngine_::template init_indices_( + target_indices); + if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices()) + BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_); + if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices()) + BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_); + } } else { ContEngine_::init_indices(target_indices); } @@ -628,6 +707,13 @@ class ScalMultEngine /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { + // TODO Phase B (batched Summa): evaluate general products natively + if (this->product_type() == TensorProduct::General) + TA_EXCEPTION( + "ScalMultEngine: evaluation of general products (fused + contracted " + "+ free indices) via the expression layer is not yet implemented; " + "use TiledArray::einsum() instead"); + this->init_perm(target_indices); // for ContEngine_::init_struct need to initialize element op first diff --git a/src/TiledArray/expressions/permopt.h b/src/TiledArray/expressions/permopt.h index 291604faa8..6f2371b3f0 100644 --- a/src/TiledArray/expressions/permopt.h +++ b/src/TiledArray/expressions/permopt.h @@ -588,6 +588,159 @@ class ScalePermutationOptimizer : public BinaryOpPermutationOptimizer { static IndexList null_indices_; }; +// clang-format off +/// Given target, left, and right index lists of a general product (fused + +/// free + contracted indices coexisting, TensorProduct::General; e.g. +/// "b,i,j" * "b,j,k" -> "b,i,k") computes canonical index lists for the +/// arguments and the result: +/// - left = (fused..., left-free..., contracted...) +/// - right = (fused..., contracted..., right-free...) +/// - result = (fused..., left-free..., right-free...) +/// i.e. the GEMM-canonical layout of GEMMPermutationOptimizer with the fused +/// (Hadamard) indices prepended to every operand, so a consuming batched-GEMM +/// op can fold the fused modes into the tile batch dimension with a zero-copy +/// reshape. The relative order of fused and free indices is taken from the +/// target, minimizing the final result permutation; contracted indices keep +/// the left argument's relative order. +/// +/// Unlike the pure-contraction case, the role of a shared index (fused vs +/// contracted) is determined by the target, so this optimizer cannot be +/// constructed without target indices. +/// +/// \note Argument permutations are not fused into GEMM transposes (permtype +/// is identity or general, never matrix_transpose): the canonical layout +/// requires the fused modes to lead, which a GEMM transpose flag cannot +/// express. Layout-fusion optimizations can be added later. +// clang-format on +class GeneralPermutationOptimizer : public BinaryOpPermutationOptimizer { + public: + GeneralPermutationOptimizer(const GeneralPermutationOptimizer&) = default; + GeneralPermutationOptimizer& operator=(const GeneralPermutationOptimizer&) = + default; + ~GeneralPermutationOptimizer() = default; + + /// The role of a shared index (fused vs contracted) is defined by the + /// target, so a general product cannot be laid out bottom-up. + GeneralPermutationOptimizer(const IndexList& left_indices, + const IndexList& right_indices, + const bool prefer_to_permute_left = true) + : BinaryOpPermutationOptimizer(left_indices, right_indices, + prefer_to_permute_left) { + TA_EXCEPTION( + "GeneralPermutationOptimizer requires target indices: the role of an " + "index shared by both arguments (fused vs contracted) cannot be " + "determined bottom-up"); + } + + GeneralPermutationOptimizer(const IndexList& result_indices, + const IndexList& left_indices, + const IndexList& right_indices, + const bool prefer_to_permute_left = true) + : BinaryOpPermutationOptimizer(result_indices, left_indices, + right_indices, prefer_to_permute_left) { + container::svector fused, contracted, ext_left, ext_right; + + // classify target indices; relative order within each class follows the + // target + for (auto&& idx : result_indices) { + const bool in_left = left_indices.count(idx); + const bool in_right = right_indices.count(idx); + if (in_left && in_right) + fused.push_back(idx); + else if (in_left) + ext_left.push_back(idx); + else if (in_right) + ext_right.push_back(idx); + else + TA_EXCEPTION( + "GeneralPermutationOptimizer: target index does not appear in " + "either argument"); + } + + // contracted = shared indices absent from the target; left's relative + // order. Also validate that no argument index is silently dropped (an + // implicit trace/reduction is not supported). + for (auto&& idx : left_indices) { + const bool in_right = right_indices.count(idx); + const bool in_target = result_indices.count(idx); + if (in_right && !in_target) contracted.push_back(idx); + if (!in_right && !in_target) + TA_EXCEPTION( + "GeneralPermutationOptimizer: left index appears in neither the " + "right argument nor the target (implicit reduction not " + "supported)"); + } + for (auto&& idx : right_indices) { + if (!left_indices.count(idx) && !result_indices.count(idx)) + TA_EXCEPTION( + "GeneralPermutationOptimizer: right index appears in neither the " + "left argument nor the target (implicit reduction not supported)"); + } + + fused_indices_ = IndexList(fused); + contracted_indices_ = IndexList(contracted); + left_external_indices_ = IndexList(ext_left); + right_external_indices_ = IndexList(ext_right); + + auto concat = + [](std::initializer_list*> + parts) { + container::svector v; + for (auto* p : parts) v.insert(v.end(), p->begin(), p->end()); + return IndexList(v); + }; + target_left_indices_ = concat({&fused, &ext_left, &contracted}); + target_right_indices_ = concat({&fused, &contracted, &ext_right}); + target_result_indices_ = concat({&fused, &ext_left, &ext_right}); + + left_permtype_ = (target_left_indices_ == left_indices) + ? PermutationType::identity + : PermutationType::general; + right_permtype_ = (target_right_indices_ == right_indices) + ? PermutationType::identity + : PermutationType::general; + } + + const IndexList& target_left_indices() const override final { + return target_left_indices_; + } + const IndexList& target_right_indices() const override final { + return target_right_indices_; + } + const IndexList& target_result_indices() const override final { + return target_result_indices_; + } + PermutationType left_permtype() const override final { + return left_permtype_; + } + PermutationType right_permtype() const override final { + return right_permtype_; + } + TensorProduct op_type() const override final { + return TensorProduct::General; + } + + /// \return the fused (Hadamard) indices, in target order + const IndexList& fused_indices() const { return fused_indices_; } + /// \return the contracted indices, in left-argument order + const IndexList& contracted_indices() const { return contracted_indices_; } + /// \return the left-argument free indices, in target order + const IndexList& left_external_indices() const { + return left_external_indices_; + } + /// \return the right-argument free indices, in target order + const IndexList& right_external_indices() const { + return right_external_indices_; + } + + private: + IndexList target_left_indices_, target_right_indices_, target_result_indices_; + IndexList fused_indices_, contracted_indices_, left_external_indices_, + right_external_indices_; + PermutationType left_permtype_ = PermutationType::general, + right_permtype_ = PermutationType::general; +}; + class NullBinaryOpPermutationOptimizer : public BinaryOpPermutationOptimizer { public: NullBinaryOpPermutationOptimizer(const NullBinaryOpPermutationOptimizer&) = @@ -646,6 +799,12 @@ inline std::shared_ptr make_permutation_optimizer( case TensorProduct::Contraction: return std::make_shared( left_indices, right_indices, prefer_to_permute_left); + case TensorProduct::General: + // a general product's layout depends on the target indices: the role of + // a shared index (fused vs contracted) cannot be determined bottom-up + TA_EXCEPTION( + "make_permutation_optimizer: a TensorProduct::General product " + "requires target indices (use the target-taking overload)"); case TensorProduct::Invalid: return std::make_shared( left_indices, right_indices, prefer_to_permute_left); @@ -668,6 +827,9 @@ inline std::shared_ptr make_permutation_optimizer( case TensorProduct::Contraction: return std::make_shared( target_indices, left_indices, right_indices, prefer_to_permute_left); + case TensorProduct::General: + return std::make_shared( + target_indices, left_indices, right_indices, prefer_to_permute_left); case TensorProduct::Invalid: return std::make_shared( target_indices, left_indices, right_indices, prefer_to_permute_left); diff --git a/src/TiledArray/expressions/product.h b/src/TiledArray/expressions/product.h index df2867a360..5e23de46f4 100644 --- a/src/TiledArray/expressions/product.h +++ b/src/TiledArray/expressions/product.h @@ -68,14 +68,40 @@ inline TensorProduct compute_product_type(const IndexList& left_indices, } /// computes the tensor product type corresponding to the left and right -/// argument indices, and validates against the target indices +/// argument indices, given the target indices +/// +/// Unlike the 2-argument overload, this can detect TensorProduct::General: +/// the target determines the role of each index, so an index shared by both +/// arguments is *fused* (Hadamard) if it survives into the target and +/// *contracted* if it does not. The 2-argument overload, lacking a target, +/// follows the bottom-up convention that every shared index is contracted. +/// \return +/// - TensorProduct::Hadamard if left, right, and target are all related by +/// permutations (fused indices only), +/// - TensorProduct::General if at least one shared index is fused (appears +/// in left, right, AND target) alongside contracted and/or free indices, +/// - TensorProduct::Contraction if no shared index is fused, +/// - else as the 2-argument overload. inline TensorProduct compute_product_type(const IndexList& left_indices, const IndexList& right_indices, const IndexList& target_indices) { auto result = compute_product_type(left_indices, right_indices); if (result == TensorProduct::Hadamard) { - TA_ASSERT(left_indices.is_permutation(target_indices)); - TA_ASSERT(right_indices.is_permutation(target_indices)); + // left ≅ right; pure Hadamard requires the target to keep every index. + // A target that omits some shared indices implies they are contracted + // (a Hadamard-reduction, e.g. "i,j" * "i,j" -> "i"): fused + contracted + // coexist => General. + if (!left_indices.is_permutation(target_indices)) + result = TensorProduct::General; + } else if (result == TensorProduct::Contraction) { + // an index of the target that appears in both arguments is fused, not + // contracted: fused + (free and/or contracted) => General. + for (auto&& idx : target_indices) { + if (left_indices.count(idx) && right_indices.count(idx)) { + result = TensorProduct::General; + break; + } + } } return result; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b217e22deb..eb0726ee6b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -100,6 +100,7 @@ set(ta_test_src_files ta_test.cpp # t_tot_tot_contract_.cpp # tot_tot_tot_contract_.cpp einsum.cpp + general_product.cpp linalg.cpp cp.cpp btas.cpp diff --git a/tests/general_product.cpp b/tests/general_product.cpp new file mode 100644 index 0000000000..5d927843cd --- /dev/null +++ b/tests/general_product.cpp @@ -0,0 +1,170 @@ +/// Unit tests for general-product (fused + contracted + free indices) +/// classification and layout in the expression layer. + +#include "TiledArray/expressions/permopt.h" +#include "TiledArray/expressions/product.h" + +#include "tiledarray.h" +#include "unit_test_config.h" + +BOOST_AUTO_TEST_SUITE(general_product_suite, TA_UT_LABEL_SERIAL) + +namespace TA = TiledArray; +using TA::expressions::compute_product_type; +using TA::expressions::GeneralPermutationOptimizer; +using TA::expressions::IndexList; +using TA::expressions::PermutationType; +using TA::expressions::TensorProduct; + +BOOST_AUTO_TEST_CASE(classification_2arg_unchanged) { + // bottom-up (no target): shared indices are assumed contracted, General is + // never produced + BOOST_CHECK(compute_product_type(IndexList("i,j"), IndexList("j,i")) == + TensorProduct::Hadamard); + BOOST_CHECK(compute_product_type(IndexList("i,j"), IndexList("j,k")) == + TensorProduct::Contraction); + BOOST_CHECK(compute_product_type(IndexList("b,i,j"), IndexList("b,j,k")) == + TensorProduct::Contraction); + BOOST_CHECK(compute_product_type(IndexList{}, IndexList("i,j")) == + TensorProduct::Scale); +} + +BOOST_AUTO_TEST_CASE(classification_3arg) { + // pure Hadamard: all three related by permutation + BOOST_CHECK(compute_product_type(IndexList("i,j"), IndexList("j,i"), + IndexList("i,j")) == + TensorProduct::Hadamard); + // pure contraction: no index in all three + BOOST_CHECK(compute_product_type(IndexList("i,j"), IndexList("j,k"), + IndexList("i,k")) == + TensorProduct::Contraction); + // batched contraction: b fused, j contracted, i/k free + BOOST_CHECK(compute_product_type(IndexList("b,i,j"), IndexList("b,j,k"), + IndexList("b,i,k")) == + TensorProduct::General); + // the motivating TODO example: ijk * jkl -> ijl (j fused, k contracted) + BOOST_CHECK(compute_product_type(IndexList("i,j,k"), IndexList("j,k,l"), + IndexList("i,j,l")) == + TensorProduct::General); + // Hadamard-reduction: args are permutations of each other but the target + // drops an index => fused + contracted + BOOST_CHECK(compute_product_type(IndexList("i,j"), IndexList("i,j"), + IndexList("i")) == TensorProduct::General); + // batched outer product: b fused, no contracted, i/k free + BOOST_CHECK(compute_product_type(IndexList("b,i"), IndexList("b,k"), + IndexList("b,i,k")) == + TensorProduct::General); +} + +BOOST_AUTO_TEST_CASE(optimizer_canonical_layout) { + // C("b,i,k") = A("b,i,j") * B("b,j,k") + GeneralPermutationOptimizer opt(IndexList("b,i,k"), IndexList("b,i,j"), + IndexList("b,j,k")); + BOOST_CHECK(opt.op_type() == TensorProduct::General); + BOOST_CHECK_EQUAL(opt.fused_indices().string(), "b"); + BOOST_CHECK_EQUAL(opt.contracted_indices().string(), "j"); + BOOST_CHECK_EQUAL(opt.left_external_indices().string(), "i"); + BOOST_CHECK_EQUAL(opt.right_external_indices().string(), "k"); + // canonical layouts: left (h, eA, c), right (h, c, eB), result (h, eA, eB) + BOOST_CHECK_EQUAL(opt.target_left_indices().string(), "b,i,j"); + BOOST_CHECK_EQUAL(opt.target_right_indices().string(), "b,j,k"); + BOOST_CHECK_EQUAL(opt.target_result_indices().string(), "b,i,k"); + // both args already canonical + BOOST_CHECK(opt.left_permtype() == PermutationType::identity); + BOOST_CHECK(opt.right_permtype() == PermutationType::identity); +} + +BOOST_AUTO_TEST_CASE(optimizer_noncanonical_args) { + // C("k,b,i") = A("i,j,b") * B("k,j,b"): b fused, j contracted, i/k free; + // neither argument is in canonical layout, so both get general permtypes + GeneralPermutationOptimizer opt(IndexList("k,b,i"), IndexList("i,j,b"), + IndexList("k,j,b")); + BOOST_CHECK_EQUAL(opt.fused_indices().string(), "b"); + BOOST_CHECK_EQUAL(opt.contracted_indices().string(), "j"); + BOOST_CHECK_EQUAL(opt.left_external_indices().string(), "i"); + BOOST_CHECK_EQUAL(opt.right_external_indices().string(), "k"); + BOOST_CHECK_EQUAL(opt.target_left_indices().string(), "b,i,j"); + BOOST_CHECK_EQUAL(opt.target_right_indices().string(), "b,j,k"); + BOOST_CHECK_EQUAL(opt.target_result_indices().string(), "b,i,k"); + BOOST_CHECK(opt.left_permtype() == PermutationType::general); + BOOST_CHECK(opt.right_permtype() == PermutationType::general); +} + +BOOST_AUTO_TEST_CASE(optimizer_multiple_fused_target_order) { + // two fused indices; class order follows the target + GeneralPermutationOptimizer opt(IndexList("c,b,i,k"), IndexList("b,c,i,j"), + IndexList("c,j,b,k")); + BOOST_CHECK_EQUAL(opt.fused_indices().string(), "c,b"); + BOOST_CHECK_EQUAL(opt.target_left_indices().string(), "c,b,i,j"); + BOOST_CHECK_EQUAL(opt.target_right_indices().string(), "c,b,j,k"); + BOOST_CHECK_EQUAL(opt.target_result_indices().string(), "c,b,i,k"); +} + +BOOST_AUTO_TEST_CASE(optimizer_hadamard_reduction) { + // "i,j" * "i,j" -> "i": i fused, j contracted, no externals + GeneralPermutationOptimizer opt(IndexList("i"), IndexList("i,j"), + IndexList("i,j")); + BOOST_CHECK_EQUAL(opt.fused_indices().string(), "i"); + BOOST_CHECK_EQUAL(opt.contracted_indices().string(), "j"); + BOOST_CHECK(!opt.left_external_indices()); + BOOST_CHECK(!opt.right_external_indices()); + BOOST_CHECK_EQUAL(opt.target_left_indices().string(), "i,j"); + BOOST_CHECK_EQUAL(opt.target_right_indices().string(), "i,j"); + BOOST_CHECK_EQUAL(opt.target_result_indices().string(), "i"); +} + +BOOST_AUTO_TEST_CASE(optimizer_requires_target) { + // bottom-up construction must throw: fused-vs-contracted undecidable + BOOST_CHECK_THROW( + GeneralPermutationOptimizer(IndexList("b,i,j"), IndexList("b,j,k")), + TiledArray::Exception); +} + +BOOST_AUTO_TEST_CASE(optimizer_rejects_implicit_reduction) { + // left index "j" appears in neither right nor target + BOOST_CHECK_THROW(GeneralPermutationOptimizer( + IndexList("b,i"), IndexList("b,i,j"), IndexList("b,i")), + TiledArray::Exception); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_gated) { + // the expression layer now *classifies* general products correctly and + // reports that evaluation is not yet implemented (instead of misrouting + // to the contraction or Hadamard machinery) + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; + TA::TArrayD a(world, tr); + TA::TArrayD b(world, tr); + a.fill(1.0); + b.fill(1.0); + TA::TArrayD c; + BOOST_CHECK_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k"), + TiledArray::Exception); + // pure contraction and pure Hadamard still work + BOOST_CHECK_NO_THROW(c("i,k") = a("b,i,j") * b("b,j,k")); + BOOST_CHECK_NO_THROW(c("b,i,j") = a("b,i,j") * b("b,i,j")); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_gated) { + // THC-style reconstruction: + // g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * X("r,r2") * X("s,r2") + // r1 is fused in X("p,r1") * X("q,r1") but contracted downstream. The + // first product is an INNER node of the expression tree, where the role of + // r1 cannot be deduced bottom-up (the target reaches only the root); + // resolving this requires top-down index-set deduction (deferred). Until + // then: an informative error, not garbage (bottom-up, X*X would contract + // r1, orphaning the r1 of Z). + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary + TA::TArrayD x(world, tr_x); + TA::TArrayD z(world, tr_z); + x.fill(1.0); + z.fill(1.0); + TA::TArrayD g; + BOOST_CHECK_THROW( + g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * x("r,r2") * x("s,r2"), + TiledArray::Exception); +} + +BOOST_AUTO_TEST_SUITE_END() From 6e72dff355fcd86ea87b1a3ae2d33674833059c1 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 02:34:41 -0400 Subject: [PATCH 03/25] dist_eval: generalize Summa to batched (fused-index) contractions Phase B step 1 of general-product support: Summa gains an optional slab count nh (default 1 = exactly the prior, unbatched behavior). For nh > 1 the operands and result carry the fused (Hadamard) modes as leading dimensions (left = (h,i,k), right = (h,k,j), result = (h,i,j)); the contraction runs as nh independent SUMMA slabs over ONE shared 2-d process grid and ONE task graph: - iteration space becomes steps s = h*k_ + k; the step-task chain, depth control, and sparse step iteration (iterate_{row,col,sparse}, skipped-range broadcasts) operate in step space - every argument/result tile ordinal is offset by its slab base; the owner of a tile is independent of h (block-cyclic phase restarts per slab), so broadcast roots and the 2-d grid logic are unchanged - per-step sparse broadcast groups are keyed by step (col: s, row: s + nsteps); the static dense groups use keys 2*nsteps, 2*nsteps+1; tile broadcast keys (global ordinals) are unique across slabs as-is - reduce tasks: one per local result tile per slab (reduce_tasks_[h*local_size + i*local_cols + j]); initialize/finalize loop slab-by-slab - sparse row/col masks take the slab index; get_tile owner computation mods out the slab No caller passes nh yet (that lands with the General-product ContEngine wiring); all existing suites pass unchanged. --- src/TiledArray/dist_eval/contraction_eval.h | 471 ++++++++++++-------- 1 file changed, 286 insertions(+), 185 deletions(-) diff --git a/src/TiledArray/dist_eval/contraction_eval.h b/src/TiledArray/dist_eval/contraction_eval.h index a747c0748b..e2b9c8c731 100644 --- a/src/TiledArray/dist_eval/contraction_eval.h +++ b/src/TiledArray/dist_eval/contraction_eval.h @@ -93,15 +93,33 @@ class Summa const ordinal_type k_; ///< Number of tiles in the inner dimension const ProcGrid proc_grid_; ///< Process grid for this contraction + // Batched (fused/Hadamard-index) dimension information. A batched + // contraction C(h,i,j) = sum_k A(h,i,k) B(h,k,j) is evaluated as nh_ + // independent SUMMA "slabs" sharing one process grid and one task graph: + // the iteration space is steps s = h*k_ + k, and every tile ordinal is + // offset by its slab base (h * {left,right,result}_slab_size_). For the + // ordinary contraction nh_ == 1 and all of this reduces to the unbatched + // arithmetic. + const ordinal_type nh_; ///< Number of fused (Hadamard) slabs + const ordinal_type nsteps_; ///< Total SUMMA steps = nh_ * k_ + const ordinal_type + left_slab_size_; ///< # of left tiles per slab (= rows*k tiles) + const ordinal_type + right_slab_size_; ///< # of right tiles per slab (= k*cols tiles) + const ordinal_type + result_slab_size_; ///< # of result tiles per slab (= rows*cols tiles) + // Contraction results ReducePairTask* reduce_tasks_; ///< A pointer to the reduction tasks // Constants used to iterate over columns and rows of left_ and right_, - // respectively. + // respectively. N.B. all are *slab-local* (slab h adds + // h * {left,right}_slab_size_ at the use sites). const ordinal_type left_start_local_; ///< The starting point of left column iterator ranges ///< (just add k for specific columns) const ordinal_type left_end_; ///< The end of the left column iterator ranges + ///< within a slab const ordinal_type left_stride_; ///< Stride for left column iterators const ordinal_type left_stride_local_; ///< Stride for local left column iterators @@ -109,6 +127,11 @@ class Summa const ordinal_type right_stride_local_; ///< stride for local right row iterators + /// \return the slab index of SUMMA step \p s + ordinal_type step_h(const ordinal_type s) const { return s / k_; } + /// \return the within-slab inner-dimension index of SUMMA step \p s + ordinal_type step_k(const ordinal_type s) const { return s % k_; } + typedef Future right_future; ///< Future to a right-hand argument tile typedef Future @@ -253,21 +276,26 @@ class Summa /// Row process group factory function - /// \param k The broadcast group index + /// \param s The SUMMA step (= slab index * k_ + broadcast group index) /// \return A row process group - madness::Group make_row_group(const ordinal_type k) const { + madness::Group make_row_group(const ordinal_type s) const { + const ordinal_type h = step_h(s); + const ordinal_type k = step_k(s); // Construct the sparse broadcast group - const ordinal_type right_begin_k = k * proc_grid_.cols(); + const ordinal_type right_begin_k = + h * right_slab_size_ + k * proc_grid_.cols(); const ordinal_type right_end_k = right_begin_k + proc_grid_.cols(); // make the row mask; using the same mask for all tiles avoids having to // compute mask for every tile and use of masked broadcasts - auto result_row_mask_k = make_row_mask(k); + auto result_row_mask_k = make_row_mask(h, k); // return empty group if I am not in this group, otherwise make a group + // N.B. group key = s + nsteps_ (unique across (h,k) and distinct from the + // column groups' keys), root flag = k (the within-slab cyclic owner) if (result_row_mask_k[proc_grid_.rank_col()]) return make_group(right_.shape(), result_row_mask_k, right_begin_k, right_end_k, right_stride_, proc_grid_.proc_cols(), k, - k_, [&](const ProcGrid::size_type col) { + s - k + nsteps_, [&](const ProcGrid::size_type col) { return proc_grid_.map_col(col); }); else @@ -276,18 +304,22 @@ class Summa /// Column process group factory function - /// \param k The broadcast group index + /// \param s The SUMMA step (= slab index * k_ + broadcast group index) /// \return A column process group - madness::Group make_col_group(const ordinal_type k) const { + madness::Group make_col_group(const ordinal_type s) const { + const ordinal_type h = step_h(s); + const ordinal_type k = step_k(s); // make the column mask; using the same mask for all tiles avoids having to // compute mask for every tile and use of masked broadcasts - auto result_col_mask_k = make_col_mask(k); + auto result_col_mask_k = make_col_mask(h, k); // return empty group if I am not in this group, otherwise make a group + // N.B. group key = s (unique across (h,k)), root flag = k if (result_col_mask_k[proc_grid_.rank_row()]) return make_group( - left_.shape(), result_col_mask_k, k, left_end_, left_stride_, - proc_grid_.proc_rows(), k, 0ul, + left_.shape(), result_col_mask_k, h * left_slab_size_ + k, + h * left_slab_size_ + left_end_, left_stride_, proc_grid_.proc_rows(), + k, s - k, [&](const ordinal_type row) { return proc_grid_.map_row(row); }); else return madness::Group(); @@ -299,7 +331,8 @@ class Summa /// \return a set object, if \code result[p] == true \endcode the process /// in column \c p of this row has at least 1 result tile for this \c /// k - std::vector make_row_mask(const ordinal_type k) const { + std::vector make_row_mask(const ordinal_type h, + const ordinal_type k) const { // "local" A[i][k] (i.e. for all i assigned to my row of processes) will // produce C[i][*] for each process in my row of the process grid determine // whether there are any nonzero C[i][*] located on that node @@ -320,13 +353,16 @@ class Summa const auto nj = proc_grid_.cols(); // number of tiles in contraction dim const auto nk = k_; + // slab bases + const auto left_base = h * left_slab_size_; + const auto result_base = h * result_slab_size_; // for each i assigned to my row of processes ... ordinal_type i_start, i_fence, i_stride; std::tie(i_start, i_fence, i_stride) = result_row_range(my_proc_row); const auto ik_stride = i_stride * nk; - for (ordinal_type i = i_start, ik = i_start * nk + k; i < i_fence; - i += i_stride, ik += ik_stride) { + for (ordinal_type i = i_start, ik = left_base + i_start * nk + k; + i < i_fence; i += i_stride, ik += ik_stride) { // ... such that A[i][k] exists ... if (!left_.shape().is_zero(ik)) { // ... the owner of А[i][k] is always in the group ... @@ -340,8 +376,8 @@ class Summa ordinal_type j_start, j_fence, j_stride; std::tie(j_start, j_fence, j_stride) = result_col_range(proc_col); const auto ij_stride = j_stride; - for (ordinal_type j = j_start, ij = i * nj + j_start; j < j_fence; - j += j_stride, ij += ij_stride) { + for (ordinal_type j = j_start, ij = result_base + i * nj + j_start; + j < j_fence; j += j_stride, ij += ij_stride) { // ... if any such C[i][j] exists, update the mask, and move // on to next process if (!result_shape.is_zero( @@ -364,7 +400,8 @@ class Summa /// \return a set object, if \code result[p] == true \endcode the process /// in row \c p of this column has at least 1 result tile for this /// \c k - std::vector make_col_mask(const ordinal_type k) const { + std::vector make_col_mask(const ordinal_type h, + const ordinal_type k) const { // "local" B[k][j] (i.e. for all j assigned to my column of processes) // will produce C[*][j] // for each process in my column of the process grid determine whether @@ -385,13 +422,16 @@ class Summa // number of tiles in col dim of the result const auto nj = proc_grid_.cols(); + // slab bases + const auto right_base = h * right_slab_size_; + const auto result_base = h * result_slab_size_; // for each j assigned to my column of processes ... ordinal_type j_start, j_fence, j_stride; std::tie(j_start, j_fence, j_stride) = result_col_range(my_proc_col); const auto kj_stride = j_stride; - for (ordinal_type j = j_start, kj = k * nj + j_start; j < j_fence; - j += j_stride, kj += kj_stride) { + for (ordinal_type j = j_start, kj = right_base + k * nj + j_start; + j < j_fence; j += j_stride, kj += kj_stride) { // ... such that B[k][j] exists ... if (!right_.shape().is_zero(kj)) { // ... the owner of B[k][j] is always in the group ... @@ -405,8 +445,8 @@ class Summa ordinal_type i_start, i_fence, i_stride; std::tie(i_start, i_fence, i_stride) = result_row_range(proc_row); const auto ij_stride = i_stride * nj; - for (ordinal_type i = i_start, ij = i_start * nj + j; i < i_fence; - i += i_stride, ij += ij_stride) { + for (ordinal_type i = i_start, ij = result_base + i_start * nj + j; + i < i_fence; i += i_stride, ij += ij_stride) { // ... if any such C[i][j] exists, update the mask, and move // on to next process if (!result_shape.is_zero( @@ -554,25 +594,27 @@ class Summa TA_ASSERT(vec.size() > 0ul); } - /// Collect non-zero tiles from column \c k of \c left_ + /// Collect non-zero tiles from column \c k of slab \c h of \c left_ - /// \param[in] k The column to be retrieved + /// \param[in] s The SUMMA step (slab * k_ + column index) /// \param[out] col The column vector that will hold the tiles - void get_col(const ordinal_type k, std::vector& col) const { + void get_col(const ordinal_type s, std::vector& col) const { + const ordinal_type base = step_h(s) * left_slab_size_; col.reserve(proc_grid_.local_rows()); - get_vector(left_, left_start_local_ + k, left_end_, left_stride_local_, - col); + get_vector(left_, base + left_start_local_ + step_k(s), base + left_end_, + left_stride_local_, col); } - /// Collect non-zero tiles from row \c k of \c right_ + /// Collect non-zero tiles from row \c k of slab \c h of \c right_ - /// \param[in] k The row to be retrieved + /// \param[in] s The SUMMA step (slab * k_ + row index) /// \param[out] row The row vector that will hold the tiles - void get_row(const ordinal_type k, std::vector& row) const { + void get_row(const ordinal_type s, std::vector& row) const { row.reserve(proc_grid_.local_cols()); - // Compute local iteration limits for row k of right_. - ordinal_type begin = k * proc_grid_.cols(); + // Compute local iteration limits for row k of slab h of right_. + ordinal_type begin = + step_h(s) * right_slab_size_ + step_k(s) * proc_grid_.cols(); const ordinal_type end = begin + proc_grid_.cols(); begin += proc_grid_.rank_col(); @@ -654,46 +696,53 @@ class Summa return group_root; } - /// Broadcast column \c k of \c left_ with a dense right-hand argument + /// Broadcast column \c k of slab \c h of \c left_ - /// \param[in] k The column of \c left_ to be broadcast + /// \param[in] s The SUMMA step (slab * k_ + column index) /// \param[out] col The vector that will hold the results of the broadcast - void bcast_col(const ordinal_type k, std::vector& col, + void bcast_col(const ordinal_type s, std::vector& col, const madness::Group& row_group) const { // broadcast if I'm part of the broadcast group if (!row_group.empty()) { - // Broadcast column k of left_. - ProcessID group_root = get_row_group_root(k, row_group); - bcast(left_start_local_ + k, left_stride_local_, row_group, group_root, - 0ul, col); + // Broadcast column k of slab h of left_. + ProcessID group_root = get_row_group_root(step_k(s), row_group); + bcast(step_h(s) * left_slab_size_ + left_start_local_ + step_k(s), + left_stride_local_, row_group, group_root, 0ul, col); } } - /// Broadcast row \c k of \c right_ with a dense left-hand argument + /// Broadcast row \c k of slab \c h of \c right_ - /// \param[in] k The row of \c right to be broadcast + /// \param[in] s The SUMMA step (slab * k_ + row index) /// \param[out] row The vector that will hold the results of the broadcast - void bcast_row(const ordinal_type k, std::vector& row, + void bcast_row(const ordinal_type s, std::vector& row, const madness::Group& col_group) const { // broadcast if I'm part of the broadcast group if (!col_group.empty()) { // Compute the group root process. - ProcessID group_root = get_col_group_root(k, col_group); + ProcessID group_root = get_col_group_root(step_k(s), col_group); - // Broadcast row k of right_. - bcast(k * proc_grid_.cols() + proc_grid_.rank_col(), right_stride_local_, - col_group, group_root, left_.size(), row); + // Broadcast row k of slab h of right_. + bcast(step_h(s) * right_slab_size_ + step_k(s) * proc_grid_.cols() + + proc_grid_.rank_col(), + right_stride_local_, col_group, group_root, left_.size(), row); } } - void bcast_col_range_task(ordinal_type k, const ordinal_type end) const { - // Compute the first local row of right + void bcast_col_range_task(ordinal_type s, const ordinal_type end) const { + // Iterate over the skipped steps for which this process column owns the + // broadcast root (i.e. within-slab k congruent to rank_col mod Pcols) const ordinal_type Pcols = proc_grid_.proc_cols(); - k += (Pcols - ((k + Pcols - proc_grid_.rank_col()) % Pcols)) % Pcols; - for (; k < end; k += Pcols) { - // Compute local iteration limits for column k of left_. - ordinal_type index = left_start_local_ + k; + for (; s < end; ++s) { + const ordinal_type k = step_k(s); + if (k % Pcols != static_cast(proc_grid_.rank_col())) + continue; + const ordinal_type left_base = step_h(s) * left_slab_size_; + + // Compute local iteration limits for column k of slab h of left_. + ordinal_type index = left_base + left_start_local_ + k; + const ordinal_type col_end = left_base + left_end_; // will create broadcast group only if needed bool have_group = false; @@ -702,13 +751,13 @@ class Summa bool do_broadcast; // Search column k of left for non-zero tiles - for (; index < left_end_; index += left_stride_local_) { + for (; index < col_end; index += left_stride_local_) { if (left_.shape().is_zero(index)) continue; // Construct broadcast group, if needed if (!have_group) { have_group = true; - row_group = make_row_group(k); + row_group = make_row_group(s); // broadcast if I am in this group and this group has others do_broadcast = !row_group.empty() && row_group.size() > 1; if (do_broadcast) group_root = get_row_group_root(k, row_group); @@ -733,14 +782,18 @@ class Summa } } - void bcast_row_range_task(ordinal_type k, const ordinal_type end) const { - // Compute the first local row of right + void bcast_row_range_task(ordinal_type s, const ordinal_type end) const { + // Iterate over the skipped steps for which this process row owns the + // broadcast root (i.e. within-slab k congruent to rank_row mod Prows) const ordinal_type Prows = proc_grid_.proc_rows(); - k += (Prows - ((k + Prows - proc_grid_.rank_row()) % Prows)) % Prows; - for (; k < end; k += Prows) { - // Compute local iteration limits for row k of right_. - ordinal_type index = k * proc_grid_.cols(); + for (; s < end; ++s) { + const ordinal_type k = step_k(s); + if (k % Prows != static_cast(proc_grid_.rank_row())) + continue; + + // Compute local iteration limits for row k of slab h of right_. + ordinal_type index = step_h(s) * right_slab_size_ + k * proc_grid_.cols(); const ordinal_type row_end = index + proc_grid_.cols(); index += proc_grid_.rank_col(); @@ -757,7 +810,7 @@ class Summa // Construct broadcast group if (!have_group) { have_group = true; - col_group = make_col_group(k); + col_group = make_col_group(s); // broadcast if I am in this group and this group has others do_broadcast = !col_group.empty() && col_group.size() > 1; if (do_broadcast) group_root = get_col_group_root(k, col_group); @@ -787,45 +840,48 @@ class Summa /// Find next non-zero row of \c right_ for a sparse shape - /// Starting at the k-th row of the right-hand argument, find the next row - /// that contains at least one non-zero tile. This search only checks for + /// Starting at SUMMA step \c s, find the next step whose right-hand row + /// contains at least one non-zero tile. This search only checks for /// non-zero tiles in this processes column. - /// \param k The first row to search - /// \return The first row, greater than or equal to \c k with non-zero - /// tiles, or \c k_ if none is found. - ordinal_type iterate_row(ordinal_type k) const { - // Iterate over k's until a non-zero tile is found or the end of the + /// \param s The first step to search + /// \return The first step, greater than or equal to \c s with non-zero + /// tiles, or \c nsteps_ if none is found. + ordinal_type iterate_row(ordinal_type s) const { + // Iterate over steps until a non-zero tile is found or the end of the // matrix is reached. - ordinal_type end = k * proc_grid_.cols(); - for (; k < k_; ++k) { - // Search for non-zero tiles in row k of right - ordinal_type i = end + proc_grid_.rank_col(); - end += proc_grid_.cols(); + for (; s < nsteps_; ++s) { + // Search for non-zero tiles in row k of slab h of right + ordinal_type i = + step_h(s) * right_slab_size_ + step_k(s) * proc_grid_.cols(); + const ordinal_type end = i + proc_grid_.cols(); + i += proc_grid_.rank_col(); for (; i < end; i += right_stride_local_) - if (!right_.shape().is_zero(i)) return k; + if (!right_.shape().is_zero(i)) return s; } - return k; + return s; } /// Find the next non-zero column of \c left_ for an arbitrary shape type - /// Starting at the k-th column of the left-hand argument, find the next - /// column that contains at least one non-zero tile. This search only + /// Starting at SUMMA step \c s, find the next step whose left-hand column + /// contains at least one non-zero tile. This search only /// checks for non-zero tiles in this process's row. - /// \param k The first column to test for non-zero tiles - /// \return The first column, greater than or equal to \c k, that contains - /// a non-zero tile. If no non-zero tile is not found, return \c k_. - ordinal_type iterate_col(ordinal_type k) const { - // Iterate over k's until a non-zero tile is found or the end of the + /// \param s The first step to test for non-zero tiles + /// \return The first step, greater than or equal to \c s, that contains + /// a non-zero tile. If no non-zero tile is not found, return \c nsteps_. + ordinal_type iterate_col(ordinal_type s) const { + // Iterate over steps until a non-zero tile is found or the end of the // matrix is reached. - for (; k < k_; ++k) - // Search row k for non-zero tiles - for (ordinal_type i = left_start_local_ + k; i < left_end_; - i += left_stride_local_) - if (!left_.shape().is_zero(i)) return k; + for (; s < nsteps_; ++s) { + // Search column k of slab h for non-zero tiles + const ordinal_type base = step_h(s) * left_slab_size_; + for (ordinal_type i = base + left_start_local_ + step_k(s); + i < base + left_end_; i += left_stride_local_) + if (!left_.shape().is_zero(i)) return s; + } - return k; + return s; } /// Find the next k where the left- and right-hand argument have non-zero @@ -839,9 +895,9 @@ class Summa /// \param k The first row/column to check /// \return The next k-th column and row of the left- and right-hand /// arguments, respectively, that both have non-zero tiles - ordinal_type iterate_sparse(const ordinal_type k) const { + ordinal_type iterate_sparse(const ordinal_type s) const { // Initial step for k_col and k_row. - ordinal_type k_col = iterate_col(k); + ordinal_type k_col = iterate_col(s); ordinal_type k_row = iterate_row(k_col); // Search for a row and column that both have non-zero tiles @@ -853,15 +909,15 @@ class Summa } } - if (k < k_row) { + if (s < k_row) { // Spawn a task to broadcast any local columns of left that were skipped TensorImpl_::world().taskq.add(shared_from_this(), - &Summa_::bcast_col_range_task, k, k_row, + &Summa_::bcast_col_range_task, s, k_row, madness::TaskAttributes::hipri()); // Spawn a task to broadcast any local rows of right that were skipped TensorImpl_::world().taskq.add(shared_from_this(), - &Summa_::bcast_row_range_task, k, k_col, + &Summa_::bcast_row_range_task, s, k_col, madness::TaskAttributes::hipri()); } @@ -907,9 +963,11 @@ class Summa return tile_count; } else { // Construct static broadcast groups for dense arguments - const madness::DistributedID col_did(DistEvalImpl_::id(), 0ul); + // (key space [0, 2*nsteps_) is reserved for the sparse per-step groups) + const madness::DistributedID col_did(DistEvalImpl_::id(), 2ul * nsteps_); col_group_ = proc_grid_.make_col_group(col_did); - const madness::DistributedID row_did(DistEvalImpl_::id(), k_); + const madness::DistributedID row_did(DistEvalImpl_::id(), + 2ul * nsteps_ + 1ul); row_group_ = proc_grid_.make_row_group(row_did); #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE @@ -926,12 +984,13 @@ class Summa printf(ss.str().c_str()); #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE - // Allocate memory for the reduce pair tasks. + // Allocate memory for the reduce pair tasks (one per local result tile + // per slab). std::allocator> alloc; - reduce_tasks_ = alloc.allocate(proc_grid_.local_size()); + reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size()); // Iterate over all local tiles - const ordinal_type n = proc_grid_.local_size(); + const ordinal_type n = nh_ * proc_grid_.local_size(); for (ordinal_type t = 0ul; t < n; ++t) { // Initialize the reduction task ReducePairTask* MADNESS_RESTRICT const reduce_task = @@ -944,7 +1003,7 @@ class Summa ); } - return proc_grid_.local_size(); + return n; } } @@ -959,46 +1018,53 @@ class Summa // fast return if there is no work to do if (k_ == 0) return 0; - // Allocate memory for the reduce pair tasks. + // Allocate memory for the reduce pair tasks (one per local result tile + // per slab). std::allocator> alloc; - reduce_tasks_ = alloc.allocate(proc_grid_.local_size()); + reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size()); // Initialize iteration variables - ordinal_type row_start = proc_grid_.rank_row() * proc_grid_.cols(); - ordinal_type row_end = row_start + proc_grid_.cols(); - row_start += proc_grid_.rank_col(); const ordinal_type col_stride = // The stride to iterate down a column proc_grid_.proc_rows() * proc_grid_.cols(); const ordinal_type row_stride = // The stride to iterate across a row proc_grid_.proc_cols(); - const ordinal_type end = TensorImpl_::size(); - // Iterate over all local tiles + // Iterate over all local tiles, slab by slab (the block-cyclic phase + // restarts at every slab: the owner of tile (h,i,j) does not depend on h) ordinal_type tile_count = 0ul; ReducePairTask* MADNESS_RESTRICT reduce_task = reduce_tasks_; - // this loops over result tiles arranged in block-cyclic order - // index = tile index (row major) - for (; row_start < end; row_start += col_stride, row_end += col_stride) { - for (ordinal_type index = row_start; index < row_end; - index += row_stride, ++reduce_task) { - // Initialize the reduction task - - // Skip zero tiles - if (!shape.is_zero(DistEvalImpl_::perm_index_to_target(index))) { + for (ordinal_type h = 0ul; h < nh_; ++h) { + const ordinal_type slab_base = h * result_slab_size_; + ordinal_type row_start = + slab_base + proc_grid_.rank_row() * proc_grid_.cols(); + ordinal_type row_end = row_start + proc_grid_.cols(); + row_start += proc_grid_.rank_col(); + const ordinal_type end = slab_base + result_slab_size_; + + // this loops over result tiles arranged in block-cyclic order + // index = tile index (row major) + for (; row_start < end; row_start += col_stride, row_end += col_stride) { + for (ordinal_type index = row_start; index < row_end; + index += row_stride, ++reduce_task) { + // Initialize the reduction task + + // Skip zero tiles + if (!shape.is_zero(DistEvalImpl_::perm_index_to_target(index))) { #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE - ss << index << " "; + ss << index << " "; #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE - new (reduce_task) ReducePairTask(TensorImpl_::world(), op_ + new (reduce_task) ReducePairTask(TensorImpl_::world(), op_ #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE - , - nullptr, index + , + nullptr, index #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE - ); - ++tile_count; - } else { - // Construct an empty task to represent zero tiles. - new (reduce_task) ReducePairTask(); + ); + ++tile_count; + } else { + // Construct an empty task to represent zero tiles. + new (reduce_task) ReducePairTask(); + } } } } @@ -1030,32 +1096,37 @@ class Summa /// Set the result tiles, destroy reduce tasks, and destroy broadcast groups void finalize(const DenseShape&) { // Initialize iteration variables - ordinal_type row_start = proc_grid_.rank_row() * proc_grid_.cols(); - ordinal_type row_end = row_start + proc_grid_.cols(); - row_start += proc_grid_.rank_col(); const ordinal_type col_stride = // The stride to iterate down a column proc_grid_.proc_rows() * proc_grid_.cols(); const ordinal_type row_stride = // The stride to iterate across a row proc_grid_.proc_cols(); - const ordinal_type end = TensorImpl_::size(); - - // Iterate over all local tiles - for (ReducePairTask* reduce_task = reduce_tasks_; row_start < end; - row_start += col_stride, row_end += col_stride) { - for (ordinal_type index = row_start; index < row_end; - index += row_stride, ++reduce_task) { - // Set the result tile - DistEvalImpl_::set_tile(DistEvalImpl_::perm_index_to_target(index), - reduce_task->submit()); - - // Destroy the reduce task - reduce_task->~ReducePairTask(); + + // Iterate over all local tiles, slab by slab + ReducePairTask* reduce_task = reduce_tasks_; + for (ordinal_type h = 0ul; h < nh_; ++h) { + const ordinal_type slab_base = h * result_slab_size_; + ordinal_type row_start = + slab_base + proc_grid_.rank_row() * proc_grid_.cols(); + ordinal_type row_end = row_start + proc_grid_.cols(); + row_start += proc_grid_.rank_col(); + const ordinal_type end = slab_base + result_slab_size_; + + for (; row_start < end; row_start += col_stride, row_end += col_stride) { + for (ordinal_type index = row_start; index < row_end; + index += row_stride, ++reduce_task) { + // Set the result tile + DistEvalImpl_::set_tile(DistEvalImpl_::perm_index_to_target(index), + reduce_task->submit()); + + // Destroy the reduce task + reduce_task->~ReducePairTask(); + } } } // Deallocate the memory for the reduce pair tasks. std::allocator>().deallocate( - reduce_tasks_, proc_grid_.local_size()); + reduce_tasks_, nh_ * proc_grid_.local_size()); } /// Set the result tiles and destroy reduce tasks @@ -1067,41 +1138,46 @@ class Summa #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE // Initialize iteration variables - ordinal_type row_start = proc_grid_.rank_row() * proc_grid_.cols(); - ordinal_type row_end = row_start + proc_grid_.cols(); - row_start += proc_grid_.rank_col(); const ordinal_type col_stride = // The stride to iterate down a column proc_grid_.proc_rows() * proc_grid_.cols(); const ordinal_type row_stride = // The stride to iterate across a row proc_grid_.proc_cols(); - const ordinal_type end = TensorImpl_::size(); - - // Iterate over all local tiles - for (ReducePairTask* reduce_task = reduce_tasks_; row_start < end; - row_start += col_stride, row_end += col_stride) { - for (ordinal_type index = row_start; index < row_end; - index += row_stride, ++reduce_task) { - // Compute the permuted index - const ordinal_type perm_index = - DistEvalImpl_::perm_index_to_target(index); - // Skip zero tiles - if (!shape.is_zero(perm_index)) { + // Iterate over all local tiles, slab by slab + ReducePairTask* reduce_task = reduce_tasks_; + for (ordinal_type h = 0ul; h < nh_; ++h) { + const ordinal_type slab_base = h * result_slab_size_; + ordinal_type row_start = + slab_base + proc_grid_.rank_row() * proc_grid_.cols(); + ordinal_type row_end = row_start + proc_grid_.cols(); + row_start += proc_grid_.rank_col(); + const ordinal_type end = slab_base + result_slab_size_; + + for (; row_start < end; row_start += col_stride, row_end += col_stride) { + for (ordinal_type index = row_start; index < row_end; + index += row_stride, ++reduce_task) { + // Compute the permuted index + const ordinal_type perm_index = + DistEvalImpl_::perm_index_to_target(index); + + // Skip zero tiles + if (!shape.is_zero(perm_index)) { #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE - ss << index << " "; + ss << index << " "; #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE - // Set the result tile - DistEvalImpl_::set_tile(perm_index, reduce_task->submit()); - } + // Set the result tile + DistEvalImpl_::set_tile(perm_index, reduce_task->submit()); + } - // Destroy the reduce task - reduce_task->~ReducePairTask(); + // Destroy the reduce task + reduce_task->~ReducePairTask(); + } } } // Deallocate the memory for the reduce pair tasks. std::allocator>().deallocate( - reduce_tasks_, proc_grid_.local_size()); + reduce_tasks_, nh_ * proc_grid_.local_size()); #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE ss << "}\n"; @@ -1149,15 +1225,19 @@ class Summa /// \param col A column of tiles from the left-hand argument /// \param row A row of tiles from the right-hand argument /// \param task The task that depends on tile contraction tasks - void contract(const DenseShape&, const ordinal_type, + void contract(const DenseShape&, const ordinal_type s, const std::vector& col, const std::vector& row, madness::TaskInterface* const task) { + // The reduce tasks of slab h occupy + // [h * local_size, (h+1) * local_size) + const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size(); + // Iterate over the row for (ordinal_type i = 0ul; i < col.size(); ++i) { // Compute the local, result-tile offset const ordinal_type reduce_task_offset = - col[i].first * proc_grid_.local_cols(); + slab_offset + col[i].first * proc_grid_.local_cols(); // Iterate over columns for (ordinal_type j = 0ul; j < row.size(); ++j) { @@ -1182,15 +1262,19 @@ class Summa /// \param row A row of tiles from the right-hand argument /// \param task The task that depends on tile contraction tasks template - void contract(const Shape&, const ordinal_type, + void contract(const Shape&, const ordinal_type s, const std::vector& col, const std::vector& row, madness::TaskInterface* const task) { + // The reduce tasks of slab h occupy + // [h * local_size, (h+1) * local_size) + const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size(); + // Iterate over the row for (ordinal_type i = 0ul; i < col.size(); ++i) { // Compute the local, result-tile offset const ordinal_type reduce_task_offset = - col[i].first * proc_grid_.local_cols(); + slab_offset + col[i].first * proc_grid_.local_cols(); // Iterate over columns for (ordinal_type j = 0ul; j < row.size(); ++j) { @@ -1369,7 +1453,7 @@ class Summa template void make_next_step_tasks(Derived* task, ordinal_type depth) { // Set the depth to be no greater than the maximum number steps - if (depth > owner_->k_) depth = owner_->k_; + if (depth > owner_->nsteps_) depth = owner_->nsteps_; // Spawn n=depth step tasks for (; depth > 0ul; --depth) { @@ -1390,7 +1474,7 @@ class Summa printf("step: start rank=%i k=%lu\n", owner_->world().rank(), k); #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_STEP - if (k < owner_->k_) { + if (k < owner_->nsteps_) { // Initialize next tail task and submit next task TA_ASSERT(next_step_task_); next_step_task_->tail_step_task_ = new Derived( @@ -1456,7 +1540,7 @@ class Summa public: DenseStepTask(const std::shared_ptr& owner, const ordinal_type depth) - : StepTask(owner, owner->k_ + 1ul), k_(0) { + : StepTask(owner, owner->nsteps_ + 1ul), k_(0) { StepTask::make_next_step_tasks(this, depth); StepTask::spawn_get_row_col_tasks(k_); } @@ -1464,7 +1548,7 @@ class Summa DenseStepTask(DenseStepTask* const parent, const int ndep) : StepTask(parent, ndep), k_(parent->k_ + 1ul) { // Spawn tasks to get k-th row and column tiles - if (k_ < owner_->k_) StepTask::spawn_get_row_col_tasks(k_); + if (k_ < owner_->nsteps_) StepTask::spawn_get_row_col_tasks(k_); } virtual ~DenseStepTask() {} @@ -1492,7 +1576,7 @@ class Summa k = owner_->iterate_sparse(k + offset); k_.set(k); - if (k < owner_->k_) { + if (k < owner_->nsteps_) { // NOTE: The order of task submissions is dependent on the order in // which we want the tasks to complete. @@ -1533,7 +1617,7 @@ class Summa SparseStepTask(SparseStepTask* const parent, const int ndep) : StepTask(parent, ndep) { - if (parent->k_.probe() && (parent->k_.get() >= owner_->k_)) { + if (parent->k_.probe() && (parent->k_.get() >= owner_->nsteps_)) { // Avoid running extra tasks if not needed. k_.set(parent->k_.get()); TA_ASSERT(ndep == @@ -1570,6 +1654,14 @@ class Summa /// \param k The number of tiles in the inner dimension /// \param proc_grid The process grid that defines the layout of the tiles /// during the contraction evaluation + /// \param nh The number of fused (Hadamard/batch) slabs; the default (1) + /// is the ordinary, unbatched contraction. For nh > 1 the + /// arguments and the result carry the fused modes as their + /// leading dimensions (left = (h,i,k), right = (h,k,j), + /// result = (h,i,j)), each slab is distributed over the same + /// 2-d process grid (i.e. the owner of a tile is independent of + /// h), and the contraction runs as nh independent SUMMA slabs + /// sharing one task graph with no inter-slab barriers. /// \note The trange, shape, and pmap refer to the final, /// permuted, state for the result, NOT to the result during /// the SUMMA evaluation. @@ -1578,7 +1670,8 @@ class Summa Summa(const left_type& left, const right_type& right, World& world, const trange_type trange, const shape_type& shape, const std::shared_ptr& pmap, const Perm& perm, - const op_type& op, const ordinal_type k, const ProcGrid& proc_grid) + const op_type& op, const ordinal_type k, const ProcGrid& proc_grid, + const ordinal_type nh = 1ul) : DistEvalImpl_(world, trange, shape, pmap, outer(perm)), left_(left), right_(right), @@ -1587,9 +1680,14 @@ class Summa col_group_(), k_(k), proc_grid_(proc_grid), + nh_(nh), + nsteps_(nh * k), + left_slab_size_(left.size() / nh), + right_slab_size_(right.size() / nh), + result_slab_size_(proc_grid.rows() * proc_grid.cols()), reduce_tasks_(NULL), left_start_local_(proc_grid_.rank_row() * k), - left_end_(left.size()), + left_end_(left.size() / nh), left_stride_(k), left_stride_local_(proc_grid.proc_rows() * k), right_stride_(1ul), @@ -1602,6 +1700,9 @@ class Summa right_ntiles_discarded_(0) #endif { + TA_ASSERT(nh_ > 0); + TA_ASSERT(left.size() % nh_ == 0); + TA_ASSERT(right.size() % nh_ == 0); } virtual ~Summa() {} @@ -1618,9 +1719,11 @@ class Summa const ordinal_type source_index = DistEvalImpl_::perm_index_to_source(i); - // Compute tile coordinate in tile grid - const ordinal_type tile_row = source_index / proc_grid_.cols(); - const ordinal_type tile_col = source_index % proc_grid_.cols(); + // Compute tile coordinate in tile grid (the owner of a tile is + // independent of its slab index) + const ordinal_type slab_index = source_index % result_slab_size_; + const ordinal_type tile_row = slab_index / proc_grid_.cols(); + const ordinal_type tile_col = slab_index % proc_grid_.cols(); // Compute process coordinate of tile in the process grid const ordinal_type proc_row = tile_row % proc_grid_.proc_rows(); const ordinal_type proc_col = tile_col % proc_grid_.proc_cols(); @@ -1729,12 +1832,11 @@ class Summa // watch out for the corner case: contraction over zero-volume range // producing nonzero-volume result ... in that case there is nothing to do // the appropriate initialization was performed in the initialize() method - if (k_ != 0) { + if (nsteps_ != 0) { // Construct the first SUMMA iteration task if (TensorImpl_::shape().is_dense()) { - // We cannot have more iterations than there are blocks in the k - // dimension - if (depth > k_) depth = k_; + // We cannot have more iterations than there are SUMMA steps + if (depth > nsteps_) depth = nsteps_; // Modify the number of concurrent iterations based on the available // memory. @@ -1761,9 +1863,8 @@ class Summa depth = float(depth) * (1.0f - 1.35638f * std::log2(frac_non_zero)) + 0.5f; - // We cannot have more iterations than there are blocks in the k - // dimension - if (depth > k_) depth = k_; + // We cannot have more iterations than there are SUMMA steps + if (depth > nsteps_) depth = nsteps_; // Modify the number of concurrent iterations based on the available // memory and sparsity of the argument tensors. @@ -1775,7 +1876,7 @@ class Summa TensorImpl_::world().taskq.add( new SparseStepTask(shared_from_this(), depth)); } - } // k_ != 0 + } // nsteps_ != 0 } #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_EVAL From 2705743ffcaba50092166431baa4bbb56d91b703 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 02:55:10 -0400 Subject: [PATCH 04/25] expressions: evaluate dense general products via the batched Summa Phase B step 2: dense (DensePolicy) general products now evaluate natively in the expression layer, end-to-end: C("b,i,k") = A("b,i,j") * B("b,j,k") runs as ONE distributed batched Summa in one World -- no per-Hadamard-tile sub-Worlds. - SlabbedPmap: replicates a base pmap over a leading slab dimension (owner of a tile is independent of its fused-index slab), used for the SUMMA phase maps of the arguments and the result pmap - BatchedContractReduce: adapts a folded (fused-mode-free) ContractReduce to tiles carrying leading fused modes; folds them into the tile batch dimension by zero-copy reshape (modes lead => layout preserved), allocates the result with its full range up front, and lets Tensor::gemm's per-batch loop do the work; TA::Tensor tiles only - ContEngine: init_struct_general / make_trange_general / make_shape_general / init_distribution_general / make_dist_eval_general -- fused-mode-prefixed result structure, per-slab 2-d process grid, batched Summa construction (nh = product of fused-mode tile extents, K = per-slab contracted tile count) - MultEngine routes General to these; ScalMultEngine still gates - not yet supported (clear errors): block-sparse shapes (per-slab shape gemm TODO), tensors-of-tensors, targets interleaving fused and free modes Differential-tested against the einsum free function (norm(diff) <= 1e-10): multi-tile uneven dims, permuted argument layouts, batched outer product; np = 1, 2, 3, 4. All existing suites unchanged. --- src/TiledArray/expressions/cont_engine.h | 223 +++++++++++++++++- src/TiledArray/expressions/mult_engine.h | 21 +- src/TiledArray/pmap/slabbed_pmap.h | 116 +++++++++ .../tile_op/batched_contract_reduce.h | 203 ++++++++++++++++ tests/general_product.cpp | 136 ++++++++++- 5 files changed, 676 insertions(+), 23 deletions(-) create mode 100644 src/TiledArray/pmap/slabbed_pmap.h create mode 100644 src/TiledArray/tile_op/batched_contract_reduce.h diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index e723952a34..fe4fd937bb 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -29,9 +29,11 @@ #include #include #include +#include #include #include #include +#include #include #include @@ -172,7 +174,11 @@ class ContEngine : public BinaryEngine { TA_NO_UNIQUE_ADDRESS arena_plan_storage_t arena_plan_; TiledArray::detail::ProcGrid proc_grid_; ///< Process grid for the contraction - size_type K_ = 1; ///< Inner dimension size + size_type K_ = 1; ///< Inner dimension size (# of tiles, per slab for + ///< general products) + // General (fused + contracted + free indices) products only: + unsigned int n_fused_modes_ = 0; ///< # of leading fused (outer) modes + size_type n_slabs_ = 1; ///< # of fused-index tile slabs static unsigned int find(const BipartiteIndexList& indices, const std::string& index_label, unsigned int i, @@ -593,6 +599,221 @@ class ContEngine : public BinaryEngine { return dist_eval_type(pimpl); } + // == General (fused + contracted + free indices) product support ========== + // + // The canonical layouts produced by GeneralPermutationOptimizer carry the + // fused (Hadamard) modes as the leading modes of both arguments and the + // result: left = (h, e_A, c), right = (h, c, e_B), result = (h, e_A, e_B). + // The product is evaluated by the batched Summa: every fused-index tile + // slab is an independent SUMMA distributed over ONE shared 2-d process + // grid (the owner of a tile is independent of its slab index), and the + // within-tile fused extents are folded into the tile batch dimension by + // BatchedContractReduce. + + /// Initialize the result structure of a general product + + /// The general-product analogue of init_struct: builds the *folded* + /// (fused-mode-free) tile op and the fused-mode-prefixed result trange and + /// shape. + /// \param target_indices The target index list for the result tensor + void init_struct_general(const BipartiteIndexList& target_indices) { + if constexpr (TiledArray::detail::is_tensor_of_tensor_v) + TA_EXCEPTION( + "general products (fused + contracted + free indices) of " + "tensors-of-tensors via the expression layer are not yet " + "implemented; use TiledArray::einsum() instead"); + else { + // Initialize children + left_.init_struct(left_indices_); + right_.init_struct(right_indices_); + + // count the fused modes: the leading indices common to the canonical + // left, right, and result layouts + const auto& left_outer = outer(left_indices_); + const auto& right_outer = outer(right_indices_); + const auto& result_outer = outer(indices_); + unsigned int nh = 0u; + while (nh < result_outer.size() && nh < left_outer.size() && + nh < right_outer.size() && left_outer[nh] == result_outer[nh] && + right_outer[nh] == result_outer[nh]) + ++nh; + TA_ASSERT(nh > 0u); // else this is a pure contraction + n_fused_modes_ = nh; + + // initialize perm_; an interleaved target (a result permutation that + // mixes fused and free modes) is not yet supported -- the canonical + // result layout must equal the target + this->init_perm(target_indices); + if (outer(target_indices) != outer(indices_)) + TA_EXCEPTION( + "general products (fused + contracted + free indices): targets " + "that interleave fused and free indices are not yet supported; " + "reorder the result annotation to (fused..., left-free..., " + "right-free...)"); + + // the tile op operates on the folded (fused-mode-free) shapes + const auto left_op = to_cblas_op(left_outer_permtype_); + const auto right_op = to_cblas_op(right_outer_permtype_); + op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh, + outer_size(left_indices_) - nh, + outer_size(right_indices_) - nh); + + trange_ = make_trange_general(); + shape_ = make_shape_general(); + + if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { + shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape); + } + } + } + + /// Tiled range factory function for a general product + + /// \return The result tiled range: the fused mode ranges followed by the + /// left- and right-external mode ranges + trange_type make_trange_general() const { + const unsigned int nh = n_fused_modes_; + const unsigned int nc = op_.gemm_helper().num_contract_ranks(); + const unsigned int neA = op_.gemm_helper().left_rank() - nc; + const unsigned int neB = op_.gemm_helper().right_rank() - nc; + + typename trange_type::Ranges ranges(nh + neA + neB); + unsigned int i = 0ul; + for (unsigned int x = 0ul; x < nh + neA; ++x, ++i) + ranges[i] = left_.trange().data()[x]; + for (unsigned int x = nh + nc; x < nh + nc + neB; ++x, ++i) + ranges[i] = right_.trange().data()[x]; + +#ifndef NDEBUG + // the fused and contracted dimensions must have congruent tilings + for (unsigned int d = 0ul; d < nh; ++d) { + if (!is_congruent(left_.trange().data()[d], right_.trange().data()[d])) + TA_EXCEPTION( + "the fused dimensions of the left- and right-hand expressions " + "are not congruent"); + } + for (unsigned int l = nh + neA, r = nh; l < nh + neA + nc; ++l, ++r) { + if (!is_congruent(left_.trange().data()[l], right_.trange().data()[r])) + TA_EXCEPTION( + "the contracted dimensions of the left- and right-hand " + "expressions are not congruent"); + } +#endif // NDEBUG + + return trange_type(ranges.begin(), ranges.end()); + } + + /// Shape factory function for a general product + + /// \return The result shape + shape_type make_shape_general() const { + if constexpr (std::is_same_v) + return shape_type(); + else + // TODO support block-sparse general products: evaluate the shape + // slab-by-slab (per-slab SparseShape::gemm over the folded modes) + TA_EXCEPTION( + "block-sparse general products (fused + contracted + free indices) " + "via the expression layer are not yet implemented; use " + "TiledArray::einsum() instead"); + } + + /// Initialize the result distribution of a general product + + /// The 2-d process grid spans the external (free) modes only; the fused + /// modes are replicated over the grid (the owner of a tile is independent + /// of its slab index) via SlabbedPmap. + /// \param world The world where the result will be distributed + /// \param pmap The process map for the result tensor tiles + void init_distribution_general(World* world, + std::shared_ptr pmap) { + const unsigned int nh = n_fused_modes_; + const unsigned int nc = op_.gemm_helper().num_contract_ranks(); + const unsigned int neA = op_.gemm_helper().left_rank() - nc; + const unsigned int neB = op_.gemm_helper().right_rank() - nc; + + // Get pointers to the argument sizes + const auto* MADNESS_RESTRICT const left_tiles_size = + left_.trange().tiles_range().extent_data(); + const auto* MADNESS_RESTRICT const left_element_size = + left_.trange().elements_range().extent_data(); + const auto* MADNESS_RESTRICT const right_tiles_size = + right_.trange().tiles_range().extent_data(); + const auto* MADNESS_RESTRICT const right_element_size = + right_.trange().elements_range().extent_data(); + + // Compute the slab count and the fused sizes of the per-slab contraction + size_type M = 1ul, m = 1ul, N = 1ul, n = 1ul; + n_slabs_ = 1ul; + for (unsigned int i = 0u; i < nh; ++i) n_slabs_ *= left_tiles_size[i]; + for (unsigned int i = nh; i < nh + neA; ++i) { + M *= left_tiles_size[i]; + m *= left_element_size[i]; + } + for (unsigned int i = nh + neA; i < nh + neA + nc; ++i) + K_ *= left_tiles_size[i]; + for (unsigned int i = nh + nc; i < nh + nc + neB; ++i) { + N *= right_tiles_size[i]; + n *= right_element_size[i]; + } + + // corner case: zero-volume result ... easier to skip proc_grid_ + // construction alltogether + if (M == 0 || N == 0 || n_slabs_ == 0) { + left_.init_distribution(world, {}); + right_.init_distribution(world, {}); + ExprEngine_::init_distribution( + world, + (pmap ? pmap : policy::default_pmap(*world, n_slabs_ * M * N))); + } else { + // Construct the per-slab process grid. + proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n); + + // Initialize children with slab-replicated SUMMA phase maps + left_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_row_phase_pmap(K_), n_slabs_)); + right_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_col_phase_pmap(K_), n_slabs_)); + + // Initialize the process map if not already defined + if (!pmap) + pmap = std::make_shared( + *world, proc_grid_.make_pmap(), n_slabs_); + ExprEngine_::init_distribution(world, pmap); + } + } + + /// Construct the distributed evaluator of a general product + + /// \return The batched-Summa distributed evaluator for this expression + dist_eval_type make_dist_eval_general() const { + if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { + // unreachable: init_struct_general throws for tensors-of-tensors + TA_EXCEPTION( + "general products of tensors-of-tensors are not yet implemented"); + abort(); // unreachable + } else { + typedef TiledArray::detail::BatchedContractReduce + batched_op_type; + typedef TiledArray::detail::Summa + impl_type; + + typename left_type::dist_eval_type left = left_.make_dist_eval(); + typename right_type::dist_eval_type right = right_.make_dist_eval(); + + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, trange_, shape_, pmap_, perm_, + batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + + return dist_eval_type(pimpl); + } + } + /// Expression identification tag /// \return An expression tag used to identify this expression diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 6be2f6fb1d..cddf8e721c 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -383,15 +383,12 @@ class MultEngine : public ContEngine> { /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { - // TODO Phase B (batched Summa): evaluate general products natively. - // Until then this engine can classify and lay out a general product - // (see init_indices) but not evaluate it. - if (this->product_type() == TensorProduct::General) - TA_EXCEPTION( - "MultEngine: evaluation of general products (fused + contracted + " - "free indices, e.g. C(\"b,i,k\") = A(\"b,i,j\") * B(\"b,j,k\")) via " - "the expression layer is not yet implemented; use " - "TiledArray::einsum() instead"); + if (this->product_type() == TensorProduct::General) { + // N.B. no inner tile op: ToT general products are rejected by + // init_struct_general + ContEngine_::init_struct_general(target_indices); + return; + } this->init_perm(target_indices); @@ -413,6 +410,8 @@ class MultEngine : public ContEngine> { std::shared_ptr pmap) { if (this->product_type() == TensorProduct::Contraction) ContEngine_::init_distribution(world, pmap); + else if (this->product_type() == TensorProduct::General) + ContEngine_::init_distribution_general(world, pmap); else BinaryEngine_::init_distribution(world, pmap); } @@ -423,6 +422,8 @@ class MultEngine : public ContEngine> { trange_type make_trange() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_trange(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_trange_general(); else return BinaryEngine_::make_trange(); } @@ -544,6 +545,8 @@ class MultEngine : public ContEngine> { dist_eval_type make_dist_eval() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_dist_eval(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_dist_eval_general(); else return BinaryEngine_::make_dist_eval(); } diff --git a/src/TiledArray/pmap/slabbed_pmap.h b/src/TiledArray/pmap/slabbed_pmap.h new file mode 100644 index 0000000000..17fe370751 --- /dev/null +++ b/src/TiledArray/pmap/slabbed_pmap.h @@ -0,0 +1,116 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * slabbed_pmap.h + * Jun 11, 2026 + * + */ + +#ifndef TILEDARRAY_PMAP_SLABBED_PMAP_H__INCLUDED +#define TILEDARRAY_PMAP_SLABBED_PMAP_H__INCLUDED + +#include + +namespace TiledArray { +namespace detail { + +/// Replicates a base process map over a leading "slab" dimension + +/// Consider a sequence of indices \f$ \{ o | o \in [0, N_{\rm slab} S) \} \f$ +/// organized into \f$ N_{\rm slab} \f$ contiguous slabs of \f$ S \f$ indices +/// each. SlabbedPmap maps index \f$ o \f$ to the process that a base map +/// (defined over a single slab, i.e. over \f$ [0, S) \f$) assigns to +/// \f$ o \% S \f$ -- i.e. the owner of an index is independent of its slab. +/// +/// This is the distribution of a *batched* contraction (see +/// TiledArray::detail::Summa): the operands and result carry the fused +/// (Hadamard/batch) modes as their leading dimensions, every slab is +/// distributed identically over the same 2-d process grid, and the slab +/// index never participates in inter-process communication patterns. +class SlabbedPmap : public Pmap { + protected: + // Import Pmap protected variables + using Pmap::local_size_; ///< The number of local tiles + using Pmap::procs_; ///< The number of processes + using Pmap::rank_; ///< The rank of this process + using Pmap::size_; ///< The number of tiles mapped among all processes + + private: + const std::shared_ptr base_; ///< The per-slab base map + const size_type slab_size_; ///< The number of indices per slab + const size_type nslabs_; ///< The number of slabs + + public: + typedef Pmap::size_type size_type; ///< Size type + + /// Construct a slab-replicated process map + + /// \param world The world where the tiles will be mapped + /// \param base The base process map, defined over one slab + /// \param nslabs The number of slabs + SlabbedPmap(World& world, std::shared_ptr base, + const size_type nslabs) + : Pmap(world, base->size() * nslabs), + base_(std::move(base)), + slab_size_(base_->size()), + nslabs_(nslabs) { + TA_ASSERT(base_); + TA_ASSERT(nslabs_ > 0ul); + this->local_size_ = base_->local_size() * nslabs_; + } + + virtual ~SlabbedPmap() {} + + /// \return the per-slab base map + const std::shared_ptr& base() const { return base_; } + /// \return the number of indices per slab + size_type slab_size() const { return slab_size_; } + /// \return the number of slabs + size_type nslabs() const { return nslabs_; } + + /// Maps \c tile to the process that owns it + + /// \param tile The tile to be queried + /// \return Process that logically owns \c tile + virtual size_type owner(const size_type tile) const { + TA_ASSERT(tile < size_); + return base_->owner(tile % slab_size_); + } + + /// Check that the tile is owned by this process + + /// \param tile The tile to be checked + /// \return \c true if \c tile is owned by this process, otherwise \c false + virtual bool is_local(const size_type tile) const { + return base_->is_local(tile % slab_size_); + } + + virtual bool known_local_size() const { return base_->known_local_size(); } + + virtual const_iterator begin() const { + return Iterator(*this, 0ul, size_, 0ul, /*checking=*/true); + } + virtual const_iterator end() const { + return Iterator(*this, 0ul, size_, size_, /*checking=*/true); + } + +}; // class SlabbedPmap + +} // namespace detail +} // namespace TiledArray + +#endif // TILEDARRAY_PMAP_SLABBED_PMAP_H__INCLUDED diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h new file mode 100644 index 0000000000..51579bfc27 --- /dev/null +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -0,0 +1,203 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * tile_op/batched_contract_reduce.h + * Jun 11, 2026 + * + */ + +#ifndef TILEDARRAY_TILE_OP_BATCHED_CONTRACT_REDUCE_H__INCLUDED +#define TILEDARRAY_TILE_OP_BATCHED_CONTRACT_REDUCE_H__INCLUDED + +#include + +namespace TiledArray { +namespace detail { + +/// Batched (fused-index) tile contract/reduce operation + +/// Adapts a ContractReduce op built for the *folded* (fused-mode-free) tile +/// shapes to tiles that carry \c nfused leading fused (Hadamard/batch) modes +/// on both arguments and the result, i.e. evaluates +/// \f$ C_{h,e_A,e_B} \mathrel{+}= \sum_c A_{h,e_A,c} B_{h,c,e_B} \f$ +/// per tile. The fused modes are folded into the tile batch dimension by a +/// zero-copy reshape (they are leading, so the fold preserves layout) and the +/// per-batch GEMM loop of \c Tensor::gemm does the rest; the accumulated +/// result tile is allocated with its full (h, e_A, e_B) range up front, so no +/// unfold is ever needed. +/// +/// Models the ReducePairTask op concept (seed / combine / pair-accumulate / +/// finalize). The result permutation must be handled outside this op (the +/// wrapped op must be perm-free). +/// +/// \tparam Op the folded-shape contract/reduce op (ContractReduce +/// instantiated with the fused-mode-free ranks) +template +class BatchedContractReduce { + public: + typedef BatchedContractReduce BatchedContractReduce_; + typedef Op op_type; + typedef typename Op::result_type result_type; + typedef typename Op::first_argument_type first_argument_type; + typedef typename Op::second_argument_type second_argument_type; + + static_assert( + !TiledArray::detail::is_tensor_of_tensor_v, + "BatchedContractReduce does not (yet) support tensor-of-tensor tiles"); + + private: + op_type op_; ///< The folded-shape contract/reduce op + unsigned int nfused_ = 0; ///< The number of leading fused modes + + /// \return the range spanned by modes [nfused_, rank) of \p r, rebased to + /// zero lobounds (the folded view is a GEMM scratch view; only extents + /// matter) + template + Range_ fold_range(const Range_& r) const { + const auto* extent = r.extent_data(); + container::svector extents(extent + nfused_, + extent + r.rank()); + return Range_(extents); + } + + /// \return the number of fused elements (product of the extents of the + /// leading \c nfused_ modes) of \p r + template + std::size_t fused_volume(const Range_& r) const { + const auto* extent = r.extent_data(); + std::size_t n = 1ul; + for (unsigned int d = 0u; d < nfused_; ++d) n *= extent[d]; + return n; + } + + public: + BatchedContractReduce() = default; + BatchedContractReduce(const BatchedContractReduce_&) = default; + BatchedContractReduce(BatchedContractReduce_&&) = default; + ~BatchedContractReduce() = default; + BatchedContractReduce_& operator=(const BatchedContractReduce_&) = default; + BatchedContractReduce_& operator=(BatchedContractReduce_&&) = default; + + /// \param op the folded-shape contract/reduce op; must be perm-free (the + /// full-range result of this op cannot host the folded-rank result + /// permutation) + /// \param nfused the number of leading fused modes carried by both + /// arguments and the result + BatchedContractReduce(const op_type& op, const unsigned int nfused) + : op_(op), nfused_(nfused) { + TA_ASSERT(!op_.perm()); + } + + /// \return the wrapped folded-shape op + const op_type& op() const { return op_; } + /// \return the number of leading fused modes + unsigned int nfused() const { return nfused_; } + /// \return the GEMM helper of the wrapped (folded-shape) op + const TiledArray::math::GemmHelper& gemm_helper() const { + return op_.gemm_helper(); + } + + /// Create a new, empty result object + result_type operator()() const { return result_type(); } + + /// Post processing step (no result permutation supported) + result_type operator()(const result_type& temp) const { + using TiledArray::empty; + TA_ASSERT(!empty(temp)); + return temp; + } + + /// Reduce two result objects (both carry the full fused range) + void operator()(result_type& result, const result_type& arg) const { + op_(result, arg); + } + + /// Contract a pair of fused-mode-carrying tiles and add to the target tile + void operator()(result_type& result, const first_argument_type& left, + const second_argument_type& right) const { + // the fold relies on TA::Tensor's zero-copy reshape + batched GEMM + constexpr bool supported_tiles = + TiledArray::detail::is_ta_tensor_v && + TiledArray::detail::is_ta_tensor_v< + std::remove_cv_t>> && + TiledArray::detail::is_ta_tensor_v< + std::remove_cv_t>>; + if constexpr (!supported_tiles) { + TA_EXCEPTION( + "BatchedContractReduce supports only TiledArray::Tensor tiles"); + } else { + contract_pair(result, left, right); + } + } + + private: + /// The TA::Tensor implementation of the pair contraction + void contract_pair(result_type& result, const first_argument_type& left, + const second_argument_type& right) const { + using TiledArray::empty; + if (empty(left) || empty(right)) return; + + const auto& gh = op_.gemm_helper(); + const unsigned int nc = gh.num_contract_ranks(); + const unsigned int neA = gh.left_rank() - nc; + const unsigned int neB = gh.right_rank() - nc; + + // both args must carry the fused modes as their leading modes, with + // equal extents + TA_ASSERT(left.range().rank() == nfused_ + neA + nc); + TA_ASSERT(right.range().rank() == nfused_ + nc + neB); + TA_ASSERT(left.nbatch() == 1ul); + TA_ASSERT(right.nbatch() == 1ul); + const std::size_t batch = fused_volume(left.range()); + TA_ASSERT(batch == fused_volume(right.range())); + + // allocate the result with its full (h, e_A, e_B) range: fused + left + // external bounds from the left tile, right external bounds from the + // right tile. The data layout of the full-range result coincides with + // the folded (range = (e_A, e_B), nbatch = batch) layout because the + // fused modes lead. + if (empty(result)) { + using index1_type = typename result_type::range_type::index1_type; + container::svector lobounds, upbounds; + lobounds.reserve(nfused_ + neA + neB); + upbounds.reserve(nfused_ + neA + neB); + for (unsigned int d = 0u; d < nfused_ + neA; ++d) { + lobounds.push_back(left.range().lobound_data()[d]); + upbounds.push_back(left.range().upbound_data()[d]); + } + for (unsigned int d = nfused_ + nc; d < nfused_ + nc + neB; ++d) { + lobounds.push_back(right.range().lobound_data()[d]); + upbounds.push_back(right.range().upbound_data()[d]); + } + result = result_type(typename result_type::range_type(lobounds, upbounds), + typename result_type::value_type(0)); + } + + // folded, zero-copy views; the result view shares the result's buffer, + // so the batched GEMM accumulates in place + auto left_folded = left.reshape(fold_range(left.range()), batch); + auto right_folded = right.reshape(fold_range(right.range()), batch); + auto result_folded = result.reshape(fold_range(result.range()), batch); + op_(result_folded, left_folded, right_folded); + } + +}; // class BatchedContractReduce + +} // namespace detail +} // namespace TiledArray + +#endif // TILEDARRAY_TILE_OP_BATCHED_CONTRACT_REDUCE_H__INCLUDED diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 5d927843cd..8d6dcaa8fa 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -1,6 +1,7 @@ /// Unit tests for general-product (fused + contracted + free indices) /// classification and layout in the expression layer. +#include "TiledArray/einsum/tiledarray.h" #include "TiledArray/expressions/permopt.h" #include "TiledArray/expressions/product.h" @@ -127,22 +128,90 @@ BOOST_AUTO_TEST_CASE(optimizer_rejects_implicit_reduction) { TiledArray::Exception); } -BOOST_AUTO_TEST_CASE(expression_general_product_gated) { - // the expression layer now *classifies* general products correctly and - // reports that evaluation is not yet implemented (instead of misrouting - // to the contraction or Hadamard machinery) +namespace { + +/// makes a dense array over \p tr filled with an index-dependent pattern +TA::TArrayD make_patterned_array(TA::World& world, const TA::TiledRange& tr, + const double seed) { + TA::TArrayD result(world, tr); + for (auto it = result.begin(); it != result.end(); ++it) { + auto tile = + TA::TArrayD::value_type(result.trange().make_tile_range(it.index())); + for (auto&& ix : tile.range()) { + double v = seed; + double scale = 1.0; + for (auto x : ix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + tile[ix] = v; + } + *it = tile; + } + return result; +} + +/// \return the Frobenius norm of `lhs - rhs` +double diff_norm(TA::TArrayD& lhs, TA::TArrayD& rhs, const std::string& annot) { + TA::TArrayD diff; + diff(annot) = lhs(annot) - rhs(annot); + return diff(annot).norm().get(); +} + +} // namespace + +BOOST_AUTO_TEST_CASE(expression_general_product_dense) { + // dense general products evaluate via the batched Summa; differential-test + // against the einsum free function (the established implementation) auto& world = TA::get_default_world(); - TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; - TA::TArrayD a(world, tr); - TA::TArrayD b(world, tr); - a.fill(1.0); - b.fill(1.0); + + // C("b,i,k") = A("b,i,j") * B("b,j,k"), uneven multi-tile dimensions + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + TA::TArrayD c; - BOOST_CHECK_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k"), - TiledArray::Exception); + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); + // pure contraction and pure Hadamard still work - BOOST_CHECK_NO_THROW(c("i,k") = a("b,i,j") * b("b,j,k")); - BOOST_CHECK_NO_THROW(c("b,i,j") = a("b,i,j") * b("b,i,j")); + TA::TArrayD d; + BOOST_CHECK_NO_THROW(d("i,k") = a("b,i,j") * b("b,j,k")); + TA::TArrayD e; + BOOST_CHECK_NO_THROW(e("b,i,j") = a("b,i,j") * a("b,i,j")); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_dense_permuted_args) { + // non-canonical argument layouts: the engine permutes the args into the + // canonical (h, e_A, c) / (h, c, e_B) layouts + auto& world = TA::get_default_world(); + + TA::TiledRange tr_a{{0, 3, 4}, {0, 2, 6, 7}, {0, 2, 5}}; // i, j, b + TA::TiledRange tr_b{{0, 4, 5}, {0, 2, 6, 7}, {0, 2, 5}}; // k, j, b + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("i,j,b") * b("k,j,b")); + auto c_ref = TA::einsum(a("i,j,b"), b("k,j,b"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_dense_batched_outer) { + // batched outer product: fused + free, no contracted indices + auto& world = TA::get_default_world(); + + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}}; // b, i + TA::TiledRange tr_b{{0, 2, 5}, {0, 4, 5}}; // b, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i") * b("b,k")); + auto c_ref = TA::einsum(a("b,i"), b("b,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); } BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_gated) { @@ -167,4 +236,45 @@ BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_gated) { TiledArray::Exception); } +BOOST_AUTO_TEST_CASE(expression_general_product_thc_intermediates) { + // the supported way to evaluate the THC factorization today: materialize + // each general product as the root of its own assignment (fused indices + // leading in the result annotation), differential-tested against einsum + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + + TA::TArrayD i1, i2, i3, g; + BOOST_REQUIRE_NO_THROW(i1("r1,p,q") = x("p,r1") * x("q,r1")); // general + BOOST_REQUIRE_NO_THROW(i2("p,q,r2") = i1("r1,p,q") * z("r1,r2")); + BOOST_REQUIRE_NO_THROW(i3("r2,p,q,r") = i2("p,q,r2") * x("r,r2")); // general + BOOST_REQUIRE_NO_THROW(g("p,q,r,s") = i3("r2,p,q,r") * x("s,r2")); + + // oracle: the same chain with the general products evaluated by einsum + auto i1_ref = TA::einsum(x("p,r1"), x("q,r1"), "r1,p,q"); + TA::TArrayD i2_ref; + i2_ref("p,q,r2") = i1_ref("r1,p,q") * z("r1,r2"); + auto i3_ref = TA::einsum(i2_ref("p,q,r2"), x("r,r2"), "r2,p,q,r"); + TA::TArrayD g_ref; + g_ref("p,q,r,s") = i3_ref("r2,p,q,r") * x("s,r2"); + + BOOST_CHECK_SMALL(diff_norm(g, g_ref, "p,q,r,s"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_gated) { + // block-sparse general products are not implemented yet: must report + // clearly rather than compute garbage + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; + TA::TSpArrayD a(world, tr); + TA::TSpArrayD b(world, tr); + a.fill(1.0); + b.fill(1.0); + TA::TSpArrayD c; + BOOST_CHECK_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k"), + TiledArray::Exception); +} + BOOST_AUTO_TEST_SUITE_END() From 8eb744c808d9495844aedc6d998699ce13af7210 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 08:54:45 -0400 Subject: [PATCH 05/25] expressions: block-sparse general products Phase B3: SparsePolicy general products (fused + contracted + free indices) now evaluate natively in the expression layer. - SparseShape::gemm_batched(other, factor, gemm_helper, nfused): the batched analogue of gemm. The leading nfused modes of both shapes and the result are fused; each fused-index slab is contracted exactly as in gemm with the *folded* (fused-mode-free) GEMM helper. The contracted-mode size vector is slab-invariant (contracted modes follow the fused modes), the norm scaling loops extend over slabs naturally, and the per-slab norm GEMMs run as one batched Tensor::gemm via the zero-copy fused-modes-into-nbatch reshape; same hard-zero threshold pass and outer-product (k_rank == 0) handling as gemm. - ContEngine::make_shape_general routes SparseShape to it (the dense branch is unchanged); this removes the last block-sparse gate, so the batched Summa's sparse path ((h,k)-keyed masks, groups, and step iteration, landed earlier) is now reachable. - tests: block-sparse differential tests vs einsum (batched contraction and batched outer product, deterministic block-sparsity patterns); pass at np = 1, 2, 3, 4. --- src/TiledArray/expressions/cont_engine.h | 11 +- src/TiledArray/sparse_shape.h | 158 +++++++++++++++++++++++ tests/general_product.cpp | 81 ++++++++++-- 3 files changed, 233 insertions(+), 17 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index fe4fd937bb..5d3c43a991 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -705,17 +705,14 @@ class ContEngine : public BinaryEngine { /// Shape factory function for a general product - /// \return The result shape + /// \return The result shape: the fused modes lead; each fused-index slab + /// is the shape-level contraction of the corresponding argument slabs shape_type make_shape_general() const { if constexpr (std::is_same_v) return shape_type(); else - // TODO support block-sparse general products: evaluate the shape - // slab-by-slab (per-slab SparseShape::gemm over the folded modes) - TA_EXCEPTION( - "block-sparse general products (fused + contracted + free indices) " - "via the expression layer are not yet implemented; use " - "TiledArray::einsum() instead"); + return left_.shape().gemm_batched(right_.shape(), factor_, + op_.gemm_helper(), n_fused_modes_); } /// Initialize the result distribution of a general product diff --git a/src/TiledArray/sparse_shape.h b/src/TiledArray/sparse_shape.h index ffeacebd5c..9bcd381462 100644 --- a/src/TiledArray/sparse_shape.h +++ b/src/TiledArray/sparse_shape.h @@ -1690,6 +1690,164 @@ class SparseShape { return gemm(other, factor, gemm_helper).perm(perm); } + /// Batched (fused-index) analogue of gemm + + /// Computes the shape of a *general* product (fused + contracted + free + /// indices): the leading \p nfused modes of both shapes and of the result + /// are fused (Hadamard/batch) modes, and each fused-index slab is + /// contracted exactly as in gemm using the *folded* (fused-mode-free) + /// \p gemm_helper. + /// \tparam Factor The scaling factor type + /// \param other The right-hand shape; its leading \p nfused modes must be + /// congruent with this shape's + /// \param factor The scaling factor + /// \param gemm_helper The folded (fused-mode-free) GEMM helper + /// \param nfused The number of leading fused modes + template + SparseShape_ gemm_batched(const SparseShape_& other, const Factor factor, + const math::GemmHelper& gemm_helper, + const unsigned int nfused) const { + TA_ASSERT(!tile_norms_.empty()); + TA_ASSERT(!other.tile_norms_.empty()); + + const value_type abs_factor = to_abs_factor(factor); + const value_type threshold = threshold_; + madness::AtomicInt zero_tile_count; + zero_tile_count = 0; + + // the number of fused-index slabs and the folded matrix sizes + using integer = TiledArray::math::blas::integer; + const auto* left_extent = tile_norms_.range().extent_data(); + const auto* right_extent = other.tile_norms_.range().extent_data(); + integer H = 1, M = 1, N = 1, K = 1; + for (unsigned int d = 0u; d < nfused; ++d) H *= left_extent[d]; + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) + M *= left_extent[nfused + i]; + for (unsigned int i = gemm_helper.left_inner_begin(); + i < gemm_helper.left_inner_end(); ++i) + K *= left_extent[nfused + i]; + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i) + N *= right_extent[nfused + i]; + + // result size vectors: fused modes (from this), then the left and right + // outer modes + const unsigned int result_rank = nfused + gemm_helper.result_rank(); + std::shared_ptr result_size_vectors( + new vector_type[result_rank], std::default_delete()); + unsigned int x = 0ul; + for (unsigned int i = 0u; i < nfused; ++i, ++x) + result_size_vectors.get()[x] = size_vectors_.get()[i]; + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i, ++x) + result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i, ++x) + result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i]; + + // the result norm tensor over (fused..., left outer..., right outer...) + using range_type = typename Tensor::range_type; + using index1_type = typename range_type::index1_type; + container::svector lobounds, upbounds; + lobounds.reserve(result_rank); + upbounds.reserve(result_rank); + for (unsigned int d = 0u; d < nfused; ++d) { + lobounds.push_back(tile_norms_.range().lobound_data()[d]); + upbounds.push_back(tile_norms_.range().upbound_data()[d]); + } + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) { + lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); + upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); + } + for (unsigned int i = gemm_helper.right_outer_begin(); + i < gemm_helper.right_outer_end(); ++i) { + lobounds.push_back(other.tile_norms_.range().lobound_data()[nfused + i]); + upbounds.push_back(other.tile_norms_.range().upbound_data()[nfused + i]); + } + Tensor result_norms(range_type(lobounds, upbounds), 0); + + // the range spanned by modes [nfused, rank) of \p r, rebased to zero + // lobounds (scratch view for the slab-batched norm GEMM) + auto fold_range = [nfused](const range_type& r) { + const auto* extent = r.extent_data(); + container::svector extents(extent + nfused, + extent + r.rank()); + return range_type(extents); + }; + + // the number of inner ranks + const unsigned int k_rank = + gemm_helper.left_inner_end() - gemm_helper.left_inner_begin(); + + if (k_rank > 0u) { + // the contracted-mode size vector; identical for every slab since the + // contracted modes follow the fused modes + const vector_type k_sizes = recursive_outer_product( + size_vectors_.get() + nfused + gemm_helper.left_inner_begin(), k_rank, + [](const vector_type& size_vector) -> const vector_type& { + return size_vector; + }); + + // scale the rows of the left norms (k_sizes pattern repeats per slab) + Tensor left(tile_norms_.range()); + const size_type hmk = H * M * K; + auto left_op = [](const value_type left, const value_type right) { + return left * right; + }; + for (size_type i = 0ul; i < hmk; i += K) + math::vector_op(left_op, K, left.data() + i, tile_norms_.data() + i, + k_sizes.data()); + + // scale the rows of the right norms (within-slab row r has factor + // k_sizes[r % K]) + Tensor right(other.tile_norms_.range()); + const size_type hk = H * K; + for (size_type i = 0ul, r = 0ul; r < hk; i += N, ++r) { + const value_type factor = k_sizes[r % K]; + auto right_op = [=](const value_type arg) { return arg * factor; }; + math::vector_op(right_op, N, right.data() + i, + other.tile_norms_.data() + i); + } + + // slab-batched norm GEMM: fold the fused modes into the tensor batch + // dimension by zero-copy reshape; result_folded shares result_norms' + // buffer, so the accumulation lands in place + auto left_folded = left.reshape(fold_range(left.range()), H); + auto right_folded = right.reshape(fold_range(right.range()), H); + auto result_folded = + result_norms.reshape(fold_range(result_norms.range()), H); + result_folded.gemm(left_folded, right_folded, abs_factor, gemm_helper); + + // Hard zero tiles that are below the zero threshold. + result_norms.inplace_unary( + [threshold, &zero_tile_count](value_type& value) { + if (value < threshold) { + value = value_type(0); + ++zero_tile_count; + } + }); + } else { + // batched outer product: per-slab outer products of the norms + for (integer h = 0; h < H; ++h) + math::outer_fill(M, N, tile_norms_.data() + h * M, + other.tile_norms_.data() + h * N, + result_norms.data() + h * M * N, + [threshold, &zero_tile_count, abs_factor]( + const value_type left, const value_type right) { + value_type norm = left * right * abs_factor; + if (norm < threshold) { + norm = value_type(0); + ++zero_tile_count; + } + return norm; + }); + } + + return SparseShape_(result_norms, result_size_vectors, zero_tile_count); + } + template >>::type* = nullptr> diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 8d6dcaa8fa..f284ac6cde 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -263,18 +263,79 @@ BOOST_AUTO_TEST_CASE(expression_general_product_thc_intermediates) { BOOST_CHECK_SMALL(diff_norm(g, g_ref, "p,q,r,s"), 1e-10); } -BOOST_AUTO_TEST_CASE(expression_general_product_sparse_gated) { - // block-sparse general products are not implemented yet: must report - // clearly rather than compute garbage +namespace { + +/// makes a block-sparse array over \p tr with an index-dependent fill and a +/// deterministic block-sparsity pattern (every \p zero_stride -th tile zero) +TA::TSpArrayD make_patterned_sparse_array(TA::World& world, + const TA::TiledRange& tr, + const double seed, + const std::size_t zero_stride) { + // shape: unit norm except every zero_stride-th tile + TA::Tensor norms(tr.tiles_range(), 1.0f); + for (std::size_t ord = 0; ord < norms.size(); ord += zero_stride) + norms.data()[ord] = 0.0f; + TA::SparseShape shape(norms, tr); + + TA::TSpArrayD result(world, tr, shape); + // iteration visits only local non-zero tiles + for (auto it = result.begin(); it != result.end(); ++it) { + auto tile = + TA::TSpArrayD::value_type(result.trange().make_tile_range(it.index())); + for (auto&& ix : tile.range()) { + double v = seed; + double scale = 1.0; + for (auto x : ix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + tile[ix] = v; + } + *it = tile; + } + return result; +} + +/// \return the Frobenius norm of `lhs - rhs` (block-sparse) +double diff_norm_sp(TA::TSpArrayD& lhs, TA::TSpArrayD& rhs, + const std::string& annot) { + TA::TSpArrayD diff; + diff(annot) = lhs(annot) - rhs(annot); + return diff(annot).norm().get(); +} + +} // namespace + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse) { + // block-sparse general products: the result shape is computed slab-by-slab + // (SparseShape::gemm_batched) and the batched Summa runs its sparse path + // ((h,k)-keyed masks/groups); differential-test against einsum auto& world = TA::get_default_world(); - TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; - TA::TSpArrayD a(world, tr); - TA::TSpArrayD b(world, tr); - a.fill(1.0); - b.fill(1.0); + + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k + auto a = make_patterned_sparse_array(world, tr_a, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr_b, 2.0, 4); + TA::TSpArrayD c; - BOOST_CHECK_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k"), - TiledArray::Exception); + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_batched_outer) { + // block-sparse batched outer product (no contracted indices) + auto& world = TA::get_default_world(); + + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}}; // b, i + TA::TiledRange tr_b{{0, 2, 5}, {0, 4, 5}}; // b, k + auto a = make_patterned_sparse_array(world, tr_a, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr_b, 2.0, 2); + + TA::TSpArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i") * b("b,k")); + auto c_ref = TA::einsum(a("b,i"), b("b,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i,k"), 1e-10); } BOOST_AUTO_TEST_SUITE_END() From 1b67f81c3cc6114c89853d9e8fa5c1ba92fdc8b7 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 09:40:46 -0400 Subject: [PATCH 06/25] expressions: tensor-of-tensor general products Phase C: ToT general products (fused + contracted + free outer indices, nested inner product) now evaluate natively in the expression layer via the batched Summa. - the inner-tile-op builders (init_inner_tile_op and the owning-cell variant) now classify the outer regime via outer_product_uses_summa() (pure contraction OR general product): for both, the tile op is a ContractReduce consumed by a (batched) SUMMA, so the per-cell ops accumulate in place, no per-cell result permutation is applied, and the arena plans / strided-DGEMM ops install as for a pure contraction - the strided-DGEMM install gates derive the outer-contracted rank from the fused-mode-free outer sizes (n_fused_outer_modes() helper) - init_struct_general gains the ToT arm, mirroring init_struct: builds the folded-rank ContractReduce with the inner element op and the arena plan, installs the strided ce+e / ce+ce ops; a non-identity inner result permutation is gated (the batched op must be perm-free) - BatchedContractReduce now admits ToT tiles: the folded result is allocated by the wrapped op itself (engaging its tile-type-specific construction, e.g. the arena reserve) and unfolded by a zero-copy reshape; this also gives plain tiles the beta=0 first-accumulation fast path - MultEngine initializes the inner tile op before init_struct_general Differential-tested against einsum (inner Hadamard and inner contraction, owning cells) at np = 1, 2, 3; all existing suites unchanged. Arena (view-cell) general products compile via the same paths; their end-to-end validation comes with the mpqc/einsum cutover. --- src/TiledArray/expressions/cont_engine.h | 251 ++++++++++-------- src/TiledArray/expressions/mult_engine.h | 6 +- .../tile_op/batched_contract_reduce.h | 39 +-- tests/general_product.cpp | 88 ++++++ 4 files changed, 259 insertions(+), 125 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 5d3c43a991..91b704108c 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -204,6 +204,30 @@ class ContEngine : public BinaryEngine { return product_type_; } + /// \return true if the outer product is evaluated by a (batched) SUMMA, + /// i.e. the tile op is a ContractReduce (a pure contraction or a general + /// product); false for the elementwise (Hadamard) binary tile op + bool outer_product_uses_summa() const { + return product_type_ == TensorProduct::Contraction || + product_type_ == TensorProduct::General; + } + + /// \return for a general product, the number of leading fused (outer) + /// modes common to the canonical left, right, and result layouts + /// (GeneralPermutationOptimizer places them first); 0 for the other + /// product types + unsigned int n_fused_outer_modes() const { + if (product_type_ != TensorProduct::General) return 0u; + auto const& l = outer(left_indices_); + auto const& r = outer(right_indices_); + auto const& res = outer(indices_); + unsigned int nh = 0u; + while (nh < res.size() && nh < l.size() && nh < r.size() && + l[nh] == res[nh] && r[nh] == res[nh]) + ++nh; + return nh; + } + /// \return the inner product type TensorProduct inner_product_type() const { TA_ASSERT(inner_product_type_ != @@ -617,53 +641,80 @@ class ContEngine : public BinaryEngine { /// shape. /// \param target_indices The target index list for the result tensor void init_struct_general(const BipartiteIndexList& target_indices) { - if constexpr (TiledArray::detail::is_tensor_of_tensor_v) + // precondition checks (mirror init_struct) + if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { + TA_ASSERT(element_nonreturn_op_); + // a view inner cell (e.g. ArenaTensor) cannot host a value-returning + // inner op, so element_return_op_ is intentionally left null for it + if constexpr (!TiledArray::is_tensor_view_v) + TA_ASSERT(element_return_op_); + } + + // Initialize children + left_.init_struct(left_indices_); + right_.init_struct(right_indices_); + + // count the fused modes: the leading indices common to the canonical + // left, right, and result layouts + const unsigned int nh = n_fused_outer_modes(); + TA_ASSERT(nh > 0u); // else this is a pure contraction + n_fused_modes_ = nh; + + // initialize perm_; an interleaved target (a result permutation that + // mixes fused and free modes) is not yet supported -- the canonical + // result layout must equal the target + this->init_perm(target_indices); + if (outer(target_indices) != outer(indices_)) TA_EXCEPTION( - "general products (fused + contracted + free indices) of " - "tensors-of-tensors via the expression layer are not yet " - "implemented; use TiledArray::einsum() instead"); - else { - // Initialize children - left_.init_struct(left_indices_); - right_.init_struct(right_indices_); - - // count the fused modes: the leading indices common to the canonical - // left, right, and result layouts - const auto& left_outer = outer(left_indices_); - const auto& right_outer = outer(right_indices_); - const auto& result_outer = outer(indices_); - unsigned int nh = 0u; - while (nh < result_outer.size() && nh < left_outer.size() && - nh < right_outer.size() && left_outer[nh] == result_outer[nh] && - right_outer[nh] == result_outer[nh]) - ++nh; - TA_ASSERT(nh > 0u); // else this is a pure contraction - n_fused_modes_ = nh; - - // initialize perm_; an interleaved target (a result permutation that - // mixes fused and free modes) is not yet supported -- the canonical - // result layout must equal the target - this->init_perm(target_indices); - if (outer(target_indices) != outer(indices_)) - TA_EXCEPTION( - "general products (fused + contracted + free indices): targets " - "that interleave fused and free indices are not yet supported; " - "reorder the result annotation to (fused..., left-free..., " - "right-free...)"); - - // the tile op operates on the folded (fused-mode-free) shapes - const auto left_op = to_cblas_op(left_outer_permtype_); - const auto right_op = to_cblas_op(right_outer_permtype_); + "general products (fused + contracted + free indices): targets " + "that interleave fused and free indices are not yet supported; " + "reorder the result annotation to (fused..., left-free..., " + "right-free...)"); + + // the tile op operates on the folded (fused-mode-free) shapes + const auto left_op = to_cblas_op(left_outer_permtype_); + const auto right_op = to_cblas_op(right_outer_permtype_); + if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh, outer_size(left_indices_) - nh, outer_size(right_indices_) - nh); + } else { + // the batched tile op must be perm-free (BatchedContractReduce cannot + // host the folded-rank result permutation); the outer perm is empty by + // the interleaved-target gate above, so only an explicit inner result + // permutation can require one + if (!implicit_permute_inner_ && bool(inner(perm_))) + TA_EXCEPTION( + "general products of tensors-of-tensors: a non-identity inner " + "result permutation is not yet supported; reorder the inner " + "annotation of the result"); + + // factor_ is absorbed into element_nonreturn_op_ + op_ = op_type(left_op, right_op, scalar_type(1), + outer_size(indices_) - nh, outer_size(left_indices_) - nh, + outer_size(right_indices_) - nh, BipartitePermutation{}, + this->element_nonreturn_op_, std::move(this->arena_plan_)); + // ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one + // is non-null and only one install fires (see init_struct) + if (this->arena_strided_dgemm_ce_e_tile_op_) + op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_e_tile_op_); + if (this->arena_strided_dgemm_ce_ce_right_tile_op_) + op_.set_strided_oprod_op( + this->arena_strided_dgemm_ce_ce_right_tile_op_); + if (this->arena_strided_dgemm_ce_ce_left_tile_op_) + op_.set_strided_oprod_op(this->arena_strided_dgemm_ce_ce_left_tile_op_); + // Plan ownership transferred to op_; mark carrier slot empty so any + // later use of arena_plan_ reads as "no plan" rather than moved-from. + if constexpr (!std::is_same_v) { + this->arena_plan_.reset(); + } + } - trange_ = make_trange_general(); - shape_ = make_shape_general(); + trange_ = make_trange_general(); + shape_ = make_shape_general(); - if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { - shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape); - } + if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { + shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape); } } @@ -786,29 +837,20 @@ class ContEngine : public BinaryEngine { /// \return The batched-Summa distributed evaluator for this expression dist_eval_type make_dist_eval_general() const { - if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { - // unreachable: init_struct_general throws for tensors-of-tensors - TA_EXCEPTION( - "general products of tensors-of-tensors are not yet implemented"); - abort(); // unreachable - } else { - typedef TiledArray::detail::BatchedContractReduce - batched_op_type; - typedef TiledArray::detail::Summa - impl_type; - - typename left_type::dist_eval_type left = left_.make_dist_eval(); - typename right_type::dist_eval_type right = right_.make_dist_eval(); - - std::shared_ptr pimpl = std::make_shared( - left, right, *world_, trange_, shape_, pmap_, perm_, - batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); - - return dist_eval_type(pimpl); - } + typedef TiledArray::detail::BatchedContractReduce batched_op_type; + typedef TiledArray::detail::Summa + impl_type; + + typename left_type::dist_eval_type left = left_.make_dist_eval(); + typename right_type::dist_eval_type right = right_.make_dist_eval(); + + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, trange_, shape_, pmap_, perm_, + batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + + return dist_eval_type(pimpl); } /// Expression identification tag @@ -863,7 +905,7 @@ class ContEngine : public BinaryEngine { this->product_type() == TensorProduct::Hadamard) { // pure Hadamard: element_*_op_ left null } else if (inner_prod == TensorProduct::Hadamard && - this->product_type() == TensorProduct::Contraction) { + this->outer_product_uses_summa()) { // outer Contraction + inner Hadamard on view inner tiles. // Mirror the owning-tile path (init_inner_tile_op_owning_): the // SUMMA shapes each result cell from a non-empty left inner cell @@ -945,7 +987,7 @@ class ContEngine : public BinaryEngine { // result cell is pre-shaped [1] by the unit_range plan. result.data()[0] += static_cast(factor) * acc; }; - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( @@ -992,7 +1034,7 @@ class ContEngine : public BinaryEngine { TiledArray::detail::make_fused_contraction_lambda< result_tile_element_type, left_tile_element_type, right_tile_element_type>(contrreduce_op); - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { // outer contraction: the SUMMA result is shaped from operand // inner cells by arena_plan_; op_'s post-processing permute // applies the (outer + inner) result permutation. @@ -1093,15 +1135,20 @@ class ContEngine : public BinaryEngine { inner_gh.left_rank() == inner_gh.num_contract_ranks(); // Derive the outer-contracted rank `oc` from the outer index // sizes (same helper used by the outer op when building op_). - const auto oc = (outer_size(this->left_indices_) + - outer_size(this->right_indices_) - - outer_size(this->indices_)) / + // for a general product the leading fused modes do not + // participate in the outer GEMM (they are folded into the + // tile batch dimension), so exclude them from the rank + // accounting + const auto nh = this->n_fused_outer_modes(); + const auto oc = (outer_size(this->left_indices_) - nh + + outer_size(this->right_indices_) - nh - + (outer_size(this->indices_) - nh)) / 2; // the ridden operand must carry an outer external to ride. const bool right_has_ext = - outer_size(this->right_indices_) > oc; + outer_size(this->right_indices_) - nh > oc; const bool left_has_ext = - outer_size(this->left_indices_) > oc; + outer_size(this->left_indices_) - nh > oc; // canonical inner orientation: identity == "no inner // transpose". right core assumes L=(a1,a4), R=(a4); left core // assumes L=(a4), R=(a4,b1). Either way BOTH inner permtypes @@ -1369,7 +1416,7 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible) { - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( @@ -1388,8 +1435,7 @@ class ContEngine : public BinaryEngine { } else { this->element_nonreturn_op_ = [contrreduce_op, - permute_inner = - this->product_type() != TensorProduct::Contraction]( + permute_inner = !this->outer_product_uses_summa()]( result_tile_element_type& result, const left_tile_element_type& left, const right_tile_element_type& right) { @@ -1401,8 +1447,8 @@ class ContEngine : public BinaryEngine { } } else { this->element_nonreturn_op_ = - [contrreduce_op, permute_inner = this->product_type() != - TensorProduct::Contraction]( + [contrreduce_op, + permute_inner = !this->outer_product_uses_summa()]( result_tile_element_type& result, const left_tile_element_type& left, const right_tile_element_type& right) { @@ -1420,7 +1466,7 @@ class ContEngine : public BinaryEngine { // inner tile op depends on the outer op ... e.g. if outer op // is contract then inner must implement (ternary) multiply-add; // if the outer is hadamard then the inner is binary multiply - const auto outer_prod = this->product_type(); + const bool outer_uses_summa = this->outer_product_uses_summa(); if (this->factor_ == scalar_type{1}) { using base_op_type = TiledArray::detail::Mult { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_h_unit) { - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( @@ -1455,15 +1501,13 @@ class ContEngine : public BinaryEngine { right_tile_element_type>(); } else { this->element_nonreturn_op_ = - [mult_op, outer_prod]( + [mult_op, outer_uses_summa]( result_tile_element_type& result, const left_tile_element_type& left, const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) + if (!outer_uses_summa) result = mult_op(left, right); - else { // outer_prod == TensorProduct::Contraction + else { // outer product evaluated by (batched) SUMMA // there is currently no fused MultAdd ternary Op, only // Add and Mult thus implement this as 2 separate steps // TODO optimize by implementing (ternary) MultAdd @@ -1478,14 +1522,13 @@ class ContEngine : public BinaryEngine { } } else { this->element_nonreturn_op_ = - [mult_op, outer_prod](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) + [mult_op, outer_uses_summa]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (!outer_uses_summa) result = mult_op(left, right); - else { // outer_prod == TensorProduct::Contraction + else { // outer product evaluated by (batched) SUMMA // there is currently no fused MultAdd ternary Op, only // Add and Mult thus implement this as 2 separate steps // TODO optimize by implementing (ternary) MultAdd @@ -1515,7 +1558,7 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_h_scaled) { - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan< result_tile_type, left_tile_type, right_tile_type>( @@ -1531,13 +1574,11 @@ class ContEngine : public BinaryEngine { right_tile_element_type>(this->factor_); } else { this->element_nonreturn_op_ = - [mult_op, outer_prod]( + [mult_op, outer_uses_summa]( result_tile_element_type& result, const left_tile_element_type& left, const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) + if (!outer_uses_summa) result = mult_op(left, right); else { // there is currently no fused MultAdd ternary Op, only @@ -1554,12 +1595,11 @@ class ContEngine : public BinaryEngine { } } else { this->element_nonreturn_op_ = - [mult_op, outer_prod](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) + [mult_op, outer_uses_summa]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (!outer_uses_summa) result = mult_op(left, right); else { // there is currently no fused MultAdd ternary Op, only @@ -1594,7 +1634,7 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_scale) { - if (this->product_type() == TensorProduct::Contraction) { + if (this->outer_product_uses_summa()) { // The inner perm handed to the plan must match how the inner // *result* permutation is applied for this result cell type -- // and the two cell types apply it in different places: @@ -1631,11 +1671,12 @@ class ContEngine : public BinaryEngine { auto fallback_op = [perm = !this->implicit_permute_inner_ ? inner(this->perm_) : Permutation{}, - outer_prod = this->product_type()]( + outer_uses_summa = + this->outer_product_uses_summa()]( result_tile_element_type& result, const left_tile_element_type& left, const right_tile_element_type& right) { - if (outer_prod == TensorProduct::Contraction) { + if (outer_uses_summa) { using TiledArray::axpy_to; if constexpr (tot_x_t) { if (perm) diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index cddf8e721c..d5db1bf093 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -384,8 +384,10 @@ class MultEngine : public ContEngine> { /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { if (this->product_type() == TensorProduct::General) { - // N.B. no inner tile op: ToT general products are rejected by - // init_struct_general + // the inner tile op (for tensors-of-tensors) must be initialized + // first; init_struct_general consumes element_nonreturn_op_ and the + // arena plan it builds + this->init_inner_tile_op(inner(target_indices)); ContEngine_::init_struct_general(target_indices); return; } diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index 51579bfc27..27e77420da 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -55,10 +55,6 @@ class BatchedContractReduce { typedef typename Op::first_argument_type first_argument_type; typedef typename Op::second_argument_type second_argument_type; - static_assert( - !TiledArray::detail::is_tensor_of_tensor_v, - "BatchedContractReduce does not (yet) support tensor-of-tensor tiles"); - private: op_type op_; ///< The folded-shape contract/reduce op unsigned int nfused_ = 0; ///< The number of leading fused modes @@ -165,12 +161,22 @@ class BatchedContractReduce { const std::size_t batch = fused_volume(left.range()); TA_ASSERT(batch == fused_volume(right.range())); - // allocate the result with its full (h, e_A, e_B) range: fused + left - // external bounds from the left tile, right external bounds from the - // right tile. The data layout of the full-range result coincides with - // the folded (range = (e_A, e_B), nbatch = batch) layout because the - // fused modes lead. + // folded, zero-copy argument views + auto left_folded = left.reshape(fold_range(left.range()), batch); + auto right_folded = right.reshape(fold_range(right.range()), batch); + if (empty(result)) { + // let the wrapped op allocate (and zero- or beta-0-initialize) the + // result in *folded* form -- this also engages its tile-type-specific + // result construction (e.g. the arena reserve for tensor-of-tensor + // tiles) -- then unfold by a zero-copy reshape: the data layout of the + // folded (range = (e_A, e_B), nbatch = batch) result coincides with + // the full (h, e_A, e_B) row-major layout because the fused modes + // lead. The full bounds: fused + left-external from the left tile, + // right-external from the right tile. + result_type result_folded; + op_(result_folded, left_folded, right_folded); + using index1_type = typename result_type::range_type::index1_type; container::svector lobounds, upbounds; lobounds.reserve(nfused_ + neA + neB); @@ -183,16 +189,13 @@ class BatchedContractReduce { lobounds.push_back(right.range().lobound_data()[d]); upbounds.push_back(right.range().upbound_data()[d]); } - result = result_type(typename result_type::range_type(lobounds, upbounds), - typename result_type::value_type(0)); + result = result_folded.reshape( + typename result_type::range_type(lobounds, upbounds)); + } else { + // accumulate through a folded, zero-copy view of the result + auto result_folded = result.reshape(fold_range(result.range()), batch); + op_(result_folded, left_folded, right_folded); } - - // folded, zero-copy views; the result view shares the result's buffer, - // so the batched GEMM accumulates in place - auto left_folded = left.reshape(fold_range(left.range()), batch); - auto right_folded = right.reshape(fold_range(right.range()), batch); - auto result_folded = result.reshape(fold_range(result.range()), batch); - op_(result_folded, left_folded, right_folded); } }; // class BatchedContractReduce diff --git a/tests/general_product.cpp b/tests/general_product.cpp index f284ac6cde..57b2ce5735 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -338,4 +338,92 @@ BOOST_AUTO_TEST_CASE(expression_general_product_sparse_batched_outer) { BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i,k"), 1e-10); } +namespace { + +using TArrayToT = + TA::DistArray>, TA::DensePolicy>; + +/// makes a dense ToT array over \p tr with cells of extents \p inner_extents +/// filled with an index-dependent pattern +TArrayToT make_patterned_tot_array( + TA::World& world, const TA::TiledRange& tr, + const std::vector& inner_extents, const double seed) { + TArrayToT result(world, tr); + for (auto it = result.begin(); it != result.end(); ++it) { + auto outer_range = result.trange().make_tile_range(it.index()); + TA::Tensor> tile(outer_range); + for (auto&& oix : outer_range) { + TA::Tensor cell{TA::Range(inner_extents)}; + for (auto&& iix : cell.range()) { + double v = seed; + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + for (auto x : iix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + cell[iix] = v; + } + tile[oix] = cell; + } + *it = tile; + } + return result; +} + +/// \return the max abs elementwise difference between two congruent dense ToT +/// arrays (replicated check: every rank fetches every tile) +double tot_max_abs_diff(const TArrayToT& lhs, const TArrayToT& rhs) { + double max_diff = 0.0; + const auto n = lhs.trange().tiles_range().volume(); + for (std::size_t ord = 0; ord < n; ++ord) { + auto lt = lhs.find(ord).get(); + auto rt = rhs.find(ord).get(); + for (std::size_t c = 0; c < lt.range().volume(); ++c) { + const auto& lc = lt.data()[c]; + const auto& rc = rt.data()[c]; + for (std::size_t e = 0; e < lc.range().volume(); ++e) + max_diff = std::max(max_diff, std::abs(lc.data()[e] - rc.data()[e])); + } + } + return max_diff; +} + +} // namespace + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_hadamard) { + // ToT general product, inner Hadamard: + // C("b,i,k;m") = A("b,i,j;m") * B("b,j,k;m") + // outer: b fused, j contracted, i/k free; inner: m fused + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;m") = a("b,i,j;m") * b("b,j,k;m")); + auto c_ref = TA::einsum(a("b,i,j;m"), b("b,j,k;m"), "b,i,k;m"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_contraction) { + // ToT general product, inner contraction: + // C("b,i,k;m,n") = A("b,i,j;m,c") * B("b,j,k;c,n") + // outer: b fused, j contracted, i/k free; inner: c contracted, m/n free + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_tot_array(world, tr_a, {3, 2}, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {2, 4}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;m,n") = a("b,i,j;m,c") * b("b,j,k;c,n")); + auto c_ref = TA::einsum(a("b,i,j;m,c"), b("b,j,k;c,n"), "b,i,k;m,n"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + BOOST_AUTO_TEST_SUITE_END() From fe6703a39f9df933d9cb5f4c0149044c7c604b06 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 10:51:14 -0400 Subject: [PATCH 07/25] einsum: opt-in expression route for general products + differential mode Phase D (partial): einsum can evaluate its generalized-contraction branch through the expression layer's native general-product support (one batched Summa in one World) instead of the legacy per-Hadamard-slab sub-World decomposition. The engine receives the canonical (fused..., left-free..., right-free...) result layout; arbitrary einsum targets are reached by a final permutation assignment. Three-way runtime control (detail::einsum_legacy_subworld / detail::einsum_differential, env TA_EINSUM_LEGACY_SUBWORLD / TA_EINSUM_DIFFERENTIAL): - legacy (DEFAULT for now, see below) - expression route (TA_EINSUM_LEGACY_SUBWORLD=0) - differential: evaluates BOTH routes per general product, compares squared norms, reports mismatching contractions (with annotations) to stderr, returns the legacy result. The legacy path is retained indefinitely as the reference implementation for such testing. Status: with the expression route, PNO-CCSD (c6h14/cc-pVDZ) runs with ZERO sub-Worlds (legacy: 1412) and the einsum-region attribution collapses from 17.9 s to 12.4 s of pure evaluation -- but the energy is WRONG (-238.09 vs -236.35). TA_EINSUM_DIFFERENTIAL isolates the mismatching shapes: (1) ToT x T (inner Scale) with a non-leading fused index and interleaved target, e.g. (i4,i1,mu;a) * (mu,i4,K) -> (i1,i4,K;a) (2) phantom-unit (denest-internal) general products, e.g. (mu,i1,i4;a) * (i1,i4,K;a,phantom) -> (mu,i1,K;phantom) Synthetic unit reproductions of (1) with fixed inner extents PASS, so the trigger involves CSV specifics (variable per-block inner extents and/or arena cell layout); under investigation. Until resolved the legacy path is the default and the expression route is opt-in. All unit suites green in both modes; einsum suites were also validated green against their reference data with the expression route as default before the flip. --- src/TiledArray/einsum/tiledarray.h | 90 ++++++++++++ tests/general_product.cpp | 214 +++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 099b19037f..8a8f7b85fa 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -14,10 +14,53 @@ #include +#include + namespace TiledArray { enum struct DeNest { True, False }; } +namespace TiledArray::detail { + +/// Runtime toggle for the legacy per-Hadamard-slab sub-World evaluation of +/// general products in einsum. +/// +/// einsum can route a general product (fused + contracted + free indices) +/// through the expression layer's native support (TensorProduct::General, +/// evaluated by the batched Summa: one task graph in one World, no per-slab +/// sub-Worlds) or through the legacy path (one MPI_Comm_split + sub-World + +/// fence per Hadamard slab). The legacy path is currently the DEFAULT: +/// the expression route has known mismatches on PNO-CC (CSV) workloads +/// (see TA_EINSUM_DIFFERENTIAL) that are under investigation. Set +/// TA_EINSUM_LEGACY_SUBWORLD=0 in the environment (or assign \c false to +/// the reference returned by this function) to opt into the expression +/// route. The legacy implementation is retained indefinitely as a reference +/// for differential testing. +inline bool &einsum_legacy_subworld() { + static bool flag = [] { + const char *e = std::getenv("TA_EINSUM_LEGACY_SUBWORLD"); + // default: legacy; any value other than "0" (incl. unset) keeps legacy + return e == nullptr || std::string_view(e) != "0"; + }(); + return flag; +} + +/// Differential-testing mode for einsum's general products: when enabled +/// (TA_EINSUM_DIFFERENTIAL set to a non-empty value other than "0"), every +/// general product is evaluated by BOTH the expression route and the legacy +/// sub-World route; the two results are compared (squared norm of the +/// difference) and mismatches are reported to stderr with the contraction +/// annotation. The legacy result is returned. +inline bool &einsum_differential() { + static bool flag = [] { + const char *e = std::getenv("TA_EINSUM_DIFFERENTIAL"); + return e != nullptr && e[0] != '\0' && std::string_view(e) != "0"; + }(); + return flag; +} + +} // namespace TiledArray::detail + namespace TiledArray::Einsum { using ::Einsum::index::small_vector; @@ -909,6 +952,35 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } } + // Route the general product through the expression layer's native + // support (TensorProduct::General -> batched Summa: one task graph in + // one World, no per-slab sub-Worlds) unless the legacy path is forced + // (see detail::einsum_legacy_subworld). The engine requires the + // canonical (fused..., left-free..., right-free...) result layout; an + // arbitrary einsum target is reached by a final permutation assignment. + // N.B. the inner annotation is already canonical here (the non-trivial + // inner permutation case recursed above). + std::optional expr_route_result; + if (!detail::einsum_legacy_subworld() || detail::einsum_differential()) { + _ein_call.branch = "generalized-expression"; + detail::EinsumTimer _t(_ein_call, detail::EinsumBucket::LocalKernel); + + TensorOpIndices top(a, b, c); + const auto c_canon = top.ix_C_canon(); + ArrayC result; + result(std::string(c_canon) + inner.c) = tnsrExprA * tnsrExprB; + if (c_canon == c) { + expr_route_result = std::move(result); + } else { + ArrayC result_perm; + result_perm(std::string(c) + inner.c) = + result(std::string(c_canon) + inner.c); + expr_route_result = std::move(result_perm); + } + if (!detail::einsum_differential()) return std::move(*expr_route_result); + _ein_call.branch = "generalized-subworld"; + } + auto update_tr = [&e = std::as_const(e), &i = std::as_const(i), &range_map = std::as_const(range_map)](auto &term) { auto ei = (e + i & term.idx); @@ -1009,6 +1081,24 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } } + if (expr_route_result) { + // differential mode: compare the expression route against the legacy + // result + const std::string c_annot = std::string(c) + inner.c; + ArrayC diff; + diff(c_annot) = (*expr_route_result)(c_annot)-C.array(c_annot); + const double d2 = diff(c_annot).squared_norm().get(); + const double ref2 = C.array(c_annot).squared_norm().get(); + if (!(d2 <= 1e-20 * std::max(ref2, 1.0))) { + if (world.rank() == 0) + std::cerr << "!! einsum DIFFERENTIAL MISMATCH: " + << (std::string)a + inner.a << " * " + << (std::string)b + inner.b << " -> " << c_annot + << " : diff2 = " << d2 << ", legacy2 = " << ref2 + << std::endl; + } + } + return C.array; } } diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 57b2ce5735..c7198bb806 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -130,6 +130,17 @@ BOOST_AUTO_TEST_CASE(optimizer_rejects_implicit_reduction) { namespace { +/// forces the legacy sub-World einsum within a scope, so einsum-based +/// reference values remain an *independent* oracle for the expression route +/// (einsum itself routes general products through the expression layer now) +struct ForceLegacyEinsum { + bool prev_; + ForceLegacyEinsum() : prev_(TA::detail::einsum_legacy_subworld()) { + TA::detail::einsum_legacy_subworld() = true; + } + ~ForceLegacyEinsum() { TA::detail::einsum_legacy_subworld() = prev_; } +}; + /// makes a dense array over \p tr filled with an index-dependent pattern TA::TArrayD make_patterned_array(TA::World& world, const TA::TiledRange& tr, const double seed) { @@ -164,6 +175,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense) { // dense general products evaluate via the batched Summa; differential-test // against the einsum free function (the established implementation) auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent // C("b,i,k") = A("b,i,j") * B("b,j,k"), uneven multi-tile dimensions TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j @@ -187,6 +199,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense_permuted_args) { // non-canonical argument layouts: the engine permutes the args into the // canonical (h, e_A, c) / (h, c, e_B) layouts auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 3, 4}, {0, 2, 6, 7}, {0, 2, 5}}; // i, j, b TA::TiledRange tr_b{{0, 4, 5}, {0, 2, 6, 7}, {0, 2, 5}}; // k, j, b @@ -202,6 +215,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense_permuted_args) { BOOST_AUTO_TEST_CASE(expression_general_product_dense_batched_outer) { // batched outer product: fused + free, no contracted indices auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}}; // b, i TA::TiledRange tr_b{{0, 2, 5}, {0, 4, 5}}; // b, k @@ -241,6 +255,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_thc_intermediates) { // each general product as the root of its own assignment (fused indices // leading in the result annotation), differential-tested against einsum auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary auto x = make_patterned_array(world, tr_x, 1.0); @@ -311,6 +326,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_sparse) { // (SparseShape::gemm_batched) and the batched Summa runs its sparse path // ((h,k)-keyed masks/groups); differential-test against einsum auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k @@ -326,6 +342,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_sparse) { BOOST_AUTO_TEST_CASE(expression_general_product_sparse_batched_outer) { // block-sparse batched outer product (no contracted indices) auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}}; // b, i TA::TiledRange tr_b{{0, 2, 5}, {0, 4, 5}}; // b, k @@ -399,6 +416,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_hadamard) { // C("b,i,k;m") = A("b,i,j;m") * B("b,j,k;m") // outer: b fused, j contracted, i/k free; inner: m fused auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); @@ -415,6 +433,7 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_contraction) { // C("b,i,k;m,n") = A("b,i,j;m,c") * B("b,j,k;c,n") // outer: b fused, j contracted, i/k free; inner: c contracted, m/n free auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k auto a = make_patterned_tot_array(world, tr_a, {3, 2}, 1.0); @@ -426,4 +445,199 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_contraction) { BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_general_product_tot_times_t) { + // mixed ToT x T general product (inner Scale): + // C("b,i,k;m") = A("b,i,j;m") * B("b,j,k") + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;m") = a("b,i,j;m") * b("b,j,k")); + auto c_ref = TA::einsum(a("b,i,j;m"), b("b,j,k"), "b,i,k;m"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_t_times_tot) { + // mixed T x ToT general product (inner Scale): + // C("b,i,k;m") = A("b,i,j") * B("b,j,k;m") + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;m") = a("b,i,j") * b("b,j,k;m")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k;m"), "b,i,k;m"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { + // the PNO-CC PPL building-block shape: ToT x ToT with an EMPTY right + // outer-external set and an inner OUTER-product: + // C("b1,b2,K;x,y") = A("b1,b2,m,K;x") * B("b1,b2,m;y") + // outer: b1,b2 fused, m contracted, K left-free, no right-free; + // inner: x (left) (x) y (right), no inner contraction + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}, {0, 3, 6}}; // b1,b2,m,K + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b1,b2,m + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {2}, 2.0); + + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b1,b2,K;x,y") = a("b1,b2,m,K;x") * b("b1,b2,m;y")); + auto c_ref = TA::einsum(a("b1,b2,m,K;x"), b("b1,b2,m;y"), "b1,b2,K;x,y"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_arena_inner_outer_product) { + // same shape as expression_general_product_tot_inner_outer_product but + // with arena-backed (ArenaTensor view) inner cells -- the PNO-CC CSV + // representation; exercises the arena plans + strided ce+e kernel under + // the batched Summa + using ArenaInner = TA::ArenaTensor; + using ArenaOuter = TA::Tensor; + using ArenaArr = TA::DistArray; + + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}, {0, 3, 6}}; // b1,b2,m,K + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b1,b2,m + constexpr long X = 3, Y = 2; + + auto fill = [](ArenaArr& arr, const TA::TiledRange& tr, const long n_in, + const double seed) { + arr = ArenaArr(arr.world(), tr); + arr.init_tiles([n_in, seed](const TA::Range& outer_range) { + ArenaOuter t = TA::detail::arena_outer_init( + outer_range, 1, + [n_in](std::size_t /*ord*/) { return TA::Range{n_in}; }); + std::size_t o = 0; + for (auto&& oix : t.range()) { + ArenaInner& cell = t.data()[o++]; + if (!cell) continue; + for (long e = 0; e < n_in; ++e) { + double v = seed + 0.01 * static_cast(e + 1); + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + cell.data()[e] = v; + } + } + return t; + }); + }; + + ArenaArr a(world, tr_a), b(world, tr_b); + fill(a, tr_a, X, 1.0); + fill(b, tr_b, Y, 2.0); + world.gop.fence(); + + ArenaArr c; + BOOST_REQUIRE_NO_THROW(c("b1,b2,K;x,y") = a("b1,b2,m,K;x") * b("b1,b2,m;y")); + auto c_ref = TA::einsum(a("b1,b2,m,K;x"), b("b1,b2,m;y"), "b1,b2,K;x,y"); + + // elementwise comparison (replicated) + double max_diff = 0.0; + const auto n = c_ref.trange().tiles_range().volume(); + for (std::size_t ord = 0; ord < n; ++ord) { + auto lt = c.find(ord).get(); + auto rt = c_ref.find(ord).get(); + BOOST_REQUIRE_EQUAL(lt.range().volume(), rt.range().volume()); + for (std::size_t cc = 0; cc < lt.range().volume(); ++cc) { + const auto& lc = lt.data()[cc]; + const auto& rc = rt.data()[cc]; + BOOST_REQUIRE_EQUAL(lc.range().volume(), rc.range().volume()); + for (std::size_t e = 0; e < lc.range().volume(); ++e) + max_diff = std::max(max_diff, std::abs(lc.data()[e] - rc.data()[e])); + } + } + BOOST_CHECK_SMALL(max_diff, 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_batched_outer) { + // batched outer product (no contracted outer index) of ToTs with an inner + // contraction, an interleaved target, and an empty right-external set: + // C("i2,i1,b1,b2;y") = A("b1,b2,i2,i1;x") * B("b1,b2;x,y") + // (the canonical layout is (b1,b2,i2,i1); einsum reaches the interleaved + // target by the final permutation assignment) + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}, {0, 2, 3}}; + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 3}}; + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3, 2}, 2.0); + + auto c = TA::einsum(a("b1,b2,i2,i1;x"), b("b1,b2;x,y"), "i2,i1,b1,b2;y"); + TiledArray::detail::einsum_legacy_subworld() = false; + auto c_new = TA::einsum(a("b1,b2,i2,i1;x"), b("b1,b2;x,y"), "i2,i1,b1,b2;y"); + TiledArray::detail::einsum_legacy_subworld() = true; + BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_t_tot_batched_outer) { + // batched outer product of a plain tensor with a ToT (inner Scale), with + // an EMPTY left-external set: + // C("i2,i3,i1;x") = A("i3,i1") * B("i3,i1,i2;x") + // (fused i3,i1; no contracted; right-external i2) + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}}; + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); + + auto c = TA::einsum(a("i3,i1"), b("i3,i1,i2;x"), "i2,i3,i1;x"); + TiledArray::detail::einsum_legacy_subworld() = false; + auto c_new = TA::einsum(a("i3,i1"), b("i3,i1,i2;x"), "i2,i3,i1;x"); + TiledArray::detail::einsum_legacy_subworld() = true; + BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_t_nonleading_fused) { + // reproduction of a CSV-CC mismatch shape: ToT x T (inner Scale) with a + // NON-leading fused index and an interleaved target: + // C("i1,i4,K;x") = A("i4,i1,m;x") * B("m,i4,K") + // (fused i4; contracted m; eA = i1; eB = K; canonical = i4,i1,K) + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // i4, i1, m + TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 4}, {0, 3, 6}}; // m, i4, K + auto a = make_patterned_tot_array(world, tr_a, {3}, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + + auto c_ref = TA::einsum(a("i4,i1,m;x"), b("m,i4,K"), "i1,i4,K;x"); + TiledArray::detail::einsum_legacy_subworld() = false; + auto c_new = TA::einsum(a("i4,i1,m;x"), b("m,i4,K"), "i1,i4,K;x"); + TiledArray::detail::einsum_legacy_subworld() = true; + BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(einsum_expression_route_matches_legacy) { + // einsum routes general products through the expression layer by default; + // differential-check against the retained legacy sub-World path, with an + // interleaved (non-canonical) target reached by the final permutation + // assignment + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + + auto c_legacy = TA::einsum(a("b,i,j"), b("b,j,k"), "i,b,k"); + TA::detail::einsum_legacy_subworld() = false; + auto c_new = TA::einsum(a("b,i,j"), b("b,j,k"), "i,b,k"); + TA::detail::einsum_legacy_subworld() = true; + BOOST_CHECK_SMALL(diff_norm(c_new, c_legacy, "i,b,k"), 1e-10); +} + BOOST_AUTO_TEST_SUITE_END() From 0e0482052b4fbe0795c564a1acaa2f1be458b1dd Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 12:11:17 -0400 Subject: [PATCH 08/25] tensor: fix dropped rows/cols in the ToT x scalar strided scale paths The GEMM-based ToT x scalar scale paths of Tensor::gemm (and the T x ToT mirror) probe each row (column) for cleanliness; the presence probe stops at the first ABSENT cell, leaving the probed inner size A == -1 when the leading cell is absent. The subsequent 'A <= 0 => empty row, nothing to do' shortcut then dropped the ENTIRE row's contributions even when later cells were present. Rows whose leading contracted cell is absent (common for screened tensor-of-tensors, e.g. PNO-CC CSV intermediates) silently lost their contraction. Fix: when the probe ends with A <= 0, scan the full row (column) for any present cell; only a fully absent row is skipped, anything else takes the per-cell AXPY fallback. Also guard the engine's scale fallback element op against absent cells. This bug predates the general-product work but was masked on the legacy einsum route, whose canonical operand layout fails the NoTranspose gate of the strided path; the expression route's GEMM-canonical layout exposed it. Found with TA_EINSUM_DIFFERENTIAL on c6h14 PNO-CCSD: the opt-in expression-route energy error drops from 1.7 Eh to 6.9e-7 (the small residual is a screening-semantics difference -- the legacy route hard-zeroes sub-threshold result tiles that the engine route keeps -- plus a small systematic difference in phantom-unit denest products, both under review). Adds the CSV-like reproduction test (arena view cells, SparsePolicy, variable inner extents, screened cells, non-leading fused index, interleaved target): expression route and einsum routes agree to 1e-10, deterministically. --- src/TiledArray/einsum/tiledarray.h | 45 +++ src/TiledArray/expressions/cont_engine.h | 2 + src/TiledArray/tensor/tensor.h | 25 +- .../tile_op/batched_contract_reduce.h | 12 +- tests/general_product.cpp | 378 ++++++++++++++++++ 5 files changed, 459 insertions(+), 3 deletions(-) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 8a8f7b85fa..3ac9cbf88f 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -1096,6 +1096,51 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, << (std::string)b + inner.b << " -> " << c_annot << " : diff2 = " << d2 << ", legacy2 = " << ref2 << std::endl; + // per-tile forensics: compare tile norms of the two routes + auto tile_norm2 = [](auto const &tile) -> double { + using TileT = + std::remove_cv_t>; + double n2 = 0; + if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { + for (std::size_t o = 0; o < tile.range().volume(); ++o) { + auto const &cell = tile.data()[o]; + if (cell.empty()) continue; + for (std::size_t e = 0; e < cell.range().volume(); ++e) + n2 += std::abs(cell.data()[e]) * std::abs(cell.data()[e]); + } + } else { + for (std::size_t e = 0; e < tile.range().volume() * tile.nbatch(); + ++e) + n2 += std::abs(tile.data()[e]) * std::abs(tile.data()[e]); + } + return n2; + }; + const auto ntiles = C.array.trange().tiles_range().volume(); + std::size_t n_diff = 0, n_expr_zero = 0, n_legacy_zero = 0, + n_printed = 0; + for (std::size_t ord = 0; ord < ntiles; ++ord) { + const bool ez = expr_route_result->is_zero(ord); + const bool lz = C.array.is_zero(ord); + double en2 = + ez ? 0.0 : tile_norm2(expr_route_result->find(ord).get()); + double ln2 = lz ? 0.0 : tile_norm2(C.array.find(ord).get()); + const double dd = std::abs(en2 - ln2); + if (dd <= 1e-14 * std::max(std::max(en2, ln2), 1.0)) continue; + ++n_diff; + if (en2 == 0.0) ++n_expr_zero; + if (ln2 == 0.0) ++n_legacy_zero; + if (world.rank() == 0 && n_printed < 8) { + ++n_printed; + std::cerr << " tile " << ord << " (" + << C.array.trange().tiles_range().idx(ord) + << "): expr_n2 = " << en2 << ", legacy_n2 = " << ln2 + << std::endl; + } + } + if (world.rank() == 0) + std::cerr << " summary: " << n_diff << "/" << ntiles + << " tiles differ by norm; expr-zero " << n_expr_zero + << ", legacy-zero " << n_legacy_zero << std::endl; } } diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 91b704108c..17c54318c6 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -1679,11 +1679,13 @@ class ContEngine : public BinaryEngine { if (outer_uses_summa) { using TiledArray::axpy_to; if constexpr (tot_x_t) { + if (left.empty()) return; // absent cell: no contribution if (perm) axpy_to(result, left, right, perm); else axpy_to(result, left, right); } else { + if (right.empty()) return; // absent cell: no contribution if (perm) axpy_to(result, right, left, perm); else diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index c623eaccb1..bb9d49223a 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -3374,7 +3374,19 @@ class Tensor { } detail::scale_phase_stop(detail::g_scale[0].check_str_ns, _scale_tcs); - if (A <= 0) continue; // empty row -> nothing to do + if (A <= 0) { + // The presence probe stops at the first empty cell, so + // A <= 0 only means no non-empty cell was seen BEFORE the + // probe stopped -- NOT that the row is empty. The row + // contributes nothing only if EVERY left cell is absent; + // else it must take the per-cell AXPY fallback (whose + // element op skips absent cells). + bool row_empty = true; + for (integer k = 0; row_empty && k != K; ++k) + if (!lc0[k].empty()) row_empty = false; + if (row_empty) continue; + clean = false; + } if (clean) { // result[m,n][a] += sum_k left[m,k][a] * right[k,n]. // Row-major gemm: C2(N x A) += right^T(N x K) * L2(K x A), @@ -3533,7 +3545,16 @@ class Tensor { } detail::scale_phase_stop(detail::g_scale[1].check_str_ns, _scale_tcs); - if (A <= 0) continue; + if (A <= 0) { + // see the ToT x scalar mirror above: A <= 0 does not imply + // an empty column when the probe stopped at the first + // absent cell + bool col_empty = true; + for (integer k = 0; col_empty && k != K; ++k) + if (!right_data[k * N + n].empty()) col_empty = false; + if (col_empty) continue; + clean = false; + } if (clean) { // C_n(M x A) += left(M x K) * B_n(K x A). Row-major gemm. if (detail::scale_gemm_timing_enabled()) { diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index 27e77420da..bd3fd19125 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -193,8 +193,18 @@ class BatchedContractReduce { typename result_type::range_type(lobounds, upbounds)); } else { // accumulate through a folded, zero-copy view of the result - auto result_folded = result.reshape(fold_range(result.range()), batch); + const auto full_range = result.range(); + auto result_folded = result.reshape(fold_range(full_range), batch); op_(result_folded, left_folded, right_folded); + // the wrapped op may REBIND the result instead of writing in place: + // the arena grow-to-cover path (a later K-panel touching inner cells + // that earlier panels left null, for contracted-dimension-sparse + // tensor-of-tensor operands) replaces the tile wholesale + // (arena_tot_grow_inplace ends with `result = std::move(grown)`). + // That rebinding lands on the folded view; propagate it to the + // full-range result, else every post-growth contribution is lost. + if (result_folded.data() != result.data()) + result = result_folded.reshape(full_range); } } diff --git a/tests/general_product.cpp b/tests/general_product.cpp index c7198bb806..331306a94d 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -622,6 +622,384 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_t_nonleading_fused) { BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(arena_outer_permute_assignment) { + // hypothesis probe for the CSV mismatches: a pure OUTER permutation + // assignment of an arena-backed (ArenaTensor view cell) ToT array + using ArenaInner = TA::ArenaTensor; + using ArenaOuter = TA::Tensor; + using ArenaArr = TA::DistArray; + + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 2, 4}, {0, 2, 3}, {0, 3, 6}}; // b1, b2, k + constexpr long X = 3; + + ArenaArr a(world, tr); + a.init_tiles([](const TA::Range& outer_range) { + ArenaOuter t = TA::detail::arena_outer_init( + outer_range, 1, [](std::size_t) { return TA::Range{X}; }); + std::size_t o = 0; + for (auto&& oix : t.range()) { + ArenaInner& cell = t.data()[o++]; + if (!cell) continue; + for (long e = 0; e < X; ++e) { + double v = 0.01 * static_cast(e + 1); + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + cell.data()[e] = v; + } + } + return t; + }); + world.gop.fence(); + + ArenaArr b; + b("k,b1,b2;x") = a("b1,b2,k;x"); + + // verify element-by-element against the source + double max_diff = 0.0; + const auto& trb = b.trange(); + for (std::size_t ord = 0; ord < trb.tiles_range().volume(); ++ord) { + auto bt = b.find(ord).get(); + for (auto&& oix : bt.range()) { + const auto& bc = bt[oix]; + // source element index: (b1,b2,k) from (k,b1,b2) + std::array six{static_cast(oix[1]), + static_cast(oix[2]), + static_cast(oix[0])}; + double v0 = 0.0; + for (long e = 0; e < X; ++e) { + double v = 0.01 * static_cast(e + 1); + double scale = 1.0; + for (auto x : six) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + max_diff = std::max(max_diff, std::abs(bc.data()[e] - v)); + v0 = v; + } + (void)v0; + } + } + BOOST_CHECK_SMALL(max_diff, 1e-12); +} + +namespace { + +using TSpArrayToT = + TA::DistArray>, TA::SparsePolicy>; + +/// block-sparse ToT array with index-dependent fill, every zero_stride-th +/// tile zero +TSpArrayToT make_patterned_sparse_tot_array( + TA::World& world, const TA::TiledRange& tr, + const std::vector& inner_extents, const double seed, + const std::size_t zero_stride) { + TA::Tensor norms(tr.tiles_range(), 1.0f); + for (std::size_t ord = 0; ord < norms.size(); ord += zero_stride) + norms.data()[ord] = 0.0f; + TA::SparseShape shape(norms, tr); + + TSpArrayToT result(world, tr, shape); + for (auto it = result.begin(); it != result.end(); ++it) { + auto outer_range = result.trange().make_tile_range(it.index()); + TA::Tensor> tile(outer_range); + for (auto&& oix : outer_range) { + TA::Tensor cell{TA::Range(inner_extents)}; + for (auto&& iix : cell.range()) { + double v = seed; + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + for (auto x : iix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + cell[iix] = v; + } + tile[oix] = cell; + } + *it = tile; + } + return result; +} + +/// max abs elementwise diff of two congruent sparse ToT arrays +double sparse_tot_max_abs_diff(const TSpArrayToT& lhs, const TSpArrayToT& rhs) { + double max_diff = 0.0; + const auto n = lhs.trange().tiles_range().volume(); + for (std::size_t ord = 0; ord < n; ++ord) { + const bool lz = lhs.is_zero(ord); + const bool rz = rhs.is_zero(ord); + if (lz && rz) continue; + if (lz != rz) { + // one is a zero tile: the other must be numerically zero + auto t = (lz ? rhs : lhs).find(ord).get(); + for (std::size_t c = 0; c < t.range().volume(); ++c) { + const auto& cell = t.data()[c]; + if (cell.empty()) continue; + for (std::size_t e = 0; e < cell.range().volume(); ++e) + max_diff = std::max(max_diff, std::abs(cell.data()[e])); + } + continue; + } + auto lt = lhs.find(ord).get(); + auto rt = rhs.find(ord).get(); + for (std::size_t c = 0; c < lt.range().volume(); ++c) { + const auto& lc = lt.data()[c]; + const auto& rc = rt.data()[c]; + if (lc.empty() && rc.empty()) continue; + for (std::size_t e = 0; e < lc.range().volume(); ++e) + max_diff = std::max(max_diff, std::abs(lc.data()[e] - rc.data()[e])); + } + } + return max_diff; +} + +} // namespace + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_tot) { + // SPARSE-policy ToT general product -- the CSV-CC case (all other ToT + // tests are dense): exercises the batched Summa's sparse step iteration + // ((h,k) skipping) and sparse reducer gating with ToT tiles + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_sparse_tot_array(world, tr_a, {3}, 1.0, 3); + auto b = make_patterned_sparse_tot_array(world, tr_b, {2}, 2.0, 4); + + TSpArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;x,y") = a("b,i,j;x") * b("b,j,k;y")); + auto c_ref = TA::einsum(a("b,i,j;x"), b("b,j,k;y"), "b,i,k;x,y"); + BOOST_CHECK_SMALL(sparse_tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_tot_skipped_steps) { + // STRUCTURED block-sparsity that zeroes entire (slab, k) panels, forcing + // the batched Summa's sparse step iteration to SKIP steps (and discard + // the skipped panels' tiles) -- the suspected CSV-CC failure mode + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr_a{{0, 2, 4, 6}, {0, 2, 3}, {0, 2, 5, 7}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4, 6}, {0, 2, 5, 7}, {0, 3, 4}}; // b, j, k + + // left: zero the whole (b=0, j=1) and (b=2, j=0) panels (all i) + TA::Tensor norms_a(tr_a.tiles_range(), 1.0f); + for (std::size_t i = 0; i < 2; ++i) { + norms_a[std::array{0, i, 1}] = 0.0f; + norms_a[std::array{2, i, 0}] = 0.0f; + } + // right: zero the whole (b=1, j=2) panel (all k) and (b=0, j=1) + TA::Tensor norms_b(tr_b.tiles_range(), 1.0f); + for (std::size_t k = 0; k < 2; ++k) { + norms_b[std::array{1, 2, k}] = 0.0f; + norms_b[std::array{0, 1, k}] = 0.0f; + } + TA::SparseShape shape_a(norms_a, tr_a); + TA::SparseShape shape_b(norms_b, tr_b); + + auto fill = [](TSpArrayToT& arr, const std::size_t n_in, const double seed) { + for (auto it = arr.begin(); it != arr.end(); ++it) { + auto outer_range = arr.trange().make_tile_range(it.index()); + TA::Tensor> tile(outer_range); + for (auto&& oix : outer_range) { + TA::Tensor cell{TA::Range(std::vector{n_in})}; + for (auto&& iix : cell.range()) { + double v = seed; + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + for (auto x : iix) v += 0.001 * static_cast(x + 1); + cell[iix] = v; + } + tile[oix] = cell; + } + *it = tile; + } + }; + + TSpArrayToT a(world, tr_a, shape_a); + TSpArrayToT b(world, tr_b, shape_b); + fill(a, 3, 1.0); + fill(b, 2, 2.0); + world.gop.fence(); + + // run the expression route TWICE to also detect non-determinism + TSpArrayToT c1, c2; + BOOST_REQUIRE_NO_THROW(c1("b,i,k;x,y") = a("b,i,j;x") * b("b,j,k;y")); + BOOST_REQUIRE_NO_THROW(c2("b,i,k;x,y") = a("b,i,j;x") * b("b,j,k;y")); + auto c_ref = TA::einsum(a("b,i,j;x"), b("b,j,k;y"), "b,i,k;x,y"); + const double d_rep = sparse_tot_max_abs_diff(c1, c2); + const double d_ref = sparse_tot_max_abs_diff(c1, c_ref); + BOOST_CHECK_SMALL(d_rep, 1e-12); + BOOST_CHECK_SMALL(d_ref, 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_csv_like) { + // kitchen-sink reproduction of the CSV-CC mismatch shape + // (i2,i1,m;a) * (m,i2,K) -> (i1,i2,K;a): + // arena (view) inner cells, SparsePolicy, inner extent VARYING with the + // fused+external outer pair, EMPTY (screened) cells, non-leading fused + // index, interleaved target + using ArenaInner = TA::ArenaTensor; + using ArenaOuter = TA::Tensor; + using ArenaSpArr = TA::DistArray; + using PlainSpArr = TA::TSpArrayD; + + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5, 7}}; // i2, i1, m + TA::TiledRange tr_b{{0, 2, 5, 7}, {0, 2, 4}, {0, 3, 6}}; // m, i2, K + + // left: arena ToT, sparse (zero the (i2=1, i1=0, m=1) tile); inner extent + // depends on (i2 + i1); the (i2+i1+m) % 5 == 0 cells are empty (screened) + TA::Tensor norms_a(tr_a.tiles_range(), 1.0f); + norms_a[std::array{1, 0, 1}] = 0.0f; + ArenaSpArr a(world, tr_a, TA::SparseShape(norms_a, tr_a)); + a.init_tiles([](const TA::Range& outer_range) { + auto lo = outer_range.lobound_data(); + auto cell_range = [&outer_range, lo](std::size_t ord) { + auto oix = outer_range.idx(ord); + const long ext = 2 + (oix[0] + oix[1]) % 3; // varies with pair + if ((oix[0] + oix[1] + oix[2]) % 5 == 0) // screened cells + return TA::Range{}; + return TA::Range{ext}; + }; + ArenaOuter t = + TA::detail::arena_outer_init(outer_range, 1, cell_range); + std::size_t o = 0; + for (auto&& oix : t.range()) { + ArenaInner& cell = t.data()[o++]; + if (!cell) continue; + for (std::size_t e = 0; e < cell.range().volume(); ++e) { + double v = 1.0 + 0.01 * static_cast(e + 1); + double scale = 1.0; + for (auto x : oix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + cell.data()[e] = v; + } + } + (void)lo; + return t; + }); + + // right: plain sparse (zero one m-panel tile) + TA::Tensor norms_b(tr_b.tiles_range(), 1.0f); + norms_b[std::array{1, 1, 0}] = 0.0f; + PlainSpArr b(world, tr_b, TA::SparseShape(norms_b, tr_b)); + for (auto it = b.begin(); it != b.end(); ++it) { + auto tile = PlainSpArr::value_type(b.trange().make_tile_range(it.index())); + for (auto&& ix : tile.range()) { + double v = 2.0; + double scale = 1.0; + for (auto x : ix) { + v += scale * static_cast(x + 1); + scale *= 0.1; + } + tile[ix] = v; + } + *it = tile; + } + world.gop.fence(); + + // expression route twice (determinism) + legacy einsum oracle + // mirror the CSV path exactly: through einsum (which canonicalizes the + // target and permutes at the end), new route vs the legacy oracle + auto c_ref = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); + TA::detail::einsum_legacy_subworld() = false; + ArenaSpArr c1, c2; + try { + c1 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); + c2 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); + } catch (std::exception& ex) { + std::cerr << "EXPRESSION ROUTE THREW: " << ex.what() << std::endl; + TA::detail::einsum_legacy_subworld() = true; + throw; + } + TA::detail::einsum_legacy_subworld() = true; + + // BISECT: same shape, CANONICAL target, direct expression (no einsum + // wrapper, no final permutation) vs the legacy oracle + { + ArenaSpArr c_canon_expr; + c_canon_expr("i2,i1,K;a") = a("i2,i1,m;a") * b("m,i2,K"); + auto c_canon_ref = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i2,i1,K;a"); + double md = 0.0; + const auto n = c_canon_ref.trange().tiles_range().volume(); + for (std::size_t ord = 0; ord < n; ++ord) { + if (c_canon_expr.is_zero(ord) && c_canon_ref.is_zero(ord)) continue; + if (c_canon_expr.is_zero(ord) != c_canon_ref.is_zero(ord)) { + md = std::max(md, 1e10); // zero-pattern mismatch marker + continue; + } + auto lt = c_canon_expr.find(ord).get(); + auto rt = c_canon_ref.find(ord).get(); + for (std::size_t cc = 0; cc < lt.range().volume(); ++cc) { + const auto& lc = lt.data()[cc]; + const auto& rc = rt.data()[cc]; + if (!lc && !rc) continue; + const std::size_t ne = lc ? lc.range().volume() : rc.range().volume(); + for (std::size_t e = 0; e < ne; ++e) { + const double lv = lc ? lc.data()[e] : 0.0; + const double rv = rc ? rc.data()[e] : 0.0; + md = std::max(md, std::abs(lv - rv)); + } + } + } + std::cerr << "CANONICAL-TARGET DIRECT-EXPR max_diff = " << md << std::endl; + BOOST_CHECK_SMALL(md, 1e-10); + } + + auto max_diff = [](const ArenaSpArr& lhs, const ArenaSpArr& rhs) { + double md = 0.0; + const auto n = lhs.trange().tiles_range().volume(); + for (std::size_t ord = 0; ord < n; ++ord) { + const bool lz = lhs.is_zero(ord); + const bool rz = rhs.is_zero(ord); + if (lz && rz) continue; + if (lz != rz) { + auto t = (lz ? rhs : lhs).find(ord).get(); + for (std::size_t c = 0; c < t.range().volume(); ++c) { + const auto& cell = t.data()[c]; + if (!cell) continue; + for (std::size_t e = 0; e < cell.range().volume(); ++e) + md = std::max(md, std::abs(cell.data()[e])); + } + continue; + } + auto lt = lhs.find(ord).get(); + auto rt = rhs.find(ord).get(); + for (std::size_t c = 0; c < lt.range().volume(); ++c) { + const auto& lc = lt.data()[c]; + const auto& rc = rt.data()[c]; + if (!lc && !rc) continue; + if (bool(lc) != bool(rc)) { + const auto& nz = lc ? lc : rc; + for (std::size_t e = 0; e < nz.range().volume(); ++e) + md = std::max(md, std::abs(nz.data()[e])); + continue; + } + for (std::size_t e = 0; e < lc.range().volume(); ++e) + md = std::max(md, std::abs(lc.data()[e] - rc.data()[e])); + } + } + return md; + }; + + BOOST_CHECK_SMALL(max_diff(c1, c2), 1e-12); + BOOST_CHECK_SMALL(max_diff(c1, c_ref), 1e-10); +} + BOOST_AUTO_TEST_CASE(einsum_expression_route_matches_legacy) { // einsum routes general products through the expression layer by default; // differential-check against the retained legacy sub-World path, with an From ef0066c01c8f22dcf731f25ec0344d47cfbce7d4 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 14:19:46 -0400 Subject: [PATCH 09/25] einsum: expression route is now the default for general products With the strided-scale-path fix in place, the TA_EINSUM_DIFFERENTIAL audit of c6h14 PNO-CCSD shows the two routes agree except for: - sub-threshold result tiles that the legacy path implicitly hard-zeroes (its result shape derives from the harvested tile norms) while the expression route keeps them (standard estimate-derived contraction shape). Per the TA screening philosophy, norms are trusted as genuine and no implicit truncation is performed; users wanting the tighter shape call truncate() explicitly. - floating-point summation-order noise in tiny, heavily-cancelling tensors (absolute tile-norm^2 differences <= 1e-9, no structural pattern). Neither is a defect, so general products in einsum now default to the expression-layer evaluation (TensorProduct::General -> batched Summa: one task graph in one World, ZERO per-slab sub-Worlds). The legacy path remains available via TA_EINSUM_LEGACY_SUBWORLD (or detail::einsum_legacy_subworld()) as the reference implementation for differential testing. All suites pass with the new default (einsum suites against their reference data; the 2 pre-existing assign_subblock_block_base1 failures are unrelated); np = 1, 2, 3. --- src/TiledArray/einsum/tiledarray.h | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 3ac9cbf88f..1019e69ed6 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -29,18 +29,24 @@ namespace TiledArray::detail { /// through the expression layer's native support (TensorProduct::General, /// evaluated by the batched Summa: one task graph in one World, no per-slab /// sub-Worlds) or through the legacy path (one MPI_Comm_split + sub-World + -/// fence per Hadamard slab). The legacy path is currently the DEFAULT: -/// the expression route has known mismatches on PNO-CC (CSV) workloads -/// (see TA_EINSUM_DIFFERENTIAL) that are under investigation. Set -/// TA_EINSUM_LEGACY_SUBWORLD=0 in the environment (or assign \c false to -/// the reference returned by this function) to opt into the expression -/// route. The legacy implementation is retained indefinitely as a reference -/// for differential testing. +/// fence per Hadamard slab). The expression route is the DEFAULT; set +/// TA_EINSUM_LEGACY_SUBWORLD in the environment (any non-empty value other +/// than "0"), or assign \c true to the reference returned by this function, +/// to force the legacy path. The legacy implementation is retained +/// indefinitely as a reference for differential testing +/// (TA_EINSUM_DIFFERENTIAL). +/// +/// \note the two routes may legitimately differ on block-sparse data: the +/// legacy path derives the result shape from the harvested tile norms and +/// thus hard-zeroes sub-threshold result tiles, while the expression route +/// keeps them (its shape is the standard estimate-derived contraction +/// shape). Per the TA screening philosophy norms are trusted as genuine and +/// no implicit truncation is performed; call truncate() explicitly if the +/// tighter shape is desired. inline bool &einsum_legacy_subworld() { static bool flag = [] { const char *e = std::getenv("TA_EINSUM_LEGACY_SUBWORLD"); - // default: legacy; any value other than "0" (incl. unset) keeps legacy - return e == nullptr || std::string_view(e) != "0"; + return e != nullptr && e[0] != char(0) && std::string_view(e) != "0"; }(); return flag; } From 5659a35fdbbf7cba5e521f023ac24f0278f1bcd5 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 18:43:26 -0400 Subject: [PATCH 10/25] address PR #562 review: include hygiene, K_ re-init, gemm_batched congruence checks - einsum/tiledarray.h, tile_op/batched_contract_reduce.h, pmap/slabbed_pmap.h: include what is used (, , util/vector.h, , ) instead of relying on transitive includes - cont_engine.h: re-initialize K_ in init_distribution_general() (defensive; engines are single-use, but mirrors the n_slabs_ reset) - sparse_shape.h: gemm_batched() now TA_ASSERTs that the argument ranks match the folded gemm ranks plus the fused modes and that the fused and contracted mode extents of the two shapes are congruent (the batched analogue of the checks GemmHelper::compute_matrix_sizes performs for plain gemm) --- src/TiledArray/einsum/tiledarray.h | 2 ++ src/TiledArray/expressions/cont_engine.h | 1 + src/TiledArray/pmap/slabbed_pmap.h | 3 +++ src/TiledArray/sparse_shape.h | 14 ++++++++++++++ src/TiledArray/tile_op/batched_contract_reduce.h | 1 + 5 files changed, 21 insertions(+) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 1019e69ed6..f21e83f519 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -14,7 +14,9 @@ #include +#include #include +#include namespace TiledArray { enum struct DeNest { True, False }; diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 17c54318c6..dfab3ee664 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -793,6 +793,7 @@ class ContEngine : public BinaryEngine { // Compute the slab count and the fused sizes of the per-slab contraction size_type M = 1ul, m = 1ul, N = 1ul, n = 1ul; n_slabs_ = 1ul; + K_ = 1ul; for (unsigned int i = 0u; i < nh; ++i) n_slabs_ *= left_tiles_size[i]; for (unsigned int i = nh; i < nh + neA; ++i) { M *= left_tiles_size[i]; diff --git a/src/TiledArray/pmap/slabbed_pmap.h b/src/TiledArray/pmap/slabbed_pmap.h index 17fe370751..a345eccb6a 100644 --- a/src/TiledArray/pmap/slabbed_pmap.h +++ b/src/TiledArray/pmap/slabbed_pmap.h @@ -25,6 +25,9 @@ #include +#include +#include + namespace TiledArray { namespace detail { diff --git a/src/TiledArray/sparse_shape.h b/src/TiledArray/sparse_shape.h index 9bcd381462..7af41176c7 100644 --- a/src/TiledArray/sparse_shape.h +++ b/src/TiledArray/sparse_shape.h @@ -1719,6 +1719,20 @@ class SparseShape { using integer = TiledArray::math::blas::integer; const auto* left_extent = tile_norms_.range().extent_data(); const auto* right_extent = other.tile_norms_.range().extent_data(); + + // check that the ranks match the folded gemm ranks plus the fused modes, + // and that the fused and contracted mode extents of the two shapes are + // congruent + TA_ASSERT(tile_norms_.range().rank() == nfused + gemm_helper.left_rank()); + TA_ASSERT(other.tile_norms_.range().rank() == + nfused + gemm_helper.right_rank()); + for (unsigned int d = 0u; d < nfused; ++d) + TA_ASSERT(left_extent[d] == right_extent[d]); + for (unsigned int i = gemm_helper.left_inner_begin(), + j = gemm_helper.right_inner_begin(); + i < gemm_helper.left_inner_end(); ++i, ++j) + TA_ASSERT(left_extent[nfused + i] == right_extent[nfused + j]); + integer H = 1, M = 1, N = 1, K = 1; for (unsigned int d = 0u; d < nfused; ++d) H *= left_extent[d]; for (unsigned int i = gemm_helper.left_outer_begin(); diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index bd3fd19125..570a7398c8 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -24,6 +24,7 @@ #define TILEDARRAY_TILE_OP_BATCHED_CONTRACT_REDUCE_H__INCLUDED #include +#include namespace TiledArray { namespace detail { From 23fc09546e9148e0b0a35665d8e1f9ba1a14c69f Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 18:43:47 -0400 Subject: [PATCH 11/25] tests: RAII-scope every einsum-route toggle in general_product_suite ScopedEinsumRoute restores the previous einsum_legacy_subworld() value on scope exit (ForceLegacyEinsum is now the legacy=true special case), so a throwing TA::einsum can no longer leak the toggle into later test cases. Also restores einsum_expression_route_matches_legacy to its intent: it was written when the legacy sub-World path was einsum's default, so after the default flip (ef0066c01) its "legacy" reference silently took the expression route (vacuous comparison) and the trailing manual toggle left the legacy path enabled for the rest of the test module. --- tests/general_product.cpp | 70 ++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 331306a94d..b0d622a31d 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -130,15 +130,22 @@ BOOST_AUTO_TEST_CASE(optimizer_rejects_implicit_reduction) { namespace { +/// sets the einsum legacy-subworld toggle within a scope; the previous value +/// is restored on scope exit (also on exceptions) +struct ScopedEinsumRoute { + bool prev_; + explicit ScopedEinsumRoute(const bool legacy) + : prev_(TA::detail::einsum_legacy_subworld()) { + TA::detail::einsum_legacy_subworld() = legacy; + } + ~ScopedEinsumRoute() { TA::detail::einsum_legacy_subworld() = prev_; } +}; + /// forces the legacy sub-World einsum within a scope, so einsum-based /// reference values remain an *independent* oracle for the expression route /// (einsum itself routes general products through the expression layer now) -struct ForceLegacyEinsum { - bool prev_; - ForceLegacyEinsum() : prev_(TA::detail::einsum_legacy_subworld()) { - TA::detail::einsum_legacy_subworld() = true; - } - ~ForceLegacyEinsum() { TA::detail::einsum_legacy_subworld() = prev_; } +struct ForceLegacyEinsum : ScopedEinsumRoute { + ForceLegacyEinsum() : ScopedEinsumRoute(true) {} }; /// makes a dense array over \p tr filled with an index-dependent pattern @@ -578,9 +585,11 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_batched_outer) { auto b = make_patterned_tot_array(world, tr_b, {3, 2}, 2.0); auto c = TA::einsum(a("b1,b2,i2,i1;x"), b("b1,b2;x,y"), "i2,i1,b1,b2;y"); - TiledArray::detail::einsum_legacy_subworld() = false; - auto c_new = TA::einsum(a("b1,b2,i2,i1;x"), b("b1,b2;x,y"), "i2,i1,b1,b2;y"); - TiledArray::detail::einsum_legacy_subworld() = true; + decltype(c) c_new; + { + ScopedEinsumRoute expression_route(false); + c_new = TA::einsum(a("b1,b2,i2,i1;x"), b("b1,b2;x,y"), "i2,i1,b1,b2;y"); + } BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c), 1e-10); } @@ -597,9 +606,11 @@ BOOST_AUTO_TEST_CASE(expression_general_product_t_tot_batched_outer) { auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); auto c = TA::einsum(a("i3,i1"), b("i3,i1,i2;x"), "i2,i3,i1;x"); - TiledArray::detail::einsum_legacy_subworld() = false; - auto c_new = TA::einsum(a("i3,i1"), b("i3,i1,i2;x"), "i2,i3,i1;x"); - TiledArray::detail::einsum_legacy_subworld() = true; + decltype(c) c_new; + { + ScopedEinsumRoute expression_route(false); + c_new = TA::einsum(a("i3,i1"), b("i3,i1,i2;x"), "i2,i3,i1;x"); + } BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c), 1e-10); } @@ -616,9 +627,11 @@ BOOST_AUTO_TEST_CASE(expression_general_product_tot_t_nonleading_fused) { auto b = make_patterned_array(world, tr_b, 2.0); auto c_ref = TA::einsum(a("i4,i1,m;x"), b("m,i4,K"), "i1,i4,K;x"); - TiledArray::detail::einsum_legacy_subworld() = false; - auto c_new = TA::einsum(a("i4,i1,m;x"), b("m,i4,K"), "i1,i4,K;x"); - TiledArray::detail::einsum_legacy_subworld() = true; + decltype(c_ref) c_new; + { + ScopedEinsumRoute expression_route(false); + c_new = TA::einsum(a("i4,i1,m;x"), b("m,i4,K"), "i1,i4,K;x"); + } BOOST_CHECK_SMALL(tot_max_abs_diff(c_new, c_ref), 1e-10); } @@ -916,17 +929,17 @@ BOOST_AUTO_TEST_CASE(expression_general_product_csv_like) { // mirror the CSV path exactly: through einsum (which canonicalizes the // target and permutes at the end), new route vs the legacy oracle auto c_ref = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); - TA::detail::einsum_legacy_subworld() = false; ArenaSpArr c1, c2; - try { - c1 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); - c2 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); - } catch (std::exception& ex) { - std::cerr << "EXPRESSION ROUTE THREW: " << ex.what() << std::endl; - TA::detail::einsum_legacy_subworld() = true; - throw; + { + ScopedEinsumRoute expression_route(false); + try { + c1 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); + c2 = TA::einsum(a("i2,i1,m;a"), b("m,i2,K"), "i1,i2,K;a"); + } catch (std::exception& ex) { + std::cerr << "EXPRESSION ROUTE THREW: " << ex.what() << std::endl; + throw; + } } - TA::detail::einsum_legacy_subworld() = true; // BISECT: same shape, CANONICAL target, direct expression (no einsum // wrapper, no final permutation) vs the legacy oracle @@ -1006,15 +1019,18 @@ BOOST_AUTO_TEST_CASE(einsum_expression_route_matches_legacy) { // interleaved (non-canonical) target reached by the final permutation // assignment auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k auto a = make_patterned_array(world, tr_a, 1.0); auto b = make_patterned_array(world, tr_b, 2.0); auto c_legacy = TA::einsum(a("b,i,j"), b("b,j,k"), "i,b,k"); - TA::detail::einsum_legacy_subworld() = false; - auto c_new = TA::einsum(a("b,i,j"), b("b,j,k"), "i,b,k"); - TA::detail::einsum_legacy_subworld() = true; + decltype(c_legacy) c_new; + { + ScopedEinsumRoute expression_route(false); + c_new = TA::einsum(a("b,i,j"), b("b,j,k"), "i,b,k"); + } BOOST_CHECK_SMALL(diff_norm(c_new, c_legacy, "i,b,k"), 1e-10); } From e5693a41820b12a830c05ba96a0918f4cc7e9138 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:39 -0400 Subject: [PATCH 12/25] expressions: deduce inner-node index sets top-down from the assignment target (Phase E) An index shared by the two children of a product is fused iff the node's target carries it, contracted otherwise; an index neither the sibling nor the target carries is consumed within the child subtree and not demanded of it. available_indices() (per-subtree leaf-annotation union, valid before init) supplies the up-pass; each child dictates the ORDER of its demand via preferred_layout() (canonical (fused, left-free, right-free) for products, pass-through elsewhere). Expressions consumed without a target (reductions) retain the bottom-up contraction convention. --- src/TiledArray/expressions/binary_engine.h | 29 +++++ src/TiledArray/expressions/leaf_engine.h | 12 +++ src/TiledArray/expressions/mult_engine.h | 120 ++++++++++++++------- src/TiledArray/expressions/unary_engine.h | 10 ++ 4 files changed, 135 insertions(+), 36 deletions(-) diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index 7d0e3ff0d2..c3d2e69da5 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -229,6 +229,35 @@ class BinaryEngine : public ExprEngine { right_inner_permtype_ == PermutationType::general); } + /// \return the indices this subtree can supply (Phase E up-pass): the + /// first-occurrence-ordered union of the children's available indices, + /// outer and inner lists separately. Valid before init: it depends only on + /// the leaf annotations, never on resolved (post-init) index sets. + BipartiteIndexList available_indices() const { + auto union_ = [](auto const& a, auto const& b) { + container::svector r(a.begin(), a.end()); + for (auto&& idx : b) + if (!a.count(idx)) r.push_back(idx); + return r; + }; + auto const l = left_.available_indices(); + auto const r = right_.available_indices(); + auto const out = union_(outer(l), outer(r)); + auto const in = union_(inner(l), inner(r)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: element-wise binary ops (Add/Subt) impose no layout of their + /// own (both children must produce the same set and are aligned by + /// permutation), so the demand is returned unchanged. MultEngine overrides + /// this with its canonical product layout. + const BipartiteIndexList& preferred_layout( + const BipartiteIndexList& demand) const { + return demand; + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/leaf_engine.h b/src/TiledArray/expressions/leaf_engine.h index 8804989d6f..112581c71f 100644 --- a/src/TiledArray/expressions/leaf_engine.h +++ b/src/TiledArray/expressions/leaf_engine.h @@ -127,6 +127,18 @@ class LeafEngine : public ExprEngine { /// This function is a noop since the index list is fixed. void init_indices() {} + /// \return the indices this subtree can supply (Phase E up-pass): for a + /// leaf, simply its annotation (set at construction, valid before init) + const BipartiteIndexList& available_indices() const { return indices_; } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: a leaf accepts any ordering (the consumer aligns against the + /// fixed annotation by permutation), so the demand is returned unchanged + const BipartiteIndexList& preferred_layout( + const BipartiteIndexList& demand) const { + return demand; + } + void init_distribution(World* world, const std::shared_ptr& pmap) { ExprEngine_::init_distribution(world, (pmap ? pmap : array_.pmap())); diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index d5db1bf093..3f74d4c7cd 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -281,44 +281,58 @@ class MultEngine : public ContEngine> { /// \param target_indices The target index list for this expression void init_indices(const BipartiteIndexList& target_indices) { - // to decide what type of product this is must initialize indices down - // the tree. - // N.B. since this may be a contraction we do not know the target indices - // for the left and right, hence do target-neutral initialization - BinaryEngine_::left_.init_indices(); - BinaryEngine_::right_.init_indices(); - - // Validate that the (bottom-up resolved) child indices are consistent - // with the target: every outer index of each child must appear in the - // other child or in the target (as a free, fused, or contracted index). - // A violation usually means a *general* product (fused + contracted + - // free indices) appears at an INNER node of the expression tree, where - // the role of a shared index cannot be deduced bottom-up; e.g. in the - // THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * ... the - // index r1 is fused in X*X but contracted downstream, while the - // bottom-up convention contracts it in X*X, orphaning the r1 of Z. - // Resolving this requires pushing the needed-index sets down the - // expression tree; until then materialize such inner products into - // explicit intermediates, so that every general product appears as the - // root of its own assignment (where the target determines the index - // roles). + // Deduce each child's index SET top-down (Phase E). The index set an + // inner node must produce depends on what its ancestors need: a shared + // index of the two children is fused iff this node's target carries it, + // contracted here otherwise; an index of one child that neither the + // sibling nor the target carries is consumed entirely within that + // child's subtree and is not demanded of it. + // Each child's demand is ordered target-kept-first (the indices this + // node's result keeps, in target order) followed by the + // contracted-here indices in the child's leaf-availability order -- + // fused indices thus lead, matching the canonical layouts of + // GeneralPermutationOptimizer (h leading) so the nbatch fold stays + // zero-copy. E.g. in the THC-like g("p,q,r,s") = X("p,r1") * X("q,r1") + // * Z("r1,r2") * ... the index r1 is demanded of X*X by its consumer + // (fused there), and contracted where Z meets the X*X subtree. + // N.B. expressions consumed WITHOUT a target (reductions, e.g. dot) + // take the no-target init_indices() overload below, which retains the + // bottom-up contraction convention -- general products under + // reductions remain unsupported. { - auto const& left_outer = outer(BinaryEngine_::left_.indices()); - auto const& right_outer = outer(BinaryEngine_::right_.indices()); - auto const& target_outer = outer(target_indices); - auto validate = [&](const IndexList& a, const IndexList& b) { - for (auto&& idx : a) - if (!b.count(idx) && !target_outer.count(idx)) - TA_EXCEPTION( - "MultEngine: an argument index appears in neither the other " - "argument nor the target. If a general product (fused + " - "contracted + free indices) appears at an inner node of the " - "expression tree, its index roles cannot be deduced " - "bottom-up; materialize it into an explicit intermediate so " - "that it appears as the root of its own assignment"); + auto const avail_l = BinaryEngine_::left_.available_indices(); + auto const avail_r = BinaryEngine_::right_.available_indices(); + auto demand = [](auto const& avail, auto const& sibling, + auto const& tgt) { + container::svector r; + for (auto&& idx : tgt) + if (avail.count(idx)) r.push_back(idx); + for (auto&& idx : avail) { + if (tgt.count(idx)) continue; + if (sibling.count(idx)) r.push_back(idx); + // else: the index is consumed entirely within the child's own + // subtree (contracted deeper down) -- not demanded here. A + // genuinely orphaned index (single occurrence, demanded nowhere = + // implicit trace, unsupported) surfaces as the leaf-level + // target-is-not-a-permutation error. + } + return r; }; - validate(left_outer, right_outer); - validate(right_outer, left_outer); + auto bipartite_demand = [&demand](auto const& avail, auto const& sib, + auto const& tgt) { + auto const out = demand(outer(avail), outer(sib), outer(tgt)); + auto const in = demand(inner(avail), inner(sib), inner(tgt)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + }; + // each child dictates the ORDER of its demand (preferred_layout): a + // product child reorders to its canonical (h, eA, eB) layout -- a + // general product cannot host a result permutation, and contraction + // consumers absorb any child layout via the GEMM transpose forms + BinaryEngine_::left_.init_indices(BinaryEngine_::left_.preferred_layout( + bipartite_demand(avail_l, avail_r, target_indices))); + BinaryEngine_::right_.init_indices(BinaryEngine_::right_.preferred_layout( + bipartite_demand(avail_r, avail_l, target_indices))); } this->product_type_ = compute_product_type( @@ -377,6 +391,40 @@ class MultEngine : public ContEngine> { } } + /// \return the layout this product prefers for producing the index set of + /// \p demand: the canonical general-product result layout (h, eA, eB) -- + /// indices supplied by both children (fused here) lead, then indices + /// supplied by the left child only, then by the right child only, each + /// group in demand order. h-leading is required when this node resolves to + /// a general product (general results cannot host a result permutation and + /// the nbatch fold needs the fused modes leading); for a pure contraction + /// (empty h) the (eA, eB) order matches the GEMM result, and for a pure + /// Hadamard the demand order is preserved. + BipartiteIndexList preferred_layout(const BipartiteIndexList& demand) const { + auto const avail_l = BinaryEngine_::left_.available_indices(); + auto const avail_r = BinaryEngine_::right_.available_indices(); + auto canonical = [](auto const& d, auto const& al, auto const& ar) { + container::svector h, ea, eb; + for (auto&& idx : d) { + bool const inl = al.count(idx); + bool const inr = ar.count(idx); + if (inl && inr) + h.push_back(idx); + else if (inl) + ea.push_back(idx); + else + eb.push_back(idx); + } + h.insert(h.end(), ea.begin(), ea.end()); + h.insert(h.end(), eb.begin(), eb.end()); + return h; + }; + auto const out = canonical(outer(demand), outer(avail_l), outer(avail_r)); + auto const in = canonical(inner(demand), inner(avail_l), inner(avail_r)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/unary_engine.h b/src/TiledArray/expressions/unary_engine.h index 631fca8fed..2b34cc716c 100644 --- a/src/TiledArray/expressions/unary_engine.h +++ b/src/TiledArray/expressions/unary_engine.h @@ -120,6 +120,16 @@ class UnaryEngine : ExprEngine { indices_ = arg_.indices(); } + /// \return the indices this subtree can supply (Phase E up-pass): a unary + /// op supplies exactly what its argument supplies (valid before init) + decltype(auto) available_indices() const { return arg_.available_indices(); } + + /// \return the layout this subtree prefers for producing the index set of + /// \p demand: a unary op defers to its argument's preference + decltype(auto) preferred_layout(const BipartiteIndexList& demand) const { + return arg_.preferred_layout(demand); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape From e9980d186549a036244c4f6840680e3222502bf6 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:52 -0400 Subject: [PATCH 13/25] expressions: general products honor arbitrary result layouts via a streaming re-permute A target that differs from the canonical (fused..., left-free..., right-free...) result layout cannot be folded into the batched tile op (BatchedContractReduce must be perm-free); evaluate canonically (Summa over a slab-replicated pmap) and re-permute to the target with a streaming UnaryEvalImpl. Honors the implicit-permute contract: when the consumer fuses the permutation into its own operation (transposed GEMM), only the tile ordinals/trange are remapped and contents stay canonical. Replaces the interleaved-target gate, enabling general products at inner expression-tree nodes and non-canonical root targets. --- src/TiledArray/expressions/cont_engine.h | 95 ++++++++++++++++++++---- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index dfab3ee664..d79d3a27db 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -27,6 +27,7 @@ #define TILEDARRAY_EXPRESSIONS_CONT_ENGINE_H__INCLUDED #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include namespace TiledArray { namespace expressions { @@ -179,6 +181,10 @@ class ContEngine : public BinaryEngine { // General (fused + contracted + free indices) products only: unsigned int n_fused_modes_ = 0; ///< # of leading fused (outer) modes size_type n_slabs_ = 1; ///< # of fused-index tile slabs + bool general_repermute_ = false; ///< whether the target layout differs + ///< from the canonical result layout, so + ///< the evaluated result is re-permuted + ///< by a streaming unary eval static unsigned int find(const BipartiteIndexList& indices, const std::string& index_label, unsigned int i, @@ -660,16 +666,13 @@ class ContEngine : public BinaryEngine { TA_ASSERT(nh > 0u); // else this is a pure contraction n_fused_modes_ = nh; - // initialize perm_; an interleaved target (a result permutation that - // mixes fused and free modes) is not yet supported -- the canonical - // result layout must equal the target + // initialize perm_; a target that differs from the canonical (fused..., + // left-free..., right-free...) result layout cannot be folded into the + // batched tile op (BatchedContractReduce must be perm-free), so the + // product is evaluated in its canonical layout and re-permuted to the + // target by a streaming unary eval (see make_dist_eval_general) this->init_perm(target_indices); - if (outer(target_indices) != outer(indices_)) - TA_EXCEPTION( - "general products (fused + contracted + free indices): targets " - "that interleave fused and free indices are not yet supported; " - "reorder the result annotation to (fused..., left-free..., " - "right-free...)"); + general_repermute_ = (outer(target_indices) != outer(indices_)); // the tile op operates on the folded (fused-mode-free) shapes const auto left_op = to_cblas_op(left_outer_permtype_); @@ -712,6 +715,12 @@ class ContEngine : public BinaryEngine { trange_ = make_trange_general(); shape_ = make_shape_general(); + if (general_repermute_) { + // consumers see the target layout; the canonical structures are + // recomputed in make_dist_eval_general for the inner Summa + trange_ = outer(perm_) * trange_; + shape_ = shape_.perm(outer(perm_)); + } if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { shape_ = shape_.mask(*ExprEngine_::override_ptr_->shape); @@ -834,9 +843,32 @@ class ContEngine : public BinaryEngine { } } + /// Streaming tile re-permute op for general products whose target layout + /// differs from the canonical (fused..., free...) result layout: the + /// batched tile op must stay perm-free, so the consumer-side unary eval + /// applies the result permutation per tile instead + struct GeneralRepermuteOp { + typedef value_type result_type; + typedef value_type argument_type; + static constexpr bool is_consumable = false; + BipartitePermutation perm; + /// false when the consumer fuses the permutation into its own operation + /// (implicit permute, e.g. a transposed GEMM): then only the tile + /// ordinals/trange are remapped (by the host UnaryEvalImpl) and the tile + /// contents are delivered in the canonical layout + bool permute_contents = true; + result_type operator()(const argument_type& tile) const { + if (!permute_contents) return tile; + TiledArray::detail::Noop noop; + return noop(tile, perm); + } + }; + /// Construct the distributed evaluator of a general product - /// \return The batched-Summa distributed evaluator for this expression + /// \return The batched-Summa distributed evaluator for this expression, + /// wrapped in a streaming re-permute when the target layout differs from + /// the canonical result layout dist_eval_type make_dist_eval_general() const { typedef TiledArray::detail::BatchedContractReduce batched_op_type; typedef TiledArray::detail::Summa { typename left_type::dist_eval_type left = left_.make_dist_eval(); typename right_type::dist_eval_type right = right_.make_dist_eval(); - std::shared_ptr pimpl = std::make_shared( - left, right, *world_, trange_, shape_, pmap_, perm_, - batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + if (!general_repermute_) { + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, trange_, shape_, pmap_, perm_, + batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + return dist_eval_type(pimpl); + } - return dist_eval_type(pimpl); + // evaluate in the canonical layout (Summa with perm-free op), then + // re-permute tiles to the target layout with a streaming unary eval; + // trange_/shape_ hold the target-layout structures (see + // init_struct_general), the canonical ones are recomputed here + auto const canonical_trange = make_trange_general(); + auto const canonical_shape = [this]() { + auto s = make_shape_general(); + if (ExprEngine_::override_ptr_ && ExprEngine_::override_ptr_->shape) { + // the consumer-supplied mask is expressed in the target layout + auto const inv_perm = outer(perm_).inv(); + s = s.mask(ExprEngine_::override_ptr_->shape->perm(inv_perm)); + } + return s; + }(); + // the inner Summa's result placement must be slab-replicated (the owner + // of a tile independent of its slab index), regardless of the + // (target-layout) pmap the consumer supplied for this node + auto canonical_pmap = std::make_shared( + *world_, proc_grid_.make_pmap(), n_slabs_); + std::shared_ptr pimpl = std::make_shared( + left, right, *world_, canonical_trange, canonical_shape, canonical_pmap, + BipartitePermutation{}, batched_op_type(op_, n_fused_modes_), K_, + proc_grid_, n_slabs_); + dist_eval_type canonical(pimpl); + + typedef TiledArray::detail::UnaryEvalImpl< + dist_eval_type, GeneralRepermuteOp, typename Derived::policy> + repermute_impl_type; + std::shared_ptr wrapper = + std::make_shared( + canonical, *world_, trange_, shape_, pmap_, perm_, + GeneralRepermuteOp{perm_, !this->implicit_permute_outer_}); + return dist_eval_type(wrapper); } /// Expression identification tag From 7c5d1c2b70fcff5b2221ac08a43b557a377ecbaa Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:03:53 -0400 Subject: [PATCH 14/25] tests: inner-node general products (THC, depth-2, non-canonical root target) --- tests/general_product.cpp | 74 +++++++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index b0d622a31d..8ce13debe5 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -235,26 +235,70 @@ BOOST_AUTO_TEST_CASE(expression_general_product_dense_batched_outer) { BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); } -BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_gated) { - // THC-style reconstruction: +BOOST_AUTO_TEST_CASE(expression_general_product_noncanonical_root_target) { + // a root-level general product with a NON-canonical target layout: the + // product evaluates canonically (r1,p,q) and is re-permuted to the target + // by the streaming unary eval + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + + TA::TArrayD w, i1, w_ref; + BOOST_REQUIRE_NO_THROW(w("p,q,r1") = x("p,r1") * x("q,r1")); + i1("r1,p,q") = x("p,r1") * x("q,r1"); // canonical evaluation + w_ref("p,q,r1") = i1("r1,p,q"); // plain permute assignment + + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r1"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_depth2) { + // minimal inner-node case: a general product (fused r1) feeding a + // contraction over r1 -- the general child evaluates canonically + // (r1,p,q) and is re-permuted on the fly to the consumer's GEMM layout + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + + TA::TArrayD w; + BOOST_REQUIRE_NO_THROW(w("p,q,r2") = (x("p,r1") * x("q,r1")) * z("r1,r2")); + + TA::TArrayD i1, w_ref; + i1("r1,p,q") = x("p,r1") * x("q,r1"); + w_ref("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r2"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_inner_node_thc) { + // THC-style reconstruction in ONE expression: // g("p,q,r,s") = X("p,r1") * X("q,r1") * Z("r1,r2") * X("r,r2") * X("s,r2") - // r1 is fused in X("p,r1") * X("q,r1") but contracted downstream. The - // first product is an INNER node of the expression tree, where the role of - // r1 cannot be deduced bottom-up (the target reaches only the root); - // resolving this requires top-down index-set deduction (deferred). Until - // then: an informative error, not garbage (bottom-up, X*X would contract - // r1, orphaning the r1 of Z). + // r1 is fused in X("p,r1") * X("q,r1") but contracted downstream, so the + // first product is a general product at an INNER node of the (left-deep) + // expression tree. The top-down index-set deduction demands r1 of the X*X + // node (its consumer carries it) and contracts it where Z meets that + // subtree; higher up, r1 is dropped from the demand (consumed entirely + // within). Verified against the same chain staged through explicit + // intermediates (the pre-deduction recipe, itself differential-tested + // against einsum in expression_general_product_thc_intermediates). auto& world = TA::get_default_world(); TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary - TA::TArrayD x(world, tr_x); - TA::TArrayD z(world, tr_z); - x.fill(1.0); - z.fill(1.0); + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + TA::TArrayD g; - BOOST_CHECK_THROW( - g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * x("r,r2") * x("s,r2"), - TiledArray::Exception); + BOOST_REQUIRE_NO_THROW(g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * + x("r,r2") * x("s,r2")); + + TA::TArrayD i1, i2, i3, g_ref; + i1("r1,p,q") = x("p,r1") * x("q,r1"); + i2("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + i3("r2,p,q,r") = i2("p,q,r2") * x("r,r2"); + g_ref("p,q,r,s") = i3("r2,p,q,r") * x("s,r2"); + + BOOST_CHECK_SMALL(diff_norm(g, g_ref, "p,q,r,s"), 1e-10); } BOOST_AUTO_TEST_CASE(expression_general_product_thc_intermediates) { From f8f4090f54e1a5b98b01c867fc545b4c9f2f3807 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 10:00:28 -0400 Subject: [PATCH 15/25] expressions: address PR review (Phase E) - GeneralRepermuteOp: store/apply only the OUTER result-layout permutation. The streaming re-permute wrapper exists to reorder a general product's outer (result) layout; inner (within-cell) permutation of ToT results is handled separately (init_struct_general / implicit_permute_inner_). Using the full bipartite perm_ would also permute inner cells -- a no-op when the inner perm is identity, but a latent double-apply if an inner perm is ever deferred to a downstream op. Pass outer(perm_) into the op; the host UnaryEvalImpl still receives full perm_ for ordinal/trange remap. - MultEngine down-pass: materialize each child demand as a named lvalue before preferred_layout(), which returns a reference to its argument for leaf/binary engines -- avoids a needlessly fragile bind-to-temporary. --- src/TiledArray/expressions/cont_engine.h | 9 +++++++-- src/TiledArray/expressions/mult_engine.h | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index d79d3a27db..2db08ea23a 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -851,7 +851,12 @@ class ContEngine : public BinaryEngine { typedef value_type result_type; typedef value_type argument_type; static constexpr bool is_consumable = false; - BipartitePermutation perm; + /// Only the *outer* (result-layout) permutation is applied here; inner + /// (within-cell) permutation of tensor-of-tensor results is handled + /// separately (see init_struct_general / implicit_permute_inner_), so this + /// op stores a plain outer Permutation to avoid accidentally permuting + /// inner contents. + Permutation perm; /// false when the consumer fuses the permutation into its own operation /// (implicit permute, e.g. a transposed GEMM): then only the tile /// ordinals/trange are remapped (by the host UnaryEvalImpl) and the tile @@ -917,7 +922,7 @@ class ContEngine : public BinaryEngine { std::shared_ptr wrapper = std::make_shared( canonical, *world_, trange_, shape_, pmap_, perm_, - GeneralRepermuteOp{perm_, !this->implicit_permute_outer_}); + GeneralRepermuteOp{outer(perm_), !this->implicit_permute_outer_}); return dist_eval_type(wrapper); } diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 3f74d4c7cd..3e345b15e5 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -328,11 +328,18 @@ class MultEngine : public ContEngine> { // each child dictates the ORDER of its demand (preferred_layout): a // product child reorders to its canonical (h, eA, eB) layout -- a // general product cannot host a result permutation, and contraction - // consumers absorb any child layout via the GEMM transpose forms - BinaryEngine_::left_.init_indices(BinaryEngine_::left_.preferred_layout( - bipartite_demand(avail_l, avail_r, target_indices))); - BinaryEngine_::right_.init_indices(BinaryEngine_::right_.preferred_layout( - bipartite_demand(avail_r, avail_l, target_indices))); + // consumers absorb any child layout via the GEMM transpose forms. + // Materialize the demands as named lvalues: some preferred_layout() + // overloads (leaf/binary/unary) return a reference to their argument, so + // binding that to a temporary demand would be needlessly fragile. + const BipartiteIndexList left_demand = + bipartite_demand(avail_l, avail_r, target_indices); + const BipartiteIndexList right_demand = + bipartite_demand(avail_r, avail_l, target_indices); + BinaryEngine_::left_.init_indices( + BinaryEngine_::left_.preferred_layout(left_demand)); + BinaryEngine_::right_.init_indices( + BinaryEngine_::right_.preferred_layout(right_demand)); } this->product_type_ = compute_product_type( From 3cb0a3ba9467dfaf5f638371522567e343e1d946 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:51:23 -0400 Subject: [PATCH 16/25] expressions: ScalMultEngine evaluates general products; share the tree-deduction down-pass The Phase E child-demand deduction moves from MultEngine into BinaryEngine::init_children_indices and ScalMultEngine adopts it, along with the full MultEngine routing for general products (inner_product_type_ classification + inner-General gate, init_struct_general, init_distribution_general, make_trange_general, make_dist_eval_general), replacing its use-einsum-instead exception. --- src/TiledArray/expressions/binary_engine.h | 46 +++++++++++ src/TiledArray/expressions/mult_engine.h | 94 ++++++++-------------- 2 files changed, 79 insertions(+), 61 deletions(-) diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index c3d2e69da5..856170803f 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -258,6 +258,52 @@ class BinaryEngine : public ExprEngine { return demand; } + /// Deduce each child's index SET top-down from this node's target and + /// initialize the children with the deduced demands (the product-engine + /// down-pass). The index set an inner node must produce depends on what + /// its ancestors need: a shared index of the two children is fused iff + /// this node's target carries it, contracted here otherwise; an index of + /// one child that neither the sibling nor the target carries is consumed + /// entirely within that child's subtree and is not demanded of it (a + /// genuinely orphaned index -- an implicit trace, unsupported -- surfaces + /// as the leaf-level target-is-not-a-permutation error). Each child's + /// demand is ordered target-kept-first followed by the contracted-here + /// indices in leaf-availability order, then reordered by the child's own + /// preferred_layout() (a product child reorders to its canonical + /// (fused, left-free, right-free) layout -- a general product cannot host + /// a result permutation, and contraction consumers absorb any child + /// layout via the GEMM transpose forms). + void init_children_indices(const BipartiteIndexList& target_indices) { + auto const avail_l = left_.available_indices(); + auto const avail_r = right_.available_indices(); + auto demand = [](auto const& avail, auto const& sibling, auto const& tgt) { + container::svector r; + for (auto&& idx : tgt) + if (avail.count(idx)) r.push_back(idx); + for (auto&& idx : avail) { + if (tgt.count(idx)) continue; + if (sibling.count(idx)) r.push_back(idx); + } + return r; + }; + auto bipartite_demand = [&demand](auto const& avail, auto const& sib, + auto const& tgt) { + auto const out = demand(outer(avail), outer(sib), outer(tgt)); + auto const in = demand(inner(avail), inner(sib), inner(tgt)); + return BipartiteIndexList(IndexList(out.begin(), out.end()), + IndexList(in.begin(), in.end())); + }; + // Materialize the demands as named lvalues: some preferred_layout() + // overloads (leaf/binary/unary) return a reference to their argument, so + // binding that to a temporary demand would be needlessly fragile. + const BipartiteIndexList left_demand = + bipartite_demand(avail_l, avail_r, target_indices); + const BipartiteIndexList right_demand = + bipartite_demand(avail_r, avail_l, target_indices); + left_.init_indices(left_.preferred_layout(left_demand)); + right_.init_indices(right_.preferred_layout(right_demand)); + } + /// Initialize result tensor structure /// This function will initialize the permutation, tiled range, and shape diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 3e345b15e5..1ad68119c5 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -299,48 +299,7 @@ class MultEngine : public ContEngine> { // take the no-target init_indices() overload below, which retains the // bottom-up contraction convention -- general products under // reductions remain unsupported. - { - auto const avail_l = BinaryEngine_::left_.available_indices(); - auto const avail_r = BinaryEngine_::right_.available_indices(); - auto demand = [](auto const& avail, auto const& sibling, - auto const& tgt) { - container::svector r; - for (auto&& idx : tgt) - if (avail.count(idx)) r.push_back(idx); - for (auto&& idx : avail) { - if (tgt.count(idx)) continue; - if (sibling.count(idx)) r.push_back(idx); - // else: the index is consumed entirely within the child's own - // subtree (contracted deeper down) -- not demanded here. A - // genuinely orphaned index (single occurrence, demanded nowhere = - // implicit trace, unsupported) surfaces as the leaf-level - // target-is-not-a-permutation error. - } - return r; - }; - auto bipartite_demand = [&demand](auto const& avail, auto const& sib, - auto const& tgt) { - auto const out = demand(outer(avail), outer(sib), outer(tgt)); - auto const in = demand(inner(avail), inner(sib), inner(tgt)); - return BipartiteIndexList(IndexList(out.begin(), out.end()), - IndexList(in.begin(), in.end())); - }; - // each child dictates the ORDER of its demand (preferred_layout): a - // product child reorders to its canonical (h, eA, eB) layout -- a - // general product cannot host a result permutation, and contraction - // consumers absorb any child layout via the GEMM transpose forms. - // Materialize the demands as named lvalues: some preferred_layout() - // overloads (leaf/binary/unary) return a reference to their argument, so - // binding that to a temporary demand would be needlessly fragile. - const BipartiteIndexList left_demand = - bipartite_demand(avail_l, avail_r, target_indices); - const BipartiteIndexList right_demand = - bipartite_demand(avail_r, avail_l, target_indices); - BinaryEngine_::left_.init_indices( - BinaryEngine_::left_.preferred_layout(left_demand)); - BinaryEngine_::right_.init_indices( - BinaryEngine_::right_.preferred_layout(right_demand)); - } + BinaryEngine_::init_children_indices(target_indices); this->product_type_ = compute_product_type( outer(BinaryEngine_::left_.indices()), @@ -713,29 +672,34 @@ class ScalMultEngine /// \param target_indices The target index list for this expression void init_indices(const BipartiteIndexList& target_indices) { - BinaryEngine_::left_.init_indices(); - BinaryEngine_::right_.init_indices(); + // deduce the children's index sets top-down (see + // BinaryEngine::init_children_indices), then classify and route exactly + // as MultEngine does (the scalar factor does not affect index roles) + BinaryEngine_::init_children_indices(target_indices); + this->product_type_ = compute_product_type( outer(BinaryEngine_::left_.indices()), outer(BinaryEngine_::right_.indices()), outer(target_indices)); + this->inner_product_type_ = compute_product_type( + inner(BinaryEngine_::left_.indices()), + inner(BinaryEngine_::right_.indices()), inner(target_indices)); + + if (this->inner_product_type_ == TensorProduct::General) + TA_EXCEPTION( + "ScalMultEngine: general products (fused + contracted + free " + "indices) between the inner (nested) indices of tensors-of-tensors " + "are not supported"); if (this->product_type() == TensorProduct::Hadamard) { // since already initialized left and right arg indices assign the target // indices BinaryEngine_::perm_indices(target_indices); } else if (this->product_type() == TensorProduct::General) { - // layout via GeneralPermutationOptimizer (the target determines which - // shared indices are fused vs contracted), then propagate to children - if (!this->implicit_permute()) { - BinaryEngine_::template init_indices_( - target_indices); - if (BinaryEngine_::left_indices_ != BinaryEngine_::left_.indices()) - BinaryEngine_::left_.perm_indices(BinaryEngine_::left_indices_); - if (BinaryEngine_::right_indices_ != BinaryEngine_::right_.indices()) - BinaryEngine_::right_.perm_indices(BinaryEngine_::right_indices_); - } + this->perm_indices(target_indices); } else { - ContEngine_::init_indices(target_indices); + auto children_initialized = true; + ContEngine_::init_indices(children_initialized); + ContEngine_::perm_indices(target_indices); } } @@ -767,12 +731,14 @@ class ScalMultEngine /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { - // TODO Phase B (batched Summa): evaluate general products natively - if (this->product_type() == TensorProduct::General) - TA_EXCEPTION( - "ScalMultEngine: evaluation of general products (fused + contracted " - "+ free indices) via the expression layer is not yet implemented; " - "use TiledArray::einsum() instead"); + if (this->product_type() == TensorProduct::General) { + // the inner tile op (for tensors-of-tensors) must be initialized + // first; init_struct_general consumes element_nonreturn_op_ and the + // arena plan it builds + this->init_inner_tile_op(inner(target_indices)); + ContEngine_::init_struct_general(target_indices); + return; + } this->init_perm(target_indices); @@ -794,6 +760,8 @@ class ScalMultEngine std::shared_ptr pmap) { if (this->product_type() == TensorProduct::Contraction) ContEngine_::init_distribution(world, pmap); + else if (this->product_type() == TensorProduct::General) + ContEngine_::init_distribution_general(world, pmap); else BinaryEngine_::init_distribution(world, pmap); } @@ -804,6 +772,8 @@ class ScalMultEngine dist_eval_type make_dist_eval() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_dist_eval(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_dist_eval_general(); else return BinaryEngine_::make_dist_eval(); } @@ -814,6 +784,8 @@ class ScalMultEngine trange_type make_trange() const { if (this->product_type() == TensorProduct::Contraction) return ContEngine_::make_trange(); + else if (this->product_type() == TensorProduct::General) + return ContEngine_::make_trange_general(); else return BinaryEngine_::make_trange(); } From ce633adba2d39764863b88d79c488e4aca69a76f Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:52:14 -0400 Subject: [PATCH 17/25] cont_engine: identity-tolerant inner-perm gate; absorb the scalar prefactor in inner-Scale ops The general-product ToT gate fired on a non-null but IDENTITY inner permutation (the bipartite perm is constructed whole when only the outer modes are re-permuted by the streaming wrapper); require a genuinely non-identity inner perm. The inner-Scale element ops (mixed T x ToT) never carried the expression-level scalar prefactor -- invisible while only MultEngine (factor == 1) reached them; the fallback op now absorbs factor_ and the factor-free fused arena ops are gated to factor == 1. --- src/TiledArray/expressions/cont_engine.h | 98 +++++++++++++----------- 1 file changed, 54 insertions(+), 44 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 2db08ea23a..8323d41259 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -683,10 +683,13 @@ class ContEngine : public BinaryEngine { outer_size(right_indices_) - nh); } else { // the batched tile op must be perm-free (BatchedContractReduce cannot - // host the folded-rank result permutation); the outer perm is empty by - // the interleaved-target gate above, so only an explicit inner result - // permutation can require one - if (!implicit_permute_inner_ && bool(inner(perm_))) + // host the folded-rank result permutation); the outer perm is handled + // by the streaming re-permute (general_repermute_), so only a genuine + // (non-identity) explicit inner result permutation requires one. N.B. + // perm_ may carry a non-null identity inner component when only the + // outer modes are permuted (the bipartite perm is constructed whole). + if (!implicit_permute_inner_ && bool(inner(perm_)) && + !inner(perm_).is_identity()) TA_EXCEPTION( "general products of tensors-of-tensors: a non-identity inner " "result permutation is not yet supported; reorder the inner " @@ -1707,7 +1710,11 @@ class ContEngine : public BinaryEngine { TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; if constexpr (arena_eligible_scale) { - if (this->outer_product_uses_summa()) { + // the fused arena scale ops are factor-free; a non-unit + // expression-level prefactor (ScalMult) takes the fallback op, + // which absorbs it + if (this->outer_product_uses_summa() && + this->factor_ == scalar_type(1)) { // The inner perm handed to the plan must match how the inner // *result* permutation is applied for this result cell type -- // and the two cell types apply it in different places: @@ -1741,45 +1748,48 @@ class ContEngine : public BinaryEngine { // cells. The Hadamard outer product is an assignment // `result = (perm ^ tot) * scalar`, which needs value-returning // `scale`; only owning inner cells support it. - auto fallback_op = [perm = !this->implicit_permute_inner_ - ? inner(this->perm_) - : Permutation{}, - outer_uses_summa = - this->outer_product_uses_summa()]( - result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - if (outer_uses_summa) { - using TiledArray::axpy_to; - if constexpr (tot_x_t) { - if (left.empty()) return; // absent cell: no contribution - if (perm) - axpy_to(result, left, right, perm); - else - axpy_to(result, left, right); - } else { - if (right.empty()) return; // absent cell: no contribution - if (perm) - axpy_to(result, right, left, perm); - else - axpy_to(result, right, left); - } - } else { - if constexpr (!TiledArray::is_tensor_view_v< - result_tile_element_type>) { - using TiledArray::scale; - if constexpr (tot_x_t) - result = perm ? scale(left, right, perm) : scale(left, right); - else - result = perm ? scale(right, left, perm) : scale(right, left); - } else { - TA_EXCEPTION( - "Tensor scale-inner Hadamard-outer product: a " - "view result cell cannot be value-assigned a fresh " - "scaled tensor"); - } - } - }; + // N.B. the expression-level scalar prefactor (factor_, != 1 for + // ScalMult expressions) multiplies the plain operand's element + auto fallback_op = + [perm = !this->implicit_permute_inner_ ? inner(this->perm_) + : Permutation{}, + outer_uses_summa = this->outer_product_uses_summa(), + factor = this->factor_](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (outer_uses_summa) { + using TiledArray::axpy_to; + if constexpr (tot_x_t) { + if (left.empty()) return; // absent cell: no contribution + if (perm) + axpy_to(result, left, right * factor, perm); + else + axpy_to(result, left, right * factor); + } else { + if (right.empty()) return; // absent cell: no contribution + if (perm) + axpy_to(result, right, left * factor, perm); + else + axpy_to(result, right, left * factor); + } + } else { + if constexpr (!TiledArray::is_tensor_view_v< + result_tile_element_type>) { + using TiledArray::scale; + if constexpr (tot_x_t) + result = perm ? scale(left, right * factor, perm) + : scale(left, right * factor); + else + result = perm ? scale(right, left * factor, perm) + : scale(right, left * factor); + } else { + TA_EXCEPTION( + "Tensor scale-inner Hadamard-outer product: a " + "view result cell cannot be value-assigned a fresh " + "scaled tensor"); + } + } + }; if constexpr (arena_eligible_scale) { if (this->arena_plan_) { if constexpr (tot_x_t) From e3f8e368714ca0cd70f3f17556c09a9f6d8473d9 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 21:52:14 -0400 Subject: [PATCH 18/25] tests: mixed T x ToT at inner tree nodes (depth-2 chains, inner general, scaled) --- tests/general_product.cpp | 69 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 8ce13debe5..498c60beec 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -528,6 +528,75 @@ BOOST_AUTO_TEST_CASE(expression_general_product_t_times_tot) { BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_depth2_chains) { + // mixed T/ToT products at INNER nodes of the expression tree: + // left-deep: w("i,k;x") = (s("i,j") * t("j,m")) * c("m,k;x") + // right-deep: w("i,k;x") = s("i,j") * (t("j,m") * c("m,k;x")) + // (plain contraction at every node, inner Scale where a plain factor + // meets the nested one) + auto& world = TA::get_default_world(); + TA::TiledRange tr_s{{0, 2, 4}, {0, 2, 5}}; // i, j + TA::TiledRange tr_t{{0, 2, 5}, {0, 3, 4}}; // j, m + TA::TiledRange tr_c{{0, 3, 4}, {0, 2, 3}}; // m, k + auto s = make_patterned_array(world, tr_s, 1.0); + auto t = make_patterned_array(world, tr_t, 2.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 3.0); + + // staged reference + TArrayToT i1, w_ref; + i1("j,k;x") = t("j,m") * c("m,k;x"); + w_ref("i,k;x") = s("i,j") * i1("j,k;x"); + + TArrayToT w_l, w_r; + BOOST_REQUIRE_NO_THROW(w_l("i,k;x") = (s("i,j") * t("j,m")) * c("m,k;x")); + BOOST_CHECK_SMALL(tot_max_abs_diff(w_l, w_ref), 1e-10); + BOOST_REQUIRE_NO_THROW(w_r("i,k;x") = s("i,j") * (t("j,m") * c("m,k;x"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w_r, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_inner_general) { + // a mixed T x ToT GENERAL product at an INNER node: + // w("i,j;x") = (g("b,i") * c("b,j;x")) * h("b") + // b is fused where g meets c (demanded by the h factor above) and + // contracted at the root + auto& world = TA::get_default_world(); + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // b, i + TA::TiledRange tr_c{{0, 3, 5}, {0, 2, 3}}; // b, j + TA::TiledRange tr_h{{0, 3, 5}}; // b + auto g = make_patterned_array(world, tr_g, 1.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 2.0); + TA::TArrayD h(world, tr_h); + h.fill(1.5); + + TArrayToT i1, w_ref; + i1("b,i,j;x") = g("b,i") * c("b,j;x"); // depth-1 mixed general + w_ref("i,j;x") = i1("b,i,j;x") * h("b"); + + TArrayToT w; + try { + w("i,j;x") = (g("b,i") * c("b,j;x")) * h("b"); + } catch (std::exception& e) { + BOOST_FAIL(std::string("threw: ") + e.what()); + } + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_scaled) { + // scaled mixed T x ToT general product (ScalMultEngine path): + // w("b,i,k;x") = 2 * (a("b,i,j") * c("b,j,k;x")) + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_c{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto c = make_patterned_tot_array(world, tr_c, {3}, 2.0); + + TArrayToT w_ref0, w_ref, w; + w_ref0("b,i,k;x") = a("b,i,j") * c("b,j,k;x"); + w_ref("b,i,k;x") = 2.0 * w_ref0("b,i,k;x"); + BOOST_REQUIRE_NO_THROW(w("b,i,k;x") = 2.0 * (a("b,i,j") * c("b,j,k;x"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { // the PNO-CC PPL building-block shape: ToT x ToT with an EMPTY right // outer-external set and an inner OUTER-product: From 07f5b9759c0a16a3045d87115b67cb24ec67cf2a Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 22:37:53 -0400 Subject: [PATCH 19/25] tests: general products with block expressions (block operands; assignment into a block view) --- tests/general_product.cpp | 42 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 498c60beec..f7372c4079 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -597,6 +597,48 @@ BOOST_AUTO_TEST_CASE(expression_mixed_t_tot_scaled) { BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_general_product_block_operands) { + // general product of BLOCK views: C("b,i,k") = A.block * B.block with + // b fused, j contracted, i/k free; the blocks restrict b and i/k + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; // (b,i,j) / (b,j,k) + auto a = make_patterned_array(world, tr, 1.0); + auto t = make_patterned_array(world, tr, 2.0); + + // materialized blocks as the reference operands + TA::TArrayD ab, tb, w_ref, w; + ab("b,i,j") = a("b,i,j").block({0, 0, 0}, {1, 2, 2}); + tb("b,j,k") = t("b,j,k").block({0, 0, 0}, {1, 2, 1}); + w_ref("b,i,k") = ab("b,i,j") * tb("b,j,k"); + + BOOST_REQUIRE_NO_THROW(w("b,i,k") = a("b,i,j").block({0, 0, 0}, {1, 2, 2}) * + t("b,j,k").block({0, 0, 0}, {1, 2, 1})); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_into_block) { + // general product assigned INTO a block view of the result: + // W.block = A * B (b fused, j contracted) + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 2}, {0, 2}, {0, 2, 4}}; // b, i, j + TA::TiledRange tr_b{{0, 2}, {0, 2, 4}, {0, 2}}; // b, j, k + TA::TiledRange tr_w{{0, 2, 4}, {0, 2, 4}, {0, 2, 4}}; + auto a = make_patterned_array(world, tr_a, 1.0); + auto t = make_patterned_array(world, tr_b, 2.0); + + TA::TArrayD prod; + prod("b,i,k") = a("b,i,j") * t("b,j,k"); + + TA::TArrayD w(world, tr_w), w_ref(world, tr_w); + w.fill(0.0); + w_ref.fill(0.0); + w_ref("b,i,k").block({0, 0, 0}, {1, 1, 1}) = prod("b,i,k"); + + BOOST_REQUIRE_NO_THROW(w("b,i,k").block({0, 0, 0}, {1, 1, 1}) = + a("b,i,j") * t("b,j,k")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "b,i,k"), 1e-10); +} + BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { // the PNO-CC PPL building-block shape: ToT x ToT with an EMPTY right // outer-external set and an inner OUTER-product: From 0abc2dca45443718039ee867dacaf9f18f2937b7 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Thu, 11 Jun 2026 23:40:44 -0400 Subject: [PATCH 20/25] expressions: gate no-external ToT general products; tests for sums under products, kitchen-sink, blocks in trees A ToT x ToT general product with no external (free) outer indices -- every outer index fused or contracted -- segfaulted in the folded GEMM; gate it with an informative error (einsum() evaluates this shape natively via its no-external regime). New tests: a SUM nested under a product with a general summand (the down-pass prunes summand-internal contraction indices from the sum's demand by construction); the kitchen-sink expression combining a THC-like batching index, a mixed T x ToT general product, a ToT x ToT general product with an inner outer-product, and a ScalMult prefactor; a block leaf under an inner general node; a re-permuted general product assigned into a block view; the no-external gate. --- src/TiledArray/expressions/cont_engine.h | 13 +++ tests/general_product.cpp | 108 +++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 8323d41259..2c9cbbd7a8 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -666,6 +666,19 @@ class ContEngine : public BinaryEngine { TA_ASSERT(nh > 0u); // else this is a pure contraction n_fused_modes_ = nh; + // a general product of tensors-of-tensors with NO external (free) outer + // indices (every outer index fused or contracted, e.g. + // C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) is not supported by the + // batched tile op yet (its folded GEMM has no free modes); + // einsum() evaluates this shape natively + if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { + if (outer_size(indices_) == nh) + TA_EXCEPTION( + "general products of tensors-of-tensors without external (free) " + "outer indices are not yet supported in the expression layer; " + "use TiledArray::einsum() for this contraction"); + } + // initialize perm_; a target that differs from the canonical (fused..., // left-free..., right-free...) result layout cannot be folded into the // batched tile op (BatchedContractReduce must be perm-free), so the diff --git a/tests/general_product.cpp b/tests/general_product.cpp index f7372c4079..2fd500311a 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -639,6 +639,114 @@ BOOST_AUTO_TEST_CASE(expression_general_product_into_block) { BOOST_CHECK_SMALL(diff_norm(w, w_ref, "b,i,k"), 1e-10); } +BOOST_AUTO_TEST_CASE(expression_general_product_block_in_tree) { + // a BLOCK leaf under an inner general node of a deeper tree: + // w("p,q,r2") = (x.block("p,r1") * y("q,r1")) * z("r1,r2") + // r1 is fused at the inner (general) node, contracted at the root + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // p x r1 + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // r1 x r2 + auto x = make_patterned_array(world, tr_x, 1.0); + auto y = make_patterned_array(world, tr_x, 1.5); + auto z = make_patterned_array(world, tr_z, 2.0); + + TA::TArrayD xb, i1, w_ref, w; + xb("p,r1") = x("p,r1").block({0, 0}, {1, 2}); + i1("r1,p,q") = xb("p,r1") * y("q,r1"); + w_ref("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + + BOOST_REQUIRE_NO_THROW( + w("p,q,r2") = (x("p,r1").block({0, 0}, {1, 2}) * y("q,r1")) * z("r1,r2")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r2"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_repermute_into_block) { + // a general product with a NON-canonical target layout (streaming + // re-permute) assigned INTO a block view of the result + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2}, {0, 3, 5}}; // p x r1 + TA::TiledRange tr_w{{0, 2, 4}, {0, 2, 4}, {0, 3, 5}}; // p, q, r1 + auto x = make_patterned_array(world, tr_x, 1.0); + auto y = make_patterned_array(world, tr_x, 1.5); + + TA::TArrayD i1, w(world, tr_w), w_ref(world, tr_w); + w.fill(0.0); + w_ref.fill(0.0); + i1("r1,p,q") = x("p,r1") * y("q,r1"); // canonical evaluation + w_ref("p,q,r1").block({0, 0, 0}, {1, 1, 2}) = i1("r1,p,q"); + + BOOST_REQUIRE_NO_THROW(w("p,q,r1").block({0, 0, 0}, {1, 1, 2}) = + x("p,r1") * y("q,r1")); + BOOST_CHECK_SMALL(diff_norm(w, w_ref, "p,q,r1"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_sum_under_product) { + // a SUM nested under a product, with a general product as one summand: + // F("i,j") = A("x,i") * (B("x,k") * C("x,k,j") + D("x,j")) + // x is fused inside B*C (demanded by the sum's consumer), k is contracted + // within the summand (never escapes), and x is contracted at the root. + // The down-pass prunes k from the sum's demand automatically (it appears + // in neither the target nor A) and hands (j,x) to BOTH summands. + auto& world = TA::get_default_world(); + TA::TiledRange tr_a{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_b{{0, 3, 5}, {0, 2, 3}}; // x, k + TA::TiledRange tr_c{{0, 3, 5}, {0, 2, 3}, {0, 2, 4}}; // x, k, j + TA::TiledRange tr_d{{0, 3, 5}, {0, 2, 4}}; // x, j + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + auto c = make_patterned_array(world, tr_c, 3.0); + auto d = make_patterned_array(world, tr_d, 4.0); + + TA::TArrayD i1, s, f_ref, f; + i1("x,j") = b("x,k") * c("x,k,j"); // general: x fused, k contracted + s("x,j") = i1("x,j") + d("x,j"); + f_ref("i,j") = a("x,i") * s("x,j"); + + BOOST_REQUIRE_NO_THROW(f("i,j") = + a("x,i") * (b("x,k") * c("x,k,j") + d("x,j"))); + BOOST_CHECK_SMALL(diff_norm(f, f_ref, "i,j"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_kitchen_sink) { + // one expression combining a THC-like batching index (x: fused at the + // inner node, contracted at the root), a mixed T x ToT general product, + // a ToT x ToT general product with an inner outer-product, and a scalar + // prefactor (ScalMult at the root): + // W("i,j,m;a,b") = 2 * ((g("x,i") * cv("x,j;a")) * dv("x,i,m;b")) + auto& world = TA::get_default_world(); + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_cv{{0, 3, 5}, {0, 2, 3}}; // x, j + TA::TiledRange tr_dv{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, m + auto g = make_patterned_array(world, tr_g, 1.0); + auto cv = make_patterned_tot_array(world, tr_cv, {2}, 2.0); + auto dv = make_patterned_tot_array(world, tr_dv, {3}, 3.0); + + TArrayToT i1, i2, w_ref, w; + i1("x,i,j;a") = g("x,i") * cv("x,j;a"); // T x ToT general + i2("i,j,m;a,b") = i1("x,i,j;a") * dv("x,i,m;b"); // ToT x ToT, inner outer + w_ref("i,j,m;a,b") = 2.0 * i2("i,j,m;a,b"); + + BOOST_REQUIRE_NO_THROW(w("i,j,m;a,b") = + 2.0 * ((g("x,i") * cv("x,j;a")) * dv("x,i,m;b"))); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_tot_no_externals_gated) { + // a ToT x ToT general product with NO external (free) outer indices -- + // every outer index fused or contracted -- is not supported by the + // batched tile op (and used to segfault in the folded GEMM); the engine + // must reject it with an informative error (einsum() handles this shape + // natively via its no-external regime) + auto& world = TA::get_default_world(); + TA::TiledRange tr{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, j + auto a = make_patterned_tot_array(world, tr, {2}, 1.0); + auto b = make_patterned_tot_array(world, tr, {3}, 2.0); + + TArrayToT c; + BOOST_CHECK_THROW(c("i,j;a,b") = a("x,i,j;a") * b("x,i,j;b"), + TiledArray::Exception); +} + BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { // the PNO-CC PPL building-block shape: ToT x ToT with an EMPTY right // outer-external set and an inner OUTER-product: From 3f00dae7b0835959ff7f520579799638abe4f265 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 00:18:58 -0400 Subject: [PATCH 21/25] expressions: native no-external general products via a synthetic unit left-external mode A general product whose every outer index is fused or contracted (e.g. C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) folds to a GEMM with no free modes, i.e. rank-0 tensors, which the tile kernels do not support (this shape used to segfault through wild stride reads). Evaluate it with a synthetic unit left-external mode instead: the folded product becomes (1,K) x (K) -> (1), the exact shape of the already-supported one-sided neB == 0 case. The unit mode lives only in the tile op's GemmHelper; tranges, shapes and tiles carry the true (external-free) ranks, and BatchedContractReduce / SparseShape::gemm_batched detect the synthetic mode from the one-rank mismatch and pad their folded views with a unit extent. Replaces the interim gate. Tests: dense ToT (incl. the no-external root fed by a general T x ToT inner node), plain dense (the Hadamard-reduction shape), and block-sparse (exercising the gemm_batched unit handling), all differential-tested against legacy einsum. --- src/TiledArray/expressions/cont_engine.h | 51 ++++++++------ src/TiledArray/sparse_shape.h | 68 ++++++++++++------- .../tile_op/batched_contract_reduce.h | 26 +++++-- tests/general_product.cpp | 58 ++++++++++++++-- 4 files changed, 144 insertions(+), 59 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 2c9cbbd7a8..fdf5b0da82 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -666,19 +666,6 @@ class ContEngine : public BinaryEngine { TA_ASSERT(nh > 0u); // else this is a pure contraction n_fused_modes_ = nh; - // a general product of tensors-of-tensors with NO external (free) outer - // indices (every outer index fused or contracted, e.g. - // C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) is not supported by the - // batched tile op yet (its folded GEMM has no free modes); - // einsum() evaluates this shape natively - if constexpr (TiledArray::detail::is_tensor_of_tensor_v) { - if (outer_size(indices_) == nh) - TA_EXCEPTION( - "general products of tensors-of-tensors without external (free) " - "outer indices are not yet supported in the expression layer; " - "use TiledArray::einsum() for this contraction"); - } - // initialize perm_; a target that differs from the canonical (fused..., // left-free..., right-free...) result layout cannot be folded into the // batched tile op (BatchedContractReduce must be perm-free), so the @@ -687,12 +674,27 @@ class ContEngine : public BinaryEngine { this->init_perm(target_indices); general_repermute_ = (outer(target_indices) != outer(indices_)); - // the tile op operates on the folded (fused-mode-free) shapes - const auto left_op = to_cblas_op(left_outer_permtype_); + // A product with NO external (free) outer indices (every outer index + // fused or contracted, e.g. C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) + // folds to a GEMM with no free modes, i.e. rank-0 tensors, which the + // tile kernels do not support. Evaluate it with a SYNTHETIC UNIT + // left-external mode instead: the folded product becomes + // (1,K) x (K) -> (1), the exact shape of the (supported) one-sided + // neB == 0 case. The unit mode lives only in the tile op's GemmHelper; + // tranges, shapes and tiles carry the true (external-free) ranks, and + // BatchedContractReduce / SparseShape::gemm_batched detect the + // synthetic mode from the one-rank mismatch and pad their folded views + // with a unit extent. + const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u; + + // the tile op operates on the folded (fused-mode-free) shapes; the + // synthetic unit mode leads the folded left operand, so it is NoTrans + const auto left_op = + u ? math::blas::NoTranspose : to_cblas_op(left_outer_permtype_); const auto right_op = to_cblas_op(right_outer_permtype_); if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { - op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh, - outer_size(left_indices_) - nh, + op_ = op_type(left_op, right_op, factor_, outer_size(indices_) - nh + u, + outer_size(left_indices_) - nh + u, outer_size(right_indices_) - nh); } else { // the batched tile op must be perm-free (BatchedContractReduce cannot @@ -710,7 +712,8 @@ class ContEngine : public BinaryEngine { // factor_ is absorbed into element_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), - outer_size(indices_) - nh, outer_size(left_indices_) - nh, + outer_size(indices_) - nh + u, + outer_size(left_indices_) - nh + u, outer_size(right_indices_) - nh, BipartitePermutation{}, this->element_nonreturn_op_, std::move(this->arena_plan_)); // ce+e, ce+ce_right and ce+ce_left are mutually exclusive; at most one @@ -750,7 +753,11 @@ class ContEngine : public BinaryEngine { trange_type make_trange_general() const { const unsigned int nh = n_fused_modes_; const unsigned int nc = op_.gemm_helper().num_contract_ranks(); - const unsigned int neA = op_.gemm_helper().left_rank() - nc; + // the no-external case carries a synthetic unit left-external mode in + // the GemmHelper only (see init_struct_general); the actual tranges do + // not have it + const unsigned int u = (outer_size(indices_) == n_fused_modes_) ? 1u : 0u; + const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; const unsigned int neB = op_.gemm_helper().right_rank() - nc; typename trange_type::Ranges ranges(nh + neA + neB); @@ -802,7 +809,11 @@ class ContEngine : public BinaryEngine { std::shared_ptr pmap) { const unsigned int nh = n_fused_modes_; const unsigned int nc = op_.gemm_helper().num_contract_ranks(); - const unsigned int neA = op_.gemm_helper().left_rank() - nc; + // the no-external case carries a synthetic unit left-external mode in + // the GemmHelper only (see init_struct_general); the actual tranges do + // not have it + const unsigned int u = (outer_size(indices_) == nh) ? 1u : 0u; + const unsigned int neA = op_.gemm_helper().left_rank() - nc - u; const unsigned int neB = op_.gemm_helper().right_rank() - nc; // Get pointers to the argument sizes diff --git a/src/TiledArray/sparse_shape.h b/src/TiledArray/sparse_shape.h index 7af41176c7..5ab12a3dee 100644 --- a/src/TiledArray/sparse_shape.h +++ b/src/TiledArray/sparse_shape.h @@ -1720,10 +1720,19 @@ class SparseShape { const auto* left_extent = tile_norms_.range().extent_data(); const auto* right_extent = other.tile_norms_.range().extent_data(); + // a no-external product carries a SYNTHETIC unit left-external mode in + // the GemmHelper only (see ContEngine::init_struct_general); detect it + // from the one-rank mismatch with the actual norm tensor and pad the + // folded left/result views with a unit extent + const bool unit_external = + (tile_norms_.range().rank() + 1u == nfused + gemm_helper.left_rank()); + const unsigned int u = unit_external ? 1u : 0u; + // check that the ranks match the folded gemm ranks plus the fused modes, // and that the fused and contracted mode extents of the two shapes are // congruent - TA_ASSERT(tile_norms_.range().rank() == nfused + gemm_helper.left_rank()); + TA_ASSERT(tile_norms_.range().rank() + u == + nfused + gemm_helper.left_rank()); TA_ASSERT(other.tile_norms_.range().rank() == nfused + gemm_helper.right_rank()); for (unsigned int d = 0u; d < nfused; ++d) @@ -1731,31 +1740,34 @@ class SparseShape { for (unsigned int i = gemm_helper.left_inner_begin(), j = gemm_helper.right_inner_begin(); i < gemm_helper.left_inner_end(); ++i, ++j) - TA_ASSERT(left_extent[nfused + i] == right_extent[nfused + j]); + TA_ASSERT(left_extent[nfused + i - u] == right_extent[nfused + j]); integer H = 1, M = 1, N = 1, K = 1; for (unsigned int d = 0u; d < nfused; ++d) H *= left_extent[d]; - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i) - M *= left_extent[nfused + i]; + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) + M *= left_extent[nfused + i]; for (unsigned int i = gemm_helper.left_inner_begin(); i < gemm_helper.left_inner_end(); ++i) - K *= left_extent[nfused + i]; + K *= left_extent[nfused + i - u]; for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i) N *= right_extent[nfused + i]; // result size vectors: fused modes (from this), then the left and right - // outer modes - const unsigned int result_rank = nfused + gemm_helper.result_rank(); + // outer modes (the synthetic unit left-external mode is absent from the + // actual result) + const unsigned int result_rank = nfused + gemm_helper.result_rank() - u; std::shared_ptr result_size_vectors( new vector_type[result_rank], std::default_delete()); unsigned int x = 0ul; for (unsigned int i = 0u; i < nfused; ++i, ++x) result_size_vectors.get()[x] = size_vectors_.get()[i]; - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i, ++x) - result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i, ++x) + result_size_vectors.get()[x] = size_vectors_.get()[nfused + i]; for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i, ++x) result_size_vectors.get()[x] = other.size_vectors_.get()[nfused + i]; @@ -1770,11 +1782,12 @@ class SparseShape { lobounds.push_back(tile_norms_.range().lobound_data()[d]); upbounds.push_back(tile_norms_.range().upbound_data()[d]); } - for (unsigned int i = gemm_helper.left_outer_begin(); - i < gemm_helper.left_outer_end(); ++i) { - lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); - upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); - } + if (!unit_external) + for (unsigned int i = gemm_helper.left_outer_begin(); + i < gemm_helper.left_outer_end(); ++i) { + lobounds.push_back(tile_norms_.range().lobound_data()[nfused + i]); + upbounds.push_back(tile_norms_.range().upbound_data()[nfused + i]); + } for (unsigned int i = gemm_helper.right_outer_begin(); i < gemm_helper.right_outer_end(); ++i) { lobounds.push_back(other.tile_norms_.range().lobound_data()[nfused + i]); @@ -1784,10 +1797,13 @@ class SparseShape { // the range spanned by modes [nfused, rank) of \p r, rebased to zero // lobounds (scratch view for the slab-batched norm GEMM) - auto fold_range = [nfused](const range_type& r) { + auto fold_range = [nfused](const range_type& r, + const bool prepend_unit = false) { const auto* extent = r.extent_data(); - container::svector extents(extent + nfused, - extent + r.rank()); + container::svector extents; + extents.reserve(r.rank() - nfused + (prepend_unit ? 1u : 0u)); + if (prepend_unit) extents.push_back(1); + extents.insert(extents.end(), extent + nfused, extent + r.rank()); return range_type(extents); }; @@ -1797,10 +1813,11 @@ class SparseShape { if (k_rank > 0u) { // the contracted-mode size vector; identical for every slab since the - // contracted modes follow the fused modes + // contracted modes follow the fused modes (helper coordinates carry + // the synthetic unit mode, the actual size vectors do not) const vector_type k_sizes = recursive_outer_product( - size_vectors_.get() + nfused + gemm_helper.left_inner_begin(), k_rank, - [](const vector_type& size_vector) -> const vector_type& { + size_vectors_.get() + nfused + gemm_helper.left_inner_begin() - u, + k_rank, [](const vector_type& size_vector) -> const vector_type& { return size_vector; }); @@ -1828,10 +1845,11 @@ class SparseShape { // slab-batched norm GEMM: fold the fused modes into the tensor batch // dimension by zero-copy reshape; result_folded shares result_norms' // buffer, so the accumulation lands in place - auto left_folded = left.reshape(fold_range(left.range()), H); + auto left_folded = + left.reshape(fold_range(left.range(), unit_external), H); auto right_folded = right.reshape(fold_range(right.range()), H); - auto result_folded = - result_norms.reshape(fold_range(result_norms.range()), H); + auto result_folded = result_norms.reshape( + fold_range(result_norms.range(), unit_external), H); result_folded.gemm(left_folded, right_folded, abs_factor, gemm_helper); // Hard zero tiles that are below the zero threshold. diff --git a/src/TiledArray/tile_op/batched_contract_reduce.h b/src/TiledArray/tile_op/batched_contract_reduce.h index 570a7398c8..48b028e267 100644 --- a/src/TiledArray/tile_op/batched_contract_reduce.h +++ b/src/TiledArray/tile_op/batched_contract_reduce.h @@ -62,12 +62,16 @@ class BatchedContractReduce { /// \return the range spanned by modes [nfused_, rank) of \p r, rebased to /// zero lobounds (the folded view is a GEMM scratch view; only extents - /// matter) + /// matter); \p prepend_unit prepends a unit extent (the synthetic + /// left-external mode of a no-external product, see + /// ContEngine::init_struct_general) template - Range_ fold_range(const Range_& r) const { + Range_ fold_range(const Range_& r, const bool prepend_unit = false) const { const auto* extent = r.extent_data(); - container::svector extents(extent + nfused_, - extent + r.rank()); + container::svector extents; + extents.reserve(r.rank() - nfused_ + (prepend_unit ? 1u : 0u)); + if (prepend_unit) extents.push_back(1); + extents.insert(extents.end(), extent + nfused_, extent + r.rank()); return Range_(extents); } @@ -150,7 +154,13 @@ class BatchedContractReduce { const auto& gh = op_.gemm_helper(); const unsigned int nc = gh.num_contract_ranks(); - const unsigned int neA = gh.left_rank() - nc; + // a no-external product carries a SYNTHETIC unit left-external mode in + // the GemmHelper only (see ContEngine::init_struct_general); detect it + // from the one-rank mismatch with the actual left tile and pad the + // folded left/result views with a unit extent + const bool unit_external = + (left.range().rank() + 1u == nfused_ + gh.left_rank()); + const unsigned int neA = gh.left_rank() - nc - (unit_external ? 1u : 0u); const unsigned int neB = gh.right_rank() - nc; // both args must carry the fused modes as their leading modes, with @@ -163,7 +173,8 @@ class BatchedContractReduce { TA_ASSERT(batch == fused_volume(right.range())); // folded, zero-copy argument views - auto left_folded = left.reshape(fold_range(left.range()), batch); + auto left_folded = + left.reshape(fold_range(left.range(), unit_external), batch); auto right_folded = right.reshape(fold_range(right.range()), batch); if (empty(result)) { @@ -195,7 +206,8 @@ class BatchedContractReduce { } else { // accumulate through a folded, zero-copy view of the result const auto full_range = result.range(); - auto result_folded = result.reshape(fold_range(full_range), batch); + auto result_folded = + result.reshape(fold_range(full_range, unit_external), batch); op_(result_folded, left_folded, right_folded); // the wrapped op may REBIND the result instead of writing in place: // the arena grow-to-cover path (a later K-panel touching inner cells diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 2fd500311a..58976f87dc 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -731,20 +731,64 @@ BOOST_AUTO_TEST_CASE(expression_general_kitchen_sink) { BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); } -BOOST_AUTO_TEST_CASE(expression_general_product_tot_no_externals_gated) { +BOOST_AUTO_TEST_CASE(expression_general_product_tot_no_externals) { // a ToT x ToT general product with NO external (free) outer indices -- - // every outer index fused or contracted -- is not supported by the - // batched tile op (and used to segfault in the folded GEMM); the engine - // must reject it with an informative error (einsum() handles this shape - // natively via its no-external regime) + // every outer index fused or contracted. The folded product has no free + // modes, so it is evaluated with a synthetic unit left-external mode + // carried by the tile op's GemmHelper only (this shape used to segfault). auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent TA::TiledRange tr{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, j auto a = make_patterned_tot_array(world, tr, {2}, 1.0); auto b = make_patterned_tot_array(world, tr, {3}, 2.0); TArrayToT c; - BOOST_CHECK_THROW(c("i,j;a,b") = a("x,i,j;a") * b("x,i,j;b"), - TiledArray::Exception); + BOOST_REQUIRE_NO_THROW(c("i,j;a,b") = a("x,i,j;a") * b("x,i,j;b")); + auto c_ref = TA::einsum(a("x,i,j;a"), b("x,i,j;b"), "i,j;a,b"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); + + // the same no-external root with the left operand produced by a general + // T x ToT product at an INNER node (the original motivating expression) + TA::TiledRange tr_g{{0, 3, 5}, {0, 2, 4}}; // x, i + TA::TiledRange tr_cv{{0, 3, 5}, {0, 2, 3}}; // x, j + auto g = make_patterned_array(world, tr_g, 1.0); + auto cv = make_patterned_tot_array(world, tr_cv, {2}, 2.0); + TArrayToT i1, w, w_ref; + i1("x,i,j;a") = g("x,i") * cv("x,j;a"); + w_ref("i,j;a,b") = i1("x,i,j;a") * b("x,i,j;b"); + BOOST_REQUIRE_NO_THROW(w("i,j;a,b") = + (g("x,i") * cv("x,j;a")) * b("x,i,j;b")); + BOOST_CHECK_SMALL(tot_max_abs_diff(w, w_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_no_externals) { + // the plain-tensor analogue of the no-external general product (the + // Hadamard-reduction shape): C("i") = A("i,j") * B("i,j") + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr{{0, 2, 4}, {0, 3, 5}}; // i, j + auto a = make_patterned_array(world, tr, 1.0); + auto b = make_patterned_array(world, tr, 2.0); + + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("i") = a("i,j") * b("i,j")); + auto c_ref = TA::einsum(a("i,j"), b("i,j"), "i"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "i"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(expression_general_product_sparse_no_externals) { + // block-sparse no-external general product: exercises the synthetic + // unit-mode handling in SparseShape::gemm_batched + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; // keep the einsum reference independent + TA::TiledRange tr{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + auto a = make_patterned_sparse_array(world, tr, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr, 2.0, 4); + + TA::TSpArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i") = a("b,i,j") * b("b,i,j")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,i,j"), "b,i"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i"), 1e-10); } BOOST_AUTO_TEST_CASE(expression_general_product_tot_inner_outer_product) { From 5e1e75127c3877aac8bcb333cd2bd6d087ab456a Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 09:32:03 -0400 Subject: [PATCH 22/25] proc_grid/pmap: rank-subset ProcGrid + 3-d (proc_h) SlabbedPmap Infrastructure for a 3-d (proc_h x proc_r x proc_c) batched-Summa grid that distributes the fused/batch (h, slab) dimension of a general product across process planes: - ProcGrid gains a rank-subset constructor (tagged rank_subset to avoid colliding with the same-arity test-only ctor) that builds a 2-d grid over a contiguous interval [rank_offset, rank_offset + nprocs) of the world's ranks; map_row/map_col and the row/col group factories emit world-correct ranks via the offset. The legacy full-world ctor is unchanged (offset 0). - SlabbedPmap gains a 3-d variant (proc_h, proc_h_stride): slab h belongs to plane h % proc_h of proc_h_stride contiguous ranks, and the per-slab base map's plane-local owners are offset by the slab's plane. The original 3-argument form (proc_h == 1, slab-replicated) is unchanged. --- src/TiledArray/pmap/slabbed_pmap.h | 74 ++++++++++++++++++++++++-- src/TiledArray/proc_grid.h | 85 +++++++++++++++++++++++++----- 2 files changed, 142 insertions(+), 17 deletions(-) diff --git a/src/TiledArray/pmap/slabbed_pmap.h b/src/TiledArray/pmap/slabbed_pmap.h index a345eccb6a..c4e6d9b4b5 100644 --- a/src/TiledArray/pmap/slabbed_pmap.h +++ b/src/TiledArray/pmap/slabbed_pmap.h @@ -44,6 +44,16 @@ namespace detail { /// (Hadamard/batch) modes as their leading dimensions, every slab is /// distributed identically over the same 2-d process grid, and the slab /// index never participates in inter-process communication patterns. +/// The 3-d-grid variant adds a process dimension \c proc_h along the slab +/// (h) axis: the world's first `proc_h * proc_h_stride` ranks are partitioned +/// into `proc_h` contiguous h-planes of `proc_h_stride` ranks each; slab +/// \f$ h \f$ belongs to plane \f$ h \% proc\_h \f$, and within its plane is +/// distributed by the base map, whose owners must then be PLANE-LOCAL ranks +/// in \f$ [0, proc\_h\_stride) \f$: the owner of index \f$ o \f$ is +/// \f$ (h \% proc\_h) \cdot proc\_h\_stride + base(o \% S) \f$. +/// `proc_h == 1` reduces to the slab-replicated map above (and is constructed +/// via the 2-argument-base constructor, which delegates locality to the base +/// map). class SlabbedPmap : public Pmap { protected: // Import Pmap protected variables @@ -56,6 +66,14 @@ class SlabbedPmap : public Pmap { const std::shared_ptr base_; ///< The per-slab base map const size_type slab_size_; ///< The number of indices per slab const size_type nslabs_; ///< The number of slabs + const size_type proc_h_ = 1; ///< Process-grid extent along the slab + ///< (h) axis (the number of h-planes) + const size_type proc_h_stride_ = 0; ///< Ranks per h-plane (0 = the whole + ///< world; base owners are world ranks + ///< and locality delegates to the base) + + /// \return whether this map distributes the slab axis over a process plane + bool hgrouped() const { return proc_h_stride_ != 0ul; } public: typedef Pmap::size_type size_type; ///< Size type @@ -76,6 +94,49 @@ class SlabbedPmap : public Pmap { this->local_size_ = base_->local_size() * nslabs_; } + /// Construct an h-grouped (3-d grid) slabbed process map + + /// \param world The world where the tiles will be mapped + /// \param base The base process map, defined over one slab; its owners + /// must be GROUP-LOCAL ranks in [0, proc_h_stride) + /// \param nslabs The number of slabs + /// \param proc_h The number of slab groups (slab h -> group h % proc_h) + /// \param proc_h_stride The number of (contiguous) world ranks per group + SlabbedPmap(World& world, std::shared_ptr base, + const size_type nslabs, const size_type proc_h, + const size_type proc_h_stride) + : Pmap(world, base->size() * nslabs), + base_(std::move(base)), + slab_size_(base_->size()), + nslabs_(nslabs), + proc_h_(proc_h), + proc_h_stride_(proc_h_stride) { + TA_ASSERT(base_); + TA_ASSERT(nslabs_ > 0ul); + TA_ASSERT(proc_h_ > 0ul); + TA_ASSERT(proc_h_stride_ > 0ul); + TA_ASSERT(proc_h_ * proc_h_stride_ <= size_type(world.size())); + + // this rank's group and group-local rank; ranks beyond the grouped + // prefix of the world own nothing + const size_type rank = world.rank(); + if (rank < proc_h_ * proc_h_stride_) { + const size_type my_group = rank / proc_h_stride_; + const size_type my_group_rank = rank % proc_h_stride_; + // count of my group's slabs: h in [0, nslabs) with h % proc_h == + // my_group + const size_type my_slabs = + (nslabs_ / proc_h_) + (my_group < (nslabs_ % proc_h_) ? 1u : 0u); + // count of slab indices the base assigns to my group-local rank + size_type base_local = 0ul; + for (size_type j = 0ul; j < slab_size_; ++j) + if (base_->owner(j) == my_group_rank) ++base_local; + this->local_size_ = my_slabs * base_local; + } else { + this->local_size_ = 0ul; + } + } + virtual ~SlabbedPmap() {} /// \return the per-slab base map @@ -84,6 +145,8 @@ class SlabbedPmap : public Pmap { size_type slab_size() const { return slab_size_; } /// \return the number of slabs size_type nslabs() const { return nslabs_; } + /// \return the number of slab (h) groups + size_type proc_h() const { return proc_h_; } /// Maps \c tile to the process that owns it @@ -91,7 +154,9 @@ class SlabbedPmap : public Pmap { /// \return Process that logically owns \c tile virtual size_type owner(const size_type tile) const { TA_ASSERT(tile < size_); - return base_->owner(tile % slab_size_); + if (!hgrouped()) return base_->owner(tile % slab_size_); + const size_type group = (tile / slab_size_) % proc_h_; + return group * proc_h_stride_ + base_->owner(tile % slab_size_); } /// Check that the tile is owned by this process @@ -99,10 +164,13 @@ class SlabbedPmap : public Pmap { /// \param tile The tile to be checked /// \return \c true if \c tile is owned by this process, otherwise \c false virtual bool is_local(const size_type tile) const { - return base_->is_local(tile % slab_size_); + if (!hgrouped()) return base_->is_local(tile % slab_size_); + return owner(tile) == size_type(rank_); } - virtual bool known_local_size() const { return base_->known_local_size(); } + virtual bool known_local_size() const { + return hgrouped() ? true : base_->known_local_size(); + } virtual const_iterator begin() const { return Iterator(*this, 0ul, size_, 0ul, /*checking=*/true); diff --git a/src/TiledArray/proc_grid.h b/src/TiledArray/proc_grid.h index cd15c1b73e..758e7fcefb 100644 --- a/src/TiledArray/proc_grid.h +++ b/src/TiledArray/proc_grid.h @@ -55,6 +55,14 @@ namespace detail { /// \f] /// where the positive, real root of \f$P_{\rm{row}}\f$ give the optimal /// optimal communication time. + +/// Tag to disambiguate the rank-subset ProcGrid constructor from the +/// (same-arity) test-only constructor. +struct rank_subset_t { + explicit rank_subset_t() = default; +}; +inline constexpr rank_subset_t rank_subset{}; + class ProcGrid { public: typedef uint_fast32_t size_type; @@ -71,9 +79,12 @@ class ProcGrid { ///< may be less than the number of processes in world. ProcessID rank_row_; ///< This process's row in the process grid ProcessID rank_col_; ///< This process's column in the process grid - size_type local_rows_; ///< The number of local element rows - size_type local_cols_; ///< The number of local element columns - size_type local_size_; ///< Number of local elements + ProcessID rank_offset_ = 0; ///< World rank of the grid's first process + ///< (nonzero for a grid over a contiguous + ///< subset of the world's ranks) + size_type local_rows_; ///< The number of local element rows + size_type local_cols_; ///< The number of local element columns + size_type local_size_; ///< Number of local elements /// Compute the number of process rows that minimizes communication @@ -180,14 +191,16 @@ class ProcGrid { proc_cols_ = 1u; proc_size_ = 1u; - // Set this process rank - rank_row_ = 0; - rank_col_ = 0; + if (rank < proc_size_) { + // Set this process rank + rank_row_ = 0; + rank_col_ = 0; - // Set local counts - local_rows_ = rows_; - local_cols_ = cols_; - local_size_ = size_; + // Set local counts + local_rows_ = rows_; + local_cols_ = cols_; + local_size_ = size_; + } } else if (size_ <= nprocs) { // Max one tile per process @@ -291,6 +304,48 @@ class ProcGrid { init(world_->rank(), world_->size(), row_size, col_size); } + /// Construct a process grid over a contiguous subset of the world's ranks + + /// The grid spans world ranks [rank_offset, rank_offset + nprocs); ranks + /// outside that interval construct a valid "not in the grid" instance + /// (zero local sizes, empty groups). Used by the h-grouped (3-d) batched + /// Summa, where each fused-index slab group runs its own 2-d grid. + /// \param world The world where the process grid will live + /// \param rank_offset The world rank of the grid's first process + /// \param nprocs The number of processes spanned by the grid + /// \param rows The number of tile rows + /// \param cols The number of tile columns + /// \param row_size The number of element rows + /// \param col_size The number of element columns + ProcGrid(World& world, rank_subset_t, const ProcessID rank_offset, + const size_type nprocs, const size_type rows, const size_type cols, + const std::size_t row_size, const std::size_t col_size) + : world_(&world), + rows_(rows), + cols_(cols), + size_(rows_ * cols_), + proc_rows_(0ul), + proc_cols_(0ul), + proc_size_(0ul), + rank_row_(-1), + rank_col_(-1), + rank_offset_(rank_offset), + local_rows_(0ul), + local_cols_(0ul), + local_size_(0ul) { + TA_ASSERT(rank_offset >= 0); + TA_ASSERT(nprocs >= 1u); + TA_ASSERT(rank_offset + nprocs <= size_type(world.size())); + const auto world_rank = world.rank(); + // out-of-grid ranks pass rank == nprocs, which every init() branch + // treats as "not in the grid" + const size_type rank = (world_rank >= rank_offset && + world_rank < rank_offset + ProcessID(nprocs)) + ? world_rank - rank_offset + : nprocs; + init(rank, nprocs, row_size, col_size); + } + #ifdef TILEDARRAY_ENABLE_TEST_PROC_GRID // Note: The following function is here for testing purposes only. It // has the same functionality as the default constructor above, except the @@ -350,6 +405,7 @@ class ProcGrid { proc_size_(other.proc_size_), rank_row_(other.rank_row_), rank_col_(other.rank_col_), + rank_offset_(other.rank_offset_), local_rows_(other.local_rows_), local_cols_(other.local_cols_), local_size_(other.local_size_) {} @@ -367,6 +423,7 @@ class ProcGrid { proc_size_ = other.proc_size_; rank_row_ = other.rank_row_; rank_col_ = other.rank_col_; + rank_offset_ = other.rank_offset_; local_rows_ = other.local_rows_; local_cols_ = other.local_cols_; local_size_ = other.local_size_; @@ -447,7 +504,7 @@ class ProcGrid { // Populate the row process list size_type p = rank_row_ * proc_cols_; const size_type row_end = p + proc_cols_; - for (; p < row_end; ++p) proc_list.push_back(p); + for (; p < row_end; ++p) proc_list.push_back(p + rank_offset_); // Construct the group group = madness::Group(*world_, proc_list, did); @@ -472,7 +529,7 @@ class ProcGrid { // Populate the column process list for (size_type p = rank_col_; p < proc_size_; p += proc_cols_) - proc_list.push_back(p); + proc_list.push_back(p + rank_offset_); // Construct the group if (proc_list.size() != 0) @@ -489,7 +546,7 @@ class ProcGrid { /// (row,rank_col) ProcessID map_row(const size_type row) const { TA_ASSERT(row < proc_rows_); - return rank_col_ + row * proc_cols_; + return rank_col_ + row * proc_cols_ + rank_offset_; } /// Map a column to the process in this process's row @@ -499,7 +556,7 @@ class ProcGrid { /// (rank_row,col) ProcessID map_col(const size_type col) const { TA_ASSERT(col < proc_cols_); - return rank_row_ * proc_cols_ + col; + return rank_row_ * proc_cols_ + col + rank_offset_; } /// Construct a cyclic process From 0ac088174bc58217e2eaffc02ac3987b822427bc Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 09:32:46 -0400 Subject: [PATCH 23/25] summa: 3-d (proc_h) process grid for the batched Summa Distribute the fused/batch (h, slab) dimension of a general product over a third process-grid axis proc_h, recovering parallelism when the result is small (M*N result tiles < P ranks) -- most acutely no-external products (M=N=1, e.g. the PNO-CCSD PPL intermediate), where the 2-d grid otherwise degenerates to a single rank. The world's first proc_h * proc_h_stride ranks form proc_h h-planes of proc_h_stride = P/proc_h ranks; slab h is evaluated on plane h % proc_h, which runs an ordinary 2-d SUMMA over its own (offset) process grid. Slabs are communication-free (independent), so the surplus of ranks beyond one result-tile-per-rank is spent on this axis. Summa carries per-plane state (first_slab_, my_slabs_), restricts its slab iteration to the plane (next_step), indexes reduce tasks by plane-local slab ordinal (slab_ord), uses plane-unique dense broadcast keys, and computes the result-tile owner (result_tile_owner) as the within-plane cyclic owner shifted by the plane's world-rank offset -- matching set_tile's pmap-routed destination (the two disagreeing was a get_tile/set_tile owner mismatch that deadlocked cross-plane result transfers). proc_h == 1 reproduces the 2-d path exactly. ContEngine::init_distribution_general sizes proc_h by a greedy heuristic (spread ranks beyond min(P, M*N) over the slab axis, bounded by n_slabs) and builds the plane-local grid + 3-d operand/result pmaps. A TODO marks the principled co-optimization of proc_h with the 2-d aspect ratio from the h/left-external/right-external element extents and a memory bound. --- src/TiledArray/dist_eval/contraction_eval.h | 176 ++++++++++++++------ src/TiledArray/expressions/cont_engine.h | 104 +++++++++--- 2 files changed, 210 insertions(+), 70 deletions(-) diff --git a/src/TiledArray/dist_eval/contraction_eval.h b/src/TiledArray/dist_eval/contraction_eval.h index e2b9c8c731..fc5fa95895 100644 --- a/src/TiledArray/dist_eval/contraction_eval.h +++ b/src/TiledArray/dist_eval/contraction_eval.h @@ -127,11 +127,65 @@ class Summa const ordinal_type right_stride_local_; ///< stride for local right row iterators + // 3-d (h-grouped) grid information. The world's first + // proc_h_ * proc_h_stride ranks are partitioned into proc_h_ contiguous + // groups; slab h belongs to group h % proc_h_, and each group runs its + // own 2-d SUMMA grid (proc_grid_ is this rank's GROUP-LOCAL grid, + // constructed over the group's rank interval). proc_h_ == 1 is the + // ordinary shared-grid batched contraction. + const ordinal_type proc_h_; ///< Number of slab (h) groups + const ordinal_type + proc_h_stride_; ///< World ranks per slab group (the + ///< group of slab h spans world ranks + ///< [(h % proc_h_) * proc_h_stride_, ...)) + const ordinal_type first_slab_; ///< This rank's group's first slab (== its + ///< group index), or nh_ if this rank is + ///< in no group (idle for this eval) + const ordinal_type my_slabs_; ///< Number of slabs of this rank's group + + /// \return the world rank that owns result tile \p i: the within-group + /// owner (from the group-local process grid) shifted by the world-rank + /// offset of the group that owns \p i's slab. For proc_h_ == 1 the offset + /// is 0 and this is the ordinary cyclic owner. + ProcessID result_tile_owner(const ordinal_type i) const { + const ordinal_type source_index = DistEvalImpl_::perm_index_to_source(i); + // owner is independent of slab index *within a group* + const ordinal_type slab_index = source_index % result_slab_size_; + const ordinal_type tile_row = slab_index / proc_grid_.cols(); + const ordinal_type tile_col = slab_index % proc_grid_.cols(); + const ordinal_type proc_row = tile_row % proc_grid_.proc_rows(); + const ordinal_type proc_col = tile_col % proc_grid_.proc_cols(); + const ProcessID within_group = proc_row * proc_grid_.proc_cols() + proc_col; + // shift by the offset of the group that owns this tile's slab + const ordinal_type slab = source_index / result_slab_size_; + const ordinal_type group = (proc_h_ > 1ul) ? (slab % proc_h_) : 0ul; + return ProcessID(group * proc_h_stride_) + within_group; + } + /// \return the slab index of SUMMA step \p s ordinal_type step_h(const ordinal_type s) const { return s / k_; } /// \return the within-slab inner-dimension index of SUMMA step \p s ordinal_type step_k(const ordinal_type s) const { return s % k_; } + /// \return the smallest SUMMA step >= \p s that belongs to one of this + /// rank's group's slabs, or nsteps_ if there is none + ordinal_type next_step(ordinal_type s) const { + if (proc_h_ == 1ul) return std::min(s, nsteps_); + if (first_slab_ >= nh_) return nsteps_; // not in any group + while (s < nsteps_ && (step_h(s) % proc_h_) != first_slab_) + s = (step_h(s) + 1ul) * k_; // jump to the start of the next slab + return std::min(s, nsteps_); + } + + /// \return this rank's group-local ordinal of slab \p h (which must + /// belong to this rank's group) + ordinal_type slab_ord(const ordinal_type h) const { + return (h - first_slab_) / proc_h_; + } + + /// \return the number of SUMMA steps of this rank's group's slabs + ordinal_type my_steps() const { return my_slabs_ * k_; } + typedef Future right_future; ///< Future to a right-hand argument tile typedef Future @@ -225,7 +279,7 @@ class Summa /// \param end The end of the row or column range /// \param stride The row or column index stride /// \param k The broadcast group index - /// \param max_group_size The maximum number of processes in the result + /// \param max_proc_h_stride The maximum number of processes in the result /// group, which is equal to the number of process in this process row or /// column as defined by \c proc_grid_. /// \param key_offset The key that will be used to identify the process group @@ -238,21 +292,21 @@ class Summa const std::vector& process_mask, ordinal_type index, const ordinal_type end, const ordinal_type stride, - const ordinal_type max_group_size, + const ordinal_type max_proc_h_stride, const ordinal_type k, const ordinal_type key_offset, const ProcMap& proc_map) const { // Generate the list of processes in rank_row - std::vector proc_list(max_group_size, -1); + std::vector proc_list(max_proc_h_stride, -1); // Flag the root processes of the broadcast, which may not be included // by shape. - ordinal_type p = k % max_group_size; + ordinal_type p = k % max_proc_h_stride; proc_list[p] = proc_map(p); ordinal_type count = 1ul; // Flag all processes that have non-zero tiles - for (p = 0ul; (index < end) && (count < max_group_size); - index += stride, p = (p + 1u) % max_group_size) { + for (p = 0ul; (index < end) && (count < max_proc_h_stride); + index += stride, p = (p + 1u) % max_proc_h_stride) { if ((proc_list[p] != -1) || (shape.is_zero(index)) || !process_mask.at(p)) continue; @@ -734,7 +788,7 @@ class Summa // broadcast root (i.e. within-slab k congruent to rank_col mod Pcols) const ordinal_type Pcols = proc_grid_.proc_cols(); - for (; s < end; ++s) { + for (s = next_step(s); s < end; s = next_step(s + 1ul)) { const ordinal_type k = step_k(s); if (k % Pcols != static_cast(proc_grid_.rank_col())) continue; @@ -787,7 +841,7 @@ class Summa // broadcast root (i.e. within-slab k congruent to rank_row mod Prows) const ordinal_type Prows = proc_grid_.proc_rows(); - for (; s < end; ++s) { + for (s = next_step(s); s < end; s = next_step(s + 1ul)) { const ordinal_type k = step_k(s); if (k % Prows != static_cast(proc_grid_.rank_row())) continue; @@ -847,9 +901,9 @@ class Summa /// \return The first step, greater than or equal to \c s with non-zero /// tiles, or \c nsteps_ if none is found. ordinal_type iterate_row(ordinal_type s) const { - // Iterate over steps until a non-zero tile is found or the end of the - // matrix is reached. - for (; s < nsteps_; ++s) { + // Iterate over this rank's group's steps until a non-zero tile is found + // or the end of the matrix is reached. + for (s = next_step(s); s < nsteps_; s = next_step(s + 1ul)) { // Search for non-zero tiles in row k of slab h of right ordinal_type i = step_h(s) * right_slab_size_ + step_k(s) * proc_grid_.cols(); @@ -871,9 +925,9 @@ class Summa /// \return The first step, greater than or equal to \c s, that contains /// a non-zero tile. If no non-zero tile is not found, return \c nsteps_. ordinal_type iterate_col(ordinal_type s) const { - // Iterate over steps until a non-zero tile is found or the end of the - // matrix is reached. - for (; s < nsteps_; ++s) { + // Iterate over this rank's group's steps until a non-zero tile is found + // or the end of the matrix is reached. + for (s = next_step(s); s < nsteps_; s = next_step(s + 1ul)) { // Search column k of slab h for non-zero tiles const ordinal_type base = step_h(s) * left_slab_size_; for (ordinal_type i = base + left_start_local_ + step_k(s); @@ -963,11 +1017,17 @@ class Summa return tile_count; } else { // Construct static broadcast groups for dense arguments - // (key space [0, 2*nsteps_) is reserved for the sparse per-step groups) - const madness::DistributedID col_did(DistEvalImpl_::id(), 2ul * nsteps_); + // (key space [0, 2*nsteps_) is reserved for the sparse per-step groups, + // whose keys h*k_ and h*k_+nsteps_ are disjoint across h-groups; the + // two static keys are offset PAST that range and made group-unique so + // that two different groups' single-grid static groups never claim the + // same DistributedID with inconsistent membership) + const std::size_t static_key_base = 2ul * nsteps_ + 2ul * first_slab_; + const madness::DistributedID col_did(DistEvalImpl_::id(), + static_key_base); col_group_ = proc_grid_.make_col_group(col_did); const madness::DistributedID row_did(DistEvalImpl_::id(), - 2ul * nsteps_ + 1ul); + static_key_base + 1ul); row_group_ = proc_grid_.make_row_group(row_did); #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE @@ -985,12 +1045,12 @@ class Summa #endif // TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE // Allocate memory for the reduce pair tasks (one per local result tile - // per slab). + // per slab of this rank's group). std::allocator> alloc; - reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size()); + reduce_tasks_ = alloc.allocate(my_slabs_ * proc_grid_.local_size()); // Iterate over all local tiles - const ordinal_type n = nh_ * proc_grid_.local_size(); + const ordinal_type n = my_slabs_ * proc_grid_.local_size(); for (ordinal_type t = 0ul; t < n; ++t) { // Initialize the reduction task ReducePairTask* MADNESS_RESTRICT const reduce_task = @@ -1019,9 +1079,9 @@ class Summa if (k_ == 0) return 0; // Allocate memory for the reduce pair tasks (one per local result tile - // per slab). + // per slab of this rank's group). std::allocator> alloc; - reduce_tasks_ = alloc.allocate(nh_ * proc_grid_.local_size()); + reduce_tasks_ = alloc.allocate(my_slabs_ * proc_grid_.local_size()); // Initialize iteration variables const ordinal_type col_stride = // The stride to iterate down a column @@ -1030,10 +1090,11 @@ class Summa proc_grid_.proc_cols(); // Iterate over all local tiles, slab by slab (the block-cyclic phase - // restarts at every slab: the owner of tile (h,i,j) does not depend on h) + // restarts at every slab: within a group, the owner of tile (h,i,j) + // does not depend on h) ordinal_type tile_count = 0ul; ReducePairTask* MADNESS_RESTRICT reduce_task = reduce_tasks_; - for (ordinal_type h = 0ul; h < nh_; ++h) { + for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) { const ordinal_type slab_base = h * result_slab_size_; ordinal_type row_start = slab_base + proc_grid_.rank_row() * proc_grid_.cols(); @@ -1103,7 +1164,7 @@ class Summa // Iterate over all local tiles, slab by slab ReducePairTask* reduce_task = reduce_tasks_; - for (ordinal_type h = 0ul; h < nh_; ++h) { + for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) { const ordinal_type slab_base = h * result_slab_size_; ordinal_type row_start = slab_base + proc_grid_.rank_row() * proc_grid_.cols(); @@ -1126,7 +1187,7 @@ class Summa // Deallocate the memory for the reduce pair tasks. std::allocator>().deallocate( - reduce_tasks_, nh_ * proc_grid_.local_size()); + reduce_tasks_, my_slabs_ * proc_grid_.local_size()); } /// Set the result tiles and destroy reduce tasks @@ -1145,7 +1206,7 @@ class Summa // Iterate over all local tiles, slab by slab ReducePairTask* reduce_task = reduce_tasks_; - for (ordinal_type h = 0ul; h < nh_; ++h) { + for (ordinal_type h = first_slab_; h < nh_; h += proc_h_) { const ordinal_type slab_base = h * result_slab_size_; ordinal_type row_start = slab_base + proc_grid_.rank_row() * proc_grid_.cols(); @@ -1177,7 +1238,7 @@ class Summa } // Deallocate the memory for the reduce pair tasks. std::allocator>().deallocate( - reduce_tasks_, nh_ * proc_grid_.local_size()); + reduce_tasks_, my_slabs_ * proc_grid_.local_size()); #ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_FINALIZE ss << "}\n"; @@ -1229,9 +1290,10 @@ class Summa const std::vector& col, const std::vector& row, madness::TaskInterface* const task) { - // The reduce tasks of slab h occupy - // [h * local_size, (h+1) * local_size) - const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size(); + // The reduce tasks of this group's slab h occupy + // [slab_ord(h) * local_size, (slab_ord(h)+1) * local_size) + const ordinal_type slab_offset = + slab_ord(step_h(s)) * proc_grid_.local_size(); // Iterate over the row for (ordinal_type i = 0ul; i < col.size(); ++i) { @@ -1266,9 +1328,10 @@ class Summa const std::vector& col, const std::vector& row, madness::TaskInterface* const task) { - // The reduce tasks of slab h occupy - // [h * local_size, (h+1) * local_size) - const ordinal_type slab_offset = step_h(s) * proc_grid_.local_size(); + // The reduce tasks of this group's slab h occupy + // [slab_ord(h) * local_size, (slab_ord(h)+1) * local_size) + const ordinal_type slab_offset = + slab_ord(step_h(s)) * proc_grid_.local_size(); // Iterate over the row for (ordinal_type i = 0ul; i < col.size(); ++i) { @@ -1540,13 +1603,14 @@ class Summa public: DenseStepTask(const std::shared_ptr& owner, const ordinal_type depth) - : StepTask(owner, owner->nsteps_ + 1ul), k_(0) { + : StepTask(owner, owner->my_steps() + 1ul), k_(owner->next_step(0ul)) { StepTask::make_next_step_tasks(this, depth); - StepTask::spawn_get_row_col_tasks(k_); + if (k_ < owner_->nsteps_) StepTask::spawn_get_row_col_tasks(k_); } DenseStepTask(DenseStepTask* const parent, const int ndep) - : StepTask(parent, ndep), k_(parent->k_ + 1ul) { + : StepTask(parent, ndep), + k_(parent->owner_->next_step(parent->k_ + 1ul)) { // Spawn tasks to get k-th row and column tiles if (k_ < owner_->nsteps_) StepTask::spawn_get_row_col_tasks(k_); } @@ -1671,7 +1735,8 @@ class Summa const trange_type trange, const shape_type& shape, const std::shared_ptr& pmap, const Perm& perm, const op_type& op, const ordinal_type k, const ProcGrid& proc_grid, - const ordinal_type nh = 1ul) + const ordinal_type nh = 1ul, const ordinal_type proc_h = 1ul, + const ordinal_type proc_h_stride = 0ul) : DistEvalImpl_(world, trange, shape, pmap, outer(perm)), left_(left), right_(right), @@ -1685,6 +1750,11 @@ class Summa left_slab_size_(left.size() / nh), right_slab_size_(right.size() / nh), result_slab_size_(proc_grid.rows() * proc_grid.cols()), + proc_h_(proc_h), + proc_h_stride_(proc_h_stride), + first_slab_(compute_first_slab(world, nh, proc_h, proc_h_stride)), + my_slabs_(first_slab_ < nh ? (nh - first_slab_ + proc_h - 1ul) / proc_h + : 0ul), reduce_tasks_(NULL), left_start_local_(proc_grid_.rank_row() * k), left_end_(left.size() / nh), @@ -1703,6 +1773,19 @@ class Summa TA_ASSERT(nh_ > 0); TA_ASSERT(left.size() % nh_ == 0); TA_ASSERT(right.size() % nh_ == 0); + TA_ASSERT(proc_h_ > 0); + TA_ASSERT(proc_h_ == 1ul || proc_h_stride > 0ul); + TA_ASSERT(proc_h_ <= nh_); + } + + /// \return this rank's group's first slab (== its group index), or + /// \p nh if this rank is outside the grouped rank interval + static ordinal_type compute_first_slab(World& world, const ordinal_type nh, + const ordinal_type proc_h, + const ordinal_type proc_h_stride) { + if (proc_h == 1ul) return 0ul; + const auto rank = ordinal_type(world.rank()); + return (rank < proc_h * proc_h_stride) ? (rank / proc_h_stride) : nh; } virtual ~Summa() {} @@ -1717,18 +1800,11 @@ class Summa TA_ASSERT(TensorImpl_::is_local(i)); TA_ASSERT(!TensorImpl_::is_zero(i)); - const ordinal_type source_index = DistEvalImpl_::perm_index_to_source(i); - - // Compute tile coordinate in tile grid (the owner of a tile is - // independent of its slab index) - const ordinal_type slab_index = source_index % result_slab_size_; - const ordinal_type tile_row = slab_index / proc_grid_.cols(); - const ordinal_type tile_col = slab_index % proc_grid_.cols(); - // Compute process coordinate of tile in the process grid - const ordinal_type proc_row = tile_row % proc_grid_.proc_rows(); - const ordinal_type proc_col = tile_col % proc_grid_.proc_cols(); - // Compute the process that owns tile - const ProcessID source = proc_row * proc_grid_.proc_cols() + proc_col; + // The process that owns tile i: the within-group cyclic owner shifted by + // the world-rank offset of the tile's slab group (see + // result_tile_owner). For proc_h_ == 1 this is the ordinary cyclic + // owner over the whole world. + const ProcessID source = result_tile_owner(i); const madness::DistributedID key(DistEvalImpl_::id(), i); return TensorImpl_::world().gop.template recv(source, key); diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index fdf5b0da82..90bd23ec8f 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -185,6 +185,9 @@ class ContEngine : public BinaryEngine { ///< from the canonical result layout, so ///< the evaluated result is re-permuted ///< by a streaming unary eval + size_type proc_h_ = 1; ///< process-grid extent along the slab (h) + ///< axis of the 3-d grid (# of h-planes) + size_type proc_h_stride_ = 0; ///< ranks per h-plane (0 = ungrouped 2-d) static unsigned int find(const BipartiteIndexList& indices, const std::string& index_label, unsigned int i, @@ -851,22 +854,78 @@ class ContEngine : public BinaryEngine { world, (pmap ? pmap : policy::default_pmap(*world, n_slabs_ * M * N))); } else { - // Construct the per-slab process grid. - proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n); - - // Initialize children with slab-replicated SUMMA phase maps - left_.init_distribution( - world, std::make_shared( - *world, proc_grid_.make_row_phase_pmap(K_), n_slabs_)); - right_.init_distribution( - world, std::make_shared( - *world, proc_grid_.make_col_phase_pmap(K_), n_slabs_)); - - // Initialize the process map if not already defined - if (!pmap) - pmap = std::make_shared( - *world, proc_grid_.make_pmap(), n_slabs_); - ExprEngine_::init_distribution(world, pmap); + // Choose the process-grid extent proc_h_ along the slab (h) axis of + // the 3-d grid: ranks beyond one-per-result-tile are useless to a + // single slab's 2-d SUMMA, so the surplus is spread over the slab + // (communication-free) dimension instead -- slab h goes to plane + // h % proc_h_ of proc_h_stride_ = P / proc_h_ contiguous ranks (the + // division-remainder ranks idle for this evaluation). For a + // no-external product (M == N == 1) this degenerates to an effectively + // 1-d grid over the slabs; for an ordinary contraction (n_slabs_ == 1) + // it is a pure 2-d grid (proc_h_ == 1). + // TODO: co-optimize proc_h_ with the 2-d (proc_r, proc_c) aspect ratio + // using the h-, left-external-, and right-external-mode element + // extents (and a per-rank memory bound), rather than the current + // greedy tile-count heuristic. + const size_type P = world->size(); + proc_h_ = 1ul; + if (n_slabs_ > 1ul && P > 1ul) { + const size_type p2d_cap = std::min(P, M * N); + proc_h_ = std::min(n_slabs_, + std::max(1ul, P / p2d_cap)); + } + proc_h_stride_ = P / proc_h_; + + if (proc_h_ == 1ul) { + // Construct the per-slab process grid over the whole world. + proc_grid_ = TiledArray::detail::ProcGrid(*world, M, N, m, n); + + // Initialize children with slab-replicated SUMMA phase maps + left_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_row_phase_pmap(K_), n_slabs_)); + right_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_col_phase_pmap(K_), n_slabs_)); + + // Initialize the process map if not already defined + if (!pmap) + pmap = std::make_shared( + *world, proc_grid_.make_pmap(), n_slabs_); + ExprEngine_::init_distribution(world, pmap); + } else { + // Construct this rank's GROUP-LOCAL per-slab process grid (ranks + // outside the grouped prefix of the world construct a valid + // not-in-grid instance). The grid shape is a pure function of + // (proc_h_stride_, M, N, m, n), so it is congruent across all groups, + // and the CyclicPmap factories below emit GROUP-LOCAL owners in + // [0, proc_h_stride_) which the h-grouped SlabbedPmap offsets by each + // slab's group. + const size_type rank = world->rank(); + const bool in_groups = rank < proc_h_ * proc_h_stride_; + const ProcessID grid_offset = + in_groups ? ProcessID((rank / proc_h_stride_) * proc_h_stride_) + : ProcessID(0); + proc_grid_ = TiledArray::detail::ProcGrid( + *world, TiledArray::detail::rank_subset, grid_offset, + proc_h_stride_, M, N, m, n); + + left_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_row_phase_pmap(K_), n_slabs_, + proc_h_, proc_h_stride_)); + right_.init_distribution( + world, std::make_shared( + *world, proc_grid_.make_col_phase_pmap(K_), n_slabs_, + proc_h_, proc_h_stride_)); + + // Initialize the process map if not already defined + if (!pmap) + pmap = std::make_shared( + *world, proc_grid_.make_pmap(), n_slabs_, proc_h_, + proc_h_stride_); + ExprEngine_::init_distribution(world, pmap); + } } } @@ -914,7 +973,8 @@ class ContEngine : public BinaryEngine { if (!general_repermute_) { std::shared_ptr pimpl = std::make_shared( left, right, *world_, trange_, shape_, pmap_, perm_, - batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_); + batched_op_type(op_, n_fused_modes_), K_, proc_grid_, n_slabs_, + proc_h_, proc_h_stride_); return dist_eval_type(pimpl); } @@ -935,12 +995,16 @@ class ContEngine : public BinaryEngine { // the inner Summa's result placement must be slab-replicated (the owner // of a tile independent of its slab index), regardless of the // (target-layout) pmap the consumer supplied for this node - auto canonical_pmap = std::make_shared( - *world_, proc_grid_.make_pmap(), n_slabs_); + auto canonical_pmap = + proc_h_ == 1ul ? std::make_shared( + *world_, proc_grid_.make_pmap(), n_slabs_) + : std::make_shared( + *world_, proc_grid_.make_pmap(), n_slabs_, proc_h_, + proc_h_stride_); std::shared_ptr pimpl = std::make_shared( left, right, *world_, canonical_trange, canonical_shape, canonical_pmap, BipartitePermutation{}, batched_op_type(op_, n_fused_modes_), K_, - proc_grid_, n_slabs_); + proc_grid_, n_slabs_, proc_h_, proc_h_stride_); dist_eval_type canonical(pimpl); typedef TiledArray::detail::UnaryEvalImpl< From 7b4319c5f717c0d3a797c122520193abe54c8ea8 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 09:33:00 -0400 Subject: [PATCH 24/25] tests: distributed (np>1) general-product suite Adds general_product_distributed_suite (UNLABELED, so the CI harness runs it at both np=1 and np=2; the existing general_product_suite is serial-labeled and never exercised the batched Summa across ranks). Seven differential cases vs the legacy sub-World einsum oracle: dense, sparse, mixed T x ToT, no-external (dense + ToT), the one-expression THC reconstruction, and dist_no_externals_3d_grid -- which engages the 3-d (proc_h > 1) grid and asserts the no-external result distributes across the h-planes rather than piling on one rank. --- tests/general_product.cpp | 141 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/tests/general_product.cpp b/tests/general_product.cpp index 58976f87dc..c9128b31d0 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -8,6 +8,8 @@ #include "tiledarray.h" #include "unit_test_config.h" +#include + BOOST_AUTO_TEST_SUITE(general_product_suite, TA_UT_LABEL_SERIAL) namespace TA = TiledArray; @@ -1342,3 +1344,142 @@ BOOST_AUTO_TEST_CASE(einsum_expression_route_matches_legacy) { } BOOST_AUTO_TEST_SUITE_END() + +// Distributed (np > 1) coverage of general products. Unlike +// general_product_suite (serial-labeled: classification/optimizer unit tests +// + np=1 evaluation), this suite carries NO label, so the CI harness runs it +// at BOTH np=1 and np=2 (see tests/CMakeLists.txt: np-1 excludes @distributed, +// np-2 excludes @serial). It exercises the batched Summa across ranks -- a +// path the serial suite never covered. Each case differential-tests the +// expression route against the legacy sub-World einsum oracle. +// +// Most cases evaluate via the ordinary 2-d (proc_h == 1) batched Summa; +// dist_no_externals_3d_grid exercises the 3-d (proc_h > 1) grid, where the +// heuristic spreads ranks over the slab dimension because M == N == 1. +BOOST_AUTO_TEST_SUITE(general_product_distributed_suite) + +// the fixtures/helpers live in general_product_suite's anonymous namespace +using general_product_suite::diff_norm; +using general_product_suite::diff_norm_sp; +using general_product_suite::ForceLegacyEinsum; +using general_product_suite::make_patterned_array; +using general_product_suite::make_patterned_sparse_array; +using general_product_suite::make_patterned_tot_array; +using general_product_suite::TArrayToT; +using general_product_suite::tot_max_abs_diff; + +BOOST_AUTO_TEST_CASE(dist_dense) { + // batched general product: b fused, j contracted, i/k free + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_array(world, tr_b, 2.0); + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_sparse) { + // block-sparse batched general product (SparseShape::gemm_batched + + // the sparse (h,k)-keyed Summa groups, across ranks) + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr_a{{0, 2, 5}, {0, 3, 4}, {0, 2, 6, 7}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 5}, {0, 2, 6, 7}, {0, 4, 5}}; // b, j, k + auto a = make_patterned_sparse_array(world, tr_a, 1.0, 3); + auto b = make_patterned_sparse_array(world, tr_b, 2.0, 4); + TA::TSpArrayD c; + BOOST_REQUIRE_NO_THROW(c("b,i,k") = a("b,i,j") * b("b,j,k")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k"), "b,i,k"); + BOOST_CHECK_SMALL(diff_norm_sp(c, c_ref, "b,i,k"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_tot_mixed) { + // mixed T x ToT general product (inner Scale) across ranks + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr_a{{0, 2, 4}, {0, 2, 3}, {0, 2, 5}}; // b, i, j + TA::TiledRange tr_b{{0, 2, 4}, {0, 2, 5}, {0, 3, 4}}; // b, j, k + auto a = make_patterned_array(world, tr_a, 1.0); + auto b = make_patterned_tot_array(world, tr_b, {3}, 2.0); + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("b,i,k;m") = a("b,i,j") * b("b,j,k;m")); + auto c_ref = TA::einsum(a("b,i,j"), b("b,j,k;m"), "b,i,k;m"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_no_externals_dense) { + // no-external general product (Hadamard reduction shape) across ranks; + // the synthetic unit left-external mode handling under distribution + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr{{0, 2, 4}, {0, 3, 5}}; // i, j + auto a = make_patterned_array(world, tr, 1.0); + auto b = make_patterned_array(world, tr, 2.0); + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("i") = a("i,j") * b("i,j")); + auto c_ref = TA::einsum(a("i,j"), b("i,j"), "i"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "i"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_no_externals_tot) { + // no-external ToT general product across ranks + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr{{0, 3, 5}, {0, 2, 4}, {0, 2, 3}}; // x, i, j + auto a = make_patterned_tot_array(world, tr, {2}, 1.0); + auto b = make_patterned_tot_array(world, tr, {3}, 2.0); + TArrayToT c; + BOOST_REQUIRE_NO_THROW(c("i,j;a,b") = a("x,i,j;a") * b("x,i,j;b")); + auto c_ref = TA::einsum(a("x,i,j;a"), b("x,i,j;b"), "i,j;a,b"); + BOOST_CHECK_SMALL(tot_max_abs_diff(c, c_ref), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_inner_node_thc) { + // THC reconstruction in one expression (general products at inner nodes) + // across ranks + auto& world = TA::get_default_world(); + TA::TiledRange tr_x{{0, 2, 4}, {0, 3, 5}}; // orbital x auxiliary + TA::TiledRange tr_z{{0, 3, 5}, {0, 3, 5}}; // auxiliary x auxiliary + auto x = make_patterned_array(world, tr_x, 1.0); + auto z = make_patterned_array(world, tr_z, 2.0); + TA::TArrayD g; + BOOST_REQUIRE_NO_THROW(g("p,q,r,s") = x("p,r1") * x("q,r1") * z("r1,r2") * + x("r,r2") * x("s,r2")); + TA::TArrayD i1, i2, i3, g_ref; + i1("r1,p,q") = x("p,r1") * x("q,r1"); + i2("p,q,r2") = i1("r1,p,q") * z("r1,r2"); + i3("r2,p,q,r") = i2("p,q,r2") * x("r,r2"); + g_ref("p,q,r,s") = i3("r2,p,q,r") * x("s,r2"); + BOOST_CHECK_SMALL(diff_norm(g, g_ref, "p,q,r,s"), 1e-10); +} + +BOOST_AUTO_TEST_CASE(dist_no_externals_3d_grid) { + // a no-external product with several slabs: at np>1 the heuristic engages + // the 3-d grid (M=N=1 ⇒ proc_h > 1, the ranks spread over the slab axis), + // so the result tiles distribute across the h-planes rather than piling on + // one rank. Exercises the cross-plane result-tile transfer. + auto& world = TA::get_default_world(); + ForceLegacyEinsum legacy_oracle; + TA::TiledRange tr{{0, 2, 4, 6}, {0, 3, 5}}; // i (3 tiles), j + auto a = make_patterned_array(world, tr, 1.0); + auto b = make_patterned_array(world, tr, 2.0); + TA::TArrayD c; + BOOST_REQUIRE_NO_THROW(c("i") = a("i,j") * b("i,j")); + auto c_ref = TA::einsum(a("i,j"), b("i,j"), "i"); + BOOST_CHECK_SMALL(diff_norm(c, c_ref, "i"), 1e-10); + + // at np>1 the no-external result must NOT all live on one rank (the + // degeneracy the 3-d grid fixes) + if (world.size() > 1) { + std::set owners; + for (std::size_t o = 0; o < c.trange().tiles_range().volume(); ++o) + owners.insert(c.pmap()->owner(o)); + BOOST_CHECK_GT(owners.size(), 1u); + } +} + +BOOST_AUTO_TEST_SUITE_END() From bf43a86c9c0f10d90b286b0d5c40f31cc8ac9c59 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Fri, 12 Jun 2026 15:15:33 -0400 Subject: [PATCH 25/25] summa: address review comments on the 3-d batched grid - contraction_eval: clamp the SUMMA step-task pipeline depth to my_steps() (this rank's group's step count) instead of nsteps_. In the 3-d (proc_h_ > 1) case my_steps() < nsteps_, so clamping to nsteps_ pre-spawned surplus step tasks that all resolved to the terminating step (k_ == nsteps_). No-op for the 2-d path (my_slabs_ == nh_). - cont_engine: keep proc_h_stride_ == 0 for the ungrouped 2-d case (proc_h_ == 1), matching the field's documented invariant; only the grouped (proc_h_ > 1) grid uses P / proc_h_. - general_product test: correct the distributed suite header comment -- dist_inner_node_thc validates against explicit binary intermediates, not the legacy einsum oracle. --- src/TiledArray/dist_eval/contraction_eval.h | 8 ++++++-- src/TiledArray/expressions/cont_engine.h | 5 ++++- tests/general_product.cpp | 8 ++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/TiledArray/dist_eval/contraction_eval.h b/src/TiledArray/dist_eval/contraction_eval.h index fc5fa95895..aef9b0e853 100644 --- a/src/TiledArray/dist_eval/contraction_eval.h +++ b/src/TiledArray/dist_eval/contraction_eval.h @@ -1515,8 +1515,12 @@ class Summa template void make_next_step_tasks(Derived* task, ordinal_type depth) { - // Set the depth to be no greater than the maximum number steps - if (depth > owner_->nsteps_) depth = owner_->nsteps_; + // Set the depth to be no greater than the number of SUMMA steps this + // rank's group actually executes. In the 2-d (proc_h_ == 1) case this is + // nsteps_ (my_slabs_ == nh_); in the 3-d (proc_h_ > 1) case my_steps() < + // nsteps_, and clamping to nsteps_ would pre-spawn surplus step tasks + // that all resolve to the terminating step (k_ == nsteps_). + if (depth > owner_->my_steps()) depth = owner_->my_steps(); // Spawn n=depth step tasks for (; depth > 0ul; --depth) { diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 90bd23ec8f..1229b00a94 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -874,7 +874,10 @@ class ContEngine : public BinaryEngine { proc_h_ = std::min(n_slabs_, std::max(1ul, P / p2d_cap)); } - proc_h_stride_ = P / proc_h_; + // keep the invariant proc_h_ == 1 => proc_h_stride_ == 0 (the ungrouped + // 2-d case) so downstream logic can key off either field; for grouped + // grids it is the per-plane world-rank count P / proc_h_. + proc_h_stride_ = (proc_h_ == 1ul) ? 0ul : P / proc_h_; if (proc_h_ == 1ul) { // Construct the per-slab process grid over the whole world. diff --git a/tests/general_product.cpp b/tests/general_product.cpp index c9128b31d0..26316784d8 100644 --- a/tests/general_product.cpp +++ b/tests/general_product.cpp @@ -1350,8 +1350,12 @@ BOOST_AUTO_TEST_SUITE_END() // + np=1 evaluation), this suite carries NO label, so the CI harness runs it // at BOTH np=1 and np=2 (see tests/CMakeLists.txt: np-1 excludes @distributed, // np-2 excludes @serial). It exercises the batched Summa across ranks -- a -// path the serial suite never covered. Each case differential-tests the -// expression route against the legacy sub-World einsum oracle. +// path the serial suite never covered. Most cases differential-test the +// expression route against the legacy sub-World einsum oracle; the exception +// is dist_inner_node_thc, which checks a single fused expression against a +// reference built from explicit binary intermediates (the same expression +// machinery), validating that nested general products at inner nodes agree +// with the step-by-step evaluation. // // Most cases evaluate via the ordinary 2-d (proc_h == 1) batched Summa; // dist_no_externals_3d_grid exercises the 3-d (proc_h > 1) grid, where the