Skip to content

Commit

Permalink
Introduce Scalar type for ColumnarValue (#12536)
Browse files Browse the repository at this point in the history
* Introduce `Scalar` type for ColumnarValue

* Add constructor constraints for `Scalar`

* Add rustdoc for `Scalar`

* Add TODO note on `ColumnarValue::cast_to`

* Add more `Scalar` rustdoc
  • Loading branch information
notfilippo authored Oct 1, 2024
1 parent 23d7fff commit 454db7e
Show file tree
Hide file tree
Showing 86 changed files with 1,300 additions and 1,036 deletions.
62 changes: 31 additions & 31 deletions datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ impl ScalarUDFImpl for PowUdf {
// function, but we check again to make sure
assert_eq!(args.len(), 2);
let (base, exp) = (&args[0], &args[1]);
assert_eq!(base.data_type(), DataType::Float64);
assert_eq!(exp.data_type(), DataType::Float64);
assert_eq!(base.data_type(), &DataType::Float64);
assert_eq!(exp.data_type(), &DataType::Float64);

match (base, exp) {
// For demonstration purposes we also implement the scalar / scalar
Expand All @@ -108,28 +108,31 @@ impl ScalarUDFImpl for PowUdf {
// the DataFusion expression simplification logic will often invoke
// this path once during planning, and simply use the result during
// execution.
(
ColumnarValue::Scalar(ScalarValue::Float64(base)),
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
) => {
// compute the output. Note DataFusion treats `None` as NULL.
let res = match (base, exp) {
(Some(base), Some(exp)) => Some(base.powf(*exp)),
// one or both arguments were NULL
_ => None,
};
Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
(ColumnarValue::Scalar(base), ColumnarValue::Scalar(exp)) => {
match (base.value(), exp.value()) {
(ScalarValue::Float64(base), ScalarValue::Float64(exp)) => {
// compute the output. Note DataFusion treats `None` as NULL.
let res = match (base, exp) {
(Some(base), Some(exp)) => Some(base.powf(*exp)),
// one or both arguments were NULL
_ => None,
};
Ok(ColumnarValue::from(ScalarValue::from(res)))
}
_ => {
internal_err!("Invalid argument types to pow function")
}
}
}
// special case if the exponent is a constant
(
ColumnarValue::Array(base_array),
ColumnarValue::Scalar(ScalarValue::Float64(exp)),
) => {
let result_array = match exp {
(ColumnarValue::Array(base_array), ColumnarValue::Scalar(exp)) => {
let result_array = match exp.value() {
// a ^ null = null
None => new_null_array(base_array.data_type(), base_array.len()),
ScalarValue::Float64(None) => {
new_null_array(base_array.data_type(), base_array.len())
}
// a ^ exp
Some(exp) => {
ScalarValue::Float64(Some(exp)) => {
// DataFusion has ensured both arguments are Float64:
let base_array = base_array.as_primitive::<Float64Type>();
// calculate the result for every row. The `unary`
Expand All @@ -139,24 +142,25 @@ impl ScalarUDFImpl for PowUdf {
compute::unary(base_array, |base| base.powf(*exp));
Arc::new(res)
}
_ => return internal_err!("Invalid argument types to pow function"),
};
Ok(ColumnarValue::Array(result_array))
}

// special case if the base is a constant (note this code is quite
// similar to the previous case, so we omit comments)
(
ColumnarValue::Scalar(ScalarValue::Float64(base)),
ColumnarValue::Array(exp_array),
) => {
let res = match base {
None => new_null_array(exp_array.data_type(), exp_array.len()),
Some(base) => {
(ColumnarValue::Scalar(base), ColumnarValue::Array(exp_array)) => {
let res = match base.value() {
ScalarValue::Float64(None) => {
new_null_array(exp_array.data_type(), exp_array.len())
}
ScalarValue::Float64(Some(base)) => {
let exp_array = exp_array.as_primitive::<Float64Type>();
let res: Float64Array =
compute::unary(exp_array, |exp| base.powf(exp));
Arc::new(res)
}
_ => return internal_err!("Invalid argument types to pow function"),
};
Ok(ColumnarValue::Array(res))
}
Expand All @@ -169,10 +173,6 @@ impl ScalarUDFImpl for PowUdf {
)?;
Ok(ColumnarValue::Array(Arc::new(res)))
}
// if the types were not float, it is a bug in DataFusion
_ => {
internal_err!("Invalid argument types to pow function")
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl ScalarUDFImpl for MyEq {
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
// this example simply returns "true" which is not what a real
// implementation would do.
Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
Ok(ColumnarValue::from(ScalarValue::from(true)))
}
}

Expand Down
24 changes: 23 additions & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::hash_utils::create_hashes;
use crate::utils::{
array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array,
};
use arrow::compute::kernels::numeric::*;
use arrow::compute::kernels::{self, numeric::*};
use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions};
use arrow::{
array::*,
Expand Down Expand Up @@ -1704,6 +1704,18 @@ impl ScalarValue {
Some(sv) => sv.data_type(),
};

Self::iter_to_array_of_type(scalars, &data_type)
}

/// Same as [`Self::iter_to_array`] but the target `data_type` can be
/// manually specified instead of being implicitly derived from the type of
/// the first value of the iterator.
pub fn iter_to_array_of_type(
scalars: impl IntoIterator<Item = ScalarValue>,
data_type: &DataType,
) -> Result<ArrayRef> {
let mut scalars = scalars.into_iter().peekable();

/// Creates an array of $ARRAY_TY by unpacking values of
/// SCALAR_TY for primitive types
macro_rules! build_array_primitive {
Expand Down Expand Up @@ -2179,6 +2191,16 @@ impl ScalarValue {
Arc::new(array_into_large_list_array(values))
}

pub fn to_array_of_size_and_type(
&self,
size: usize,
target_type: &DataType,
) -> Result<ArrayRef> {
let array = self.to_array_of_size(size)?;
let cast_array = kernels::cast::cast(&array, target_type)?;
Ok(cast_array)
}

/// Converts a scalar value into an array of `size` rows.
///
/// # Errors
Expand Down
18 changes: 10 additions & 8 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,14 +687,16 @@ impl BoolVecBuilder {
ColumnarValue::Array(array) => {
self.combine_array(array.as_boolean());
}
ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => {
// False means all containers can not pass the predicate
self.inner = vec![false; self.inner.len()];
}
_ => {
// Null or true means the rows in container may pass this
// conjunct so we can't prune any containers based on that
}
ColumnarValue::Scalar(scalar) => match scalar.value() {
ScalarValue::Boolean(Some(false)) => {
// False means all containers can not pass the predicate
self.inner = vec![false; self.inner.len()];
}
_ => {
// Null or true means the rows in container may pass this
// conjunct so we can't prune any containers based on that
}
},
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
Ok(ColumnarValue::from(ScalarValue::Int32(Some(100))))
}
}

Expand Down Expand Up @@ -323,7 +323,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> {
vec![DataType::Int32],
DataType::Int32,
Volatility::Immutable,
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
Arc::new(move |_| Ok(ColumnarValue::from(ScalarValue::Int32(Some(1))))),
));

// Make sure that the UDF is used instead of the built-in function
Expand Down Expand Up @@ -669,7 +669,10 @@ impl ScalarUDFImpl for TakeUDF {
// The actual implementation
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
ColumnarValue::Scalar(scalar) => match scalar.value() {
ScalarValue::Int64(Some(v)) if v < &2 => *v as usize,
_ => unreachable!(),
},
_ => unreachable!(),
};
match &args[take_idx] {
Expand Down Expand Up @@ -1070,19 +1073,20 @@ impl ScalarUDFImpl for MyRegexUdf {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Scalar(scalar)] => match scalar.value() {
ScalarValue::Utf8(value) => Ok(ColumnarValue::from(
ScalarValue::Boolean(self.matches(value.as_deref())),
)),
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
},
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
_ => unreachable!(),
}
}

Expand Down
27 changes: 16 additions & 11 deletions datafusion/expr-common/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use datafusion_common::format::DEFAULT_CAST_OPTIONS;
use datafusion_common::{internal_err, Result, ScalarValue};
use std::sync::Arc;

use crate::scalar::Scalar;

/// The result of evaluating an expression.
///
/// [`ColumnarValue::Scalar`] represents a single value repeated any number of
Expand Down Expand Up @@ -89,7 +91,7 @@ pub enum ColumnarValue {
/// Array of values
Array(ArrayRef),
/// A single value
Scalar(ScalarValue),
Scalar(Scalar),
}

impl From<ArrayRef> for ColumnarValue {
Expand All @@ -100,14 +102,14 @@ impl From<ArrayRef> for ColumnarValue {

impl From<ScalarValue> for ColumnarValue {
fn from(value: ScalarValue) -> Self {
ColumnarValue::Scalar(value)
ColumnarValue::Scalar(value.into())
}
}

impl ColumnarValue {
pub fn data_type(&self) -> DataType {
pub fn data_type(&self) -> &DataType {
match self {
ColumnarValue::Array(array_value) => array_value.data_type().clone(),
ColumnarValue::Array(array_value) => array_value.data_type(),
ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(),
}
}
Expand Down Expand Up @@ -195,9 +197,12 @@ impl ColumnarValue {
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
// TODO(@notfilippo, logical vs physical): if `scalar.data_type` is *logically equivalent*
// to `cast_type` then skip the kernel cast and only change the `data_type` of the scalar.

let scalar_array =
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
if let ScalarValue::Float64(Some(float_ts)) = scalar {
if let ScalarValue::Float64(Some(float_ts)) = scalar.value() {
ScalarValue::Int64(Some(
(float_ts * 1_000_000_000_f64).trunc() as i64,
))
Expand All @@ -213,7 +218,7 @@ impl ColumnarValue {
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
let cast_scalar = Scalar::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
Expand Down Expand Up @@ -250,7 +255,7 @@ mod tests {
TestCase {
input: vec![
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::from(ScalarValue::Int32(Some(100))),
],
expected: vec![
make_array(1, 3),
Expand All @@ -260,7 +265,7 @@ mod tests {
// scalar and array
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::from(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
],
expected: vec![
Expand All @@ -271,9 +276,9 @@ mod tests {
// multiple scalars and array
TestCase {
input: vec![
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::from(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(200))),
ColumnarValue::from(ScalarValue::Int32(Some(200))),
],
expected: vec![
make_array(100, 3), // scalar is expanded
Expand Down Expand Up @@ -306,7 +311,7 @@ mod tests {
fn values_to_arrays_mixed_length_and_scalar() {
ColumnarValue::values_to_arrays(&[
ColumnarValue::Array(make_array(1, 3)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(100))),
ColumnarValue::from(ScalarValue::Int32(Some(100))),
ColumnarValue::Array(make_array(2, 7)),
])
.unwrap();
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub mod columnar_value;
pub mod groups_accumulator;
pub mod interval_arithmetic;
pub mod operator;
pub mod scalar;
pub mod signature;
pub mod sort_properties;
pub mod type_coercion;
Loading

0 comments on commit 454db7e

Please sign in to comment.