diff --git a/encodings/runend/src/compute/min_max.rs b/encodings/runend/src/compute/min_max.rs index 2a65d9de2bf..4dd39edc5f3 100644 --- a/encodings/runend/src/compute/min_max.rs +++ b/encodings/runend/src/compute/min_max.rs @@ -28,16 +28,16 @@ impl DynAggregateKernel for RunEndMinMaxKernel { batch: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { return Ok(None); - } + }; let Some(run_end) = batch.as_opt::() else { return Ok(None); }; let struct_dtype = make_minmax_dtype(batch.dtype()); - match min_max(run_end.values(), ctx)? { + match min_max(run_end.values(), ctx, *options)? { Some(result) => Ok(Some(Scalar::struct_( struct_dtype, vec![result.min, result.max], diff --git a/encodings/sparse/benches/sparse_pushdown.rs b/encodings/sparse/benches/sparse_pushdown.rs index d34758e7cd8..1f9b6e565bc 100644 --- a/encodings/sparse/benches/sparse_pushdown.rs +++ b/encodings/sparse/benches/sparse_pushdown.rs @@ -19,6 +19,7 @@ use vortex_array::Canonical; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::SkipNansOptions; use vortex_array::aggregate_fn::fns::is_constant::is_constant; use vortex_array::aggregate_fn::fns::min_max::min_max; use vortex_array::aggregate_fn::fns::null_count::null_count; @@ -106,7 +107,9 @@ fn sparse_min_max(bencher: Bencher) { bencher .with_inputs(|| (make_sparse(40_000, false), SESSION.create_execution_ctx())) .bench_values(|(array, mut ctx)| { - divan::black_box(min_max(&array, &mut ctx).vortex_expect("min_max")) + divan::black_box( + min_max(&array, &mut ctx, SkipNansOptions::default()).vortex_expect("min_max"), + ) }); } diff --git a/encodings/sparse/src/compute/min_max.rs b/encodings/sparse/src/compute/min_max.rs index cf89bc28da6..0162063b1e7 100644 --- a/encodings/sparse/src/compute/min_max.rs +++ b/encodings/sparse/src/compute/min_max.rs @@ -7,7 +7,6 @@ use vortex_array::IntoArray; use vortex_array::aggregate_fn::Accumulator; use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::aggregate_fn::DynAccumulator; -use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::fns::min_max::MinMax; use vortex_array::aggregate_fn::kernels::DynAggregateKernel; use vortex_array::arrays::ConstantArray; @@ -32,9 +31,9 @@ impl DynAggregateKernel for SparseMinMaxKernel { batch: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { return Ok(None); - } + }; let Some(sparse) = batch.as_opt::() else { return Ok(None); @@ -42,7 +41,7 @@ impl DynAggregateKernel for SparseMinMaxKernel { let patches = sparse.patches(); - let mut acc = Accumulator::try_new(MinMax, EmptyOptions, batch.dtype().clone())?; + let mut acc = Accumulator::try_new(MinMax, *options, batch.dtype().clone())?; if !patches.values().is_empty() { acc.accumulate(patches.values(), ctx)?; @@ -66,6 +65,7 @@ mod tests { use rstest::rstest; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; + use vortex_array::aggregate_fn::SkipNansOptions; use vortex_array::aggregate_fn::fns::min_max::MinMaxResult; use vortex_array::aggregate_fn::fns::min_max::min_max; use vortex_array::scalar::Scalar; @@ -100,10 +100,18 @@ mod tests { #[case(Sparse::try_new(buffer![0u64, 1, 2].into_array(), buffer![7i32, 3, 9].into_array(), 3, Scalar::from(99i32)).unwrap())] fn min_max_matches_canonical(#[case] array: SparseArray) { let arr = array.into_array(); - let kernel: Option = - min_max(&arr, &mut SESSION.create_execution_ctx()).unwrap(); - let canonical: Option = - min_max(&arr, &mut CANONICAL_SESSION.create_execution_ctx()).unwrap(); + let kernel: Option = min_max( + &arr, + &mut SESSION.create_execution_ctx(), + SkipNansOptions::default(), + ) + .unwrap(); + let canonical: Option = min_max( + &arr, + &mut CANONICAL_SESSION.create_execution_ctx(), + SkipNansOptions::default(), + ) + .unwrap(); assert_eq!(kernel, canonical); } } diff --git a/encodings/sparse/src/compute/sum.rs b/encodings/sparse/src/compute/sum.rs index a47884e6f0f..8f85124b3b5 100644 --- a/encodings/sparse/src/compute/sum.rs +++ b/encodings/sparse/src/compute/sum.rs @@ -7,7 +7,6 @@ use vortex_array::IntoArray; use vortex_array::aggregate_fn::Accumulator; use vortex_array::aggregate_fn::AggregateFnRef; use vortex_array::aggregate_fn::DynAccumulator; -use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::fns::sum::Sum; use vortex_array::aggregate_fn::kernels::DynAggregateKernel; use vortex_array::arrays::ConstantArray; @@ -34,9 +33,9 @@ impl DynAggregateKernel for SparseSumKernel { batch: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { return Ok(None); - } + }; let Some(sparse) = batch.as_opt::() else { return Ok(None); @@ -47,8 +46,8 @@ impl DynAggregateKernel for SparseSumKernel { // Build a fresh Sum accumulator over the array dtype and fold in the fill and patch // contributions. The accumulator's existing semantics (checked overflow → null - // partial) are preserved. - let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone())?; + // partial, NaN handling per the options) are preserved. + let mut acc = Accumulator::try_new(Sum, *options, batch.dtype().clone())?; if n_fill > 0 { let fill_array = ConstantArray::new(sparse.fill_scalar().clone(), n_fill).into_array(); diff --git a/fuzz/src/array/min_max.rs b/fuzz/src/array/min_max.rs index 5600fe27568..2cb22a4f5ae 100644 --- a/fuzz/src/array/min_max.rs +++ b/fuzz/src/array/min_max.rs @@ -4,6 +4,7 @@ use vortex_array::Canonical; use vortex_array::ExecutionCtx; use vortex_array::IntoArray as _; +use vortex_array::aggregate_fn::SkipNansOptions; use vortex_array::aggregate_fn::fns::min_max::MinMaxResult; use vortex_array::aggregate_fn::fns::min_max::min_max; use vortex_error::VortexResult; @@ -13,5 +14,5 @@ pub fn min_max_canonical_array( canonical: Canonical, ctx: &mut ExecutionCtx, ) -> VortexResult> { - min_max(&canonical.into_array(), ctx) + min_max(&canonical.into_array(), ctx, SkipNansOptions::default()) } diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index e19a38bae86..304991444cd 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -43,6 +43,7 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::SkipNansOptions; use vortex_array::aggregate_fn::fns::all_non_distinct::all_non_distinct; use vortex_array::aggregate_fn::fns::min_max::MinMaxResult; use vortex_array::aggregate_fn::fns::min_max::min_max; @@ -667,7 +668,7 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult { assert_scalar_eq(&expected.scalar(), &sum_result, i)?; } Action::MinMax => { - let min_max_result = min_max(¤t_array, &mut ctx) + let min_max_result = min_max(¤t_array, &mut ctx, SkipNansOptions::default()) .vortex_expect("min_max operation should succeed in fuzz test"); assert_min_max_eq(expected.min_max().as_ref(), min_max_result.as_ref(), i)?; } diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index b067314c1d9..72f496a7857 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -14,8 +14,8 @@ use vortex_array::LEGACY_SESSION; use vortex_array::VortexSessionExecute; use vortex_array::aggregate_fn::AggregateFnVTable; use vortex_array::aggregate_fn::DynGroupedAccumulator; -use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::GroupedAccumulator; +use vortex_array::aggregate_fn::SkipNansOptions; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; use vortex_array::arrays::ListViewArray; @@ -149,10 +149,14 @@ fn list_element_dtype(list_view: &ArrayRef) -> DType { fn grouped_accumulator(list_view: &ArrayRef, vtable: V) -> ArrayRef where - V: AggregateFnVTable + Clone, + V: AggregateFnVTable + Clone, { - let mut acc = - GroupedAccumulator::try_new(vtable, EmptyOptions, list_element_dtype(list_view)).unwrap(); + let mut acc = GroupedAccumulator::try_new( + vtable, + SkipNansOptions::default(), + list_element_dtype(list_view), + ) + .unwrap(); acc.accumulate_list(list_view, &mut LEGACY_SESSION.create_execution_ctx()) .unwrap(); divan::black_box(acc.finish().unwrap()) diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 4e1aff762ad..459bbd24f88 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -274,7 +274,7 @@ mod tests { use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::combined::Combined; use crate::aggregate_fn::combined::PairOptions; use crate::aggregate_fn::fns::mean::Mean; @@ -348,7 +348,7 @@ mod tests { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); Accumulator::try_new( Mean::combined(), - PairOptions(EmptyOptions, EmptyOptions), + PairOptions(SkipNansOptions::default(), SkipNansOptions::default()), dtype, ) } diff --git a/vortex-array/src/aggregate_fn/fns/bounded_max/mod.rs b/vortex-array/src/aggregate_fn/fns/bounded_max/mod.rs index 37188f25797..fd648fd7949 100644 --- a/vortex-array/src/aggregate_fn/fns/bounded_max/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/bounded_max/mod.rs @@ -20,7 +20,7 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::max::Max; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::min_max::min_max; @@ -136,7 +136,11 @@ impl AggregateFnVTable for BoundedMax { }; } - if requested.is::() { + // The stored bound skips NaNs, so it cannot stand in for a NaN-including maximum. + if requested + .as_opt::() + .is_some_and(|options| options.skip_nans) + { AggregateFnSatisfaction::Approximate } else { AggregateFnSatisfaction::No @@ -192,7 +196,7 @@ impl AggregateFnVTable for BoundedMax { Columnar::Canonical(canonical) => canonical.clone().into_array(), Columnar::Constant(constant) => constant.clone().into_array(), }; - let Some(result) = min_max(&array, ctx)? else { + let Some(result) = min_max(&array, ctx, SkipNansOptions::default())? else { return Ok(()); }; match truncate_max(result.max, partial.max_bytes.get())? { @@ -213,7 +217,7 @@ impl AggregateFnVTable for BoundedMax { fn supported_dtype<'a>(_options: &BoundedMaxOptions, input_dtype: &'a DType) -> Option<&'a DType> { MinMax - .return_dtype(&EmptyOptions, input_dtype) + .return_dtype(&SkipNansOptions::default(), input_dtype) .map(|_| input_dtype) } @@ -253,7 +257,7 @@ mod tests { use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::AggregateFnVTableExt; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_max::BoundedMaxOptions; use crate::aggregate_fn::fns::max::Max; @@ -416,15 +420,23 @@ mod tests { AggregateFnSatisfaction::No ); assert_eq!( - stored.can_satisfy(&Max.bind(EmptyOptions)), + stored.can_satisfy(&Max.bind(SkipNansOptions::default())), AggregateFnSatisfaction::Approximate ); assert_eq!( - Max.bind(EmptyOptions).can_satisfy(&stored), + stored.can_satisfy(&Max.bind(SkipNansOptions::include())), + AggregateFnSatisfaction::No + ); + assert_eq!( + Max.bind(SkipNansOptions::include()).can_satisfy(&stored), + AggregateFnSatisfaction::No + ); + assert_eq!( + Max.bind(SkipNansOptions::default()).can_satisfy(&stored), AggregateFnSatisfaction::Approximate ); assert_eq!( - stored.can_satisfy(&Min.bind(EmptyOptions)), + stored.can_satisfy(&Min.bind(SkipNansOptions::default())), AggregateFnSatisfaction::No ); } diff --git a/vortex-array/src/aggregate_fn/fns/bounded_min/mod.rs b/vortex-array/src/aggregate_fn/fns/bounded_min/mod.rs index 661b7fe1227..971db20cb5a 100644 --- a/vortex-array/src/aggregate_fn/fns/bounded_min/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/bounded_min/mod.rs @@ -20,7 +20,7 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::min_max::min_max; @@ -126,7 +126,11 @@ impl AggregateFnVTable for BoundedMin { }; } - if requested.is::() { + // The stored bound skips NaNs, so it cannot stand in for a NaN-including minimum. + if requested + .as_opt::() + .is_some_and(|options| options.skip_nans) + { AggregateFnSatisfaction::Approximate } else { AggregateFnSatisfaction::No @@ -182,7 +186,7 @@ impl AggregateFnVTable for BoundedMin { Columnar::Canonical(canonical) => canonical.clone().into_array(), Columnar::Constant(constant) => constant.clone().into_array(), }; - let Some(result) = min_max(&array, ctx)? else { + let Some(result) = min_max(&array, ctx, SkipNansOptions::default())? else { return Ok(()); }; if let Some(bound) = truncate_min(result.min, partial.max_bytes.get())? { @@ -202,7 +206,7 @@ impl AggregateFnVTable for BoundedMin { fn supported_dtype<'a>(_options: &BoundedMinOptions, input_dtype: &'a DType) -> Option<&'a DType> { MinMax - .return_dtype(&EmptyOptions, input_dtype) + .return_dtype(&SkipNansOptions::default(), input_dtype) .map(|_| input_dtype) } @@ -241,7 +245,7 @@ mod tests { use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::AggregateFnVTableExt; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::bounded_min::BoundedMin; use crate::aggregate_fn::fns::bounded_min::BoundedMinOptions; use crate::aggregate_fn::fns::max::Max; @@ -350,15 +354,23 @@ mod tests { AggregateFnSatisfaction::No ); assert_eq!( - stored.can_satisfy(&Min.bind(EmptyOptions)), + stored.can_satisfy(&Min.bind(SkipNansOptions::default())), AggregateFnSatisfaction::Approximate ); assert_eq!( - Min.bind(EmptyOptions).can_satisfy(&stored), + stored.can_satisfy(&Min.bind(SkipNansOptions::include())), + AggregateFnSatisfaction::No + ); + assert_eq!( + Min.bind(SkipNansOptions::include()).can_satisfy(&stored), + AggregateFnSatisfaction::No + ); + assert_eq!( + Min.bind(SkipNansOptions::default()).can_satisfy(&stored), AggregateFnSatisfaction::Approximate ); assert_eq!( - stored.can_satisfy(&Max.bind(EmptyOptions)), + stored.can_satisfy(&Max.bind(SkipNansOptions::default())), AggregateFnSatisfaction::No ); } diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index fb94489dde0..90f046c3045 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -27,7 +27,12 @@ impl DynGroupedAggregateKernel for CountGroupedKernel { groups: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { + return Ok(None); + }; + // NaN-skipping counts over floats must inspect the element values, which this + // validity-only kernel cannot do; fall back to the per-group accumulator path. + if options.skip_nans && groups.elements().dtype().is_float() { return Ok(None); } try_grouped_count(groups, ctx) @@ -90,8 +95,8 @@ mod tests { use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::count::Count; use crate::arrays::FixedSizeListArray; use crate::arrays::ListViewArray; @@ -106,7 +111,8 @@ mod tests { /// Run a grouped count through the accumulator. fn grouped_count_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, elem_dtype.clone())?; + let mut acc = + GroupedAccumulator::try_new(Count, SkipNansOptions::default(), elem_dtype.clone())?; acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; acc.finish() } @@ -200,6 +206,28 @@ mod tests { Ok(()) } + #[test] + fn fixed_size_counts_float_nans() -> VortexResult<()> { + let elements = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(2.0)]) + .into_array(); + let elem_dtype = DType::Primitive(PType::F64, Nullable); + let groups = + FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?.into_array(); + + // NaNs are excluded by default and counted otherwise. + let actual = grouped_count_actual(&groups, &elem_dtype)?; + let expected = PrimitiveArray::new(buffer![1u64, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &expected); + + let mut acc = GroupedAccumulator::try_new(Count, SkipNansOptions::include(), elem_dtype)?; + acc.accumulate_list(&groups, &mut LEGACY_SESSION.create_execution_ctx())?; + let actual = acc.finish()?; + let expected = PrimitiveArray::new(buffer![2u64, 1], Validity::NonNullable).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn fixed_size_counts_with_nulls() -> VortexResult<()> { let elements = diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 1fe984fb099..d109352b860 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -11,7 +11,8 @@ use crate::Columnar; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; +use crate::aggregate_fn::fns::nan_count::nan_count; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -21,12 +22,23 @@ use crate::scalar::Scalar; /// /// Applies to all types. Returns a `u64` count. /// The identity value is zero. +/// +/// For float inputs, NaN handling is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values are treated as missing and excluded from the count, otherwise they are +/// counted like any other non-null value. #[derive(Clone, Debug)] pub struct Count; +/// Partial accumulator state for the count aggregate. +pub struct CountPartial { + count: u64, + /// Whether NaN values must be excluded from the count (float input with `skip_nans`). + exclude_nans: bool, +} + impl AggregateFnVTable for Count { - type Options = EmptyOptions; - type Partial = u64; + type Options = SkipNansOptions; + type Partial = CountPartial; fn id(&self) -> AggregateFnId { AggregateFnId::new("vortex.count") @@ -46,10 +58,13 @@ impl AggregateFnVTable for Count { fn empty_partial( &self, - _options: &Self::Options, - _input_dtype: &DType, + options: &Self::Options, + input_dtype: &DType, ) -> VortexResult { - Ok(0u64) + Ok(CountPartial { + count: 0, + exclude_nans: options.skip_nans && input_dtype.is_float(), + }) } fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { @@ -57,16 +72,16 @@ impl AggregateFnVTable for Count { .as_primitive() .typed_value::() .vortex_expect("count partial should not be null"); - *partial += val; + partial.count += val; Ok(()) } fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { - Ok(Scalar::primitive(*partial, Nullability::NonNullable)) + Ok(Scalar::primitive(partial.count, Nullability::NonNullable)) } fn reset(&self, partial: &mut Self::Partial) { - *partial = 0; + partial.count = 0; } #[inline] @@ -80,7 +95,12 @@ impl AggregateFnVTable for Count { batch: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { - *state += batch.valid_count(ctx)? as u64; + let mut count = batch.valid_count(ctx)? as u64; + if state.exclude_nans { + // `nan_count` shortcircuits on an exact `Stat::NaNCount` before scanning the batch. + count -= nan_count(batch, ctx)? as u64; + } + state.count += count; Ok(true) } @@ -104,19 +124,21 @@ impl AggregateFnVTable for Count { #[cfg(test)] mod tests { + use std::sync::LazyLock; + use vortex_buffer::buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::ArrayRef; use crate::ExecutionCtx; use crate::IntoArray; - use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; @@ -124,11 +146,17 @@ mod tests { use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; use crate::scalar::Scalar; + use crate::scalar::ScalarValue; use crate::validity::Validity; + static SESSION: LazyLock = LazyLock::new(vortex_array::array_session); + pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?; + let mut acc = + Accumulator::try_new(Count, SkipNansOptions::default(), array.dtype().clone())?; acc.accumulate(array, ctx)?; let result = acc.finish()?; @@ -144,7 +172,7 @@ mod tests { fn count_all_valid() -> VortexResult<()> { let array = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&array, &mut ctx)?, 5); Ok(()) } @@ -153,7 +181,7 @@ mod tests { fn count_with_nulls() -> VortexResult<()> { let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) .into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&array, &mut ctx)?, 3); Ok(()) } @@ -161,7 +189,7 @@ mod tests { #[test] fn count_all_null() -> VortexResult<()> { let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&array, &mut ctx)?, 0); Ok(()) } @@ -169,7 +197,7 @@ mod tests { #[test] fn count_empty() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Count, SkipNansOptions::default(), dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -177,9 +205,9 @@ mod tests { #[test] fn count_multi_batch() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Count, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -194,9 +222,9 @@ mod tests { #[test] fn count_finish_resets_state() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Count, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -213,7 +241,7 @@ mod tests { #[test] fn count_state_merge() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = Count.empty_partial(&EmptyOptions, &dtype)?; + let mut state = Count.empty_partial(&SkipNansOptions::default(), &dtype)?; let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable); Count.combine_partials(&mut state, scalar1)?; @@ -227,10 +255,73 @@ mod tests { Ok(()) } + fn count_with_options( + array: &ArrayRef, + ctx: &mut ExecutionCtx, + options: SkipNansOptions, + ) -> VortexResult { + let mut acc = Accumulator::try_new(Count, options, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + Ok(acc + .finish()? + .as_primitive() + .typed_value::() + .vortex_expect("count result should not be null")) + } + + #[test] + fn count_float_excludes_nans_by_default() -> VortexResult<()> { + let array = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)]) + .into_array(); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 2); + Ok(()) + } + + #[test] + fn count_float_includes_nans_when_not_skipping() -> VortexResult<()> { + let array = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(f64::NAN), None, Some(3.0)]) + .into_array(); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!( + count_with_options(&array, &mut ctx, SkipNansOptions::include())?, + 3 + ); + Ok(()) + } + + #[test] + fn count_float_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> { + // The array has no NaNs; a planted exact NaNCount stat proves the count is derived from + // the stat rather than a scan. + let array = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0], Validity::NonNullable).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(3u64))); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 1); + Ok(()) + } + + #[test] + fn count_constant_nan() -> VortexResult<()> { + let array = ConstantArray::new(f64::NAN, 5).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 0); + assert_eq!( + count_with_options(&array, &mut ctx, SkipNansOptions::include())?, + 5 + ); + Ok(()) + } + #[test] fn count_constant_non_null() -> VortexResult<()> { let array = ConstantArray::new(42i32, 10); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&array.into_array(), &mut ctx)?, 10); Ok(()) } @@ -241,7 +332,7 @@ mod tests { Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), 10, ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&array.into_array(), &mut ctx)?, 0); Ok(()) } @@ -252,7 +343,7 @@ mod tests { let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]); let dtype = chunk1.dtype().clone(); let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/max/mod.rs b/vortex-array/src/aggregate_fn/fns/max/mod.rs index 755e1a4ed35..bc67b7a862b 100644 --- a/vortex-array/src/aggregate_fn/fns/max/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/max/mod.rs @@ -12,15 +12,24 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::min_max::min_max; +use crate::aggregate_fn::fns::min_max::nan_scalar; +use crate::aggregate_fn::fns::min_max::scalar_is_nan; use crate::dtype::DType; +use crate::expr::stats::Precision; +use crate::expr::stats::Stat; +use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; use crate::partial_ord::partial_max; use crate::scalar::Scalar; /// Compute the maximum non-null value of an array. +/// +/// NaN handling for float inputs is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values are ignored, otherwise any NaN value poisons the maximum to NaN. #[derive(Clone, Debug)] pub struct Max; @@ -28,6 +37,7 @@ pub struct Max; pub struct MaxPartial { max: Option, element_dtype: DType, + skip_nans: bool, } impl MaxPartial { @@ -36,15 +46,32 @@ impl MaxPartial { return; } + // NaN scalars are incomparable under `partial_max`; they poison the maximum when NaNs + // participate, and are dropped when they are skipped. + if scalar_is_nan(&max) || self.is_poisoned() { + if !self.skip_nans { + self.poison(); + } + return; + } + self.max = Some(match self.max.take() { Some(current) => partial_max(max, current).vortex_expect("incomparable max scalars"), None => max, }); } + + fn poison(&mut self) { + self.max = Some(nan_scalar(&self.element_dtype)); + } + + fn is_poisoned(&self) -> bool { + self.max.as_ref().is_some_and(scalar_is_nan) + } } impl AggregateFnVTable for Max { - type Options = EmptyOptions; + type Options = SkipNansOptions; type Partial = MaxPartial; fn id(&self) -> AggregateFnId { @@ -55,20 +82,24 @@ impl AggregateFnVTable for Max { Ok(None) } - fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { MinMax - .return_dtype(&EmptyOptions, input_dtype) + .return_dtype(options, input_dtype) .map(|_| input_dtype.as_nullable()) } fn can_satisfy( &self, - _options: &Self::Options, + options: &Self::Options, requested: &AggregateFnRef, ) -> AggregateFnSatisfaction { - if requested.is::() { + if requested + .as_opt::() + .is_some_and(|other| other == options) + { AggregateFnSatisfaction::Exact - } else if requested.is::() { + } else if requested.is::() && options.skip_nans { + // A NaN-including maximum may be NaN, which is not a usable upper bound. AggregateFnSatisfaction::Approximate } else { AggregateFnSatisfaction::No @@ -81,12 +112,13 @@ impl AggregateFnVTable for Max { fn empty_partial( &self, - _options: &Self::Options, + options: &Self::Options, input_dtype: &DType, ) -> VortexResult { Ok(MaxPartial { max: None, element_dtype: input_dtype.clone(), + skip_nans: options.skip_nans, }) } @@ -107,8 +139,38 @@ impl AggregateFnVTable for Max { partial.max = None; } - fn is_saturated(&self, _partial: &Self::Partial) -> bool { - false + fn is_saturated(&self, partial: &Self::Partial) -> bool { + // A poisoned NaN-including maximum is fully determined. + partial.is_poisoned() + } + + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + // NaN-aware shortcircuits only apply to the NaN-including float maximum; everything else + // takes the default dispatch path. + if partial.skip_nans || !partial.element_dtype.is_float() { + return Ok(false); + } + match batch.statistics().get_as::(Stat::NaNCount) { + Precision::Exact(0) => { + // NaN-free batch: the cached NaN-skipping maximum (if any) is valid. `to_scalar` + // re-casts to the result dtype, so the cached scalar can merge as-is. + if let Some(max) = batch.statistics().get(Stat::Max).as_exact() { + partial.merge(max); + return Ok(true); + } + Ok(false) + } + Precision::Exact(_) => { + partial.poison(); + Ok(true) + } + _ => Ok(false), + } } fn accumulate( @@ -123,7 +185,10 @@ impl AggregateFnVTable for Max { Columnar::Canonical(canonical) => canonical.clone().into_array(), Columnar::Constant(constant) => constant.clone().into_array(), }; - if let Some(result) = min_max(&array, ctx)? { + let options = SkipNansOptions { + skip_nans: partial.skip_nans, + }; + if let Some(result) = min_max(&array, ctx, options)? { partial.merge(result.max); } Ok(()) @@ -148,7 +213,7 @@ mod tests { use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::max::Max; use crate::arrays::PrimitiveArray; use crate::dtype::DType; @@ -164,7 +229,7 @@ mod tests { fn max_aggregate_fn() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Max, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Max, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -182,7 +247,7 @@ mod tests { #[test] fn max_empty_group_returns_null() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Max, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Max, SkipNansOptions::default(), dtype)?; assert_eq!( acc.finish()?, @@ -191,6 +256,70 @@ mod tests { Ok(()) } + #[test] + fn max_with_nan_not_skipping() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Max, SkipNansOptions::include(), dtype)?; + + let batch = PrimitiveArray::new(buffer![1.0f64, f64::NAN, -5.0], Validity::NonNullable) + .into_array(); + acc.accumulate(&batch, &mut ctx)?; + assert!(acc.is_saturated()); + + let result = acc.finish()?; + assert!( + result + .as_primitive() + .typed_value::() + .is_some_and(f64::is_nan) + ); + Ok(()) + } + + #[test] + fn max_not_skipping_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + // The array has no NaNs; a planted exact NaNCount stat proves the poisoning came from + // the stat rather than a scan. + let batch = PrimitiveArray::new(buffer![1.0f64, 2.0], Validity::NonNullable).into_array(); + batch + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(1u64))); + let mut acc = Accumulator::try_new(Max, SkipNansOptions::include(), batch.dtype().clone())?; + acc.accumulate(&batch, &mut ctx)?; + let result = acc.finish()?; + assert!( + result + .as_primitive() + .typed_value::() + .is_some_and(f64::is_nan) + ); + Ok(()) + } + + #[test] + fn max_nan_including_nullable_cached_stat() -> VortexResult<()> { + // A nullable float array's cached Max stat is reconstructed as a nullable scalar. The + // NaN-including shortcircuit merges it as-is; `to_scalar` re-casts to the result dtype. + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let array = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(2.0), Some(3.0)]).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64))); + array + .statistics() + .set(Stat::Max, Precision::Exact(ScalarValue::from(3.0f64))); + let mut acc = Accumulator::try_new(Max, SkipNansOptions::include(), array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + assert_eq!( + acc.finish()?, + Scalar::primitive(3.0f64, Nullability::Nullable) + ); + Ok(()) + } + #[test] fn max_casts_nonnullable_legacy_stat_to_nullable_partial() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); @@ -198,7 +327,7 @@ mod tests { batch .statistics() .set(Stat::Max, Precision::Exact(ScalarValue::from(25i32))); - let mut acc = Accumulator::try_new(Max, EmptyOptions, batch.dtype().clone())?; + let mut acc = Accumulator::try_new(Max, SkipNansOptions::default(), batch.dtype().clone())?; acc.accumulate(&batch, &mut ctx)?; diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 17fb2616d44..3338ad96084 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -10,7 +10,7 @@ use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::combined::BinaryCombined; use crate::aggregate_fn::combined::Combined; use crate::aggregate_fn::combined::CombinedOptions; @@ -30,7 +30,7 @@ use crate::scalar_fn::fns::operators::Operator; pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { let mut acc = Accumulator::try_new( Mean::combined(), - PairOptions(EmptyOptions, EmptyOptions), + PairOptions(SkipNansOptions::default(), SkipNansOptions::default()), array.dtype().clone(), )?; acc.accumulate(array, ctx)?; @@ -231,6 +231,34 @@ mod tests { Ok(()) } + #[test] + fn mean_skips_nans_by_default() -> VortexResult<()> { + // NaNs are excluded from both the sum and the count. + let array = + PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(2.0)); + Ok(()) + } + + #[test] + fn mean_with_nan_not_skipping() -> VortexResult<()> { + let array = + PrimitiveArray::new(buffer![1.0f64, f64::NAN, 3.0], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let keep_nans = SkipNansOptions::include(); + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(keep_nans, keep_nans), + array.dtype().clone(), + )?; + acc.accumulate(&array, &mut ctx)?; + let result = acc.finish()?; + assert!(result.as_primitive().as_::().is_some_and(f64::is_nan)); + Ok(()) + } + #[test] fn mean_all_null_returns_nan() -> VortexResult<()> { let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); @@ -246,7 +274,7 @@ mod tests { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); let mut acc = Accumulator::try_new( Mean::combined(), - PairOptions(EmptyOptions, EmptyOptions), + PairOptions(SkipNansOptions::default(), SkipNansOptions::default()), dtype, )?; diff --git a/vortex-array/src/aggregate_fn/fns/min/mod.rs b/vortex-array/src/aggregate_fn/fns/min/mod.rs index b176bbaf742..2063f99affb 100644 --- a/vortex-array/src/aggregate_fn/fns/min/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/min/mod.rs @@ -12,15 +12,24 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnSatisfaction; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::bounded_min::BoundedMin; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::min_max::min_max; +use crate::aggregate_fn::fns::min_max::nan_scalar; +use crate::aggregate_fn::fns::min_max::scalar_is_nan; use crate::dtype::DType; +use crate::expr::stats::Precision; +use crate::expr::stats::Stat; +use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; use crate::partial_ord::partial_min; use crate::scalar::Scalar; /// Compute the minimum non-null value of an array. +/// +/// NaN handling for float inputs is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values are ignored, otherwise any NaN value poisons the minimum to NaN. #[derive(Clone, Debug)] pub struct Min; @@ -28,6 +37,7 @@ pub struct Min; pub struct MinPartial { min: Option, element_dtype: DType, + skip_nans: bool, } impl MinPartial { @@ -36,15 +46,32 @@ impl MinPartial { return; } + // NaN scalars are incomparable under `partial_min`; they poison the minimum when NaNs + // participate, and are dropped when they are skipped. + if scalar_is_nan(&min) || self.is_poisoned() { + if !self.skip_nans { + self.poison(); + } + return; + } + self.min = Some(match self.min.take() { Some(current) => partial_min(min, current).vortex_expect("incomparable min scalars"), None => min, }); } + + fn poison(&mut self) { + self.min = Some(nan_scalar(&self.element_dtype)); + } + + fn is_poisoned(&self) -> bool { + self.min.as_ref().is_some_and(scalar_is_nan) + } } impl AggregateFnVTable for Min { - type Options = EmptyOptions; + type Options = SkipNansOptions; type Partial = MinPartial; fn id(&self) -> AggregateFnId { @@ -55,20 +82,24 @@ impl AggregateFnVTable for Min { Ok(None) } - fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { MinMax - .return_dtype(&EmptyOptions, input_dtype) + .return_dtype(options, input_dtype) .map(|_| input_dtype.as_nullable()) } fn can_satisfy( &self, - _options: &Self::Options, + options: &Self::Options, requested: &AggregateFnRef, ) -> AggregateFnSatisfaction { - if requested.is::() { + if requested + .as_opt::() + .is_some_and(|other| other == options) + { AggregateFnSatisfaction::Exact - } else if requested.is::() { + } else if requested.is::() && options.skip_nans { + // A NaN-including minimum may be NaN, which is not a usable lower bound. AggregateFnSatisfaction::Approximate } else { AggregateFnSatisfaction::No @@ -81,12 +112,13 @@ impl AggregateFnVTable for Min { fn empty_partial( &self, - _options: &Self::Options, + options: &Self::Options, input_dtype: &DType, ) -> VortexResult { Ok(MinPartial { min: None, element_dtype: input_dtype.clone(), + skip_nans: options.skip_nans, }) } @@ -107,8 +139,38 @@ impl AggregateFnVTable for Min { partial.min = None; } - fn is_saturated(&self, _partial: &Self::Partial) -> bool { - false + fn is_saturated(&self, partial: &Self::Partial) -> bool { + // A poisoned NaN-including minimum is fully determined. + partial.is_poisoned() + } + + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + // NaN-aware shortcircuits only apply to the NaN-including float minimum; everything else + // takes the default dispatch path. + if partial.skip_nans || !partial.element_dtype.is_float() { + return Ok(false); + } + match batch.statistics().get_as::(Stat::NaNCount) { + Precision::Exact(0) => { + // NaN-free batch: the cached NaN-skipping minimum (if any) is valid. `to_scalar` + // re-casts to the result dtype, so the cached scalar can merge as-is. + if let Some(min) = batch.statistics().get(Stat::Min).as_exact() { + partial.merge(min); + return Ok(true); + } + Ok(false) + } + Precision::Exact(_) => { + partial.poison(); + Ok(true) + } + _ => Ok(false), + } } fn accumulate( @@ -123,7 +185,10 @@ impl AggregateFnVTable for Min { Columnar::Canonical(canonical) => canonical.clone().into_array(), Columnar::Constant(constant) => constant.clone().into_array(), }; - if let Some(result) = min_max(&array, ctx)? { + let options = SkipNansOptions { + skip_nans: partial.skip_nans, + }; + if let Some(result) = min_max(&array, ctx, options)? { partial.merge(result.min); } Ok(()) @@ -148,7 +213,7 @@ mod tests { use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min::Min; use crate::arrays::PrimitiveArray; use crate::dtype::DType; @@ -164,7 +229,7 @@ mod tests { fn min_aggregate_fn() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Min, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Min, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -182,7 +247,7 @@ mod tests { #[test] fn min_empty_group_returns_null() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Min, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Min, SkipNansOptions::default(), dtype)?; assert_eq!( acc.finish()?, @@ -191,6 +256,70 @@ mod tests { Ok(()) } + #[test] + fn min_with_nan_not_skipping() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Min, SkipNansOptions::include(), dtype)?; + + let batch = PrimitiveArray::new(buffer![1.0f64, f64::NAN, -5.0], Validity::NonNullable) + .into_array(); + acc.accumulate(&batch, &mut ctx)?; + assert!(acc.is_saturated()); + + let result = acc.finish()?; + assert!( + result + .as_primitive() + .typed_value::() + .is_some_and(f64::is_nan) + ); + Ok(()) + } + + #[test] + fn min_not_skipping_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + // The array has no NaNs; a planted exact NaNCount stat proves the poisoning came from + // the stat rather than a scan. + let batch = PrimitiveArray::new(buffer![1.0f64, 2.0], Validity::NonNullable).into_array(); + batch + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(1u64))); + let mut acc = Accumulator::try_new(Min, SkipNansOptions::include(), batch.dtype().clone())?; + acc.accumulate(&batch, &mut ctx)?; + let result = acc.finish()?; + assert!( + result + .as_primitive() + .typed_value::() + .is_some_and(f64::is_nan) + ); + Ok(()) + } + + #[test] + fn min_nan_including_nullable_cached_stat() -> VortexResult<()> { + // A nullable float array's cached Min stat is reconstructed as a nullable scalar. The + // NaN-including shortcircuit merges it as-is; `to_scalar` re-casts to the result dtype. + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let array = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(2.0), Some(3.0)]).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64))); + array + .statistics() + .set(Stat::Min, Precision::Exact(ScalarValue::from(1.0f64))); + let mut acc = Accumulator::try_new(Min, SkipNansOptions::include(), array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + assert_eq!( + acc.finish()?, + Scalar::primitive(1.0f64, Nullability::Nullable) + ); + Ok(()) + } + #[test] fn min_casts_nonnullable_legacy_stat_to_nullable_partial() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); @@ -198,7 +327,7 @@ mod tests { batch .statistics() .set(Stat::Min, Precision::Exact(ScalarValue::from(3i32))); - let mut acc = Accumulator::try_new(Min, EmptyOptions, batch.dtype().clone())?; + let mut acc = Accumulator::try_new(Min, SkipNansOptions::default(), batch.dtype().clone())?; acc.accumulate(&batch, &mut ctx)?; diff --git a/vortex-array/src/aggregate_fn/fns/min_max/extension.rs b/vortex-array/src/aggregate_fn/fns/min_max/extension.rs index 60962dd008e..b366edb7b5f 100644 --- a/vortex-array/src/aggregate_fn/fns/min_max/extension.rs +++ b/vortex-array/src/aggregate_fn/fns/min_max/extension.rs @@ -7,6 +7,7 @@ use super::MinMaxPartial; use super::MinMaxResult; use super::min_max; use crate::ExecutionCtx; +use crate::aggregate_fn::SkipNansOptions; use crate::arrays::ExtensionArray; use crate::arrays::extension::ExtensionArrayExt; use crate::dtype::Nullability; @@ -18,11 +19,12 @@ pub(super) fn accumulate_extension( ctx: &mut ExecutionCtx, ) -> VortexResult<()> { let non_nullable_ext_dtype = array.ext_dtype().with_nullability(Nullability::NonNullable); - let local = - min_max(array.storage_array(), ctx)?.map(|MinMaxResult { min, max }| MinMaxResult { + let local = min_max(array.storage_array(), ctx, SkipNansOptions::default())?.map( + |MinMaxResult { min, max }| MinMaxResult { min: Scalar::extension_ref(non_nullable_ext_dtype.clone(), min), max: Scalar::extension_ref(non_nullable_ext_dtype, max), - }); + }, + ); partial.merge(local); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/min_max/mod.rs b/vortex-array/src/aggregate_fn/fns/min_max/mod.rs index 540b5608e28..e43e595435a 100644 --- a/vortex-array/src/aggregate_fn/fns/min_max/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/min_max/mod.rs @@ -12,6 +12,7 @@ use std::sync::LazyLock; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_panic; use self::bool::accumulate_bool; use self::decimal::accumulate_decimal; @@ -26,14 +27,17 @@ use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::dtype::DType; use crate::dtype::FieldNames; use crate::dtype::Nullability; +use crate::dtype::PType; use crate::dtype::StructFields; +use crate::dtype::half::f16; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; use crate::partial_ord::partial_max; use crate::partial_ord::partial_min; use crate::scalar::Scalar; @@ -42,9 +46,40 @@ static NAMES: LazyLock = LazyLock::new(|| FieldNames::from(["min", " /// The minimum and maximum non-null values of an array, or `None` if there are no non-null values. /// +/// NaN handling for float inputs is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values are ignored and the cached `Stat::Min`/`Stat::Max` statistics are consulted +/// and updated. With `skip_nans=false`, any NaN value in a float array poisons both extrema to +/// NaN; an exact `Stat::NaNCount` statistic shortcircuits the NaN scan in either direction. +/// /// The result scalars have the non-nullable version of the array dtype. /// This will update the stats set of the array as a side effect. -pub fn min_max(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { +pub fn min_max( + array: &ArrayRef, + ctx: &mut ExecutionCtx, + options: SkipNansOptions, +) -> VortexResult> { + if !options.skip_nans && array.dtype().is_float() { + match array.statistics().get_as::(Stat::NaNCount) { + // NaN-free: identical to the NaN-skipping path below, including its stat caching. + Precision::Exact(0) => {} + // At least one NaN value poisons both extrema. + Precision::Exact(_) => return Ok(Some(nan_minmax_result(array.dtype()))), + _ => { + if array.is_empty() || array.valid_count(ctx)? == 0 { + return Ok(None); + } + // Compute with NaN-including options; the NaN-skipping `Stat::Min`/`Stat::Max` + // caches are neither read nor written. + let mut acc = Accumulator::try_new(MinMax, options, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + return MinMaxResult::from_scalar(acc.finish()?); + } + } + } + + // NaN-skipping path. Also reached for NaN-free not-skipping float arrays and all non-float + // arrays, where `skip_nans` has no effect. + // Short-circuit using cached array statistics. let cached_min = array.statistics().get(Stat::Min).as_exact(); let cached_max = array.statistics().get(Stat::Max).as_exact(); @@ -67,7 +102,7 @@ pub fn min_max(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult. - let mut acc = Accumulator::try_new(MinMax, EmptyOptions, array.dtype().clone())?; + let mut acc = Accumulator::try_new(MinMax, SkipNansOptions::default(), array.dtype().clone())?; acc.accumulate(array, ctx)?; let result_scalar = acc.finish()?; let result = MinMaxResult::from_scalar(result_scalar)?; @@ -89,6 +124,30 @@ pub fn min_max(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult MinMaxResult { + let nan = nan_scalar(dtype); + MinMaxResult { + min: nan.clone(), + max: nan, + } +} + +/// A non-nullable NaN scalar of the float `dtype`. +pub(crate) fn nan_scalar(dtype: &DType) -> Scalar { + match dtype.as_ptype() { + PType::F16 => Scalar::primitive(f16::NAN, Nullability::NonNullable), + PType::F32 => Scalar::primitive(f32::NAN, Nullability::NonNullable), + PType::F64 => Scalar::primitive(f64::NAN, Nullability::NonNullable), + _ => vortex_panic!("NaN scalar requested for non-float dtype {dtype}"), + } +} + +/// Whether a scalar holds a primitive float NaN value. +pub(crate) fn scalar_is_nan(scalar: &Scalar) -> bool { + scalar.as_primitive_opt().is_some_and(|p| p.is_nan()) +} + /// The minimum and maximum non-null values of an array. #[derive(Debug, Clone, PartialEq, Eq)] pub struct MinMaxResult { @@ -119,6 +178,9 @@ impl MinMaxResult { /// /// Returns a nullable struct scalar `{min: T, max: T}` where `T` is the non-nullable input dtype. /// The struct is null when the array is empty or all-null. +/// +/// NaN handling for float inputs is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values are ignored, otherwise any NaN value poisons both extrema to NaN. #[derive(Clone, Debug)] pub struct MinMax; @@ -127,6 +189,7 @@ pub struct MinMaxPartial { min: Option, max: Option, element_dtype: DType, + skip_nans: bool, } impl MinMaxPartial { @@ -136,6 +199,16 @@ impl MinMaxPartial { return; }; + // NaN scalars are incomparable under `partial_min`/`partial_max`, so they are handled + // explicitly: a NaN extremum poisons the partial state when NaNs participate, and is + // dropped when they are skipped. + if scalar_is_nan(&min) || scalar_is_nan(&max) || self.is_poisoned() { + if !self.skip_nans { + self.poison(); + } + return; + } + self.min = Some(match self.min.take() { Some(current) => partial_min(min, current).vortex_expect("incomparable min scalars"), None => min, @@ -146,6 +219,18 @@ impl MinMaxPartial { None => max, }); } + + /// Poison the partial state to `{min: NaN, max: NaN}`. + fn poison(&mut self) { + let nan = nan_scalar(&self.element_dtype); + self.min = Some(nan.clone()); + self.max = Some(nan); + } + + /// Whether the partial state is poisoned to NaN. + fn is_poisoned(&self) -> bool { + self.min.as_ref().is_some_and(scalar_is_nan) + } } /// Creates the struct dtype `{min: T, max: T}` (nullable) used for min/max aggregate results. @@ -194,7 +279,7 @@ fn minmax_compute_supported_dtype(input_dtype: &DType) -> bool { } impl AggregateFnVTable for MinMax { - type Options = EmptyOptions; + type Options = SkipNansOptions; type Partial = MinMaxPartial; fn id(&self) -> AggregateFnId { @@ -215,13 +300,14 @@ impl AggregateFnVTable for MinMax { fn empty_partial( &self, - _options: &Self::Options, + options: &Self::Options, input_dtype: &DType, ) -> VortexResult { Ok(MinMaxPartial { min: None, max: None, element_dtype: input_dtype.clone(), + skip_nans: options.skip_nans, }) } @@ -245,8 +331,46 @@ impl AggregateFnVTable for MinMax { } #[inline] - fn is_saturated(&self, _partial: &Self::Partial) -> bool { - false + fn is_saturated(&self, partial: &Self::Partial) -> bool { + // A poisoned NaN-including min/max is fully determined. + partial.is_poisoned() + } + + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + // NaN-aware shortcircuits only apply to NaN-including float min/max; everything else + // takes the default dispatch path. + if partial.skip_nans || !partial.element_dtype.is_float() { + return Ok(false); + } + match batch.statistics().get_as::(Stat::NaNCount) { + Precision::Exact(0) => { + // NaN-free batch: the cached NaN-skipping extrema (if any) are valid. + let cached_min = batch.statistics().get(Stat::Min).as_exact(); + let cached_max = batch.statistics().get(Stat::Max).as_exact(); + if let Some((min, max)) = cached_min.zip(cached_max) { + // Cached float stats carry the (possibly nullable) array dtype; `to_scalar` + // builds a struct with non-nullable fields, so normalise here. + let non_nullable_dtype = partial.element_dtype.as_nonnullable(); + partial.merge(Some(MinMaxResult { + min: min.cast(&non_nullable_dtype)?, + max: max.cast(&non_nullable_dtype)?, + })); + return Ok(true); + } + Ok(false) + } + Precision::Exact(_) => { + // At least one NaN value poisons both extrema without scanning the batch. + partial.poison(); + Ok(true) + } + _ => Ok(false), + } } fn accumulate( @@ -261,8 +385,11 @@ impl AggregateFnVTable for MinMax { if scalar.is_null() { return Ok(()); } - // Skip NaN float constants - if scalar.as_primitive_opt().is_some_and(|p| p.is_nan()) { + // NaN float constants are skipped or poison the extrema, per the options. + if scalar_is_nan(scalar) { + if !partial.skip_nans { + partial.poison(); + } return Ok(()); } let non_nullable_dtype = scalar.dtype().as_nonnullable(); @@ -302,19 +429,20 @@ impl AggregateFnVTable for MinMax { #[cfg(test)] mod tests { use std::sync::Arc; + use std::sync::LazyLock; use vortex_buffer::BitBuffer; use vortex_buffer::buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; + use vortex_session::VortexSession; use crate::IntoArray as _; - use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::min_max::MinMaxResult; use crate::aggregate_fn::fns::min_max::make_minmax_dtype; @@ -332,17 +460,21 @@ mod tests { use crate::dtype::DecimalDType; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; use crate::scalar::DecimalValue; use crate::scalar::Scalar; use crate::scalar::ScalarValue; use crate::validity::Validity; + static SESSION: LazyLock = LazyLock::new(vortex_array::array_session); + #[test] fn test_prim_min_max() -> VortexResult<()> { let p = PrimitiveArray::new(buffer![1, 2, 3], Validity::NonNullable).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!( - min_max(&p, &mut ctx)?, + min_max(&p, &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: 1.into(), max: 3.into() @@ -366,9 +498,9 @@ mod tests { Some(7), ]) .into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); assert_eq!( - min_max(&p, &mut ctx)?, + min_max(&p, &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: 1.into(), max: 9.into() @@ -379,7 +511,7 @@ mod tests { #[test] fn test_bool_min_max() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let all_true = BoolArray::new( BitBuffer::from([true, true, true].as_slice()), @@ -387,7 +519,7 @@ mod tests { ) .into_array(); assert_eq!( - min_max(&all_true, &mut ctx)?, + min_max(&all_true, &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: true.into(), max: true.into() @@ -400,7 +532,7 @@ mod tests { ) .into_array(); assert_eq!( - min_max(&all_false, &mut ctx)?, + min_max(&all_false, &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: false.into(), max: false.into() @@ -413,7 +545,7 @@ mod tests { ) .into_array(); assert_eq!( - min_max(&mixed, &mut ctx)?, + min_max(&mixed, &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: false.into(), max: true.into() @@ -425,8 +557,8 @@ mod tests { #[test] fn test_null_array() -> VortexResult<()> { let p = NullArray::new(1).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(min_max(&p, &mut ctx)?, None); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!(min_max(&p, &mut ctx, SkipNansOptions::default())?, None); Ok(()) } @@ -436,8 +568,9 @@ mod tests { buffer![f32::NAN, -f32::NAN, -1.0, 1.0], Validity::NonNullable, ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = min_max(&array.into_array(), &mut ctx)?.vortex_expect("should have result"); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&array.into_array(), &mut ctx, SkipNansOptions::default())? + .vortex_expect("should have result"); assert_eq!(f32::try_from(&result.min)?, -1.0); assert_eq!(f32::try_from(&result.max)?, 1.0); Ok(()) @@ -449,8 +582,9 @@ mod tests { buffer![f32::INFINITY, f32::NEG_INFINITY, -1.0, 1.0], Validity::NonNullable, ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = min_max(&array.into_array(), &mut ctx)?.vortex_expect("should have result"); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&array.into_array(), &mut ctx, SkipNansOptions::default())? + .vortex_expect("should have result"); assert_eq!(f32::try_from(&result.min)?, f32::NEG_INFINITY); assert_eq!(f32::try_from(&result.max)?, f32::INFINITY); Ok(()) @@ -458,9 +592,9 @@ mod tests { #[test] fn test_multi_batch() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(MinMax, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(MinMax, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -476,9 +610,9 @@ mod tests { #[test] fn test_finish_resets_state() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(MinMax, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(MinMax, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -497,7 +631,7 @@ mod tests { #[test] fn test_state_merge() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = MinMax.empty_partial(&EmptyOptions, &dtype)?; + let mut state = MinMax.empty_partial(&SkipNansOptions::default(), &dtype)?; let struct_dtype = make_minmax_dtype(&dtype); let scalar1 = Scalar::struct_( @@ -520,19 +654,136 @@ mod tests { fn test_constant_nan() -> VortexResult<()> { let scalar = Scalar::primitive(f16::NAN, Nullability::NonNullable); let array = ConstantArray::new(scalar, 2).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(min_max(&array, &mut ctx)?, None); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!(min_max(&array, &mut ctx, SkipNansOptions::default())?, None); + Ok(()) + } + + const KEEP_NANS: SkipNansOptions = SkipNansOptions::include(); + + fn assert_poisoned(result: Option) -> VortexResult<()> { + let result = result.vortex_expect("should have result"); + assert!(f64::try_from(&result.min.cast(&result.min.dtype().as_nullable())?)?.is_nan()); + assert!(f64::try_from(&result.max.cast(&result.max.dtype().as_nullable())?)?.is_nan()); + Ok(()) + } + + #[test] + fn test_prim_nan_not_skipping() -> VortexResult<()> { + let array = PrimitiveArray::new( + buffer![f32::NAN, -f32::NAN, -1.0, 1.0], + Validity::NonNullable, + ) + .into_array(); + let mut ctx = SESSION.create_execution_ctx(); + assert_poisoned(min_max(&array, &mut ctx, KEEP_NANS)?) + } + + #[test] + fn test_prim_no_nan_not_skipping() -> VortexResult<()> { + let array = + PrimitiveArray::new(buffer![3.0f32, -1.0, 1.0], Validity::NonNullable).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&array, &mut ctx, KEEP_NANS)?.vortex_expect("should have result"); + assert_eq!(f32::try_from(&result.min)?, -1.0); + assert_eq!(f32::try_from(&result.max)?, 3.0); + Ok(()) + } + + #[test] + fn test_constant_nan_not_skipping() -> VortexResult<()> { + let scalar = Scalar::primitive(f64::NAN, Nullability::NonNullable); + let array = ConstantArray::new(scalar, 2).into_array(); + let mut ctx = SESSION.create_execution_ctx(); + assert_poisoned(min_max(&array, &mut ctx, KEEP_NANS)?) + } + + #[test] + fn test_not_skipping_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> { + // The array has no NaNs; a planted exact NaNCount stat proves the poisoning came from + // the stat rather than a scan. + let array = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(2u64))); + let mut ctx = SESSION.create_execution_ctx(); + assert_poisoned(min_max(&array, &mut ctx, KEEP_NANS)?) + } + + #[test] + fn test_not_skipping_uses_cached_stats_when_nan_free() -> VortexResult<()> { + // With an exact NaNCount of zero, the planted exact Min/Max stats are usable as-is. + let array = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64))); + array + .statistics() + .set(Stat::Min, Precision::Exact(ScalarValue::from(-10.0f64))); + array + .statistics() + .set(Stat::Max, Precision::Exact(ScalarValue::from(10.0f64))); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&array, &mut ctx, KEEP_NANS)?.vortex_expect("should have result"); + assert_eq!(f64::try_from(&result.min)?, -10.0); + assert_eq!(f64::try_from(&result.max)?, 10.0); + Ok(()) + } + + #[test] + fn test_accumulator_nan_including_nullable_cached_stats() -> VortexResult<()> { + // A nullable float array's cached Min/Max stats are reconstructed as nullable scalars. + // The NaN-including accumulator shortcircuit must normalise them to the non-nullable + // struct field dtype before building the result scalar. + let mut ctx = SESSION.create_execution_ctx(); + let array = + PrimitiveArray::from_option_iter([Some(1.0f64), Some(2.0), Some(3.0)]).into_array(); + array + .statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64))); + array + .statistics() + .set(Stat::Min, Precision::Exact(ScalarValue::from(1.0f64))); + array + .statistics() + .set(Stat::Max, Precision::Exact(ScalarValue::from(3.0f64))); + + let mut acc = Accumulator::try_new(MinMax, KEEP_NANS, array.dtype().clone())?; + acc.accumulate(&array, &mut ctx)?; + let result = MinMaxResult::from_scalar(acc.finish()?)?.vortex_expect("should have result"); + assert_eq!(f64::try_from(&result.min)?, 1.0); + assert_eq!(f64::try_from(&result.max)?, 3.0); Ok(()) } + #[test] + fn test_multi_batch_nan_poisoning() -> VortexResult<()> { + let mut ctx = SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(MinMax, KEEP_NANS, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![1.0f64, 2.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + assert!(!acc.is_saturated()); + + let batch2 = PrimitiveArray::new(buffer![f64::NAN], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + assert!(acc.is_saturated()); + + assert_poisoned(MinMaxResult::from_scalar(acc.finish()?)?) + } + #[test] fn test_chunked() -> VortexResult<()> { let chunk1 = PrimitiveArray::from_option_iter([Some(5i32), None, Some(1)]); let chunk2 = PrimitiveArray::from_option_iter([Some(10i32), Some(3), None]); let dtype = chunk1.dtype().clone(); let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = min_max(&chunked.into_array(), &mut ctx)?.vortex_expect("should have result"); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&chunked.into_array(), &mut ctx, SkipNansOptions::default())? + .vortex_expect("should have result"); assert_eq!(result.min, Scalar::from(1i32)); assert_eq!(result.max, Scalar::from(10i32)); Ok(()) @@ -541,8 +792,11 @@ mod tests { #[test] fn test_all_null() -> VortexResult<()> { let p = PrimitiveArray::from_option_iter::([None, None, None]); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(min_max(&p.into_array(), &mut ctx)?, None); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!( + min_max(&p.into_array(), &mut ctx, SkipNansOptions::default())?, + None + ); Ok(()) } @@ -557,8 +811,9 @@ mod tests { ], DType::Utf8(Nullability::Nullable), ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = min_max(&array.into_array(), &mut ctx)?.vortex_expect("should have result"); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&array.into_array(), &mut ctx, SkipNansOptions::default())? + .vortex_expect("should have result"); assert_eq!( result.min, Scalar::utf8("hello world", Nullability::NonNullable) @@ -580,8 +835,9 @@ mod tests { DecimalDType::new(4, 2), Validity::from_iter([true, false, true]), ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = min_max(&decimal.into_array(), &mut ctx)?.vortex_expect("should have result"); + let mut ctx = SESSION.create_execution_ctx(); + let result = min_max(&decimal.into_array(), &mut ctx, SkipNansOptions::default())? + .vortex_expect("should have result"); let non_nullable_dtype = DType::Decimal(DecimalDType::new(4, 2), Nullability::NonNullable); let expected_min = Scalar::try_new( @@ -605,18 +861,18 @@ mod tests { DType::FixedSizeList(Arc::new(element_dtype), 1, Nullability::Nullable); assert_eq!( - MinMax.return_dtype(&EmptyOptions, &list_dtype), + MinMax.return_dtype(&SkipNansOptions::default(), &list_dtype), Some(make_minmax_dtype(&list_dtype)) ); assert_eq!( - MinMax.return_dtype(&EmptyOptions, &fixed_size_list_dtype), + MinMax.return_dtype(&SkipNansOptions::default(), &fixed_size_list_dtype), Some(make_minmax_dtype(&fixed_size_list_dtype)) ); } #[test] fn list_and_fixed_size_list_min_max_returns_none() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let list_array = ListArray::try_new( buffer![1i32, 2, 3].into_array(), @@ -624,7 +880,10 @@ mod tests { Validity::NonNullable, )? .into_array(); - assert_eq!(min_max(&list_array, &mut ctx)?, None); + assert_eq!( + min_max(&list_array, &mut ctx, SkipNansOptions::default())?, + None + ); let fixed_size_list_array = FixedSizeListArray::try_new( buffer![1i32, 2, 3, 4].into_array(), @@ -633,7 +892,10 @@ mod tests { 2, )? .into_array(); - assert_eq!(min_max(&fixed_size_list_array, &mut ctx)?, None); + assert_eq!( + min_max(&fixed_size_list_array, &mut ctx, SkipNansOptions::default())?, + None + ); Ok(()) } @@ -642,11 +904,12 @@ mod tests { #[test] fn test_bool_with_nulls() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let result = min_max( &BoolArray::from_iter(vec![Some(true), Some(true), None, None]).into_array(), &mut ctx, + SkipNansOptions::default(), )?; assert_eq!( result, @@ -659,6 +922,7 @@ mod tests { let result = min_max( &BoolArray::from_iter(vec![None, Some(true), Some(true)]).into_array(), &mut ctx, + SkipNansOptions::default(), )?; assert_eq!( result, @@ -671,6 +935,7 @@ mod tests { let result = min_max( &BoolArray::from_iter(vec![None, Some(true), Some(true), None]).into_array(), &mut ctx, + SkipNansOptions::default(), )?; assert_eq!( result, @@ -683,6 +948,7 @@ mod tests { let result = min_max( &BoolArray::from_iter(vec![Some(false), Some(false), None, None]).into_array(), &mut ctx, + SkipNansOptions::default(), )?; assert_eq!( result, @@ -701,7 +967,7 @@ mod tests { /// partial state. #[test] fn test_bool_chunked_with_empty_chunk() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let empty = BoolArray::new(BitBuffer::from([].as_slice()), Validity::NonNullable); let chunk1 = BoolArray::new( @@ -717,7 +983,7 @@ mod tests { DType::Bool(Nullability::NonNullable), )?; - let result = min_max(&chunked.into_array(), &mut ctx)?; + let result = min_max(&chunked.into_array(), &mut ctx, SkipNansOptions::default())?; assert_eq!( result, Some(MinMaxResult { @@ -736,7 +1002,7 @@ mod tests { /// running min/max. Empty chunks are now skipped during chunked aggregation. #[test] fn test_chunked_with_empty_constant_chunk() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let empty = ConstantArray::new(Scalar::primitive(u32::MAX, Nullability::NonNullable), 0) .into_array(); @@ -748,7 +1014,7 @@ mod tests { )?; assert_eq!( - min_max(&chunked.into_array(), &mut ctx)?, + min_max(&chunked.into_array(), &mut ctx, SkipNansOptions::default())?, Some(MinMaxResult { min: Scalar::primitive(0u32, Nullability::NonNullable), max: Scalar::primitive(7631471u32, Nullability::NonNullable), @@ -763,8 +1029,11 @@ mod tests { vec![Option::<&str>::None, None, None], DType::Utf8(Nullability::Nullable), ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(min_max(&array.into_array(), &mut ctx)?, None); + let mut ctx = SESSION.create_execution_ctx(); + assert_eq!( + min_max(&array.into_array(), &mut ctx, SkipNansOptions::default())?, + None + ); Ok(()) } } diff --git a/vortex-array/src/aggregate_fn/fns/min_max/primitive.rs b/vortex-array/src/aggregate_fn/fns/min_max/primitive.rs index 06b9749678d..7ca97712477 100644 --- a/vortex-array/src/aggregate_fn/fns/min_max/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/min_max/primitive.rs @@ -20,8 +20,9 @@ pub(super) fn accumulate_primitive( p: &PrimitiveArray, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { + let skip_nans = partial.skip_nans; match_each_native_ptype!(p.ptype(), |T| { - let local = compute_min_max_with_validity::(p, ctx)?; + let local = compute_min_max_with_validity::(p, ctx, skip_nans)?; partial.merge(local); Ok(()) }) @@ -30,6 +31,7 @@ pub(super) fn accumulate_primitive( fn compute_min_max_with_validity( array: &PrimitiveArray, ctx: &mut ExecutionCtx, + skip_nans: bool, ) -> VortexResult> where T: NativePType, @@ -48,7 +50,7 @@ where if T::PTYPE.is_int() { integer_min_max_raw(slice).map(min_max_result) } else { - compute_min_max(slice.iter()) + compute_min_max(slice.iter(), skip_nans) } } Mask::AllFalse(_) => None, @@ -73,6 +75,7 @@ where v.slices() .iter() .flat_map(|&(start, end)| slice[start..end].iter()), + skip_nans, ) } } @@ -110,15 +113,29 @@ where } } -fn compute_min_max<'a, T>(iter: impl Iterator) -> Option +fn compute_min_max<'a, T>( + iter: impl Iterator, + skip_nans: bool, +) -> Option where T: NativePType, PValue: From, { - match iter - .filter(|v| !v.is_nan()) - .minmax_by(|a, b| a.total_compare(**b)) - { + if skip_nans { + minmax_by_total_order(iter.filter(|v| !v.is_nan())) + } else { + // Compute extrema under the total order (where NaNs sort to the ends) and let the + // partial's merge poison the result if either end is NaN. + minmax_by_total_order(iter) + } +} + +fn minmax_by_total_order<'a, T>(iter: impl Iterator) -> Option +where + T: NativePType, + PValue: From, +{ + match iter.minmax_by(|a, b| a.total_compare(**b)) { itertools::MinMaxResult::NoElements => None, itertools::MinMaxResult::OneElement(&x) => { let scalar = Scalar::primitive(x, NonNullable); diff --git a/vortex-array/src/aggregate_fn/fns/sum/bool.rs b/vortex-array/src/aggregate_fn/fns/sum/bool.rs index 7728d3f64af..787693cfc38 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/bool.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/bool.rs @@ -41,7 +41,7 @@ mod tests { use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::BoolArray; @@ -108,7 +108,7 @@ mod tests { #[test] fn sum_bool_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -118,7 +118,7 @@ mod tests { fn sum_bool_finish_resets_state() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Bool(Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let batch1: BoolArray = [true, true, false].into_iter().collect(); acc.accumulate(&batch1.into_array(), &mut ctx)?; @@ -135,7 +135,10 @@ mod tests { #[test] fn sum_bool_return_dtype() -> VortexResult<()> { let dtype = Sum - .return_dtype(&EmptyOptions, &DType::Bool(Nullability::NonNullable)) + .return_dtype( + &SkipNansOptions::default(), + &DType::Bool(Nullability::NonNullable), + ) .unwrap(); assert_eq!(dtype, DType::Primitive(PType::U64, Nullability::Nullable)); Ok(()) diff --git a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs index e37b671eb46..73a2beb850f 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/decimal.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/decimal.rs @@ -115,7 +115,7 @@ mod tests { use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::aggregate_fn::AggregateFnVTable; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::DecimalArray; @@ -357,7 +357,7 @@ mod tests { // Native type for precision 14 is I64 (max precision 18), so 14 < 18. // Use combine_partials to push state near (but under) 10^14. let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); - let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + let mut state = Sum.empty_partial(&SkipNansOptions::default(), &input_dtype)?; let near_limit = Scalar::decimal( DecimalValue::from(99_999_999_999_990i64), @@ -387,7 +387,7 @@ mod tests { // i256 arithmetic does not overflow. This tests the precision-based // saturation path in combine_partials. let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); - let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + let mut state = Sum.empty_partial(&SkipNansOptions::default(), &input_dtype)?; let near_limit = Scalar::decimal( DecimalValue::from(99_999_999_999_999i64), @@ -414,7 +414,7 @@ mod tests { fn sum_decimal_precision_overflow_negative() -> VortexResult<()> { // Same setup but with negative values: sum reaches -10^14. let input_dtype = DType::Decimal(DecimalDType::new(4, 0), Nullability::NonNullable); - let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + let mut state = Sum.empty_partial(&SkipNansOptions::default(), &input_dtype)?; let near_limit = Scalar::decimal( DecimalValue::from(-99_999_999_999_999i64), @@ -446,7 +446,7 @@ mod tests { // a real array that pushes it over. let input_dtype = DType::Decimal(DecimalDType::new(27, 0), Nullability::NonNullable); let return_dtype = DecimalDType::new(37, 0); - let mut state = Sum.empty_partial(&EmptyOptions, &input_dtype)?; + let mut state = Sum.empty_partial(&SkipNansOptions::default(), &input_dtype)?; // Set state to 10^37 - 1 via combine_partials. let near_limit_val: i128 = 10i128.pow(37) - 1; diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 6f00cce7fdb..7c7b8aaa730 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -32,10 +32,10 @@ impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { groups: &GroupedArray, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { return Ok(None); - } - try_grouped_sum(groups, ctx) + }; + try_grouped_sum(groups, ctx, options.skip_nans) } } @@ -48,6 +48,7 @@ impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { pub(super) fn try_grouped_sum( groups: &GroupedArray, ctx: &mut ExecutionCtx, + skip_nans: bool, ) -> VortexResult> { if !groups.elements().is::() { return Ok(None); @@ -61,6 +62,7 @@ pub(super) fn try_grouped_sum( &group_ranges, &group_validity, ctx, + skip_nans, )?)) } @@ -70,6 +72,7 @@ fn grouped_sum( group_ranges: &GroupRanges, group_validity: &Mask, ctx: &mut ExecutionCtx, + skip_nans: bool, ) -> VortexResult { let elem_mask = elements .as_ref() @@ -91,7 +94,7 @@ fn grouped_sum( floating: |T| { let values = elements.as_slice::(); collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, - |acc, slice| { sum_float_all(acc, slice); false }) + |acc, slice| { sum_float_all(acc, slice, skip_nans); false }) } ); @@ -159,8 +162,8 @@ mod tests { use crate::LEGACY_SESSION; use crate::VortexSessionExecute; use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::FixedSizeListArray; @@ -176,7 +179,8 @@ mod tests { /// Run a grouped sum through the accumulator. fn grouped_sum_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + let mut acc = + GroupedAccumulator::try_new(Sum, SkipNansOptions::default(), elem_dtype.clone())?; acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; acc.finish() } @@ -193,7 +197,7 @@ mod tests { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let sum_dtype = Sum - .partial_dtype(&EmptyOptions, elem_dtype) + .partial_dtype(&SkipNansOptions::default(), elem_dtype) .expect("sum partial dtype"); let mut builder = builder_with_capacity(&sum_dtype, ranges.len()); for (i, &(offset, size)) in ranges.iter().enumerate() { @@ -347,6 +351,29 @@ mod tests { Ok(()) } + #[test] + fn listview_float_nan_not_skipping() -> VortexResult<()> { + let elements = PrimitiveArray::new( + buffer![1.0f64, f64::NAN, 2.0, 3.0, 4.0], + Validity::NonNullable, + ) + .into_array(); + let elem_dtype = DType::Primitive(PType::F64, NonNullable); + let groups = listview(elements, &[(0, 3), (3, 2)], &[true, true])?; + + let mut acc = GroupedAccumulator::try_new(Sum, SkipNansOptions::include(), elem_dtype)?; + acc.accumulate_list(&groups, &mut LEGACY_SESSION.create_execution_ctx())?; + let actual = acc.finish()?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + // Group 0 contains a NaN -> NaN sum; group 1 sums normally. + let g0 = actual.execute_scalar(0, &mut ctx)?; + assert!(g0.as_primitive().typed_value::().unwrap().is_nan()); + let g1 = actual.execute_scalar(1, &mut ctx)?; + assert_eq!(g1.as_primitive().typed_value::(), Some(7.0)); + Ok(()) + } + #[test] fn fixed_size_overflow_and_nan() -> VortexResult<()> { // FixedSize path: first group overflows -> null sum, second sums normally. diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index f48397798d9..11759027549 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -25,7 +25,7 @@ use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; @@ -34,6 +34,7 @@ use crate::dtype::PType; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; +use crate::expr::stats::StatsProviderExt; use crate::scalar::DecimalValue; use crate::scalar::Scalar; @@ -48,7 +49,7 @@ pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { // Compute using Accumulator. // TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe. - let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), array.dtype().clone())?; acc.accumulate(array, ctx)?; let result = acc.finish()?; @@ -64,11 +65,14 @@ pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { /// /// If the sum overflows, a null scalar will be returned. /// If the array is all-invalid, the sum will be zero. +/// +/// NaN handling for float inputs is controlled by [`SkipNansOptions`]: with `skip_nans` (the +/// default) NaN values contribute nothing, otherwise any NaN value poisons the sum to NaN. #[derive(Clone, Debug)] pub struct Sum; impl AggregateFnVTable for Sum { - type Options = EmptyOptions; + type Options = SkipNansOptions; type Partial = SumPartial; fn id(&self) -> AggregateFnId { @@ -130,6 +134,7 @@ impl AggregateFnVTable for Sum { Ok(SumPartial { return_dtype, current: Some(initial), + skip_nans: options.skip_nans, }) } @@ -214,6 +219,43 @@ impl AggregateFnVTable for Sum { } } + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + // NaN-aware shortcircuits only apply to NaN-including float sums; everything else takes + // the default dispatch path. + if partial.skip_nans || !matches!(partial.current, Some(SumState::Float(_))) { + return Ok(false); + } + match batch.statistics().get_as::(Stat::NaNCount) { + Precision::Exact(0) => { + // NaN-free batch: the cached NaN-skipping sum (if any) equals the + // NaN-including sum. + if let Precision::Exact(sum) = batch.statistics().get(Stat::Sum) { + let sum = if sum.dtype() == &partial.return_dtype { + sum + } else { + sum.cast(&partial.return_dtype)? + }; + self.combine_partials(partial, sum)?; + return Ok(true); + } + Ok(false) + } + Precision::Exact(_) => { + // At least one NaN value: the sum is NaN without scanning the batch. + if let Some(SumState::Float(acc)) = partial.current.as_mut() { + *acc = f64::NAN; + } + Ok(true) + } + _ => Ok(false), + } + } + fn accumulate( &self, partial: &mut Self::Partial, @@ -222,12 +264,17 @@ impl AggregateFnVTable for Sum { ) -> VortexResult<()> { // Constants compute scalar * len and combine via combine_partials. if let Columnar::Constant(c) = batch { + // NaN constants are treated as missing when skipping NaNs. + if partial.skip_nans && c.scalar().as_primitive_opt().is_some_and(|p| p.is_nan()) { + return Ok(()); + } if let Some(product) = multiply_constant(c.scalar(), c.len(), &partial.return_dtype)? { self.combine_partials(partial, product)?; } return Ok(()); } + let skip_nans = partial.skip_nans; let mut inner = match partial.current.take() { Some(inner) => inner, None => return Ok(()), @@ -235,7 +282,7 @@ impl AggregateFnVTable for Sum { let result = match batch { Columnar::Canonical(c) => match c { - Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx), + Canonical::Primitive(p) => accumulate_primitive(&mut inner, p, ctx, skip_nans), Canonical::Bool(b) => accumulate_bool(&mut inner, b, ctx), Canonical::Decimal(d) => accumulate_decimal(&mut inner, d, ctx), _ => vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()), @@ -269,6 +316,8 @@ pub struct SumPartial { return_dtype: DType, /// The current accumulated state, or `None` if saturated (checked overflow). current: Option, + /// Whether NaN values in float inputs are skipped. + skip_nans: bool, } /// The accumulated sum value. @@ -338,8 +387,8 @@ mod tests { use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; use crate::aggregate_fn::GroupedAccumulator; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; use crate::arrays::BoolArray; @@ -424,7 +473,7 @@ mod tests { fn sum_multi_batch() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -441,7 +490,7 @@ mod tests { fn sum_finish_resets_state() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); acc.accumulate(&batch1, &mut ctx)?; @@ -460,7 +509,7 @@ mod tests { #[test] fn sum_state_merge() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = Sum.empty_partial(&EmptyOptions, &dtype)?; + let mut state = Sum.empty_partial(&SkipNansOptions::default(), &dtype)?; let scalar1 = Scalar::primitive(100i64, Nullable); Sum.combine_partials(&mut state, scalar1)?; @@ -513,7 +562,8 @@ mod tests { // Grouped sum tests fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; + let mut acc = + GroupedAccumulator::try_new(Sum, SkipNansOptions::default(), elem_dtype.clone())?; acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; acc.finish() } @@ -596,7 +646,7 @@ mod tests { fn grouped_sum_finish_resets() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; + let mut acc = GroupedAccumulator::try_new(Sum, SkipNansOptions::default(), elem_dtype)?; let elements1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index df7d929d896..da632557d34 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -25,16 +25,21 @@ pub(super) fn accumulate_primitive( inner: &mut SumState, p: &PrimitiveArray, ctx: &mut ExecutionCtx, + skip_nans: bool, ) -> VortexResult { let mask = p.as_ref().validity()?.execute_mask(p.as_ref().len(), ctx)?; match mask.slices() { AllOr::None => Ok(false), - AllOr::All => accumulate_primitive_all(inner, p), - AllOr::Some(slices) => accumulate_primitive_valid(inner, p, slices), + AllOr::All => accumulate_primitive_all(inner, p, skip_nans), + AllOr::Some(slices) => accumulate_primitive_valid(inner, p, slices, skip_nans), } } -fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexResult { +fn accumulate_primitive_all( + inner: &mut SumState, + p: &PrimitiveArray, + skip_nans: bool, +) -> VortexResult { match inner { SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), unsigned: |T| { Ok(sum_unsigned_all(acc, p.as_slice::())) }, @@ -50,7 +55,7 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR unsigned: |_T| { vortex_panic!("float sum state with unsigned input") }, signed: |_T| { vortex_panic!("float sum state with signed input") }, floating: |T| { - sum_float_all(acc, p.as_slice::()); + sum_float_all(acc, p.as_slice::(), skip_nans); Ok(false) } ), @@ -58,11 +63,18 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR } } -/// Sum the non-NaN values of a float slice into an `f64` accumulator. NaNs are skipped to match the -/// scalar `sum` semantics. Floats cannot overflow the accumulator, so this never reports saturation. -pub(super) fn sum_float_all(acc: &mut f64, slice: &[T]) { - for &v in slice { - if !v.is_nan() { +/// Sum the values of a float slice into an `f64` accumulator. When `skip_nans` is set, NaN values +/// are skipped to match the scalar `sum` semantics; otherwise any NaN poisons the accumulator to +/// NaN. Floats cannot overflow the accumulator, so this never reports saturation. +pub(super) fn sum_float_all(acc: &mut f64, slice: &[T], skip_nans: bool) { + if skip_nans { + for &v in slice { + if !v.is_nan() { + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); + } + } + } else { + for &v in slice { *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); } } @@ -122,6 +134,7 @@ fn accumulate_primitive_valid( inner: &mut SumState, p: &PrimitiveArray, slices: &[(usize, usize)], + skip_nans: bool, ) -> VortexResult { match inner { SumState::Unsigned(acc) => match_each_native_ptype!(p.ptype(), @@ -156,7 +169,7 @@ fn accumulate_primitive_valid( floating: |T| { let values = p.as_slice::(); for &(start, end) in slices { - sum_float_all(acc, &values[start..end]); + sum_float_all(acc, &values[start..end], skip_nans); } Ok(false) } @@ -175,15 +188,19 @@ mod tests { use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::sum::sum; + use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::Nullability::Nullable; use crate::dtype::PType; + use crate::expr::stats::Precision; + use crate::expr::stats::Stat; use crate::scalar::Scalar; + use crate::scalar::ScalarValue; use crate::validity::Validity; #[test] @@ -274,7 +291,7 @@ mod tests { #[test] fn sum_empty_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0)); Ok(()) @@ -283,7 +300,7 @@ mod tests { #[test] fn sum_empty_f64_produces_zero() -> VortexResult<()> { let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; let result = acc.finish()?; assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); Ok(()) @@ -328,6 +345,70 @@ mod tests { Ok(()) } + /// Sum an array with explicit [`SkipNansOptions`] (test-only helper). + fn sum_with_options(arr: &crate::ArrayRef, options: SkipNansOptions) -> VortexResult { + let mut acc = Accumulator::try_new(Sum, options, arr.dtype().clone())?; + acc.accumulate(arr, &mut LEGACY_SESSION.create_execution_ctx())?; + acc.finish() + } + + #[test] + fn sum_f64_with_nan_not_skipping() -> VortexResult<()> { + let arr = + PrimitiveArray::new(buffer![1.0f64, f64::NAN, 2.0], Validity::NonNullable).into_array(); + let result = sum_with_options(&arr, SkipNansOptions::include())?; + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); + Ok(()) + } + + #[test] + fn sum_f64_without_nan_not_skipping() -> VortexResult<()> { + let arr = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + let result = sum_with_options(&arr, SkipNansOptions::include())?; + assert_eq!(result.as_primitive().typed_value::(), Some(6.0)); + Ok(()) + } + + #[test] + fn sum_not_skipping_shortcircuits_on_exact_nan_count_stat() -> VortexResult<()> { + // The array has no NaNs; a planted exact NaNCount stat proves the NaN poisoning came + // from the stat rather than a scan. + let arr = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + arr.statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(1u64))); + let result = sum_with_options(&arr, SkipNansOptions::include())?; + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); + Ok(()) + } + + #[test] + fn sum_not_skipping_uses_cached_sum_when_nan_free() -> VortexResult<()> { + // With an exact NaNCount of zero, the planted exact Sum stat is usable as-is. + let arr = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + arr.statistics() + .set(Stat::NaNCount, Precision::Exact(ScalarValue::from(0u64))); + arr.statistics() + .set(Stat::Sum, Precision::Exact(ScalarValue::from(42.0f64))); + let result = sum_with_options(&arr, SkipNansOptions::include())?; + assert_eq!(result.as_primitive().typed_value::(), Some(42.0)); + Ok(()) + } + + #[test] + fn sum_constant_nan() -> VortexResult<()> { + let arr = ConstantArray::new(f64::NAN, 4).into_array(); + // NaN constants are skipped by default and poison the sum otherwise. + let result = sum_with_options(&arr, SkipNansOptions::default())?; + assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); + + let result = sum_with_options(&arr, SkipNansOptions::include())?; + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); + Ok(()) + } + #[test] fn sum_f64_with_infinity() -> VortexResult<()> { let batch = PrimitiveArray::new( @@ -341,7 +422,7 @@ mod tests { let mut acc = Accumulator::try_new( Sum, - EmptyOptions, + SkipNansOptions::default(), DType::Primitive(PType::F64, Nullability::NonNullable), )?; acc.accumulate(&batch, &mut LEGACY_SESSION.create_execution_ctx())?; @@ -360,7 +441,7 @@ mod tests { #[test] fn sum_checked_overflow_is_saturated() -> VortexResult<()> { let dtype = DType::Primitive(PType::I64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, dtype)?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), dtype)?; assert!(!acc.is_saturated()); let batch = diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 28b91d45166..dc7fb773193 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -167,6 +167,57 @@ impl Display for EmptyOptions { } } +/// Options for aggregate functions over primitive numeric inputs, controlling how NaN values in +/// floating-point arrays are handled. +/// +/// When `skip_nans` is `true` (the default), NaN values are treated as missing: they contribute +/// nothing to `sum`/`min`/`max`/`mean` and are excluded from `count`. +/// +/// When `skip_nans` is `false`, NaN values participate in the aggregate: `count` includes them, +/// while any NaN value poisons the result of `sum`/`min`/`max`/`mean` to NaN. +/// +/// The option has no effect on non-float inputs. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct SkipNansOptions { + /// Whether NaN values are skipped (treated as missing) during aggregation. + pub skip_nans: bool, +} + +impl SkipNansOptions { + /// Options that skip NaN values, treating them as missing during aggregation. + /// + /// This is the default configuration; see [`SkipNansOptions::include`] for the NaN-including + /// variant. + pub const fn skip() -> Self { + Self { skip_nans: true } + } + + /// Options that include NaN values in the aggregate: `count` counts them, while any NaN + /// poisons the result of `sum`/`min`/`max`/`mean` to NaN. + /// + /// See [`SkipNansOptions::skip`] for the default NaN-skipping variant. + pub const fn include() -> Self { + Self { skip_nans: false } + } +} + +impl Default for SkipNansOptions { + fn default() -> Self { + Self::skip() + } +} + +impl Display for SkipNansOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // Only the non-default configuration is displayed, so that aggregates with default + // options render identically to their pre-options form, e.g. `vortex.sum()`. + if !self.skip_nans { + write!(f, "skip_nans=false")?; + } + Ok(()) + } +} + /// Factory functions for aggregate vtables. pub trait AggregateFnVTableExt: AggregateFnVTable { /// Bind this vtable with the given options into an [`AggregateFnRef`]. diff --git a/vortex-array/src/arrays/chunked/compute/aggregate.rs b/vortex-array/src/arrays/chunked/compute/aggregate.rs index 52c3db5cf2f..eef619d6d45 100644 --- a/vortex-array/src/arrays/chunked/compute/aggregate.rs +++ b/vortex-array/src/arrays/chunked/compute/aggregate.rs @@ -47,7 +47,7 @@ mod tests { use crate::VortexSessionExecute; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::sum::Sum; use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; @@ -59,7 +59,7 @@ mod tests { fn run_sum(batch: &crate::ArrayRef) -> VortexResult { let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let mut acc = Accumulator::try_new(Sum, EmptyOptions, batch.dtype().clone())?; + let mut acc = Accumulator::try_new(Sum, SkipNansOptions::default(), batch.dtype().clone())?; acc.accumulate(batch, &mut ctx)?; acc.finish() } diff --git a/vortex-array/src/arrays/dict/compute/min_max.rs b/vortex-array/src/arrays/dict/compute/min_max.rs index 8910ab12a53..417d724dfa7 100644 --- a/vortex-array/src/arrays/dict/compute/min_max.rs +++ b/vortex-array/src/arrays/dict/compute/min_max.rs @@ -30,9 +30,9 @@ impl DynAggregateKernel for DictMinMaxKernel { batch: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !aggregate_fn.is::() { + let Some(options) = aggregate_fn.as_opt::() else { return Ok(None); - } + }; let Some(dict) = batch.as_opt::() else { return Ok(None); @@ -42,13 +42,13 @@ impl DynAggregateKernel for DictMinMaxKernel { let result = if dict.has_all_values_referenced() { // All values are referenced, compute min/max directly on the values array. - min_max(dict.values(), ctx)? + min_max(dict.values(), ctx, *options)? } else { // Filter to only referenced values, then compute min/max. let referenced_mask = dict.compute_referenced_values_mask(true)?; let mask = Mask::from(referenced_mask); let filtered_values = dict.values().filter(mask)?; - min_max(&filtered_values, ctx)? + min_max(&filtered_values, ctx, *options)? }; match result { @@ -70,6 +70,7 @@ mod tests { use crate::ArrayRef; use crate::IntoArray; use crate::VortexSessionExecute; + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::min_max; use crate::arrays::DictArray; use crate::arrays::PrimitiveArray; @@ -79,7 +80,10 @@ mod tests { fn assert_min_max(array: &ArrayRef, expected: Option<(i32, i32)>) -> VortexResult<()> { let mut ctx = SESSION.create_execution_ctx(); - match (min_max(array, &mut ctx)?, expected) { + match ( + min_max(array, &mut ctx, SkipNansOptions::default())?, + expected, + ) { (Some(result), Some((expected_min, expected_max))) => { assert_eq!(i32::try_from(&result.min)?, expected_min); assert_eq!(i32::try_from(&result.max)?, expected_max); diff --git a/vortex-array/src/arrays/list/array.rs b/vortex-array/src/arrays/list/array.rs index 7788511aa25..e9e9e84394f 100644 --- a/vortex-array/src/arrays/list/array.rs +++ b/vortex-array/src/arrays/list/array.rs @@ -20,6 +20,7 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::min_max; use crate::array::Array; use crate::array::ArrayParts; @@ -212,7 +213,7 @@ impl ListData { // Validate that offsets min is non-negative, and max does not exceed the length of // the elements array. - if let Some(min_max) = min_max(offsets, &mut ctx)? { + if let Some(min_max) = min_max(offsets, &mut ctx, SkipNansOptions::default())? { match_each_integer_ptype!(offsets_ptype, |P| { #[allow(clippy::absurd_extreme_comparisons, unused_comparisons)] { diff --git a/vortex-array/src/arrays/listview/array.rs b/vortex-array/src/arrays/listview/array.rs index e5b6f9aaf99..8decd997071 100644 --- a/vortex-array/src/arrays/listview/array.rs +++ b/vortex-array/src/arrays/listview/array.rs @@ -22,6 +22,7 @@ use crate::LEGACY_SESSION; #[expect(deprecated)] use crate::ToCanonical as _; use crate::VortexSessionExecute; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::min_max; use crate::array::Array; use crate::array::ArrayParts; @@ -567,12 +568,16 @@ pub trait ListViewArrayExt: TypedArrayRef { }); let offsets = self.offsets().cast(wide_dtype.clone())?; let sizes = self.sizes().cast(wide_dtype)?; - let end = min_max(&offsets.binary(sizes, Operator::Add)?, ctx)? - .vortex_expect("non-empty array must report a min/max") - .max - .as_primitive() - .as_::() - .vortex_expect("max `offset + size` must fit in a usize"); + let end = min_max( + &offsets.binary(sizes, Operator::Add)?, + ctx, + SkipNansOptions::default(), + )? + .vortex_expect("non-empty array must report a min/max") + .max + .as_primitive() + .as_::() + .vortex_expect("max `offset + size` must fit in a usize"); Ok((start, end)) } diff --git a/vortex-array/src/arrays/primitive/array/mod.rs b/vortex-array/src/arrays/primitive/array/mod.rs index e22fc9a4336..1ea27369c6c 100644 --- a/vortex-array/src/arrays/primitive/array/mod.rs +++ b/vortex-array/src/arrays/primitive/array/mod.rs @@ -42,6 +42,7 @@ pub use patch::chunk_range; pub use patch::patch_chunk; use crate::ArrayRef; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::min_max; use crate::array::child_to_validity; use crate::array::validity_to_child; @@ -154,7 +155,7 @@ pub trait PrimitiveArrayExt: TypedArrayRef { return Ok(self.to_owned()); } - let Some(min_max) = min_max(self.as_ref(), ctx)? else { + let Some(min_max) = min_max(self.as_ref(), ctx, SkipNansOptions::default())? else { return Ok(PrimitiveArray::new( Buffer::::zeroed(self.len()), self.validity(), diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index 10c0b8d6eba..87f1deac23e 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -213,10 +213,14 @@ fn values_fit_in( if !compute { return false; } - aggregate_fn::fns::min_max::min_max(array.array(), ctx) - .ok() - .flatten() - .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok()) + aggregate_fn::fns::min_max::min_max( + array.array(), + ctx, + aggregate_fn::SkipNansOptions::default(), + ) + .ok() + .flatten() + .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok()) } /// Cached-only check: returns `Some(fits)` if both `Min` and `Max` are present as `Exact` in the diff --git a/vortex-array/src/compute/conformance/cast.rs b/vortex-array/src/compute/conformance/cast.rs index be5142d8f53..f52a7483e58 100644 --- a/vortex-array/src/compute/conformance/cast.rs +++ b/vortex-array/src/compute/conformance/cast.rs @@ -10,6 +10,7 @@ use crate::IntoArray; use crate::LEGACY_SESSION; use crate::RecursiveCanonical; use crate::VortexSessionExecute; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::MinMaxResult; use crate::aggregate_fn::fns::min_max::min_max; use crate::builtins::ArrayBuiltins; @@ -248,8 +249,8 @@ fn fits(value: &Scalar, ptype: PType) -> bool { fn test_cast_to_primitive(array: &ArrayRef, target_ptype: PType, test_round_trip: bool) { let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let maybe_min_max = - min_max(array, &mut ctx).vortex_expect("cast should succeed in conformance test"); + let maybe_min_max = min_max(array, &mut ctx, SkipNansOptions::default()) + .vortex_expect("cast should succeed in conformance test"); if let Some(MinMaxResult { min, max }) = maybe_min_max && (!fits(&min, target_ptype) || !fits(&max, target_ptype)) diff --git a/vortex-array/src/compute/conformance/consistency.rs b/vortex-array/src/compute/conformance/consistency.rs index 8ae7101c6db..e9b7b134e09 100644 --- a/vortex-array/src/compute/conformance/consistency.rs +++ b/vortex-array/src/compute/conformance/consistency.rs @@ -1017,6 +1017,7 @@ fn test_boolean_demorgan_consistency(array: &ArrayRef) { /// Aggregate operations on sliced arrays must produce correct results /// regardless of the underlying encoding's offset handling. fn test_slice_aggregate_consistency(array: &ArrayRef) { + use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::min_max::min_max; use crate::aggregate_fn::fns::nan_count::nan_count; use crate::aggregate_fn::fns::sum::sum; @@ -1075,8 +1076,8 @@ fn test_slice_aggregate_consistency(array: &ArrayRef) { // Test min_max if let (Ok(slice_minmax), Ok(canonical_minmax)) = ( - min_max(&sliced, &mut ctx), - min_max(&canonical_sliced, &mut ctx), + min_max(&sliced, &mut ctx, SkipNansOptions::default()), + min_max(&canonical_sliced, &mut ctx, SkipNansOptions::default()), ) { match (slice_minmax, canonical_minmax) { (Some(s_result), Some(c_result)) => { diff --git a/vortex-array/src/expr/stats/mod.rs b/vortex-array/src/expr/stats/mod.rs index cd5da7811e9..a546e140eb2 100644 --- a/vortex-array/src/expr/stats/mod.rs +++ b/vortex-array/src/expr/stats/mod.rs @@ -28,6 +28,7 @@ use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::AggregateFnVTableExt; use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; #[derive( Debug, @@ -186,17 +187,20 @@ impl Stat { .return_dtype(&EmptyOptions, data_type); } Self::Sum => { - return aggregate_fn::fns::sum::Sum.return_dtype(&EmptyOptions, data_type); + // Statistics follow NaN-skipping semantics; request it explicitly. + return aggregate_fn::fns::sum::Sum + .return_dtype(&SkipNansOptions::skip(), data_type); } }) } /// Return the built-in aggregate function corresponding to this statistic, if one exists. pub fn aggregate_fn(&self) -> Option { + // Statistics follow NaN-skipping semantics; request it explicitly rather than the default. Some(match self { - Self::Max => aggregate_fn::fns::max::Max.bind(EmptyOptions), - Self::Min => aggregate_fn::fns::min::Min.bind(EmptyOptions), - Self::Sum => aggregate_fn::fns::sum::Sum.bind(EmptyOptions), + Self::Max => aggregate_fn::fns::max::Max.bind(SkipNansOptions::skip()), + Self::Min => aggregate_fn::fns::min::Min.bind(SkipNansOptions::skip()), + Self::Sum => aggregate_fn::fns::sum::Sum.bind(SkipNansOptions::skip()), Self::NullCount => aggregate_fn::fns::null_count::NullCount.bind(EmptyOptions), Self::NaNCount => aggregate_fn::fns::nan_count::NanCount.bind(EmptyOptions), Self::UncompressedSizeInBytes => { @@ -208,9 +212,12 @@ impl Stat { } /// Return the statistic represented by `aggregate_fn`, if it has a legacy stat slot. + /// + /// Min/max/sum statistics skip NaN values, so NaN-including configurations of those + /// aggregates have no stat slot. pub fn from_aggregate_fn(aggregate_fn: &AggregateFnRef) -> Option { - if aggregate_fn.is::() { - return Some(Self::Sum); + if let Some(options) = aggregate_fn.as_opt::() { + return options.skip_nans.then_some(Self::Sum); } if aggregate_fn.is::() { return Some(Self::NaNCount); @@ -218,11 +225,11 @@ impl Stat { if aggregate_fn.is::() { return Some(Self::NullCount); } - if aggregate_fn.is::() { - return Some(Self::Min); + if let Some(options) = aggregate_fn.as_opt::() { + return options.skip_nans.then_some(Self::Min); } - if aggregate_fn.is::() { - return Some(Self::Max); + if let Some(options) = aggregate_fn.as_opt::() { + return options.skip_nans.then_some(Self::Max); } if aggregate_fn .is::() diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index 7dc7a42603a..ff8c62b777b 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -16,6 +16,7 @@ use super::StatsSet; use super::StatsSetIntoIter; use super::TypedStatsSetRef; use crate::ArrayRef; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::is_constant::is_constant; use crate::aggregate_fn::fns::is_sorted::is_sorted; use crate::aggregate_fn::fns::is_sorted::is_strict_sorted; @@ -162,8 +163,10 @@ impl StatsSetRef<'_> { } Ok(match stat { - Stat::Min => min_max(self.dyn_array_ref, ctx)?.map(|MinMaxResult { min, max: _ }| min), - Stat::Max => min_max(self.dyn_array_ref, ctx)?.map(|MinMaxResult { min: _, max }| max), + Stat::Min => min_max(self.dyn_array_ref, ctx, SkipNansOptions::default())? + .map(|MinMaxResult { min, max: _ }| min), + Stat::Max => min_max(self.dyn_array_ref, ctx, SkipNansOptions::default())? + .map(|MinMaxResult { min: _, max }| max), Stat::Sum => { Stat::Sum .dtype(self.dyn_array_ref.dtype()) diff --git a/vortex-array/src/stats/expr.rs b/vortex-array/src/stats/expr.rs index 58f0c93c508..3701676104a 100644 --- a/vortex-array/src/stats/expr.rs +++ b/vortex-array/src/stats/expr.rs @@ -6,6 +6,7 @@ use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTableExt; use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::SkipNansOptions; use crate::aggregate_fn::fns::all_nan::AllNan; use crate::aggregate_fn::fns::all_non_nan::AllNonNan; use crate::aggregate_fn::fns::all_non_null::AllNonNull; @@ -29,12 +30,14 @@ pub fn stat(expr: Expression, aggregate_fn: AggregateFnRef) -> Expression { /// Creates `stat(expr, min_max)`, returning a nullable `{ min, max }` struct statistic. pub fn min_max(expr: Expression) -> Expression { - stat(expr, MinMax.bind(EmptyOptions)) + // Statistics follow NaN-skipping semantics; request it explicitly rather than via the default. + stat(expr, MinMax.bind(SkipNansOptions::skip())) } /// Creates `stat(expr, sum)`, returning a nullable sum statistic. pub fn sum(expr: Expression) -> Expression { - stat(expr, Sum.bind(EmptyOptions)) + // Statistics follow NaN-skipping semantics; request it explicitly rather than via the default. + stat(expr, Sum.bind(SkipNansOptions::skip())) } /// Creates `stat(expr, null_count)`, returning a nullable null-count statistic.