expressions: mixed T x ToT products in arbitrary expression trees (Phase F)#564
Open
evaleev wants to merge 6 commits into
Open
Conversation
…e-deduction down-pass The Phase E child-demand deduction moves from MultEngine into BinaryEngine::init_children_indices and ScalMultEngine adopts it, along with the full MultEngine routing for general products (inner_product_type_ classification + inner-General gate, init_struct_general, init_distribution_general, make_trange_general, make_dist_eval_general), replacing its use-einsum-instead exception.
…factor in inner-Scale ops The general-product ToT gate fired on a non-null but IDENTITY inner permutation (the bipartite perm is constructed whole when only the outer modes are re-permuted by the streaming wrapper); require a genuinely non-identity inner perm. The inner-Scale element ops (mixed T x ToT) never carried the expression-level scalar prefactor -- invisible while only MultEngine (factor == 1) reached them; the fallback op now absorbs factor_ and the factor-free fused arena ops are gated to factor == 1.
…nment into a block view)
…der products, kitchen-sink, blocks in trees A ToT x ToT general product with no external (free) outer indices -- every outer index fused or contracted -- segfaulted in the folded GEMM; gate it with an informative error (einsum() evaluates this shape natively via its no-external regime). New tests: a SUM nested under a product with a general summand (the down-pass prunes summand-internal contraction indices from the sum's demand by construction); the kitchen-sink expression combining a THC-like batching index, a mixed T x ToT general product, a ToT x ToT general product with an inner outer-product, and a ScalMult prefactor; a block leaf under an inner general node; a re-permuted general product assigned into a block view; the no-external gate.
… left-external mode
A general product whose every outer index is fused or contracted (e.g.
C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) folds to a GEMM with no
free modes, i.e. rank-0 tensors, which the tile kernels do not support
(this shape used to segfault through wild stride reads). Evaluate it
with a synthetic unit left-external mode instead: the folded product
becomes (1,K) x (K) -> (1), the exact shape of the already-supported
one-sided neB == 0 case. The unit mode lives only in the tile op's
GemmHelper; tranges, shapes and tiles carry the true (external-free)
ranks, and BatchedContractReduce / SparseShape::gemm_batched detect the
synthetic mode from the one-rank mismatch and pad their folded views
with a unit extent. Replaces the interim gate.
Tests: dense ToT (incl. the no-external root fed by a general T x ToT
inner node), plain dense (the Hadamard-reduction shape), and
block-sparse (exercising the gemm_batched unit handling), all
differential-tested against legacy einsum.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #563 (which stacks on #562). Completes mixed plain-tensor x tensor-of-tensors support in the expression layer for arbitrary expression trees, plus native support for no-external general products.
What
BinaryEngine::init_children_indices(shared), and ScalMultEngine adopts it together with the full general-product routing (init_struct_general,init_distribution_general,make_trange_general,make_dist_eval_general, inner-product classification) — replacing its "use einsum() instead" exception.w("b,i,k;x") = 2.0 * (a("b,i,j") * c("b,j,k;x"))now evaluates.w("i,j;x") = (g("b,i") * c("b,j;x")) * h("b").factor_; the factor-free fused arena ops are gated to factor == 1 (scaled products take the fallback).C("i,j;a,b") = A("x,i,j;a") * B("x,i,j;b")) folds to a GEMM with no free modes, i.e. rank-0 tensors, which the tile kernels do not support (this shape segfaulted through wild stride reads). It is now evaluated with a SYNTHETIC UNIT left-external mode: the folded product becomes(1,K) x (K) -> (1), the exact shape of the already-supported one-sided neB == 0 case. The unit mode lives only in the tile op's GemmHelper; tranges, shapes and tiles carry the true (external-free) ranks, andBatchedContractReduce/SparseShape::gemm_batcheddetect the synthetic mode from the one-rank mismatch and pad their folded views with a unit extent.Notable non-findings
(s("i,j") * t("j,m")) * c("m,k;x")ands("i,j") * (t("j,m") * c("m,k;x"))— already worked unchanged through the Phase E deduction (the empty-inner-demand convention for plain subtrees composes correctly).f("i,j") = a("x,i") * (b("x,k") * c("x,k,j") + d("x,j")), with a general product as a summand — work by construction: an Add'savailable_indices()is the leaf-union of its summands and the parent's demand intersection prunes summand-internal contraction indices automatically.Tests
expression_mixed_t_tot_depth2_chains(both nesting orders),expression_mixed_t_tot_inner_general,expression_mixed_t_tot_scaled.expression_general_sum_under_product;expression_general_kitchen_sink—w("i,j,m;a,b") = 2.0 * ((g("x,i") * cv("x,j;a")) * dv("x,i,m;b")), combining a THC-like batching index, a mixed T x ToT general product, a ToT x ToT general product with an inner outer-product, and a ScalMult prefactor.expression_general_product_block_operands,expression_general_product_into_block,expression_general_product_block_in_tree,expression_general_product_repermute_into_block.C("i") = A("i,j") * B("i,j")), and block-sparse (exercising thegemm_batchedunit handling), all differential-tested against legacy einsum.assign_subblock_block_base1failures), tot suites — green.Notes / still out of scope
!eregime ("hadamard-reduction-local", the arena kernel) handles them before the generalized-contraction dispatch and remains the right tool for distributed workloads — the engine's no-external path uses a degenerate 1x1 process grid (all result tiles on one rank), so it is correctness-first; unifying the einsum regime under the engine remains gated on a perf/distribution comparison (see the design doc's open decisions).