Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ indexmap = { workspace = true }
itertools = { workspace = true }
libc = "0.2.185"
log = { workspace = true }
num-traits = { workspace = true }
object_store = { workspace = true, optional = true }
parquet = { workspace = true, optional = true, default-features = true }
recursive = { workspace = true, optional = true }
Expand Down
224 changes: 189 additions & 35 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use arrow::util::display::{ArrayFormatter, FormatOptions, array_value_to_string}
use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
use chrono::{Duration, NaiveDate};
use half::f16;
use num_traits::ToPrimitive;
pub use struct_builder::ScalarStructBuilder;

const SECONDS_PER_DAY: i64 = 86_400;
Expand Down Expand Up @@ -2585,63 +2586,107 @@ impl ScalarValue {
/// distance is greater than [`usize::MAX`]. If the type is a float, then the distance will be
/// rounded to the nearest integer.
///
///
/// Note: the datatype itself must support subtraction.
pub fn distance(&self, other: &ScalarValue) -> Option<usize> {
self.distance_u64(other)
.and_then(|d| usize::try_from(d).ok())
}

/// Helper to convert a rounded float distance to u64, returning None if it exceeds u64::MAX, is negative, or is not finite.
fn rounded_float_distance_u64(diff: f64) -> Option<u64> {
if diff.is_finite() && diff >= 0.0 && diff < u64::MAX as f64 {
Some(diff as u64)
} else {
None
}
}

/// Absolute distance between two numeric values (of the same type). This method will return
/// None if either one of the arguments are null. It might also return None if the resulting
/// distance is greater than [`u64::MAX`]. If the type is a float, then the distance will be
/// rounded to the nearest integer.
///
/// Note: the datatype itself must support subtraction.
pub fn distance_u64(&self, other: &ScalarValue) -> Option<u64> {
match (self, other) {
(Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r)),
(Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r)),
// TODO: we might want to look into supporting ceil/floor here for floats.
(Self::Float16(Some(l)), Self::Float16(Some(r))) => {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking thought: the float arms all repeat the same pattern of taking a rounded absolute diff and then doing the checked conversion. Once the boundary check is fixed, it might be worth using a small helper like fn rounded_float_distance_u64(diff: f64) -> Option<u64> and calling it from Float16, Float32, and Float64, with f16 and f32 widened to f64. That would keep the overflow invariant in one place and make future drift less likely.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in c356c99

Since this is now a helper, I am also checking for >= 0

Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _)
let diff = (f16::to_f32(*l) - f16::to_f32(*r)).abs().round();
Self::rounded_float_distance_u64(diff as f64)
}
(Self::Float32(Some(l)), Self::Float32(Some(r))) => {
Some((l - r).abs().round() as _)
let diff = (l - r).abs().round();
Self::rounded_float_distance_u64(diff as f64)
}
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
Some((l - r).abs().round() as _)
let diff = (l - r).abs().round();
Self::rounded_float_distance_u64(diff)
}
(Self::Date32(Some(l)), Self::Date32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Date64(Some(l)), Self::Date64(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Date32(Some(l)), Self::Date32(Some(r))) => Some(l.abs_diff(*r) as u64),
(Self::Date64(Some(l)), Self::Date64(Some(r))) => Some(l.abs_diff(*r)),
// Timestamp values are stored as epoch ticks regardless of timezone
// annotation, so the distance is tz-independent (tz is display metadata).
(Self::TimestampSecond(Some(l), _), Self::TimestampSecond(Some(r), _)) => {
Some(l.abs_diff(*r) as _)
Some(l.abs_diff(*r))
}
(
Self::TimestampMillisecond(Some(l), _),
Self::TimestampMillisecond(Some(r), _),
) => Some(l.abs_diff(*r) as _),
) => Some(l.abs_diff(*r)),
(
Self::TimestampMicrosecond(Some(l), _),
Self::TimestampMicrosecond(Some(r), _),
) => Some(l.abs_diff(*r) as _),
) => Some(l.abs_diff(*r)),
(
Self::TimestampNanosecond(Some(l), _),
Self::TimestampNanosecond(Some(r), _),
) => Some(l.abs_diff(*r) as _),
) => Some(l.abs_diff(*r)),
(
Self::Decimal32(Some(l), _, lscale),
Self::Decimal32(Some(r), _, rscale),
) => {
// In order to be aligned with PartialOrd we only
// check for equal scale, ignoring precision
if lscale == rscale {
Some(l.abs_diff(*r) as u64)
} else {
None
}
}
(
Self::Decimal64(Some(l), _, lscale),
Self::Decimal64(Some(r), _, rscale),
) => {
if lscale == rscale {
Some(l.abs_diff(*r))
} else {
None
}
}
(
Self::Decimal128(Some(l), lprecision, lscale),
Self::Decimal128(Some(r), rprecision, rscale),
Self::Decimal128(Some(l), _, lscale),
Self::Decimal128(Some(r), _, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
if lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_u64()
} else {
None
}
}
(
Self::Decimal256(Some(l), lprecision, lscale),
Self::Decimal256(Some(r), rprecision, rscale),
Self::Decimal256(Some(l), _, lscale),
Self::Decimal256(Some(r), _, rscale),
) => {
if lprecision == rprecision && lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_usize()
if lscale == rscale {
l.checked_sub(*r)?.checked_abs()?.to_u64()
} else {
None
}
Expand Down Expand Up @@ -9444,8 +9489,8 @@ mod tests {
),
];
for (lhs, rhs, expected) in cases.iter() {
let distance = lhs.distance(rhs).unwrap();
assert_eq!(distance, *expected);
let distance = lhs.distance_u64(rhs).unwrap();
assert_eq!(distance, *expected as u64);
}
}

Expand All @@ -9462,7 +9507,7 @@ mod tests {
),
];
for (lhs, rhs) in cases.iter() {
let distance = lhs.distance(rhs);
let distance = lhs.distance_u64(rhs);
assert!(distance.is_none(), "{lhs} vs {rhs}");
}
}
Expand Down Expand Up @@ -9508,13 +9553,9 @@ mod tests {
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 3),
),
(
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 3, 5),
),
(
ScalarValue::Decimal256(Some(123.into()), 5, 5),
ScalarValue::Decimal256(Some(120.into()), 3, 5),
ScalarValue::Decimal256(Some(120.into()), 5, 3),
),
// Distance 2 * 2^50 is larger than usize
(
Expand All @@ -9536,11 +9577,124 @@ mod tests {
),
];
for (lhs, rhs) in cases {
let distance = lhs.distance(&rhs);
let distance = lhs.distance_u64(&rhs);
assert!(distance.is_none());
}
}

#[test]
fn test_scalar_distance_u64_boundaries() {
// 1. Full-domain integer ranges
// i64::MIN to i64::MAX -> distance is u64::MAX
let lhs = ScalarValue::Int64(Some(i64::MIN));
let rhs = ScalarValue::Int64(Some(i64::MAX));
assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX));
assert_eq!(rhs.distance_u64(&lhs), Some(u64::MAX));

// u64::MIN to u64::MAX -> distance is u64::MAX
let lhs = ScalarValue::UInt64(Some(u64::MIN));
let rhs = ScalarValue::UInt64(Some(u64::MAX));
assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX));
assert_eq!(rhs.distance_u64(&lhs), Some(u64::MAX));

// 2. Decimal128 overflow edges (around u64::MAX)
// distance equal to u64::MAX fits
let lhs = ScalarValue::Decimal128(Some(0), 20, 0);
let rhs = ScalarValue::Decimal128(Some(u64::MAX as i128), 20, 0);
assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX));

// distance greater than u64::MAX overflows
let lhs = ScalarValue::Decimal128(Some(0), 20, 0);
let rhs = ScalarValue::Decimal128(Some(u64::MAX as i128 + 1), 20, 0);
assert_eq!(lhs.distance_u64(&rhs), None);

// 3. Decimal256 overflow edges (around u64::MAX)
// distance equal to u64::MAX fits
let lhs = ScalarValue::Decimal256(Some(i256::from_parts(0, 0)), 20, 0);
let rhs =
ScalarValue::Decimal256(Some(i256::from_parts(u64::MAX as u128, 0)), 20, 0);
assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX));

// distance greater than u64::MAX overflows
let lhs = ScalarValue::Decimal256(Some(i256::from_parts(0, 0)), 20, 0);
let rhs = ScalarValue::Decimal256(
Some(i256::from_parts(u64::MAX as u128 + 1, 0)),
20,
0,
);
assert_eq!(lhs.distance_u64(&rhs), None);

// 4. Float64 overflow edges (around u64::MAX)
let lhs = ScalarValue::Float64(Some(0.0));
let val: f64 = 18446744073709500000.0;
let rhs = ScalarValue::Float64(Some(val));
assert_eq!(lhs.distance_u64(&rhs), Some(18446744073709500416));

// float value > u64::MAX overflows
let rhs = ScalarValue::Float64(Some(1.9e19));
assert_eq!(lhs.distance_u64(&rhs), None);

// exact 2^64 boundary (18446744073709551616.0) is greater than u64::MAX, so it should return None
let exact_2_64_f64 = ScalarValue::Float64(Some(18446744073709551616.0));
assert_eq!(lhs.distance_u64(&exact_2_64_f64), None);

// exact 2^64 boundary as Float32 should also return None
let lhs_f32 = ScalarValue::Float32(Some(0.0));
let exact_2_64_f32 = ScalarValue::Float32(Some(18446744073709551616.0));
assert_eq!(lhs_f32.distance_u64(&exact_2_64_f32), None);

// largest float32 value below 2^64 (2^64 - 2^41 = 18446741874686296064.0) should fit
let below_2_64_f32 = ScalarValue::Float32(Some(18446741874686296064.0));
assert_eq!(
lhs_f32.distance_u64(&below_2_64_f32),
Some(18446741874686296064)
);

// Inf, NegInf, NaN
let inf = ScalarValue::Float64(Some(f64::INFINITY));
let neg_inf = ScalarValue::Float64(Some(f64::NEG_INFINITY));
let nan = ScalarValue::Float64(Some(f64::NAN));
assert_eq!(lhs.distance_u64(&inf), None);
assert_eq!(lhs.distance_u64(&neg_inf), None);
assert_eq!(lhs.distance_u64(&nan), None);

let inf_f32 = ScalarValue::Float32(Some(f32::INFINITY));
let neg_inf_f32 = ScalarValue::Float32(Some(f32::NEG_INFINITY));
let nan_f32 = ScalarValue::Float32(Some(f32::NAN));
assert_eq!(lhs_f32.distance_u64(&inf_f32), None);
assert_eq!(lhs_f32.distance_u64(&neg_inf_f32), None);
assert_eq!(lhs_f32.distance_u64(&nan_f32), None);

let lhs_f16 = ScalarValue::Float16(Some(f16::ZERO));
let inf_f16 = ScalarValue::Float16(Some(f16::INFINITY));
let neg_inf_f16 = ScalarValue::Float16(Some(f16::NEG_INFINITY));
let nan_f16 = ScalarValue::Float16(Some(f16::NAN));
assert_eq!(lhs_f16.distance_u64(&inf_f16), None);
assert_eq!(lhs_f16.distance_u64(&neg_inf_f16), None);
assert_eq!(lhs_f16.distance_u64(&nan_f16), None);

// 5. Date and Timestamp boundaries
// Date32: i32::MIN to i32::MAX
let lhs = ScalarValue::Date32(Some(i32::MIN));
let rhs = ScalarValue::Date32(Some(i32::MAX));
assert_eq!(lhs.distance_u64(&rhs), Some(u32::MAX as u64));

// TimestampSecond: i64::MIN to i64::MAX
let lhs = ScalarValue::TimestampSecond(Some(i64::MIN), None);
let rhs = ScalarValue::TimestampSecond(Some(i64::MAX), None);
assert_eq!(lhs.distance_u64(&rhs), Some(u64::MAX));

// 6. Decimal scale matching (ignoring precision)
let lhs = ScalarValue::Decimal128(Some(100), 10, 2);
let rhs = ScalarValue::Decimal128(Some(150), 15, 2);
assert_eq!(lhs.distance_u64(&rhs), Some(50));
assert_eq!(rhs.distance_u64(&lhs), Some(50));

let lhs = ScalarValue::Decimal128(Some(100), 10, 2);
let rhs = ScalarValue::Decimal128(Some(150), 10, 3);
assert_eq!(lhs.distance_u64(&rhs), None);
}

#[test]
fn test_scalar_interval_negate() {
let cases = [
Expand Down
6 changes: 3 additions & 3 deletions datafusion/common/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,8 @@ pub fn estimate_ndv_with_overlap(
let right_min = right.min_value.get_value()?;
let right_max = right.max_value.get_value()?;

let range_left = left_max.distance(left_min)?;
let range_right = right_max.distance(right_min)?;
let range_left = left_max.distance_u64(left_min)?;
let range_right = right_max.distance_u64(right_min)?;

// Constant columns (range == 0) can't use the proportional overlap
// formula below, so check interval overlap directly instead.
Expand Down Expand Up @@ -859,7 +859,7 @@ pub fn estimate_ndv_with_overlap(
return Some(ndv_left + ndv_right);
}

let overlap_range = overlap_max.distance(overlap_min)? as f64;
let overlap_range = overlap_max.distance_u64(overlap_min)? as f64;

let overlap_left = overlap_range / range_left as f64;
let overlap_right = overlap_range / range_right as f64;
Expand Down
17 changes: 15 additions & 2 deletions datafusion/expr-common/src/interval_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -910,10 +910,16 @@ impl Interval {
if data_type.is_integer()
|| matches!(
data_type,
DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _)
DataType::Date32
| DataType::Date64
| DataType::Timestamp(_, _)
| DataType::Decimal32(_, _)
| DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
)
{
self.upper.distance(&self.lower).map(|diff| diff as u64)
self.upper.distance_u64(&self.lower)
} else if data_type.is_floating() {
// Negative numbers are sorted in the reverse order. To
// always have a positive difference after the subtraction,
Expand Down Expand Up @@ -4157,6 +4163,13 @@ mod tests {
ScalarValue::TimestampNanosecond(Some(2_000_000_000), None),
)?;
assert_eq!(interval.cardinality().unwrap(), 1_000_000_001);

// Decimal types
let interval = Interval::try_new(
ScalarValue::Decimal128(Some(100), 10, 2),
ScalarValue::Decimal128(Some(110), 10, 2),
)?;
assert_eq!(interval.cardinality().unwrap(), 11);
Ok(())
}

Expand Down
Loading