diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index af8b146cfba..f4158929124 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -28,7 +28,6 @@ use vortex_array::dtype::PType; use vortex_array::expr::stats::Precision as StatPrecision; use vortex_array::expr::stats::Stat; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::match_each_pvalue; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; @@ -61,6 +60,8 @@ pub struct SequenceMetadata { base: Option, #[prost(message, tag = "2")] multiplier: Option, + #[prost(enumeration = "PType", optional, tag = "3")] + calculation_ptype: Option, } pub(super) const SLOT_NAMES: [&str; 0] = []; @@ -70,11 +71,16 @@ pub(super) const SLOT_NAMES: [&str; 0] = []; pub struct SequenceData { base: PValue, multiplier: PValue, + calculation_ptype: PType, } impl Display for SequenceData { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "base: {}, multiplier: {}", self.base, self.multiplier) + write!( + f, + "base: {}, multiplier: {}, calculation_ptype: {}", + self.base, self.multiplier, self.calculation_ptype + ) } } @@ -91,51 +97,52 @@ impl SequenceData { nullability: Nullability, length: usize, ) -> VortexResult { - Self::try_new( - base.into(), - multiplier.into(), - T::PTYPE, - nullability, - length, - ) + let dtype = DType::Primitive(T::PTYPE, nullability); + Self::try_new(base.into(), multiplier.into(), T::PTYPE, &dtype, length) } - /// Constructs a sequence array using two integer values (with the same ptype). + /// Constructs a sequence array using `calculation_ptype` for arithmetic. pub(crate) fn try_new( base: PValue, multiplier: PValue, - ptype: PType, - nullability: Nullability, + calculation_ptype: PType, + dtype: &DType, length: usize, ) -> VortexResult { - let dtype = DType::Primitive(ptype, nullability); - Self::validate(base, multiplier, &dtype, length)?; - let (base, multiplier) = Self::normalize(base, multiplier, ptype)?; + Self::validate(base, multiplier, calculation_ptype, dtype, length)?; + let (base, multiplier) = Self::normalize(base, multiplier, calculation_ptype)?; - Ok(unsafe { Self::new_unchecked(base, multiplier) }) + Ok(unsafe { Self::new_unchecked(base, multiplier, calculation_ptype) }) } pub fn validate( base: PValue, multiplier: PValue, - dtype: &DType, + calculation_ptype: PType, + output_dtype: &DType, length: usize, ) -> VortexResult<()> { - let DType::Primitive(ptype, _) = dtype else { + let DType::Primitive(output_ptype, _) = output_dtype else { vortex_bail!("only primitive dtypes are supported in SequenceArray currently"); }; - if !ptype.is_int() { - vortex_bail!("only integer ptype are supported in SequenceArray currently") + if !calculation_ptype.is_int() || !output_ptype.is_int() { + vortex_bail!("only integer ptypes are supported in SequenceArray currently") } vortex_ensure!(length > 0, "SequenceArray length must be greater than zero"); - Self::try_last(base, multiplier, *ptype, length).map_err(|e| { + let last = Self::try_last(base, multiplier, calculation_ptype, length).map_err(|e| { e.with_context(format!( "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ", )) })?; + match_each_integer_ptype!(*output_ptype, |P| { + base.cast::

()?; + last.cast::

()?; + VortexResult::Ok(()) + })?; + Ok(()) } @@ -153,14 +160,23 @@ impl SequenceData { /// # Safety /// /// The caller must ensure that: - /// - `base` and `multiplier` are both normalized to the same integer `ptype`. + /// - `base` and `multiplier` are both normalized to `calculation_ptype`. + /// - `calculation_ptype` is an integer type. /// - they are logically compatible with the outer dtype and len. - pub(crate) unsafe fn new_unchecked(base: PValue, multiplier: PValue) -> Self { - Self { base, multiplier } + pub(crate) unsafe fn new_unchecked( + base: PValue, + multiplier: PValue, + calculation_ptype: PType, + ) -> Self { + Self { + base, + multiplier, + calculation_ptype, + } } - pub fn ptype(&self) -> PType { - self.base.ptype() + pub fn calculation_ptype(&self) -> PType { + self.calculation_ptype } pub fn base(&self) -> PValue { @@ -171,6 +187,10 @@ impl SequenceData { self.multiplier } + pub(crate) fn cast_value(value: PValue, output_ptype: PType) -> VortexResult { + match_each_integer_ptype!(output_ptype, |O| { Ok(PValue::from(value.cast::()?)) }) + } + pub fn into_parts(self) -> SequenceDataParts { SequenceDataParts { base: self.base, @@ -200,7 +220,7 @@ impl SequenceData { } pub(crate) fn index_value(&self, idx: usize) -> PValue { - match_each_native_ptype!(self.ptype(), |P| { + match_each_integer_ptype!(self.calculation_ptype(), |P| { let base = self.base.cast::

().vortex_expect("must be able to cast"); let multiplier = self .multiplier @@ -217,12 +237,15 @@ impl ArrayHash for SequenceData { fn array_hash(&self, state: &mut H, _accuracy: EqMode) { self.base.hash(state); self.multiplier.hash(state); + self.calculation_ptype.hash(state); } } impl ArrayEq for SequenceData { fn array_eq(&self, other: &Self, _accuracy: EqMode) -> bool { - self.base == other.base && self.multiplier == other.multiplier + self.base == other.base + && self.multiplier == other.multiplier + && self.calculation_ptype == other.calculation_ptype } } @@ -244,7 +267,13 @@ impl VTable for Sequence { len: usize, _slots: &[Option], ) -> VortexResult<()> { - SequenceData::validate(data.base, data.multiplier, dtype, len) + SequenceData::validate( + data.base, + data.multiplier, + data.calculation_ptype, + dtype, + len, + ) } fn nbuffers(_array: ArrayView<'_, Self>) -> usize { @@ -266,6 +295,7 @@ impl VTable for Sequence { let metadata = SequenceMetadata { base: Some((&array.base()).into()), multiplier: Some((&array.multiplier()).into()), + calculation_ptype: Some(array.calculation_ptype() as i32), }; Ok(Some(metadata.encode_to_vec())) @@ -292,7 +322,11 @@ impl VTable for Sequence { ); let metadata = SequenceMetadata::decode(metadata)?; - let ptype = dtype.as_ptype(); + let calculation_ptype = metadata + .calculation_ptype + .map(|p| PType::try_from(p).map_err(|e| vortex_err!("invalid PType {p}: {e}"))) + .transpose()? + .unwrap_or_else(|| dtype.as_ptype()); // We go via Scalar to validate that the value is valid for the ptype. let base = Scalar::from_proto_value( @@ -300,7 +334,7 @@ impl VTable for Sequence { .base .as_ref() .ok_or_else(|| vortex_err!("base required"))?, - &DType::Primitive(ptype, NonNullable), + &DType::Primitive(calculation_ptype, NonNullable), session, )? .as_primitive() @@ -312,14 +346,14 @@ impl VTable for Sequence { .multiplier .as_ref() .ok_or_else(|| vortex_err!("multiplier required"))?, - &DType::Primitive(ptype, NonNullable), + &DType::Primitive(calculation_ptype, NonNullable), session, )? .as_primitive() .pvalue() .vortex_expect("sequence array multiplier should be a non-nullable primitive"); - let data = SequenceData::try_new(base, multiplier, ptype, dtype.nullability(), len)?; + let data = SequenceData::try_new(base, multiplier, calculation_ptype, dtype, len)?; Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data)) } @@ -355,10 +389,8 @@ impl OperationsVTable for Sequence { index: usize, _ctx: &mut ExecutionCtx, ) -> VortexResult { - Scalar::try_new( - array.dtype().clone(), - Some(ScalarValue::Primitive(array.index_value(index))), - ) + let value = SequenceData::cast_value(array.index_value(index), array.dtype().as_ptype())?; + Scalar::try_new(array.dtype().clone(), Some(ScalarValue::Primitive(value))) } } @@ -402,15 +434,16 @@ impl Sequence { pub(crate) unsafe fn new_unchecked( base: PValue, multiplier: PValue, - ptype: PType, + calculation_ptype: PType, + output_ptype: PType, nullability: Nullability, length: usize, ) -> SequenceArray { - let dtype = DType::Primitive(ptype, nullability); - let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype) + let dtype = DType::Primitive(output_ptype, nullability); + let (base, multiplier) = SequenceData::normalize(base, multiplier, calculation_ptype) .vortex_expect("SequenceArray parts must be normalized to the target ptype"); let stats = Self::stats(multiplier); - let data = unsafe { SequenceData::new_unchecked(base, multiplier) }; + let data = unsafe { SequenceData::new_unchecked(base, multiplier, calculation_ptype) }; unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) } .with_stats_set(stats) } @@ -419,12 +452,13 @@ impl Sequence { pub fn try_new( base: PValue, multiplier: PValue, - ptype: PType, + calculation_ptype: PType, + output_ptype: PType, nullability: Nullability, length: usize, ) -> VortexResult { - let dtype = DType::Primitive(ptype, nullability); - let data = SequenceData::try_new(base, multiplier, ptype, nullability, length)?; + let dtype = DType::Primitive(output_ptype, nullability); + let data = SequenceData::try_new(base, multiplier, calculation_ptype, &dtype, length)?; let stats = Self::stats(data.multiplier()); Ok( unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) } diff --git a/encodings/sequence/src/compress.rs b/encodings/sequence/src/compress.rs index 9f5e62a5217..6bfd476d4ed 100644 --- a/encodings/sequence/src/compress.rs +++ b/encodings/sequence/src/compress.rs @@ -5,20 +5,22 @@ use std::ops::Add; use num_traits::CheckedAdd; use num_traits::CheckedSub; +use num_traits::cast::NumCast; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Primitive; use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::IntegerPType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::scalar::PValue; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_buffer::trusted_len::TrustedLen; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::Sequence; @@ -59,24 +61,32 @@ unsafe impl> TrustedLen for SequenceIter {} /// Decompresses a [`SequenceArray`] into a [`PrimitiveArray`]. #[inline] pub fn sequence_decompress(array: &SequenceArray) -> VortexResult { - fn decompress_inner( - base: P, - multiplier: P, + fn decompress_inner( + base: C, + multiplier: C, len: usize, nullability: Nullability, ) -> PrimitiveArray { - let values = BufferMut::from_trusted_len_iter(SequenceIter { - acc: base, - step: multiplier, - remaining: len, - }); + let values = BufferMut::from_trusted_len_iter( + SequenceIter { + acc: base, + step: multiplier, + remaining: len, + } + .map(|value| { + ::from(value) + .vortex_expect("validated sequence values must fit output ptype") + }), + ); PrimitiveArray::new(values, Validity::from(nullability)) } - let prim = match_each_native_ptype!(array.ptype(), |P| { - let base = array.base().cast::

()?; - let multiplier = array.multiplier().cast::

()?; - decompress_inner(base, multiplier, array.len(), array.dtype().nullability()) + let prim = match_each_integer_ptype!(array.calculation_ptype(), |C| { + let base = array.base().cast::()?; + let multiplier = array.multiplier().cast::()?; + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + decompress_inner::(base, multiplier, array.len(), array.dtype().nullability()) + }) }); Ok(prim.into_array()) } diff --git a/encodings/sequence/src/compute/cast.rs b/encodings/sequence/src/compute/cast.rs index e6d64fdf7c5..85a29b9a563 100644 --- a/encodings/sequence/src/compute/cast.rs +++ b/encodings/sequence/src/compute/cast.rs @@ -5,12 +5,8 @@ use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::IntoArray; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::scalar::Scalar; -use vortex_array::scalar::ScalarValue; use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; -use vortex_error::vortex_err; use crate::Sequence; impl CastReduce for Sequence { @@ -26,62 +22,22 @@ impl CastReduce for Sequence { return Ok(None); } - // Check if this is just a nullability change - if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability { - // For SequenceArray, we can just create a new one with the same parameters - // but different nullability - return Ok(Some( - Sequence::try_new( - array.base(), - array.multiplier(), - *target_ptype, - *target_nullability, - array.len(), - )? - .into_array(), - )); - } - - // For type changes, we need to cast the base and multiplier - if array.ptype() != *target_ptype { - // Create scalars from PValues and cast them - let base_scalar = Scalar::try_new( - DType::Primitive(array.ptype(), Nullability::NonNullable), - Some(ScalarValue::Primitive(array.base())), - )?; - let multiplier_scalar = Scalar::try_new( - DType::Primitive(array.ptype(), Nullability::NonNullable), - Some(ScalarValue::Primitive(array.multiplier())), - )?; - - let new_base_scalar = - base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; - let new_multiplier_scalar = multiplier_scalar - .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; - - // Extract PValues from the casted scalars - let new_base = new_base_scalar - .as_primitive() - .pvalue() - .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?; - let new_multiplier = new_multiplier_scalar - .as_primitive() - .pvalue() - .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?; - - return Ok(Some( - Sequence::try_new( - new_base, - new_multiplier, - *target_ptype, - *target_nullability, - array.len(), - )? - .into_array(), - )); + if array.dtype() == dtype { + return Ok(None); } - Ok(None) + // try_new -> validate proves the produced values fit the target type. + Ok(Some( + Sequence::try_new( + array.base(), + array.multiplier(), + array.calculation_ptype(), + *target_ptype, + *target_nullability, + array.len(), + )? + .into_array(), + )) } } @@ -99,7 +55,10 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::scalar::Scalar; + use vortex_array::scalar::ScalarValue; use vortex_array::session::ArraySession; + use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::Sequence; @@ -189,14 +148,39 @@ mod tests { ); } + #[test] + fn test_cast_sequence_preserves_encoding_and_scalar_at_uses_output_dtype() -> VortexResult<()> { + // Cast the public dtype to u8 + let casted = Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, 5)? + .into_array() + .cast(DType::Primitive(PType::U8, Nullability::NonNullable))?; + + let sequence = casted + .as_typed::() + .expect("integer sequence cast should preserve SequenceArray"); + assert_eq!(sequence.calculation_ptype(), PType::I32); + assert_eq!( + sequence.dtype(), + &DType::Primitive(PType::U8, Nullability::NonNullable) + ); + + let scalar = casted.execute_scalar(1, &mut SESSION.create_execution_ctx())?; + assert_eq!( + scalar, + Scalar::try_new( + DType::Primitive(PType::U8, Nullability::NonNullable), + Some(ScalarValue::from(90u8)), + )? + ); + + Ok(()) + } + #[rstest] #[case::i32(Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())] #[case::u64(Sequence::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())] - // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even - // though all its values are representable therein. - // - // #[case::negative_step(Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, - // 5).unwrap())] + #[case::negative_step(Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, + 5).unwrap())] #[case::single(Sequence::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())] #[case::constant(Sequence::try_new_typed( 100i32, diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs index 3ef24188fbf..bb15114f8b5 100644 --- a/encodings/sequence/src/compute/filter.rs +++ b/encodings/sequence/src/compute/filter.rs @@ -1,14 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::NumCast; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::filter::FilterKernel; -use vortex_array::dtype::NativePType; -use vortex_array::match_each_native_ptype; +use vortex_array::dtype::IntegerPType; +use vortex_array::match_each_integer_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; @@ -24,22 +25,30 @@ impl FilterKernel for Sequence { _ctx: &mut ExecutionCtx, ) -> VortexResult> { let validity = Validity::from(array.dtype().nullability()); - match_each_native_ptype!(array.ptype(), |P| { - let mul = array.multiplier().cast::

()?; - let base = array.base().cast::

()?; - Ok(Some(filter_impl(mul, base, mask, validity))) + match_each_integer_ptype!(array.calculation_ptype(), |C| { + let mul = array.multiplier().cast::()?; + let base = array.base().cast::()?; + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + Ok(Some(filter_impl::(mul, base, mask, validity))) + }) }) } } -fn filter_impl(mul: T, base: T, mask: &Mask, validity: Validity) -> ArrayRef { +fn filter_impl( + mul: C, + base: C, + mask: &Mask, + validity: Validity, +) -> ArrayRef { let mask_values = mask .values() .vortex_expect("FilterKernel precondition: mask is Mask::Values"); - let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); + let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); buffer.extend(mask_values.indices().iter().map(|&idx| { - let i = T::from_usize(idx).vortex_expect("all valid indices fit"); - base + i * mul + let i = C::from_usize(idx).vortex_expect("all valid indices fit"); + ::from(base + i * mul) + .vortex_expect("valid sequence values must fit output ptype") })); PrimitiveArray::new(buffer.freeze(), validity).into_array() } diff --git a/encodings/sequence/src/compute/min_max.rs b/encodings/sequence/src/compute/min_max.rs index 2ff8b2abd58..f856157533c 100644 --- a/encodings/sequence/src/compute/min_max.rs +++ b/encodings/sequence/src/compute/min_max.rs @@ -48,7 +48,9 @@ impl DynAggregateKernel for SequenceMinMaxKernel { } let base = seq.base(); - let last = SequenceData::try_last(base, seq.multiplier(), seq.ptype(), seq.len())?; + let output_ptype = seq.dtype().as_ptype(); + let last = + SequenceData::try_last(base, seq.multiplier(), seq.calculation_ptype(), seq.len())?; // Determine min and max based on multiplier direction. // For unsigned types, multiplier is always >= 0. @@ -65,7 +67,10 @@ impl DynAggregateKernel for SequenceMinMaxKernel { float: |_v| { unreachable!("float multiplier not supported for SequenceArray") } ); - let non_nullable_dtype = DType::Primitive(seq.ptype(), Nullability::NonNullable); + let min_pvalue = SequenceData::cast_value(min_pvalue, output_ptype)?; + let max_pvalue = SequenceData::cast_value(max_pvalue, output_ptype)?; + + let non_nullable_dtype = DType::Primitive(output_ptype, Nullability::NonNullable); let min_scalar = Scalar::try_new( non_nullable_dtype.clone(), Some(ScalarValue::Primitive(min_pvalue)), diff --git a/encodings/sequence/src/compute/slice.rs b/encodings/sequence/src/compute/slice.rs index c2b64b68ef4..941387b03da 100644 --- a/encodings/sequence/src/compute/slice.rs +++ b/encodings/sequence/src/compute/slice.rs @@ -19,7 +19,8 @@ impl SliceReduce for Sequence { Sequence::new_unchecked( array.index_value(range.start), array.multiplier(), - array.ptype(), + array.calculation_ptype(), + array.dtype().as_ptype(), array.dtype().nullability(), range.len(), ) diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index 6e1ebe263df..6a5469ed924 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -11,10 +11,8 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::TakeExecute; use vortex_array::dtype::DType; use vortex_array::dtype::IntegerPType; -use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -26,9 +24,9 @@ use vortex_mask::Mask; use crate::Sequence; -fn take_inner( - mul: S, - base: S, +fn take_inner( + mul: C, + base: C, indices: &[T], indices_mask: Mask, result_nullability: Nullability, @@ -40,14 +38,15 @@ fn take_inner( if i.as_() >= len { vortex_panic!(OutOfBounds: i.as_(), 0, len); } - let i = ::from::(*i).vortex_expect("all indices fit"); - base + i * mul + let i = ::from::(*i).vortex_expect("all indices fit"); + ::from(base + i * mul) + .vortex_expect("validated sequence values must fit output ptype") })), Validity::from(result_nullability), ) .into_array(), AllOr::None => ConstantArray::new( - Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)), + Scalar::null(DType::Primitive(O::PTYPE, Nullability::Nullable)), indices.len(), ) .into_array(), @@ -60,10 +59,11 @@ fn take_inner( } let i = - ::from::(*i).vortex_expect("all valid indices fit"); - base + i * mul + ::from::(*i).vortex_expect("all valid indices fit"); + ::from(base + i * mul) + .vortex_expect("validated sequence values must fit output ptype") } else { - S::zero() + O::zero() } })); PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array() @@ -71,6 +71,62 @@ fn take_inner( } } +fn take_with_output_ptype( + array: ArrayView<'_, Sequence>, + indices: &[T], + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + let mul = array.multiplier().cast::()?; + let base = array.base().cast::()?; + Ok(take_inner::( + mul, + base, + indices, + indices_mask, + result_nullability, + array.len(), + )) +} + +fn take_with_calculation_ptype( + array: ArrayView<'_, Sequence>, + indices: &[T], + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + take_with_output_ptype::(array, indices, indices_mask, result_nullability) + }) +} + +fn take_with_indices_ptype( + array: ArrayView<'_, Sequence>, + indices: &[T], + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + match_each_integer_ptype!(array.calculation_ptype(), |C| { + take_with_calculation_ptype::(array, indices, indices_mask, result_nullability) + }) +} + +fn take_sequence( + array: ArrayView<'_, Sequence>, + indices: &PrimitiveArray, + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + match_each_integer_ptype!(indices.ptype(), |T| { + take_with_indices_ptype::( + array, + indices.as_slice::(), + indices_mask, + result_nullability, + ) + }) +} + impl TakeExecute for Sequence { fn take( array: ArrayView<'_, Self>, @@ -81,21 +137,7 @@ impl TakeExecute for Sequence { let indices = indices.clone().execute::(ctx)?; let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); - match_each_integer_ptype!(indices.ptype(), |T| { - let indices = indices.as_slice::(); - match_each_native_ptype!(array.ptype(), |S| { - let mul = array.multiplier().cast::()?; - let base = array.base().cast::()?; - Ok(Some(take_inner( - mul, - base, - indices, - mask, - result_nullability, - array.len(), - ))) - }) - }) + take_sequence(array, &indices, mask, result_nullability).map(Some) } } diff --git a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs index e0705e07de0..c010d2f7c1b 100644 --- a/vortex-duckdb/src/e2e_test/vortex_scan_test.rs +++ b/vortex-duckdb/src/e2e_test/vortex_scan_test.rs @@ -843,6 +843,7 @@ async fn write_vortex_file_with_encodings() -> NamedTempFile { PValue::I64(0), PValue::I64(10), PType::I64, + PType::I64, Nullability::NonNullable, 5, ) diff --git a/vortex-layout/src/layouts/row_idx/mod.rs b/vortex-layout/src/layouts/row_idx/mod.rs index 3814c2f8093..3bd8987893a 100644 --- a/vortex-layout/src/layouts/row_idx/mod.rs +++ b/vortex-layout/src/layouts/row_idx/mod.rs @@ -275,6 +275,7 @@ fn idx_array(row_offset: u64, row_range: &Range) -> SequenceArray { PValue::U64(row_offset + row_range.start), PValue::U64(1), PType::U64, + PType::U64, NonNullable, usize::try_from(row_range.end - row_range.start) .vortex_expect("Row range length must fit in usize"),