diff --git a/Cargo.lock b/Cargo.lock index fdca3237b71d0..1cb3fe0ecb0d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1879,6 +1879,7 @@ dependencies = [ "itertools 0.14.0", "libc", "log", + "num-traits", "object_store", "parquet", "rand 0.9.4", diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 740d4e45b8d05..1eb23089a4021 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -77,6 +77,7 @@ indexmap = { workspace = true } itertools = { workspace = true } libc = "0.2.185" log = { workspace = true } +num-traits = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8a8a47b3bb50b..bba7f77b89c36 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -92,6 +92,7 @@ use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string} use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; use chrono::{Duration, NaiveDate}; use half::f16; +use num_traits::ToPrimitive; pub use struct_builder::ScalarStructBuilder; const SECONDS_PER_DAY: i64 = 86_400; @@ -2585,63 +2586,107 @@ impl ScalarValue { /// distance is greater than [`usize::MAX`]. If the type is a float, then the distance will be /// rounded to the nearest integer. /// - /// /// Note: the datatype itself must support subtraction. pub fn distance(&self, other: &ScalarValue) -> Option { + self.distance_u64(other) + .and_then(|d| usize::try_from(d).ok()) + } + + /// Helper to convert a rounded float distance to u64, returning None if it exceeds u64::MAX, is negative, or is not finite. + fn rounded_float_distance_u64(diff: f64) -> Option { + if diff.is_finite() && diff >= 0.0 && diff < u64::MAX as f64 { + Some(diff as u64) + } else { + None + } + } + + /// Absolute distance between two numeric values (of the same type). This method will return + /// None if either one of the arguments are null. It might also return None if the resulting + /// distance is greater than [`u64::MAX`]. If the type is a float, then the distance will be + /// rounded to the nearest integer. + /// + /// Note: the datatype itself must support subtraction. + pub fn distance_u64(&self, other: &ScalarValue) -> Option { match (self, other) { - (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r)), + (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r)), // TODO: we might want to look into supporting ceil/floor here for floats. (Self::Float16(Some(l)), Self::Float16(Some(r))) => { - Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _) + let diff = (f16::to_f32(*l) - f16::to_f32(*r)).abs().round(); + Self::rounded_float_distance_u64(diff as f64) } (Self::Float32(Some(l)), Self::Float32(Some(r))) => { - Some((l - r).abs().round() as _) + let diff = (l - r).abs().round(); + Self::rounded_float_distance_u64(diff as f64) } (Self::Float64(Some(l)), Self::Float64(Some(r))) => { - Some((l - r).abs().round() as _) + let diff = (l - r).abs().round(); + Self::rounded_float_distance_u64(diff) } - (Self::Date32(Some(l)), Self::Date32(Some(r))) => Some(l.abs_diff(*r) as _), - (Self::Date64(Some(l)), Self::Date64(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Date32(Some(l)), Self::Date32(Some(r))) => Some(l.abs_diff(*r) as u64), + (Self::Date64(Some(l)), Self::Date64(Some(r))) => Some(l.abs_diff(*r)), // Timestamp values are stored as epoch ticks regardless of timezone // annotation, so the distance is tz-independent (tz is display metadata). (Self::TimestampSecond(Some(l), _), Self::TimestampSecond(Some(r), _)) => { - Some(l.abs_diff(*r) as _) + Some(l.abs_diff(*r)) } ( Self::TimestampMillisecond(Some(l), _), Self::TimestampMillisecond(Some(r), _), - ) => Some(l.abs_diff(*r) as _), + ) => Some(l.abs_diff(*r)), ( Self::TimestampMicrosecond(Some(l), _), Self::TimestampMicrosecond(Some(r), _), - ) => Some(l.abs_diff(*r) as _), + ) => Some(l.abs_diff(*r)), ( Self::TimestampNanosecond(Some(l), _), Self::TimestampNanosecond(Some(r), _), - ) => Some(l.abs_diff(*r) as _), + ) => Some(l.abs_diff(*r)), + ( + Self::Decimal32(Some(l), _, lscale), + Self::Decimal32(Some(r), _, rscale), + ) => { + // In order to be aligned with PartialOrd we only + // check for equal scale, ignoring precision + if lscale == rscale { + Some(l.abs_diff(*r) as u64) + } else { + None + } + } + ( + Self::Decimal64(Some(l), _, lscale), + Self::Decimal64(Some(r), _, rscale), + ) => { + if lscale == rscale { + Some(l.abs_diff(*r)) + } else { + None + } + } ( - Self::Decimal128(Some(l), lprecision, lscale), - Self::Decimal128(Some(r), rprecision, rscale), + Self::Decimal128(Some(l), _, lscale), + Self::Decimal128(Some(r), _, rscale), ) => { - if lprecision == rprecision && lscale == rscale { - l.checked_sub(*r)?.checked_abs()?.to_usize() + if lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_u64() } else { None } } ( - Self::Decimal256(Some(l), lprecision, lscale), - Self::Decimal256(Some(r), rprecision, rscale), + Self::Decimal256(Some(l), _, lscale), + Self::Decimal256(Some(r), _, rscale), ) => { - if lprecision == rprecision && lscale == rscale { - l.checked_sub(*r)?.checked_abs()?.to_usize() + if lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_u64() } else { None } @@ -9444,8 +9489,8 @@ mod tests { ), ]; for (lhs, rhs, expected) in cases.iter() { - let distance = lhs.distance(rhs).unwrap(); - assert_eq!(distance, *expected); + let distance = lhs.distance_u64(rhs).unwrap(); + assert_eq!(distance, *expected as u64); } } @@ -9462,7 +9507,7 @@ mod tests { ), ]; for (lhs, rhs) in cases.iter() { - let distance = lhs.distance(rhs); + let distance = lhs.distance_u64(rhs); assert!(distance.is_none(), "{lhs} vs {rhs}"); } } @@ -9508,13 +9553,9 @@ mod tests { ScalarValue::Decimal128(Some(123), 5, 5), ScalarValue::Decimal128(Some(120), 5, 3), ), - ( - ScalarValue::Decimal128(Some(123), 5, 5), - ScalarValue::Decimal128(Some(120), 3, 5), - ), ( ScalarValue::Decimal256(Some(123.into()), 5, 5), - ScalarValue::Decimal256(Some(120.into()), 3, 5), + ScalarValue::Decimal256(Some(120.into()), 5, 3), ), // Distance 2 * 2^50 is larger than usize ( @@ -9536,11 +9577,124 @@ mod tests { ), ]; for (lhs, rhs) in cases { - let distance = lhs.distance(&rhs); + let distance = lhs.distance_u64(&rhs); assert!(distance.is_none()); } } + #[test] + fn test_scalar_distance_u64_boundaries() { + // 1. Full-domain integer ranges + // i64::MIN to i64::MAX -> distance is u64::MAX + let lhs = ScalarValue::Int64(Some(i64::MIN)); + let rhs = ScalarValue::Int64(Some(i64::MAX)); + assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX)); + assert_eq!(rhs.distance_u64(&lhs), Some(u64::MAX)); + + // u64::MIN to u64::MAX -> distance is u64::MAX + let lhs = ScalarValue::UInt64(Some(u64::MIN)); + let rhs = ScalarValue::UInt64(Some(u64::MAX)); + assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX)); + assert_eq!(rhs.distance_u64(&lhs), Some(u64::MAX)); + + // 2. Decimal128 overflow edges (around u64::MAX) + // distance equal to u64::MAX fits + let lhs = ScalarValue::Decimal128(Some(0), 20, 0); + let rhs = ScalarValue::Decimal128(Some(u64::MAX as i128), 20, 0); + assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX)); + + // distance greater than u64::MAX overflows + let lhs = ScalarValue::Decimal128(Some(0), 20, 0); + let rhs = ScalarValue::Decimal128(Some(u64::MAX as i128 + 1), 20, 0); + assert_eq!(lhs.distance_u64(&rhs), None); + + // 3. Decimal256 overflow edges (around u64::MAX) + // distance equal to u64::MAX fits + let lhs = ScalarValue::Decimal256(Some(i256::from_parts(0, 0)), 20, 0); + let rhs = + ScalarValue::Decimal256(Some(i256::from_parts(u64::MAX as u128, 0)), 20, 0); + assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX)); + + // distance greater than u64::MAX overflows + let lhs = ScalarValue::Decimal256(Some(i256::from_parts(0, 0)), 20, 0); + let rhs = ScalarValue::Decimal256( + Some(i256::from_parts(u64::MAX as u128 + 1, 0)), + 20, + 0, + ); + assert_eq!(lhs.distance_u64(&rhs), None); + + // 4. Float64 overflow edges (around u64::MAX) + let lhs = ScalarValue::Float64(Some(0.0)); + let val: f64 = 18446744073709500000.0; + let rhs = ScalarValue::Float64(Some(val)); + assert_eq!(lhs.distance_u64(&rhs), Some(18446744073709500416)); + + // float value > u64::MAX overflows + let rhs = ScalarValue::Float64(Some(1.9e19)); + assert_eq!(lhs.distance_u64(&rhs), None); + + // exact 2^64 boundary (18446744073709551616.0) is greater than u64::MAX, so it should return None + let exact_2_64_f64 = ScalarValue::Float64(Some(18446744073709551616.0)); + assert_eq!(lhs.distance_u64(&exact_2_64_f64), None); + + // exact 2^64 boundary as Float32 should also return None + let lhs_f32 = ScalarValue::Float32(Some(0.0)); + let exact_2_64_f32 = ScalarValue::Float32(Some(18446744073709551616.0)); + assert_eq!(lhs_f32.distance_u64(&exact_2_64_f32), None); + + // largest float32 value below 2^64 (2^64 - 2^41 = 18446741874686296064.0) should fit + let below_2_64_f32 = ScalarValue::Float32(Some(18446741874686296064.0)); + assert_eq!( + lhs_f32.distance_u64(&below_2_64_f32), + Some(18446741874686296064) + ); + + // Inf, NegInf, NaN + let inf = ScalarValue::Float64(Some(f64::INFINITY)); + let neg_inf = ScalarValue::Float64(Some(f64::NEG_INFINITY)); + let nan = ScalarValue::Float64(Some(f64::NAN)); + assert_eq!(lhs.distance_u64(&inf), None); + assert_eq!(lhs.distance_u64(&neg_inf), None); + assert_eq!(lhs.distance_u64(&nan), None); + + let inf_f32 = ScalarValue::Float32(Some(f32::INFINITY)); + let neg_inf_f32 = ScalarValue::Float32(Some(f32::NEG_INFINITY)); + let nan_f32 = ScalarValue::Float32(Some(f32::NAN)); + assert_eq!(lhs_f32.distance_u64(&inf_f32), None); + assert_eq!(lhs_f32.distance_u64(&neg_inf_f32), None); + assert_eq!(lhs_f32.distance_u64(&nan_f32), None); + + let lhs_f16 = ScalarValue::Float16(Some(f16::ZERO)); + let inf_f16 = ScalarValue::Float16(Some(f16::INFINITY)); + let neg_inf_f16 = ScalarValue::Float16(Some(f16::NEG_INFINITY)); + let nan_f16 = ScalarValue::Float16(Some(f16::NAN)); + assert_eq!(lhs_f16.distance_u64(&inf_f16), None); + assert_eq!(lhs_f16.distance_u64(&neg_inf_f16), None); + assert_eq!(lhs_f16.distance_u64(&nan_f16), None); + + // 5. Date and Timestamp boundaries + // Date32: i32::MIN to i32::MAX + let lhs = ScalarValue::Date32(Some(i32::MIN)); + let rhs = ScalarValue::Date32(Some(i32::MAX)); + assert_eq!(lhs.distance_u64(&rhs), Some(u32::MAX as u64)); + + // TimestampSecond: i64::MIN to i64::MAX + let lhs = ScalarValue::TimestampSecond(Some(i64::MIN), None); + let rhs = ScalarValue::TimestampSecond(Some(i64::MAX), None); + assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX)); + + // 6. Decimal scale matching (ignoring precision) + let lhs = ScalarValue::Decimal128(Some(100), 10, 2); + let rhs = ScalarValue::Decimal128(Some(150), 15, 2); + assert_eq!(lhs.distance_u64(&rhs), Some(50)); + assert_eq!(rhs.distance_u64(&lhs), Some(50)); + + let lhs = ScalarValue::Decimal128(Some(100), 10, 2); + let rhs = ScalarValue::Decimal128(Some(150), 10, 3); + assert_eq!(lhs.distance_u64(&rhs), None); + } + #[test] fn test_scalar_interval_negate() { let cases = [ diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index a64d5e00ee6df..b704a70002d81 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -829,8 +829,8 @@ pub fn estimate_ndv_with_overlap( let right_min = right.min_value.get_value()?; let right_max = right.max_value.get_value()?; - let range_left = left_max.distance(left_min)?; - let range_right = right_max.distance(right_min)?; + let range_left = left_max.distance_u64(left_min)?; + let range_right = right_max.distance_u64(right_min)?; // Constant columns (range == 0) can't use the proportional overlap // formula below, so check interval overlap directly instead. @@ -859,7 +859,7 @@ pub fn estimate_ndv_with_overlap( return Some(ndv_left + ndv_right); } - let overlap_range = overlap_max.distance(overlap_min)? as f64; + let overlap_range = overlap_max.distance_u64(overlap_min)? as f64; let overlap_left = overlap_range / range_left as f64; let overlap_right = overlap_range / range_right as f64; diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 51858be538f5a..68541e1e6b32c 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -910,10 +910,16 @@ impl Interval { if data_type.is_integer() || matches!( data_type, - DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) ) { - self.upper.distance(&self.lower).map(|diff| diff as u64) + self.upper.distance_u64(&self.lower) } else if data_type.is_floating() { // Negative numbers are sorted in the reverse order. To // always have a positive difference after the subtraction, @@ -4157,6 +4163,13 @@ mod tests { ScalarValue::TimestampNanosecond(Some(2_000_000_000), None), )?; assert_eq!(interval.cardinality().unwrap(), 1_000_000_001); + + // Decimal types + let interval = Interval::try_new( + ScalarValue::Decimal128(Some(100), 10, 2), + ScalarValue::Decimal128(Some(110), 10, 2), + )?; + assert_eq!(interval.cardinality().unwrap(), 11); Ok(()) }