From d8e4e92daf7f20eef9af6919a8061192f7505043 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 16 Oct 2024 16:15:27 -0600 Subject: [PATCH 001/110] docs: Add documentation about conventional commits (#12971) * add documentation about conventional commits * prettier --- docs/source/contributor-guide/index.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 79a929879833..4645fe5c8804 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -116,6 +116,20 @@ If you are concerned that a larger design will be lost in a string of small PRs, Note all commits in a PR are squashed when merged to the `main` branch so there is one commit per PR after merge. +## Conventional Commits & Labeling PRs + +We generate change logs for each release using an automated process that will categorize PRs based on the title +and/or the GitHub labels attached to the PR. + +We follow the [Conventional Commits] specification to categorize PRs based on the title. This most often simply means +looking for titles starting with prefixes such as `fix:`, `feat:`, `docs:`, or `chore:`. We do not enforce this +convention but encourage its use if you want your PR to feature in the correct section of the changelog. + +The change log generator will also look at GitHub labels such as `bug`, `enhancement`, or `api change`, and labels +do take priority over the conventional commit approach, allowing maintainers to re-categorize PRs after they have been merged. + +[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/ + # Reviewing Pull Requests Some helpful links: From 3d1d28d287d6584668bde510908f65ebe262d22e Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:33:40 +1000 Subject: [PATCH 002/110] fix: Add Int32 type override for Dialects (#12916) * fix: Add Int32 type override for Dialects * fix: Dialect builder with_int32_cast_dtype: * test: Fix with_int32 test --- datafusion/sql/src/unparser/dialect.rs | 25 +++++++++++++++++++++ datafusion/sql/src/unparser/expr.rs | 30 +++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index aef3b0dfabbc..cfc28f2c499f 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -86,6 +86,12 @@ pub trait Dialect: Send + Sync { ast::DataType::BigInt(None) } + /// The SQL type to use for Arrow Int32 unparsing + /// Most dialects use Integer, but some, like MySQL, require SIGNED + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Integer(None) + } + /// The SQL type to use for Timestamp unparsing /// Most dialects use Timestamp, but some, like MySQL, require Datetime /// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp @@ -282,6 +288,10 @@ impl Dialect for MySqlDialect { ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) } + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + fn timestamp_cast_dtype( &self, _time_unit: &TimeUnit, @@ -347,6 +357,7 @@ pub struct CustomDialect { large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: sqlparser::ast::DataType, @@ -365,6 +376,7 @@ impl Default for CustomDialect { large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), timestamp_tz_cast_dtype: ast::DataType::Timestamp( None, @@ -424,6 +436,10 @@ impl Dialect for CustomDialect { self.int64_cast_dtype.clone() } + fn int32_cast_dtype(&self) -> ast::DataType { + self.int32_cast_dtype.clone() + } + fn timestamp_cast_dtype( &self, _time_unit: &TimeUnit, @@ -482,6 +498,7 @@ pub struct CustomDialectBuilder { large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: ast::DataType, @@ -506,6 +523,7 @@ impl CustomDialectBuilder { large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), timestamp_tz_cast_dtype: ast::DataType::Timestamp( None, @@ -527,6 +545,7 @@ impl CustomDialectBuilder { large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, int64_cast_dtype: self.int64_cast_dtype, + int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, date32_cast_dtype: self.date32_cast_dtype, @@ -604,6 +623,12 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific SQL type for Int32 casting: Integer, SIGNED, etc. + pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> Self { + self.int32_cast_dtype = int32_cast_dtype; + self + } + /// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc. pub fn with_timestamp_cast_dtype( mut self, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index b7491d1f88ce..1be5aa68bfba 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1352,7 +1352,7 @@ impl Unparser<'_> { DataType::Boolean => Ok(ast::DataType::Bool), DataType::Int8 => Ok(ast::DataType::TinyInt(None)), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), - DataType::Int32 => Ok(ast::DataType::Integer(None)), + DataType::Int32 => Ok(self.dialect.int32_cast_dtype()), DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), @@ -2253,6 +2253,34 @@ mod tests { Ok(()) } + #[test] + fn custom_dialect_with_int32_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int32_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int32, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + #[test] fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> { let default_dialect = CustomDialectBuilder::new().build(); From 0e2023d044eef862dec210a30051fd8dd9430f00 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Thu, 17 Oct 2024 22:09:54 +0800 Subject: [PATCH 003/110] fix: using simple string match replace regex match for contains udf (#12931) * fix: using simple string match replace regex match * doc: update doc of contains * test: add case for contains udf --------- Co-authored-by: zhuliquan --- datafusion/functions/src/string/contains.rs | 55 +++++++++++++-------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 0f75731aa1c3..86f1eda03342 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -16,8 +16,8 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; -use arrow::compute::regexp_is_match; +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; @@ -102,40 +102,25 @@ fn get_contains_doc() -> &'static Documentation { }) } -/// use regexp_is_match_utf8_scalar to do the calculation for contains +/// use `arrow::compute::contains` to do the calculation for contains pub fn contains(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< - StringViewArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } other => { @@ -143,3 +128,31 @@ pub fn contains(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod test { + use super::ContainsFunc; + use arrow::array::{BooleanArray, StringArray}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_contains_udf() { + let udf = ContainsFunc::new(); + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("xxx?()"), + Some("yyy?()"), + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let actual = udf.invoke(&[array, scalar]).unwrap(); + let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + ]))); + assert_eq!( + *actual.into_array(2).unwrap(), + *expect.into_array(2).unwrap() + ); + } +} From 56946b4d5df89f6ac3f07e06591e909aa2942e4e Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 17 Oct 2024 16:10:14 +0200 Subject: [PATCH 004/110] Increase minimum supported Rust version (MSRV) to 1.79 (#12962) Current goal is to support four last stable versions or versions for 4 months whichever is lower. Given 1.78.0 was released on: 2 May, 2024, it does not need to be supported. --- .github/workflows/rust.yml | 4 ++-- Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- datafusion-cli/Dockerfile | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4527d047e4c0..39b7b2b17857 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -574,9 +574,9 @@ jobs: # # To reproduce: # 1. Install the version of Rust that is failing. Example: - # rustup install 1.78.0 + # rustup install 1.79.0 # 2. Run the command that failed with that version. Example: - # cargo +1.78.0 check -p datafusion + # cargo +1.79.0 check -p datafusion # # To resolve, either: # 1. Change your code to use older Rust features, diff --git a/Cargo.toml b/Cargo.toml index 448607257ca1..2c142c87c892 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.78" +rust-version = "1.79" version = "42.0.0" [workspace.dependencies] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index b86dbd2a3802..fe929495aae6 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" readme = "README.md" [dependencies] diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index 7adead64db57..79c24f6baf3e 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.78-bookworm AS builder +FROM rust:1.79-bookworm AS builder COPY . /usr/src/datafusion COPY ./datafusion /usr/src/datafusion/datafusion diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 28d0d136bd05..8c4ad80e2924 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.78" +rust-version = "1.79" [lints] workspace = true diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 5051c8f9322f..6c53e1b1ced0 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 0914669f82fa..6e5783f467a7 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index d65c6ccaa660..3ffe5e3e76e7 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index ea28ac86e8df..aee8fac4a120 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.78" +rust-version = "1.79" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 6f8f81401f3b..41755018284e 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.78" +rust-version = "1.79" [lints] workspace = true From e63abe78f54cdbbba7ed92d65400525eeae59b71 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Thu, 17 Oct 2024 21:21:49 +0400 Subject: [PATCH 005/110] feat(substrait): add set operations to consumer, update substrait to `0.45.0` (#12863) * feat(substait): add set operations to consumer * add missing intersect all test, change distinct to is_all * upgrade substrait crate to 0.45 --- datafusion/substrait/Cargo.toml | 2 +- .../substrait/src/logical_plan/consumer.rs | 120 +++++++++++-- .../substrait/src/logical_plan/producer.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 66 +++++++ .../intersect_multiset.substrait.json | 166 ++++++++++++++++++ .../intersect_multiset_all.substrait.json | 166 ++++++++++++++++++ .../intersect_primary.substrait.json | 166 ++++++++++++++++++ .../test_plans/minus_primary.substrait.json | 166 ++++++++++++++++++ .../minus_primary_all.substrait.json | 166 ++++++++++++++++++ .../test_plans/union_distinct.substrait.json | 118 +++++++++++++ datafusion/substrait/tests/utils.rs | 1 + 11 files changed, 1136 insertions(+), 13 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json create mode 100644 datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 41755018284e..b0aa6acf3c7c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,7 +41,7 @@ object_store = { workspace = true } pbjson-types = "0.7" # TODO use workspace version prost = "0.13" -substrait = { version = "0.42", features = ["serde"] } +substrait = { version = "0.45", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index c727f784ee01..4af02858e65a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -196,6 +196,65 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( (accum_join_keys, nulls_equal_nulls, join_filter) } +async fn union_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + from_substrait_rel(ctx, &rels[0], extensions).await?, + )); + for input in &rels[1..] { + let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( ctx: &SessionContext, @@ -494,6 +553,7 @@ fn make_renamed_schema( } /// Convert Substrait Rel to DataFusion DataFrame +#[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( ctx: &SessionContext, @@ -877,27 +937,65 @@ pub async fn from_substrait_rel( Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { if !set.inputs.is_empty() { - let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - )); - for input in &set.inputs[1..] { - union_builder = union_builder? - .union(from_substrait_rel(ctx, input, extensions).await?); - } - union_builder?.build() + union_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!("Union relation requires at least one input") + } + } + set_rel::SetOp::UnionDistinct => { + if !set.inputs.is_empty() { + union_rels(&set.inputs, ctx, extensions, false).await } else { not_impl_err!("Union relation requires at least one input") } } set_rel::SetOp::IntersectionPrimary => { - if set.inputs.len() == 2 { + if set.inputs.len() >= 2 { LogicalPlanBuilder::intersect( from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - from_substrait_rel(ctx, &set.inputs[1], extensions).await?, + union_rels(&set.inputs[1..], ctx, extensions, true).await?, false, ) } else { - not_impl_err!("Primary Intersect relation with more than two inputs isn't supported") + not_impl_err!( + "Primary Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::IntersectionMultiset => { + if set.inputs.len() >= 2 { + intersect_rels(&set.inputs, ctx, extensions, false).await + } else { + not_impl_err!( + "Multiset Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::IntersectionMultisetAll => { + if set.inputs.len() >= 2 { + intersect_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!( + "MultisetAll Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::MinusPrimary => { + if set.inputs.len() >= 2 { + except_rels(&set.inputs, ctx, extensions, false).await + } else { + not_impl_err!( + "Primary Minus relation requires at least two inputs" + ) + } + } + set_rel::SetOp::MinusPrimaryAll => { + if set.inputs.len() >= 2 { + except_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!( + "PrimaryAll Minus relation requires at least two inputs" + ) } } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1165ce13d236..0e1375a8e0ea 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -172,6 +172,7 @@ pub fn to_substrait_extended_expr( } /// Convert DataFusion LogicalPlan to Substrait Rel +#[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, ctx: &SessionContext, @@ -227,6 +228,7 @@ pub fn to_substrait_rel( advanced_extension: None, read_type: Some(ReadType::VirtualTable(VirtualTable { values: vec![], + expressions: vec![], })), }))), })) @@ -263,7 +265,10 @@ pub fn to_substrait_rel( best_effort_filter: None, projection: None, advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { values })), + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), }))), })) } @@ -359,6 +364,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings, measures, advanced_extension: None, @@ -377,8 +383,10 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings: vec![Grouping { grouping_expressions: grouping, + expression_references: vec![], }], measures: vec![], advanced_extension: None, @@ -764,6 +772,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } +#[allow(deprecated)] pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], @@ -776,6 +785,7 @@ pub fn parse_flat_grouping_exprs( .collect::>>()?; Ok(Grouping { grouping_expressions, + expression_references: vec![], }) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ae87dad0153e..23ac601a44ec 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -687,6 +687,72 @@ async fn simple_intersect_consume() -> Result<()> { .await } +#[tokio::test] +async fn primary_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/minus_primary_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn union_distinct_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json"); + + assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await +} + #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json new file mode 100644 index 000000000000..8ff69bd82c3a --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json new file mode 100644 index 000000000000..56daf6ed46f4 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json new file mode 100644 index 000000000000..229dd7251705 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json new file mode 100644 index 000000000000..33b0e2ab8c80 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json new file mode 100644 index 000000000000..229f78ab5bf6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json new file mode 100644 index 000000000000..e8b02749660d --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_DISTINCT" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 9f63b74ef0fc..00cbfb0c412c 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -147,6 +147,7 @@ pub mod test { Ok(()) } + #[allow(deprecated)] fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> { let rel_type = rel .rel_type From 1ba1e539b01bbfc7f9001423cfe1ff0015a99db7 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 18 Oct 2024 01:21:59 +0800 Subject: [PATCH 006/110] Unparse `SubqueryAlias` without projections to SQL (#12896) * change pub function comment to doc * unparse subquery alias without projections * fix tests * rollback the empty line * rollback the empty line * exclude the table_scan with pushdown case * fmt and clippy * simplify the ast to string and remove unused debug code --- datafusion/sql/src/unparser/plan.rs | 64 ++++++----- datafusion/sql/src/unparser/rewrite.rs | 93 ++++++++-------- datafusion/sql/tests/cases/plan_to_sql.rs | 124 +++++++++++++++++----- 3 files changed, 184 insertions(+), 97 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index d150f0e532c6..9b4818b98cb0 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,19 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::unparser::utils::unproject_agg_exprs; -use datafusion_common::{ - internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, - Column, DataFusionError, Result, TableReference, -}; -use datafusion_expr::{ - expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, -}; -use sqlparser::ast::{self, Ident, SetExpr}; -use std::sync::Arc; - use super::{ ast::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, @@ -44,6 +31,18 @@ use super::{ }, Unparser, }; +use crate::unparser::utils::unproject_agg_exprs; +use datafusion_common::{ + internal_err, not_impl_err, + tree_node::{TransformedResult, TreeNode}, + Column, DataFusionError, Result, TableReference, +}; +use datafusion_expr::{ + expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + LogicalPlanBuilder, Projection, SortExpr, TableScan, +}; +use sqlparser::ast::{self, Ident, SetExpr}; +use std::sync::Arc; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// @@ -249,12 +248,9 @@ impl Unparser<'_> { ) -> Result<()> { match plan { LogicalPlan::TableScan(scan) => { - if scan.projection.is_some() - || !scan.filters.is_empty() - || scan.fetch.is_some() + if let Some(unparsed_table_scan) = + Self::unparse_table_scan_pushdown(plan, None)? { - let unparsed_table_scan = - Self::unparse_table_scan_pushdown(plan, None)?; return self.select_to_sql_recursively( &unparsed_table_scan, query, @@ -498,10 +494,18 @@ impl Unparser<'_> { LogicalPlan::SubqueryAlias(plan_alias) => { let (plan, mut columns) = subquery_alias_inner_query_and_columns(plan_alias); - let plan = Self::unparse_table_scan_pushdown( + let unparsed_table_scan = Self::unparse_table_scan_pushdown( plan, Some(plan_alias.alias.clone()), )?; + // if the child plan is a TableScan with pushdown operations, we don't need to + // create an additional subquery for it + if !select.already_projected() && unparsed_table_scan.is_none() { + select.projection(vec![ast::SelectItem::Wildcard( + ast::WildcardAdditionalOptions::default(), + )]); + } + let plan = unparsed_table_scan.unwrap_or_else(|| plan.clone()); if !columns.is_empty() && !self.dialect.supports_column_alias_in_table_alias() { @@ -582,12 +586,21 @@ impl Unparser<'_> { } } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { + scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() + } + + /// Try to unparse a table scan with pushdown operations into a new subquery plan. + /// If the table scan is without any pushdown operations, return None. fn unparse_table_scan_pushdown( plan: &LogicalPlan, alias: Option, - ) -> Result { + ) -> Result> { match plan { LogicalPlan::TableScan(table_scan) => { + if !Self::is_scan_with_pushdown(table_scan) { + return Ok(None); + } let mut filter_alias_rewriter = alias.as_ref().map(|alias_name| TableAliasRewriter { table_schema: table_scan.source.schema(), @@ -648,18 +661,15 @@ impl Unparser<'_> { builder = builder.limit(0, Some(fetch))?; } - builder.build() + Ok(Some(builder.build()?)) } LogicalPlan::SubqueryAlias(subquery_alias) => { - let new_plan = Self::unparse_table_scan_pushdown( + Self::unparse_table_scan_pushdown( &subquery_alias.input, Some(subquery_alias.alias.clone()), - )?; - LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.alias.clone())? - .build() + ) } - _ => Ok(plan.clone()), + _ => Ok(None), } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 304a02f037e6..3049df9396cb 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -101,25 +101,25 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { Ok(sort_exprs) } -// Rewrite logic plan for query that order by columns are not in projections -// Plan before rewrite: -// -// Projection: j1.j1_string, j2.j2_string -// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST -// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id -// Inner Join: Filter: j1.j1_id = j2.j2_id -// TableScan: j1 -// TableScan: j2 -// -// Plan after rewrite -// -// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST -// Projection: j1.j1_string, j2.j2_string -// Inner Join: Filter: j1.j1_id = j2.j2_id -// TableScan: j1 -// TableScan: j2 -// -// This prevents the original plan generate query with derived table but missing alias. +/// Rewrite logic plan for query that order by columns are not in projections +/// Plan before rewrite: +/// +/// Projection: j1.j1_string, j2.j2_string +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// Plan after rewrite +/// +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// This prevents the original plan generate query with derived table but missing alias. pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( p: &Projection, ) -> Option { @@ -191,33 +191,33 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( } } -// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of -// subquery -// - `(SELECT column_a as a from table) AS A` -// - `(SELECT column_a from table) AS A (a)` -// -// A roundtrip example for table alias with columns -// -// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) -// -// LogicPlan: -// Projection: c.id -// SubqueryAlias: c -// Projection: j1.j1_id AS id -// Projection: j1.j1_id -// TableScan: j1 -// -// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS -// id FROM (SELECT j1.j1_id FROM j1)) AS c`. -// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table -// `(SELECT j1.j1_id FROM j1)` -// -// With this logic, the unparsed query will be: -// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` -// -// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` -// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and -// Column in the Projections. Once the parser side is fixed, this logic should work +/// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +/// subquery +/// - `(SELECT column_a as a from table) AS A` +/// - `(SELECT column_a from table) AS A (a)` +/// +/// A roundtrip example for table alias with columns +/// +/// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +/// +/// LogicPlan: +/// Projection: c.id +/// SubqueryAlias: c +/// Projection: j1.j1_id AS id +/// Projection: j1.j1_id +/// TableScan: j1 +/// +/// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +/// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +/// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +/// `(SELECT j1.j1_id FROM j1)` +/// +/// With this logic, the unparsed query will be: +/// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +/// +/// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +/// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +/// Column in the Projections. Once the parser side is fixed, this logic should work pub(super) fn subquery_alias_inner_query_and_columns( subquery_alias: &datafusion_expr::SubqueryAlias, ) -> (&LogicalPlan, Vec) { @@ -330,6 +330,7 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { _ => None, } } + /// A `TreeNodeRewriter` implementation that rewrites `Expr::Column` expressions by /// replacing the column's name with an alias if the column exists in the provided schema. /// diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index aff9f99c8cd3..e4e5d6a92964 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -71,7 +71,7 @@ fn roundtrip_expr() { let ast = expr_to_sql(&expr)?; - Ok(format!("{}", ast)) + Ok(ast.to_string()) }; for (table, query, expected) in tests { @@ -192,7 +192,7 @@ fn roundtrip_statement() -> Result<()> { let roundtrip_statement = plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -224,7 +224,7 @@ fn roundtrip_crossjoin() -> Result<()> { let roundtrip_statement = plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -237,7 +237,7 @@ fn roundtrip_crossjoin() -> Result<()> { \n TableScan: j1\ \n TableScan: j2"; - assert_eq!(format!("{plan_roundtrip}"), expected); + assert_eq!(plan_roundtrip.to_string(), expected); Ok(()) } @@ -478,7 +478,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?; - let actual = format!("{}", &roundtrip_statement); + let actual = &roundtrip_statement.to_string(); println!("roundtrip sql: {actual}"); println!("plan {}", plan.display_indent()); @@ -508,7 +508,7 @@ Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placehold Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col TableScan: unnest_table"#.trim_start(); - assert_eq!(format!("{plan}"), expected); + assert_eq!(plan.to_string(), expected); Ok(()) } @@ -528,7 +528,7 @@ fn test_table_references_in_plan_to_sql() { .unwrap(); let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(format!("{}", sql), expected_sql) + assert_eq!(sql.to_string(), expected_sql) } test( @@ -558,7 +558,7 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() { .build() .unwrap(); let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(format!("{}", sql), expected_sql) + assert_eq!(sql.to_string(), expected_sql) } test( @@ -667,27 +667,103 @@ where } #[test] -fn test_table_scan_pushdown() -> Result<()> { +fn test_table_scan_alias() -> Result<()> { let schema = Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("age", DataType::Utf8, false), ]); + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .filter(col("id").gt(lit(5)))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!( + sql.to_string(), + "SELECT * FROM (SELECT t1.id FROM t1 WHERE (t1.id > 5)) AS a" + ); + + let table_scan_with_two_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(lit(1)), col("age").lt(lit(2))], + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; + assert_eq!( + table_scan_with_two_filter.to_string(), + "SELECT * FROM (SELECT t1.id FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))) AS a" + ); + + let table_scan_with_fetch = + table_scan_with_filter_and_fetch(Some("t1"), &schema, None, vec![], Some(10))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; + assert_eq!( + table_scan_with_fetch.to_string(), + "SELECT * FROM (SELECT t1.id FROM (SELECT * FROM t1 LIMIT 10)) AS a" + ); + + let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![col("id").gt(lit(1))], + Some(10), + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; + assert_eq!( + table_scan_with_pushdown_all.to_string(), + "SELECT * FROM (SELECT t1.id FROM (SELECT t1.id, t1.age FROM t1 WHERE (t1.id > 1) LIMIT 10)) AS a" + ); + Ok(()) +} + +#[test] +fn test_table_scan_pushdown() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; assert_eq!( - format!("{}", scan_with_projection), + scan_with_projection.to_string(), "SELECT t1.id, t1.age FROM t1" ); let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; - assert_eq!(format!("{}", scan_with_projection), "SELECT t1.age FROM t1"); + assert_eq!(scan_with_projection.to_string(), "SELECT t1.age FROM t1"); let scan_with_no_projection = table_scan(Some("t1"), &schema, None)?.build()?; let scan_with_no_projection = plan_to_sql(&scan_with_no_projection)?; - assert_eq!(format!("{}", scan_with_no_projection), "SELECT * FROM t1"); + assert_eq!(scan_with_no_projection.to_string(), "SELECT * FROM t1"); let table_scan_with_projection_alias = table_scan(Some("t1"), &schema, Some(vec![0, 1]))? @@ -696,7 +772,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_projection_alias), + table_scan_with_projection_alias.to_string(), "SELECT ta.id, ta.age FROM t1 AS ta" ); @@ -707,7 +783,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_projection_alias), + table_scan_with_projection_alias.to_string(), "SELECT ta.age FROM t1 AS ta" ); @@ -717,7 +793,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_no_projection_alias = plan_to_sql(&table_scan_with_no_projection_alias)?; assert_eq!( - format!("{}", table_scan_with_no_projection_alias), + table_scan_with_no_projection_alias.to_string(), "SELECT * FROM t1 AS ta" ); @@ -729,7 +805,7 @@ fn test_table_scan_pushdown() -> Result<()> { let query_from_table_scan_with_projection = plan_to_sql(&query_from_table_scan_with_projection)?; assert_eq!( - format!("{}", query_from_table_scan_with_projection), + query_from_table_scan_with_projection.to_string(), "SELECT * FROM (SELECT t1.id, t1.age FROM t1)" ); @@ -742,7 +818,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_filter = plan_to_sql(&table_scan_with_filter)?; assert_eq!( - format!("{}", table_scan_with_filter), + table_scan_with_filter.to_string(), "SELECT * FROM t1 WHERE (t1.id > t1.age)" ); @@ -755,7 +831,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; assert_eq!( - format!("{}", table_scan_with_two_filter), + table_scan_with_two_filter.to_string(), "SELECT * FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))" ); @@ -769,7 +845,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_filter_alias = plan_to_sql(&table_scan_with_filter_alias)?; assert_eq!( - format!("{}", table_scan_with_filter_alias), + table_scan_with_filter_alias.to_string(), "SELECT * FROM t1 AS ta WHERE (ta.id > ta.age)" ); @@ -783,7 +859,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; assert_eq!( - format!("{}", table_scan_with_projection_and_filter), + table_scan_with_projection_and_filter.to_string(), "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age)" ); @@ -797,7 +873,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; assert_eq!( - format!("{}", table_scan_with_projection_and_filter), + table_scan_with_projection_and_filter.to_string(), "SELECT t1.age FROM t1 WHERE (t1.id > t1.age)" ); @@ -806,7 +882,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_inline_fetch = plan_to_sql(&table_scan_with_inline_fetch)?; assert_eq!( - format!("{}", table_scan_with_inline_fetch), + table_scan_with_inline_fetch.to_string(), "SELECT * FROM t1 LIMIT 10" ); @@ -821,7 +897,7 @@ fn test_table_scan_pushdown() -> Result<()> { let table_scan_with_projection_and_inline_fetch = plan_to_sql(&table_scan_with_projection_and_inline_fetch)?; assert_eq!( - format!("{}", table_scan_with_projection_and_inline_fetch), + table_scan_with_projection_and_inline_fetch.to_string(), "SELECT t1.id, t1.age FROM t1 LIMIT 10" ); @@ -835,7 +911,7 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_all = plan_to_sql(&table_scan_with_all)?; assert_eq!( - format!("{}", table_scan_with_all), + table_scan_with_all.to_string(), "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10" ); Ok(()) From b098893a34f83f1a1df290168377d7622938b3f5 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Thu, 17 Oct 2024 19:22:18 +0200 Subject: [PATCH 007/110] Fix 2 bugs related to push down partition filters (#12902) * Report errors in partition filters This patch fixes 2 bugs. Errors in partition filters are ignored and that we allow partitions filters be push down for unpartition tables but we never evaluate such filters. The first bug is fixed by reporting errors for partition filters and only evaluating the filters we allowed as partition filters in `supports_filters_pushdown`. The second bug is fixed by only allowing partition filters to be pushed down when we have partition columns. * Update datafusion/sqllogictest/test_files/errors.slt --- datafusion/core/src/dataframe/mod.rs | 4 +- .../core/src/datasource/listing/helpers.rs | 36 +++++----- .../core/src/datasource/listing/table.rs | 69 ++++++++++--------- .../sqllogictest/test_files/arrow_files.slt | 5 ++ datafusion/sqllogictest/test_files/errors.slt | 4 ++ 5 files changed, 65 insertions(+), 53 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 67e2a4780d06..8a0829cd5e4b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2987,9 +2987,7 @@ mod tests { JoinType::Inner, Some(Expr::Literal(ScalarValue::Null)), )?; - let expected_plan = "CrossJoin:\ - \n TableScan: a projection=[c1], full_filters=[Boolean(NULL)]\ - \n TableScan: b projection=[c1]"; + let expected_plan = "EmptyRelation"; assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?)); // JOIN ON expression must be boolean type diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 72d7277d6ae2..47012f777ad1 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -24,6 +24,7 @@ use std::sync::Arc; use super::ListingTableUrl; use super::PartitionedFile; use crate::execution::context::SessionState; +use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{BinaryExpr, Operator}; @@ -285,25 +286,20 @@ async fn prune_partitions( let props = ExecutionProps::new(); // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Option { - let expr = create_physical_expr(filter, &df_schema, &props).ok()?; - expr.evaluate(&batch) - .ok()? - .into_array(partitions.len()) - .ok() + let do_filter = |filter| -> Result { + let expr = create_physical_expr(filter, &df_schema, &props)?; + expr.evaluate(&batch)?.into_array(partitions.len()) }; - //.Compute the conjunction of the filters, ignoring errors + //.Compute the conjunction of the filters let mask = filters .iter() - .fold(None, |acc, filter| match (acc, do_filter(filter)) { - (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), - (None, Some(r)) => Some(r.as_boolean().clone()), - (r, None) => r, - }); + .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) + .reduce(|a, b| Ok(and(&a?, &b?)?)); let mask = match mask { - Some(mask) => mask, + Some(Ok(mask)) => mask, + Some(Err(err)) => return Err(err), None => return Ok(partitions), }; @@ -401,8 +397,8 @@ fn evaluate_partition_prefix<'a>( /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. -/// `filters` might contain expressions that can be resolved only at the -/// file level (e.g. Parquet row group pruning). +/// `filters` should only contain expressions that can be evaluated +/// using only the partition columns. pub async fn pruned_partition_list<'a>( ctx: &'a SessionState, store: &'a dyn ObjectStore, @@ -413,6 +409,12 @@ pub async fn pruned_partition_list<'a>( ) -> Result>> { // if no partition col => simply list all the files if partition_cols.is_empty() { + if !filters.is_empty() { + return internal_err!( + "Got partition filters for unpartitioned table {}", + table_path + ); + } return Ok(Box::pin( table_path .list_all_files(ctx, store, file_extension) @@ -631,13 +633,11 @@ mod tests { ]); let filter1 = Expr::eq(col("part1"), lit("p1v2")); let filter2 = Expr::eq(col("part2"), lit("p2v1")); - // filter3 cannot be resolved at partition pruning - let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2, filter3], + &[filter1, filter2], ".parquet", &[ (String::from("part1"), DataType::Utf8), diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index a9c6aec17537..1e9f06c20b47 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -782,6 +782,16 @@ impl ListingTable { } } +// Expressions can be used for parttion pruning if they can be evaluated using +// only the partiton columns and there are partition columns. +fn can_be_evaluted_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + #[async_trait] impl TableProvider for ListingTable { fn as_any(&self) -> &dyn Any { @@ -807,10 +817,28 @@ impl TableProvider for ListingTable { filters: &[Expr], limit: Option, ) -> Result> { + // extract types of partition columns + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + }); // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? let session_state = state.as_any().downcast_ref::().unwrap(); let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(session_state, filters, limit) + .list_files_for_scan(session_state, &partition_filters, limit) .await?; // if no files need to be read, return an `EmptyExec` @@ -846,28 +874,6 @@ impl TableProvider for ListingTable { None => {} // no ordering required }; - // extract types of partition columns - let table_partition_cols = self - .options - .table_partition_cols - .iter() - .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) - .collect::>>()?; - - // If the filters can be resolved using only partition cols, there is no need to - // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated - let table_partition_col_names = table_partition_cols - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - let filters = filters - .iter() - .filter(|filter| { - !expr_applicable_for_cols(&table_partition_col_names, filter) - }) - .cloned() - .collect::>(); - let filters = conjunction(filters.to_vec()) .map(|expr| -> Result<_> { // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. @@ -908,18 +914,17 @@ impl TableProvider for ListingTable { &self, filters: &[&Expr], ) -> Result> { + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); filters .iter() .map(|filter| { - if expr_applicable_for_cols( - &self - .options - .table_partition_cols - .iter() - .map(|col| col.0.as_str()) - .collect::>(), - filter, - ) { + if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) + { // if filter can be handled by partition pruning, it is exact return Ok(TableProviderFilterPushDown::Exact); } diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index e66ba7477fc4..e73acc384cb3 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -118,3 +118,8 @@ EXPLAIN SELECT f0 FROM arrow_partitioned WHERE part = 456 ---- logical_plan TableScan: arrow_partitioned projection=[f0], full_filters=[arrow_partitioned.part = Int32(456)] physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow]]}, projection=[f0] + + +# Errors in partition filters should be reported +query error Divide by zero error +SELECT f0 FROM arrow_partitioned WHERE CASE WHEN true THEN 1 / 0 ELSE part END = 1; diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index ce0947525344..da46a7e5e679 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -133,3 +133,7 @@ create table foo as values (1), ('foo'); query error No function matches select 1 group by substr(''); + +# Error in filter should be reported +query error Divide by zero +SELECT c2 from aggregate_test_100 where CASE WHEN true THEN 1 / 0 ELSE 0 END = 1; From 54bd26ed12f854b87d20f0d70ac64c02fcd5150f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 17 Oct 2024 19:22:40 +0200 Subject: [PATCH 008/110] Move TableConstraint to Constraints conversion (#12953) Reduce datafusion-common dependency on sqlparser --- .../common/src/functional_dependencies.rs | 73 +----------------- datafusion/sql/src/statement.rs | 77 ++++++++++++++++++- 2 files changed, 74 insertions(+), 76 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 90f4e6e7e3d1..ed9a68c19536 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -23,11 +23,8 @@ use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::vec::IntoIter; -use crate::error::_plan_err; use crate::utils::{merge_and_order_indices, set_difference}; -use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; - -use sqlparser::ast::TableConstraint; +use crate::{DFSchema, JoinType}; /// This object defines a constraint on a table. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -60,74 +57,6 @@ impl Constraints { Self { inner: constraints } } - /// Convert each `TableConstraint` to corresponding `Constraint` - pub fn new_from_table_constraints( - constraints: &[TableConstraint], - df_schema: &DFSchemaRef, - ) -> Result { - let constraints = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let field_names = df_schema.field_names(); - // Get unique constraint indices in the schema: - let indices = columns - .iter() - .map(|u| { - let idx = field_names - .iter() - .position(|item| *item == u.value) - .ok_or_else(|| { - let name = name - .as_ref() - .map(|name| format!("with name '{name}' ")) - .unwrap_or("".to_string()); - DataFusionError::Execution( - format!("Column for unique constraint {}not found in schema: {}", name,u.value) - ) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::Unique(indices)) - } - TableConstraint::PrimaryKey { columns, .. } => { - let field_names = df_schema.field_names(); - // Get primary key indices in the schema: - let indices = columns - .iter() - .map(|pk| { - let idx = field_names - .iter() - .position(|item| *item == pk.value) - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Column for primary key not found in schema: {}", - pk.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::PrimaryKey(indices)) - } - TableConstraint::ForeignKey { .. } => { - _plan_err!("Foreign key constraints are not currently supported") - } - TableConstraint::Check { .. } => { - _plan_err!("Check constraints are not currently supported") - } - TableConstraint::Index { .. } => { - _plan_err!("Indexes are not currently supported") - } - TableConstraint::FulltextOrSpatial { .. } => { - _plan_err!("Indexes are not currently supported") - } - }) - .collect::>>()?; - Ok(Constraints::new_unverified(constraints)) - } - /// Check whether constraints is empty pub fn is_empty(&self) -> bool { self.inner.is_empty() diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index edb4316db1e0..4109f1371187 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -30,10 +30,11 @@ use crate::planner::{ use crate::utils::normalize_ident; use arrow_schema::{DataType, Fields}; +use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, - unqualified_field_not_found, Column, Constraints, DFSchema, DFSchemaRef, + unqualified_field_not_found, Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, }; @@ -427,7 +428,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -452,7 +453,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, }; let plan = LogicalPlan::EmptyRelation(plan); - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -1242,7 +1243,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let name = self.object_name_to_table_reference(name)?; let constraints = - Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; + Self::new_constraint_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, @@ -1262,6 +1263,74 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + /// Convert each `TableConstraint` to corresponding `Constraint` + fn new_constraint_from_table_constraints( + constraints: &[TableConstraint], + df_schema: &DFSchemaRef, + ) -> Result { + let constraints = constraints + .iter() + .map(|c: &TableConstraint| match c { + TableConstraint::Unique { name, columns, .. } => { + let field_names = df_schema.field_names(); + // Get unique constraint indices in the schema: + let indices = columns + .iter() + .map(|u| { + let idx = field_names + .iter() + .position(|item| *item == u.value) + .ok_or_else(|| { + let name = name + .as_ref() + .map(|name| format!("with name '{name}' ")) + .unwrap_or("".to_string()); + DataFusionError::Execution( + format!("Column for unique constraint {}not found in schema: {}", name,u.value) + ) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::Unique(indices)) + } + TableConstraint::PrimaryKey { columns, .. } => { + let field_names = df_schema.field_names(); + // Get primary key indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = field_names + .iter() + .position(|item| *item == pk.value) + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Column for primary key not found in schema: {}", + pk.value + )) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::PrimaryKey(indices)) + } + TableConstraint::ForeignKey { .. } => { + _plan_err!("Foreign key constraints are not currently supported") + } + TableConstraint::Check { .. } => { + _plan_err!("Check constraints are not currently supported") + } + TableConstraint::Index { .. } => { + _plan_err!("Indexes are not currently supported") + } + TableConstraint::FulltextOrSpatial { .. } => { + _plan_err!("Indexes are not currently supported") + } + }) + .collect::>>()?; + Ok(Constraints::new_unverified(constraints)) + } + fn parse_options_map( &self, options: Vec<(String, Value)>, From ccfe020a9a98203d7d37d1431e351be8d4418f63 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Thu, 17 Oct 2024 13:23:04 -0400 Subject: [PATCH 009/110] Added current_timestamp alias (#12958) * Add current_timestamp * ft fix? * fmt fix --- datafusion/functions/src/datetime/now.rs | 6 ++++++ datafusion/sqllogictest/test_files/timestamps.slt | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 74eb5aea4255..690008d97212 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -28,6 +28,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} #[derive(Debug)] pub struct NowFunc { signature: Signature, + aliases: Vec, } impl Default for NowFunc { @@ -40,6 +41,7 @@ impl NowFunc { pub fn new() -> Self { Self { signature: Signature::uniform(0, vec![], Volatility::Stable), + aliases: vec!["current_timestamp".to_string()], } } } @@ -85,6 +87,10 @@ impl ScalarUDFImpl for NowFunc { ))) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { false } diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index a680e0db522d..d866ec8c94dd 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -84,6 +84,11 @@ select case when current_time() = (now()::bigint % 86400000000000)::time then 'O ---- OK +query B +select now() = current_timestamp; +---- +true + ########## ## Timestamp Handling Tests ########## From ad273cab8bf300a704baf005df072bb980645e51 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Thu, 17 Oct 2024 10:23:29 -0700 Subject: [PATCH 010/110] Improve unparsing for `ORDER BY`, `UNION`, Windows functions with Aggregation (#12946) * Improve unparsing for ORDER BY with Aggregation functions (#38) * Improve UNION unparsing (#39) * Scalar functions in ORDER BY unparsing support (#41) * Improve unparsing for complex Window functions with Aggregation (#42) * WindowFunction order_by should respect `supports_nulls_first_in_sort` dialect setting (#43) * Fix plan_to_sql * Improve --- datafusion/sql/src/unparser/expr.rs | 10 +--- datafusion/sql/src/unparser/plan.rs | 42 +++++++++----- datafusion/sql/src/unparser/utils.rs | 69 ++++++++++++++++++----- datafusion/sql/tests/cases/plan_to_sql.rs | 63 ++++++++++++++++++++- 4 files changed, 148 insertions(+), 36 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 1be5aa68bfba..8864c97bb1ff 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -76,11 +76,6 @@ pub fn expr_to_sql(expr: &Expr) -> Result { unparser.expr_to_sql(expr) } -pub fn sort_to_sql(sort: &Sort) -> Result { - let unparser = Unparser::default(); - unparser.sort_to_sql(sort) -} - const LOWEST: &BinaryOperator = &BinaryOperator::Or; // Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs // (https://www.postgresql.org/docs/7.2/sql-precedence.html) @@ -229,9 +224,10 @@ impl Unparser<'_> { ast::WindowFrameUnits::Groups } }; - let order_by: Vec = order_by + + let order_by = order_by .iter() - .map(sort_to_sql) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>()?; let start_bound = self.convert_bound(&window_frame.start_bound)?; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 9b4818b98cb0..c22400f1faa1 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -27,7 +27,7 @@ use super::{ }, utils::{ find_agg_node_within_select, find_window_nodes_within_select, - unproject_window_exprs, + unproject_sort_expr, unproject_window_exprs, }, Unparser, }; @@ -352,19 +352,30 @@ impl Unparser<'_> { if select.already_projected() { return self.derive(plan, relation); } - if let Some(query_ref) = query { - if let Some(fetch) = sort.fetch { - query_ref.limit(Some(ast::Expr::Value(ast::Value::Number( - fetch.to_string(), - false, - )))); - } - query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); - } else { + let Some(query_ref) = query else { return internal_err!( "Sort operator only valid in a statement context." ); - } + }; + + if let Some(fetch) = sort.fetch { + query_ref.limit(Some(ast::Expr::Value(ast::Value::Number( + fetch.to_string(), + false, + )))); + }; + + let agg = find_agg_node_within_select(plan, select.already_projected()); + // unproject sort expressions + let sort_exprs: Vec = sort + .expr + .iter() + .map(|sort_expr| { + unproject_sort_expr(sort_expr, agg, sort.input.as_ref()) + }) + .collect::>>()?; + + query_ref.order_by(self.sorts_to_sql(&sort_exprs)?); self.select_to_sql_recursively( sort.input.as_ref(), @@ -402,7 +413,7 @@ impl Unparser<'_> { .collect::>>()?; if let Some(sort_expr) = &on.sort_expr { if let Some(query_ref) = query { - query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort_expr)?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -546,6 +557,11 @@ impl Unparser<'_> { ); } + // Covers cases where the UNION is a subquery and the projection is at the top level + if select.already_projected() { + return self.derive(plan, relation); + } + let input_exprs: Vec = union .inputs .iter() @@ -691,7 +707,7 @@ impl Unparser<'_> { } } - fn sorts_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result> { sort_exprs .iter() .map(|sort_expr| self.sort_to_sql(sort_expr)) diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index e8c4eca569b1..5e3a3aa600b6 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -20,10 +20,11 @@ use std::cmp::Ordering; use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Column, DataFusionError, Result, ScalarValue, + Column, Result, ScalarValue, }; use datafusion_expr::{ - utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, + utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr, + Window, }; use sqlparser::ast; @@ -118,21 +119,11 @@ pub(crate) fn unproject_agg_exprs( if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(mut unprojected_expr) = + } else if let Some(unprojected_expr) = windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { - if let Expr::WindowFunction(func) = &mut unprojected_expr { - // Window function can contain an aggregation column, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - func.args.iter_mut().try_for_each(|arg| { - if let Expr::Column(c) = arg { - if let Some(expr) = find_agg_expr(agg, c)? { - *arg = expr.clone(); - } - } - Ok::<(), DataFusionError>(()) - })?; - } - Ok(Transformed::yes(unprojected_expr)) + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?)); } else { internal_err!( "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name @@ -200,6 +191,54 @@ fn find_window_expr<'a>( .find(|expr| expr.schema_name().to_string() == column_name) } +/// Transforms a Column expression into the actual expression from aggregation or projection if found. +/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced +/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to +/// the actual expression, such as sum("catalog_returns"."cr_net_loss"). +pub(crate) fn unproject_sort_expr( + sort_expr: &SortExpr, + agg: Option<&Aggregate>, + input: &LogicalPlan, +) -> Result { + let mut sort_expr = sort_expr.clone(); + + // Remove alias if present, because ORDER BY cannot use aliases + if let Expr::Alias(alias) = &sort_expr.expr { + sort_expr.expr = *alias.expr.clone(); + } + + let Expr::Column(ref col_ref) = sort_expr.expr else { + return Ok(sort_expr); + }; + + if col_ref.relation.is_some() { + return Ok(sort_expr); + }; + + // In case of aggregation there could be columns containing aggregation functions we need to unproject + if let Some(agg) = agg { + if agg.schema.is_column_from_schema(col_ref) { + let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?; + sort_expr.expr = new_expr; + return Ok(sort_expr); + } + } + + // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will + // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need + // to transform it back to the actual expression. + if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { + if let Ok(idx) = schema.index_of_column(col_ref) { + if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { + sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + } + } + return Ok(sort_expr); + } + + Ok(sort_expr) +} + /// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. pub(crate) fn date_part_to_sql( unparser: &Unparser, diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index e4e5d6a92964..74abdf075f23 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -22,6 +22,9 @@ use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; +use datafusion_functions::unicode; +use datafusion_functions_aggregate::grouping::grouping_udaf; +use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, @@ -139,6 +142,13 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, + r#"SELECT col1, id FROM ( + SELECT j1_string AS col1, j1_id AS id FROM j1 + UNION ALL + SELECT j2_string AS col1, j2_id AS id FROM j2 + UNION ALL + SELECT j3_string AS col1, j3_id AS id FROM j3 + ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#, "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", @@ -657,7 +667,12 @@ where .unwrap(); let context = MockContextProvider { - state: MockSessionState::default(), + state: MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())), }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -969,3 +984,49 @@ fn test_with_offset0() { fn test_with_offset95() { sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95"); } + +#[test] +fn test_order_by_to_sql() { + // order by aggregation function + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by aggregation function alias + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by scalar function from projection + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#, + r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#, + ); +} + +#[test] +fn test_aggregation_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, + rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1, + rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2 + FROM person + GROUP BY id, first_name;"#, + r#"SELECT person.id, person.first_name, +sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum, +max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, +rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, +rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 +FROM person +GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), + ); +} From 0ed369e925ae8856e36b166bfcea8601019c6967 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 17 Oct 2024 20:24:07 +0300 Subject: [PATCH 011/110] Handle one-element array return value in ScalarFunctionExpr (#12965) This was done in #12922 only for math functions. We now generalize this fallback to all scalar UDFs. --- datafusion/expr-common/src/columnar_value.rs | 11 ----------- datafusion/functions/src/macros.rs | 12 ++++++------ .../physical-expr/src/scalar_function.rs | 18 +++++++++++++++--- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 1ee90eb4b4a8..57056d0806a7 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -217,17 +217,6 @@ impl ColumnarValue { } } } - - /// Converts an [`ArrayRef`] to a [`ColumnarValue`] based on the supplied arguments. - /// This is useful for scalar UDF implementations to fulfil their contract: - /// if all arguments are scalar values, the result should also be a scalar value. - pub fn from_args_and_result(args: &[Self], result: ArrayRef) -> Result { - if result.len() == 1 && args.iter().all(|arg| matches!(arg, Self::Scalar(_))) { - Ok(Self::Scalar(ScalarValue::try_from_array(&result, 0)?)) - } else { - Ok(Self::Array(result)) - } - } } #[cfg(test)] diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 85ffaa868f24..744a0189125c 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -228,8 +228,8 @@ macro_rules! make_math_unary_udf { $EVALUATE_BOUNDS(inputs) } - fn invoke(&self, col_args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(col_args)?; + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => { Arc::new(make_function_scalar_inputs_return_type!( @@ -257,7 +257,7 @@ macro_rules! make_math_unary_udf { } }; - ColumnarValue::from_args_and_result(col_args, arr) + Ok(ColumnarValue::Array(arr)) } fn documentation(&self) -> Option<&Documentation> { @@ -344,8 +344,8 @@ macro_rules! make_math_binary_udf { $OUTPUT_ORDERING(input) } - fn invoke(&self, col_args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(col_args)?; + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => Arc::new(make_function_inputs2!( &args[0], @@ -372,7 +372,7 @@ macro_rules! make_math_binary_udf { } }; - ColumnarValue::from_args_and_result(col_args, arr) + Ok(ColumnarValue::Array(arr)) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 130c335d1c95..4d3db96ceb3c 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -39,7 +39,8 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, DFSchema, Result}; +use arrow_array::Array; +use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; @@ -147,8 +148,19 @@ impl PhysicalExpr for ScalarFunctionExpr { if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { - return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", - batch.num_rows(), array.len()); + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 + && !inputs.is_empty() + && inputs + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), array.len()) + }; } } Ok(output) From f718fe2270cb4bf9e3a933b2351d28c62216060c Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Thu, 17 Oct 2024 22:54:28 +0530 Subject: [PATCH 012/110] Migrate datetime documentation to code (#12966) * added code docs for the datetime functions * removed old docs for time and date functions * fixed description for to_unixtime() * removed todo comments * fix merge --------- Co-authored-by: Andrew Lamb --- .../functions/src/datetime/current_date.rs | 27 +- .../functions/src/datetime/current_time.rs | 30 +- datafusion/functions/src/datetime/date_bin.rs | 43 +- .../functions/src/datetime/date_part.rs | 45 +- .../functions/src/datetime/date_trunc.rs | 39 +- .../functions/src/datetime/from_unixtime.rs | 29 +- .../functions/src/datetime/make_date.rs | 48 +- datafusion/functions/src/datetime/now.rs | 29 +- datafusion/functions/src/datetime/to_char.rs | 41 +- .../functions/src/datetime/to_local_time.rs | 73 ++- .../functions/src/datetime/to_timestamp.rs | 214 ++++++- .../functions/src/datetime/to_unixtime.rs | 50 +- .../source/user-guide/sql/scalar_functions.md | 605 ++---------------- .../user-guide/sql/scalar_functions_new.md | 489 ++++++++++++++ 14 files changed, 1177 insertions(+), 585 deletions(-) diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 8b180ff41b91..24046611a71f 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -22,8 +22,12 @@ use arrow::datatypes::DataType::Date32; use chrono::{Datelike, NaiveDate}; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::OnceLock; #[derive(Debug)] pub struct CurrentDateFunc { @@ -95,4 +99,25 @@ impl ScalarUDFImpl for CurrentDateFunc { ScalarValue::Date32(days), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_date()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 803759d4e904..4122b54b07e8 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct CurrentTimeFunc { @@ -84,4 +87,25 @@ impl ScalarUDFImpl for CurrentTimeFunc { ScalarValue::Time64Nanosecond(nano), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_time()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 997f1a36ad04..e335c4e097f7 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::NANOSECONDS; use arrow::array::types::{ @@ -35,10 +35,11 @@ use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; #[derive(Debug)] pub struct DateBinFunc { @@ -163,6 +164,44 @@ impl ScalarUDFImpl for DateBinFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_bin_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_bin_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. +"#) + .with_syntax_example("date_bin(interval, expression, origin-timestamp)") + .with_argument("interval", "Bin interval.") + .with_argument("expression", "Time expression to operate on. Can be a constant, column, or function.") + .with_argument("origin-timestamp", "Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century +") + .build() + .unwrap() + }) } enum Interval { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index b6a9a1c7e9db..3fefa5051376 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{Array, ArrayRef, Float64Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; @@ -37,9 +37,10 @@ use datafusion_common::cast::{ as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -217,6 +218,46 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_part_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns the specified part of the date as an integer.") + .with_syntax_example("date_part(part, expression)") + .with_argument( + "part", + r#"Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .build() + .unwrap() + }) } /// Invoke [`date_part`] and cast the result to Float64 diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index f4786b16685f..4808f020e0ca 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::ops::{Add, Sub}; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::{ as_datetime_with_timezone, timestamp_ns_to_datetime, @@ -36,12 +36,13 @@ use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; use chrono::{ DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, }; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; #[derive(Debug)] pub struct DateTruncFunc { @@ -241,6 +242,40 @@ impl ScalarUDFImpl for DateTruncFunc { Ok(SortProperties::Unordered) } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_trunc_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_trunc_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Truncates a timestamp value to a specified precision.") + .with_syntax_example("date_trunc(precision, expression)") + .with_argument( + "precision", + r#"Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .build() + .unwrap() + }) } fn _date_trunc_coarse(granularity: &str, value: Option) -> Result> diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d36ebe735ee7..84aa9feec654 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp}; use arrow::datatypes::TimeUnit::Second; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FromUnixtimeFunc { @@ -78,4 +81,24 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_from_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_from_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.") + .with_syntax_example("from_unixtime(expression)") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ded7b454f9eb..78bd7c63a412 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::builder::PrimitiveBuilder; use arrow::array::cast::AsArray; @@ -27,7 +27,10 @@ use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf use chrono::prelude::*; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct MakeDateFunc { @@ -148,6 +151,47 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(value) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_make_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_make_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Make a date from year/month/day component parts.") + .with_syntax_example("make_date(year, month, day)") + .with_argument( + "year", + " Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", ) + .with_argument( + "month", + "Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +"#) + .build() + .unwrap() + }) } /// Converts the year/month/day fields to an `i32` representing the days from diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 690008d97212..c13bbfb18105 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct NowFunc { @@ -86,6 +89,9 @@ impl ScalarUDFImpl for NowFunc { ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } fn aliases(&self) -> &[String] { &self.aliases @@ -95,3 +101,20 @@ impl ScalarUDFImpl for NowFunc { false } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. +"#) + .with_syntax_example("now()") + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f2e5af978ca0..430dcedd92cf 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::cast::AsArray; use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; @@ -29,9 +29,10 @@ use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -137,6 +138,42 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_char_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_char_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported.") + .with_syntax_example("to_char(expression, format)") + .with_argument( + "expression", + " Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration." + ) + .with_argument( + "format", + "A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +"#) + .build() + .unwrap() + }) } fn _build_format_options<'a>( diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 0e33da14547e..7646137ce656 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::ops::Add; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::timezone::Tz; use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; @@ -31,7 +31,10 @@ use arrow::datatypes::{ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; /// A UDF function that converts a timezone-aware timestamp to local time (with no offset or /// timezone information). In other words, this function strips off the timezone from the timestamp, @@ -351,6 +354,72 @@ impl ScalarUDFImpl for ToLocalTimeFunc { _ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"), } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_local_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_local_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes.") + .with_syntax_example("to_local_time(expression)") + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function." + ) + .with_sql_example(r#"```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +```"#) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index b17c9a005d1f..9479e25fe61f 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; @@ -25,10 +25,12 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use datafusion_common::{exec_err, Result, ScalarType}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use crate::datetime::common::*; +use datafusion_common::{exec_err, Result, ScalarType}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct ToTimestampFunc { @@ -182,6 +184,50 @@ impl ScalarUDFImpl for ToTimestampFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_timestamp_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +"#) + .with_syntax_example("to_timestamp(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampSecondsFunc { @@ -230,6 +276,46 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_seconds_doc()) + } +} + +static TO_TIMESTAMP_SECONDS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_seconds_doc() -> &'static Documentation { + TO_TIMESTAMP_SECONDS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_seconds(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMillisFunc { @@ -280,6 +366,46 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_millis_doc()) + } +} + +static TO_TIMESTAMP_MILLIS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_millis_doc() -> &'static Documentation { + crate::datetime::to_timestamp::TO_TIMESTAMP_MILLIS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_millis(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMicrosFunc { @@ -330,6 +456,46 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_micros_doc()) + } +} + +static TO_TIMESTAMP_MICROS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_micros_doc() -> &'static Documentation { + TO_TIMESTAMP_MICROS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_micros(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampNanosFunc { @@ -380,6 +546,46 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_nanos_doc()) + } +} + +static TO_TIMESTAMP_NANOS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_nanos_doc() -> &'static Documentation { + TO_TIMESTAMP_NANOS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_nanos(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } /// Returns the return type for the to_timestamp_* function, preserving diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 396dadccb4b3..10f0f87a4ab1 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::{DataType, TimeUnit}; +use std::any::Any; +use std::sync::OnceLock; +use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use super::to_timestamp::ToTimestampSecondsFunc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct ToUnixtimeFunc { @@ -86,4 +88,42 @@ impl ScalarUDFImpl for ToUnixtimeFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.") + .with_syntax_example("to_unixtime(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ).with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.") + .with_sql_example(r#" +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` +"#) + .build() + .unwrap() + }) } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d1f816898d93..547ea108080e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -51,175 +51,7 @@ position(substr in origstr) ## Time and Date Functions -- [now](#now) -- [current_date](#current_date) -- [current_time](#current_time) -- [date_bin](#date_bin) -- [date_trunc](#date_trunc) -- [datetrunc](#datetrunc) -- [date_part](#date_part) -- [datepart](#datepart) - [extract](#extract) -- [today](#today) -- [make_date](#make_date) -- [to_char](#to_char) -- [to_local_time](#to_local_time) -- [to_timestamp](#to_timestamp) -- [to_timestamp_millis](#to_timestamp_millis) -- [to_timestamp_micros](#to_timestamp_micros) -- [to_timestamp_seconds](#to_timestamp_seconds) -- [to_timestamp_nanos](#to_timestamp_nanos) -- [from_unixtime](#from_unixtime) -- [to_unixtime](#to_unixtime) - -### `now` - -Returns the current UTC timestamp. - -The `now()` return value is determined at query time and will return the same timestamp, -no matter when in the query plan the function executes. - -``` -now() -``` - -### `current_date` - -Returns the current UTC date. - -The `current_date()` return value is determined at query time and will return the same date, -no matter when in the query plan the function executes. - -``` -current_date() -``` - -#### Aliases - -- today - -### `today` - -_Alias of [current_date](#current_date)._ - -### `current_time` - -Returns the current UTC time. - -The `current_time()` return value is determined at query time and will return the same time, -no matter when in the query plan the function executes. - -``` -current_time() -``` - -### `date_bin` - -Calculates time intervals and returns the start of the interval nearest to the specified timestamp. -Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" -and applying an aggregate or selector function to each window. - -For example, if you "bin" or "window" data into 15 minute intervals, an input -timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 -minute bin it is in: `2023-01-01T18:15:00Z`. - -``` -date_bin(interval, expression, origin-timestamp) -``` - -#### Arguments - -- **interval**: Bin interval. -- **expression**: Time expression to operate on. - Can be a constant, column, or function. -- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified - defaults `1970-01-01T00:00:00Z` (the UNIX epoch in UTC). - -The following intervals are supported: - -- nanoseconds -- microseconds -- milliseconds -- seconds -- minutes -- hours -- days -- weeks -- months -- years -- century - -### `date_trunc` - -Truncates a timestamp value to a specified precision. - -``` -date_trunc(precision, expression) -``` - -#### Arguments - -- **precision**: Time precision to truncate to. - The following precisions are supported: - - - year / YEAR - - quarter / QUARTER - - month / MONTH - - week / WEEK - - day / DAY - - hour / HOUR - - minute / MINUTE - - second / SECOND - -- **expression**: Time expression to operate on. - Can be a constant, column, or function. - -#### Aliases - -- datetrunc - -### `datetrunc` - -_Alias of [date_trunc](#date_trunc)._ - -### `date_part` - -Returns the specified part of the date as an integer. - -``` -date_part(part, expression) -``` - -#### Arguments - -- **part**: Part of the date to return. - The following date parts are supported: - - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ - -- **expression**: Time expression to operate on. - Can be a constant, column, or function. - -#### Aliases - -- datepart - -### `datepart` - -_Alias of [date_part](#date_part)._ ### `extract` @@ -238,394 +70,10 @@ date_part('day', '2024-04-13'::date) See [date_part](#date_part). -### `make_date` - -Make a date from year/month/day component parts. - -``` -make_date(year, month, day) -``` - -#### Arguments - -- **year**: Year to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **month**: Month to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **day**: Day to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. - -#### Example - -``` -> select make_date(2023, 1, 31); -+-------------------------------------------+ -| make_date(Int64(2023),Int64(1),Int64(31)) | -+-------------------------------------------+ -| 2023-01-31 | -+-------------------------------------------+ -> select make_date('2023', '01', '31'); -+-----------------------------------------------+ -| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | -+-----------------------------------------------+ -| 2023-01-31 | -+-----------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) - -### `to_char` - -Returns a string representation of a date, time, timestamp or duration based -on a [Chrono format]. Unlike the PostgreSQL equivalent of this function -numerical formatting is not supported. - -``` -to_char(expression, format) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function that results in a - date, time, timestamp or duration. -- **format**: A [Chrono format] string to use to convert the expression. - -#### Example - -``` -> select to_char('2023-03-01'::date, '%d-%m-%Y'); -+----------------------------------------------+ -| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | -+----------------------------------------------+ -| 01-03-2023 | -+----------------------------------------------+ -``` - -Additional examples can be found [here] - -[here]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs - -#### Aliases - -- date_format - -### `to_local_time` - -Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or -timezone information). This function handles daylight saving time changes. - -``` -to_local_time(expression) -``` - -#### Arguments - -- **expression**: Time expression to operate on. Can be a constant, column, or function. - -#### Example - -``` -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT - time, - arrow_typeof(time) as type, - to_local_time(time) as to_local_time, - arrow_typeof(to_local_time(time)) as to_local_time_type -FROM ( - SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time -); -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| time | type | to_local_time | to_local_time_type | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ - -# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather -# than UTC boundaries - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; -+---------------------+ -| date_bin | -+---------------------+ -| 2024-04-01T00:00:00 | -+---------------------+ - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; -+---------------------------+ -| date_bin_with_timezone | -+---------------------------+ -| 2024-04-01T00:00:00+02:00 | -+---------------------------+ -``` - -### `to_timestamp` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, unsigned integer, and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. -Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. -Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` -for the input outside of supported bounds. - -``` -to_timestamp(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html - -#### Example - -``` -> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------+ -| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------+ -> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------+ -| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+--------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_millis` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_millis(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123 | -+------------------------------------------------------------------+ -> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_micros` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) -Returns the corresponding timestamp. - -``` -to_timestamp_micros(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456 | -+------------------------------------------------------------------+ -> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_nanos` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_nanos(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------------+ -> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_seconds` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. - -``` -to_timestamp_seconds(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); -+-------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-------------------------------------------------------------------+ -| 2023-01-31T14:26:56 | -+-------------------------------------------------------------------+ -> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+----------------------------------------------------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+----------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00 | -+----------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `from_unixtime` - -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. - -``` -from_unixtime(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `to_unixtime` - -Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Supports strings, dates, timestamps and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. - -``` -to_unixtime(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -``` -> select to_unixtime('2020-09-08T12:00:00+00:00'); -+------------------------------------------------+ -| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | -+------------------------------------------------+ -| 1599566400 | -+------------------------------------------------+ -> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); -+-----------------------------------------------------------------------------------------------------------------------------+ -| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | -+-----------------------------------------------------------------------------------------------------------------------------+ -| 1673638290 | -+-----------------------------------------------------------------------------------------------------------------------------+ -``` - ## Array Functions - [unnest](#unnest) +- [range](#range) ### `unnest` @@ -669,11 +117,60 @@ Transforms an array into rows. +-----------------------------------+ ``` +### `range` + +Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or +`SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` + +The range start..end contains all values with start <= x < end. It is empty if start >= end. + +Step can not be 0 (then the range will be nonsense.). + +Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, +but when the required range is a date or timestamp, it must be 3 non-NULL parameters. +For example, + +``` +SELECT range(3); +SELECT range(1,5); +SELECT range(1,5,1); +``` + +are allowed in number ranges + +but in date and timestamp ranges, only + +``` +SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); +SELECT range(TIMESTAMP '1992-09-01', TIMESTAMP '1993-03-01', INTERVAL '1' MONTH); +``` + +is allowed, and + +``` +SELECT range(DATE '1992-09-01', DATE '1993-03-01', NULL); +SELECT range(NULL, DATE '1993-03-01', INTERVAL '1' MONTH); +SELECT range(DATE '1992-09-01', NULL, INTERVAL '1' MONTH); +``` + +are not allowed + +#### Arguments + +- **start**: start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: end of the range (not included). Type must be the same as start. +- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. + +#### Aliases + +- generate_series + ## Struct Functions - [unnest](#unnest-struct) -For more struct functions see the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) +For more struct functions see the new documentation [ +`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) ### `unnest (struct)` diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 0a073db543b0..8f1e30f1fa53 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1800,7 +1800,239 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ## Time and Date Functions +- [current_date](#current_date) +- [current_time](#current_time) +- [date_bin](#date_bin) +- [date_format](#date_format) +- [date_part](#date_part) +- [date_trunc](#date_trunc) +- [datepart](#datepart) +- [datetrunc](#datetrunc) +- [from_unixtime](#from_unixtime) +- [make_date](#make_date) +- [now](#now) +- [to_char](#to_char) - [to_date](#to_date) +- [to_local_time](#to_local_time) +- [to_timestamp](#to_timestamp) +- [to_timestamp_micros](#to_timestamp_micros) +- [to_timestamp_millis](#to_timestamp_millis) +- [to_timestamp_nanos](#to_timestamp_nanos) +- [to_timestamp_seconds](#to_timestamp_seconds) +- [to_unixtime](#to_unixtime) +- [today](#today) + +### `current_date` + +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. + +``` +current_date() +``` + +#### Aliases + +- today + +### `current_time` + +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. + +``` +current_time() +``` + +### `date_bin` + +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. + +``` +date_bin(interval, expression, origin-timestamp) +``` + +#### Arguments + +- **interval**: Bin interval. +- **expression**: Time expression to operate on. Can be a constant, column, or function. +- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century + +### `date_format` + +_Alias of [to_char](#to_char)._ + +### `date_part` + +Returns the specified part of the date as an integer. + +``` +date_part(part, expression) +``` + +#### Arguments + +- **part**: Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Aliases + +- datepart + +### `date_trunc` + +Truncates a timestamp value to a specified precision. + +``` +date_trunc(precision, expression) +``` + +#### Arguments + +- **precision**: Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Aliases + +- datetrunc + +### `datepart` + +_Alias of [date_part](#date_part)._ + +### `datetrunc` + +_Alias of [date_trunc](#date_trunc)._ + +### `from_unixtime` + +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. + +``` +from_unixtime(expression) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. + +### `make_date` + +Make a date from year/month/day component parts. + +``` +make_date(year, month, day) +``` + +#### Arguments + +- **year**: Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **month**: Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) + +### `now` + +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. + +``` +now() +``` + +### `to_char` + +Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. + +``` +to_char(expression, format) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration. +- **format**: A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) + +#### Aliases + +- date_format ### `to_date` @@ -1842,6 +2074,263 @@ to_date('2017-05-31', '%Y-%m-%d') Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + +### `to_timestamp` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + +``` +to_timestamp(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_micros` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp. + +``` +to_timestamp_micros(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_millis` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_millis(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_nanos` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_nanos(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_seconds` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_seconds(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_unixtime` + +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. + +``` +to_unixtime(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` + +### `today` + +_Alias of [current_date](#current_date)._ + ## Array Functions - [array_any_value](#array_any_value) From 700b07fd64b96e3f66ef01dce13dcef7c8588437 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 17 Oct 2024 20:12:37 -0400 Subject: [PATCH 013/110] Fix CI / regenerate functions (#12991) --- docs/source/user-guide/sql/scalar_functions_new.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 8f1e30f1fa53..ffc2b680b5c5 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1802,6 +1802,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo - [current_date](#current_date) - [current_time](#current_time) +- [current_timestamp](#current_timestamp) - [date_bin](#date_bin) - [date_format](#date_format) - [date_part](#date_part) @@ -1846,6 +1847,10 @@ The `current_time()` return value is determined at query time and will return th current_time() ``` +### `current_timestamp` + +_Alias of [now](#now)._ + ### `date_bin` Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. @@ -2003,6 +2008,10 @@ The `now()` return value is determined at query time and will return the same ti now() ``` +#### Aliases + +- current_timestamp + ### `to_char` Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. From efe5708978a480d11d5406a7d7df76d73e15c5d7 Mon Sep 17 00:00:00 2001 From: jcsherin Date: Fri, 18 Oct 2024 16:56:41 +0530 Subject: [PATCH 014/110] Convert `BuiltInWindowFunction::{Lead, Lag}` to a user defined window function (#12857) * Move `lead-lag` to `functions-window` package * Builds with warnings * Adds `PartitionEvaluatorArgs` * Extracts `shift_offset` from input expressions * Computes shift offset * Get default value from input expression * Implements `partition_evaluator` * Fixes compiler warnings * Comments out failing tests * Fixes `cargo test` errors and warnings * Minor: taplo formatting * Delete code * Define `lead`, `lag` user-defined window functions * Fixes `cargo build` errors * Export udwf and expression public APIs * Mark result field as nullable * Delete `return_type` tests for `lead` and `lag` * Disables test: window function case insensitive * Fixes: lowercase name in logical plan * Reverts to old methods for computing `shift_offset`, `default_value` * Implements expression reversal * Fixes: lowercase name in logical plans * Fixes: doc test compilation errors Fixes: doc test build errors * Temporarily quite clippy errors * Fixes proto defintion * Minor: fixes formatting * Fixes: doc tests * Uses macro for defining `lag_udwf()` and `leag_udwf()` * Fixes: window fuzz test cases * Copies doc comments verbatim from `BuiltInWindowFunction` enum * Deletes from window function case insensitive test * Deletes `BuiltInWindowFunction` expression APIs * Delete from `create_built_in_window_expr` * Deletes proto serialization * Delete from `BuiltInWindowFunction` enum * Deletes test for finding built-in window function * Fixes build errors + deletes redundant code * Deletes more code * Delete unnecessary structs * Refactors shift offset computation * Passes range unit test * Fixes: clippy::get-first error * Rewrite unit tests for WindowUDF * Fixes: unit test for lag with default value * Consistent input expressions and data types in unit tests * Minor: fixes formatting * Restore original helper method for unit tests * Revert "Refactors shift offset computation" This reverts commit 000ceb76409e66230f9c5017a30fa3c9bb1e6575. * Moves helper functions into `functions-window-common` package * Uses common helper functions in `{lead, lag}` * Minor: formatting * Revert "Moves helper functions into `functions-window-common` package" This reverts commit ab8a83c9c11ca3a245278f6f300438feaacb0978. * Moves common functions to utils * Minor: formatting fixes * Update lowercase names in explain output * Adds doc for `lead()` and `lag()` expression functions * Add doc for `WindowShiftKind::shift_offset` * Remove `arrow` dev dependency * Minor: formatting * Update inner doc comment * Serialize 1 or more window function arguments * Adds logical plan roundtrip test cases * Refactor: readability of unit tests * Minor: rename variable bindings * Minor: copy edit * Revert "Remove `arrow` dev dependency" This reverts commit 3eb09856c8ec4ddce20472deee2df590c2fd3f35. * Move null argument handling helper to utils * Disable failing sqllogic tests for handling NULL input * Revert "Disable failing sqllogic tests for handling NULL input" This reverts commit 270a2030637012d549c001e973a0a1bb6b3d4dd0. * Fixes: incorrect NULL handling in `lead`/`lag` window function * Adds more tests cases --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + .../core/tests/fuzz_cases/window_fuzz.rs | 13 +- .../expr/src/built_in_window_function.rs | 32 +- datafusion/expr/src/expr.rs | 38 -- datafusion/expr/src/udwf.rs | 23 + datafusion/expr/src/window_function.rs | 34 -- .../functions-window-common/src/expr.rs | 64 +++ datafusion/functions-window-common/src/lib.rs | 1 + datafusion/functions-window/Cargo.toml | 1 + .../src}/lead_lag.rs | 392 ++++++++++++------ datafusion/functions-window/src/lib.rs | 8 + datafusion/functions-window/src/utils.rs | 53 +++ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/physical-expr/src/window/mod.rs | 1 - datafusion/physical-plan/src/windows/mod.rs | 88 +--- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 30 +- datafusion/proto/src/generated/prost.rs | 14 +- .../proto/src/logical_plan/from_proto.rs | 17 +- datafusion/proto/src/logical_plan/to_proto.rs | 14 +- .../proto/src/physical_plan/to_proto.rs | 20 - .../tests/cases/roundtrip_logical_plan.rs | 12 +- datafusion/sqllogictest/test_files/union.slt | 8 +- datafusion/sqllogictest/test_files/window.slt | 56 ++- 24 files changed, 520 insertions(+), 407 deletions(-) create mode 100644 datafusion/functions-window-common/src/expr.rs rename datafusion/{physical-expr/src/window => functions-window/src}/lead_lag.rs (59%) create mode 100644 datafusion/functions-window/src/utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index aa64e14fca8e..dfd07a7658ff 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1445,6 +1445,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-functions-window-common", + "datafusion-physical-expr", "datafusion-physical-expr-common", "log", "paste", diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 4a33334770a0..d649919f1b6a 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf}; use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf}; use hashbrown::HashMap; use rand::distributions::Alphanumeric; @@ -197,7 +198,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::WindowUDF(lag_udwf()), // its name "LAG", // no argument @@ -211,7 +212,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::WindowUDF(lead_udwf()), // its name "LEAD", // no argument @@ -393,9 +394,7 @@ fn get_random_function( window_fn_map.insert( "lead", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lead, - ), + WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -406,9 +405,7 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lag, - ), + WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 6a30080fb38b..2c70a07a4e15 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::type_coercion::functions::data_types; use crate::utils; -use crate::{Signature, TypeSignature, Volatility}; +use crate::{Signature, Volatility}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use arrow::datatypes::DataType; @@ -44,17 +44,7 @@ pub enum BuiltInWindowFunction { CumeDist, /// Integer ranging from 1 to the argument value, dividing the partition as equally as possible Ntile, - /// Returns value evaluated at the row that is offset rows before the current row within the partition; - /// If there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// Returns value evaluated at the row that is offset rows after the current row within the partition; - /// If there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// Returns value evaluated at the row that is the first row of the window frame + /// returns value evaluated at the row that is the first row of the window frame FirstValue, /// Returns value evaluated at the row that is the last row of the window frame LastValue, @@ -68,8 +58,6 @@ impl BuiltInWindowFunction { match self { CumeDist => "CUME_DIST", Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", FirstValue => "first_value", LastValue => "last_value", NthValue => "NTH_VALUE", @@ -83,8 +71,6 @@ impl FromStr for BuiltInWindowFunction { Ok(match name.to_uppercase().as_str() { "CUME_DIST" => BuiltInWindowFunction::CumeDist, "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, @@ -117,9 +103,7 @@ impl BuiltInWindowFunction { match self { BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::CumeDist => Ok(DataType::Float64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } @@ -130,16 +114,6 @@ impl BuiltInWindowFunction { // Note: The physical expression must accept the type returned by this function or the execution panics. match self { BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e692189e488..f3f71a87278b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2560,30 +2560,6 @@ mod test { Ok(()) } - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64], &[true], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); @@ -2621,8 +2597,6 @@ mod test { let names = vec![ "cume_dist", "ntile", - "lag", - "lead", "first_value", "last_value", "nth_value", @@ -2660,18 +2634,6 @@ mod test { built_in_window_function::BuiltInWindowFunction::LastValue )) ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lead - )) - ); assert_eq!(find_df_window_func("not_exist"), None) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6d8f2be97e02..6ab94c1e841a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -34,8 +34,10 @@ use crate::{ Signature, }; use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -149,6 +151,12 @@ impl WindowUDF { self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory( &self, @@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + /// Invoke the function, returning the [`PartitionEvaluator`] instance fn partition_evaluator( &self, @@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) + } + fn partition_evaluator( &self, partition_evaluator_args: PartitionEvaluatorArgs, diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 7ac6fb7d167c..3e1870c59c15 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::ScalarValue; - use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; /// Create an expression to represent the `cume_dist` window function @@ -29,38 +27,6 @@ pub fn ntile(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) } -/// Create an expression to represent the `lag` window function -pub fn lag( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lag, - vec![arg, shift_offset_lit, default_lit], - )) -} - -/// Create an expression to represent the `lead` window function -pub fn lead( - arg: Expr, - shift_offset: Option, - default_value: Option, -) -> Expr { - let shift_offset_lit = shift_offset - .map(|v| v.lit()) - .unwrap_or(ScalarValue::Null.lit()); - let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::Lead, - vec![arg, shift_offset_lit, default_lit], - )) -} - /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new( diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 000000000000..1d99fe7acf15 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 53f9eb1c9ac6..da8d096da562 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -18,5 +18,6 @@ //! Common user-defined window functionality for [DataFusion] //! //! [DataFusion]: +pub mod expr; pub mod field; pub mod partition; diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index 952e5720c77c..262c21fcec65 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -41,6 +41,7 @@ path = "src/lib.rs" datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs similarity index 59% rename from datafusion/physical-expr/src/window/lead_lag.rs rename to datafusion/functions-window/src/lead_lag.rs index 1656b7c3033a..f81521099751 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -15,125 +15,275 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +//! `lead` and `lag` window function implementations + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::{ + Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility, + WindowUDFImpl, +}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; use std::sync::Arc; -/// window shift expression +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + #[derive(Debug)] -pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc, - default_value: ScalarValue, - ignore_nulls: bool, +enum WindowShiftKind { + Lag, + Lead, } -impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } } - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value when computing `lead()`. + fn shift_offset(&self, value: Option) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } } } -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, - } +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + signature: Signature, + kind: WindowShiftKind, } -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, +impl WindowShift { + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } + } + + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) + } + + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) } } -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + fn name(&self) -> &str { + self.kind.name() } - fn expressions(&self) -> Vec> { - vec![Arc::clone(&self.expr)] + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() } - fn create_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let default_value = parse_default_value( + partition_evaluator_args.input_exprs(), + partition_evaluator_args.input_types(), + )?; + Ok(Box::new(WindowShiftEvaluator { - shift_offset: self.shift_offset, - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, + shift_offset, + default_value, + ignore_nulls: partition_evaluator_args.ignore_nulls(), non_null_offsets: VecDeque::new(), })) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.clone(), - data_type: self.data_type.clone(), - shift_offset: -self.shift_offset, - expr: Arc::clone(&self.expr), - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - })) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()), + WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), + } + } +} + +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); + } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) } #[derive(Debug)] -pub(crate) struct WindowShiftEvaluator { +struct WindowShiftEvaluator { shift_offset: i64, default_value: ScalarValue, ignore_nulls: bool, @@ -205,7 +355,7 @@ fn shift_with_default_value( offset: i64, default_value: &ScalarValue, ) -> Result { - use arrow::compute::concat; + use datafusion_common::arrow::compute::concat; let value_len = array.len() as i64; if offset == 0 { @@ -402,19 +552,22 @@ impl PartitionEvaluator for WindowShiftEvaluator { #[cfg(test)] mod tests { use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; + use arrow::array::*; use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + fn test_i32_result( + expr: WindowShift, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let values = expr.evaluate_args(&batch)?; + let num_rows = values.len(); let result = expr - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; + .partition_evaluator(partition_evaluator_args)? + .evaluate_all(&values, num_rows)?; let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) @@ -466,16 +619,12 @@ mod tests { } #[test] - fn lead_lag_window_shift() -> Result<()> { + fn test_lead_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_result( - lead( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lead(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ Some(-2), Some(3), @@ -488,17 +637,16 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), [ None, Some(1), @@ -511,17 +659,24 @@ mod tests { ] .iter() .collect::(), - )?; + ) + } + + #[test] + fn test_lag_with_default() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let shift_offset = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100)))) + as Arc; + + let input_exprs = &[expr, shift_offset, default_value]; + let input_types: &[DataType] = + &[DataType::Int32, DataType::Int32, DataType::Int32]; test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Int32(Some(100)), - false, - ), + WindowShift::lag(), + PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), [ Some(100), Some(1), @@ -534,7 +689,6 @@ mod tests { ] .iter() .collect::(), - )?; - Ok(()) + ) } } diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index ef624e13e61c..5a2aafa2892e 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -31,11 +31,17 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; + +pub mod lead_lag; + pub mod rank; pub mod row_number; +mod utils; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::lead_lag::lag; + pub use super::lead_lag::lead; pub use super::rank::{dense_rank, percent_rank, rank}; pub use super::row_number::row_number; } @@ -44,6 +50,8 @@ pub mod expr_fn { pub fn all_default_window_functions() -> Vec> { vec![ row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), rank::rank_udwf(), rank::dense_rank_udwf(), rank::percent_rank_udwf(), diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs new file mode 100644 index 000000000000..69f68aa78f2c --- /dev/null +++ b/datafusion/functions-window/src/utils.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::Int64)?.try_into() +} + +pub(crate) fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .value() + .clone(); + Some(tmp) + } else { + None + }) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e07e11e43199..54b8aafdb4da 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -36,7 +36,6 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; -pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; pub use crate::window::ntile::Ntile; pub use crate::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 938bdac50f97..c0fe3c2933a7 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -19,7 +19,6 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; pub(crate) mod cume_dist; -pub(crate) mod lead_lag; pub(crate) mod nth_value; pub(crate) mod ntile; mod sliding_aggregate; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index e6a773f6b1ea..adf61f27bc6f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,7 +21,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{cume_dist, lag, lead, Literal, NthValue, Ntile, PhysicalSortExpr}, + expressions::{cume_dist, Literal, NthValue, Ntile, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; @@ -48,6 +48,7 @@ mod utils; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::expressions::Column; @@ -206,52 +207,6 @@ fn get_unsigned_integer(value: ScalarValue) -> Result { value.cast_to(&DataType::UInt64)?.try_into() } -fn get_casted_value( - default_value: Option, - dtype: &DataType, -) -> Result { - match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), - // If None or Null datatype - _ => ScalarValue::try_from(dtype), - } -} - -/// Rewrites the NULL expression (1st argument) with an expression -/// which is the same data type as the default value (3rd argument). -/// Also rewrites the return type with the same data type as the -/// default value. -/// -/// If a default value is not provided, or it is NULL the original -/// expression (1st argument) and return type is returned without -/// any modifications. -fn rewrite_null_expr_and_data_type( - args: &[Arc], - expr_type: &DataType, -) -> Result<(Arc, DataType)> { - assert!(!args.is_empty()); - let expr = Arc::clone(&args[0]); - - // The input expression and the return is type is unchanged - // when the input expression is not NULL. - if !expr_type.is_null() { - return Ok((expr, expr_type.clone())); - } - - get_scalar_value_from_args(args, 2)? - .and_then(|value| { - ScalarValue::try_from(value.data_type().clone()) - .map(|sv| { - Ok(( - Arc::new(Literal::new(sv)) as Arc, - value.data_type().clone(), - )) - }) - .ok() - }) - .unwrap_or(Ok((expr, expr_type.clone()))) -} - fn create_built_in_window_expr( fun: &BuiltInWindowFunction, args: &[Arc], @@ -286,42 +241,6 @@ fn create_built_in_window_expr( Arc::new(Ntile::new(name, n as u64, out_data_type)) } } - BuiltInWindowFunction::Lag => { - // rewrite NULL expression and the return datatype - let (arg, out_data_type) = - rewrite_null_expr_and_data_type(args, out_data_type)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; - Arc::new(lag( - name, - default_value.data_type().clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } - BuiltInWindowFunction::Lead => { - // rewrite NULL expression and the return datatype - let (arg, out_data_type) = - rewrite_null_expr_and_data_type(args, out_data_type)?; - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(get_signed_integer) - .map_or(Ok(None), |v| v.map(Some))?; - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?; - Arc::new(lead( - name, - default_value.data_type().clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } BuiltInWindowFunction::NthValue => { let arg = Arc::clone(&args[0]); let n = get_signed_integer( @@ -415,7 +334,8 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn expressions(&self) -> Vec> { - self.args.clone() + self.fun + .expressions(ExpressionArgs::new(&self.args, &self.input_types)) } fn create_evaluator(&self) -> Result> { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5256f7473c95..9964ab498fb1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -515,8 +515,8 @@ enum BuiltInWindowFunction { // PERCENT_RANK = 3; CUME_DIST = 4; NTILE = 5; - LAG = 6; - LEAD = 7; + // LAG = 6; + // LEAD = 7; FIRST_VALUE = 8; LAST_VALUE = 9; NTH_VALUE = 10; @@ -528,7 +528,7 @@ message WindowExprNode { string udaf = 3; string udwf = 9; } - LogicalExprNode expr = 4; + repeated LogicalExprNode exprs = 4; repeated LogicalExprNode partition_by = 5; repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e876008e853f..4417d1149681 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1664,8 +1664,6 @@ impl serde::Serialize for BuiltInWindowFunction { Self::Unspecified => "UNSPECIFIED", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1683,8 +1681,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { "UNSPECIFIED", "CUME_DIST", "NTILE", - "LAG", - "LEAD", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", @@ -1731,8 +1727,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), "NTILE" => Ok(BuiltInWindowFunction::Ntile), - "LAG" => Ok(BuiltInWindowFunction::Lag), - "LEAD" => Ok(BuiltInWindowFunction::Lead), "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), @@ -21157,7 +21151,7 @@ impl serde::Serialize for WindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.exprs.is_empty() { len += 1; } if !self.partition_by.is_empty() { @@ -21176,8 +21170,8 @@ impl serde::Serialize for WindowExprNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.exprs.is_empty() { + struct_ser.serialize_field("exprs", &self.exprs)?; } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; @@ -21218,7 +21212,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "exprs", "partition_by", "partitionBy", "order_by", @@ -21235,7 +21229,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Exprs, PartitionBy, OrderBy, WindowFrame, @@ -21264,7 +21258,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "exprs" => Ok(GeneratedField::Exprs), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), @@ -21291,7 +21285,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut exprs__ = None; let mut partition_by__ = None; let mut order_by__ = None; let mut window_frame__ = None; @@ -21299,11 +21293,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Exprs => { + if exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("exprs")); } - expr__ = map_.next_value()?; + exprs__ = Some(map_.next_value()?); } GeneratedField::PartitionBy => { if partition_by__.is_some() { @@ -21352,7 +21346,7 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } } Ok(WindowExprNode { - expr: expr__, + exprs: exprs__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2aa14f7e80b0..d3fe031a48c9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -538,7 +538,7 @@ pub mod logical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "18")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::WindowExprNode), /// AggregateUDF expressions #[prost(message, tag = "19")] AggregateUdfExpr(::prost::alloc::boxed::Box), @@ -735,8 +735,8 @@ pub struct ScalarUdfExprNode { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub exprs: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] @@ -1828,8 +1828,8 @@ pub enum BuiltInWindowFunction { /// PERCENT_RANK = 3; CumeDist = 4, Ntile = 5, - Lag = 6, - Lead = 7, + /// LAG = 6; + /// LEAD = 7; FirstValue = 8, LastValue = 9, NthValue = 10, @@ -1844,8 +1844,6 @@ impl BuiltInWindowFunction { Self::Unspecified => "UNSPECIFIED", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1857,8 +1855,6 @@ impl BuiltInWindowFunction { "UNSPECIFIED" => Some(Self::Unspecified), "CUME_DIST" => Some(Self::CumeDist), "NTILE" => Some(Self::Ntile), - "LAG" => Some(Self::Lag), - "LEAD" => Some(Self::Lead), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), "NTH_VALUE" => Some(Self::NthValue), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 32e1b68203ce..20d007048a00 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,8 +142,6 @@ impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), - protobuf::BuiltInWindowFunction::Lag => Self::Lag, - protobuf::BuiltInWindowFunction::Lead => Self::Lead, protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, protobuf::BuiltInWindowFunction::CumeDist => Self::CumeDist, protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, @@ -286,10 +284,7 @@ pub fn parse_expr( .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::BuiltInWindowFunction( @@ -309,10 +304,7 @@ pub fn parse_expr( None => registry.udaf(udaf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, @@ -329,10 +321,7 @@ pub fn parse_expr( None => registry.udwf(udwf_name)?, }; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 07823b422f71..15fec3a8b2a8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -119,8 +119,6 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::NthValue => Self::NthValue, BuiltInWindowFunction::Ntile => Self::Ntile, BuiltInWindowFunction::CumeDist => Self::CumeDist, - BuiltInWindowFunction::Lag => Self::Lag, - BuiltInWindowFunction::Lead => Self::Lead, } } } @@ -333,25 +331,19 @@ pub fn serialize_expr( ) } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(serialize_expr(arg, codec)?)) - } else { - None - }; let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, + let window_expr = protobuf::WindowExprNode { + exprs: serialize_exprs(args, codec)?, window_function: Some(window_function), partition_by, order_by, window_frame, fun_definition, - }); + }; protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 85d4fe8a16d0..6072baca688c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,7 +25,6 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, - WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -119,25 +118,6 @@ pub fn serialize_physical_window_expr( )))), ); protobuf::BuiltInWindowFunction::Ntile - } else if let Some(window_shift_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - window_shift_expr.get_shift_offset(), - )))), - ); - args.insert( - 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), - ); - - if window_shift_expr.get_shift_offset() >= 0 { - protobuf::BuiltInWindowFunction::Lag - } else { - protobuf::BuiltInWindowFunction::Lead - } } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { match nth_value_expr.get_kind() { NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ffa8fc1eefe9..c017395d979f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -47,8 +47,10 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; -use datafusion::functions_window::rank::{dense_rank, percent_rank, rank, rank_udwf}; -use datafusion::functions_window::row_number::row_number; +use datafusion::functions_window::expr_fn::{ + dense_rank, lag, lead, percent_rank, rank, row_number, +}; +use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -942,6 +944,12 @@ async fn roundtrip_expr_api() -> Result<()> { rank(), dense_rank(), percent_rank(), + lead(col("b"), None, None), + lead(col("b"), Some(2), None), + lead(col("b"), Some(2), Some(ScalarValue::from(100))), + lag(col("b"), None, None), + lag(col("b"), Some(2), None), + lag(col("b"), Some(2), Some(ScalarValue::from(100))), nth_value(col("b"), 1, vec![]), nth_value( col("b"), diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index a3d0ff4383ae..fb7afdda2ea8 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -503,9 +503,9 @@ logical_plan 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 14)--------EmptyRelation -15)----Projection: LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt +15)----Projection: lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt 16)------Limit: skip=0, fetch=3 -17)--------WindowAggr: windowExpr=[[LEAD(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +17)--------WindowAggr: windowExpr=[[lead(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 18)----------SubqueryAlias: b 19)------------Projection: Int64(1) AS c1 20)--------------EmptyRelation @@ -528,8 +528,8 @@ physical_plan 16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true 17)------ProjectionExec: expr=[1 as cnt] 18)--------PlaceholderRowExec -19)------ProjectionExec: expr=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -20)--------BoundedWindowAggExec: wdw=[LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LEAD(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +19)------ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 21)----------ProjectionExec: expr=[1 as c1] 22)------------PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 1b612f921262..b3f2786d3dba 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1376,16 +1376,16 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 +01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] +01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -2636,15 +2636,15 @@ EXPLAIN SELECT ---- logical_plan 01)Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII @@ -4971,6 +4971,26 @@ SELECT LAG(NULL, 1, false) OVER () FROM t1; ---- false +query B +SELECT LEAD(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true + statement ok insert into t1 values (2); @@ -4986,6 +5006,18 @@ SELECT LAG(NULL, 1, false) OVER () FROM t1; false NULL +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +NULL +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true +NULL + statement ok DROP TABLE t1; From 24148bd65fdf61fba340b69dc87a7920850cb19f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 18 Oct 2024 13:28:03 +0200 Subject: [PATCH 015/110] Add links to new_constraint_from_table_constraints doc (#12995) --- datafusion/sql/src/statement.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 4109f1371187..60e3413b836f 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1263,7 +1263,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } - /// Convert each `TableConstraint` to corresponding `Constraint` + /// Convert each [TableConstraint] to corresponding [Constraint] fn new_constraint_from_table_constraints( constraints: &[TableConstraint], df_schema: &DFSchemaRef, From 87e931c976a7aa24cecaa9bf3658b42bba12a51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alihan=20=C3=87elikcan?= Date: Fri, 18 Oct 2024 14:34:42 +0300 Subject: [PATCH 016/110] Split output batches of joins that do not respect batch size (#12969) * Add BatchSplitter to joins that do not respect batch size * Group relevant imports * Update configs.md * Update SQL logic tests for config * Review * Use PrimitiveBuilder for PrimitiveArray concatenation * Fix into_builder() bug * Apply suggestions from code review Co-authored-by: Andrew Lamb * Update config docs * Format * Update config SQL Logic Test --------- Co-authored-by: Mehmet Ozan Kabak Co-authored-by: Andrew Lamb --- datafusion/common/src/config.rs | 26 +- datafusion/execution/src/config.rs | 14 + .../physical-plan/src/joins/cross_join.rs | 84 +++-- .../physical-plan/src/joins/hash_join.rs | 2 +- .../src/joins/nested_loop_join.rs | 356 ++++++++++++------ .../src/joins/stream_join_utils.rs | 83 ++-- .../src/joins/symmetric_hash_join.rs | 252 +++++++------ datafusion/physical-plan/src/joins/utils.rs | 220 +++++++++-- .../test_files/information_schema.slt | 2 + docs/source/user-guide/configs.md | 1 + 10 files changed, 709 insertions(+), 331 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1e1c5d5424b0..47ffe0b1c66b 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -338,6 +338,12 @@ config_namespace! { /// if the source of statistics is accurate. /// We plan to make this the default in the future. pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } @@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - ConfigFileType::PARQUET => self.parquet.set(rem, value), - ConfigFileType::CSV => self.csv.set(rem, value), - ConfigFileType::JSON => self.json.set(rem, value), - }, + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), + } + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index cede75d21ca4..53646dc5b468 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -432,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index a70645f3d6c0..8f2bef56da76 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -19,7 +19,8 @@ //! and producing batches in parallel for the right partitions use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, BatchSplitter, BatchTransformer, + BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -86,6 +87,7 @@ impl CrossJoinExec { let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( Arc::clone(&self.left), @@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec { ) }); - Ok(Box::pin(CrossJoinStream { - schema: Arc::clone(&self.schema), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -319,7 +339,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -334,9 +354,11 @@ struct CrossJoinStream { state: CrossJoinStreamState, /// Left data left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } @@ -390,7 +412,7 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( @@ -401,7 +423,7 @@ impl Stream for CrossJoinStream { } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( @@ -470,21 +492,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 74a45a7e4761..3b730c01291c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1438,7 +1438,7 @@ impl HashJoinStream { index_alignment_range_start..index_alignment_range_end, self.join_type, self.right_side_ordered, - ); + )?; let result = build_batch_from_indices( &self.schema, diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 6068e7526316..358ff02473a6 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; +use super::utils::{ + asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, + BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -35,8 +38,8 @@ use crate::joins::utils::{ }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType, + Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -45,7 +48,9 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; @@ -230,10 +235,11 @@ impl NestedLoopJoinExec { asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; - } + let mode = if left.execution_mode().is_unbounded() { + ExecutionMode::PipelineBreaking + } else { + execution_mode_from_children([left, right]) + }; PlanProperties::new(eq_properties, output_partitioning, mode) } @@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec { ) }); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); - Ok(Box::pin(NestedLoopJoinStream { - schema: Arc::clone(&self.schema), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - indices_cache, - right_side_ordered, - })) + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -442,8 +472,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -454,8 +513,6 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal @@ -466,6 +523,12 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -544,107 +607,164 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); + + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + } - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } } + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } - } - }) + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } @@ -684,7 +804,7 @@ fn join_left_and_right_batch( 0..right_batch.num_rows(), join_type, right_side_ordered, - ); + )?; build_batch_from_indices( schema, @@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap( get_final_indices_from_bit_map(&bitmap, join_type) } -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -850,7 +970,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index ba9384aef1a6..bddd152341da 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,8 +31,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; @@ -369,34 +368,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -409,41 +414,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -456,13 +465,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs( left_sort_exprs: &[PhysicalSortExpr], right_sort_exprs: &[PhysicalSortExpr], ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -952,15 +971,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4..70ada3892aea 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -32,7 +32,6 @@ use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; -use crate::handle_state; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, @@ -42,8 +41,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ execution_mode_from_children, @@ -465,23 +465,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -494,6 +498,10 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_stream = self.right.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), @@ -502,29 +510,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: SHJStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -556,15 +587,19 @@ struct SymmetricHashJoinStream { reservation: SharedMemoryReservation, /// State machine for input execution state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( @@ -1140,7 +1175,7 @@ impl OneSideHashJoiner { /// - Transition to `BothExhausted { final_result: true }`: /// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are /// exhausted, indicating completion of processing and availability of final results. -impl SymmetricHashJoinStream { +impl SymmetricHashJoinStream { /// Implements the main polling logic for the join stream. /// /// This method continuously checks the state of the join stream and @@ -1159,26 +1194,45 @@ impl SymmetricHashJoinStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - return match self.state() { - SHJStreamState::PullRight => { - handle_state!(ready!(self.fetch_next_from_right_stream(cx))) - } - SHJStreamState::PullLeft => { - handle_state!(ready!(self.fetch_next_from_left_stream(cx))) + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } } - SHJStreamState::RightExhausted => { - handle_state!(ready!(self.handle_right_stream_end(cx))) - } - SHJStreamState::LeftExhausted => { - handle_state!(ready!(self.handle_left_stream_end(cx))) - } - SHJStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); } - SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None), - }; + } } } /// Asynchronously pulls the next batch from the right stream. @@ -1384,11 +1438,8 @@ impl SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1523,11 +1574,6 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } @@ -1716,15 +1762,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1771,10 +1817,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1825,10 +1868,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1877,10 +1917,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1926,10 +1963,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1987,10 +2021,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2048,10 +2079,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2111,10 +2139,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -2170,10 +2195,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2237,10 +2259,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2296,10 +2315,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { @@ -2380,10 +2396,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let left_sorted = vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { @@ -2473,10 +2486,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be..c520e4271416 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -546,15 +546,16 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. #[derive(Debug, Clone)] pub struct JoinFilter { /// Filter expression - expression: Arc, + pub(crate) expression: Arc, /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, + pub(crate) column_indices: Vec, /// Physical schema of intermediate batch - schema: Schema, + pub(crate) schema: Schema, } impl JoinFilter { @@ -1280,15 +1281,15 @@ pub(crate) fn adjust_indices_by_join_type( adjust_range: Range, join_type: JoinType, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } JoinType::Right => { @@ -1307,22 +1308,22 @@ pub(crate) fn adjust_indices_by_join_type( // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::LeftSemi | JoinType::LeftAnti => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } @@ -1347,27 +1348,64 @@ pub(crate) fn append_right_indices( right_indices: UInt32Array, adjust_range: Range, preserve_order_for_right: bool, -) -> (UInt64Array, UInt32Array) { +) -> Result<(UInt64Array, UInt32Array)> { if preserve_order_for_right { - append_probe_indices_in_order(left_indices, right_indices, adjust_range) + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + Ok((left_indices, right_indices)) } else { - let unmatched_size = right_unmatched_indices.len(); + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect(); - (new_left_indices, new_right_indices) + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) } } } @@ -1635,6 +1673,91 @@ pub(crate) fn asymmetric_join_output_partitioning( } } +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, + } + } +} + +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1643,11 +1766,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -2554,4 +2679,49 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7acdf25b6596..57bf029a63c1 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -173,6 +173,7 @@ datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false datafusion.execution.enable_recursive_ctes true +datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 @@ -263,6 +264,7 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs +datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index f34d148f092f..c61a7b673334 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | | datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | | datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | | datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | | datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | From e9435a920ed84a1956b23e7ab6d13fe833cce3eb Mon Sep 17 00:00:00 2001 From: yi wang <48236141+my-vegetable-has-exploded@users.noreply.github.com> Date: Sat, 19 Oct 2024 00:52:23 +0800 Subject: [PATCH 017/110] =?UTF-8?q?Fix=EF=BC=9Afix=20HashJoin=20projection?= =?UTF-8?q?=20swap=20(#12967)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * swap_hash_join works with joins with projections * use non swapped hash join's projection * clean up * fix hashjoin projection swap. * assert hashjoinexec. * Update datafusion/core/src/physical_optimizer/join_selection.rs Co-authored-by: Eduard Karacharov * fix clippy. --------- Co-authored-by: Onur Satici Co-authored-by: Eduard Karacharov --- .../src/physical_optimizer/join_selection.rs | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 499fb9cbbcf0..dfaa7dbb8910 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -183,13 +183,15 @@ pub fn swap_hash_join( partition_mode, hash_join.null_equals_null(), )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( hash_join.join_type(), JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - ) { + ) || hash_join.projection.is_some() + { Ok(Arc::new(new_join)) } else { // TODO avoid adding ProjectionExec again and again, only adding Final Projection @@ -1287,6 +1289,33 @@ mod tests_statistical { ); } + #[tokio::test] + async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> { + let (big, small) = create_big_and_small(); + let join = Arc::new(HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema())?), + Arc::new(Column::new_with_schema("small_col", &small.schema())?), + )], + None, + &JoinType::Inner, + Some(vec![1]), + PartitionMode::Partitioned, + false, + )?); + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) + .expect("swap_hash_join must support joins with projections"); + let swapped_join = swapped.as_any().downcast_ref::().expect( + "ProjectionExec won't be added above if HashJoinExec contains embedded projection", + ); + assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped.schema().fields.len(), 1); + assert_eq!(swapped.schema().fields[0].name(), "small_col"); + Ok(()) + } + #[tokio::test] async fn test_swap_reverting_projection() { let left_schema = Schema::new(vec![ From 97f7491ed62ed7643b8b466237fd1ceb19a54431 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 18 Oct 2024 23:06:45 +0400 Subject: [PATCH 018/110] refactor(substrait): refactor ReadRel consumer (#12983) --- .../substrait/src/logical_plan/consumer.rs | 181 +++++++++--------- 1 file changed, 87 insertions(+), 94 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 4af02858e65a..08e54166d39a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -794,60 +794,61 @@ pub async fn from_substrait_rel( let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } - Some(RelType::Read(read)) => match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Named Table") - })?; + Some(RelType::Read(read)) => { + fn read_with_schema( + df: DataFrame, + schema: DFSchema, + projection: &Option, + ) -> Result { + ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; + let schema = apply_masking(schema, projection)?; - let t = ctx.table(table_reference.clone()).await?; + apply_projection(df, schema) + } - let substrait_schema = - from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference); + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; - ensure_schema_compatability( - t.schema().to_owned(), - substrait_schema.clone(), - )?; + let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; - let substrait_schema = apply_masking(substrait_schema, &read.projection)?; + match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; - apply_projection(t, substrait_schema) - } - Some(ReadType::VirtualTable(vt)) => { - let base_schema = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Virtual Table") - })?; + let t = ctx.table(table_reference.clone()).await?; - let schema = from_substrait_named_struct(base_schema, extensions)?; + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(schema), - })); + read_with_schema(t, substrait_schema, &read.projection) } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } - let values = vt + let values = vt .values .iter() .map(|row| { @@ -860,79 +861,71 @@ pub async fn from_substrait_rel( Ok(Expr::Literal(from_substrait_literal( lit, extensions, - &base_schema.names, + &named_struct.names, &mut name_idx, )?)) }) .collect::>()?; - if name_idx != base_schema.names.len() { + if name_idx != named_struct.names.len() { return substrait_err!( "Names list must match exactly to nested schema, but found {} uses for {} names", name_idx, - base_schema.names.len() + named_struct.names.len() ); } Ok(lits) }) .collect::>()?; - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for LocalFiles") - })?; - - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = if name.starts_with("file://") + && !name.starts_with("file:///") + { name.replacen("file://", "file:///", 1) } else { name.to_string() }; - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } - - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); - - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference.clone()).await?; + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } - let substrait_schema = - from_substrait_named_struct(named_struct, extensions)? - .replace_qualifier(table_reference); + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); - ensure_schema_compatability( - t.schema().to_owned(), - substrait_schema.clone(), - )?; + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference.clone()).await?; - let substrait_schema = apply_masking(substrait_schema, &read.projection)?; + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); - apply_projection(t, substrait_schema) + read_with_schema(t, substrait_schema, &read.projection) + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) + } } - _ => not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type), - }, + } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { From 42f906072a3000d005b8ced97654aaec2828a878 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 18 Oct 2024 23:06:58 +0400 Subject: [PATCH 019/110] feat(substrait): add wildcard handling to producer (#12987) * feat(substrait): add wildcard expand rule in producer * add comment describing need for ExpandWildcardRule --- .../substrait/src/logical_plan/producer.rs | 10 +++++- .../tests/cases/roundtrip_logical_plan.rs | 34 ++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 0e1375a8e0ea..7504a287c055 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::config::ConfigOptions; +use datafusion::optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion::optimizer::AnalyzerRule; use std::sync::Arc; use substrait::proto::expression_reference::ExprType; @@ -103,9 +106,14 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result<()> { #[tokio::test] async fn wildcard_select() -> Result<()> { - roundtrip("SELECT * FROM data").await + assert_expected_plan_unoptimized( + "SELECT * FROM data", + "Projection: data.a, data.b, data.c, data.d, data.e, data.f\ + \n TableScan: data", + true, + ) + .await } #[tokio::test] @@ -1174,6 +1180,32 @@ async fn verify_post_join_filter_value(proto: Box) -> Result<()> { Ok(()) } +async fn assert_expected_plan_unoptimized( + sql: &str, + expected_plan_str: &str, + assert_schema: bool, +) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_unoptimized_plan(); + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + println!("{plan}"); + println!("{plan2}"); + + println!("{proto:?}"); + + if assert_schema { + assert_eq!(plan.schema(), plan2.schema()); + } + + let plan2str = format!("{plan2}"); + assert_eq!(expected_plan_str, &plan2str); + + Ok(()) +} + async fn assert_expected_plan( sql: &str, expected_plan_str: &str, From 3405234836be98860ce1516ed2263c163ada5535 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 18 Oct 2024 12:26:48 -0700 Subject: [PATCH 020/110] Move SMJ join filtered part out of join_output stage. LeftOuter, LeftSemi (#12764) * WIP: move filtered join out of join_output stage * WIP: move filtered join out of join_output stage * WIP: move filtered join out of join_output stage * cleanup * cleanup * Move Left/LeftAnti filtered SMJ join out of join partial stage * Move Left/LeftAnti filtered SMJ join out of join partial stage * Address comments --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 12 +- .../src/joins/sort_merge_join.rs | 1095 ++++++++++++----- .../test_files/sort_merge_join.slt | 478 +++---- 3 files changed, 1061 insertions(+), 524 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 96aa1be181f5..2eab45256dbb 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -125,8 +125,6 @@ async fn test_left_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -134,7 +132,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -229,6 +227,7 @@ async fn test_anti_join_1k() { #[tokio::test] // flaky for HjSmj case, giving 1 rows difference sometimes // https://github.com/apache/datafusion/issues/11555 +#[ignore] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -515,14 +514,11 @@ impl JoinFuzzTestCase { "input2", ); - if join_tests.contains(&JoinTestType::NljHj) - && join_tests.contains(&JoinTestType::NljHj) - && nlj_rows != hj_rows - { + if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); println!("=============== NestedLoopJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); Self::save_partitioned_batches_as_parquet( &nlj_collected, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2118c1a5266f..5e77becd1c5e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -29,18 +29,17 @@ use std::io::BufReader; use std::mem; use std::ops::Range; use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::{self, concat_batches, take, SortOptions}; +use arrow::compute::{self, concat_batches, filter_record_batch, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; - use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -52,6 +51,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion_physical_expr_common::sort_expr::LexRequirement; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -687,7 +688,7 @@ struct SMJStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: Vec, + pub output_record_batches: JoinedRecordBatches, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -702,6 +703,22 @@ struct SMJStream { pub reservation: MemoryReservation, /// Runtime env pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Joined batches with attached join filter information +struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, } impl RecordBatchStream for SMJStream { @@ -710,6 +727,82 @@ impl RecordBatchStream for SMJStream { } } +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || ids[row_index] != ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +fn get_corrected_filter_mask( + join_type: JoinType, + indices: &UInt64Array, + ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let streamed_indices_length = indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left => { + for i in 0..streamed_indices_length { + let last_index = + last_index_for_row(i, indices, ids, streamed_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi => { + for i in 0..streamed_indices_length { + let last_index = + last_index_for_row(i, indices, ids, streamed_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, + } +} + impl Stream for SMJStream { type Item = Result; @@ -719,7 +812,6 @@ impl Stream for SMJStream { ) -> Poll> { let join_time = self.join_metrics.join_time.clone(); let _timer = join_time.timer(); - loop { match &self.state { SMJState::Init => { @@ -733,6 +825,22 @@ impl Stream for SMJStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + ) + { + self.freeze_all()?; + + if !self.output_record_batches.batches.is_empty() + && self.buffered_data.scanning_finished() + { + let out_batch = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out_batch))); + } + } + self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -786,8 +894,23 @@ impl Stream for SMJStream { } } else { self.freeze_all()?; - if !self.output_record_batches.is_empty() { + if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + let record_batch = if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + )) { + record_batch + } else { + continue; + }; + return Poll::Ready(Some(Ok(record_batch))); } return Poll::Pending; @@ -795,11 +918,23 @@ impl Stream for SMJStream { } SMJState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if !self.output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi + ) + { + let out = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else { + return Poll::Ready(None); } - return Poll::Ready(None); } } } @@ -844,13 +979,19 @@ impl SMJStream { on_streamed, on_buffered, filter, - output_record_batches: vec![], + output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, output_size: 0, batch_size, join_type, join_metrics, reservation, runtime_env, + streamed_batch_counter: AtomicUsize::new(0), }) } @@ -882,6 +1023,10 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; } } @@ -1062,14 +1207,14 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size @@ -1228,7 +1373,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); @@ -1251,7 +1396,7 @@ impl SMJStream { &buffered_indices, buffered_batch, )? { - self.output_record_batches.push(record_batch); + self.output_record_batches.batches.push(record_batch); } buffered_batch.join_filter_failed_map.clear(); } @@ -1329,15 +1474,14 @@ impl SMJStream { }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); + buffered_columns.extend(streamed_columns); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { @@ -1367,59 +1511,46 @@ impl SMJStream { pre_mask.clone() }; - // For certain join types, we need to adjust the initial mask to handle the join filter. - let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = - get_filtered_join_mask( - self.join_type, - &streamed_indices, - &mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ); - - let mask = - if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { - self.streamed_batch - .join_filter_matched_idxs - .extend(&filtered_join_mask.1); - &filtered_join_mask.0 - } else { - &mask - }; - // Push the filtered batch which contains rows passing join filter to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); + if matches!(self.join_type, JoinType::Left | JoinType::LeftSemi) { + self.output_record_batches + .batches + .push(output_batch.clone()); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.output_record_batches.batches.push(filtered_batch); + } + + self.output_record_batches.filter_mask.extend(&mask); + self.output_record_batches + .row_indices + .extend(&streamed_indices); + self.output_record_batches.batch_ids.extend(vec![ + self.streamed_batch_counter.load(Relaxed); + streamed_indices.len() + ]); // For outer joins, we need to push the null joined rows to the output if // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!( - self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full - ) { + if matches!(self.join_type, JoinType::Right | JoinType::Full) { // We need to get the mask for row indices that the joined rows are failed // on the join filter. I.e., for a row in streamed side, if all joined rows // between it and all buffered rows are failed on the join filter, we need to // output it with null columns from buffered side. For the mask here, it // behaves like LeftAnti join. - let null_mask: BooleanArray = get_filtered_join_mask( - // Set a mask slot as true only if all joined rows of same streamed index - // are failed on the join filter. - // The masking behavior is like LeftAnti join. - JoinType::LeftAnti, - &streamed_indices, - mask, - &self.streamed_batch.join_filter_matched_idxs, - &self.buffered_data.scanning_offset, - ) - .unwrap() - .0; + let not_mask = if mask.null_count() > 0 { + // If the mask contains nulls, we need to use `prep_null_mask_filter` to + // handle the nulls in the mask as false to produce rows where the mask + // was null itself. + compute::not(&compute::prep_null_mask_filter(&mask))? + } else { + compute::not(&mask)? + }; let null_joined_batch = - compute::filter_record_batch(&output_batch, &null_mask)?; + filter_record_batch(&output_batch, ¬_mask)?; let mut buffered_columns = self .buffered_schema @@ -1457,11 +1588,11 @@ impl SMJStream { }; // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = RecordBatch::try_new( - Arc::clone(&self.schema), - columns.clone(), - )?; - self.output_record_batches.push(null_joined_streamed_batch); + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + self.output_record_batches + .batches + .push(null_joined_streamed_batch); // For full join, we also need to output the null joined rows from the buffered side. // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with @@ -1494,10 +1625,10 @@ impl SMJStream { } } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } @@ -1507,7 +1638,8 @@ impl SMJStream { } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; + let record_batch = + concat_batches(&self.schema, &self.output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact @@ -1520,9 +1652,92 @@ impl SMJStream { } else { self.output_size -= record_batch.num_rows(); } - self.output_record_batches.clear(); + + if !(self.filter.is_some() + && matches!(self.join_type, JoinType::Left | JoinType::LeftSemi)) + { + self.output_record_batches.batches.clear(); + } Ok(record_batch) } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = self.output_record_batch_and_reset()?; + let out_indices = self.output_record_batches.row_indices.finish(); + let out_mask = self.output_record_batches.filter_mask.finish(); + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + &self.output_record_batches.batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let buffered_columns_length = self.buffered_schema.fields.len(); + let streamed_columns_length = self.streamed_schema.fields.len(); + + if matches!(self.join_type, JoinType::Left | JoinType::Right) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), null_joined_batch.num_rows())) + .collect::>(); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + // Left join or full outer join + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi) { + let output_column_indices = (0..streamed_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } + + self.output_record_batches.batches.clear(); + self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.filter_mask = BooleanBuilder::new(); + self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) + } } /// Gets the arrays which join filters are applied on. @@ -1631,101 +1846,6 @@ fn get_buffered_columns_from_batch( } } -/// Calculate join filter bit mask considering join type specifics -/// `streamed_indices` - array of streamed datasource JOINED row indices -/// `mask` - array booleans representing computed join filter expression eval result: -/// true = the row index matches the join filter -/// false = the row index doesn't match the join filter -/// `streamed_indices` have the same length as `mask` -/// `matched_indices` array of streaming indices that already has a join filter match -/// `scanning_buffered_offset` current buffered offset across batches -/// -/// This return a tuple of: -/// - corrected mask with respect to the join type -/// - indices of rows in streamed batch that have a join filter match -fn get_filtered_join_mask( - join_type: JoinType, - streamed_indices: &UInt64Array, - mask: &BooleanArray, - matched_indices: &HashSet, - scanning_buffered_offset: &usize, -) -> Option<(BooleanArray, Vec)> { - let mut seen_as_true: bool = false; - let streamed_indices_length = streamed_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); - - let mut filter_matched_indices: Vec = vec![]; - - #[allow(clippy::needless_range_loop)] - match join_type { - // for LeftSemi Join the filter mask should be calculated in its own way: - // if we find at least one matching row for specific streaming index - // we don't need to check any others for the same index - JoinType::LeftSemi => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - // LeftSemi respects only first true values for specific streaming index, - // others true values for the same index must be false - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - corrected_mask.append_value(true); - filter_matched_indices.push(streamed_idx); - } else { - corrected_mask.append_value(false); - } - - // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag - if i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1) - { - seen_as_true = false; - } - } - Some((corrected_mask.finish(), filter_matched_indices)) - } - // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. - // `filter_matched_indices` needs to be set once per streaming index - // to prevent duplicates in the output - JoinType::LeftAnti => { - // have we seen a filter match for a streaming index before - for i in 0..streamed_indices_length { - let streamed_idx = streamed_indices.value(i); - if mask.value(i) - && !seen_as_true - && !matched_indices.contains(&streamed_idx) - { - seen_as_true = true; - filter_matched_indices.push(streamed_idx); - } - - // Reset `seen_as_true` flag and calculate mask for the current streaming index - // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) - // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last - if (i < streamed_indices_length - 1 - && streamed_idx != streamed_indices.value(i + 1)) - || (i == streamed_indices_length - 1 - && *scanning_buffered_offset == 0) - { - corrected_mask.append_value( - !matched_indices.contains(&streamed_idx) && !seen_as_true, - ); - seen_as_true = false; - } else { - corrected_mask.append_value(false); - } - } - - Some((corrected_mask.finish(), filter_matched_indices)) - } - _ => None, - } -} - /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1966,13 +2086,13 @@ mod tests { use std::sync::Arc; use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::builder::{BooleanBuilder, UInt64Builder}; use arrow_array::{BooleanArray, UInt64Array}; - use hashbrown::HashSet; - use datafusion_common::JoinType::{LeftAnti, LeftSemi}; + use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -1982,7 +2102,7 @@ mod tests { use datafusion_execution::TaskContext; use crate::expressions::Column; - use crate::joins::sort_merge_join::get_filtered_join_mask; + use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; use crate::joins::utils::JoinOn; use crate::joins::SortMergeJoinExec; use crate::memory::MemoryExec; @@ -3214,170 +3334,573 @@ mod tests { } #[tokio::test] - async fn left_semi_join_filtered_mask() -> Result<()> { + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut tb = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![1; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + tb.batch_ids.extend(vec![2; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![3; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + tb.filter_mask + .extend(&BooleanArray::from(vec![true, false])); + tb.filter_mask.extend(&BooleanArray::from(vec![true])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, true])); + tb.filter_mask.extend(&BooleanArray::from(vec![false])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + let output = concat_batches(&schema, &tb.batches)?; + let out_mask = tb.filter_mask.finish(); + let out_indices = tb.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, false, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + false, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, true]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, true, false, false, false, false, false, false + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![1])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![0])) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, true, false, true, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, true]), - vec![1] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) ); assert_eq!( - get_filtered_join_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![true, false, false, false, false, true]), - &HashSet::from_iter(vec![1]), - &0, - ), - Some(( - BooleanArray::from(vec![true, false, false, false, false, false]), - vec![0] - )) + get_corrected_filter_mask( + JoinType::Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + JoinType::Left, + &out_indices, + &tb.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] ); + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[null_joined_batch] + ); Ok(()) } #[tokio::test] - async fn left_anti_join_filtered_mask() -> Result<()> { + async fn test_left_semi_join_filtered_mask() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut tb = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + tb.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + tb.batch_ids.extend(vec![0; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![1; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + tb.batch_ids.extend(vec![2; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + tb.batch_ids.extend(vec![3; streamed_indices.len()]); + tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + + tb.filter_mask + .extend(&BooleanArray::from(vec![true, false])); + tb.filter_mask.extend(&BooleanArray::from(vec![true])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, true])); + tb.filter_mask.extend(&BooleanArray::from(vec![false])); + tb.filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + let output = concat_batches(&schema, &tb.batches)?; + let out_mask = tb.filter_mask.finish(); + let out_indices = tb.row_indices.finish(); + assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 1, 1]), - &BooleanArray::from(vec![true, true, false, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], &BooleanArray::from(vec![true, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, false]), vec![0, 1])) + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![false, true]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![true, false]), vec![1])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 1]), - &BooleanArray::from(vec![true, false]), - &HashSet::new(), - &0, - ), - Some((BooleanArray::from(vec![false, true]), vec![0])) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, true, true, true, true, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, false, false, false, false]), - vec![0, 1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) ); assert_eq!( - get_filtered_join_mask( - LeftAnti, - &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), - &BooleanArray::from(vec![false, false, false, false, false, true]), - &HashSet::new(), - &0, - ), - Some(( - BooleanArray::from(vec![false, false, true, false, false, false]), - vec![1] - )) + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) ); + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftSemi, + &out_indices, + &tb.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index ebd53e9690fc..d00b7d6f0a52 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -100,13 +100,14 @@ Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 +# Uncomment when filtered RIGHT moved # right join with join filter -query TITI rowsort -SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 Alice 1 +#query TITI rowsort +#SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 Alice 1 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -126,22 +127,24 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL +# Uncomment when filtered FULL moved # full join with join filter -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ----- -Alice 100 NULL NULL -Alice 50 Alice 2 -Bob 1 NULL NULL -NULL NULL Alice 1 - -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 NULL NULL -Bob 1 NULL NULL +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +#---- +#Alice 100 NULL NULL +#Alice 50 Alice 2 +#Bob 1 NULL NULL +#NULL NULL Alice 1 + +# Uncomment when filtered RIGHT moved +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 NULL NULL +#Bob 1 NULL NULL statement ok DROP TABLE t1; @@ -405,221 +408,236 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != statement ok set datafusion.execution.batch_size = 10; -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c union all - select 11 a, 14 b, 4 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c where false - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 11 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 11 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 14 c union all - select 11 a, 11 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c where false +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 11 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 11 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 14 c union all +# select 11 a, 11 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- # Test LEFT ANTI with cross batch data distribution statement ok set datafusion.execution.batch_size = 1; -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c union all - select 11 a, 14 b, 4 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query III -select * from ( -with -t1 as ( - select 11 a, 12 b, 1 c union all - select 11 a, 13 b, 2 c), -t2 as ( - select 11 a, 12 b, 3 c where false - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -) order by 1, 2; ----- -11 12 1 -11 13 2 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 13 c union all - select 11 a, 14 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- -11 12 - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 11 c union all - select 11 a, 15 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query II -select * from ( -with -t1 as ( - select 11 a, 12 b), -t2 as ( - select 11 a, 12 c union all - select 11 a, 14 c union all - select 11 a, 11 c - ) -select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -) order by 1, 2 ----- - -query IIII -select * from ( -with t as ( - select id, id % 5 id1 from (select unnest(range(0,10)) id) -), t1 as ( - select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) -) -select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 -) order by 1, 2, 3, 4 ----- -5 0 0 2 -6 1 1 3 -7 2 2 4 -8 3 3 5 -9 4 4 6 -NULL NULL 5 7 -NULL NULL 6 8 -NULL NULL 7 9 -NULL NULL 8 10 -NULL NULL 9 11 +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c where false +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 13 c union all +# select 11 a, 14 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- +#11 12 + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 11 c union all +# select 11 a, 15 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered LEFTANTI moved +#query II +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b), +#t2 as ( +# select 11 a, 12 c union all +# select 11 a, 14 c union all +# select 11 a, 11 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +#) order by 1, 2 +#---- + +# Uncomment when filtered RIGHT moved +#query IIII +#select * from ( +#with t as ( +# select id, id % 5 id1 from (select unnest(range(0,10)) id) +#), t1 as ( +# select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +#) +#select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +#) order by 1, 2, 3, 4 +#---- +#5 0 0 2 +#6 1 1 3 +#7 2 2 4 +#8 3 3 5 +#9 4 4 6 +#NULL NULL 5 7 +#NULL NULL 6 8 +#NULL NULL 7 9 +#NULL NULL 8 10 +#NULL NULL 9 11 query IIII select * from ( From 73ba4c45ff44e7c3c697aa8fea7bb019bb76711a Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 18 Oct 2024 16:19:48 -0400 Subject: [PATCH 021/110] feat: Add regexp_count function (#12970) * Implement regexp_ccount * Update document * fix check * add more tests * Update the world to 1.80 * Fix doc format * Add null tests * Add uft8 support and bench * Refactoring regexp_count * Refactoring regexp_count * Revert ci change * Fix ci * Updates for documentation, minor improvements. * Updates for documentation, minor improvements. * updates to fix scalar tests, doc updates. * updated regex and string features to remove deps on other features. --------- Co-authored-by: Xin Li --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/benches/regx.rs | 54 +- datafusion/functions/src/regex/mod.rs | 27 +- datafusion/functions/src/regex/regexpcount.rs | 951 ++++++++++++++++++ datafusion/sqllogictest/test_files/regexp.slt | 331 +++++- .../user-guide/sql/scalar_functions_new.md | 32 + 6 files changed, 1382 insertions(+), 15 deletions(-) create mode 100644 datafusion/functions/src/regex/regexpcount.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 6099ad62c1d9..70a988dbfefb 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["regex_expressions", "uuid"] +string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index c9a9c1dfb19e..468d3d548bcf 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,11 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +62,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,6 +87,46 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_count_1000 string", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index cde777311aa1..803f51e915a9 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -19,11 +19,13 @@ use std::sync::Arc; +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -35,6 +37,24 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + /// Returns a list of regular expression matches in a string. pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; @@ -70,5 +90,10 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { - vec![regexp_match(), regexp_like(), regexp_replace()] + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] } diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 000000000000..880c91094555 --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,951 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::strings::StringArrayType; +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.") + .with_syntax_example("regexp_count(str, regexp[, start, flags])") + .with_sql_example(r#"```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_standard_argument("regexp","Regular") + .with_argument("start", "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); + } + } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_count` function. +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; + + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } + } else { + (None, None, true) + }; + + let mut regex_cache = HashMap::new(); + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::, ArrowError>>()?, + ))) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + let start_array = start_array.unwrap(); + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::, ArrowError>>()?, + ))) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + } +} + +fn compile_and_cache_regex( + regex: &str, + flags: Option<&str>, + regex_cache: &mut HashMap, +) -> Result { + match regex_cache.entry(regex.to_string()) { + Entry::Vacant(entry) => { + let compiled = compile_regex(regex, flags)?; + entry.insert(compiled.clone()); + Ok(compiled) + } + Entry::Occupied(entry) => Ok(entry.get().to_owned()), + } +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{GenericStringArray, StringViewArray}; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_scalar_start(); + test_case_insensitive_regexp_count_scalar_flags(); + test_case_sensitive_regexp_count_start_scalar_complex(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let expected: Vec = vec![0, 1, 2, 1, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new() + .invoke(&[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 2; + let expected: Vec = vec![0, 1, 1, 0, 2]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 1; + let flags = "i"; + let expected: Vec = vec![0, 1, 2, 2, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = ["", "abc", "a", "bc", "ab"]; + let start = 5; + let flags = ["", "i", "", "", "i"]; + let expected: Vec = vec![0, 0, 0, 1, 1]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke(&[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index eedc3ddb6d59..800026dd766d 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -460,6 +460,313 @@ SELECT NULL not iLIKE NULL; ---- NULL +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t; diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ffc2b680b5c5..ca70c83e58f9 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1676,10 +1676,42 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) (minus support for several features including look-around and backreferences). The following regular expression functions are supported: +- [regexp_count](#regexp_count) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) +### `regexp_count` + +Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +``` + ### `regexp_like` Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. From 8c9b9152c8201d8b75d8e0b9b85b85d3199c94d8 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:41:04 -0400 Subject: [PATCH 022/110] Minor: Fixed regexpr_match docs (#13008) * regexp_match * update generated docs --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/regex/regexpmatch.rs | 2 +- docs/source/user-guide/sql/scalar_functions_new.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 443e50533268..4a86adbe683a 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -119,7 +119,7 @@ fn get_regexp_match_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_REGEX) - .with_description("Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string.") .with_syntax_example("regexp_match(str, regexp[, flags])") .with_sql_example(r#"```sql > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ca70c83e58f9..1915623012f4 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1752,7 +1752,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `regexp_match` -Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string. ``` regexp_match(str, regexp[, flags]) From 10af8a73662f4f6aac09a34157b7cf5fee034502 Mon Sep 17 00:00:00 2001 From: Albert Skalt <133099191+askalt@users.noreply.github.com> Date: Fri, 18 Oct 2024 23:41:53 +0300 Subject: [PATCH 023/110] Improve performance for physical plan creation with many columns (#12950) * Add a benchmark for physical plan creation with many aggregates * Wrap AggregateFunctionExpr with Arc Patch f5c47fa274d53c1d524a1fb788d9a063bf5240ef removed Arc wrappers for AggregateFunctionExpr. But, it can be inefficient. When physical optimizer decides to replace a node child to other, it clones the node (with `with_new_children`). Assume, that node is `AggregateExec` than contains hundreds aggregates and these aggregates are cloned each time. This patch returns a Arc wrapping to not clone AggregateFunctionExpr itself but clone a pointer. * Do not build mapping if parent does not require any This patch adds a small optimization that can soft the edges on some queries. If there are no parent requirements we do not need to build column mapping. --- datafusion/core/benches/sql_planner.rs | 14 +++ .../physical_optimizer/update_aggr_exprs.rs | 10 +- datafusion/core/src/physical_planner.rs | 5 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 1 + .../combine_partial_final_agg.rs | 8 +- .../limited_distinct_aggregation.rs | 16 +-- datafusion/physical-expr/src/aggregate.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 4 + .../physical-expr/src/window/aggregate.rs | 8 +- .../src/window/sliding_aggregate.rs | 13 +- .../src/aggregate_statistics.rs | 24 ++-- .../src/combine_partial_final_agg.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 119 ++++++++++-------- .../physical-plan/src/aggregates/row_hash.rs | 4 +- datafusion/physical-plan/src/windows/mod.rs | 5 +- datafusion/proto/src/physical_plan/mod.rs | 3 +- .../proto/src/physical_plan/to_proto.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 44 ++++--- 18 files changed, 165 insertions(+), 119 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 00f6d5916751..e7c35c8d86d6 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -144,6 +144,20 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("physical_select_aggregates_from_200", |b| { + let mut aggregates = String::new(); + for i in 0..200 { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + let query = format!("SELECT {} FROM t1", aggregates); + b.iter(|| { + physical_plan(&ctx, &query); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs index c0d9140c025e..26cdd65883e4 100644 --- a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -131,10 +131,10 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { /// successfully. Any errors occurring during the conversion process are /// passed through. fn try_convert_aggregate_if_better( - aggr_exprs: Vec, + aggr_exprs: Vec>, prefix_requirement: &[PhysicalSortRequirement], eq_properties: &EquivalenceProperties, -) -> Result> { +) -> Result>> { aggr_exprs .into_iter() .map(|aggr_expr| { @@ -154,7 +154,7 @@ fn try_convert_aggregate_if_better( let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); if eq_properties.ordering_satisfy_requirement(&reqs) { // Existing ordering satisfies the aggregator requirements: - aggr_expr.with_beneficial_ordering(true)? + aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) } else if eq_properties.ordering_satisfy_requirement(&concat_slices( prefix_requirement, &reverse_aggr_req, @@ -163,12 +163,14 @@ fn try_convert_aggregate_if_better( // given the existing ordering (if possible): aggr_expr .reverse_expr() + .map(Arc::new) .unwrap_or(aggr_expr) .with_beneficial_ordering(true)? + .map(Arc::new) } else { // There is no beneficial ordering present -- aggregation // will still work albeit in a less efficient mode. - aggr_expr.with_beneficial_ordering(false)? + aggr_expr.with_beneficial_ordering(false)?.map(Arc::new) } .ok_or_else(|| { plan_datafusion_err!( diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cf2a157b04b6..a4dffd3d0208 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1523,7 +1523,7 @@ pub fn create_window_expr( } type AggregateExprWithOptionalArgs = ( - AggregateFunctionExpr, + Arc, // The filter clause, if any Option>, // Ordering requirements, if any @@ -1587,7 +1587,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .alias(name) .with_ignore_nulls(ignore_nulls) .with_distinct(*distinct) - .build()?; + .build() + .map(Arc::new)?; (agg_expr, filter, physical_sort_exprs) }; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 34061a64d783..ff512829333a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -405,6 +405,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .schema(Arc::clone(&schema)) .alias("sum1") .build() + .map(Arc::new) .unwrap(), ]; let expr = group_by_columns diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 24e46b3ad97c..85076abdaf29 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -84,7 +84,7 @@ fn parquet_exec(schema: &SchemaRef) -> Arc { fn partial_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -104,7 +104,7 @@ fn partial_aggregate_exec( fn final_aggregate_exec( input: Arc, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, ) -> Arc { let schema = input.schema(); let n_aggr = aggr_expr.len(); @@ -130,11 +130,12 @@ fn count_expr( expr: Arc, name: &str, schema: &Schema, -) -> AggregateFunctionExpr { +) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) .schema(Arc::new(schema.clone())) .alias(name) .build() + .map(Arc::new) .unwrap() } @@ -218,6 +219,7 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { .schema(Arc::clone(&schema)) .alias("Sum(b)") .build() + .map(Arc::new) .unwrap(), ]; let groups: Vec<(Arc, String)> = diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 042f6d622565..d6991711f581 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -347,10 +347,10 @@ fn test_has_aggregate_expression() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -384,10 +384,10 @@ fn test_has_filter() -> Result<()> { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr(&schema)], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 866596d0b690..6330c240241a 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -328,7 +328,7 @@ impl AggregateFunctionExpr { /// not implement the method, returns an error. Order insensitive and hard /// requirement aggregators return `Ok(None)`. pub fn with_beneficial_ordering( - self, + self: Arc, beneficial_ordering: bool, ) -> Result> { let Some(updated_fn) = self diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7..4bd022975ac3 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -86,6 +86,10 @@ pub fn map_columns_before_projection( parent_required: &[Arc], proj_exprs: &[(Arc, String)], ) -> Vec> { + if parent_required.is_empty() { + // No need to build mapping. + return vec![]; + } let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index d012fef93b67..3fe5d842dfd1 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct PlainAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct PlainAggregateWindowExpr { impl PlainAggregateWindowExpr { /// Create a new aggregate window function expression pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -137,14 +137,14 @@ impl WindowExpr for PlainAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 143d59eb4495..b889ec8c5d98 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -41,7 +41,7 @@ use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct SlidingAggregateWindowExpr { - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: Vec>, order_by: Vec, window_frame: Arc, @@ -50,7 +50,7 @@ pub struct SlidingAggregateWindowExpr { impl SlidingAggregateWindowExpr { /// Create a new (sliding) aggregate window function expression. pub fn new( - aggregate: AggregateFunctionExpr, + aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -121,14 +121,14 @@ impl WindowExpr for SlidingAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), @@ -159,7 +159,10 @@ impl WindowExpr for SlidingAggregateWindowExpr { }) .collect::>(); Some(Arc::new(SlidingAggregateWindowExpr { - aggregate: self.aggregate.with_new_expressions(args, vec![])?, + aggregate: self + .aggregate + .with_new_expressions(args, vec![]) + .map(Arc::new)?, partition_by: partition_bys, order_by: new_order_by, window_frame: Arc::clone(&self.window_frame), diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index fd21362fd3eb..27870c7865f3 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -312,7 +312,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -321,7 +321,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -342,7 +342,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -351,7 +351,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -371,7 +371,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -383,7 +383,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -403,7 +403,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -415,7 +415,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -446,7 +446,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -455,7 +455,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -491,7 +491,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -500,7 +500,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 4e352e25b52c..86f7e73e9e35 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -125,7 +125,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { type GroupExprsRef<'a> = ( &'a PhysicalGroupBy, - &'a [AggregateFunctionExpr], + &'a [Arc], &'a [Option>], ); diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 296c5811e577..f36bd920e83c 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -351,7 +351,7 @@ pub struct AggregateExec { /// Group by expressions group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause @@ -378,7 +378,10 @@ impl AggregateExec { /// Function used in `OptimizeAggregateOrder` optimizer rule, /// where we need parts of the new value, others cloned from the old one /// Rewrites aggregate exec with new aggregate expressions. - pub fn with_new_aggr_exprs(&self, aggr_expr: Vec) -> Self { + pub fn with_new_aggr_exprs( + &self, + aggr_expr: Vec>, + ) -> Self { Self { aggr_expr, // clone the rest of the fields @@ -404,7 +407,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -435,7 +438,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -545,7 +548,7 @@ impl AggregateExec { } /// Aggregate expressions - pub fn aggr_expr(&self) -> &[AggregateFunctionExpr] { + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -876,7 +879,7 @@ impl ExecutionPlan for AggregateExec { fn create_schema( input_schema: &Schema, group_by: &PhysicalGroupBy, - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: AggregateMode, ) -> Result { let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); @@ -1006,7 +1009,7 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( - aggr_exprs: &mut [AggregateFunctionExpr], + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, @@ -1034,7 +1037,7 @@ pub fn get_finer_aggregate_exprs_requirement( // Reverse requirement is satisfied by exiting ordering. // Hence reverse the aggregator requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1058,7 +1061,7 @@ pub fn get_finer_aggregate_exprs_requirement( // There is a requirement that both satisfies existing requirement and reverse // aggregate requirement. Use updated requirement requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -1080,7 +1083,7 @@ pub fn get_finer_aggregate_exprs_requirement( /// * Partial: AggregateFunctionExpr::expressions /// * Final: columns of `AggregateFunctionExpr::state_fields()` pub fn aggregate_expressions( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { @@ -1135,7 +1138,7 @@ fn merge_expressions( pub type AccumulatorItem = Box; pub fn create_accumulators( - aggr_expr: &[AggregateFunctionExpr], + aggr_expr: &[Arc], ) -> Result> { aggr_expr .iter() @@ -1458,10 +1461,12 @@ mod tests { ], ); - let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) - .alias("COUNT(1)") - .build()?]; + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(1)") + .build()?, + )]; let task_ctx = if spill { // adjust the max memory size to have the partial aggregate result for spill mode. @@ -1596,13 +1601,12 @@ mod tests { vec![vec![false]], ); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1925,17 +1929,16 @@ mod tests { ); // something that allocates within the aggregator - let aggregates_v0: Vec = - vec![test_median_agg_expr(Arc::clone(&input_schema))?]; + let aggregates_v0: Vec> = + vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates_v2: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1989,13 +1992,12 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(a)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2029,13 +2031,12 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(b)") - .build()?, - ]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2080,7 +2081,7 @@ mod tests { fn test_first_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2092,13 +2093,14 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, sort_options: SortOptions, - ) -> Result { + ) -> Result> { let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, @@ -2109,6 +2111,7 @@ mod tests { .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() + .map(Arc::new) } // This function either constructs the physical plan below, @@ -2153,7 +2156,7 @@ mod tests { descending: false, nulls_first: false, }; - let aggregates: Vec = if is_first_acc { + let aggregates: Vec> = if is_first_acc { vec![test_first_value_agg_expr(&schema, sort_options)?] } else { vec![test_last_value_agg_expr(&schema, sort_options)?] @@ -2289,6 +2292,7 @@ mod tests { .order_by(ordering_req.to_vec()) .schema(Arc::clone(&test_schema)) .build() + .map(Arc::new) .unwrap() }) .collect::>(); @@ -2318,7 +2322,7 @@ mod tests { }; let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); - let aggregates: Vec = vec![ + let aggregates: Vec> = vec![ test_first_value_agg_expr(&schema, option_desc)?, test_last_value_agg_expr(&schema, option_desc)?, ]; @@ -2376,11 +2380,12 @@ mod tests { ], ); - let aggregates: Vec = + let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) .schema(Arc::clone(&schema)) .alias("1") - .build()?]; + .build() + .map(Arc::new)?]; let input_batches = (0..4) .map(|_| { @@ -2512,7 +2517,8 @@ mod tests { ) .schema(Arc::clone(&batch.schema())) .alias(String::from("SUM(value)")) - .build()?]; + .build() + .map(Arc::new)?]; let input = Arc::new(MemoryExec::try_new( &[vec![batch.clone()]], @@ -2560,7 +2566,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2641,7 +2648,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) .schema(Arc::clone(&schema)) .alias(String::from("COUNT(val)")) - .build()?, + .build() + .map(Arc::new)?, ]; let input_data = vec![ @@ -2728,7 +2736,8 @@ mod tests { AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) .schema(Arc::clone(&input_schema)) .alias("COUNT(a)") - .build()?, + .build() + .map(Arc::new)?, ]; let grouping_set = PhysicalGroupBy::new( diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 624844b6b985..7d21cc2f1944 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -591,7 +591,7 @@ impl GroupedHashAggregateStream { /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( - agg_expr: &AggregateFunctionExpr, + agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() @@ -601,7 +601,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index adf61f27bc6f..f6902fcbe2e7 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -119,7 +119,8 @@ pub fn create_window_expr( .schema(Arc::new(input_schema.clone())) .alias(name) .with_ignore_nulls(ignore_nulls) - .build()?; + .build() + .map(Arc::new)?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -142,7 +143,7 @@ fn window_expr_from_aggregate_expr( partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, - aggregate: AggregateFunctionExpr, + aggregate: Arc, ) -> Arc { // Is there a potentially unlimited sized window frame? let unbounded_window = window_frame.start_bound.is_unbounded(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9a6850cb2108..634ae284c955 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -488,7 +488,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let physical_aggr_expr: Vec = hash_agg + let physical_aggr_expr: Vec> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) @@ -518,6 +518,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .with_distinct(agg_node.distinct) .order_by(ordering_req) .build() + .map(Arc::new) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6072baca688c..33eca0723103 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -48,7 +48,7 @@ use crate::protobuf::{ use super::PhysicalExtensionCodec; pub fn serialize_physical_aggr_expr( - aggr_expr: AggregateFunctionExpr, + aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 025676f790a8..4a9bf6afb49e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,7 +73,6 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{ @@ -305,7 +304,8 @@ fn roundtrip_window() -> Result<()> { ) .schema(Arc::clone(&schema)) .alias("avg(b)") - .build()?, + .build() + .map(Arc::new)?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -321,7 +321,8 @@ fn roundtrip_window() -> Result<()> { let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) .schema(Arc::clone(&schema)) .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") - .build()?; + .build() + .map(Arc::new)?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, @@ -367,13 +368,13 @@ fn rountrip_aggregate() -> Result<()> { .alias("NTH_VALUE(b, 1)") .build()?; - let test_cases: Vec> = vec![ + let test_cases = vec![ // AVG - vec![avg_expr], + vec![Arc::new(avg_expr)], // NTH_VALUE - vec![nth_expr], + vec![Arc::new(nth_expr)], // STRING_AGG - vec![str_agg_expr], + vec![Arc::new(str_agg_expr)], ]; for aggregates in test_cases { @@ -400,12 +401,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("AVG(b)") - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -429,13 +431,14 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = vec![AggregateExprBuilder::new( + let aggregates = vec![AggregateExprBuilder::new( approx_percentile_cont_udaf(), vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) .alias("APPROX_PERCENTILE_CONT(b, 0.5)") - .build()?]; + .build() + .map(Arc::new)?]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -464,13 +467,14 @@ fn rountrip_aggregate_with_sort() -> Result<()> { }, }]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("ARRAY_AGG(b)") .order_by(sort_exprs) - .build()?, + .build() + .map(Arc::new)?, ]; let agg = AggregateExec::try_new( @@ -531,12 +535,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec = + let aggregates = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) .schema(Arc::clone(&schema)) .alias("example_agg") - .build()?, + .build() + .map(Arc::new)?, ]; roundtrip_test_with_context( @@ -1001,7 +1006,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { AggregateExprBuilder::new(max_udaf(), vec![udf_expr as Arc]) .schema(schema.clone()) .alias("max") - .build()?; + .build() + .map(Arc::new)?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( @@ -1052,7 +1058,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) .schema(Arc::clone(&schema)) .alias("aggregate_udf") - .build()?; + .build() + .map(Arc::new)?; let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( @@ -1079,7 +1086,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { .alias("aggregate_udf") .distinct() .ignore_nulls() - .build()?; + .build() + .map(Arc::new)?; let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, From 34bd8237d2189eca5b560c034d15e63d97a15fa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 18 Oct 2024 23:00:24 +0200 Subject: [PATCH 024/110] Remove logical cross join in planning (#12985) * Remove logical cross join in planning * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * WIP * Implement some more substrait pieces * Update datafusion/core/src/physical_planner.rs Co-authored-by: Oleks V * Remove incorrect comment --------- Co-authored-by: Oleks V --- datafusion/core/src/physical_planner.rs | 22 ++++--- datafusion/expr/src/logical_plan/builder.rs | 11 +++- datafusion/expr/src/logical_plan/plan.rs | 6 ++ .../optimizer/src/eliminate_cross_join.rs | 25 +++++--- datafusion/optimizer/src/eliminate_join.rs | 26 +------- datafusion/optimizer/src/push_down_filter.rs | 4 +- datafusion/optimizer/src/push_down_limit.rs | 7 +-- datafusion/sql/src/relation/join.rs | 4 +- datafusion/sql/tests/cases/plan_to_sql.rs | 2 +- datafusion/sql/tests/sql_integration.rs | 30 ++++----- datafusion/sqllogictest/test_files/cte.slt | 2 +- .../sqllogictest/test_files/group_by.slt | 2 +- datafusion/sqllogictest/test_files/join.slt | 4 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- datafusion/sqllogictest/test_files/select.slt | 2 +- datafusion/sqllogictest/test_files/update.slt | 4 +- .../substrait/src/logical_plan/consumer.rs | 12 +++- .../tests/cases/consumer_integration.rs | 62 +++++++++---------- 18 files changed, 117 insertions(+), 110 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a4dffd3d0208..918ebccbeb70 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,7 +78,7 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; @@ -1045,14 +1045,18 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, - join_type, - )?) + if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + // cross join if there is no join conditions and no join filter set + Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else { + // there is no equal join condition, use the nested loop join + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da2a96327ce5..6ab50440ec5b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -30,8 +30,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -950,9 +950,14 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin { + Ok(Self::new(LogicalPlan::Join(Join { left: self.plan, right: Arc::new(right), + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, schema: DFSchemaRef::new(join_schema), }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9bd57d22128d..10a99c9e78da 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -222,6 +222,7 @@ pub enum LogicalPlan { Join(Join), /// Apply Cross Join to two logical plans. /// This is used to implement SQL `CROSS JOIN` + /// Deprecated: use [LogicalPlan::Join] instead with empty `on` / no filter CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an @@ -1873,6 +1874,11 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); + let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + "Cross".to_string() + } else { + join_type.to_string() + }; match join_constraint { JoinConstraint::On => { write!( diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 550728ddd3f9..bce5c77ca674 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -25,7 +25,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; @@ -51,7 +51,7 @@ impl EliminateCrossJoin { /// Looks like this: /// ```text /// Filter(a.x = b.y AND b.xx = 100) -/// CrossJoin +/// Cross Join /// TableScan a /// TableScan b /// ``` @@ -351,10 +351,15 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::CrossJoin(CrossJoin { + Ok(LogicalPlan::Join(Join { left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, })) } @@ -513,7 +518,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -601,7 +606,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -627,7 +632,7 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -843,7 +848,7 @@ mod tests { let expected = vec![ "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", @@ -924,7 +929,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; @@ -999,7 +1004,7 @@ mod tests { "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -1238,7 +1243,7 @@ mod tests { let expected = vec![ "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index f9b79e036f9b..789235595dab 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - CrossJoin, Expr, + Expr, }; /// Eliminates joins when join condition is false. @@ -54,13 +54,6 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -105,21 +98,4 @@ mod tests { let expected = "EmptyRelation"; assert_optimized_plan_equal(plan, expected) } - - #[test] - fn join_on_true() -> Result<()> { - let plan = LogicalPlanBuilder::empty(false) - .join_on( - LogicalPlanBuilder::empty(false).build()?, - Inner, - Some(lit(true)), - )? - .build()?; - - let expected = "\ - CrossJoin:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) - } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6e2cc0cbdbcb..2e3bca5b0bbd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1727,7 +1727,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.d\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ @@ -1754,7 +1754,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.a\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 47fce64ae00e..6ed77387046e 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -254,10 +254,9 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { - Left | Right | Full => (Some(limit), Some(limit)), + Left | Right | Full | Inner => (Some(limit), Some(limit)), LeftAnti | LeftSemi => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), - Inner => (None, None), } } else { match join.join_type { @@ -1116,7 +1115,7 @@ mod test { .build()?; let expected = "Limit: skip=0, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000\ \n Limit: skip=0, fetch=1000\ @@ -1136,7 +1135,7 @@ mod test { .build()?; let expected = "Limit: skip=1000, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=2000\ \n TableScan: test, fetch=2000\ \n Limit: skip=0, fetch=2000\ diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 409533a3eaa5..3f34608e3756 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -151,7 +151,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } } - JoinConstraint::None => not_impl_err!("NONE constraint is not supported"), + JoinConstraint::None => LogicalPlanBuilder::from(left) + .join_on(right, join_type, [])? + .build(), } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 74abdf075f23..2a3c5b5f6b2b 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -243,7 +243,7 @@ fn roundtrip_crossjoin() -> Result<()> { .unwrap(); let expected = "Projection: j1.j1_id, j2.j2_string\ - \n Inner Join: Filter: Boolean(true)\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 19f3d31321ce..edb614493b38 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -898,7 +898,7 @@ fn natural_right_join() { fn natural_join_no_common_becomes_cross_join() { let sql = "SELECT * FROM person a NATURAL JOIN lineitem b"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: a\ \n TableScan: person\ \n SubqueryAlias: b\ @@ -2744,8 +2744,8 @@ fn cross_join_not_to_inner_join() { "select person.id from person, orders, lineitem where person.id = person.age;"; let expected = "Projection: person.id\ \n Filter: person.id = person.age\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: person\ \n TableScan: orders\ \n TableScan: lineitem"; @@ -2842,11 +2842,11 @@ fn exists_subquery_schema_outer_schema_overlap() { \n Subquery:\ \n Projection: person.first_name\ \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p2\ \n TableScan: person\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2934,10 +2934,10 @@ fn scalar_subquery_reference_outer_field() { \n Projection: count(*)\ \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j3\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; @@ -3123,7 +3123,7 @@ fn join_on_complex_condition() { fn lateral_constant() { let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3138,7 +3138,7 @@ fn lateral_comma_join() { j1, \ LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2"; let expected = "Projection: j1.j1_string, j2.j2_string\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3154,7 +3154,7 @@ fn lateral_comma_join_referencing_join_rhs() { \n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\ \n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n Inner Join: Filter: j1.j1_id = j2.j2_id\ \n TableScan: j1\ \n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\ @@ -3178,12 +3178,12 @@ fn lateral_comma_join_with_shadowing() { ) as j2\ ) as j2;"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ \n Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n SubqueryAlias: j2\ \n Subquery:\ @@ -3215,7 +3215,7 @@ fn lateral_nested_left_join() { j1, \ (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))"; let expected = "Projection: *\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n Left Join: Filter: Boolean(true)\ \n TableScan: j2\ @@ -4281,7 +4281,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ @@ -4299,7 +4299,7 @@ fn test_table_alias() { let expected = "Projection: *\ \n SubqueryAlias: f\ \n Projection: t1.id AS c1, t2.age AS c2\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e9fcf07e7739..60569803322c 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -722,7 +722,7 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation 05)----Projection: Int64(2) AS val -06)------CrossJoin: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 8202b806a755..4f2778b5c0d1 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4050,7 +4050,7 @@ EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 ---- logical_plan 01)Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: lhs 04)------Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 519fbb887c7e..fe9ceaa7907a 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -671,7 +671,7 @@ query TT explain select * from t1 inner join t2 on true; ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan @@ -905,7 +905,7 @@ JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: e 03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") 04)------TableScan: employees projection=[emp_id, name] diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index be9321ddb945..558a9170c7d3 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4050,7 +4050,7 @@ query TT explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--SubqueryAlias: t1 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 0fef56aeea5c..9910ca8da71f 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -558,7 +558,7 @@ EXPLAIN SELECT * FROM ((SELECT column1 FROM foo) "T1" CROSS JOIN (SELECT column2 ---- logical_plan 01)SubqueryAlias: F -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: T1 04)------TableScan: foo projection=[column1] 05)----SubqueryAlias: T2 diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 59133379d443..aaba6998ee63 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -67,7 +67,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -86,7 +86,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 07)--------TableScan: t2 diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 08e54166d39a..5f1824bc4b30 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -780,7 +780,17 @@ pub async fn from_substrait_rel( )? .build() } - None => plan_err!("JoinRel without join condition is not allowed"), + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + false, + )? + .build() + } } } Some(RelType::Cross(cross)) => { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index fffa29df1db5..bc38ef82977f 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -73,17 +73,17 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST\ \n Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"EUROPE\")\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION\ \n TableScan: REGION\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PART\ \n TableScan: SUPPLIER\ \n TableScan: PARTSUPP\ @@ -105,8 +105,8 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: CUSTOMER\ \n TableScan: ORDERS" @@ -142,11 +142,11 @@ mod tests { \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"ASIA\") AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM\ @@ -206,9 +206,9 @@ mod tests { \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-10-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8(\"R\") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM\ @@ -230,16 +230,16 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION\ \n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ \n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: SUPPLIER\ \n TableScan: NATION" @@ -257,7 +257,7 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END)]]\ \n Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\ \n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: ORDERS\ \n TableScan: LINEITEM" ); @@ -292,7 +292,7 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8(\"PROMO%\") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ \n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-10-01\") AS Date32)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: PART" ); @@ -321,7 +321,7 @@ mod tests { \n Projection: SUPPLIER.S_SUPPKEY\ \n Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ \n TableScan: SUPPLIER\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: PARTSUPP\ \n TableScan: PART" ); @@ -353,8 +353,8 @@ mod tests { \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ \n Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY\ \n TableScan: LINEITEM\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: CUSTOMER\ \n TableScan: ORDERS\ \n TableScan: LINEITEM" @@ -369,7 +369,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]]\ \n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ \n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#23\") AND (PART.P_CONTAINER = CAST(Utf8(\"MED BAG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PKG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PACK\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#34\") AND (PART.P_CONTAINER = CAST(Utf8(\"LG CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\")\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: LINEITEM\ \n TableScan: PART" ); @@ -398,7 +398,7 @@ mod tests { \n Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ \n TableScan: LINEITEM\ \n TableScan: PARTSUPP\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: SUPPLIER\ \n TableScan: NATION" ); @@ -422,9 +422,9 @@ mod tests { \n Subquery:\ \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE\ \n TableScan: LINEITEM\ - \n CrossJoin:\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: SUPPLIER\ \n TableScan: LINEITEM\ \n TableScan: ORDERS\ From 12568bf1bd9b3a6cd1ea1b0632dfd5bdbc00bea1 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Sat, 19 Oct 2024 07:42:16 -0400 Subject: [PATCH 025/110] fix spelling (#13014) --- datafusion/functions/src/regex/regexpmatch.rs | 2 +- docs/source/user-guide/sql/scalar_functions_new.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 4a86adbe683a..a458b205f4e3 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -119,7 +119,7 @@ fn get_regexp_match_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_REGEX) - .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string.") + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") .with_syntax_example("regexp_match(str, regexp[, flags])") .with_sql_example(r#"```sql > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 1915623012f4..ac6e56a44c10 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1752,7 +1752,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `regexp_match` -Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matche in a string. +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. ``` regexp_match(str, regexp[, flags]) From 7a3414774cb7858d9649820ddffa59f5712a3153 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Sat, 19 Oct 2024 04:49:14 -0700 Subject: [PATCH 026/110] replace take_array with arrow util (#13013) --- datafusion/common/src/utils/mod.rs | 57 +------------------ .../src/aggregate/groups_accumulator.rs | 7 +-- .../functions-aggregate/src/first_last.rs | 8 +-- .../physical-plan/src/repartition/mod.rs | 5 +- datafusion/physical-plan/src/sorts/sort.rs | 5 +- .../src/windows/bounded_window_agg_exec.rs | 8 ++- 6 files changed, 19 insertions(+), 71 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 5bf0f08b092a..def1def9853c 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -26,8 +26,7 @@ use crate::error::{_internal_datafusion_err, _internal_err}; use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::buffer::OffsetBuffer; -use arrow::compute; -use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::compute::{partition, take_arrays, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; @@ -98,7 +97,7 @@ pub fn get_record_batch_at_indices( record_batch: &RecordBatch, indices: &PrimitiveArray, ) -> Result { - let new_columns = take_arrays(record_batch.columns(), indices)?; + let new_columns = take_arrays(record_batch.columns(), indices, None)?; RecordBatch::try_new_with_options( record_batch.schema(), new_columns, @@ -290,24 +289,6 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } -/// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. -/// -/// TODO: use implementation in arrow-rs when available: -/// -pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result> { - arrays - .iter() - .map(|array| { - compute::take( - array.as_ref(), - indices, - None, // None: no index check - ) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect() -} - pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -1003,40 +984,6 @@ mod tests { Ok(()) } - #[test] - fn test_take_arrays() -> Result<()> { - let arrays: Vec = vec![ - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), - Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), - Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), - ]; - - let row_indices_vec: Vec> = vec![ - // Get rows 0 and 1 - vec![0, 1], - // Get rows 0 and 1 - vec![0, 2], - // Get rows 1 and 3 - vec![1, 3], - // Get rows 2 and 4 - vec![2, 4], - ]; - for row_indices in row_indices_vec { - let indices: PrimitiveArray = - PrimitiveArray::from_iter_values(row_indices.iter().cloned()); - let chunk = take_arrays(&arrays, &indices)?; - for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { - for (idx, orig_idx) in row_indices.iter().enumerate() { - let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; - let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; - assert_eq!(res1, res2); - } - } - } - Ok(()) - } - #[test] fn test_get_at_indices() -> Result<()> { let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index b03df0224089..c936c80cbed7 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -27,11 +27,10 @@ use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, + compute::take_arrays, datatypes::UInt32Type, }; -use datafusion_common::{ - arrow_datafusion_err, utils::take_arrays, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; @@ -239,7 +238,7 @@ impl GroupsAccumulatorAdapter { // reorder the values and opt_filter by batch_indices so that // all values for each group are contiguous, then invoke the // accumulator once per group with values - let values = take_arrays(values, &batch_indices)?; + let values = take_arrays(values, &batch_indices, None)?; let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; // invoke each accumulator with the appropriate rows, first diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f6a84c84dcb0..2a3fc623657a 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -22,9 +22,9 @@ use std::fmt::Debug; use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_row_at_idx, take_arrays}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; @@ -340,7 +340,7 @@ impl Accumulator for FirstValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { let first_row = get_row_at_idx(&ordered_states, 0)?; @@ -670,7 +670,7 @@ impl Accumulator for LastValueAccumulator { filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - take_arrays(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 902d9f4477bc..90e62d6f11f8 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -38,10 +38,11 @@ use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; +use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::{PrimitiveArray, RecordBatchOptions}; -use datafusion_common::utils::{take_arrays, transpose}; +use datafusion_common::utils::transpose; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -300,7 +301,7 @@ impl BatchPartitioner { let _timer = partitioner_timer.timer(); // Produce batches based on indices - let columns = take_arrays(batch.columns(), &indices)?; + let columns = take_arrays(batch.columns(), &indices, None)?; let mut options = RecordBatchOptions::new(); options = options.with_row_count(Some(indices.len())); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5d86c2183b9e..8e13a2e07e49 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -40,13 +40,12 @@ use crate::{ SendableRecordBatchStream, Statistics, }; -use arrow::compute::{concat_batches, lexsort_to_indices, SortColumn}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::utils::take_arrays; use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -618,7 +617,7 @@ pub fn sort_batch( lexsort_to_indices(&sort_columns, fetch)? }; - let columns = take_arrays(batch.columns(), &indices)?; + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 4a4c940b22e2..6254ae139a00 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -42,7 +42,7 @@ use crate::{ use ahash::RandomState; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, - compute::{concat, concat_batches, sort_to_indices}, + compute::{concat, concat_batches, sort_to_indices, take_arrays}, datatypes::SchemaRef, record_batch::RecordBatch, }; @@ -50,7 +50,7 @@ use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::utils::{ evaluate_partition_ranges, get_at_indices, get_record_batch_at_indices, - get_row_at_idx, take_arrays, + get_row_at_idx, }; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -536,7 +536,9 @@ impl PartitionSearcher for LinearSearch { // We should emit columns according to row index ordering. let sorted_indices = sort_to_indices(&all_indices, None, None)?; // Construct new column according to row ordering. This fixes ordering - take_arrays(&new_columns, &sorted_indices).map(Some) + take_arrays(&new_columns, &sorted_indices, None) + .map(Some) + .map_err(|e| arrow_datafusion_err!(e)) } fn evaluate_partition_batches( From c7e5d8db453cf1b9d98aae520563a5ea67cdca4c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai <35887761+duongcongtoai@users.noreply.github.com> Date: Sun, 20 Oct 2024 12:53:15 +0200 Subject: [PATCH 027/110] Improve recursive `unnest` options API (#12836) * refactor * refactor unnest options * more test * resolve comments * add back doc * fix proto * flaky test * clippy * use indexmap * chore: compile err * chore: update cargo * chore: fmt cargotoml --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + datafusion/common/src/lib.rs | 2 +- datafusion/common/src/unnest.rs | 26 ++ datafusion/expr/src/logical_plan/builder.rs | 186 +++++------- datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 37 +-- datafusion/expr/src/logical_plan/tree_node.rs | 2 +- datafusion/physical-plan/src/unnest.rs | 70 ++--- datafusion/proto/proto/datafusion.proto | 18 +- datafusion/proto/src/generated/pbjson.rs | 285 +++++++++--------- datafusion/proto/src/generated/prost.rs | 32 +- .../proto/src/logical_plan/from_proto.rs | 13 +- datafusion/proto/src/logical_plan/mod.rs | 63 +--- datafusion/proto/src/logical_plan/to_proto.rs | 10 + datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/select.rs | 57 +++- datafusion/sql/src/utils.rs | 226 ++++++++------ 17 files changed, 504 insertions(+), 529 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index dfd07a7658ff..08d5d4843c62 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1571,6 +1571,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-expr", + "indexmap", "log", "regex", "sqlparser", diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 10541e01914a..8323f5efc86d 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -70,7 +70,7 @@ pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{ResolvedTableReference, TableReference}; -pub use unnest::UnnestOptions; +pub use unnest::{RecursionUnnestOption, UnnestOptions}; pub use utils::project_schema; // These are hidden from docs purely to avoid polluting the public view of what this crate exports. diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs index fd92267f9b4c..db48edd06160 100644 --- a/datafusion/common/src/unnest.rs +++ b/datafusion/common/src/unnest.rs @@ -17,6 +17,8 @@ //! [`UnnestOptions`] for unnesting structured types +use crate::Column; + /// Options for unnesting a column that contains a list type, /// replicating values in the other, non nested rows. /// @@ -60,10 +62,27 @@ /// └─────────┘ └─────┘ └─────────┘ └─────┘ /// c1 c2 c1 c2 /// ``` +/// +/// `recursions` instruct how a column should be unnested (e.g unnesting a column multiple +/// time, with depth = 1 and depth = 2). Any unnested column not being mentioned inside this +/// options is inferred to be unnested with depth = 1 #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] pub struct UnnestOptions { /// Should nulls in the input be preserved? Defaults to true pub preserve_nulls: bool, + /// If specific columns need to be unnested multiple times (e.g at different depth), + /// declare them here. Any unnested columns not being mentioned inside this option + /// will be unnested with depth = 1 + pub recursions: Vec, +} + +/// Instruction on how to unnest a column (mostly with a list type) +/// such as how to name the output, and how many level it should be unnested +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct RecursionUnnestOption { + pub input_column: Column, + pub output_column: Column, + pub depth: usize, } impl Default for UnnestOptions { @@ -71,6 +90,7 @@ impl Default for UnnestOptions { Self { // default to true to maintain backwards compatible behavior preserve_nulls: true, + recursions: vec![], } } } @@ -87,4 +107,10 @@ impl UnnestOptions { self.preserve_nulls = preserve_nulls; self } + + /// Set the recursions for the unnest operation + pub fn with_recursions(mut self, recursion: RecursionUnnestOption) -> Self { + self.recursions.push(recursion); + self + } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6ab50440ec5b..f119a2ade827 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -44,6 +44,8 @@ use crate::{ TableProviderFilterPushDown, TableSource, WriteOp, }; +use super::dml::InsertOp; +use super::plan::ColumnUnnestList; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; @@ -54,9 +56,6 @@ use datafusion_common::{ }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; -use super::dml::InsertOp; -use super::plan::{ColumnUnnestList, ColumnUnnestType}; - /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -1186,7 +1185,7 @@ impl LogicalPlanBuilder { ) -> Result { unnest_with_options( Arc::unwrap_or_clone(self.plan), - vec![(column.into(), ColumnUnnestType::Inferred)], + vec![column.into()], options, ) .map(Self::new) @@ -1197,26 +1196,6 @@ impl LogicalPlanBuilder { self, columns: Vec, options: UnnestOptions, - ) -> Result { - unnest_with_options( - Arc::unwrap_or_clone(self.plan), - columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(), - options, - ) - .map(Self::new) - } - - /// Unnest the given columns with the given [`UnnestOptions`] - /// if one column is a list type, it can be recursively and simultaneously - /// unnested into the desired recursion levels - /// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2) - pub fn unnest_columns_recursive_with_options( - self, - columns: Vec<(Column, ColumnUnnestType)>, - options: UnnestOptions, ) -> Result { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) @@ -1594,14 +1573,12 @@ impl TableSource for LogicalTableSource { /// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { - let unnestings = columns - .into_iter() - .map(|c| (c, ColumnUnnestType::Inferred)) - .collect(); - unnest_with_options(input, unnestings, UnnestOptions::default()) + unnest_with_options(input, columns, UnnestOptions::default()) } -pub fn get_unnested_list_datatype_recursive( +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( data_type: &DataType, depth: usize, ) -> Result { @@ -1620,27 +1597,6 @@ pub fn get_unnested_list_datatype_recursive( internal_err!("trying to unnest on invalid data type {:?}", data_type) } -/// Infer the unnest type based on the data type: -/// - list type: infer to unnest(list(col, depth=1)) -/// - struct type: infer to unnest(struct) -fn infer_unnest_type( - col_name: &String, - data_type: &DataType, -) -> Result { - match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - Ok(ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(col_name), - depth: 1, - }])) - } - DataType::Struct(_) => Ok(ColumnUnnestType::Struct), - _ => { - internal_err!("trying to unnest on invalid data type {:?}", data_type) - } - } -} - pub fn get_struct_unnested_columns( col_name: &String, inner_fields: &Fields, @@ -1729,20 +1685,15 @@ pub fn get_unnested_columns( /// ``` pub fn unnest_with_options( input: LogicalPlan, - columns_to_unnest: Vec<(Column, ColumnUnnestType)>, + columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; let mut struct_columns = vec![]; let indices_to_unnest = columns_to_unnest .iter() - .map(|col_unnesting| { - Ok(( - input.schema().index_of_column(&col_unnesting.0)?, - col_unnesting, - )) - }) - .collect::>>()?; + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; let input_schema = input.schema(); @@ -1767,51 +1718,59 @@ pub fn unnest_with_options( .enumerate() .map(|(index, (original_qualifier, original_field))| { match indices_to_unnest.get(&index) { - Some((column_to_unnest, unnest_type)) => { - let mut inferred_unnest_type = unnest_type.clone(); - if let ColumnUnnestType::Inferred = unnest_type { - inferred_unnest_type = infer_unnest_type( + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( &column_to_unnest.name, original_field.data_type(), + 1, )?; - } - let transformed_columns: Vec<(Column, Arc)> = - match inferred_unnest_type { - ColumnUnnestType::Struct => { + match original_field.data_type() { + DataType::Struct(_) => { struct_columns.push(index); - get_unnested_columns( - &column_to_unnest.name, - original_field.data_type(), - 1, - )? } - ColumnUnnestType::List(unnest_lists) => { - list_columns.extend( - unnest_lists - .iter() - .map(|ul| (index, ul.to_owned().clone())), - ); - unnest_lists - .iter() - .map( - |ColumnUnnestList { - output_column, - depth, - }| { - get_unnested_columns( - &output_column.name, - original_field.data_type(), - *depth, - ) - }, - ) - .collect::)>>>>()? - .into_iter() - .flatten() - .collect::>() + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); } - _ => return internal_err!("Invalid unnest type"), + _ => {} }; + } + // new columns dependent on the same original index dependency_indices .extend(std::iter::repeat(index).take(transformed_columns.len())); @@ -1860,7 +1819,7 @@ mod tests { use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::SchemaError; + use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2268,24 +2227,19 @@ mod tests { // Simultaneously unnesting a list (with different depth) and a struct column let plan = nested_table_scan("test_table")? - .unnest_columns_recursive_with_options( - vec![ - ( - "stringss".into(), - ColumnUnnestType::List(vec![ - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_1"), - depth: 1, - }, - ColumnUnnestList { - output_column: Column::from_name("stringss_depth_2"), - depth: 2, - }, - ]), - ), - ("struct_singular".into(), ColumnUnnestType::Inferred), - ], - UnnestOptions::default(), + .unnest_columns_with_options( + vec!["stringss".into(), "struct_singular".into()], + UnnestOptions::default() + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_1".into(), + depth: 1, + }) + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_2".into(), + depth: 2, + }), )? .build()?; diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a189d4635e00..da44cfb010d7 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -35,8 +35,8 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, ColumnUnnestType, CrossJoin, - DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, + projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable, + Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 10a99c9e78da..72d8f7158be2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3367,39 +3367,6 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents the unnesting operation on a column based on the context (a known struct -/// column, a list column, or let the planner infer the unnesting type). -/// -/// The inferred unnesting type works for both struct and list column, but the unnesting -/// will only be done once (depth = 1). In case recursion is needed on a multi-dimensional -/// list type, use [`ColumnUnnestList`] -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] -pub enum ColumnUnnestType { - // Unnesting a list column, a vector of ColumnUnnestList is used because - // a column can be unnested at different levels, resulting different output columns - List(Vec), - // for struct, there can only be one unnest performed on one column at a time - Struct, - // Infer the unnest type based on column schema - // If column is a list column, the unnest depth will be 1 - // This value is to support sugar syntax of old api in Dataframe (unnest(either_list_or_struct_column)) - Inferred, -} - -impl fmt::Display for ColumnUnnestType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ColumnUnnestType::List(lists) => { - let list_strs: Vec = - lists.iter().map(|list| list.to_string()).collect(); - write!(f, "List([{}])", list_strs.join(", ")) - } - ColumnUnnestType::Struct => write!(f, "Struct"), - ColumnUnnestType::Inferred => write!(f, "Inferred"), - } - } -} - /// Represent the unnesting operation on a list column, such as the recursion depth and /// the output column name after unnesting /// @@ -3438,7 +3405,7 @@ pub struct Unnest { /// The incoming logical plan pub input: Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: Vec<(Column, ColumnUnnestType)>, + pub exec_columns: Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: Vec<(usize, ColumnUnnestList)>, @@ -3462,7 +3429,7 @@ impl PartialOrd for Unnest { /// The incoming logical plan pub input: &'a Arc, /// Columns to run unnest on, can be a list of (List/Struct) columns - pub exec_columns: &'a Vec<(Column, ColumnUnnestType)>, + pub exec_columns: &'a Vec, /// refer to the indices(in the input schema) of columns /// that have type list to run unnest on pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 83206a2b2af5..606868e75abf 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -501,7 +501,7 @@ impl LogicalPlan { let exprs = columns .iter() - .map(|(c, _)| Expr::Column(c.clone())) + .map(|c| Expr::Column(c.clone())) .collect::>(); exprs.iter().apply_until_stop(f) } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 50af6b4960a5..2311541816f3 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -905,12 +905,10 @@ fn repeat_arrs_from_indices( #[cfg(test)] mod tests { use super::*; - use arrow::{ - datatypes::{Field, Int32Type}, - util::pretty::pretty_format_batches, - }; + use arrow::datatypes::{Field, Int32Type}; use arrow_array::{GenericListArray, OffsetSizeTrait, StringArray}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + use datafusion_common::assert_batches_eq; // Create a GenericListArray with the following list values: // [A, B, C], [], NULL, [D], NULL, [NULL, F] @@ -1092,38 +1090,37 @@ mod tests { &HashSet::default(), &UnnestOptions { preserve_nulls: true, + recursions: vec![], }, )?; - let actual = - format!("{}", pretty_format_batches(vec![ret].as_ref())?).to_lowercase(); - let expected = r#" -+---------------------------------+---------------------------------+---------------------------------+ -| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 | -+---------------------------------+---------------------------------+---------------------------------+ -| [1, 2, 3] | 1 | a | -| | 2 | b | -| [4, 5] | 3 | | -| [1, 2, 3] | | a | -| | | b | -| [4, 5] | | | -| [1, 2, 3] | 4 | a | -| | 5 | b | -| [4, 5] | | | -| [7, 8, 9, 10] | 7 | c | -| | 8 | d | -| [11, 12, 13] | 9 | | -| | 10 | | -| [7, 8, 9, 10] | | c | -| | | d | -| [11, 12, 13] | | | -| [7, 8, 9, 10] | 11 | c | -| | 12 | d | -| [11, 12, 13] | 13 | | -| | | e | -+---------------------------------+---------------------------------+---------------------------------+ - "# - .trim(); - assert_eq!(actual, expected); + + let expected = &[ +"+---------------------------------+---------------------------------+---------------------------------+", +"| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 |", +"+---------------------------------+---------------------------------+---------------------------------+", +"| [1, 2, 3] | 1 | a |", +"| | 2 | b |", +"| [4, 5] | 3 | |", +"| [1, 2, 3] | | a |", +"| | | b |", +"| [4, 5] | | |", +"| [1, 2, 3] | 4 | a |", +"| | 5 | b |", +"| [4, 5] | | |", +"| [7, 8, 9, 10] | 7 | c |", +"| | 8 | d |", +"| [11, 12, 13] | 9 | |", +"| | 10 | |", +"| [7, 8, 9, 10] | | c |", +"| | | d |", +"| [11, 12, 13] | | |", +"| [7, 8, 9, 10] | 11 | c |", +"| | 12 | d |", +"| [11, 12, 13] | 13 | |", +"| | | e |", +"+---------------------------------+---------------------------------+---------------------------------+", + ]; + assert_batches_eq!(expected, &[ret]); Ok(()) } @@ -1177,7 +1174,10 @@ mod tests { preserve_nulls: bool, expected: Vec, ) -> datafusion_common::Result<()> { - let options = UnnestOptions { preserve_nulls }; + let options = UnnestOptions { + preserve_nulls, + recursions: vec![], + }; let longest_length = find_longest_length(list_arrays, &options)?; let expected_array = Int64Array::from(expected); assert_eq!( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9964ab498fb1..a15fa2c5f9c6 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -264,7 +264,7 @@ message CopyToNode { message UnnestNode { LogicalPlanNode input = 1; - repeated ColumnUnnestExec exec_columns = 2; + repeated datafusion_common.Column exec_columns = 2; repeated ColumnUnnestListItem list_type_columns = 3; repeated uint64 struct_type_columns = 4; repeated uint64 dependency_indices = 5; @@ -285,17 +285,15 @@ message ColumnUnnestListRecursion { uint32 depth = 2; } -message ColumnUnnestExec { - datafusion_common.Column column = 1; - oneof UnnestType { - ColumnUnnestListRecursions list = 2; - datafusion_common.EmptyMessage struct = 3; - datafusion_common.EmptyMessage inferred = 4; - } -} - message UnnestOptions { bool preserve_nulls = 1; + repeated RecursionUnnestOption recursions = 2; +} + +message RecursionUnnestOption { + datafusion_common.Column output_column = 1; + datafusion_common.Column input_column = 2; + uint32 depth = 3; } message UnionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4417d1149681..d223e3646b51 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2306,145 +2306,6 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { deserializer.deserialize_struct("datafusion.ColumnIndex", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnUnnestExec { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.column.is_some() { - len += 1; - } - if self.unnest_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestExec", len)?; - if let Some(v) = self.column.as_ref() { - struct_ser.serialize_field("column", v)?; - } - if let Some(v) = self.unnest_type.as_ref() { - match v { - column_unnest_exec::UnnestType::List(v) => { - struct_ser.serialize_field("list", v)?; - } - column_unnest_exec::UnnestType::Struct(v) => { - struct_ser.serialize_field("struct", v)?; - } - column_unnest_exec::UnnestType::Inferred(v) => { - struct_ser.serialize_field("inferred", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ColumnUnnestExec { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "column", - "list", - "struct", - "inferred", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Column, - List, - Struct, - Inferred, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "column" => Ok(GeneratedField::Column), - "list" => Ok(GeneratedField::List), - "struct" => Ok(GeneratedField::Struct), - "inferred" => Ok(GeneratedField::Inferred), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnUnnestExec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnUnnestExec") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut column__ = None; - let mut unnest_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Column => { - if column__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - column__ = map_.next_value()?; - } - GeneratedField::List => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::List) -; - } - GeneratedField::Struct => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("struct")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Struct) -; - } - GeneratedField::Inferred => { - if unnest_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inferred")); - } - unnest_type__ = map_.next_value::<::std::option::Option<_>>()?.map(column_unnest_exec::UnnestType::Inferred) -; - } - } - } - Ok(ColumnUnnestExec { - column: column__, - unnest_type: unnest_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.ColumnUnnestExec", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for ColumnUnnestListItem { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17489,6 +17350,135 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RecursionUnnestOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.output_column.is_some() { + len += 1; + } + if self.input_column.is_some() { + len += 1; + } + if self.depth != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RecursionUnnestOption", len)?; + if let Some(v) = self.output_column.as_ref() { + struct_ser.serialize_field("outputColumn", v)?; + } + if let Some(v) = self.input_column.as_ref() { + struct_ser.serialize_field("inputColumn", v)?; + } + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "output_column", + "outputColumn", + "input_column", + "inputColumn", + "depth", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OutputColumn, + InputColumn, + Depth, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "outputColumn" | "output_column" => Ok(GeneratedField::OutputColumn), + "inputColumn" | "input_column" => Ok(GeneratedField::InputColumn), + "depth" => Ok(GeneratedField::Depth), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RecursionUnnestOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RecursionUnnestOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut output_column__ = None; + let mut input_column__ = None; + let mut depth__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OutputColumn => { + if output_column__.is_some() { + return Err(serde::de::Error::duplicate_field("outputColumn")); + } + output_column__ = map_.next_value()?; + } + GeneratedField::InputColumn => { + if input_column__.is_some() { + return Err(serde::de::Error::duplicate_field("inputColumn")); + } + input_column__ = map_.next_value()?; + } + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); + } + depth__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(RecursionUnnestOption { + output_column: output_column__, + input_column: input_column__, + depth: depth__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RecursionUnnestOption", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for RepartitionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20411,10 +20401,16 @@ impl serde::Serialize for UnnestOptions { if self.preserve_nulls { len += 1; } + if !self.recursions.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.UnnestOptions", len)?; if self.preserve_nulls { struct_ser.serialize_field("preserveNulls", &self.preserve_nulls)?; } + if !self.recursions.is_empty() { + struct_ser.serialize_field("recursions", &self.recursions)?; + } struct_ser.end() } } @@ -20427,11 +20423,13 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { const FIELDS: &[&str] = &[ "preserve_nulls", "preserveNulls", + "recursions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { PreserveNulls, + Recursions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20454,6 +20452,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { { match value { "preserveNulls" | "preserve_nulls" => Ok(GeneratedField::PreserveNulls), + "recursions" => Ok(GeneratedField::Recursions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20474,6 +20473,7 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { V: serde::de::MapAccess<'de>, { let mut preserve_nulls__ = None; + let mut recursions__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::PreserveNulls => { @@ -20482,10 +20482,17 @@ impl<'de> serde::Deserialize<'de> for UnnestOptions { } preserve_nulls__ = Some(map_.next_value()?); } + GeneratedField::Recursions => { + if recursions__.is_some() { + return Err(serde::de::Error::duplicate_field("recursions")); + } + recursions__ = Some(map_.next_value()?); + } } } Ok(UnnestOptions { preserve_nulls: preserve_nulls__.unwrap_or_default(), + recursions: recursions__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d3fe031a48c9..6b234be57a92 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -400,7 +400,7 @@ pub struct UnnestNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub exec_columns: ::prost::alloc::vec::Vec, + pub exec_columns: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] pub list_type_columns: ::prost::alloc::vec::Vec, #[prost(uint64, repeated, tag = "4")] @@ -432,28 +432,20 @@ pub struct ColumnUnnestListRecursion { pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnUnnestExec { - #[prost(message, optional, tag = "1")] - pub column: ::core::option::Option, - #[prost(oneof = "column_unnest_exec::UnnestType", tags = "2, 3, 4")] - pub unnest_type: ::core::option::Option, -} -/// Nested message and enum types in `ColumnUnnestExec`. -pub mod column_unnest_exec { - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum UnnestType { - #[prost(message, tag = "2")] - List(super::ColumnUnnestListRecursions), - #[prost(message, tag = "3")] - Struct(super::super::datafusion_common::EmptyMessage), - #[prost(message, tag = "4")] - Inferred(super::super::datafusion_common::EmptyMessage), - } -} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct UnnestOptions { #[prost(bool, tag = "1")] pub preserve_nulls: bool, + #[prost(message, repeated, tag = "2")] + pub recursions: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RecursionUnnestOption { + #[prost(message, optional, tag = "1")] + pub output_column: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub input_column: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub depth: u32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 20d007048a00..99b11939e95b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, - TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, + Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::expr::{Alias, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; @@ -56,6 +56,15 @@ impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: r.input_column.as_ref().unwrap().into(), + output_column: r.output_column.as_ref().unwrap().into(), + depth: r.depth as usize, + }) + .collect::>(), } } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 6061a7a0619a..f57910b09ade 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -19,11 +19,10 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; -use crate::protobuf::column_unnest_exec::UnnestType; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestExec, ColumnUnnestListItem, ColumnUnnestListRecursion, - ColumnUnnestListRecursions, CustomTableScanNode, SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, + SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -69,8 +68,7 @@ use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, ColumnUnnestType, Unnest}; -use datafusion_proto_common::EmptyMessage; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -875,33 +873,7 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(unnest.input, ctx, extension_codec)?; Ok(datafusion_expr::LogicalPlan::Unnest(Unnest { input: Arc::new(input), - exec_columns: unnest - .exec_columns - .iter() - .map(|c| { - ( - c.column.as_ref().unwrap().to_owned().into(), - match c.unnest_type.as_ref().unwrap() { - UnnestType::Inferred(_) => ColumnUnnestType::Inferred, - UnnestType::Struct(_) => ColumnUnnestType::Struct, - UnnestType::List(l) => ColumnUnnestType::List( - l.recursions - .iter() - .map(|ul| ColumnUnnestList { - output_column: ul - .output_column - .as_ref() - .unwrap() - .to_owned() - .into(), - depth: ul.depth as usize, - }) - .collect(), - ), - }, - ) - }) - .collect(), + exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest .list_type_columns .iter() @@ -1610,32 +1582,7 @@ impl AsLogicalPlan for LogicalPlanNode { input: Some(Box::new(input)), exec_columns: exec_columns .iter() - .map(|(col, unnesting)| ColumnUnnestExec { - column: Some(col.into()), - unnest_type: Some(match unnesting { - ColumnUnnestType::Inferred => { - UnnestType::Inferred(EmptyMessage {}) - } - ColumnUnnestType::Struct => { - UnnestType::Struct(EmptyMessage {}) - } - ColumnUnnestType::List(list) => { - UnnestType::List(ColumnUnnestListRecursions { - recursions: list - .iter() - .map(|ul| ColumnUnnestListRecursion { - output_column: Some( - ul.output_column - .to_owned() - .into(), - ), - depth: ul.depth as _, - }) - .collect(), - }) - } - }), - }) + .map(|col| col.into()) .collect(), list_type_columns: proto_unnest_list_items, struct_type_columns: struct_type_columns diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 15fec3a8b2a8..a34a220e490c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ WindowFrameUnits, WindowFunctionDefinition, }; +use crate::protobuf::RecursionUnnestOption; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -49,6 +50,15 @@ impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: Some((&r.input_column).into()), + output_column: Some((&r.output_column).into()), + depth: r.depth as u32, + }) + .collect(), } } } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 5c4b83fe38e1..90be576a884e 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,6 +46,7 @@ arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c665dec21df4..80a08da5e35d 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -25,8 +25,8 @@ use crate::utils::{ }; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, @@ -38,6 +38,7 @@ use datafusion_expr::{ qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use indexmap::IndexMap; use sqlparser::ast::{ Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, WildcardAdditionalOptions, WindowType, @@ -301,7 +302,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // The transformation happen bottom up, one at a time for each iteration // Only exhaust the loop if no more unnest transformation is found for i in 0.. { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); // from which column used for projection, before the unnest happen // including non unnest column and unnest column let mut inner_projection_exprs = vec![]; @@ -329,14 +330,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { break; } else { // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); - + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } let plan = LogicalPlanBuilder::from(intermediate_plan) .project(inner_projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_plan = plan; intermediate_select_exprs = outer_projection_exprs; @@ -405,7 +419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut intermediate_select_exprs = group_expr; loop { - let mut unnest_columns = vec![]; + let mut unnest_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; let outer_projection_exprs = rewrite_recursive_unnests_bottom_up( @@ -418,7 +432,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if unnest_columns.is_empty() { break; } else { - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); let mut projection_exprs = match &aggr_expr_using_columns { Some(exprs) => (*exprs).clone(), @@ -440,12 +454,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; projection_exprs.extend(inner_projection_exprs); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } + intermediate_plan = LogicalPlanBuilder::from(intermediate_plan) .project(projection_exprs)? - .unnest_columns_recursive_with_options( - unnest_columns, - unnest_options, - )? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? .build()?; intermediate_select_exprs = outer_projection_exprs; diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 787bc6634355..14436de01843 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -34,9 +34,9 @@ use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ - col, expr_vec_fmt, ColumnUnnestList, ColumnUnnestType, Expr, ExprSchemable, - LogicalPlan, + col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, }; +use indexmap::IndexMap; use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree @@ -295,7 +295,7 @@ pub(crate) fn value_to_string(value: &Value) -> Option { pub(crate) fn rewrite_recursive_unnests_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_exprs: &[Expr], ) -> Result> { @@ -326,7 +326,7 @@ struct RecursiveUnnestRewriter<'a> { top_most_unnest: Option, consecutive_unnest: Vec>, inner_projection_exprs: &'a mut Vec, - columns_unnestings: &'a mut Vec<(Column, ColumnUnnestType)>, + columns_unnestings: &'a mut IndexMap>>, transformed_root_exprs: Option>, } impl<'a> RecursiveUnnestRewriter<'a> { @@ -360,13 +360,11 @@ impl<'a> RecursiveUnnestRewriter<'a> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - // let placeholder_name = unnest_expr.display_name()?; let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); let post_unnest_name = format!("unnest_placeholder({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by - // let post_unnest_alias = print_unnest(&inner_expr_name, level); let placeholder_column = Column::from_name(placeholder_name.clone()); let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; @@ -380,10 +378,8 @@ impl<'a> RecursiveUnnestRewriter<'a> { self.inner_projection_exprs, expr_in_unnest.clone().alias(placeholder_name.clone()), ); - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::Struct, - )); + self.columns_unnestings + .insert(Column::from_name(placeholder_name.clone()), None); Ok( get_struct_unnested_columns(&placeholder_name, &inner_fields) .into_iter() @@ -399,39 +395,18 @@ impl<'a> RecursiveUnnestRewriter<'a> { expr_in_unnest.clone().alias(placeholder_name.clone()), ); - // Let post_unnest_column = Column::from_name(post_unnest_name); let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name); - match self + let list_unnesting = self .columns_unnestings - .iter_mut() - .find(|(inner_col, _)| inner_col == &placeholder_column) - { - // There is not unnesting done on this column yet - None => { - self.columns_unnestings.push(( - Column::from_name(placeholder_name.clone()), - ColumnUnnestType::List(vec![ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }]), - )); - } - // Some unnesting(at some level) has been done on this column - // e.g select unnest(column3), unnest(unnest(column3)) - Some((_, unnesting)) => match unnesting { - ColumnUnnestType::List(list) => { - let unnesting = ColumnUnnestList { - output_column: Column::from_name(post_unnest_name), - depth: level, - }; - if !list.contains(&unnesting) { - list.push(unnesting); - } - } - _ => { - return internal_err!("not reached"); - } - }, + .entry(placeholder_column) + .or_insert(Some(vec![])); + let unnesting = ColumnUnnestList { + output_column: Column::from_name(post_unnest_name), + depth: level, + }; + let list_unnestings = list_unnesting.as_mut().unwrap(); + if !list_unnestings.contains(&unnesting) { + list_unnestings.push(unnesting); } Ok(vec![post_unnest_expr]) } @@ -478,8 +453,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { } /// The rewriting only happens when the traversal has reached the top-most unnest expr - /// within a sequence of consecutive unnest exprs. - /// node, for example given a stack of expr + /// within a sequence of consecutive unnest exprs node /// /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))** /// ```text @@ -560,7 +534,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { // For column exprs that are not descendants of any unnest node // retain their projection // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b - // this condition can be checked by maintaining an Option + // this condition can be checked by maintaining an Option if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() { push_projection_dedupl(self.inner_projection_exprs, expr.clone()); } @@ -589,7 +563,7 @@ fn push_projection_dedupl(projection: &mut Vec, expr: Expr) { /// is done only for the bottom expression pub(crate) fn rewrite_recursive_unnest_bottom_up( input: &LogicalPlan, - unnest_placeholder_columns: &mut Vec<(Column, ColumnUnnestType)>, + unnest_placeholder_columns: &mut IndexMap>>, inner_projection_exprs: &mut Vec, original_expr: &Expr, ) -> Result> { @@ -610,8 +584,8 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102 // // The transformation looks like: - // - unnest(array_col) will be transformed into unnest(array_col) - // - unnest(array_col) + 1 will be transformed into unnest(array_col) + 1 + // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)") + // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1") let Transformed { data: transformed_expr, transformed, @@ -647,17 +621,33 @@ mod tests { use arrow_schema::Fields; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::{ - col, lit, unnest, ColumnUnnestType, EmptyRelation, LogicalPlan, + col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan, }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::expr_fn::count; + use indexmap::IndexMap; use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up}; - fn column_unnests_eq(l: Vec<(&str, &str)>, r: &[(Column, ColumnUnnestType)]) { - let r_formatted: Vec = - r.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); - let l_formatted: Vec = - l.iter().map(|i| format!("{}|{}", i.0, i.1)).collect(); + + fn column_unnests_eq( + l: Vec<&str>, + r: &IndexMap>>, + ) { + let r_formatted: Vec = r + .iter() + .map(|i| match i.1 { + None => format!("{}", i.0), + Some(vec) => format!( + "{}=>[{}]", + i.0, + vec.iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ), + }) + .collect(); + let l_formatted: Vec = l.iter().map(|i| i.to_string()).collect(); assert_eq!(l_formatted, r_formatted); } @@ -687,7 +677,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(unnest(3d_col)) + unnest(unnest(3d_col)) @@ -712,10 +702,9 @@ mod tests { .add(col("i64_col"))] ); column_unnests_eq( - vec![( - "unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + ], &unnest_placeholder_columns, ); @@ -746,9 +735,7 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(3d_col)", - "List([unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1])"), - ], + vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -794,7 +781,7 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // unnest(struct_col) @@ -813,7 +800,7 @@ mod tests { ] ); column_unnests_eq( - vec![("unnest_placeholder(struct_col)", "Struct")], + vec!["unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -833,11 +820,8 @@ mod tests { )?; column_unnests_eq( vec![ - ("unnest_placeholder(struct_col)", "Struct"), - ( - "unnest_placeholder(array_col)", - "List([unnest_placeholder(array_col,depth=1)|depth=1])", - ), + "unnest_placeholder(struct_col)", + "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); @@ -860,24 +844,44 @@ mod tests { ] ); - // A nested structure struct[[]] + Ok(()) + } + + // Unnest -> field access -> unnest + #[test] + fn test_transform_non_consecutive_unnests() -> Result<()> { + // List of struct + // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}] let schema = Schema::new(vec![ Field::new( - "struct_col", // {array_col: [1,2,3]} - ArrowDataType::Struct(Fields::from(vec![Field::new( - "matrix", - ArrowDataType::List(Arc::new(Field::new( - "matrix_row", - ArrowDataType::List(Arc::new(Field::new( - "item", - ArrowDataType::Int64, + "struct_list", + ArrowDataType::List(Arc::new(Field::new( + "element", + ArrowDataType::Struct(Fields::from(vec![ + Field::new( + // list of i64 + "subfield1", + ArrowDataType::List(Arc::new(Field::new( + "i64_element", + ArrowDataType::Int64, + true, + ))), true, - ))), - true, - ))), + ), + Field::new( + // list of utf8 + "subfield2", + ArrowDataType::List(Arc::new(Field::new( + "utf8_element", + ArrowDataType::Utf8, + true, + ))), + true, + ), + ])), true, - )])), - false, + ))), + true, ), Field::new("int_col", ArrowDataType::Int32, false), ]); @@ -889,39 +893,69 @@ mod tests { schema: Arc::new(dfschema), }); - let mut unnest_placeholder_columns = vec![]; + let mut unnest_placeholder_columns = IndexMap::new(); let mut inner_projection_exprs = vec![]; // An expr with multiple unnest - let original_expr = unnest(unnest(col("struct_col").field("matrix"))); + let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1")); let transformed_exprs = rewrite_recursive_unnest_bottom_up( &input, &mut unnest_placeholder_columns, &mut inner_projection_exprs, - &original_expr, + &select_expr1, )?; // Only the inner most/ bottom most unnest is transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(struct_col[matrix],depth=2)") - .alias("UNNEST(UNNEST(struct_col[matrix]))")] + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield1") + )] ); - // TODO: add a test case where - // unnest -> field access -> unnest column_unnests_eq( - vec![( - "unnest_placeholder(struct_col[matrix])", - "List([unnest_placeholder(struct_col[matrix],depth=2)|depth=2])", - )], + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + + assert_eq!( + inner_projection_exprs, + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + ); + + // continue rewrite another expr in select + let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &select_expr2, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield2") + )] + ); + + // unnest place holder columns remain the same + // because expr1 and expr2 derive from the same unnest result + column_unnests_eq( + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_col") - .field("matrix") - .alias("unnest_placeholder(struct_col[matrix])"),] + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] ); Ok(()) From 373fe23733d97dfac6195d77ebca0646fe9c37d0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 20 Oct 2024 08:45:51 -0400 Subject: [PATCH 028/110] Update version to 42.1.0, add CHANGELOG (#12986) (#12989) * Update version to 42.1.0, add CHANGELOG (#12986) * CHANGELOG for 42.1.0 * Update version to 42.1.0 * Update datafusion-cli/Cargo.lock * update config docs * update datafusion-cli --- Cargo.toml | 48 ++++----- datafusion-cli/Cargo.lock | 166 +++++++++++++++--------------- datafusion-cli/Cargo.toml | 4 +- dev/changelog/42.1.0.md | 42 ++++++++ docs/source/user-guide/configs.md | 2 +- 5 files changed, 152 insertions(+), 110 deletions(-) create mode 100644 dev/changelog/42.1.0.md diff --git a/Cargo.toml b/Cargo.toml index 2c142c87c892..63bfb7fce413 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" rust-version = "1.79" -version = "42.0.0" +version = "42.1.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -92,29 +92,29 @@ bytes = "1.4" chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "42.0.0", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "42.0.0" } -datafusion-common = { path = "datafusion/common", version = "42.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "42.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "42.0.0" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "42.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "42.0.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.0.0" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.0.0" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "42.0.0" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "42.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "42.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "42.0.0" } +datafusion = { path = "datafusion/core", version = "42.1.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "42.1.0" } +datafusion-common = { path = "datafusion/common", version = "42.1.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "42.1.0" } +datafusion-execution = { path = "datafusion/execution", version = "42.1.0" } +datafusion-expr = { path = "datafusion/expr", version = "42.1.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "42.1.0" } +datafusion-functions = { path = "datafusion/functions", version = "42.1.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "42.1.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.1.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.1.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "42.1.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.1.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "42.1.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.1.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.1.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "42.1.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "42.1.0" } +datafusion-proto = { path = "datafusion/proto", version = "42.1.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "42.1.0" } +datafusion-sql = { path = "datafusion/sql", version = "42.1.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "42.1.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "42.1.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 08d5d4843c62..612209fdd922 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.13" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e614738943d3f68c628ae3dbce7c3daffb196665f82f8c8ea6b65de73c79429" +checksum = "103db485efc3e41214fe4fda9f3dbeae2eb9082f48fd236e6095627a9422066e" dependencies = [ "bzip2", "flate2", @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" +checksum = "0dc2faec3205d496c7e57eff685dd944203df7ce16a4116d0281c44021788a7b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -545,9 +545,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.46.0" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" +checksum = "c93c241f52bc5e0476e259c953234dab7e2a35ee207ee202e86c0095ec4951dc" dependencies = [ "aws-credential-types", "aws-runtime", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" +checksum = "b259429be94a3459fa1b00c5684faee118d74f9577cc50aebadc36e507c63b5f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -663,9 +663,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -678,7 +678,7 @@ dependencies = [ "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ "jobserver", "libc", @@ -974,9 +974,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -984,9 +984,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.19" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -1162,9 +1162,9 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "dashmap" @@ -1182,7 +1182,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1239,7 +1239,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow-schema", "async-trait", @@ -1252,7 +1252,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "assert_cmd", @@ -1282,7 +1282,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "apache-avro", @@ -1305,7 +1305,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "42.0.0" +version = "42.1.0" dependencies = [ "log", "tokio", @@ -1313,7 +1313,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "chrono", @@ -1332,7 +1332,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1354,7 +1354,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "datafusion-common", @@ -1363,7 +1363,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-buffer", @@ -1388,7 +1388,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1407,7 +1407,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1419,7 +1419,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", @@ -1440,7 +1440,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", "datafusion-expr", @@ -1453,7 +1453,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1461,7 +1461,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "async-trait", @@ -1479,7 +1479,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1505,7 +1505,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1517,7 +1517,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-schema", @@ -1531,7 +1531,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "42.0.0" +version = "42.1.0" dependencies = [ "ahash", "arrow", @@ -1564,7 +1564,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "42.0.0" +version = "42.1.0" dependencies = [ "arrow", "arrow-array", @@ -2066,9 +2066,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -2090,9 +2090,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ "bytes", "futures-channel", @@ -2116,7 +2116,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2132,9 +2132,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-util", - "rustls 0.23.14", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2153,7 +2153,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.0", "pin-project-lite", "socket2", "tokio", @@ -2260,9 +2260,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -2339,9 +2339,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" @@ -2625,7 +2625,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.0", "itertools", "md-5", "parking_lot", @@ -2879,9 +2879,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" dependencies = [ "unicode-ident", ] @@ -2913,7 +2913,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.15", "socket2", "thiserror", "tokio", @@ -2930,7 +2930,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.14", + "rustls 0.23.15", "slab", "thiserror", "tinyvec", @@ -3074,7 +3074,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.0", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3085,7 +3085,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.14", + "rustls 0.23.15", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3204,9 +3204,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" dependencies = [ "once_cell", "ring", @@ -3261,9 +3261,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -3288,9 +3288,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "rustyline" @@ -3411,9 +3411,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.130" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944" dependencies = [ "itoa", "memchr", @@ -3772,7 +3772,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.14", + "rustls 0.23.15", "rustls-pki-types", "tokio", ] @@ -3950,9 +3950,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -4006,9 +4006,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -4017,9 +4017,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -4032,9 +4032,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -4044,9 +4044,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4054,9 +4054,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -4067,9 +4067,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" @@ -4086,9 +4086,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index fe929495aae6..8e4352612889 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "42.0.0" +version = "42.1.0" authors = ["Apache DataFusion "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -39,7 +39,7 @@ aws-sdk-sts = "1.43.0" # end pin aws-sdk crates aws-credential-types = "1.2.0" clap = { version = "4.5.16", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "42.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "42.1.0", features = [ "avro", "crypto_expressions", "datetime_expressions", diff --git a/dev/changelog/42.1.0.md b/dev/changelog/42.1.0.md new file mode 100644 index 000000000000..cf4f911150ac --- /dev/null +++ b/dev/changelog/42.1.0.md @@ -0,0 +1,42 @@ + + +# Apache DataFusion 42.1.0 Changelog + +This release consists of 5 commits from 4 contributors. See credits at the end of this changelog for more information. + +**Other:** + +- Backport update to arrow 53.1.0 on branch-42 [#12977](https://github.com/apache/datafusion/pull/12977) (alamb) +- Backport "Provide field and schema metadata missing on cross joins, and union with null fields" (#12729) [#12974](https://github.com/apache/datafusion/pull/12974) (matthewmturner) +- Backport "physical-plan: Cast nested group values back to dictionary if necessary" (#12586) [#12976](https://github.com/apache/datafusion/pull/12976) (matthewmturner) +- backport-to-DF-42: Provide field and schema metadata missing on distinct aggregations [#12975](https://github.com/apache/datafusion/pull/12975) (Xuanwo) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 2 Matthew Turner + 1 Andrew Lamb + 1 Andy Grove + 1 Xuanwo +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index c61a7b673334..10917932482c 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -66,7 +66,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 42.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 42.1.0 | (writing) Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | From 8d4614d6c43104a13b42c062d957843f25ee32db Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sun, 20 Oct 2024 05:46:25 -0700 Subject: [PATCH 029/110] Don't preserve functional dependency when generating UNION logical plan (#44) (#12979) * Don't preserve functional dependency when generating UNION logical plan * Remove extra lines --- datafusion/core/src/dataframe/mod.rs | 48 +++++++++++++++++++++ datafusion/expr/src/logical_plan/builder.rs | 11 +++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8a0829cd5e4b..4feadd260d7f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2623,6 +2623,54 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_union() -> Result<()> { + let df = test_table().await?; + + let df1 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![min(col("c2"))])? + // SELECT `c1` , min(c2) as `result` + .select(vec![col("c1"), min(col("c2")).alias("result")])?; + let df2 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![max(col("c3"))])? + // SELECT `c1` , max(c3) as `result` + .select(vec![col("c1"), max(col("c3")).alias("result")])?; + + let df_union = df1.union(df2)?; + let df = df_union + // GROUP BY `c1` + .aggregate( + vec![col("c1")], + vec![sum(col("result")).alias("sum_result")], + )? + // SELECT `c1`, sum(result) as `sum_result` + .select(vec![(col("c1")), col("sum_result")])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+----+------------+", + "| c1 | sum_result |", + "+----+------------+", + "| a | 84 |", + "| b | 69 |", + "| c | 124 |", + "| d | 126 |", + "| e | 121 |", + "+----+------------+" + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn test_aggregate_subexpr() -> Result<()> { let df = test_table().await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f119a2ade827..21304068a8ab 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -51,8 +51,8 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, + Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -1386,7 +1386,12 @@ pub fn validate_unique_names<'a>( pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { // Temporarily use the schema from the left input and later rely on the analyzer to // coerce the two schemas into a common one. - let schema = Arc::clone(left_plan.schema()); + + // Functional Dependencies doesn't preserve after UNION operation + let schema = (**left_plan.schema()).clone(); + let schema = + Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); + Ok(LogicalPlan::Union(Union { inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], schema, From 972e3abea4286b0d06c44498d576c8498ddd3be2 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 20 Oct 2024 14:47:59 +0200 Subject: [PATCH 030/110] feat: Decorrelate more predicate subqueries (#12945) * Decorrelate more predicate subqueries * Added sqllogictest explain tests --- datafusion/core/tests/tpcds_planning.rs | 3 - .../src/decorrelate_predicate_subquery.rs | 500 ++++++++---------- .../sqllogictest/test_files/subquery.slt | 170 +++++- .../sqllogictest/test_files/tpch/q20.slt.part | 8 +- .../tests/cases/roundtrip_logical_plan.rs | 18 +- 5 files changed, 405 insertions(+), 294 deletions(-) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index b99bc2680044..6beb29183483 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -571,7 +571,6 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -697,7 +696,6 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -750,7 +748,6 @@ async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q45() -> Result<()> { create_physical_plan(45).await diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index d1ac80003ba7..cdffa8c645ea 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -17,6 +17,7 @@ //! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; +use std::iter; use std::ops::Deref; use std::sync::Arc; @@ -27,16 +28,17 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; +use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; +use itertools::chain; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins @@ -48,79 +50,6 @@ impl DecorrelatePredicateSubquery { pub fn new() -> Self { Self::default() } - - fn rewrite_subquery( - &self, - mut subquery: Subquery, - config: &dyn OptimizerConfig, - ) -> Result { - subquery.subquery = Arc::new( - self.rewrite(Arc::unwrap_or_clone(subquery.subquery), config)? - .data, - ); - Ok(subquery) - } - - /// Finds expressions that have the predicate subqueries (and recurses when found) - /// - /// # Arguments - /// - /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases - /// - /// Returns a tuple (subqueries, non-subquery expressions) - fn extract_subquery_exprs( - &self, - predicate: Expr, - config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction_owned(predicate); // TODO: add ExistenceJoin to support disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.into_iter() { - match it { - Expr::Not(not_expr) => match *not_expr { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - !negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, !negated)); - } - expr => others.push(not(expr)), - }, - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - *expr, - negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let new_subquery = self.rewrite_subquery(subquery, config)?; - subqueries.push(SubqueryInfo::new(new_subquery, negated)); - } - expr => others.push(expr), - } - } - - Ok((subqueries, others)) - } } impl OptimizerRule for DecorrelatePredicateSubquery { @@ -133,69 +62,51 @@ impl OptimizerRule for DecorrelatePredicateSubquery { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + let plan = plan + .map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })? + .data; + let LogicalPlan::Filter(filter) = plan else { return Ok(Transformed::no(plan)); }; - // if there are no subqueries in the predicate, return the original plan - let has_subqueries = - split_conjunction(&filter.predicate) - .iter() - .any(|expr| match expr { - Expr::Not(not_expr) => { - matches!(not_expr.as_ref(), Expr::InSubquery(_) | Expr::Exists(_)) - } - Expr::InSubquery(_) | Expr::Exists(_) => true, - _ => false, - }); - - if !has_subqueries { + if !has_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - let Filter { - predicate, input, .. - } = filter; - let (subqueries, mut other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - if subqueries.is_empty() { + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + if with_subqueries.is_empty() { return internal_err!( "can not find expected subqueries in DecorrelatePredicateSubquery" ); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(input); - for subquery in subqueries { - if let Some(plan) = - build_join(&subquery, &cur_input, config.alias_generator())? - { - cur_input = plan; - } else { - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - let sub_query_expr = match subquery { - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: false, - } => in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: true, - } => not_in_subquery(expr, query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: false, - } => exists(query.subquery), - SubqueryInfo { - query, - where_in_expr: None, - negated: true, - } => not_exists(query.subquery), - }; - other_exprs.push(sub_query_expr); + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? + { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } } } @@ -216,6 +127,104 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +fn rewrite_inner_subqueries( + outer: LogicalPlan, + expr: Expr, + config: &dyn OptimizerConfig, +) -> Result<(LogicalPlan, Expr)> { + let mut cur_input = outer; + let alias = config.alias_generator(); + let expr_without_subqueries = expr.transform(|e| match e { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + negated, + }) => { + match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? + { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + } + } + Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }) => { + let in_predicate = subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |output_expr| { + Ok(Expr::eq(*expr.clone(), output_expr)) + })?; + match existence_join( + &cur_input, + Arc::clone(&subquery), + Some(in_predicate), + negated, + alias, + )? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))), + None => Ok(Transformed::no(in_subquery(*expr, subquery))), + } + } + _ => Ok(Transformed::no(e)), + })?; + Ok((cur_input, expr_without_subqueries.data)) +} + +enum SubqueryPredicate { + // The subquery expression is at the top level of the filter and can be fully replaced by a + // semi/anti join + Top(SubqueryInfo), + // The subquery expression is embedded within another expression and is replaced using an + // existence join + Embedded(Expr), +} + +fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { + match expr { + Expr::Not(not_expr) => match *not_expr { + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, !negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) + } + expr => SubqueryPredicate::Embedded(not(expr)), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) + } + expr => SubqueryPredicate::Embedded(expr), + } +} + +fn has_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + /// Optimize the subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// @@ -246,7 +255,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { /// Projection: t2.id /// TableScan: t2 /// ``` -fn build_join( +fn build_join_top( query_info: &SubqueryInfo, left: &LogicalPlan, alias: &Arc, @@ -265,9 +274,70 @@ fn build_join( }) .map_or(Ok(None), |v| v.map(Some))?; + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); + build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) +} + +/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join +/// and checking if the column is null or not. If native support is added for Existence/Mark then +/// we should use that instead. +/// +/// This is used to handle the case when the subquery is embedded in a more complex boolean +/// expression like and OR. For example +/// +/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL +/// Left Join: Filter: t1.id = __correlated_sq_1.id +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.id, true as __exists +/// TableScan: t2 +fn existence_join( + left: &LogicalPlan, + subquery: Arc, + in_predicate_opt: Option, + negated: bool, + alias_generator: &Arc, +) -> Result> { + // Add non nullable column to emulate existence join + let always_true_expr = lit(true).alias("__exists"); + let cols = chain( + subquery.schema().columns().into_iter().map(Expr::Column), + iter::once(always_true_expr), + ); + let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?; + let alias = alias_generator.next("__correlated_sq"); + + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists")); + let exists_expr = if negated { + exists_col.is_null() + } else { + exists_col.is_not_null() + }; + + Ok( + build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)? + .map(|plan| (plan, exists_expr)), + ) +} +fn build_join( + left: &LogicalPlan, + subquery: &LogicalPlan, + in_predicate_opt: Option, + join_type: JoinType, + alias: String, +) -> Result> { let mut pull_up = PullUpCorrelatedExpr::new() .with_in_predicate_opt(in_predicate_opt.clone()) .with_exists_sub_query(in_predicate_opt.is_none()); @@ -278,7 +348,7 @@ fn build_join( } let sub_query_alias = LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.to_string())? + .alias(alias.to_string())? .build()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -289,8 +359,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Option::Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { @@ -302,7 +371,7 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate.and(join_filter)) } @@ -315,17 +384,13 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate) } _ => None, } { // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; @@ -361,6 +426,19 @@ impl SubqueryInfo { negated, } } + + pub fn expr(self) -> Expr { + match self.where_in_expr { + Some(expr) => match self.negated { + true => not_in_subquery(expr, self.query.subquery), + false => in_subquery(expr, self.query.subquery), + }, + None => match self.negated { + true => not_exists(self.query.subquery), + false => exists(self.query.subquery), + }, + } + } } #[cfg(test)] @@ -371,7 +449,7 @@ mod tests { use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{and, binary_expr, col, lit, not, or, out_ref_col, table_scan}; + use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -442,60 +520,6 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for IN subquery with additional OR filter - /// filter expression not modified - #[test] - fn in_subquery_with_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(or( - and( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - ), - in_subquery(col("c"), test_subquery_with_name("sq")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) OR test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - - #[test] - fn in_subquery_with_and_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and( - or( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - in_subquery(col("b"), test_subquery_with_name("sq1")?), - ), - in_subquery(col("c"), test_subquery_with_name("sq2")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq1.c [c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq2.c [c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test for nested IN subqueries #[test] fn in_subquery_nested() -> Result<()> { @@ -512,51 +536,19 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } - /// Test for filter input modification in case filter not supported - /// Outer filter expression not modified while inner converted to join - #[test] - fn in_subquery_input_modified() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq_inner")?))? - .project(vec![col("b"), col("c")])? - .alias("wrapped")? - .filter(or( - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - in_subquery(col("c"), test_subquery_with_name("sq_outer")?), - ))? - .project(vec![col("b")])? - .build()?; - - let expected = "Projection: wrapped.b [b:UInt32]\ - \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq_outer.c [c:UInt32]\ - \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ - \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_inner.c [c:UInt32]\ - \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] @@ -630,13 +622,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; @@ -1003,44 +995,6 @@ mod tests { Ok(()) } - /// Test for correlated IN subquery filter with disjustions - #[test] - fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( - LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(col("orders.o_custkey")), - )? - .project(vec![col("orders.o_custkey")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - in_subquery(col("customer.c_custkey"), sq) - .or(col("customer.c_custkey").eq(lit(1))), - )? - .project(vec![col("customer.c_custkey")])? - .build()?; - - // TODO: support disjunction - for now expect unaltered plan - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey IN () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: outer_ref(customer.c_custkey) = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) - } - /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { @@ -1407,13 +1361,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 30b3631681e7..22857dd285c2 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -415,13 +415,13 @@ query TT explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int)) ---- logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +01)LeftSemi Join: t1.t1_id = __correlated_sq_2.t2_id 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] -03)--SubqueryAlias: __correlated_sq_1 +03)--SubqueryAlias: __correlated_sq_2 04)----Projection: t2.t2_id -05)------LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int +05)------LeftSemi Join: Filter: __correlated_sq_1.t1_int > t2.t2_int 06)--------TableScan: t2 projection=[t2_id, t2_int] -07)--------SubqueryAlias: __correlated_sq_2 +07)--------SubqueryAlias: __correlated_sq_1 08)----------TableScan: t1 projection=[t1_int] #invalid_scalar_subquery @@ -1028,6 +1028,168 @@ false true true +# in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 + +# exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists +04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------Projection: t2.t2_id, Boolean(true) AS __exists +08)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +33 c 3 +44 d 4 + +# in_subquery_to_join_with_correlated_outer_filter_and_or +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL +03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists +04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int] +07)----------SubqueryAlias: __correlated_sq_1 +08)------------TableScan: t3 projection=[t3_id] +09)--------SubqueryAlias: __correlated_sq_2 +10)----------Projection: t2.t2_id, Boolean(true) AS __exists +11)------------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +11 a 1 +22 b 2 +44 d 4 + +# Nested subqueries +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where exists ( + select * from t2 where t1.t1_id = t2.t2_id OR exists ( + select * from t3 where t2.t2_id = t3.t3_id + ) +) +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 # issue: https://github.com/apache/datafusion/issues/7027 query TTTT rowsort diff --git a/datafusion/sqllogictest/test_files/tpch/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/q20.slt.part index 67ea87b6ee61..177e38e51ca4 100644 --- a/datafusion/sqllogictest/test_files/tpch/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q20.slt.part @@ -58,19 +58,19 @@ order by logical_plan 01)Sort: supplier.s_name ASC NULLS LAST 02)--Projection: supplier.s_name, supplier.s_address -03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey +03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_2.ps_suppkey 04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey 08)------------Filter: nation.n_name = Utf8("CANADA") 09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] -10)------SubqueryAlias: __correlated_sq_1 +10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) -13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey +13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_1.p_partkey 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] -15)--------------SubqueryAlias: __correlated_sq_2 +15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey 17)------------------Filter: part.p_name LIKE Utf8("forest%") 18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ae67b6924436..06a047b108bd 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -474,16 +474,14 @@ async fn roundtrip_inlist_5() -> Result<()> { // using assert_expected_plan here as a workaround assert_expected_plan( "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2\ - \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2", + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\ + \n Projection: data.a, data.f, Boolean(true)\ + \n Left Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a, Boolean(true)\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", true).await } From e9584bc46ffc574cd65044d4199966402def1d15 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:20:01 -0400 Subject: [PATCH 031/110] Added expresion to "with_standard_argument" (#12926) * Added default value to 'with_standard_argument' * small fix * change function * small changes * with_argument change * ran build * small fix --- datafusion/expr/src/udf_docs.rs | 22 ++++++---- .../src/approx_distinct.rs | 2 +- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- .../src/approx_percentile_cont_with_weight.rs | 2 +- .../functions-aggregate/src/array_agg.rs | 2 +- datafusion/functions-aggregate/src/average.rs | 2 +- .../functions-aggregate/src/bit_and_or_xor.rs | 6 +-- .../functions-aggregate/src/bool_and_or.rs | 4 +- .../functions-aggregate/src/correlation.rs | 4 +- datafusion/functions-aggregate/src/count.rs | 2 +- .../functions-aggregate/src/covariance.rs | 8 ++-- .../functions-aggregate/src/first_last.rs | 4 +- datafusion/functions-aggregate/src/median.rs | 2 +- datafusion/functions-aggregate/src/min_max.rs | 4 +- .../functions-aggregate/src/nth_value.rs | 2 +- datafusion/functions-aggregate/src/stddev.rs | 4 +- datafusion/functions-aggregate/src/sum.rs | 2 +- .../functions-aggregate/src/variance.rs | 4 +- datafusion/functions/src/crypto/digest.rs | 2 +- datafusion/functions/src/crypto/md5.rs | 2 +- datafusion/functions/src/crypto/sha224.rs | 2 +- datafusion/functions/src/crypto/sha256.rs | 2 +- datafusion/functions/src/crypto/sha384.rs | 2 +- datafusion/functions/src/datetime/to_date.rs | 2 +- datafusion/functions/src/math/abs.rs | 2 +- datafusion/functions/src/math/factorial.rs | 2 +- datafusion/functions/src/math/gcd.rs | 4 +- datafusion/functions/src/math/iszero.rs | 2 +- datafusion/functions/src/math/lcm.rs | 4 +- datafusion/functions/src/math/log.rs | 4 +- datafusion/functions/src/math/monotonicity.rs | 44 +++++++++---------- datafusion/functions/src/math/nans.rs | 2 +- datafusion/functions/src/math/power.rs | 4 +- datafusion/functions/src/math/round.rs | 2 +- datafusion/functions/src/math/signum.rs | 2 +- datafusion/functions/src/math/trunc.rs | 2 +- datafusion/functions/src/regex/regexpcount.rs | 4 +- datafusion/functions/src/regex/regexplike.rs | 4 +- datafusion/functions/src/regex/regexpmatch.rs | 2 +- .../functions/src/regex/regexpreplace.rs | 4 +- datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/bit_length.rs | 2 +- datafusion/functions/src/string/btrim.rs | 2 +- datafusion/functions/src/string/chr.rs | 2 +- datafusion/functions/src/string/concat.rs | 2 +- datafusion/functions/src/string/concat_ws.rs | 7 +-- datafusion/functions/src/string/contains.rs | 2 +- datafusion/functions/src/string/ends_with.rs | 2 +- datafusion/functions/src/string/initcap.rs | 2 +- datafusion/functions/src/string/lower.rs | 2 +- datafusion/functions/src/string/ltrim.rs | 2 +- .../functions/src/string/octet_length.rs | 2 +- datafusion/functions/src/string/overlay.rs | 2 +- datafusion/functions/src/string/repeat.rs | 2 +- datafusion/functions/src/string/replace.rs | 6 +-- datafusion/functions/src/string/rtrim.rs | 2 +- datafusion/functions/src/string/split_part.rs | 2 +- .../functions/src/string/starts_with.rs | 2 +- datafusion/functions/src/string/to_hex.rs | 2 +- datafusion/functions/src/string/upper.rs | 2 +- .../functions/src/unicode/character_length.rs | 2 +- datafusion/functions/src/unicode/left.rs | 2 +- datafusion/functions/src/unicode/lpad.rs | 2 +- datafusion/functions/src/unicode/reverse.rs | 2 +- datafusion/functions/src/unicode/right.rs | 2 +- datafusion/functions/src/unicode/rpad.rs | 2 +- datafusion/functions/src/unicode/strpos.rs | 2 +- datafusion/functions/src/unicode/substr.rs | 2 +- .../functions/src/unicode/substrindex.rs | 2 +- datafusion/functions/src/unicode/translate.rs | 2 +- .../user-guide/sql/aggregate_functions_new.md | 2 +- .../user-guide/sql/scalar_functions_new.md | 4 +- 73 files changed, 129 insertions(+), 126 deletions(-) diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs index 8e255566606c..63d1a964345d 100644 --- a/datafusion/expr/src/udf_docs.rs +++ b/datafusion/expr/src/udf_docs.rs @@ -147,23 +147,29 @@ impl DocumentationBuilder { /// Add a standard "expression" argument to the documentation /// - /// This is similar to [`Self::with_argument`] except that a standard - /// description is appended to the end: `"Can be a constant, column, or - /// function, and any combination of arithmetic operators."` - /// - /// The argument is rendered like + /// The argument is rendered like below if Some() is passed through: /// /// ```text /// : /// expression to operate on. Can be a constant, column, or function, and any combination of operators. /// ``` + /// + /// The argument is rendered like below if None is passed through: + /// + /// ```text + /// : + /// The expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` pub fn with_standard_argument( self, arg_name: impl Into, - expression_type: impl AsRef, + expression_type: Option<&str>, ) -> Self { - let expression_type = expression_type.as_ref(); - self.with_argument(arg_name, format!("{expression_type} expression to operate on. Can be a constant, column, or function, and any combination of operators.")) + let description = format!( + "{} expression to operate on. Can be a constant, column, or function, and any combination of operators.", + expression_type.unwrap_or("The") + ); + self.with_argument(arg_name, description) } pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index d6cc711147b5..1df106feb4d3 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -332,7 +332,7 @@ fn get_approx_distinct_doc() -> &'static Documentation { +-----------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 84442fa5a2e6..96609622a51e 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -145,7 +145,7 @@ fn get_approx_median_doc() -> &'static Documentation { +-----------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 9b8a99e977d2..83b9f714fa89 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -293,7 +293,7 @@ fn get_approx_percentile_cont_doc() -> &'static Documentation { | 65.0 | +-------------------------------------------------+ ```"#) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") .with_argument("centroids", "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory.") .build() diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index a5362713a6fb..b86fec1e037e 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -179,7 +179,7 @@ fn get_approx_percentile_cont_with_weight_doc() -> &'static Documentation { +----------------------------------------------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .with_argument("weight", "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators.") .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") .build() diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 28ff6fb346e5..6f523756832e 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -168,7 +168,7 @@ fn get_array_agg_doc() -> &'static Documentation { +-----------------------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 8782f8cfcc7c..67b824c2ea79 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -263,7 +263,7 @@ fn get_avg_doc() -> &'static Documentation { +---------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index c5382c168f17..0a281ad81467 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -142,7 +142,7 @@ fn get_bit_and_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_GENERAL) .with_description("Computes the bitwise AND of all non-null input values.") .with_syntax_example("bit_and(expression)") - .with_standard_argument("expression", "Integer") + .with_standard_argument("expression", Some("Integer")) .build() .unwrap() }) @@ -156,7 +156,7 @@ fn get_bit_or_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_GENERAL) .with_description("Computes the bitwise OR of all non-null input values.") .with_syntax_example("bit_or(expression)") - .with_standard_argument("expression", "Integer") + .with_standard_argument("expression", Some("Integer")) .build() .unwrap() }) @@ -172,7 +172,7 @@ fn get_bit_xor_doc() -> &'static Documentation { "Computes the bitwise exclusive OR of all non-null input values.", ) .with_syntax_example("bit_xor(expression)") - .with_standard_argument("expression", "Integer") + .with_standard_argument("expression", Some("Integer")) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 63ad1ea573d5..b410bfa139e9 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -201,7 +201,7 @@ fn get_bool_and_doc() -> &'static Documentation { +----------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) @@ -350,7 +350,7 @@ fn get_bool_or_doc() -> &'static Documentation { +----------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index d5dc482d68d2..40429289d768 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -134,8 +134,8 @@ fn get_corr_doc() -> &'static Documentation { +--------------------------------+ ```"#, ) - .with_standard_argument("expression1", "First") - .with_standard_argument("expression2", "Second") + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2511c70c4608..61dbfd674993 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -357,7 +357,7 @@ fn get_count_doc() -> &'static Documentation { | 120 | +------------------+ ```"#) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index f3b323d74d30..4b2b21059d16 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -150,8 +150,8 @@ fn get_covar_samp_doc() -> &'static Documentation { +-----------------------------------+ ```"#, ) - .with_standard_argument("expression1", "First") - .with_standard_argument("expression2", "Second") + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) .build() .unwrap() }) @@ -248,8 +248,8 @@ fn get_covar_pop_doc() -> &'static Documentation { +-----------------------------------+ ```"#, ) - .with_standard_argument("expression1", "First") - .with_standard_argument("expression2", "Second") + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 2a3fc623657a..c708d23ae6c5 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -191,7 +191,7 @@ fn get_first_value_doc() -> &'static Documentation { +-----------------------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) @@ -519,7 +519,7 @@ fn get_last_value_doc() -> &'static Documentation { +-----------------------------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 5a0cac2c829e..e0011e2e0f69 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -177,7 +177,7 @@ fn get_median_doc() -> &'static Documentation { +----------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index d576b1fdad78..8102d0e4794b 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -357,7 +357,7 @@ fn get_max_doc() -> &'static Documentation { +----------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) @@ -1187,7 +1187,7 @@ fn get_min_doc() -> &'static Documentation { +----------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 6d8cea8f0531..3e7f51af5265 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -191,7 +191,7 @@ fn get_nth_value_doc() -> &'static Documentation { | 2 | 45000 | 45000 | +---------+--------+-------------------------+ ```"#) - .with_standard_argument("expression", "The column or expression to retrieve the nth value from.") + .with_argument("expression", "The column or expression to retrieve the nth value from.") .with_argument("n", "The position (nth) of the value to retrieve, based on the ordering.") .build() .unwrap() diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 332a8efcc0f9..0d1821687524 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -158,7 +158,7 @@ fn get_stddev_doc() -> &'static Documentation { +----------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) @@ -282,7 +282,7 @@ fn get_stddev_pop_doc() -> &'static Documentation { +--------------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 3b561f3028de..943f66a92c00 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -260,7 +260,7 @@ fn get_sum_doc() -> &'static Documentation { +-----------------------+ ```"#, ) - .with_standard_argument("expression", "The") + .with_standard_argument("expression", None) .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 49a30344c212..8453c9d3010b 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -153,7 +153,7 @@ fn get_variance_sample_doc() -> &'static Documentation { "Returns the statistical sample variance of a set of numbers.", ) .with_syntax_example("var(expression)") - .with_standard_argument("expression", "Numeric") + .with_standard_argument("expression", Some("Numeric")) .build() .unwrap() }) @@ -259,7 +259,7 @@ fn get_variance_population_doc() -> &'static Documentation { "Returns the statistical population variance of a set of numbers.", ) .with_syntax_example("var_pop(expression)") - .with_standard_argument("expression", "Numeric") + .with_standard_argument("expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index 9ec07b1cab53..0e43fb7785df 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -98,7 +98,7 @@ fn get_digest_doc() -> &'static Documentation { ```"#, ) .with_standard_argument( - "expression", "String") + "expression", Some("String")) .with_argument( "algorithm", "String expression specifying algorithm to use. Must be one of: diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index f273c9d28c23..062d63bcc018 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -112,7 +112,7 @@ fn get_md5_doc() -> &'static Documentation { +-------------------------------------+ ```"#, ) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .build() .unwrap() }) diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 868c8cdc3558..39202d5bf691 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -68,7 +68,7 @@ fn get_sha224_doc() -> &'static Documentation { +------------------------------------------+ ```"#, ) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .build() .unwrap() }) diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 99a470efbc1f..74deb3fc6caa 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -92,7 +92,7 @@ fn get_sha256_doc() -> &'static Documentation { +--------------------------------------+ ```"#, ) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .build() .unwrap() }) diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index afe2db7478f7..9b1e1ba9ec3c 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -92,7 +92,7 @@ fn get_sha384_doc() -> &'static Documentation { +-----------------------------------------+ ```"#, ) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .build() .unwrap() }) diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index b21fe995cea6..82e189698c5e 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -111,7 +111,7 @@ Note: `to_date` returns Date32, which represents its values as the number of day Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) "#) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .with_argument( "format_n", "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 5dcbb99eae65..5511a57d8566 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -201,7 +201,7 @@ fn get_abs_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the absolute value of a number.") .with_syntax_example("abs(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index ac04c03190ae..4b87284744d3 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -85,7 +85,7 @@ fn get_factorial_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Factorial. Returns 1 if value is less than 2.") .with_syntax_example("factorial(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 5e56dacb9380..f4edef3acca3 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -87,8 +87,8 @@ fn get_gcd_doc() -> &'static Documentation { "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.", ) .with_syntax_example("gcd(expression_x, expression_y)") - .with_standard_argument("expression_x", "First numeric") - .with_standard_argument("expression_y", "Second numeric") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index b8deee2c6125..7e5d4fe77ffa 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -90,7 +90,7 @@ fn get_iszero_doc() -> &'static Documentation { "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", ) .with_syntax_example("iszero(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 844dbfd39d38..64b07ce606f2 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -88,8 +88,8 @@ fn get_lcm_doc() -> &'static Documentation { "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", ) .with_syntax_example("lcm(expression_x, expression_y)") - .with_standard_argument("expression_x", "First numeric") - .with_standard_argument("expression_y", "Second numeric") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 07ff8e2166ff..89ba14e32ba0 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -57,8 +57,8 @@ fn get_log_doc() -> &'static Documentation { .with_description("Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.") .with_syntax_example(r#"log(base, numeric_expression) log(numeric_expression)"#) - .with_standard_argument("base", "Base numeric") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("base", Some("Base numeric")) + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 959434d74f82..19c85f4b6e3c 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -46,7 +46,7 @@ pub fn get_acos_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the arc cosine or inverse cosine of a number.") .with_syntax_example("acos(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -79,7 +79,7 @@ pub fn get_acosh_doc() -> &'static Documentation { "Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.", ) .with_syntax_example("acosh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -108,7 +108,7 @@ pub fn get_asin_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the arc sine or inverse sine of a number.") .with_syntax_example("asin(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -129,7 +129,7 @@ pub fn get_asinh_doc() -> &'static Documentation { "Returns the area hyperbolic sine or inverse hyperbolic sine of a number.", ) .with_syntax_example("asinh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -148,7 +148,7 @@ pub fn get_atan_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the arc tangent or inverse tangent of a number.") .with_syntax_example("atan(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -179,7 +179,7 @@ pub fn get_atanh_doc() -> &'static Documentation { "Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.", ) .with_syntax_example("atanh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -223,7 +223,7 @@ pub fn get_cbrt_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the cube root of a number.") .with_syntax_example("cbrt(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -244,7 +244,7 @@ pub fn get_ceil_doc() -> &'static Documentation { "Returns the nearest integer greater than or equal to a number.", ) .with_syntax_example("ceil(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -265,7 +265,7 @@ pub fn get_cos_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the cosine of a number.") .with_syntax_example("cos(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -295,7 +295,7 @@ pub fn get_cosh_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the hyperbolic cosine of a number.") .with_syntax_example("cosh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -314,7 +314,7 @@ pub fn get_degrees_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Converts radians to degrees.") .with_syntax_example("degrees(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -333,7 +333,7 @@ pub fn get_exp_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the base-e exponential of a number.") .with_syntax_example("exp(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -354,7 +354,7 @@ pub fn get_floor_doc() -> &'static Documentation { "Returns the nearest integer less than or equal to a number.", ) .with_syntax_example("floor(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -382,7 +382,7 @@ pub fn get_ln_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the natural logarithm of a number.") .with_syntax_example("ln(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -410,7 +410,7 @@ pub fn get_log2_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the base-2 logarithm of a number.") .with_syntax_example("log2(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -438,7 +438,7 @@ pub fn get_log10_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the base-10 logarithm of a number.") .with_syntax_example("log10(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -457,7 +457,7 @@ pub fn get_radians_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Converts degrees to radians.") .with_syntax_example("radians(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -478,7 +478,7 @@ pub fn get_sin_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the sine of a number.") .with_syntax_example("sin(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -497,7 +497,7 @@ pub fn get_sinh_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the hyperbolic sine of a number.") .with_syntax_example("sinh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -525,7 +525,7 @@ pub fn get_sqrt_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the square root of a number.") .with_syntax_example("sqrt(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -546,7 +546,7 @@ pub fn get_tan_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the tangent of a number.") .with_syntax_example("tan(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) @@ -565,7 +565,7 @@ pub fn get_tanh_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Returns the hyperbolic tangent of a number.") .with_syntax_example("tanh(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 79e4587958dc..c1dd1aacc35a 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -107,7 +107,7 @@ fn get_isnan_doc() -> &'static Documentation { "Returns true if a given number is +NaN or -NaN otherwise returns false.", ) .with_syntax_example("isnan(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a99afaec97f7..9125f9b0fecd 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -181,8 +181,8 @@ fn get_power_doc() -> &'static Documentation { "Returns a base expression raised to the power of an exponent.", ) .with_syntax_example("power(base, exponent)") - .with_standard_argument("base", "Numeric") - .with_standard_argument("exponent", "Exponent numeric") + .with_standard_argument("base", Some("Numeric")) + .with_standard_argument("exponent", Some("Exponent numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index ae8eee0dbba2..fec1f1ce1aee 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -114,7 +114,7 @@ fn get_round_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_MATH) .with_description("Rounds a number to the nearest integer.") .with_syntax_example("round(numeric_expression[, decimal_places])") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .with_argument( "decimal_places", "Optional. The number of decimal places to round to. Defaults to 0.", diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 6c020b0ce52a..ac881eb42f26 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -101,7 +101,7 @@ Negative numbers return `-1`. Zero and positive numbers return `1`."#, ) .with_syntax_example("signum(numeric_expression)") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .build() .unwrap() }) diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 17a84420318e..9a05684d238e 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -119,7 +119,7 @@ fn get_trunc_doc() -> &'static Documentation { "Truncates a number to a whole number or truncated to the specified decimal places.", ) .with_syntax_example("trunc(numeric_expression[, decimal_places])") - .with_standard_argument("numeric_expression", "Numeric") + .with_standard_argument("numeric_expression", Some("Numeric")) .with_argument("decimal_places", r#"Optional. The number of decimal places to truncate to. Defaults to 0 (truncate to a whole number). If `decimal_places` is a positive integer, truncates digits to the diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 880c91094555..7f7896ecd923 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -127,8 +127,8 @@ fn get_regexp_count_doc() -> &'static Documentation { | 1 | +---------------------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") - .with_standard_argument("regexp","Regular") + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp",Some("Regular")) .with_argument("start", "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function.") .with_argument("flags", r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index aad67e4ecab6..13de7888aa5f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -67,8 +67,8 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) "#) - .with_standard_argument("str", "String") - .with_standard_argument("regexp","Regular") + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp", Some("Regular")) .with_argument("flags", r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index a458b205f4e3..019666bd7b2d 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -137,7 +137,7 @@ fn get_regexp_match_doc() -> &'static Documentation { ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) "#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("regexp","Regular expression to match against. Can be a constant, column, or function.") .with_argument("flags", diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 279e5c6ba9dd..4d8e5e5fe3e3 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -154,10 +154,10 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) "#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("regexp","Regular expression to match against. Can be a constant, column, or function.") - .with_standard_argument("replacement", "Replacement string") + .with_standard_argument("replacement", Some("Replacement string")) .with_argument("flags", r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **g**: (global) Search globally and don't return after the first match diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 8d61661f97b8..b76d70d7e9d2 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -99,7 +99,7 @@ fn get_ascii_doc() -> &'static Documentation { +-------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("chr") .build() .unwrap() diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 7d162e7d411b..25b56341fcaa 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -107,7 +107,7 @@ fn get_bit_length_doc() -> &'static Documentation { +--------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("length") .with_related_udf("octet_length") .build() diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 82b7599f0735..f689f27d9d24 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -122,7 +122,7 @@ fn get_btrim_doc() -> &'static Documentation { | datafusion | +-------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") .with_related_udf("ltrim") .with_related_udf("rtrim") diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index ae0900af37d3..0d94cab08d91 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -125,7 +125,7 @@ fn get_chr_doc() -> &'static Documentation { +--------------------+ ```"#, ) - .with_standard_argument("expression", "String") + .with_standard_argument("expression", Some("String")) .with_related_udf("ascii") .build() .unwrap() diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 33a926863a4a..a4218c39e7b2 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -270,7 +270,7 @@ fn get_concat_doc() -> &'static Documentation { +-------------------------------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("str_n", "Subsequent string expressions to concatenate.") .with_related_udf("concat_ws") .build() diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 17361b073315..8d966f495663 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -295,11 +295,8 @@ fn get_concat_ws_doc() -> &'static Documentation { "separator", "Separator to insert between concatenated strings.", ) - .with_standard_argument("str", "String") - .with_standard_argument( - "str_n", - "Subsequent string expressions to concatenate.", - ) + .with_standard_argument("str", Some("String")) + .with_argument("str_n", "Subsequent string expressions to concatenate.") .with_related_udf("concat") .build() .unwrap() diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 86f1eda03342..d0e63bb0f353 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -95,7 +95,7 @@ fn get_contains_doc() -> &'static Documentation { +---------------------------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("search_str", "The string to search for in str.") .build() .unwrap() diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 8c90cbc3b146..88978a35c0b7 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -103,7 +103,7 @@ fn get_ends_with_doc() -> &'static Documentation { +--------------------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("substr", "Substring to test for.") .build() .unwrap() diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 78c95b9a5e35..5fd1e7929881 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -96,7 +96,7 @@ fn get_initcap_doc() -> &'static Documentation { | Apache Datafusion | +------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("lower") .with_related_udf("upper") .build() diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index f82b11ca9051..b07189a832dc 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -89,7 +89,7 @@ fn get_lower_doc() -> &'static Documentation { +-------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("initcap") .with_related_udf("upper") .build() diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index b64dcda7218e..91809d691647 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -122,7 +122,7 @@ fn get_ltrim_doc() -> &'static Documentation { | datafusion___ | +-------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") .with_related_udf("btrim") .with_related_udf("rtrim") diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 04094396fadc..2ac2bf70da23 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -110,7 +110,7 @@ fn get_octet_length_doc() -> &'static Documentation { +--------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("bit_length") .with_related_udf("length") .build() diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3b31bc360851..796776304f4a 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -108,7 +108,7 @@ fn get_overlay_doc() -> &'static Documentation { | Thomas | +--------------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("substr", "Substring to replace in str.") .with_argument("pos", "The start position to start the replace in str.") .with_argument("count", "The count of characters to be replaced from start position of str. If not specified, will use substr length instead.") diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 7364c7d36f10..aa69f9c6609a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -107,7 +107,7 @@ fn get_repeat_doc() -> &'static Documentation { +-------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("n", "Number of times to repeat the input string.") .build() .unwrap() diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 612cd7276bab..91abc39da058 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -96,9 +96,9 @@ fn get_replace_doc() -> &'static Documentation { | ABcdbaBA | +-------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") - .with_standard_argument("substr", "Substring expression to replace in the input string. Substring expression") - .with_standard_argument("replacement", "Replacement substring") + .with_standard_argument("str", Some("String")) + .with_standard_argument("substr", Some("Substring expression to replace in the input string. Substring")) + .with_standard_argument("replacement", Some("Replacement substring")) .build() .unwrap() }) diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 1a27502a2082..06c8a85c38dd 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -122,7 +122,7 @@ fn get_rtrim_doc() -> &'static Documentation { | ___datafusion | +-------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") .with_related_udf("btrim") .with_related_udf("ltrim") diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index cea3b0890f9b..ea01cb1f56f9 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -198,7 +198,7 @@ fn get_split_part_doc() -> &'static Documentation { | 3 | +--------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("delimiter", "String or character to split on.") .with_argument("pos", "Position of the part to return.") .build() diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 713b642d5e91..dce161a2e14b 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -102,7 +102,7 @@ fn get_starts_with_doc() -> &'static Documentation { +----------------------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("substr", "Substring to test for.") .build() .unwrap() diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 72cd4fbffa33..e0033d2d1cb0 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -134,7 +134,7 @@ fn get_to_hex_doc() -> &'static Documentation { +-------------------------+ ```"#, ) - .with_standard_argument("int", "Integer") + .with_standard_argument("int", Some("Integer")) .build() .unwrap() }) diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index bfcb2a86994d..042c26b2e3da 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -88,7 +88,7 @@ fn get_upper_doc() -> &'static Documentation { +---------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("initcap") .with_related_udf("lower") .build() diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 6e74135b6028..7858a59664d3 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -103,7 +103,7 @@ fn get_character_length_doc() -> &'static Documentation { +------------------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_related_udf("bit_length") .with_related_udf("octet_length") .build() diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index 6610cfb25e79..a6c2b9768f0b 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -115,7 +115,7 @@ fn get_left_doc() -> &'static Documentation { | data | +-----------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("n", "Number of characters to return.") .with_related_udf("right") .build() diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 948afd050cdb..767eda203c8f 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -119,7 +119,7 @@ fn get_lpad_doc() -> &'static Documentation { | helloDolly | +---------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("n", "String length to pad to.") .with_argument("padding_str", "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") .with_related_udf("rpad") diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 32872c28a613..baf3b56636e2 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -105,7 +105,7 @@ fn get_reverse_doc() -> &'static Documentation { +-----------------------------+ ```"#, ) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .build() .unwrap() }) diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 585611fe60e4..ab3b7ba1a27e 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -115,7 +115,7 @@ fn get_right_doc() -> &'static Documentation { | fusion | +------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("n", "Number of characters to return") .with_related_udf("left") .build() diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index fd4c1ee6fe38..bd9d625105e9 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -140,7 +140,7 @@ fn get_rpad_doc() -> &'static Documentation { ```"#) .with_standard_argument( "str", - "String", + Some("String"), ) .with_argument("n", "String length to pad to.") .with_argument("padding_str", diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index e4696e4e5c3f..152623b0e5dc 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -95,7 +95,7 @@ fn get_strpos_doc() -> &'static Documentation { | 5 | +----------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("substr", "Substring expression to search for.") .build() .unwrap() diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 4e0c293577b9..5a8c2500900b 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -170,7 +170,7 @@ fn get_substr_doc() -> &'static Documentation { | fus | +----------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") .build() diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 436d554a49f7..c04839783f58 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -115,7 +115,7 @@ If count is negative, everything to the right of the final delimiter (counting f | org | +----------------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("delim", "The string to find in str to split str.") .with_argument("count", "The number of times to search for the delimiter. Can be either a positive or negative number.") .build() diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index cbee9a6fe1f2..fa626b396b3b 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -101,7 +101,7 @@ fn get_translate_doc() -> &'static Documentation { | there | +--------------------------------------------------+ ```"#) - .with_standard_argument("str", "String") + .with_standard_argument("str", Some("String")) .with_argument("chars", "Characters to translate.") .with_argument("translation", "Translation characters. Translation characters replace only characters at the same position in the **chars** string.") .build() diff --git a/docs/source/user-guide/sql/aggregate_functions_new.md b/docs/source/user-guide/sql/aggregate_functions_new.md index 6c9d9b043fa6..24ef313f3d49 100644 --- a/docs/source/user-guide/sql/aggregate_functions_new.md +++ b/docs/source/user-guide/sql/aggregate_functions_new.md @@ -562,7 +562,7 @@ nth_value(expression, n ORDER BY expression) #### Arguments -- **expression**: The column or expression to retrieve the nth value from. expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression**: The column or expression to retrieve the nth value from. - **n**: The position (nth) of the value to retrieve, based on the ordering. #### Example diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ac6e56a44c10..1f4ec1c27858 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -908,7 +908,7 @@ concat_ws(separator, str[, ..., str_n]) - **separator**: Separator to insert between concatenated strings. - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **str_n**: Subsequent string expressions to concatenate. expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. #### Example @@ -1250,7 +1250,7 @@ replace(str, substr, replacement) #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **substr**: Substring expression to replace in the input string. Substring expression expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to replace in the input string. Substring expression to operate on. Can be a constant, column, or function, and any combination of operators. - **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. #### Example From b42d9b81caddb5b53800f8eed32f9af4a9e3a01d Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Mon, 21 Oct 2024 16:34:14 +1000 Subject: [PATCH 032/110] fix: Dialect requires derived table alias (#12994) * fix: Dialect requires table alias (#46) * fix: Add Dialect option for requiring table aliases * feat: Add CustomDialectBuilder for requires_table_alias * docs: Spelling * refactor: rename requires_derived_table_alias * refactor: rename requires_derived_table_alias * review: Rewrite match to if, add another test case * test: Update RHS expected * test: Update tests with more cases --- datafusion/sql/src/unparser/dialect.rs | 27 +++++++++++ datafusion/sql/src/unparser/plan.rs | 57 ++++++++++++++++++++--- datafusion/sql/tests/cases/plan_to_sql.rs | 39 ++++++++++++++++ 3 files changed, 116 insertions(+), 7 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index cfc28f2c499f..02934a004d6f 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -120,6 +120,12 @@ pub trait Dialect: Send + Sync { true } + /// Whether the dialect requires a table alias for any subquery in the FROM clause + /// This affects behavior when deriving logical plans for Sort, Limit, etc. + fn requires_derived_table_alias(&self) -> bool { + false + } + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is /// a custom implementation for the function. @@ -300,6 +306,10 @@ impl Dialect for MySqlDialect { ast::DataType::Datetime(None) } + fn requires_derived_table_alias(&self) -> bool { + true + } + fn scalar_function_to_sql_overrides( &self, unparser: &Unparser, @@ -362,6 +372,7 @@ pub struct CustomDialect { timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: sqlparser::ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialect { @@ -384,6 +395,7 @@ impl Default for CustomDialect { ), date32_cast_dtype: sqlparser::ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } } @@ -472,6 +484,10 @@ impl Dialect for CustomDialect { Ok(None) } + + fn requires_derived_table_alias(&self) -> bool { + self.requires_derived_table_alias + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -503,6 +519,7 @@ pub struct CustomDialectBuilder { timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialectBuilder { @@ -531,6 +548,7 @@ impl CustomDialectBuilder { ), date32_cast_dtype: sqlparser::ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } @@ -551,6 +569,7 @@ impl CustomDialectBuilder { date32_cast_dtype: self.date32_cast_dtype, supports_column_alias_in_table_alias: self .supports_column_alias_in_table_alias, + requires_derived_table_alias: self.requires_derived_table_alias, } } @@ -653,4 +672,12 @@ impl CustomDialectBuilder { self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias; self } + + pub fn with_requires_derived_table_alias( + mut self, + requires_derived_table_alias: bool, + ) -> Self { + self.requires_derived_table_alias = requires_derived_table_alias; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c22400f1faa1..8e70654d8d6f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -222,9 +222,14 @@ impl Unparser<'_> { Ok(()) } - fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> { + fn derive( + &self, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + alias: Option, + ) -> Result<()> { let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ + derived_builder.lateral(false).alias(alias).subquery({ let inner_statement = self.plan_to_sql(plan)?; if let ast::Statement::Query(inner_query) = inner_statement { inner_query @@ -239,6 +244,23 @@ impl Unparser<'_> { Ok(()) } + fn derive_with_dialect_alias( + &self, + alias: &str, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + ) -> Result<()> { + if self.dialect.requires_derived_table_alias() { + self.derive( + plan, + relation, + Some(self.new_table_alias(alias.to_string(), vec![])), + ) + } else { + self.derive(plan, relation, None) + } + } + fn select_to_sql_recursively( &self, plan: &LogicalPlan, @@ -284,7 +306,11 @@ impl Unparser<'_> { // Projection can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_projection", + plan, + relation, + ); } self.reconstruct_select_statement(plan, p, select)?; self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) @@ -311,8 +337,13 @@ impl Unparser<'_> { LogicalPlan::Limit(limit) => { // Limit can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_limit", + plan, + relation, + ); } + if let Some(fetch) = limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( @@ -350,7 +381,11 @@ impl Unparser<'_> { LogicalPlan::Sort(sort) => { // Sort can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_sort", + plan, + relation, + ); } let Some(query_ref) = query else { return internal_err!( @@ -396,7 +431,11 @@ impl Unparser<'_> { LogicalPlan::Distinct(distinct) => { // Distinct can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_distinct", + plan, + relation, + ); } let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), @@ -559,7 +598,11 @@ impl Unparser<'_> { // Covers cases where the UNION is a subquery and the projection is at the top level if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_union", + plan, + relation, + ); } let input_exprs: Vec = union diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 2a3c5b5f6b2b..0de74e050553 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -261,6 +261,45 @@ fn roundtrip_statement_with_dialect() -> Result<()> { unparser_dialect: Box, } let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort gets derived into a subquery + // for MySQL, this subquery needs an alias + "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort still gets derived into a subquery in default dialect + // except for the default dialect, the subquery is left non-aliased + "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", + expected: + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select 1 as j1_id);", + expected: + "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select * from (select * from j1 limit 10);", + expected: + "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", expected: From 69a464846a9253d73ea252b5882d3607ce7f2c7f Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Mon, 21 Oct 2024 03:03:24 -0700 Subject: [PATCH 033/110] [Minor]: Add data based sort expression test (#12992) * Initial commit * Fix formatting, minor changes * Minor changes * Move test to fuzz tests * Add comment to test --- .../tests/fuzz_cases/equivalence/ordering.rs | 68 +++++++++- .../tests/fuzz_cases/equivalence/utils.rs | 122 +++++++++++++++++- .../physical-expr/src/equivalence/mod.rs | 31 +++-- 3 files changed, 204 insertions(+), 17 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index b1ee24a7a373..604d1a1000c3 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,8 +16,9 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - TestScalarUDF, + convert_to_orderings, create_random_schema, create_test_schema_2, + generate_table_for_eq_properties, generate_table_for_orderings, + is_table_same_after_sort, TestScalarUDF, }; use arrow_schema::SortOptions; use datafusion_common::{DFSchema, Result}; @@ -158,3 +159,66 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { Ok(()) } + +// This test checks given a table is ordered with `[a ASC, b ASC, c ASC, d ASC]` and `[a ASC, c ASC, b ASC, d ASC]` +// whether the table is also ordered with `[a ASC, b ASC, d ASC]` and `[a ASC, c ASC, d ASC]` +// Since these orderings cannot be deduced, these orderings shouldn't be satisfied by the table generated. +// For background see discussion: https://github.com/apache/datafusion/issues/12700#issuecomment-2411134296 +#[test] +fn test_ordering_satisfy_on_data() -> Result<()> { + let schema = create_test_schema_2()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let orderings = vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ], + // [a ASC, c ASC, b ASC, d ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + + let batch = generate_table_for_orderings(orderings, schema, 1000, 10)?; + + // [a ASC, c ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC] can be deduced + let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(is_table_same_after_sort(ordering, batch.clone())?); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index e51dabd6437f..ce3afba81ee2 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,23 +15,29 @@ // specific language governing permissions and limitations // under the License. // -// use datafusion_physical_expr::expressions::{col, Column}; use datafusion::physical_plan::expressions::col; use datafusion::physical_plan::expressions::Column; use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; use std::any::Any; +use std::cmp::Ordering; use std::sync::Arc; use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; -use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow_array::{ + ArrayRef, Float32Array, Float64Array, PrimitiveArray, RecordBatch, UInt32Array, +}; use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::utils::{ + compare_rows, get_record_batch_at_indices, get_row_at_idx, +}; use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; - use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; + use itertools::izip; use rand::prelude::*; @@ -67,7 +73,7 @@ pub fn output_schema( } // Generate a schema which consists of 6 columns (a, b, c, d, e, f) -fn create_test_schema_2() -> Result { +pub fn create_test_schema_2() -> Result { let a = Field::new("a", DataType::Float64, true); let b = Field::new("b", DataType::Float64, true); let c = Field::new("c", DataType::Float64, true); @@ -374,6 +380,114 @@ pub fn generate_table_for_eq_properties( Ok(RecordBatch::try_from_iter(res)?) } +// Generate a table that satisfies the given orderings; +pub fn generate_table_for_orderings( + mut orderings: Vec, + schema: SchemaRef, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + assert!(!orderings.is_empty()); + // Sort the inner vectors by their lengths (longest first) + orderings.sort_by_key(|v| std::cmp::Reverse(v.len())); + + let arrays = schema + .fields + .iter() + .map(|field| { + ( + field.name(), + generate_random_f64_array(n_elem, n_distinct, &mut rng), + ) + }) + .collect::>(); + let batch = RecordBatch::try_from_iter(arrays)?; + + // Sort batch according to first ordering expression + let sort_columns = get_sort_columns(&batch, &orderings[0])?; + let sort_indices = lexsort_to_indices(&sort_columns, None)?; + let mut batch = get_record_batch_at_indices(&batch, &sort_indices)?; + + // prune out rows that is invalid according to remaining orderings. + for ordering in orderings.iter().skip(1) { + let sort_columns = get_sort_columns(&batch, ordering)?; + + // Collect sort options and values into separate vectors. + let (sort_options, sort_col_values): (Vec<_>, Vec<_>) = sort_columns + .into_iter() + .map(|sort_col| (sort_col.options.unwrap(), sort_col.values)) + .unzip(); + + let mut cur_idx = 0; + let mut keep_indices = vec![cur_idx as u32]; + for next_idx in 1..batch.num_rows() { + let cur_row = get_row_at_idx(&sort_col_values, cur_idx)?; + let next_row = get_row_at_idx(&sort_col_values, next_idx)?; + + if compare_rows(&cur_row, &next_row, &sort_options)? != Ordering::Greater { + // next row satisfies ordering relation given, compared to the current row. + keep_indices.push(next_idx as u32); + cur_idx = next_idx; + } + } + // Only keep valid rows, that satisfies given ordering relation. + batch = get_record_batch_at_indices( + &batch, + &PrimitiveArray::from_iter_values(keep_indices), + )?; + } + + Ok(batch) +} + +// Convert each tuple to PhysicalSortExpr +pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], +) -> Vec { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(*expr), + options: *options, + }) + .collect() +} + +// Convert each inner tuple to PhysicalSortExpr +pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], +) -> Vec> { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() +} + +// Utility function to generate random f64 array +fn generate_random_f64_array( + n_elems: usize, + n_distinct: usize, + rng: &mut StdRng, +) -> ArrayRef { + let values: Vec = (0..n_elems) + .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) +} + +// Helper function to get sort columns from a batch +fn get_sort_columns( + batch: &RecordBatch, + ordering: LexOrderingRef, +) -> Result> { + ordering + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>() +} + #[derive(Debug, Clone)] pub struct TestScalarUDF { pub(crate) signature: Signature, diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 7726458a46ac..253f1196491b 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -72,6 +72,7 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use super::*; use crate::expressions::col; use crate::PhysicalSortExpr; @@ -385,14 +386,6 @@ mod tests { let schema = eq_properties.schema(); let mut schema_vec = vec![None; schema.fields.len()]; - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - // Fill constant columns for constant in &eq_properties.constants { let col = constant.expr().as_any().downcast_ref::().unwrap(); @@ -409,7 +402,7 @@ mod tests { .map(|PhysicalSortExpr { expr, options }| { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); + let arr = generate_random_f64_array(n_elem, n_distinct, &mut rng); ( SortColumn { values: arr, @@ -430,7 +423,9 @@ mod tests { for eq_group in eq_properties.eq_group.iter() { let representative_array = get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + .unwrap_or_else(|| { + generate_random_f64_array(n_elem, n_distinct, &mut rng) + }); for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); @@ -446,11 +441,25 @@ mod tests { ( field.name(), // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + elem.unwrap_or_else(|| { + generate_random_f64_array(n_elem, n_distinct, &mut rng) + }), ) }) .collect(); Ok(RecordBatch::try_from_iter(res)?) } + + // Utility function to generate random f64 array + fn generate_random_f64_array( + n_elems: usize, + n_distinct: usize, + rng: &mut StdRng, + ) -> ArrayRef { + let values: Vec = (0..n_elems) + .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + } } From edeca39c36a27d8cd4d0365c1f692398b79c88d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20=C5=9Een?= Date: Mon, 21 Oct 2024 14:12:27 +0300 Subject: [PATCH 034/110] Removed last usages of scalar_inputs, scalar_input_types and inputs2 to use arrow unary/binary for performance (#12972) * removed last uses of make_function_scalar_inputs * delete make_function_scalar_inputs * fix * refactored other macros * fix unary CI * fix base f32/f64 mismatch not caught by tests * import order changes * Update log.rs * stylistic changes --------- Co-authored-by: berkaysynnada --- datafusion/functions/src/macros.rs | 137 +++++++------------------ datafusion/functions/src/math/log.rs | 57 +++++----- datafusion/functions/src/math/nanvl.rs | 39 +++---- datafusion/functions/src/math/power.rs | 34 +++--- datafusion/functions/src/math/round.rs | 92 +++++++---------- 5 files changed, 138 insertions(+), 221 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 744a0189125c..9bc038e71edc 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -112,26 +112,6 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - /// Downcast an argument to a specific array type, returning an internal error /// if the cast fails /// @@ -168,9 +148,9 @@ macro_rules! make_math_unary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -231,24 +211,16 @@ macro_rules! make_math_unary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -286,9 +258,9 @@ macro_rules! make_math_binary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; use datafusion_expr::{ @@ -347,23 +319,26 @@ macro_rules! make_math_binary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -382,43 +357,3 @@ macro_rules! make_math_binary_udf { } }; } - -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 89ba14e32ba0..f82c0df34e27 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,11 +22,10 @@ use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; @@ -140,37 +139,40 @@ impl ScalarUDFImpl for LogFunc { let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) + } + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( + x.as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } @@ -259,6 +261,7 @@ mod tests { use super::*; + use arrow::array::{Float32Array, Float64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index b82ee0d45744..cfd21256dd96 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,10 +18,11 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; @@ -29,8 +30,6 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct NanvlFunc { signature: Signature, @@ -113,14 +112,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -131,14 +127,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -146,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 9125f9b0fecd..9bb6006d55b9 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,9 +16,13 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::{ArrowNativeTypeOp, DataType}; +use super::log::LogFunc; +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, @@ -27,13 +31,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; - -use arrow::array::{ArrayRef, Float64Array, Int64Array}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::{Arc, OnceLock}; - -use super::log::LogFunc; #[derive(Debug)] pub struct PowerFunc { @@ -90,15 +88,16 @@ impl ScalarUDFImpl for PowerFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ + } DataType::Int64 => { let bases = downcast_arg!(&args[0], "base", Int64Array); let exponents = downcast_arg!(&args[1], "exponent", Int64Array); @@ -116,7 +115,7 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(None), }) .collect::>() - .map(Arc::new)? as ArrayRef + .map(Arc::new)? as _ } other => { @@ -195,6 +194,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index fec1f1ce1aee..cf0f53a80a43 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,13 +20,11 @@ use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; @@ -148,17 +146,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(decimal_places) => { let options = CastOptions { @@ -169,20 +163,18 @@ pub fn round(args: &[ArrayRef]) -> Result { .map_err(|e| { exec_datafusion_err!("Invalid values for decimal places: {e}") })?; - Ok(Arc::new(make_function_inputs2!( - &args[0], + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + values, decimal_places, - "value", - "decimal_places", - Float64Array, - Int32Array, - { - |value: f64, decimal_places: i32| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") @@ -196,18 +188,13 @@ pub fn round(args: &[ArrayRef]) -> Result { "Invalid value for decimal places: {decimal_places}: {e}" ) })?; - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = @@ -218,20 +205,17 @@ pub fn round(args: &[ArrayRef]) -> Result { panic!("Unexpected result of ColumnarValue::Array.cast") }; - Ok(Arc::new(make_function_inputs2!( - &args[0], + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float32Array, - Int32Array, - { - |value: f32, decimal_places: i32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") From 2de6e2907f773fe05afd97eeb3311e5efde37faa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 21 Oct 2024 09:49:20 -0400 Subject: [PATCH 035/110] Minor: Update release instructions (#13024) --- dev/release/README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dev/release/README.md b/dev/release/README.md index bd9c0621fdbc..0e0daa9d6c40 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -260,19 +260,22 @@ Verify that the Cargo.toml in the tarball contains the correct version ```shell (cd datafusion/common && cargo publish) +(cd datafusion/expr-common && cargo publish) +(cd datafusion/physical-expr-common && cargo publish) +(cd datafusion/functions-aggregate-common && cargo publish) (cd datafusion/expr && cargo publish) (cd datafusion/execution && cargo publish) -(cd datafusion/physical-expr-common && cargo publish) -(cd datafusion/functions-aggregate && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/functions && cargo publish) +(cd datafusion/functions-aggregate && cargo publish) +(cd datafusion/functions-window && cargo publish) (cd datafusion/functions-nested && cargo publish) (cd datafusion/sql && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) -(cd datafusion/catalog && cargo publish) (cd datafusion/physical-plan && cargo publish) (cd datafusion/physical-optimizer && cargo publish) +(cd datafusion/catalog && cargo publish) (cd datafusion/core && cargo publish) (cd datafusion/proto-common && cargo publish) (cd datafusion/proto && cargo publish) From 701cb00e3b9b24b00908ad4a07ed01d68ccfa7c2 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Mon, 21 Oct 2024 18:48:24 +0300 Subject: [PATCH 036/110] fix: join swap for projected semi/anti joins (#13022) --- .../src/physical_optimizer/join_selection.rs | 85 ++++++++++++++----- 1 file changed, 65 insertions(+), 20 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index dfaa7dbb8910..1c63df1f0281 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -140,20 +140,32 @@ fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, projection: Option<&Vec>, + join_type: &JoinType, ) -> Option> { - projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }) + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } } /// This function swaps the inputs of the given join operator. @@ -179,6 +191,7 @@ pub fn swap_hash_join( left.schema().fields().len(), right.schema().fields().len(), hash_join.projection.as_ref(), + hash_join.join_type(), ), partition_mode, hash_join.null_equals_null(), @@ -1289,27 +1302,59 @@ mod tests_statistical { ); } + #[rstest( + join_type, projection, small_on_right, + case::inner(JoinType::Inner, vec![1], true), + case::left(JoinType::Left, vec![1], true), + case::right(JoinType::Right, vec![1], true), + case::full(JoinType::Full, vec![1], true), + case::left_anti(JoinType::LeftAnti, vec![0], false), + case::left_semi(JoinType::LeftSemi, vec![0], false), + case::right_anti(JoinType::RightAnti, vec![0], true), + case::right_semi(JoinType::RightSemi, vec![0], true), + )] #[tokio::test] - async fn test_hash_join_swap_on_joins_with_projections() -> Result<()> { + async fn test_hash_join_swap_on_joins_with_projections( + join_type: JoinType, + projection: Vec, + small_on_right: bool, + ) -> Result<()> { let (big, small) = create_big_and_small(); + + let left = if small_on_right { &big } else { &small }; + let right = if small_on_right { &small } else { &big }; + + let left_on = if small_on_right { + "big_col" + } else { + "small_col" + }; + let right_on = if small_on_right { + "small_col" + } else { + "big_col" + }; + let join = Arc::new(HashJoinExec::try_new( - Arc::clone(&big), - Arc::clone(&small), + Arc::clone(left), + Arc::clone(right), vec![( - Arc::new(Column::new_with_schema("big_col", &big.schema())?), - Arc::new(Column::new_with_schema("small_col", &small.schema())?), + Arc::new(Column::new_with_schema(left_on, &left.schema())?), + Arc::new(Column::new_with_schema(right_on, &right.schema())?), )], None, - &JoinType::Inner, - Some(vec![1]), + &join_type, + Some(projection), PartitionMode::Partitioned, false, )?); + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) .expect("swap_hash_join must support joins with projections"); let swapped_join = swapped.as_any().downcast_ref::().expect( "ProjectionExec won't be added above if HashJoinExec contains embedded projection", ); + assert_eq!(swapped_join.projection, Some(vec![0_usize])); assert_eq!(swapped.schema().fields.len(), 1); assert_eq!(swapped.schema().fields[0].name(), "small_col"); From 45a316ca7e3b5faa21cc80bab6d879c2be3d90c8 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 21 Oct 2024 21:47:05 +0200 Subject: [PATCH 037/110] Extract CSE logic to `datafusion_common` (#13002) * Extract CSE logic * address review comments, move `HashNode` to `datafusion_common::cse`, shorter names for eliminator and controller, change `CSE::extract_common_nodes()` to return `Result>` (instead of `Result>>`) --- datafusion-cli/Cargo.lock | 21 +- datafusion/common/Cargo.toml | 1 + datafusion/common/src/cse.rs | 800 ++++++++++ datafusion/common/src/lib.rs | 1 + datafusion/common/src/tree_node.rs | 135 +- datafusion/expr/src/expr.rs | 79 +- .../optimizer/src/common_subexpr_eliminate.rs | 1361 +++++------------ 7 files changed, 1314 insertions(+), 1084 deletions(-) create mode 100644 datafusion/common/src/cse.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 612209fdd922..401f203dd931 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.16" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "103db485efc3e41214fe4fda9f3dbeae2eb9082f48fd236e6095627a9422066e" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" dependencies = [ "bzip2", "flate2", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.30" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "jobserver", "libc", @@ -1293,6 +1293,7 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", + "indexmap", "instant", "libc", "num_cpus", @@ -2615,9 +2616,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25a0c4b3a0e31f8b66f71ad8064521efa773910196e2cde791436f13409f3b45" +checksum = "6eb4c22c6154a1e759d7099f9ffad7cc5ef8245f9efbab4a41b92623079c82f3" dependencies = [ "async-trait", "base64 0.22.1", @@ -3411,9 +3412,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.130" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -3605,9 +3606,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 1ac27b40c219..0747672a18f6 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -56,6 +56,7 @@ arrow-schema = { workspace = true } chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } +indexmap = { workspace = true } libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs new file mode 100644 index 000000000000..453ae26e7333 --- /dev/null +++ b/datafusion/common/src/cse.rs @@ -0,0 +1,800 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with +//! a [`CSEController`], that defines how to eliminate common subtrees from a particular +//! [`TreeNode`] tree. + +use crate::hash_utils::combine_hashes; +use crate::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, +}; +use crate::Result; +use indexmap::IndexMap; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Hashes the direct content of an [`TreeNode`] without recursing into its children. +/// +/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds +/// a deep hash of a node and its descendants during the bottom-up phase of the first +/// traversal and so avoid computing the hash of the node and then the hash of its +/// descendants separately. +/// +/// If a node doesn't have any children then the value returned by `hash_node()` is +/// similar to '.hash()`, but not necessarily returns the same value. +pub trait HashNode { + fn hash_node(&self, state: &mut H); +} + +impl HashNode for Arc { + fn hash_node(&self, state: &mut H) { + (**self).hash_node(state); + } +} + +/// Identifier that represents a [`TreeNode`] tree. +/// +/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and +/// "have no collision (as low as possible)" +#[derive(Debug, Eq, PartialEq)] +struct Identifier<'n, N> { + // Hash of `node` built up incrementally during the first, visiting traversal. + // Its value is not necessarily equal to default hash of the node. E.g. it is not + // equal to `expr.hash()` if the node is `Expr`. + hash: u64, + node: &'n N, +} + +impl Clone for Identifier<'_, N> { + fn clone(&self) -> Self { + *self + } +} +impl Copy for Identifier<'_, N> {} + +impl Hash for Identifier<'_, N> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +impl<'n, N: HashNode> Identifier<'n, N> { + fn new(node: &'n N, random_state: &RandomState) -> Self { + let mut hasher = random_state.build_hasher(); + node.hash_node(&mut hasher); + let hash = hasher.finish(); + Self { hash, node } + } + + fn combine(mut self, other: Option) -> Self { + other.map_or(self, |other_id| { + self.hash = combine_hashes(self.hash, other_id.hash); + self + }) + } +} + +/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the +/// preorder index of the nodes. +/// +/// This cache is filled by [`CSEVisitor`] during the first traversal and is +/// used by [`CSERewriter`] during the second traversal. +/// +/// The purpose of this cache is to quickly find the identifier of a node during the +/// second traversal. +/// +/// Elements in this array are added during `f_down` so the indexes represent the preorder +/// index of nodes and thus element 0 belongs to the root of the tree. +/// +/// The elements of the array are tuples that contain: +/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start +/// from 0. +/// - The optional [`Identifier`] of the node. If none the node should not be considered +/// for CSE. +/// +/// # Example +/// An expression tree like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))), +/// (1, Some(Identifier(hash_of("a"), &"a"))), +/// (0, Some(Identifier(hash_of("b"), &"b"))) +/// ] +/// ``` +type IdArray<'n, N> = Vec<(usize, Option>)>; + +/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s +/// by their identifiers. +type NodeStats<'n, N> = HashMap, (usize, usize)>; + +/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, +/// extracted during the second, rewriting traversal. +type CommonNodes<'n, N> = IndexMap, (N, String)>; + +type ChildrenList = (Vec, Vec); + +/// The [`TreeNode`] specific definition of elimination. +pub trait CSEController { + /// The type of the tree nodes. + type Node; + + /// Splits the children to normal and conditionally evaluated ones or returns `None` + /// if all are always evaluated. + fn conditional_children(node: &Self::Node) -> Option>; + + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. + // Validity is propagated up which means no subtree can be eliminated that contains + // an invalid node. + // (E.g. volatile expressions are not valid and subtrees containing such a node can't + // be extracted.) + fn is_valid(node: &Self::Node) -> bool; + + // Returns true if a node should be ignored during CSE. Contrary to validity of a node, + // it is not propagated up. + fn is_ignored(&self, node: &Self::Node) -> bool; + + // Generates a new name for the extracted subtree. + fn generate_alias(&self) -> String; + + // Replaces a node to the generated alias. + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + + // A helper method called on each node during top-down traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node during bottom-up traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common +/// subtrees. +#[derive(Debug)] +pub enum FoundCommonNodes { + /// No common [`TreeNode`]s were found + No { original_nodes_list: Vec> }, + + /// Common [`TreeNode`]s were found + Yes { + /// extracted common [`TreeNode`] + common_nodes: Vec<(N, String)>, + + /// new [`TreeNode`]s with common subtrees replaced + new_nodes_list: Vec>, + + /// original [`TreeNode`]s + original_nodes_list: Vec>, + }, +} + +/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees. +/// +/// An identifier contains information of the [`TreeNode`] itself and its subtrees. +/// This visitor implementation use a stack `visit_stack` to track traversal, which +/// lets us know when a subtree's visiting is finished. When `pre_visit` is called +/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack. +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem` +/// before the first `EnterMark` is considered to be sub-tree of the leaving node. +/// +/// This visitor also records identifier in `id_array`. Makes the following traverse +/// pass can get the identifier of a node without recalculate it. We assign each node +/// in the tree a series number, start from 1, maintained by `series_number`. +/// Series number represents the order we left (`f_up()`) a node. Has the property +/// that child node's series number always smaller than parent's. While `id_array` is +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to +/// get the index of `id_array` for each node. +/// +/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier +/// because they should not be recognized as common subtree. +struct CSEVisitor<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a mut NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a mut IdArray<'n, N>, + + /// inner states + visit_stack: Vec>, + + /// preorder index, start from 0. + down_index: usize, + + /// postorder index, start from 0. + up_index: usize, + + /// a [`RandomState`] to generate hashes during the first traversal + random_state: &'a RandomState, + + /// a flag to indicate that common [`TreeNode`]s found + found_common: bool, + + /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`] + /// might not be executed depending on the runtime values of other [`TreeNode`]s, and + /// thus can not be extracted as a common [`TreeNode`]. + conditional: bool, + + controller: &'a C, +} + +/// Record item that used when traversing a [`TreeNode`] tree. +enum VisitRecord<'n, N> { + /// Marks the beginning of [`TreeNode`]. It contains: + /// - The post-order index assigned during the first, visiting traversal. + EnterMark(usize), + + /// Marks an accumulated subtree. It contains: + /// - The accumulated identifier of a subtree. + /// - A accumulated boolean flag if the subtree is valid for CSE. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + NodeItem(Identifier<'n, N>, bool), +} + +impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { + /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before + /// it. Returns a tuple that contains: + /// - The pre-order index of the [`TreeNode`] we marked. + /// - The accumulated identifier of the children of the marked [`TreeNode`]. + /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all + /// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a + /// common [`TreeNode`] from its children POV). + /// (E.g. if any of the children of the marked expression is not valid (e.g. is + /// volatile) then the expression is also not valid, so we can propagate this + /// information up from children to parents via `visit_stack` during the first, + /// visiting traversal and no need to test the expression's validity beforehand with + /// an extra traversal). + fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { + let mut node_id = None; + let mut is_valid = true; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::EnterMark(down_index) => { + return (down_index, node_id, is_valid); + } + VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { + node_id = Some(sub_node_id.combine(node_id)); + is_valid &= sub_node_is_valid; + } + } + } + unreachable!("EnterMark should paired with NodeItem"); + } +} + +impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> + for CSEVisitor<'_, 'n, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + self.id_array.push((0, None)); + self.visit_stack + .push(VisitRecord::EnterMark(self.down_index)); + self.down_index += 1; + + // If a node can short-circuit then some of its children might not be executed so + // count the occurrence either normal or conditional. + Ok(if self.conditional { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + TreeNodeRecursion::Continue + } else { + // If we are already in a node that can short-circuit then start new + // traversals on its normal conditional children. + match C::conditional_children(node) { + Some((normal, conditional)) => { + normal + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = true; + conditional + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = false; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit node continue the traversal. + _ => TreeNodeRecursion::Continue, + } + }) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark(); + + let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); + let is_valid = C::is_valid(node) && sub_node_is_valid; + + self.id_array[down_index].0 = self.up_index; + if is_valid && !self.controller.is_ignored(node) { + self.id_array[down_index].1 = Some(node_id); + let (count, conditional_count) = + self.node_stats.entry(node_id).or_insert((0, 0)); + if self.conditional { + *conditional_count += 1; + } else { + *count += 1; + } + if *count > 1 || (*count == 1 && *conditional_count > 0) { + self.found_common = true; + } + } + self.visit_stack + .push(VisitRecord::NodeItem(node_id, is_valid)); + self.up_index += 1; + + Ok(TreeNodeRecursion::Continue) + } +} + +/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the +/// corresponding temporary [`TreeNode`], that column contains the evaluate result of +/// replaced [`TreeNode`] tree. +struct CSERewriter<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a IdArray<'n, N>, + + /// common [`TreeNode`]s, that are replaced during the second traversal, are collected + /// to this map + common_nodes: &'a mut CommonNodes<'n, N>, + + // preorder index, starts from 0. + down_index: usize, + + controller: &'a mut C, +} + +impl> TreeNodeRewriter + for CSERewriter<'_, '_, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_down(&node); + + let (up_index, node_id) = self.id_array[self.down_index]; + self.down_index += 1; + + // Handle nodes with identifiers only + if let Some(node_id) = node_id { + let (count, conditional_count) = self.node_stats.get(&node_id).unwrap(); + if *count > 1 || *count == 1 && *conditional_count > 0 { + // step index to skip all sub-node (which has smaller series number). + while self.down_index < self.id_array.len() + && self.id_array[self.down_index].0 < up_index + { + self.down_index += 1; + } + + let (node, alias) = + self.common_nodes.entry(node_id).or_insert_with(|| { + let node_alias = self.controller.generate_alias(); + (node, node_alias) + }); + + let rewritten = self.controller.rewrite(node, alias); + + return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_up(&node); + + Ok(Transformed::no(node)) + } +} + +/// The main entry point of Common Subexpression Elimination. +/// +/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular +/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the +/// [`CSE::extract_common_nodes()`] method. +pub struct CSE> { + random_state: RandomState, + phantom_data: PhantomData, + controller: C, +} + +impl> CSE { + pub fn new(controller: C) -> Self { + Self { + random_state: RandomState::new(), + phantom_data: PhantomData, + controller, + } + } + + /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. + fn node_to_id_array<'n>( + &self, + node: &'n N, + node_stats: &mut NodeStats<'n, N>, + id_array: &mut IdArray<'n, N>, + ) -> Result { + let mut visitor = CSEVisitor { + node_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + random_state: &self.random_state, + found_common: false, + conditional: false, + controller: &self.controller, + }; + node.visit(&mut visitor)?; + + Ok(visitor.found_common) + } + + /// Returns the identifier list for each element in `nodes` and a flag to indicate if + /// rewrite phase of CSE make sense. + /// + /// Returns and array with 1 element for each input node in `nodes` + /// + /// Each element is itself the result of [`CSE::node_to_id_array`] for that node + /// (e.g. the identifiers for each node in the tree) + fn to_arrays<'n>( + &self, + nodes: &'n [N], + node_stats: &mut NodeStats<'n, N>, + ) -> Result<(bool, Vec>)> { + let mut found_common = false; + nodes + .iter() + .map(|n| { + let mut id_array = vec![]; + self.node_to_id_array(n, node_stats, &mut id_array) + .map(|fc| { + found_common |= fc; + + id_array + }) + }) + .collect::>>() + .map(|id_arrays| (found_common, id_arrays)) + } + + /// Replace common subtrees in `node` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`] + fn replace_common_node<'n>( + &mut self, + node: N, + id_array: &IdArray<'n, N>, + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result { + if id_array.is_empty() { + Ok(Transformed::no(node)) + } else { + node.rewrite(&mut CSERewriter { + node_stats, + id_array, + common_nodes, + down_index: 0, + controller: &mut self.controller, + }) + } + .data() + } + + /// Replace common subtrees in `nodes_list` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]. + fn rewrite_nodes_list<'n>( + &mut self, + nodes_list: Vec>, + arrays_list: &[Vec>], + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result>> { + nodes_list + .into_iter() + .zip(arrays_list.iter()) + .map(|(nodes, arrays)| { + nodes + .into_iter() + .zip(arrays.iter()) + .map(|(node, id_array)| { + self.replace_common_node(node, id_array, node_stats, common_nodes) + }) + .collect::>>() + }) + .collect::>>() + } + + /// Extracts common [`TreeNode`]s and rewrites `nodes_list`. + /// + /// Returns [`FoundCommonNodes`] recording the result of the extraction. + pub fn extract_common_nodes( + &mut self, + nodes_list: Vec>, + ) -> Result> { + let mut found_common = false; + let mut node_stats = NodeStats::new(); + let id_arrays_list = nodes_list + .iter() + .map(|nodes| { + self.to_arrays(nodes, &mut node_stats) + .map(|(fc, id_arrays)| { + found_common |= fc; + + id_arrays + }) + }) + .collect::>>()?; + if found_common { + let mut common_nodes = CommonNodes::new(); + let new_nodes_list = self.rewrite_nodes_list( + // Must clone the list of nodes as Identifiers use references to original + // nodes so we have to keep them intact. + nodes_list.clone(), + &id_arrays_list, + &node_stats, + &mut common_nodes, + )?; + assert!(!common_nodes.is_empty()); + + Ok(FoundCommonNodes::Yes { + common_nodes: common_nodes.into_values().collect(), + new_nodes_list, + original_nodes_list: nodes_list, + }) + } else { + Ok(FoundCommonNodes::No { + original_nodes_list: nodes_list, + }) + } + } +} + +#[cfg(test)] +mod test { + use crate::alias::AliasGenerator; + use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; + use crate::tree_node::tests::TestTreeNode; + use crate::Result; + use std::collections::HashSet; + use std::hash::{Hash, Hasher}; + + const CSE_PREFIX: &str = "__common_node"; + + #[derive(Clone, Copy)] + pub enum TestTreeNodeMask { + Normal, + NormalAndAggregates, + } + + pub struct TestTreeNodeCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: TestTreeNodeMask, + } + + impl<'a> TestTreeNodeCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { + Self { + alias_generator, + mask, + } + } + } + + impl CSEController for TestTreeNodeCSEController<'_> { + type Node = TestTreeNode; + + fn conditional_children( + _: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + None + } + + fn is_valid(_node: &Self::Node) -> bool { + true + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + let is_leaf = node.is_leaf(); + let is_aggr = node.data == "avg" || node.data == "sum"; + + match self.mask { + TestTreeNodeMask::Normal => is_leaf || is_aggr, + TestTreeNodeMask::NormalAndAggregates => is_leaf, + } + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) + } + } + + impl HashNode for TestTreeNode { + fn hash_node(&self, state: &mut H) { + self.data.hash(state); + } + } + + #[test] + fn id_array_visitor() -> Result<()> { + let alias_generator = AliasGenerator::new(); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::Normal, + )); + + let a_plus_1 = TestTreeNode::new( + vec![ + TestTreeNode::new_leaf("a".to_string()), + TestTreeNode::new_leaf("1".to_string()), + ], + "+".to_string(), + ); + let avg_c = TestTreeNode::new( + vec![TestTreeNode::new_leaf("c".to_string())], + "avg".to_string(), + ); + let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string()); + let sum_a_plus_1_minus_avg_c = + TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string()); + let root = TestTreeNode::new( + vec![ + sum_a_plus_1_minus_avg_c, + TestTreeNode::new_leaf("2".to_string()), + ], + "*".to_string(), + ); + + let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [a_plus_1] = sum_a_plus_1.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + + // skip aggregates + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + // Collect distinct hashes and set them to 0 in `id_array` + fn collect_hashes( + id_array: &mut IdArray<'_, TestTreeNode>, + ) -> HashSet { + id_array + .iter_mut() + .flat_map(|(_, id_option)| { + id_option.as_mut().map(|node_id| { + let hash = node_id.hash; + node_id.hash = 0; + hash + }) + }) + .collect::>() + } + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 3); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + (3, None), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + (5, None), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + // include aggregates + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::NormalAndAggregates, + )); + + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 5); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + ( + 3, + Some(Identifier { + hash: 0, + node: sum_a_plus_1, + }), + ), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + ( + 5, + Some(Identifier { + hash: 0, + node: avg_c, + }), + ), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + Ok(()) + } +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 8323f5efc86d..e4575038ab98 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -31,6 +31,7 @@ mod unnest; pub mod alias; pub mod cast; pub mod config; +pub mod cse; pub mod display; pub mod error; pub mod file_options; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index b4d3251fd263..563f1fa85614 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -1027,7 +1027,7 @@ impl TreeNode for T { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::collections::HashMap; use std::fmt::Display; @@ -1037,16 +1037,27 @@ mod tests { }; use crate::Result; - #[derive(Debug, Eq, Hash, PartialEq)] - struct TestTreeNode { - children: Vec>, - data: T, + #[derive(Debug, Eq, Hash, PartialEq, Clone)] + pub struct TestTreeNode { + pub(crate) children: Vec>, + pub(crate) data: T, } impl TestTreeNode { - fn new(children: Vec>, data: T) -> Self { + pub(crate) fn new(children: Vec>, data: T) -> Self { Self { children, data } } + + pub(crate) fn new_leaf(data: T) -> Self { + Self { + children: vec![], + data, + } + } + + pub(crate) fn is_leaf(&self) -> bool { + self.children.is_empty() + } } impl TreeNode for TestTreeNode { @@ -1086,12 +1097,12 @@ mod tests { // | // A fn test_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1130,13 +1141,13 @@ mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1146,12 +1157,12 @@ mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1160,12 +1171,12 @@ mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1202,12 +1213,12 @@ mod tests { } fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1236,12 +1247,12 @@ mod tests { } fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1250,12 +1261,12 @@ mod tests { } fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1289,12 +1300,12 @@ mod tests { } fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1303,12 +1314,12 @@ mod tests { } fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1372,12 +1383,12 @@ mod tests { } fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1385,12 +1396,12 @@ mod tests { } fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1406,12 +1417,12 @@ mod tests { } fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1419,12 +1430,12 @@ mod tests { } fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1451,12 +1462,12 @@ mod tests { } fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1464,12 +1475,12 @@ mod tests { } fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1499,13 +1510,13 @@ mod tests { } fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1513,12 +1524,12 @@ mod tests { } fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -2016,16 +2027,16 @@ mod tests { // A #[test] fn test_apply_and_visit_references() -> Result<()> { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_a_2 = TestTreeNode::new(vec![], "a".to_string()); - let node_b_2 = TestTreeNode::new(vec![], "b".to_string()); + let node_a_2 = TestTreeNode::new_leaf("a".to_string()); + let node_b_2 = TestTreeNode::new_leaf("b".to_string()); let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string()); let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string()); - let node_a_3 = TestTreeNode::new(vec![], "a".to_string()); + let node_a_3 = TestTreeNode::new_leaf("a".to_string()); let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string()); let node_f_ref = &tree; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f3f71a87278b..691b65d34443 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -34,6 +34,7 @@ use crate::{ }; use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -1652,47 +1653,39 @@ impl Expr { | Expr::Placeholder(..) => false, } } +} - /// Hashes the direct content of an `Expr` without recursing into its children. - /// - /// This method is useful to incrementally compute hashes, such as in - /// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants - /// during the bottom-up phase of the first traversal and so avoid computing the hash - /// of the node and then the hash of its descendants separately. - /// - /// If a node doesn't have any children then this method is similar to `.hash()`, but - /// not necessarily returns the same value. - /// +impl HashNode for Expr { /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes /// compile time. - pub fn hash_node(&self, hasher: &mut H) { - mem::discriminant(self).hash(hasher); + fn hash_node(&self, state: &mut H) { + mem::discriminant(self).hash(state); match self { Expr::Alias(Alias { expr: _expr, relation, name, }) => { - relation.hash(hasher); - name.hash(hasher); + relation.hash(state); + name.hash(state); } Expr::Column(column) => { - column.hash(hasher); + column.hash(state); } Expr::ScalarVariable(data_type, name) => { - data_type.hash(hasher); - name.hash(hasher); + data_type.hash(state); + name.hash(state); } Expr::Literal(scalar_value) => { - scalar_value.hash(hasher); + scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { left: _left, op, right: _right, }) => { - op.hash(hasher); + op.hash(state); } Expr::Like(Like { negated, @@ -1708,9 +1701,9 @@ impl Expr { escape_char, case_insensitive, }) => { - negated.hash(hasher); - escape_char.hash(hasher); - case_insensitive.hash(hasher); + negated.hash(state); + escape_char.hash(state); + case_insensitive.hash(state); } Expr::Not(_expr) | Expr::IsNotNull(_expr) @@ -1728,7 +1721,7 @@ impl Expr { low: _low, high: _high, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Case(Case { expr: _expr, @@ -1743,10 +1736,10 @@ impl Expr { expr: _expr, data_type, }) => { - data_type.hash(hasher); + data_type.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { - func.hash(hasher); + func.hash(state); } Expr::AggregateFunction(AggregateFunction { func, @@ -1756,9 +1749,9 @@ impl Expr { order_by: _order_by, null_treatment, }) => { - func.hash(hasher); - distinct.hash(hasher); - null_treatment.hash(hasher); + func.hash(state); + distinct.hash(state); + null_treatment.hash(state); } Expr::WindowFunction(WindowFunction { fun, @@ -1768,49 +1761,49 @@ impl Expr { window_frame, null_treatment, }) => { - fun.hash(hasher); - window_frame.hash(hasher); - null_treatment.hash(hasher); + fun.hash(state); + window_frame.hash(state); + null_treatment.hash(state); } Expr::InList(InList { expr: _expr, list: _list, negated, }) => { - negated.hash(hasher); + negated.hash(state); } Expr::Exists(Exists { subquery, negated }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::InSubquery(InSubquery { expr: _expr, subquery, negated, }) => { - subquery.hash(hasher); - negated.hash(hasher); + subquery.hash(state); + negated.hash(state); } Expr::ScalarSubquery(subquery) => { - subquery.hash(hasher); + subquery.hash(state); } Expr::Wildcard { qualifier, options } => { - qualifier.hash(hasher); - options.hash(hasher); + qualifier.hash(state); + options.hash(state); } Expr::GroupingSet(grouping_set) => { - mem::discriminant(grouping_set).hash(hasher); + mem::discriminant(grouping_set).hash(state); match grouping_set { GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} GroupingSet::GroupingSets(_exprs) => {} } } Expr::Placeholder(place_holder) => { - place_holder.hash(hasher); + place_holder.hash(state); } Expr::OuterReferenceColumn(data_type, column) => { - data_type.hash(hasher); - column.hash(hasher); + data_type.hash(state); + column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} }; diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c13cb3a8e973..921011d33fc4 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,8 +17,8 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions -use std::collections::{BTreeSet, HashMap}; -use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::collections::BTreeSet; +use std::fmt::Debug; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; @@ -26,11 +26,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; -use datafusion_common::hash_utils::combine_hashes; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; + +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ @@ -38,81 +36,9 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::tree_node::replace_sort_expressions; use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator}; -use indexmap::IndexMap; const CSE_PREFIX: &str = "__common_expr"; -/// Identifier that represents a subexpression tree. -/// -/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and -/// "have no collision (as low as possible)" -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -struct Identifier<'n> { - // Hash of `expr` built up incrementally during the first, visiting traversal, but its - // value is not necessarily equal to `expr.hash()`. - hash: u64, - expr: &'n Expr, -} - -impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, random_state: &RandomState) -> Self { - let mut hasher = random_state.build_hasher(); - expr.hash_node(&mut hasher); - let hash = hasher.finish(); - Self { hash, expr } - } - - fn combine(mut self, other: Option) -> Self { - other.map_or(self, |other_id| { - self.hash = combine_hashes(self.hash, other_id.hash); - self - }) - } -} - -impl Hash for Identifier<'_> { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - -/// A cache that contains the postorder index and the identifier of expression tree nodes -/// by the preorder index of the nodes. -/// -/// This cache is filled by `ExprIdentifierVisitor` during the first traversal and is used -/// by `CommonSubexprRewriter` during the second traversal. -/// -/// The purpose of this cache is to quickly find the identifier of a node during the -/// second traversal. -/// -/// Elements in this array are added during `f_down` so the indexes represent the preorder -/// index of expression nodes and thus element 0 belongs to the root of the expression -/// tree. -/// The elements of the array are tuples that contain: -/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start -/// from 0. -/// - Identifier of the expression. If empty (`""`), expr should not be considered for -/// CSE. -/// -/// # Example -/// An expression like `(a + b)` would have the following `IdArray`: -/// ```text -/// [ -/// (2, "a + b"), -/// (1, "a"), -/// (0, "b") -/// ] -/// ``` -type IdArray<'n> = Vec<(usize, Option>)>; - -/// A map that contains the number of normal and conditional occurrences of expressions by -/// their identifiers. -type ExprStats<'n> = HashMap, (usize, usize)>; - -/// A map that contains the common expressions and their alias extracted during the -/// second, rewriting traversal. -type CommonExprs<'n> = IndexMap, (Expr, String)>; - /// Performs Common Sub-expression Elimination optimization. /// /// This optimization improves query performance by computing expressions that @@ -140,168 +66,11 @@ type CommonExprs<'n> = IndexMap, (Expr, String)>; /// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once /// ``` #[derive(Debug)] -pub struct CommonSubexprEliminate { - random_state: RandomState, -} - -/// The result of potentially rewriting a list of expressions to eliminate common -/// subexpressions. -#[derive(Debug)] -enum FoundCommonExprs { - /// No common expressions were found - No { original_exprs_list: Vec> }, - /// Common expressions were found - Yes { - /// extracted common expressions - common_exprs: Vec<(Expr, String)>, - /// new expressions with common subexpressions replaced - new_exprs_list: Vec>, - /// original expressions - original_exprs_list: Vec>, - }, -} +pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { pub fn new() -> Self { - Self { - random_state: RandomState::new(), - } - } - - /// Returns the identifier list for each element in `exprs` and a flag to indicate if - /// rewrite phase of CSE make sense. - /// - /// Returns and array with 1 element for each input expr in `exprs` - /// - /// Each element is itself the result of [`CommonSubexprEliminate::expr_to_identifier`] for that expr - /// (e.g. the identifiers for each node in the tree) - fn to_arrays<'n>( - &self, - exprs: &'n [Expr], - expr_stats: &mut ExprStats<'n>, - expr_mask: ExprMask, - ) -> Result<(bool, Vec>)> { - let mut found_common = false; - exprs - .iter() - .map(|e| { - let mut id_array = vec![]; - self.expr_to_identifier(e, expr_stats, &mut id_array, expr_mask) - .map(|fc| { - found_common |= fc; - - id_array - }) - }) - .collect::>>() - .map(|id_arrays| (found_common, id_arrays)) - } - - /// Add an identifier to `id_array` for every subexpression in this tree. - fn expr_to_identifier<'n>( - &self, - expr: &'n Expr, - expr_stats: &mut ExprStats<'n>, - id_array: &mut IdArray<'n>, - expr_mask: ExprMask, - ) -> Result { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - conditional: false, - }; - expr.visit(&mut visitor)?; - - Ok(visitor.found_common) - } - - /// Rewrites `exprs_list` with common sub-expressions replaced with a new - /// column. - /// - /// `common_exprs` is updated with any sub expressions that were replaced. - /// - /// Returns the rewritten expressions - fn rewrite_exprs_list<'n>( - &self, - exprs_list: Vec>, - arrays_list: &[Vec>], - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, - ) -> Result>> { - exprs_list - .into_iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { - exprs - .into_iter() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr( - expr, - id_array, - expr_stats, - common_exprs, - alias_generator, - ) - }) - .collect::>>() - }) - .collect::>>() - } - - /// Extracts common sub-expressions and rewrites `exprs_list`. - /// - /// Returns `FoundCommonExprs` recording the result of the extraction - fn find_common_exprs( - &self, - exprs_list: Vec>, - config: &dyn OptimizerConfig, - expr_mask: ExprMask, - ) -> Result> { - let mut found_common = false; - let mut expr_stats = ExprStats::new(); - let id_arrays_list = exprs_list - .iter() - .map(|exprs| { - self.to_arrays(exprs, &mut expr_stats, expr_mask).map( - |(fc, id_arrays)| { - found_common |= fc; - - id_arrays - }, - ) - }) - .collect::>>()?; - if found_common { - let mut common_exprs = CommonExprs::new(); - let new_exprs_list = self.rewrite_exprs_list( - // Must clone as Identifiers use references to original expressions so we have - // to keep the original expressions intact. - exprs_list.clone(), - &id_arrays_list, - &expr_stats, - &mut common_exprs, - config.alias_generator().as_ref(), - )?; - assert!(!common_exprs.is_empty()); - - Ok(Transformed::yes(FoundCommonExprs::Yes { - common_exprs: common_exprs.into_values().collect(), - new_exprs_list, - original_exprs_list: exprs_list, - })) - } else { - Ok(Transformed::no(FoundCommonExprs::No { - original_exprs_list: exprs_list, - })) - } + Self {} } fn try_optimize_proj( @@ -372,80 +141,83 @@ impl CommonSubexprEliminate { get_consecutive_window_exprs(window); // Extract common sub-expressions from the list. - self.find_common_exprs(window_expr_list, config, ExprMask::Normal)? - .map_data(|common| match common { - // If there are common sub-expressions, then the insert a projection node - // with the common expressions between the new window nodes and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - new_exprs_list, - original_exprs_list, - } => { - build_common_expr_project_plan(input, common_exprs).map(|new_input| { - (new_exprs_list, new_input, Some(original_exprs_list)) + + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(window_expr_list)? + { + // If there are common sub-expressions, then the insert a projection node + // with the common expressions between the new window nodes and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: new_exprs_list, + original_nodes_list: original_exprs_list, + } => build_common_expr_project_plan(input, common_exprs).map(|new_input| { + Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list))) + }), + FoundCommonNodes::No { + original_nodes_list: original_exprs_list, + } => Ok(Transformed::no((original_exprs_list, input, None))), + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok((new_window_expr_list, new_input, window_expr_list)) + }) + })? + // Rebuild the consecutive window nodes. + .map_data(|(new_window_expr_list, new_input, window_expr_list)| { + // If there were common expressions extracted, then we need to make sure + // we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around extracted + // common expressions this doesn't mean that the original column names + // (schema) are preserved due to the inserted aliases are not always at + // the top of the expression. + // Let's consider improving `find_common_exprs()` to always keep column + // names and get rid of additional name preserving logic here. + if let Some(window_expr_list) = window_expr_list { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = window_expr_list + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>() }) - } - FoundCommonExprs::No { - original_exprs_list, - } => Ok((original_exprs_list, input, None)), - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok((new_window_expr_list, new_input, window_expr_list)) - }) - })? - // Rebuild the consecutive window nodes. - .map_data(|(new_window_expr_list, new_input, window_expr_list)| { - // If there were common expressions extracted, then we need to make sure - // we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around extracted - // common expressions this doesn't mean that the original column names - // (schema) are preserved due to the inserted aliases are not always at - // the top of the expression. - // Let's consider improving `find_common_exprs()` to always keep column - // names and get rid of additional name preserving logic here. - if let Some(window_expr_list) = window_expr_list { - let name_preserver = NamePreserver::new_for_projection(); - let saved_names = window_expr_list - .iter() - .map(|exprs| { - exprs - .iter() - .map(|expr| name_preserver.save(expr)) - .collect::>() - }) - .collect::>(); - new_window_expr_list.into_iter().zip(saved_names).try_rfold( - new_input, - |plan, (new_window_expr, saved_names)| { - let new_window_expr = new_window_expr - .into_iter() - .zip(saved_names) - .map(|(new_window_expr, saved_name)| { - saved_name.restore(new_window_expr) - }) - .collect::>(); - Window::try_new(new_window_expr, Arc::new(plan)) - .map(LogicalPlan::Window) - }, - ) - } else { - new_window_expr_list - .into_iter() - .zip(window_schemas) - .try_rfold(new_input, |plan, (new_window_expr, schema)| { - Window::try_new_with_schema( - new_window_expr, - Arc::new(plan), - schema, - ) + .collect::>(); + new_window_expr_list.into_iter().zip(saved_names).try_rfold( + new_input, + |plan, (new_window_expr, saved_names)| { + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>(); + Window::try_new(new_window_expr, Arc::new(plan)) .map(LogicalPlan::Window) - }) - } - }) + }, + ) + } else { + new_window_expr_list + .into_iter() + .zip(window_schemas) + .try_rfold(new_input, |plan, (new_window_expr, schema)| { + Window::try_new_with_schema( + new_window_expr, + Arc::new(plan), + schema, + ) + .map(LogicalPlan::Window) + }) + } + }) } fn try_optimize_aggregate( @@ -462,174 +234,175 @@ impl CommonSubexprEliminate { } = aggregate; let input = Arc::unwrap_or_clone(input); // Extract common sub-expressions from the aggregate and grouping expressions. - self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)? - .map_data(|common| { - match common { - // If there are common sub-expressions, then insert a projection node - // with the common expressions between the new aggregate node and the - // original input. - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let new_aggr_expr = new_exprs_list.pop().unwrap(); - let new_group_expr = new_exprs_list.pop().unwrap(); - - build_common_expr_project_plan(input, common_exprs).map( - |new_input| { - let aggr_expr = original_exprs_list.pop().unwrap(); - ( - new_aggr_expr, - new_group_expr, - new_input, - Some(aggr_expr), - ) - }, - ) - } - - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_aggr_expr = original_exprs_list.pop().unwrap(); - let new_group_expr = original_exprs_list.pop().unwrap(); - - Ok((new_aggr_expr, new_group_expr, input, None)) - } - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { - self.rewrite(new_input, config)?.map_data(|new_input| { - Ok(( + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![group_expr, aggr_expr])? + { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + Transformed::yes(( new_aggr_expr, new_group_expr, - aggr_expr, - Arc::new(new_input), + new_input, + Some(aggr_expr), )) }) - })? - // Try extracting common aggregate expressions and rebuild the aggregate node. - .transform_data(|(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + } + + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); + + Ok(Transformed::no(( + new_aggr_expr, + new_group_expr, + input, + None, + ))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data( + |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { // Extract common aggregate sub-expressions from the aggregate expressions. - self.find_common_exprs( - vec![new_aggr_expr], - config, + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), ExprMask::NormalAndAggregates, - )? - .map_data(|common| { - match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - mut original_exprs_list, - } => { - let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); - let new_aggr_expr = original_exprs_list.pop().unwrap(); - - let mut agg_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); + )) + .extract_common_nodes(vec![new_aggr_expr])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &mut proj_exprs) - } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) - { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = - expr_rewritten - { - agg_exprs.push(expr.alias(&name)); - proj_exprs - .push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = - config.alias_generator().next(CSE_PREFIX); - let (qualifier, field_name) = - expr_rewritten.qualified_name(); - let out_name = qualified_name( - qualifier.as_ref(), - &field_name, - ); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)) - .alias(out_name), - ); - } + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &mut proj_exprs) + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); } else { - proj_exprs.push(expr_rewritten); + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field_name) = + expr_rewritten.qualified_name(); + let out_name = + qualified_name(qualifier.as_ref(), &field_name); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); } + } else { + proj_exprs.push(expr_rewritten); } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - new_input, - new_group_expr, - agg_exprs, - )?); - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) } - // If there aren't any common aggregate sub-expressions, then just - // rebuild the aggregate node. - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); - - // If there were common expressions extracted, then we need to - // make sure we restore the original column names. - // TODO: Although `find_common_exprs()` inserts aliases around - // extracted common expressions this doesn't mean that the - // original column names (schema) are preserved due to the - // inserted aliases are not always at the top of the - // expression. - // Let's consider improving `find_common_exprs()` to always - // keep column names and get rid of additional name - // preserving logic here. - if let Some(aggr_expr) = aggr_expr { - let name_perserver = NamePreserver::new_for_projection(); - let saved_names = aggr_expr - .iter() - .map(|expr| name_perserver.save(expr)) - .collect::>(); - let new_aggr_expr = rewritten_aggr_expr - .into_iter() - .zip(saved_names) - .map(|(new_expr, saved_name)| { - saved_name.restore(new_expr) - }) - .collect::>(); - - // Since `group_expr` may have changed, schema may also. - // Use `try_new()` method. - Aggregate::try_new( - new_input, - new_group_expr, - new_aggr_expr, - ) - .map(LogicalPlan::Aggregate) - } else { - Aggregate::try_new_with_schema( - new_input, - new_group_expr, - rewritten_aggr_expr, - schema, - ) + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(|p| Transformed::yes(LogicalPlan::Projection(p))) + } + + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>(); + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>(); + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) - } + .map(Transformed::no) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) } } - }) - }) + } + }, + ) } /// Rewrites the expr list and input to remove common subexpressions @@ -653,30 +426,34 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { // Extract common sub-expressions from the expressions. - self.find_common_exprs(vec![exprs], config, ExprMask::Normal)? - .map_data(|common| match common { - FoundCommonExprs::Yes { - common_exprs, - mut new_exprs_list, - original_exprs_list: _, - } => { - let new_exprs = new_exprs_list.pop().unwrap(); - build_common_expr_project_plan(input, common_exprs) - .map(|new_input| (new_exprs, new_input)) - } - FoundCommonExprs::No { - mut original_exprs_list, - } => { - let new_exprs = original_exprs_list.pop().unwrap(); - Ok((new_exprs, input)) - } - })? - // Recurse into the new input. - // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) - .transform_data(|(new_exprs, new_input)| { - self.rewrite(new_input, config)? - .map_data(|new_input| Ok((new_exprs, new_input))) - }) + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![exprs])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| Transformed::yes((new_exprs, new_input))) + } + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); + Ok(Transformed::no((new_exprs, input))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_exprs, new_input)| { + self.rewrite(new_input, config)? + .map_data(|new_input| Ok((new_exprs, new_input))) + }) } } @@ -800,71 +577,6 @@ impl OptimizerRule for CommonSubexprEliminate { } } -impl Default for CommonSubexprEliminate { - fn default() -> Self { - Self::new() - } -} - -/// Build the "intermediate" projection plan that evaluates the extracted common -/// expressions. -/// -/// # Arguments -/// input: the input plan -/// -/// common_exprs: which common subexpressions were used (and thus are added to -/// intermediate projection) -/// -/// expr_stats: the set of common subexpressions -fn build_common_expr_project_plan( - input: LogicalPlan, - common_exprs: Vec<(Expr, String)>, -) -> Result { - let mut fields_set = BTreeSet::new(); - let mut project_exprs = common_exprs - .into_iter() - .map(|(expr, expr_alias)| { - fields_set.insert(expr_alias.clone()); - Ok(expr.alias(expr_alias)) - }) - .collect::>>()?; - - for (qualifier, field) in input.schema().iter() { - if fields_set.insert(qualified_name(qualifier, field.name())) { - project_exprs.push(Expr::from((qualifier, field))); - } - } - - Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -/// Build the projection plan to eliminate unnecessary columns produced by -/// the "intermediate" projection plan built in [build_common_expr_project_plan]. -/// -/// This is required to keep the schema the same for plans that pass the input -/// on to the output, such as `Filter` or `Sort`. -fn build_recover_project_plan( - schema: &DFSchema, - input: LogicalPlan, -) -> Result { - let col_exprs = schema.iter().map(Expr::from).collect(); - Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) -} - -fn extract_expressions(expr: &Expr, result: &mut Vec) { - if let Expr::GroupingSet(groupings) = expr { - for e in groupings.distinct_expr() { - let (qualifier, field_name) = e.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)) - } - } else { - let (qualifier, field_name) = expr.qualified_name(); - let col = Column::new(qualifier, field_name); - result.push(Expr::Column(col)); - } -} - /// Which type of [expressions](Expr) should be considered for rewriting? #[derive(Debug, Clone, Copy)] enum ExprMask { @@ -882,156 +594,36 @@ enum ExprMask { NormalAndAggregates, } -impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { - let is_normal_minus_aggregates = matches!( - expr, - Expr::Literal(..) - | Expr::Column(..) - | Expr::ScalarVariable(..) - | Expr::Alias(..) - | Expr::Wildcard { .. } - ); - - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } - } -} - -/// Go through an expression tree and generate identifiers for each subexpression. -/// -/// An identifier contains information of the expression itself and its sub-expression. -/// This visitor implementation use a stack `visit_stack` to track traversal, which -/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called -/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` -/// before the first `EnterMark` is considered to be sub-tree of the leaving node. -/// -/// This visitor also records identifier in `id_array`. Makes the following traverse -/// pass can get the identifier of a node without recalculate it. We assign each node -/// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`f_up()`) a node. Has the property -/// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to -/// get the index of `id_array` for each node. -/// -/// `Expr` without sub-expr (column, literal etc.) will not have identifier -/// because they should not be recognized as common sub-expr. -struct ExprIdentifierVisitor<'a, 'n> { - // statistics of expressions - expr_stats: &'a mut ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a mut IdArray<'n>, - // inner states - visit_stack: Vec>, - // preorder index, start from 0. - down_index: usize, - // postorder index, start from 0. - up_index: usize, - // which expression should be skipped? - expr_mask: ExprMask, - // a `RandomState` to generate hashes during the first traversal - random_state: &'a RandomState, - // a flag to indicate that common expression found - found_common: bool, - // if we are in a conditional branch. A conditional branch means that the expression - // might not be executed depending on the runtime values of other expressions, and - // thus can not be extracted as a common expression. - conditional: bool, -} +struct ExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: ExprMask, -/// Record item that used when traversing an expression tree. -enum VisitRecord<'n> { - /// Marks the beginning of expression. It contains: - /// - The post-order index assigned during the first, visiting traversal. - EnterMark(usize), - - /// Marks an accumulated subexpression tree. It contains: - /// - The accumulated identifier of a subexpression. - /// - A boolean flag if the expression is valid for subexpression elimination. - /// The flag is propagated up from children to parent. (E.g. volatile expressions - /// are not valid and can't be extracted, but non-volatile children of volatile - /// expressions can be extracted.) - ExprItem(Identifier<'n>, bool), + // how many aliases have we seen so far + alias_counter: usize, } -impl<'n> ExprIdentifierVisitor<'_, 'n> { - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before - /// it. Returns a tuple that contains: - /// - The pre-order index of the expression we marked. - /// - The accumulated identifier of the children of the marked expression. - /// - An accumulated boolean flag from the children of the marked expression if all - /// children are valid for subexpression elimination (i.e. it is safe to extract the - /// expression as a common expression from its children POV). - /// (E.g. if any of the children of the marked expression is not valid (e.g. is - /// volatile) then the expression is also not valid, so we can propagate this - /// information up from children to parents via `visit_stack` during the first, - /// visiting traversal and no need to test the expression's validity beforehand with - /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { - let mut expr_id = None; - let mut is_valid = true; - - while let Some(item) = self.visit_stack.pop() { - match item { - VisitRecord::EnterMark(down_index) => { - return (down_index, expr_id, is_valid); - } - VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { - expr_id = Some(sub_expr_id.combine(expr_id)); - is_valid &= sub_expr_is_valid; - } - } +impl<'a> ExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self { + Self { + alias_generator, + mask, + alias_counter: 0, } - unreachable!("Enter mark should paired with node number"); - } - - /// Save the current `conditional` status and run `f` with `conditional` set to true. - fn conditionally Result<()>>( - &mut self, - mut f: F, - ) -> Result<()> { - let conditional = self.conditional; - self.conditional = true; - f(self)?; - self.conditional = conditional; - - Ok(()) } } -impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { +impl CSEController for ExprCSEController<'_> { type Node = Expr; - fn f_down(&mut self, expr: &'n Expr) -> Result { - self.id_array.push((0, None)); - self.visit_stack - .push(VisitRecord::EnterMark(self.down_index)); - self.down_index += 1; - - // If an expression can short-circuit then some of its children might not be - // executed so count the occurrence of subexpressions as conditional in all - // children. - Ok(match expr { - // If we are already in a conditionally evaluated subtree then continue - // traversal. - _ if self.conditional => TreeNodeRecursion::Continue, - + fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { + match node { // In case of `ScalarFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) if func.short_circuits() => { - self.conditionally(|visitor| { - args.iter().try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump + Some((vec![], args.iter().collect())) } // In case of `And` and `Or` the first child is surely executed, but we @@ -1040,12 +632,7 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { left, op: Operator::And | Operator::Or, right, - }) => { - left.visit(self)?; - self.conditionally(|visitor| right.visit(visitor).map(|_| ()))?; - - TreeNodeRecursion::Jump - } + }) => Some((vec![left.as_ref()], vec![right.as_ref()])), // In case of `Case` the optional base expression and the first when // expressions are surely executed, but we account subexpressions as @@ -1054,167 +641,151 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { expr, when_then_expr, else_expr, - }) => { - expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?; - when_then_expr.iter().take(1).try_for_each(|(when, then)| { - when.visit(self)?; - self.conditionally(|visitor| then.visit(visitor).map(|_| ())) - })?; - self.conditionally(|visitor| { - when_then_expr.iter().skip(1).try_for_each(|(when, then)| { - when.visit(visitor)?; - then.visit(visitor).map(|_| ()) - })?; - else_expr - .iter() - .try_for_each(|e| e.visit(visitor).map(|_| ())) - })?; - - TreeNodeRecursion::Jump - } + }) => Some(( + expr.iter() + .map(|e| e.as_ref()) + .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) + .collect(), + when_then_expr + .iter() + .take(1) + .map(|(_, then)| then.as_ref()) + .chain( + when_then_expr + .iter() + .skip(1) + .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]), + ) + .chain(else_expr.iter().map(|e| e.as_ref())) + .collect(), + )), + _ => None, + } + } - // In case of non-short-circuit expressions continue the traversal. - _ => TreeNodeRecursion::Continue, - }) + fn is_valid(node: &Expr) -> bool { + !node.is_volatile_node() } - fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); + fn is_ignored(&self, node: &Expr) -> bool { + let is_normal_minus_aggregates = matches!( + node, + Expr::Literal(..) + | Expr::Column(..) + | Expr::ScalarVariable(..) + | Expr::Alias(..) + | Expr::Wildcard { .. } + ); - let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); - let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; + let is_aggr = matches!(node, Expr::AggregateFunction(..)); - self.id_array[down_index].0 = self.up_index; - if is_valid && !self.expr_mask.ignores(expr) { - self.id_array[down_index].1 = Some(expr_id); - let (count, conditional_count) = - self.expr_stats.entry(expr_id).or_insert((0, 0)); - if self.conditional { - *conditional_count += 1; - } else { - *count += 1; - } - if *count > 1 || (*count == 1 && *conditional_count > 0) { - self.found_common = true; - } + match self.mask { + ExprMask::Normal => is_normal_minus_aggregates || is_aggr, + ExprMask::NormalAndAggregates => is_normal_minus_aggregates, } - self.visit_stack - .push(VisitRecord::ExprItem(expr_id, is_valid)); - self.up_index += 1; - - Ok(TreeNodeRecursion::Continue) } -} -/// Rewrite expression by replacing detected common sub-expression with -/// the corresponding temporary column name. That column contains the -/// evaluate result of replaced expression. -struct CommonSubexprRewriter<'a, 'n> { - // statistics of expressions - expr_stats: &'a ExprStats<'n>, - // cache to speed up second traversal - id_array: &'a IdArray<'n>, - // common expression, that are replaced during the second traversal, are collected to - // this map - common_exprs: &'a mut CommonExprs<'n>, - // preorder index, starts from 0. - down_index: usize, - // how many aliases have we seen so far - alias_counter: usize, - // alias generator for extracted common expressions - alias_generator: &'a AliasGenerator, -} + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } -impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { - type Node = Expr; + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + // alias the expressions without an `Alias` ancestor node + if self.alias_counter > 0 { + col(alias) + } else { + self.alias_counter += 1; + col(alias).alias(node.schema_name().to_string()) + } + } - fn f_down(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { + fn rewrite_f_down(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { self.alias_counter += 1; } + } + fn rewrite_f_up(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { + self.alias_counter -= 1 + } + } +} - let (up_index, expr_id) = self.id_array[self.down_index]; - self.down_index += 1; +impl Default for CommonSubexprEliminate { + fn default() -> Self { + Self::new() + } +} - // Handle `Expr`s with identifiers only - if let Some(expr_id) = expr_id { - let (count, conditional_count) = self.expr_stats.get(&expr_id).unwrap(); - if *count > 1 || *count == 1 && *conditional_count > 0 { - // step index to skip all sub-node (which has smaller series number). - while self.down_index < self.id_array.len() - && self.id_array[self.down_index].0 < up_index - { - self.down_index += 1; - } +/// Build the "intermediate" projection plan that evaluates the extracted common +/// expressions. +/// +/// # Arguments +/// input: the input plan +/// +/// common_exprs: which common subexpressions were used (and thus are added to +/// intermediate projection) +/// +/// expr_stats: the set of common subexpressions +fn build_common_expr_project_plan( + input: LogicalPlan, + common_exprs: Vec<(Expr, String)>, +) -> Result { + let mut fields_set = BTreeSet::new(); + let mut project_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| { + fields_set.insert(expr_alias.clone()); + Ok(expr.alias(expr_alias)) + }) + .collect::>>()?; - let expr_name = expr.schema_name().to_string(); - let (_, expr_alias) = - self.common_exprs.entry(expr_id).or_insert_with(|| { - let expr_alias = self.alias_generator.next(CSE_PREFIX); - (expr, expr_alias) - }); - - // alias the expressions without an `Alias` ancestor node - let rewritten = if self.alias_counter > 0 { - col(expr_alias.clone()) - } else { - self.alias_counter += 1; - col(expr_alias.clone()).alias(expr_name) - }; - - return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); - } + for (qualifier, field) in input.schema().iter() { + if fields_set.insert(qualified_name(qualifier, field.name())) { + project_exprs.push(Expr::from((qualifier, field))); } - - Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { - self.alias_counter -= 1 - } + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) +} - Ok(Transformed::no(expr)) - } +/// Build the projection plan to eliminate unnecessary columns produced by +/// the "intermediate" projection plan built in [build_common_expr_project_plan]. +/// +/// This is required to keep the schema the same for plans that pass the input +/// on to the output, such as `Filter` or `Sort`. +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let col_exprs = schema.iter().map(Expr::from).collect(); + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } -/// Replace common sub-expression in `expr` with the corresponding temporary -/// column name, updating `common_exprs` with any replaced expressions -fn replace_common_expr<'n>( - expr: Expr, - id_array: &IdArray<'n>, - expr_stats: &ExprStats<'n>, - common_exprs: &mut CommonExprs<'n>, - alias_generator: &AliasGenerator, -) -> Result { - if id_array.is_empty() { - Ok(Transformed::no(expr)) +fn extract_expressions(expr: &Expr, result: &mut Vec) { + if let Expr::GroupingSet(groupings) = expr { + for e in groupings.distinct_expr() { + let (qualifier, field_name) = e.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)) + } } else { - expr.rewrite(&mut CommonSubexprRewriter { - expr_stats, - id_array, - common_exprs, - down_index: 0, - alias_counter: 0, - alias_generator, - }) + let (qualifier, field_name) = expr.qualified_name(); + let col = Column::new(qualifier, field_name); + result.push(Expr::Column(col)); } - .data() } #[cfg(test)] mod test { use std::any::Any; - use std::collections::HashSet; use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, - ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -1238,154 +809,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let optimizer = CommonSubexprEliminate::new(); - - let a_plus_1 = col("a") + lit(1); - let avg_c = avg(col("c")); - let sum_a_plus_1 = sum(a_plus_1); - let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c; - let expr = sum_a_plus_1_minus_avg_c * lit(2); - - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1_minus_avg_c, - .. - }) = &expr - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::BinaryExpr(BinaryExpr { - left: sum_a_plus_1, - right: avg_c, - .. - }) = sum_a_plus_1_minus_avg_c.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let Expr::AggregateFunction(AggregateFunction { - args: a_plus_1_vec, .. - }) = sum_a_plus_1.as_ref() - else { - panic!("Cannot extract subexpression reference") - }; - let a_plus_1 = &a_plus_1_vec.as_slice()[0]; - - // skip aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::Normal, - )?; - - // Collect distinct hashes and set them to 0 in `id_array` - fn collect_hashes(id_array: &mut IdArray) -> HashSet { - id_array - .iter_mut() - .flat_map(|(_, expr_id_option)| { - expr_id_option.as_mut().map(|expr_id| { - let hash = expr_id.hash; - expr_id.hash = 0; - hash - }) - }) - .collect::>() - } - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 3); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - (3, None), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - (5, None), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - optimizer.expr_to_identifier( - &expr, - &mut ExprStats::new(), - &mut id_array, - ExprMask::NormalAndAggregates, - )?; - - let hashes = collect_hashes(&mut id_array); - assert_eq!(hashes.len(), 5); - - let expected = vec![ - ( - 8, - Some(Identifier { - hash: 0, - expr: &expr, - }), - ), - ( - 6, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1_minus_avg_c, - }), - ), - ( - 3, - Some(Identifier { - hash: 0, - expr: sum_a_plus_1, - }), - ), - ( - 2, - Some(Identifier { - hash: 0, - expr: a_plus_1, - }), - ), - (0, None), - (1, None), - ( - 5, - Some(Identifier { - hash: 0, - expr: avg_c, - }), - ), - (4, None), - (7, None), - ]; - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: From 2535d88b4c8af7010fbc6366c06e9d4f0eb4f640 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Oct 2024 04:37:25 +0800 Subject: [PATCH 038/110] enhance unparsing plan with pushdown to avoid unnamed subquery (#13006) --- datafusion/sql/src/unparser/plan.rs | 61 +++++++++++++++++++++-- datafusion/sql/src/unparser/rewrite.rs | 10 ++-- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +-- 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 8e70654d8d6f..77f885c1de5f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -660,9 +660,10 @@ impl Unparser<'_> { if !Self::is_scan_with_pushdown(table_scan) { return Ok(None); } + let table_schema = table_scan.source.schema(); let mut filter_alias_rewriter = alias.as_ref().map(|alias_name| TableAliasRewriter { - table_schema: table_scan.source.schema(), + table_schema: &table_schema, alias_name: alias_name.clone(), }); @@ -671,6 +672,17 @@ impl Unparser<'_> { Arc::clone(&table_scan.source), None, )?; + // We will rebase the column references to the new alias if it exists. + // If the projection or filters are empty, we will append alias to the table scan. + // + // Example: + // select t1.c1 from t1 where t1.c1 > 1 -> select a.c1 from t1 as a where a.c1 > 1 + if alias.is_some() + && (table_scan.projection.is_some() || !table_scan.filters.is_empty()) + { + builder = builder.alias(alias.clone().unwrap())?; + } + if let Some(project_vec) = &table_scan.projection { let project_columns = project_vec .iter() @@ -688,9 +700,6 @@ impl Unparser<'_> { } }) .collect::>(); - if let Some(alias) = alias { - builder = builder.alias(alias)?; - } builder = builder.project(project_columns)?; } @@ -720,6 +729,16 @@ impl Unparser<'_> { builder = builder.limit(0, Some(fetch))?; } + // If the table scan has an alias but no projection or filters, it means no column references are rebased. + // So we will append the alias to this subquery. + // Example: + // select * from t1 limit 10 -> (select * from t1 limit 10) as a + if alias.is_some() + && (table_scan.projection.is_none() && table_scan.filters.is_empty()) + { + builder = builder.alias(alias.clone().unwrap())?; + } + Ok(Some(builder.build()?)) } LogicalPlan::SubqueryAlias(subquery_alias) => { @@ -728,6 +747,40 @@ impl Unparser<'_> { Some(subquery_alias.alias.clone()), ) } + // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. + // The inner table scan could be a scan with pushdown operations. + LogicalPlan::Projection(projection) => { + if let Some(plan) = + Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? + { + let exprs = if alias.is_some() { + let mut alias_rewriter = + alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: plan.schema().as_arrow(), + alias_name: alias_name.clone(), + }); + projection + .expr + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::>>()? + } else { + projection.expr.clone() + }; + Ok(Some( + LogicalPlanBuilder::from(plan).project(exprs)?.build()?, + )) + } else { + Ok(None) + } + } _ => Ok(None), } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 3049df9396cb..57d700f86955 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,7 +20,7 @@ use std::{ sync::Arc, }; -use arrow_schema::SchemaRef; +use arrow_schema::Schema; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, Result, TableReference, @@ -293,7 +293,7 @@ pub(super) fn inject_column_aliases_into_subquery( /// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to /// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table` pub(super) fn inject_column_aliases( - projection: &datafusion_expr::Projection, + projection: &Projection, aliases: impl IntoIterator, ) -> LogicalPlan { let mut updated_projection = projection.clone(); @@ -343,12 +343,12 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { /// from which the columns are referenced. This is used to look up columns by their names. /// * `alias_name`: The alias (`TableReference`) that will replace the table name /// in the column references when applicable. -pub struct TableAliasRewriter { - pub table_schema: SchemaRef, +pub struct TableAliasRewriter<'a> { + pub table_schema: &'a Schema, pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter { +impl TreeNodeRewriter for TableAliasRewriter<'_> { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 0de74e050553..e7b96199511a 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -765,7 +765,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; assert_eq!( table_scan_with_two_filter.to_string(), - "SELECT * FROM (SELECT t1.id FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))) AS a" + "SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))" ); let table_scan_with_fetch = @@ -776,7 +776,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; assert_eq!( table_scan_with_fetch.to_string(), - "SELECT * FROM (SELECT t1.id FROM (SELECT * FROM t1 LIMIT 10)) AS a" + "SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a" ); let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( @@ -792,7 +792,7 @@ fn test_table_scan_alias() -> Result<()> { let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; assert_eq!( table_scan_with_pushdown_all.to_string(), - "SELECT * FROM (SELECT t1.id FROM (SELECT t1.id, t1.age FROM t1 WHERE (t1.id > 1) LIMIT 10)) AS a" + "SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a" ); Ok(()) } From 4fca0d5122918c70e0783e44016609b27ba7c253 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Mon, 21 Oct 2024 22:37:50 +0200 Subject: [PATCH 039/110] fix: Verify supported type for Unary::Plus in sql planner (#13019) This adds a type check when planning unary plus operator. Since we currently do not represent the operator in our logical plan we can not check it later. Instead of introducing a new `Expr` this patch just verifies the type during the translation instead. --- datafusion/sql/src/expr/unary_op.rs | 19 ++++++++++++++++--- datafusion/sqllogictest/test_files/scalar.slt | 3 +++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 3c547050380d..06988eb03893 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -16,8 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DFSchema, Result}; -use datafusion_expr::Expr; +use datafusion_common::{not_impl_err, plan_err, DFSchema, Result}; +use datafusion_expr::{ + type_coercion::{is_interval, is_timestamp}, + Expr, ExprSchemable, +}; use sqlparser::ast::{Expr as SQLExpr, UnaryOperator, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -33,7 +36,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), UnaryOperator::Plus => { - Ok(self.sql_expr_to_logical_expr(expr, schema, planner_context)?) + let operand = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let (data_type, _) = operand.data_type_and_nullable(schema)?; + if data_type.is_numeric() + || is_interval(&data_type) + || is_timestamp(&data_type) + { + Ok(operand) + } else { + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types") + } } UnaryOperator::Minus => { match expr { diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 0c2fa41e5bf8..d510206b1930 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1526,6 +1526,9 @@ NULL NULL query error DataFusion error: Error during planning: Negation only supports numeric, interval and timestamp types SELECT -'100' +query error DataFusion error: Error during planning: Unary operator '\+' only supports numeric, interval and timestamp types +SELECT +true + statement ok drop table test_boolean From 34fbe8e2161cd56171a04eeae6bb36c2b00040d9 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Mon, 21 Oct 2024 22:39:08 +0200 Subject: [PATCH 040/110] Fix count on all null `VALUES` clause (#13029) * Test Count accumulator with all-nulls * Fix count on null values Before the change, the `ValuesExec` containing `NullArray` would incorrectly report column statistics as being non-null, which would misinform `AggregateStatistics` optimizer and fold `count(always_null)` into row count instead of 0. This commit fixes the column statistics derivation for values with `NullArray` and therefore fixes execution of logical plans with count over such values. Note that the bug was not reproducible using DataFusion SQL frontend, because in DataFusion SQL the `VALUES (NULL)` doesn't have type `DataType:Null` (it has some apparently arbitrarily picked type instead). As a follow-up, all usages of `Array:null_count` should be inspected. The function can easily be misused (it returns "physical nulls", which do not exist for null type). --- datafusion/core/tests/core_integration.rs | 3 + .../core/tests/execution/logical_plan.rs | 95 +++++++++++++++++++ datafusion/core/tests/execution/mod.rs | 18 ++++ datafusion/functions-aggregate/src/count.rs | 14 +++ datafusion/physical-plan/src/common.rs | 6 +- datafusion/physical-plan/src/values.rs | 31 ++++++ 6 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 datafusion/core/tests/execution/logical_plan.rs create mode 100644 datafusion/core/tests/execution/mod.rs diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 79e5056e3cf5..e0917e6cca19 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -24,6 +24,9 @@ mod dataframe; /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; +/// Run all tests that are found in the `execution` directory +mod execution; + /// Run all tests that are found in the `expr_api` directory mod expr_api; diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs new file mode 100644 index 000000000000..168bf484e541 --- /dev/null +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::Int64Array; +use arrow_schema::{DataType, Field}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::logical_plan::{LogicalPlan, Values}; +use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_plan::collect; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +///! Logical plans need to provide stable semantics, as downstream projects +///! create them and depend on them. Test executable semantics of logical plans. + +#[tokio::test] +async fn count_only_nulls() -> Result<()> { + // Input: VALUES (NULL), (NULL), (NULL) AS _(col) + let input_schema = Arc::new(DFSchema::from_unqualified_fields( + vec![Field::new("col", DataType::Null, true)].into(), + HashMap::new(), + )?); + let input = Arc::new(LogicalPlan::Values(Values { + schema: input_schema, + values: vec![ + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + ], + })); + let input_col_ref = Expr::Column(Column { + relation: None, + name: "col".to_string(), + }); + + // Aggregation: count(col) AS count + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input, + vec![], + vec![Expr::AggregateFunction(AggregateFunction { + func: Arc::new(AggregateUDF::new_from_impl(Count::new())), + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + })], + )?); + + // Execute and verify results + let session_state = SessionStateBuilder::new().build(); + let physical_plan = session_state.create_physical_plan(&aggregate).await?; + let result = + collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?; + + let result = only(result.as_slice()); + let result_schema = result.schema(); + let field = only(result_schema.fields().deref()); + let column = only(result.columns()); + + assert_eq!(field.data_type(), &DataType::Int64); // TODO should be UInt64 + assert_eq!(column.deref(), &Int64Array::from(vec![0])); + + Ok(()) +} + +fn only(elements: &[T]) -> &T +where + T: Debug, +{ + let [element] = elements else { + panic!("Expected exactly one element, got {:?}", elements); + }; + element +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs new file mode 100644 index 000000000000..8169db1a4611 --- /dev/null +++ b/datafusion/core/tests/execution/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod logical_plan; diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 61dbfd674993..b4eeb937d4fb 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -715,3 +715,17 @@ impl Accumulator for DistinctCountAccumulator { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + + #[test] + fn count_accumulator_nulls() -> Result<()> { + let mut accumulator = CountAccumulator::new(); + accumulator.update_batch(&[Arc::new(NullArray::new(10))])?; + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 4b5eea6b760d..5abdf367c571 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -156,7 +156,11 @@ pub fn compute_record_batch_statistics( for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - null_counts[stat_index] += batch.column(*col_index).null_count(); + null_counts[stat_index] += batch + .column(*col_index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default(); } } } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index e01aea1fdd6b..ab5b45463b0c 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -219,6 +219,7 @@ mod tests { use crate::test::{self, make_partition}; use arrow_schema::{DataType, Field}; + use datafusion_common::stats::{ColumnStatistics, Precision}; #[tokio::test] async fn values_empty_case() -> Result<()> { @@ -269,4 +270,34 @@ mod tests { let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) .unwrap_err(); } + + #[test] + fn values_stats_with_nulls_only() -> Result<()> { + let data = vec![ + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + ]; + let rows = data.len(); + let values = ValuesExec::try_new( + Arc::new(Schema::new(vec![Field::new("col0", DataType::Null, true)])), + data, + )?; + + assert_eq!( + values.statistics()?, + Statistics { + num_rows: Precision::Exact(rows), + total_byte_size: Precision::Exact(8), // not important + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(rows), // there are only nulls + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + },], + } + ); + + Ok(()) + } } From b978cf8236436038a106ed94fb0d7eaa6ba99962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Tue, 22 Oct 2024 04:57:04 +0200 Subject: [PATCH 041/110] Support filter in cross join elimination (#13025) * Support filter in cross join elimination * Support filter in cross join elimination * Support filter in cross join elimination * Support filter in cross join elimination --- .../optimizer/src/eliminate_cross_join.rs | 61 +++++++++++-------- datafusion/sqllogictest/test_files/join.slt | 2 +- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index bce5c77ca674..8a365fb389be 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -22,13 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; @@ -88,6 +88,7 @@ impl OptimizerRule for EliminateCrossJoin { let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; + let mut all_filters: Vec = vec![]; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin { Arc::unwrap_or_clone(input), &mut possible_join_keys, &mut all_inputs, + &mut all_filters, )?; extract_possible_join_keys(&predicate, &mut possible_join_keys); @@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(&plan) { return Ok(Transformed::no(plan)); } - flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; None } else { // recursively try to rewrite children @@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin { )); } + if !all_filters.is_empty() { + // Add any filters on top - PushDownFilter can push filters down to applicable join + let first = all_filters.swap_remove(0); + let predicate = all_filters.into_iter().fold(first, and); + left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + } + let Some(predicate) = parent_predicate else { return Ok(Transformed::yes(left)); }; @@ -206,25 +220,25 @@ fn flatten_join_inputs( plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, + all_filters: &mut Vec, ) -> Result<()> { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // checked in can_flatten_join_inputs - if join.filter.is_some() { - return internal_err!( - "should not have filter in inner join in flatten_join_inputs" - ); + if let Some(filter) = join.filter { + all_filters.push(filter); } possible_join_keys.insert_all_owned(join.on); flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, + all_filters, )?; } LogicalPlan::CrossJoin(join) => { @@ -232,11 +246,13 @@ fn flatten_join_inputs( Arc::unwrap_or_clone(join.left), possible_join_keys, all_inputs, + all_filters, )?; flatten_join_inputs( Arc::unwrap_or_clone(join.right), possible_join_keys, all_inputs, + all_filters, )?; } _ => { @@ -253,13 +269,7 @@ fn flatten_join_inputs( fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { - LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - if join.filter.is_some() { - return false; - } - } + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} LogicalPlan::CrossJoin(_) => {} _ => return false, }; @@ -467,12 +477,6 @@ mod tests { assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: LogicalPlan) { - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(!transformed_plan.transformed) - } - #[test] fn eliminate_cross_with_simple_and() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -642,8 +646,7 @@ mod tests { } #[test] - /// See https://github.com/apache/datafusion/issues/7530 - fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; let t3 = test_table_scan_with_name("t3")?; @@ -660,7 +663,17 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(plan); + let expected = vec![ + "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" + ]; + + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index fe9ceaa7907a..39f903a58714 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1152,7 +1152,7 @@ logical_plan 01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1 02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0) 03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4 -04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1 +04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1 05)--------TableScan: t1 projection=[v0, v1] 06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4] 07)----TableScan: t0 projection=[v0, v1] From 465d6609bc7c284321ac7b1c2934a0a21951346f Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 22 Oct 2024 05:11:49 +0200 Subject: [PATCH 042/110] Do no alias in TableScan filters (#13048) --- datafusion/core/tests/expr_api/simplification.rs | 4 ++-- datafusion/expr/src/expr_rewriter/mod.rs | 9 ++++++--- .../optimizer/src/simplify_expressions/simplify_exprs.rs | 2 +- datafusion/sqllogictest/test_files/tpch/q22.slt.part | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index d7995d4663be..800a087587da 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -333,8 +333,8 @@ fn simplify_scan_predicate() -> Result<()> { .build()?; // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; + // after simplify: t.g = t.f" + let expected = "TableScan: test, full_filters=[g = f]"; let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 15930914dd59..47cc947be3ca 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -306,9 +306,12 @@ impl NamePreserver { /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan pub fn new(plan: &LogicalPlan) -> Self { Self { - // The schema of Filter and Join nodes comes from their inputs rather than their output expressions, - // so there is no need to use aliases to preserve expression names. - use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + // The schema of Filter, Join and TableScan nodes comes from their inputs rather than + // their expressions, so there is no need to use aliases to preserve expression names. + use_alias: !matches!( + plan, + LogicalPlan::Filter(_) | LogicalPlan::Join(_) | LogicalPlan::TableScan(_) + ), } } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index c0142ae0fc5a..200f1f159d81 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -208,7 +208,7 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; assert_optimized_plan_eq(table_scan, expected) } diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part index d2168b0136ba..2955748160ea 100644 --- a/datafusion/sqllogictest/test_files/tpch/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -72,7 +72,7 @@ logical_plan 14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] 15)----------------Projection: customer.c_acctbal 16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] +17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] physical_plan 01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] 02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] From 755ba9158ac2125b5d5b10bb76e27ee9137f7552 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Mon, 21 Oct 2024 23:48:37 -0700 Subject: [PATCH 043/110] [minor]: remove same util functions from the code base. (#13026) * Initial commit * Resolve linter errors * Decrease diff --- .../tests/fuzz_cases/equivalence/ordering.rs | 173 ++++++++++++++++- .../tests/fuzz_cases/equivalence/utils.rs | 57 ++++++ .../physical-expr/src/equivalence/mod.rs | 178 ------------------ .../physical-expr/src/equivalence/ordering.rs | 171 +---------------- 4 files changed, 230 insertions(+), 349 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 604d1a1000c3..94157e11702c 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,7 +16,7 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - convert_to_orderings, create_random_schema, create_test_schema_2, + convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, is_table_same_after_sort, TestScalarUDF, }; @@ -160,6 +160,177 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { Ok(()) } +#[test] +fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options, + }) + .collect::>(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(&required), + expected, + "{err_msg}" + ); + } + + Ok(()) +} + // This test checks given a table is ordered with `[a ASC, b ASC, c ASC, d ASC]` and `[a ASC, c ASC, b ASC, d ASC]` // whether the table is also ordered with `[a ASC, b ASC, d ASC]` and `[a ASC, c ASC, d ASC]` // Since these orderings cannot be deduced, these orderings shouldn't be satisfied by the table generated. diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index ce3afba81ee2..61691311fe4e 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -299,6 +299,63 @@ fn get_representative_arr( None } +// Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) +pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) +} + +/// Construct a schema with following properties +/// Schema satisfies following orderings: +/// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] +/// and +/// Column [a=c] (e.g they are aliases). +pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + eq_properties.add_equal_conditions(col_a, col_c)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) +} + // Generate a table that satisfies the given equivalence properties; i.e. // equivalences, ordering equivalences, and constants. pub fn generate_table_for_eq_properties( diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 253f1196491b..95bb93d6ca57 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -77,16 +77,10 @@ mod tests { use crate::expressions::col; use crate::PhysicalSortExpr; - use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; - use itertools::izip; - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; - pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, @@ -290,176 +284,4 @@ mod tests { Ok(()) } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - pub fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(Arc::clone(&unique_col)); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(Arc::clone(res)); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - pub fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.expr().as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_f64_array(n_elem, n_distinct, &mut rng); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) - .unwrap_or_else(|| { - generate_random_f64_array(n_elem, n_distinct, &mut rng) - }); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(Arc::clone(&representative_array)); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| { - generate_random_f64_array(n_elem, n_distinct, &mut rng) - }), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } - - // Utility function to generate random f64 array - fn generate_random_f64_array( - n_elems: usize, - n_distinct: usize, - rng: &mut StdRng, - ) -> ArrayRef { - let values: Vec = (0..n_elems) - .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index a3cf8c965b69..d71f3b037fb1 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -254,8 +254,7 @@ mod tests { use std::sync::Arc; use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_test_params, - create_test_schema, generate_table_for_eq_properties, is_table_same_after_sort, + convert_to_orderings, convert_to_sort_exprs, create_test_schema, }; use crate::equivalence::{ EquivalenceClass, EquivalenceGroup, EquivalenceProperties, @@ -600,174 +599,6 @@ mod tests { Ok(()) } - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; From c22abb4ac3f1af8bbdf176ef0198988fc7b0982c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 22 Oct 2024 05:43:21 -0400 Subject: [PATCH 044/110] Improve `AggregateFuzz` testing: generate random queries (#12847) * Add random queries into aggregate fuzz tester * Address review comments * Update datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs Co-authored-by: Jax Liu * Update datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs Co-authored-by: Jax Liu --------- Co-authored-by: Jax Liu --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 385 ++++++------------ .../aggregation_fuzzer/data_generator.rs | 37 +- .../fuzz_cases/aggregation_fuzzer/fuzzer.rs | 237 ++++++++++- test-utils/src/string_gen.rs | 2 +- 4 files changed, 370 insertions(+), 291 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff512829333a..1035fa31da08 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -45,299 +45,150 @@ use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; use crate::fuzz_cases::aggregation_fuzzer::{ - AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, }; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] // ======================================================================== - -// TODO: write more test case to cover more `group by`s and `aggregation function`s -// TODO: maybe we can use macro to simply the case creating - -/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `no group by` +// +// Notes on tests: +// +// Since the supported types differ for each aggregation function, the tests +// below are structured so they enumerate each different aggregate function. +// +// The test framework handles varying combinations of arguments (data types), +// sortedness, and grouping parameters +// +// TODO: Test floating point values (where output needs to be compared with some +// acceptable range due to floating point rounding) +// +// TODO: test other aggregate functions +// - AVG (unstable given the wide range of inputs) +// +// TODO: specific test for ordering (ensure all group by columns are ordered) +// Currently the data is sorted by random columns, so there are almost no +// repeated runs. To improve coverage we should also sort by lower cardinality columns #[tokio::test(flavor = "multi_thread")] -async fn test_basic_prim_aggr_no_group() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ColumnDescr::new("a", DataType::Int32)]; - - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set: Vec::new(), - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT sum(a) FROM fuzz_table") - .add_sql("SELECT sum(distinct a) FROM fuzz_table") - .add_sql("SELECT max(a) FROM fuzz_table") - .add_sql("SELECT min(a) FROM fuzz_table") - .add_sql("SELECT count(a) FROM fuzz_table") - .add_sql("SELECT count(distinct a) FROM fuzz_table") - .add_sql("SELECT avg(a) FROM fuzz_table") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await +async fn test_min() { + let data_gen_config = baseline_config(); + + // Queries like SELECT min(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("min") + // min works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; } -/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single int64` #[tokio::test(flavor = "multi_thread")] -async fn test_basic_prim_aggr_group_by_single_int64() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ - ColumnDescr::new("a", DataType::Int32), - ColumnDescr::new("b", DataType::Int64), - ColumnDescr::new("c", DataType::Int64), - ]; - let sort_keys_set = vec![ - vec!["b".to_string()], - vec!["c".to_string(), "b".to_string()], - ]; - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; +async fn test_max() { + let data_gen_config = baseline_config(); + + // Queries like SELECT max(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("max") + // max works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; } -/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single string` #[tokio::test(flavor = "multi_thread")] -async fn test_basic_prim_aggr_group_by_single_string() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ - ColumnDescr::new("a", DataType::Int32), - ColumnDescr::new("b", DataType::Utf8), - ColumnDescr::new("c", DataType::Int64), - ]; - let sort_keys_set = vec![ - vec!["b".to_string()], - vec!["c".to_string(), "b".to_string()], - ]; - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; +async fn test_sum() { + let data_gen_config = baseline_config(); + + // Queries like SELECT sum(a), sum(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("sum") + .with_distinct_aggregate_function("sum") + // sum only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; } -/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by string + int64` #[tokio::test(flavor = "multi_thread")] -async fn test_basic_prim_aggr_group_by_mixed_string_int64() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ - ColumnDescr::new("a", DataType::Int32), - ColumnDescr::new("b", DataType::Utf8), - ColumnDescr::new("c", DataType::Int64), - ColumnDescr::new("d", DataType::Int32), - ]; - let sort_keys_set = vec![ - vec!["b".to_string(), "c".to_string()], - vec!["d".to_string(), "b".to_string(), "c".to_string()], - ]; - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT b, c, sum(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, sum(distinct a) FROM fuzz_table GROUP BY b,c") - .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, avg(a) FROM fuzz_table GROUP BY b, c") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; +async fn test_count() { + let data_gen_config = baseline_config(); + + // Queries like SELECT count(a), count(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("count") + .with_distinct_aggregate_function("count") + // count work for all arguments + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; } -/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `no group by` -#[tokio::test(flavor = "multi_thread")] -async fn test_basic_string_aggr_no_group() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ColumnDescr::new("a", DataType::Utf8)]; - - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set: Vec::new(), - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(8) - .add_sql("SELECT max(a) FROM fuzz_table") - .add_sql("SELECT min(a) FROM fuzz_table") - .add_sql("SELECT count(a) FROM fuzz_table") - .add_sql("SELECT count(distinct a) FROM fuzz_table") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; -} - -/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single int64` -#[tokio::test(flavor = "multi_thread")] -async fn test_basic_string_aggr_group_by_single_int64() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ - ColumnDescr::new("a", DataType::Utf8), - ColumnDescr::new("b", DataType::Int64), - ColumnDescr::new("c", DataType::Int64), - ]; - let sort_keys_set = vec![ - vec!["b".to_string()], - vec!["c".to_string(), "b".to_string()], - ]; - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(8) - .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; -} - -/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single string` -#[tokio::test(flavor = "multi_thread")] -async fn test_basic_string_aggr_group_by_single_string() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config +/// Return a standard set of columns for testing data generation +/// +/// Includes numeric and string types +/// +/// Does not include: +/// 1. Floating point numbers +/// 1. structured types +fn baseline_config() -> DatasetGeneratorConfig { let columns = vec![ - ColumnDescr::new("a", DataType::Utf8), - ColumnDescr::new("b", DataType::Utf8), - ColumnDescr::new("c", DataType::Int64), - ]; - let sort_keys_set = vec![ - vec!["b".to_string()], - vec!["c".to_string(), "b".to_string()], + ColumnDescr::new("i8", DataType::Int8), + ColumnDescr::new("i16", DataType::Int16), + ColumnDescr::new("i32", DataType::Int32), + ColumnDescr::new("i64", DataType::Int64), + ColumnDescr::new("u8", DataType::UInt8), + ColumnDescr::new("u16", DataType::UInt16), + ColumnDescr::new("u32", DataType::UInt32), + ColumnDescr::new("u64", DataType::UInt64), + // TODO: date/time columns + // todo decimal columns + // begin string columns + ColumnDescr::new("utf8", DataType::Utf8), + ColumnDescr::new("largeutf8", DataType::LargeUtf8), + // TODO add support for utf8view in data generator + // ColumnDescr::new("utf8view", DataType::Utf8View), + // todo binary ]; - let data_gen_config = DatasetGeneratorConfig { - columns, - rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") - .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; -} -/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by string + int64` -#[tokio::test(flavor = "multi_thread")] -async fn test_basic_string_aggr_group_by_mixed_string_int64() { - let builder = AggregationFuzzerBuilder::default(); - - // Define data generator config - let columns = vec![ - ColumnDescr::new("a", DataType::Utf8), - ColumnDescr::new("b", DataType::Utf8), - ColumnDescr::new("c", DataType::Int64), - ColumnDescr::new("d", DataType::Int32), - ]; - let sort_keys_set = vec![ - vec!["b".to_string(), "c".to_string()], - vec!["d".to_string(), "b".to_string(), "c".to_string()], - ]; - let data_gen_config = DatasetGeneratorConfig { + DatasetGeneratorConfig { columns, rows_num_range: (512, 1024), - sort_keys_set, - }; - - // Build fuzzer - let fuzzer = builder - .data_gen_config(data_gen_config) - .data_gen_rounds(16) - .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") - .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") - .table_name("fuzz_table") - .build(); - - fuzzer.run().await; + sort_keys_set: vec![ + // low cardinality to try and get many repeated runs + vec![String::from("u8")], + vec![String::from("utf8"), String::from("u8")], + ], + } } // ======================================================================== // The old aggregation fuzz tests // ======================================================================== + /// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results @@ -353,7 +204,7 @@ async fn streaming_aggregate_test() { vec!["d", "c", "a"], vec!["d", "c", "b", "a"], ]; - let n = 300; + let n = 10; let distincts = vec![10, 20]; for distinct in distincts { let mut join_set = JoinSet::new(); diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 9d45779295e7..44f96d5a1a07 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -48,16 +48,41 @@ use test_utils::{ /// #[derive(Debug, Clone)] pub struct DatasetGeneratorConfig { - // Descriptions of columns in datasets, it's `required` + /// Descriptions of columns in datasets, it's `required` pub columns: Vec, - // Rows num range of the generated datasets, it's `required` + /// Rows num range of the generated datasets, it's `required` pub rows_num_range: (usize, usize), - // Sort keys used to generate the sorted data set, it's optional + /// Additional optional sort keys + /// + /// The generated datasets always include a non-sorted copy. For each + /// element in `sort_keys_set`, an additional datasets is created that + /// is sorted by these values as well. pub sort_keys_set: Vec>, } +impl DatasetGeneratorConfig { + /// return a list of all column names + pub fn all_columns(&self) -> Vec<&str> { + self.columns.iter().map(|d| d.name.as_str()).collect() + } + + /// return a list of column names that are "numeric" + pub fn numeric_columns(&self) -> Vec<&str> { + self.columns + .iter() + .filter_map(|d| { + if d.column_type.is_numeric() { + Some(d.name.as_str()) + } else { + None + } + }) + .collect() + } +} + /// Dataset generator /// /// It will generate one random [`Dataset`]s when `generate` function is called. @@ -96,7 +121,7 @@ impl DatasetGenerator { pub fn generate(&self) -> Result> { let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); - // Generate the base batch + // Generate the base batch (unsorted) let base_batch = self.batch_generator.generate()?; let batches = stagger_batch(base_batch.clone()); let dataset = Dataset::new(batches, Vec::new()); @@ -362,7 +387,9 @@ impl RecordBatchGenerator { DataType::LargeUtf8 => { generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i64) } - _ => unreachable!(), + _ => { + panic!("Unsupported data generator type: {data_type}") + } } } } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 6daebc894272..898d1081ff13 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::sync::Arc; use arrow::util::pretty::pretty_format_batches; @@ -61,7 +62,16 @@ impl AggregationFuzzerBuilder { } } - pub fn add_sql(mut self, sql: &str) -> Self { + /// Adds random SQL queries to the fuzzer along with the table name + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + const NUM_QUERIES: usize = 10; + for _ in 0..NUM_QUERIES { + self = self.add_sql(&query_builder.generate_query()); + } + self.table_name(query_builder.table_name()) + } + + fn add_sql(mut self, sql: &str) -> Self { self.candidate_sqls.push(Arc::from(sql)); self } @@ -76,11 +86,6 @@ impl AggregationFuzzerBuilder { self } - pub fn data_gen_rounds(mut self, data_gen_rounds: usize) -> Self { - self.data_gen_rounds = data_gen_rounds; - self - } - pub fn build(self) -> AggregationFuzzer { assert!(!self.candidate_sqls.is_empty()); let candidate_sqls = self.candidate_sqls; @@ -99,12 +104,18 @@ impl AggregationFuzzerBuilder { } } -impl Default for AggregationFuzzerBuilder { +impl std::default::Default for AggregationFuzzerBuilder { fn default() -> Self { Self::new() } } +impl From for AggregationFuzzerBuilder { + fn from(value: DatasetGeneratorConfig) -> Self { + Self::default().data_gen_config(value) + } +} + /// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], /// and running them to check the correctness of the optimizations /// (e.g. sorted, partial skipping, spilling...) @@ -169,6 +180,10 @@ impl AggregationFuzzer { }) .collect::>(); + for q in &query_groups { + println!(" Testing with query {}", q.sql); + } + let tasks = self.generate_fuzz_tasks(query_groups).await; for task in tasks { join_set.spawn(async move { task.run().await }); @@ -270,20 +285,27 @@ impl AggregationFuzzTestTask { check_equality_of_batches(task_result, expected_result).map_err(|e| { // If we found inconsistent result, we print the test details for reproducing at first let message = format!( - "{}\n\ - ### Inconsistent row:\n\ - - row_idx:{}\n\ - - task_row:{}\n\ - - expected_row:{}\n\ - ### Task total result:\n{}\n\ - ### Expected total result:\n{}\n\ - ", - self.context_error_report(), + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, e.row_idx, e.lhs_row, e.rhs_row, - pretty_format_batches(task_result).unwrap(), - pretty_format_batches(expected_result).unwrap(), + format_batches_with_limit(task_result), + format_batches_with_limit(expected_result), + format_batches_with_limit(&self.dataset_ref.batches), ); DataFusionError::Internal(message) }) @@ -305,3 +327,182 @@ impl AggregationFuzzTestTask { ) } } + +/// Pretty prints the `RecordBatch`es, limited to the first 100 rows +fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display { + const MAX_ROWS: usize = 100; + let mut row_count = 0; + let to_print = batches + .iter() + .filter_map(|b| { + if row_count >= MAX_ROWS { + None + } else if row_count + b.num_rows() > MAX_ROWS { + // output last rows before limit + let slice_len = MAX_ROWS - row_count; + let b = b.slice(0, slice_len); + row_count += slice_len; + Some(b) + } else { + row_count += b.num_rows(); + Some(b.clone()) + } + }) + .collect::>(); + + pretty_format_batches(&to_print).unwrap() +} + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default)] +pub struct QueryBuilder { + /// The name of the table to query + table_name: String, + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + /// Columns to be used in group by + group_by_columns: Vec, + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} +impl QueryBuilder { + pub fn new() -> Self { + std::default::Default::default() + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Add a column to be used in the group bys + pub fn with_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + let group_by = group_by.into_iter().map(String::from); + self.group_by_columns.extend(group_by); + self + } + + /// Add a column to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + pub fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + let mut query = String::from("SELECT "); + query.push_str(&self.random_aggregate_functions().join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = thread_rng(); + let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.gen_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{}", alias_gen); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + let function = format!("{function_name}({distinct}{argument}) as {alias}"); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = thread_rng(); + let idx = rng.gen_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to 3 group by columns to ensure coverage for large groups. With + /// larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = thread_rng(); + const MAX_GROUPS: usize = 3; + let max_groups = self.group_by_columns.len().max(MAX_GROUPS); + let num_group_by = rng.gen_range(1..max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by { + let idx = rng.gen_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index 725eb22b85af..b598241db1e9 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -62,7 +62,7 @@ impl StringBatchGenerator { let mut cases = vec![]; let mut rng = thread_rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { - for _ in 0..100 { + for _ in 0..10 { // max length of generated strings let max_len = rng.gen_range(1..50); let num_strings = rng.gen_range(1..100); From ef1365aa2416867caa63d610008ed41a7413a6e4 Mon Sep 17 00:00:00 2001 From: Agaev Guseyn <60943542+agscpp@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:05:54 +0300 Subject: [PATCH 045/110] Fix functions with Volatility::Volatile and parameters (#13001) Co-authored-by: Agaev Huseyn --- .../user_defined_scalar_functions.rs | 181 ++++++++++++++++++ datafusion/expr/src/udf.rs | 31 ++- .../physical-expr/src/scalar_function.rs | 5 +- 3 files changed, 212 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 013aec48d510..0887645b8cbf 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,9 +16,11 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use arrow::array::as_string_array; use arrow::compute::kernels::numeric::add; use arrow_array::builder::BooleanBuilder; use arrow_array::cast::AsArray; @@ -483,6 +485,185 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { Ok(()) } +/// Volatile UDF that should append a different value to each row +#[derive(Debug)] +struct AddIndexToStringVolatileScalarUDF { + name: String, + signature: Signature, + return_type: DataType, +} + +impl AddIndexToStringVolatileScalarUDF { + fn new() -> Self { + Self { + name: "add_index_to_string".to_string(), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + return_type: DataType::Utf8, + } + } +} + +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("index_with_offset function does not accept arguments") + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + let answer = match &args[0] { + // When called with static arguments, the result is returned as an array. + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + let mut answer = vec![]; + for index in 1..=number_rows { + // When calling a function with immutable arguments, the result is returned with ")". + // Example: SELECT add_index_to_string('const_value') FROM table; + answer.push(index.to_string() + ") " + value); + } + answer + } + // The result is returned as an array when called with dynamic arguments. + ColumnarValue::Array(array) => { + let string_array = as_string_array(array); + let mut counter = HashMap::<&str, u64>::new(); + string_array + .iter() + .map(|value| { + let value = value.expect("Unexpected null"); + let index = counter.get(value).unwrap_or(&0) + 1; + counter.insert(value, index); + + // When calling a function with mutable arguments, the result is returned with ".". + // Example: SELECT add_index_to_string(table.value) FROM table; + index.to_string() + ". " + value + }) + .collect() + } + _ => unimplemented!(), + }; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) + } +} + +#[tokio::test] +async fn volatile_scalar_udf_with_params() -> Result<()> { + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters + .await?; + let expected = [ + "+-----------+", + "| str |", + "+-----------+", + "| 1. test_1 |", + "| 2. test_1 |", + "| 3. test_1 |", + "| 1. test_2 |", + "| 2. test_2 |", + "| 4. test_1 |", + "| 3. test_2 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters + .await?; + let expected = [ + "+---------+", + "| str |", + "+---------+", + "| 1) test |", + "| 2) test |", + "| 3) test |", + "| 4) test |", + "| 5) test |", + "| 6) test |", + "| 7) test |", + "+---------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters + .await?; + let expected = [ + "+---------------+", + "| str |", + "+---------------+", + "| 1) test_value |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + } + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") + .await?; + let expected = [ + "+-----------+", // + "| str |", // + "+-----------+", // + "| 1. test_1 |", // + "| 2. test_1 |", // + "| 3. test_1 |", // + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + } + Ok(()) +} + #[derive(Debug)] struct CastToI64UDF { signature: Signature, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3759fb18f56d..83563603f2f3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -201,6 +201,17 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + /// Invoke the function with `args` and number of rows, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_batch`] for more details. + pub fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + self.inner.invoke_batch(args, number_rows) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_no_args`] for more details. @@ -467,7 +478,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. /// /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args - fn invoke(&self, _args: &[ColumnarValue]) -> Result; + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!( + "Function {} does not implement invoke but called", + self.name() + ) + } + + /// Invoke the function with `args` and the number of rows, + /// returning the appropriate result. + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + match args.is_empty() { + true => self.invoke_no_args(number_rows), + false => self.invoke(args), + } + } /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 4d3db96ceb3c..ab53106f6059 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -141,10 +141,7 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?; // evaluate the function - let output = match self.args.is_empty() { - true => self.fun.invoke_no_args(batch.num_rows()), - false => self.fun.invoke(&inputs), - }?; + let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { From 227908ff16a6eed6fa2b9c0f89ecd564be67a7a9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 22 Oct 2024 11:07:24 -0400 Subject: [PATCH 046/110] Migrate documentation for `regr*` aggregate functions to code (#12871) * Migrate documentation for regr* functions to code * Fix double expression * Fix logical conflict --- datafusion/functions-aggregate/src/regr.rs | 187 ++++++++++++++---- .../user-guide/sql/aggregate_functions.md | 172 +--------------- .../user-guide/sql/aggregate_functions_new.md | 126 ++++++++++++ 3 files changed, 280 insertions(+), 205 deletions(-) diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 390a769aca7f..a1fc5b094276 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,9 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::fmt::Debug; - use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -29,10 +26,17 @@ use arrow::{ }; use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::OnceLock; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -76,23 +80,7 @@ impl Regr { } } -/* -#[derive(Debug)] -pub struct Regr { - name: String, - regr_type: RegrType, - expr_y: Arc, - expr_x: Arc, -} - -impl Regr { - pub fn get_regr_type(&self) -> RegrType { - self.regr_type.clone() - } -} -*/ - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum RegrType { /// Variant for `regr_slope` aggregate expression @@ -135,6 +123,148 @@ pub enum RegrType { SXY, } +impl RegrType { + /// return the documentation for the `RegrType` + fn documentation(&self) -> Option<&Documentation> { + get_regr_docs().get(self) + } +} + +static DOCUMENTATION: OnceLock> = OnceLock::new(); +fn get_regr_docs() -> &'static HashMap { + DOCUMENTATION.get_or_init(|| { + let mut hash_map = HashMap::new(); + hash_map.insert( + RegrType::Slope, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \ + Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.", + ) + .with_syntax_example("regr_slope(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Intercept, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \ + this function returns b.", + ) + .with_syntax_example("regr_intercept(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Count, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Counts the number of non-null paired data points.", + ) + .with_syntax_example("regr_count(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::R2, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the square of the correlation coefficient between the independent and dependent variables.", + ) + .with_syntax_example("regr_r2(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the independent variable (input) expression_x for the non-null paired data points.", + ) + .with_syntax_example("regr_avgx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.", + ) + .with_syntax_example("regr_avgy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the independent variable.", + ) + .with_syntax_example("regr_sxx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SYY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the dependent variable.", + ) + .with_syntax_example("regr_syy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of products of paired data points.", + ) + .with_syntax_example("regr_sxy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + hash_map + }) +} + impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self @@ -198,22 +328,11 @@ impl AggregateUDFImpl for Regr { ), ]) } -} -/* -impl PartialEq for Regr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.expr_y.eq(&x.expr_y) - && self.expr_x.eq(&x.expr_x) - }) - .unwrap_or(false) + fn documentation(&self) -> Option<&Documentation> { + self.regr_type.documentation() } } -*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 4f774fe6d0f0..77f527c92cda 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -19,174 +19,4 @@ # Aggregate Functions -Aggregate functions operate on a set of values to compute a single result. - -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Aggregate Functions (new)](aggregate_functions_new.md) page for -the rest of the documentation. - -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 - -## Statistical - -- [covar](#covar) -- [regr_avgx](#regr_avgx) -- [regr_avgy](#regr_avgy) -- [regr_count](#regr_count) -- [regr_intercept](#regr_intercept) -- [regr_r2](#regr_r2) -- [regr_slope](#regr_slope) -- [regr_sxx](#regr_sxx) -- [regr_syy](#regr_syy) -- [regr_sxy](#regr_sxy) - -### `covar` - -Returns the covariance of a set of number pairs. - -``` -covar(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_slope` - -Returns the slope of the linear regression line for non-null pairs in aggregate columns. -Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. - -``` -regr_slope(expression1, expression2) -``` - -#### Arguments - -- **expression_y**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_avgx` - -Computes the average of the independent variable (input) `expression_x` for the non-null paired data points. - -``` -regr_avgx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_avgy` - -Computes the average of the dependent variable (output) `expression_y` for the non-null paired data points. - -``` -regr_avgy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_count` - -Counts the number of non-null paired data points. - -``` -regr_count(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_intercept` - -Computes the y-intercept of the linear regression line. For the equation \(y = kx + b\), this function returns `b`. - -``` -regr_intercept(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_r2` - -Computes the square of the correlation coefficient between the independent and dependent variables. - -``` -regr_r2(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_sxx` - -Computes the sum of squares of the independent variable. - -``` -regr_sxx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_syy` - -Computes the sum of squares of the dependent variable. - -``` -regr_syy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `regr_sxy` - -Computes the sum of products of paired data points. - -``` -regr_sxy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +Note: this documentation has been migrated to [Aggregate Functions (new)](aggregate_functions_new.md) diff --git a/docs/source/user-guide/sql/aggregate_functions_new.md b/docs/source/user-guide/sql/aggregate_functions_new.md index 24ef313f3d49..ad6d15b94ee5 100644 --- a/docs/source/user-guide/sql/aggregate_functions_new.md +++ b/docs/source/user-guide/sql/aggregate_functions_new.md @@ -468,6 +468,15 @@ _Alias of [var](#var)._ - [covar_pop](#covar_pop) - [covar_samp](#covar_samp) - [nth_value](#nth_value) +- [regr_avgx](#regr_avgx) +- [regr_avgy](#regr_avgy) +- [regr_count](#regr_count) +- [regr_intercept](#regr_intercept) +- [regr_r2](#regr_r2) +- [regr_slope](#regr_slope) +- [regr_sxx](#regr_sxx) +- [regr_sxy](#regr_sxy) +- [regr_syy](#regr_syy) - [stddev](#stddev) - [stddev_pop](#stddev_pop) - [stddev_samp](#stddev_samp) @@ -581,6 +590,123 @@ nth_value(expression, n ORDER BY expression) +---------+--------+-------------------------+ ``` +### `regr_avgx` + +Computes the average of the independent variable (input) expression_x for the non-null paired data points. + +``` +regr_avgx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_avgy` + +Computes the average of the dependent variable (output) expression_y for the non-null paired data points. + +``` +regr_avgy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_count` + +Counts the number of non-null paired data points. + +``` +regr_count(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_intercept` + +Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. + +``` +regr_intercept(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_r2` + +Computes the square of the correlation coefficient between the independent and dependent variables. + +``` +regr_r2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_slope` + +Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. + +``` +regr_slope(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxx` + +Computes the sum of squares of the independent variable. + +``` +regr_sxx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxy` + +Computes the sum of products of paired data points. + +``` +regr_sxy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_syy` + +Computes the sum of squares of the dependent variable. + +``` +regr_syy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + ### `stddev` Returns the standard deviation of a set of numbers. From 91d2886d7494c1966c5a05876b26f09f47a5e1c2 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Tue, 22 Oct 2024 19:07:57 +0400 Subject: [PATCH 047/110] fix(substrait): disallow union with a single input (#13023) * fix(substrait): disallow union with a single input * flip if condition --- .../substrait/src/logical_plan/consumer.rs | 98 +++++++------------ 1 file changed, 34 insertions(+), 64 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5f1824bc4b30..8a8d195507a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -937,72 +937,42 @@ pub async fn from_substrait_rel( } } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { - Ok(set_op) => match set_op { - set_rel::SetOp::UnionAll => { - if !set.inputs.is_empty() { - union_rels(&set.inputs, ctx, extensions, true).await - } else { - not_impl_err!("Union relation requires at least one input") - } - } - set_rel::SetOp::UnionDistinct => { - if !set.inputs.is_empty() { - union_rels(&set.inputs, ctx, extensions, false).await - } else { - not_impl_err!("Union relation requires at least one input") - } - } - set_rel::SetOp::IntersectionPrimary => { - if set.inputs.len() >= 2 { - LogicalPlanBuilder::intersect( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - union_rels(&set.inputs[1..], ctx, extensions, true).await?, - false, - ) - } else { - not_impl_err!( - "Primary Intersect relation requires at least two inputs" - ) - } - } - set_rel::SetOp::IntersectionMultiset => { - if set.inputs.len() >= 2 { - intersect_rels(&set.inputs, ctx, extensions, false).await - } else { - not_impl_err!( - "Multiset Intersect relation requires at least two inputs" - ) - } - } - set_rel::SetOp::IntersectionMultisetAll => { - if set.inputs.len() >= 2 { - intersect_rels(&set.inputs, ctx, extensions, true).await - } else { - not_impl_err!( - "MultisetAll Intersect relation requires at least two inputs" - ) - } - } - set_rel::SetOp::MinusPrimary => { - if set.inputs.len() >= 2 { - except_rels(&set.inputs, ctx, extensions, false).await - } else { - not_impl_err!( - "Primary Minus relation requires at least two inputs" - ) - } - } - set_rel::SetOp::MinusPrimaryAll => { - if set.inputs.len() >= 2 { - except_rels(&set.inputs, ctx, extensions, true).await - } else { - not_impl_err!( - "PrimaryAll Minus relation requires at least two inputs" - ) + Ok(set_op) => { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set_op { + set_rel::SetOp::UnionAll => { + union_rels(&set.inputs, ctx, extensions, true).await + } + set_rel::SetOp::UnionDistinct => { + union_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionPrimary => { + LogicalPlanBuilder::intersect( + from_substrait_rel(ctx, &set.inputs[0], extensions) + .await?, + union_rels(&set.inputs[1..], ctx, extensions, true) + .await?, + false, + ) + } + set_rel::SetOp::IntersectionMultiset => { + intersect_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionMultisetAll => { + intersect_rels(&set.inputs, ctx, extensions, true).await + } + set_rel::SetOp::MinusPrimary => { + except_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::MinusPrimaryAll => { + except_rels(&set.inputs, ctx, extensions, true).await + } + _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } } - _ => not_impl_err!("Unsupported set operator: {set_op:?}"), - }, + } Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { From 818ce3f01efe1213a9a1eda5dff1542bb9d457f7 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 22 Oct 2024 17:27:55 +0200 Subject: [PATCH 048/110] refactor: Incorporate RewriteDisjunctivePredicate rule into SimplifyExpressions (#13032) * Elliminate common factors in disjunctions This adds a rewrite rule that elliminates common factors in OR. This is already implmented in RewriteDisjunctivePredicate but this implementation is simpler and will apply in more cases. * Remove RewriteDisjunctivePredicate rule * Fix cse test * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/utils.rs | 48 ++ datafusion/optimizer/src/lib.rs | 1 - datafusion/optimizer/src/optimizer.rs | 2 - datafusion/optimizer/src/push_down_filter.rs | 4 +- .../src/rewrite_disjunctive_predicate.rs | 430 ------------------ .../simplify_expressions/expr_simplifier.rs | 73 ++- datafusion/sqllogictest/test_files/cse.slt | 8 +- .../sqllogictest/test_files/explain.slt | 2 - 8 files changed, 126 insertions(+), 442 deletions(-) delete mode 100644 datafusion/optimizer/src/rewrite_disjunctive_predicate.rs diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 9ee13f1e06d3..86562daf6909 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1101,6 +1101,54 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& } } +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction(expr: &Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(right); + stack.push(left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(expr), + other => return Some(other), + } + } + None + }) +} + +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(*right); + stack.push(*left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + other => return Some(other), + } + } + None + }) +} + /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// This is often used to "split" filter expressions such as `col1 = 5 diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 3b1df3510d2a..f31083831125 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -51,7 +51,6 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; -pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 08dcefa22f08..373c87718789 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,7 +51,6 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -255,7 +254,6 @@ impl Optimizer { // run it again after running the optimizations that potentially converted // subqueries to joins Arc::new(SimplifyExpressions::new()), - Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 2e3bca5b0bbd..ac81f3efaa11 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1213,7 +1213,7 @@ mod tests { }; use crate::optimizer::Optimizer; - use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; + use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; @@ -1235,7 +1235,7 @@ mod tests { expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ - Arc::new(RewriteDisjunctivePredicate::new()), + Arc::new(SimplifyExpressions::new()), Arc::new(PushDownFilter::new()), ]); let optimized_plan = diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs deleted file mode 100644 index a6b633fdb8fe..000000000000 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; -use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::logical_plan::Filter; -use datafusion_expr::{Expr, LogicalPlan, Operator}; - -/// Optimizer pass that rewrites predicates of the form -/// -/// ```text -/// (A = B AND ) OR (A = B AND ) OR ... (A = B AND ) -/// ``` -/// -/// Into -/// ```text -/// (A = B) AND ( OR OR ... ) -/// ``` -/// -/// Predicates connected by `OR` typically not able to be broken down -/// and distributed as well as those connected by `AND`. -/// -/// The idea is to rewrite predicates into `good_predicate1 AND -/// good_predicate2 AND ...` where `good_predicate` means the -/// predicate has special support in the execution engine. -/// -/// Equality join predicates (e.g. `col1 = col2`), or single column -/// expressions (e.g. `col = 5`) are examples of predicates with -/// special support. -/// -/// # TPCH Q19 -/// -/// This optimization is admittedly somewhat of a niche usecase. It's -/// main use is that it appears in TPCH Q19 and is required to avoid a -/// CROSS JOIN. -/// -/// Specifically, Q19 has a WHERE clause that looks like -/// -/// ```sql -/// where -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// or -/// ( -/// p_partkey = l_partkey -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// ) -/// ) -/// ``` -/// -/// Naively planning this query will result in a CROSS join with that -/// single large OR filter. However, rewriting it using the rewrite in -/// this pass results in a proper join predicate, `p_partkey = l_partkey`: -/// -/// ```sql -/// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( -/// ( -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// ) -/// ) -/// ``` -/// -#[derive(Default, Debug)] -pub struct RewriteDisjunctivePredicate; - -impl RewriteDisjunctivePredicate { - pub fn new() -> Self { - Self - } -} - -impl OptimizerRule for RewriteDisjunctivePredicate { - fn name(&self) -> &str { - "rewrite_disjunctive_predicate" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = predicate(filter.predicate)?; - let rewritten_predicate = rewrite_predicate(predicate); - let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - rewritten_expr, - filter.input, - )?))) - } - _ => Ok(Transformed::no(plan)), - } - } -} - -#[derive(Clone, PartialEq, Debug)] -enum Predicate { - And { args: Vec }, - Or { args: Vec }, - Other { expr: Box }, -} - -fn predicate(expr: Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::And { args }) - } - Operator::Or => { - let args = vec![predicate(*left)?, predicate(*right)?]; - Ok(Predicate::Or { args }) - } - _ => Ok(Predicate::Other { - expr: Box::new(Expr::BinaryExpr(BinaryExpr::new(left, op, right))), - }), - }, - _ => Ok(Predicate::Other { - expr: Box::new(expr), - }), - } -} - -fn normalize_predicate(predicate: Predicate) -> Expr { - match predicate { - Predicate::And { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::and) - .expect("had more than one arg") - } - Predicate::Or { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::or) - .expect("had more than one arg") - } - Predicate::Other { expr } => *expr, - } -} - -fn rewrite_predicate(predicate: Predicate) -> Predicate { - match predicate { - Predicate::And { args } => { - let mut rewritten_args = Vec::with_capacity(args.len()); - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_and_predicates(rewritten_args); - Predicate::And { - args: rewritten_args, - } - } - Predicate::Or { args } => { - let mut rewritten_args = vec![]; - for arg in args.into_iter() { - rewritten_args.push(rewrite_predicate(arg)); - } - rewritten_args = flatten_or_predicates(rewritten_args); - delete_duplicate_predicates(rewritten_args) - } - Predicate::Other { expr } => Predicate::Other { expr }, - } -} - -fn flatten_and_predicates( - and_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in and_predicates { - match predicate { - Predicate::And { args } => { - flattened_predicates.append(&mut flatten_and_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn flatten_or_predicates( - or_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in or_predicates { - match predicate { - Predicate::Or { args } => { - flattened_predicates.append(&mut flatten_or_predicates(args)); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn delete_duplicate_predicates(or_predicates: Vec) -> Predicate { - let mut shortest_exprs: Vec = vec![]; - let mut shortest_exprs_len = 0; - // choose the shortest AND predicate - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - let args_num = args.len(); - if shortest_exprs.is_empty() || args_num < shortest_exprs_len { - shortest_exprs.clone_from(args); - shortest_exprs_len = args_num; - } - } - _ => { - // if there is no AND predicate, it must be the shortest expression. - shortest_exprs = vec![or_predicate.clone()]; - break; - } - } - } - - // dedup shortest_exprs - shortest_exprs.dedup(); - - // Check each element in shortest_exprs to see if it's in all the OR arguments. - let mut exist_exprs: Vec = vec![]; - for expr in shortest_exprs.iter() { - let found = or_predicates.iter().all(|or_predicate| match or_predicate { - Predicate::And { args } => args.contains(expr), - _ => or_predicate == expr, - }); - if found { - exist_exprs.push((*expr).clone()); - } - } - if exist_exprs.is_empty() { - return Predicate::Or { - args: or_predicates, - }; - } - - // Rebuild the OR predicate. - // (A AND B) OR A will be optimized to A. - let mut new_or_predicates = vec![]; - for or_predicate in or_predicates.into_iter() { - match or_predicate { - Predicate::And { mut args } => { - args.retain(|expr| !exist_exprs.contains(expr)); - if !args.is_empty() { - if args.len() == 1 { - new_or_predicates.push(args.remove(0)); - } else { - new_or_predicates.push(Predicate::And { args }); - } - } else { - new_or_predicates.clear(); - break; - } - } - _ => { - if exist_exprs.contains(&or_predicate) { - new_or_predicates.clear(); - break; - } - } - } - } - if !new_or_predicates.is_empty() { - if new_or_predicates.len() == 1 { - exist_exprs.push(new_or_predicates.remove(0)); - } else { - exist_exprs.push(Predicate::Or { - args: flatten_or_predicates(new_or_predicates), - }); - } - } - - if exist_exprs.len() == 1 { - exist_exprs.remove(0) - } else { - Predicate::And { - args: flatten_and_predicates(exist_exprs), - } - } -} - -#[cfg(test)] -mod tests { - use crate::rewrite_disjunctive_predicate::{ - normalize_predicate, predicate, rewrite_predicate, Predicate, - }; - - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{and, col, lit, or}; - - #[test] - fn test_rewrite_predicate() -> Result<()> { - let equi_expr = col("t1.a").eq(col("t2.b")); - let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); - let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); - let expr = or( - and(equi_expr.clone(), gt_expr.clone()), - and(equi_expr.clone(), lt_expr.clone()), - ); - let predicate = predicate(expr)?; - assert_eq!( - predicate, - Predicate::Or { - args: vec![ - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - ] - }, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_predicate = rewrite_predicate(predicate); - assert_eq!( - rewritten_predicate, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Or { - args: vec![ - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_expr = normalize_predicate(rewritten_predicate); - assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); - Ok(()) - } -} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2bac71a6ae1f..f9dfadc70826 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,14 +32,18 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; +use datafusion_expr::{ + expr::{InList, InSubquery, WindowFunction}, + utils::{iter_conjunction, iter_conjunction_owned}, +}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use indexmap::IndexSet; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -850,6 +854,27 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + // Eliminate common factors in conjunctions e.g + // (A AND B) OR (A AND C) -> A AND (B OR C) + Expr::BinaryExpr(BinaryExpr { + left, + op: Or, + right, + }) if has_common_conjunction(&left, &right) => { + let lhs: IndexSet = iter_conjunction_owned(*left).collect(); + let (common, rhs): (Vec<_>, Vec<_>) = + iter_conjunction_owned(*right).partition(|e| lhs.contains(e)); + + let new_rhs = rhs.into_iter().reduce(and); + let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and); + let common_conjunction = common.into_iter().reduce(and).unwrap(); + + let new_expr = match (new_lhs, new_rhs) { + (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)), + (_, _) => common_conjunction, + }; + Transformed::yes(new_expr) + } // // Rules for AND @@ -1656,6 +1681,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { + let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect(); + iter_conjunction(rhs).any(|e| lhs.contains(&e)) +} + // TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq_and_match_neg( left: &Expr, @@ -3743,6 +3773,47 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + + fn boolean_test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("A", DataType::Boolean, false), + Field::new("B", DataType::Boolean, false), + Field::new("C", DataType::Boolean, false), + Field::new("D", DataType::Boolean, false), + ]) + .to_dfschema_ref() + .unwrap() + } + + #[test] + fn simplify_common_factor_conjuction_in_disjunction() { + let props = ExecutionProps::new(); + let schema = boolean_test_schema(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + + let a = || col("A"); + let b = || col("B"); + let c = || col("C"); + let d = || col("D"); + + // (A AND B) OR (A AND C) -> A AND (B OR C) + let expr = a().and(b()).or(a().and(c())); + let expected = a().and(b().or(c())); + + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D) + let expr = a().and(b()).or(a().and(c())).or(a().and(d())); + let expected = a().and(b().or(c()).or(d())); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // A OR (B AND C AND A) -> A + let expr = a().or(b().and(c().and(a()))); + let expected = a(); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + } + #[test] fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt index 9f0f654179e9..c95e9a1309f8 100644 --- a/datafusion/sqllogictest/test_files/cse.slt +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -199,18 +199,18 @@ physical_plan # Surely only once but also conditionally evaluated subexpressions query TT EXPLAIN SELECT - (a = 1 OR random() = 0) AND (a = 1 OR random() = 1) AS c1, - (a = 2 AND random() = 0) OR (a = 2 AND random() = 1) AS c2, + (a = 1 OR random() = 0) AND (a = 2 OR random() = 1) AS c1, + (a = 2 AND random() = 0) OR (a = 1 AND random() = 1) AS c2, CASE WHEN a + 3 = 0 THEN a + 3 + random() ELSE 0 END AS c3, CASE WHEN a + 4 = 0 THEN 0 ELSE a + 4 + random() END AS c4 FROM t1 ---- logical_plan -01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_1 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_2 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4 03)----TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_1@0 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_2@1 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] 02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4] 03)----MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b1962ffcc116..54340604ad40 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -189,7 +189,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE @@ -216,7 +215,6 @@ logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE From afecd7be32b23dfb5b9adf94793584c22113cc39 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 22 Oct 2024 13:41:53 -0700 Subject: [PATCH 049/110] Move filtered SMJ right join out of `join_partial` phase (#13053) * Move filtered SMJ right join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 4 +- .../src/joins/sort_merge_join.rs | 257 +++++++----------- .../test_files/sort_merge_join.slt | 56 ++-- 3 files changed, 122 insertions(+), 195 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 2eab45256dbb..ca2c2bf4e438 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -149,8 +149,6 @@ async fn test_right_join_1k() { } #[tokio::test] -// flaky for HjSmj case -// https://github.com/apache/datafusion/issues/12359 async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -158,7 +156,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 5e77becd1c5e..d5134855440a 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -727,15 +727,19 @@ impl RecordBatchStream for SMJStream { } } +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices #[inline(always)] fn last_index_for_row( row_index: usize, indices: &UInt64Array, - ids: &[usize], + batch_ids: &[usize], indices_len: usize, ) -> bool { row_index == indices_len - 1 - || ids[row_index] != ids[row_index + 1] + || batch_ids[row_index] != batch_ids[row_index + 1] || indices.value(row_index) != indices.value(row_index + 1) } @@ -746,21 +750,21 @@ fn last_index_for_row( // `false` - the row sent as NULL joined row fn get_corrected_filter_mask( join_type: JoinType, - indices: &UInt64Array, - ids: &[usize], + row_indices: &UInt64Array, + batch_ids: &[usize], filter_mask: &BooleanArray, expected_size: usize, ) -> Option { - let streamed_indices_length = indices.len(); + let row_indices_length = row_indices.len(); let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(streamed_indices_length); + BooleanBuilder::with_capacity(row_indices_length); let mut seen_true = false; match join_type { - JoinType::Left => { - for i in 0..streamed_indices_length { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { let last_index = - last_index_for_row(i, indices, ids, streamed_indices_length); + last_index_for_row(i, row_indices, batch_ids, row_indices_length); if filter_mask.value(i) { seen_true = true; corrected_mask.append_value(true); @@ -781,9 +785,9 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } JoinType::LeftSemi => { - for i in 0..streamed_indices_length { + for i in 0..row_indices_length { let last_index = - last_index_for_row(i, indices, ids, streamed_indices_length); + last_index_for_row(i, row_indices, batch_ids, row_indices_length); if filter_mask.value(i) && !seen_true { seen_true = true; corrected_mask.append_value(true); @@ -828,7 +832,9 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right ) { self.freeze_all()?; @@ -904,7 +910,7 @@ impl Stream for SMJStream { let record_batch = if !(self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left | JoinType::LeftSemi | JoinType::Right )) { record_batch } else { @@ -923,7 +929,7 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi + JoinType::Left | JoinType::LeftSemi | JoinType::Right ) { let out = self.filter_joined_batch()?; @@ -1445,7 +1451,6 @@ impl SMJStream { }; let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. @@ -1512,7 +1517,10 @@ impl SMJStream { }; // Push the filtered batch which contains rows passing join filter to the output - if matches!(self.join_type, JoinType::Left | JoinType::LeftSemi) { + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi | JoinType::Right + ) { self.output_record_batches .batches .push(output_batch.clone()); @@ -1534,7 +1542,7 @@ impl SMJStream { // all joined rows are failed on the join filter. // I.e., if all rows joined from a streamed row are failed with the join filter, // we need to join it with nulls as buffered side. - if matches!(self.join_type, JoinType::Right | JoinType::Full) { + if matches!(self.join_type, JoinType::Full) { // We need to get the mask for row indices that the joined rows are failed // on the join filter. I.e., for a row in streamed side, if all joined rows // between it and all buffered rows are failed on the join filter, we need to @@ -1552,7 +1560,7 @@ impl SMJStream { let null_joined_batch = filter_record_batch(&output_batch, ¬_mask)?; - let mut buffered_columns = self + let buffered_columns = self .buffered_schema .fields() .iter() @@ -1564,18 +1572,7 @@ impl SMJStream { }) .collect::>(); - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join + let columns = { let mut streamed_columns = null_joined_batch .columns() .iter() @@ -1590,6 +1587,7 @@ impl SMJStream { // Push the streamed/buffered batch joined nulls to the output let null_joined_streamed_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + self.output_record_batches .batches .push(null_joined_streamed_batch); @@ -1654,7 +1652,10 @@ impl SMJStream { } if !(self.filter.is_some() - && matches!(self.join_type, JoinType::Left | JoinType::LeftSemi)) + && matches!( + self.join_type, + JoinType::Left | JoinType::LeftSemi | JoinType::Right + )) { self.output_record_batches.batches.clear(); } @@ -3333,8 +3334,7 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_left_outer_join_filtered_mask() -> Result<()> { + fn build_joined_record_batches() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -3342,14 +3342,14 @@ mod tests { Field::new("y", DataType::Int32, true), ])); - let mut tb = JoinedRecordBatches { + let mut batches = JoinedRecordBatches { batches: vec![], filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }; - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3359,7 +3359,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -3369,7 +3369,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3379,7 +3379,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -3389,7 +3389,7 @@ mod tests { ], )?); - tb.batches.push(RecordBatch::try_new( + batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -3400,41 +3400,62 @@ mod tests { )?); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![1]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![1; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0]; - tb.batch_ids.extend(vec![2; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![3; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); - tb.filter_mask + batches + .filter_mask .extend(&BooleanArray::from(vec![true, false])); - tb.filter_mask.extend(&BooleanArray::from(vec![true])); - tb.filter_mask + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask .extend(&BooleanArray::from(vec![false, true])); - tb.filter_mask.extend(&BooleanArray::from(vec![false])); - tb.filter_mask + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask .extend(&BooleanArray::from(vec![false, false])); - let output = concat_batches(&schema, &tb.batches)?; - let out_mask = tb.filter_mask.finish(); - let out_indices = tb.row_indices.finish(); + Ok(batches) + } + + #[tokio::test] + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0]), &[0usize], &BooleanArray::from(vec![true]), @@ -3448,7 +3469,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0]), &[0usize], &BooleanArray::from(vec![false]), @@ -3462,7 +3483,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0]), &[0usize; 2], &BooleanArray::from(vec![true, true]), @@ -3476,7 +3497,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![true, true, true]), @@ -3488,7 +3509,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![true, false, true]), @@ -3509,7 +3530,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, false, true]), @@ -3530,7 +3551,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, true, true]), @@ -3551,7 +3572,7 @@ mod tests { assert_eq!( get_corrected_filter_mask( - JoinType::Left, + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], &BooleanArray::from(vec![false, false, false]), @@ -3571,9 +3592,9 @@ mod tests { ); let corrected_mask = get_corrected_filter_mask( - JoinType::Left, + Left, &out_indices, - &tb.batch_ids, + &joined_batches.batch_ids, &out_mask, output.num_rows(), ) @@ -3643,102 +3664,12 @@ mod tests { #[tokio::test] async fn test_left_semi_join_filtered_mask() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - Field::new("y", DataType::Int32, true), - ])); - - let mut tb = JoinedRecordBatches { - batches: vec![], - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], - }; - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![10, 10])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 9])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![11])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 12])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 13])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![13])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - tb.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![14, 14])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 11])), - ], - )?); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![1]; - tb.batch_ids.extend(vec![0; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![1; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0]; - tb.batch_ids.extend(vec![2; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - tb.batch_ids.extend(vec![3; streamed_indices.len()]); - tb.row_indices.extend(&UInt64Array::from(streamed_indices)); - - tb.filter_mask - .extend(&BooleanArray::from(vec![true, false])); - tb.filter_mask.extend(&BooleanArray::from(vec![true])); - tb.filter_mask - .extend(&BooleanArray::from(vec![false, true])); - tb.filter_mask.extend(&BooleanArray::from(vec![false])); - tb.filter_mask - .extend(&BooleanArray::from(vec![false, false])); + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &tb.batches)?; - let out_mask = tb.filter_mask.finish(); - let out_indices = tb.row_indices.finish(); + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); assert_eq!( get_corrected_filter_mask( @@ -3839,7 +3770,7 @@ mod tests { let corrected_mask = get_corrected_filter_mask( LeftSemi, &out_indices, - &tb.batch_ids, + &joined_batches.batch_ids, &out_mask, output.num_rows(), ) diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index d00b7d6f0a52..051cc6dce3d4 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -100,14 +100,13 @@ Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 -# Uncomment when filtered RIGHT moved # right join with join filter -#query TITI rowsort -#SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b -#---- -#Alice 100 Alice 1 -#Alice 100 Alice 2 -#Alice 50 Alice 1 +query TITI rowsort +SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b +---- +Alice 100 Alice 1 +Alice 100 Alice 2 +Alice 50 Alice 1 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -137,7 +136,7 @@ Bob 1 NULL NULL #Bob 1 NULL NULL #NULL NULL Alice 1 -# Uncomment when filtered RIGHT moved +# Uncomment when filtered FULL moved #query TITI rowsort #SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 #---- @@ -617,27 +616,26 @@ set datafusion.execution.batch_size = 1; #) order by 1, 2 #---- -# Uncomment when filtered RIGHT moved -#query IIII -#select * from ( -#with t as ( -# select id, id % 5 id1 from (select unnest(range(0,10)) id) -#), t1 as ( -# select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) -#) -#select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 -#) order by 1, 2, 3, 4 -#---- -#5 0 0 2 -#6 1 1 3 -#7 2 2 4 -#8 3 3 5 -#9 4 4 6 -#NULL NULL 5 7 -#NULL NULL 6 8 -#NULL NULL 7 9 -#NULL NULL 8 10 -#NULL NULL 9 11 +query IIII +select * from ( +with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) +), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +) +select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +) order by 1, 2, 3, 4 +---- +5 0 0 2 +6 1 1 3 +7 2 2 4 +8 3 3 5 +9 4 4 6 +NULL NULL 5 7 +NULL NULL 6 8 +NULL NULL 7 9 +NULL NULL 8 10 +NULL NULL 9 11 query IIII select * from ( From cf60da9045a7aade515380193d1a17b40d2154fd Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 22 Oct 2024 22:45:46 +0200 Subject: [PATCH 050/110] Remove functions and types deprecated since 37 (#13056) * Remove deprecated CatalogList trait Deprecated since v 35. * Remove deprecated unnest_column(s) dataframe functions Deprecated since v 37. * Remove deprecated LogicalPlan::inspect_expressions Deprecated since v 37. --- datafusion/core/src/catalog_common/mod.rs | 4 ---- datafusion/core/src/dataframe/mod.rs | 25 +---------------------- datafusion/expr/src/logical_plan/plan.rs | 21 ------------------- 3 files changed, 1 insertion(+), 49 deletions(-) diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs index 85207845a005..68c78dda4899 100644 --- a/datafusion/core/src/catalog_common/mod.rs +++ b/datafusion/core/src/catalog_common/mod.rs @@ -36,10 +36,6 @@ pub use datafusion_sql::{ResolvedTableReference, TableReference}; use std::collections::BTreeSet; use std::ops::ControlFlow; -/// See [`CatalogProviderList`] -#[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")] -pub trait CatalogList: CatalogProviderList {} - /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. /// This can be used to determine which tables need to be in the catalog for a query to be planned. /// diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4feadd260d7f..d1d49bfaa693 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -373,32 +373,9 @@ impl DataFrame { self.select(expr) } - /// Expand each list element of a column to multiple rows. - #[deprecated(since = "37.0.0", note = "use unnest_columns instead")] - pub fn unnest_column(self, column: &str) -> Result { - self.unnest_columns(&[column]) - } - - /// Expand each list element of a column to multiple rows, with - /// behavior controlled by [`UnnestOptions`]. - /// - /// Please see the documentation on [`UnnestOptions`] for more - /// details about the meaning of unnest. - #[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")] - pub fn unnest_column_with_options( - self, - column: &str, - options: UnnestOptions, - ) -> Result { - self.unnest_columns_with_options(&[column], options) - } - /// Expand multiple list/struct columns into a set of rows and new columns. /// - /// See also: - /// - /// 1. [`UnnestOptions`] documentation for the behavior of `unnest` - /// 2. [`Self::unnest_column_with_options`] + /// See also: [`UnnestOptions`] documentation for the behavior of `unnest` /// /// # Example /// ``` diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 72d8f7158be2..d8dfe7b56e40 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -423,27 +423,6 @@ impl LogicalPlan { exprs } - #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> - where - F: FnMut(&Expr) -> Result<(), E>, - { - let mut err = Ok(()); - self.apply_expressions(|e| { - if let Err(e) = f(e) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err - } - /// Returns all inputs / children of this `LogicalPlan` node. /// /// Note does not include inputs to inputs, or subqueries. From d3920f3060fc3745b8a50170dafb2beaa898adc2 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Tue, 22 Oct 2024 18:46:49 -0400 Subject: [PATCH 051/110] Minor: Cleaned physical-plan Comments (#13055) * Fixed documentations * small fixes --- datafusion/physical-plan/src/analyze.rs | 8 ++--- .../physical-plan/src/coalesce_batches.rs | 2 +- .../physical-plan/src/coalesce_partitions.rs | 4 +-- datafusion/physical-plan/src/common.rs | 14 ++++---- datafusion/physical-plan/src/display.rs | 1 + datafusion/physical-plan/src/empty.rs | 2 +- datafusion/physical-plan/src/explain.rs | 4 +-- datafusion/physical-plan/src/filter.rs | 18 +++++----- datafusion/physical-plan/src/insert.rs | 2 +- datafusion/physical-plan/src/limit.rs | 34 +++++++++---------- datafusion/physical-plan/src/memory.rs | 4 +-- .../physical-plan/src/placeholder_row.rs | 6 ++-- datafusion/physical-plan/src/projection.rs | 10 +++--- datafusion/physical-plan/src/stream.rs | 30 ++++++++-------- datafusion/physical-plan/src/streaming.rs | 6 ++-- datafusion/physical-plan/src/test.rs | 6 ++-- datafusion/physical-plan/src/unnest.rs | 28 +++++++-------- datafusion/physical-plan/src/values.rs | 6 ++-- datafusion/physical-plan/src/work_table.rs | 12 +++---- 19 files changed, 99 insertions(+), 98 deletions(-) diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 287446328f8d..c8b329fabdaa 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -40,9 +40,9 @@ use futures::StreamExt; /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] pub struct AnalyzeExec { - /// control how much extra to print + /// Control how much extra to print verbose: bool, - /// if statistics should be displayed + /// If statistics should be displayed show_statistics: bool, /// The input plan (the plan being analyzed) pub(crate) input: Arc, @@ -69,12 +69,12 @@ impl AnalyzeExec { } } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } - /// access to show_statistics + /// Access to show_statistics pub fn show_statistics(&self) -> bool { self.show_statistics } diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index e1a2f32d8a38..61fb3599f013 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -52,7 +52,7 @@ use futures::stream::{Stream, StreamExt}; pub struct CoalesceBatchesExec { /// The input plan input: Arc, - /// Minimum number of rows for coalesces batches + /// Minimum number of rows for coalescing batches target_batch_size: usize, /// Maximum number of rows to fetch, `None` means fetching all rows fetch: Option, diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 2ab6e3de1add..f9d4ec6a1a34 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -236,10 +236,10 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); - let coaelesce_partitions_exec = + let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, task_ctx); + let fut = collect(coalesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 5abdf367c571..844208999d25 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -109,7 +109,7 @@ pub(crate) fn spawn_buffered( builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { - // receiver dropped when query is shutdown early (e.g., limit) or error, + // Receiver dropped when query is shutdown early (e.g., limit) or error, // no need to return propagate the send error. return Ok(()); } @@ -182,15 +182,15 @@ pub fn compute_record_batch_statistics( /// Write in Arrow IPC format. pub struct IPCWriter { - /// path + /// Path pub path: PathBuf, - /// inner writer + /// Inner writer pub writer: FileWriter, - /// batches written + /// Batches written pub num_batches: usize, - /// rows written + /// Rows written pub num_rows: usize, - /// bytes written + /// Bytes written pub num_bytes: usize, } @@ -315,7 +315,7 @@ mod tests { ], )?; - // just select f32,f64 + // Just select f32,f64 let select_projection = Some(vec![0, 1]); let byte_size = batch .project(&select_projection.clone().unwrap()) diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 0d2653c5c775..4e936fb37a12 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -231,6 +231,7 @@ impl<'a> DisplayableExecutionPlan<'a> { } } +/// Enum representing the different levels of metrics to display #[derive(Debug, Clone, Copy)] enum ShowMetrics { /// Do not show any metrics diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 4bacea48c347..f6e0abb94fa8 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -173,7 +173,7 @@ mod tests { let empty = EmptyExec::new(Arc::clone(&schema)); assert_eq!(empty.schema(), schema); - // we should have no results + // We should have no results let iter = empty.execute(0, task_ctx)?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 56dc35e8819d..96f55a1446b0 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -67,7 +67,7 @@ impl ExplainExec { &self.stringified_plans } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } @@ -112,7 +112,7 @@ impl ExecutionPlan for ExplainExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c39a91e251b7..30b0af19f43b 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -115,7 +115,7 @@ impl FilterExec { /// Return new instance of [FilterExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { - // check if the projection is valid + // Check if the projection is valid can_project(&self.schema(), projection.as_ref())?; let projection = match projection { @@ -157,7 +157,7 @@ impl FilterExec { self.default_selectivity } - /// projection + /// Projection pub fn projection(&self) -> Option<&Vec> { self.projection.as_ref() } @@ -255,9 +255,9 @@ impl FilterExec { let expr = Arc::new(column) as _; ConstExpr::new(expr).with_across_partitions(true) }); - // this is for statistics + // This is for statistics eq_properties = eq_properties.with_constants(constants); - // this is for logical constant (for example: a = '1', then a could be marked as a constant) + // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) eq_properties = eq_properties.with_constants(Self::extend_constants(input, predicate)); @@ -331,7 +331,7 @@ impl ExecutionPlan for FilterExec { } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -425,7 +425,7 @@ struct FilterExecStream { predicate: Arc, /// The input partition to filter. input: SendableRecordBatchStream, - /// runtime metrics recording + /// Runtime metrics recording baseline_metrics: BaselineMetrics, /// The projection indices of the columns in the input schema projection: Option>, @@ -449,7 +449,7 @@ fn filter_and_project( .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { Ok(match (as_boolean_array(&array), projection) { - // apply filter array to record batch + // Apply filter array to record batch (Ok(filter_array), None) => filter_record_batch(batch, filter_array)?, (Ok(filter_array), Some(projection)) => { let projected_columns = projection @@ -490,7 +490,7 @@ impl Stream for FilterExecStream { &self.schema, )?; timer.done(); - // skip entirely filtered batches + // Skip entirely filtered batches if filtered_batch.num_rows() == 0 { continue; } @@ -507,7 +507,7 @@ impl Stream for FilterExecStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 5dc27bc239d2..dda45ebebb0c 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -271,7 +271,7 @@ fn make_count_batch(count: u64) -> RecordBatch { } fn make_count_schema() -> SchemaRef { - // define a schema. + // Define a schema. Arc::new(Schema::new(vec![Field::new( "count", DataType::UInt64, diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index a42e2da60587..eda75b37fe66 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -398,7 +398,7 @@ impl LimitStream { if batch.num_rows() > 0 { break poll; } else { - // continue to poll input stream + // Continue to poll input stream } } Poll::Ready(Some(Err(_e))) => break poll, @@ -408,12 +408,12 @@ impl LimitStream { } } - /// fetches from the batch + /// Fetches from the batch fn stream_limit(&mut self, batch: RecordBatch) -> Option { // records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); if self.fetch == 0 { - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early None } else if batch.num_rows() < self.fetch { // @@ -422,7 +422,7 @@ impl LimitStream { } else if batch.num_rows() >= self.fetch { let batch_rows = self.fetch; self.fetch = 0; - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early // It is guaranteed that batch_rows is <= batch.num_rows Some(batch.slice(0, batch_rows)) @@ -453,7 +453,7 @@ impl Stream for LimitStream { other => other, }) } - // input has been cleared + // Input has been cleared None => Poll::Ready(None), }; @@ -489,17 +489,17 @@ mod tests { let num_partitions = 4; let csv = test::scan_partitioned(num_partitions); - // input should have 4 partitions + // Input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let limit = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7)); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = limit.execute(0, task_ctx)?; let batches = common::collect(iter).await?; - // there should be a total of 100 rows + // There should be a total of 100 rows let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); assert_eq!(row_count, 7); @@ -520,7 +520,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (5 rows) and 1 row from the second (1 row) let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -550,7 +550,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -580,7 +580,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -598,7 +598,7 @@ mod tests { Ok(()) } - // test cases for "skip" + // Test cases for "skip" async fn skip_and_fetch(skip: usize, fetch: Option) -> Result { let task_ctx = Arc::new(TaskContext::default()); @@ -611,7 +611,7 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = offset.execute(0, task_ctx)?; let batches = common::collect(iter).await?; Ok(batches.iter().map(|batch| batch.num_rows()).sum()) @@ -633,7 +633,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 3 rows (offset = 3) + // There are total of 400 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, None).await?; assert_eq!(row_count, 397); Ok(()) @@ -641,7 +641,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_10_stats() -> Result<()> { - // there are total of 100 rows, we skipped 3 rows (offset = 3) + // There are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); Ok(()) @@ -656,7 +656,7 @@ mod tests { #[tokio::test] async fn skip_400_fetch_1() -> Result<()> { - // there are a total of 400 rows + // There are a total of 400 rows let row_count = skip_and_fetch(400, Some(1)).await?; assert_eq!(row_count, 0); Ok(()) @@ -664,7 +664,7 @@ mod tests { #[tokio::test] async fn skip_401_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 401 rows (offset = 3) + // There are total of 400 rows, we skipped 401 rows (offset = 3) let row_count = skip_and_fetch(401, None).await?; assert_eq!(row_count, 0); Ok(()) diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 456f0ef2dcc8..52a8631d5a63 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -119,7 +119,7 @@ impl ExecutionPlan for MemoryExec { } fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children + // This is a leaf node and has no children vec![] } @@ -179,7 +179,7 @@ impl MemoryExec { }) } - /// set `show_sizes` to determine whether to display partition sizes + /// Set `show_sizes` to determine whether to display partition sizes pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { self.show_sizes = show_sizes; self diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 272211d5056e..5d8ca7e76935 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -208,7 +208,7 @@ mod tests { let schema = test::aggr_test_schema(); let placeholder = PlaceholderRowExec::new(schema); - // ask for the wrong partition + // Ask for the wrong partition assert!(placeholder.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(placeholder.execute(20, task_ctx).is_err()); Ok(()) @@ -223,7 +223,7 @@ mod tests { let iter = placeholder.execute(0, task_ctx)?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); Ok(()) @@ -240,7 +240,7 @@ mod tests { let iter = placeholder.execute(n, Arc::clone(&task_ctx))?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index a28328fb5d43..936cf742a792 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -90,7 +90,7 @@ impl ProjectionExec { input_schema.metadata().clone(), )); - // construct a map from the input expressions to the output expression of the Projection + // Construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; let cache = Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; @@ -183,7 +183,7 @@ impl ExecutionPlan for ProjectionExec { } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -240,7 +240,7 @@ impl ExecutionPlan for ProjectionExec { } } -/// If e is a direct column reference, returns the field level +/// If 'e' is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None pub(crate) fn get_field_metadata( e: &Arc, @@ -294,7 +294,7 @@ fn stats_projection( impl ProjectionStream { fn batch_project(&self, batch: &RecordBatch) -> Result { - // records time on drop + // Records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); let arrays = self .expr @@ -340,7 +340,7 @@ impl Stream for ProjectionStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index faeb4799f5af..9220646653e6 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -56,7 +56,7 @@ pub(crate) struct ReceiverStreamBuilder { } impl ReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(capacity: usize) -> Self { let (tx, rx) = tokio::sync::mpsc::channel(capacity); @@ -83,10 +83,10 @@ impl ReceiverStreamBuilder { } /// Spawn a blocking task that will be aborted if this builder (or the stream - /// built from it) are dropped + /// built from it) are dropped. /// - /// this is often used to spawn tasks that write to the sender - /// retrieved from `Self::tx` + /// This is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx`. pub fn spawn_blocking(&mut self, f: F) where F: FnOnce() -> Result<()>, @@ -103,7 +103,7 @@ impl ReceiverStreamBuilder { mut join_set, } = self; - // don't need tx + // Doesn't need tx drop(tx); // future that checks the result of the join set, and propagates panic if seen @@ -112,7 +112,7 @@ impl ReceiverStreamBuilder { match result { Ok(task_result) => { match task_result { - // nothing to report + // Nothing to report Ok(_) => continue, // This means a blocking task error Err(error) => return Some(Err(error)), @@ -215,7 +215,7 @@ pub struct RecordBatchReceiverStreamBuilder { } impl RecordBatchReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { Self { schema, @@ -256,7 +256,7 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn_blocking(f) } - /// runs the `partition` of the `input` ExecutionPlan on the + /// Runs the `partition` of the `input` ExecutionPlan on the /// tokio threadpool and writes its outputs to this stream /// /// If the input partition produces an error, the error will be @@ -299,7 +299,7 @@ impl RecordBatchReceiverStreamBuilder { return Ok(()); } - // stop after the first error is encontered (don't + // Stop after the first error is encountered (Don't // drive all streams to completion) if is_err { debug!( @@ -483,13 +483,13 @@ mod test { async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { let schema = schema(); - // make 2 partitions, second partition panics before the first + // Make 2 partitions, second partition panics before the first let num_partitions = 2; let input = PanicExec::new(Arc::clone(&schema), num_partitions) .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) - // ensure that the panic results in an early shutdown (that + // Ensure that the panic results in an early shutdown (that // everything stops after the first panic). // Since the stream reads every other batch: (0,1,0,1,0,panic) @@ -512,10 +512,10 @@ mod test { builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx)); let stream = builder.build(); - // input should still be present + // Input should still be present assert!(std::sync::Weak::strong_count(&refs) > 0); - // drop the stream, ensure the refs go to zero + // Drop the stream, ensure the refs go to zero drop(stream); assert_strong_count_converges_to_zero(refs).await; } @@ -539,7 +539,7 @@ mod test { builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx)); let mut stream = builder.build(); - // get the first result, which should be an error + // Get the first result, which should be an error let first_batch = stream.next().await.unwrap(); let first_err = first_batch.unwrap_err(); assert_eq!(first_err.strip_backtrace(), "Execution error: Test1"); @@ -570,7 +570,7 @@ mod test { } let mut stream = builder.build(); - // drain the stream until it is complete, panic'ing on error + // Drain the stream until it is complete, panic'ing on error let mut num_batches = 0; while let Some(next) = stream.next().await { next.unwrap(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index b02e4fb5738d..0f7c75c2c90b 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -295,7 +295,7 @@ mod test { #[tokio::test] async fn test_no_limit() { let exec = TestBuilder::new() - // make 2 batches, each with 100 rows + // Make 2 batches, each with 100 rows .with_batches(vec![make_partition(100), make_partition(100)]) .build(); @@ -306,9 +306,9 @@ mod test { #[tokio::test] async fn test_limit() { let exec = TestBuilder::new() - // make 2 batches, each with 100 rows + // Make 2 batches, each with 100 rows .with_batches(vec![make_partition(100), make_partition(100)]) - // limit to only the first 75 rows back + // Limit to only the first 75 rows back .with_limit(Some(75)) .build(); diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 4da43b313403..90ec9b106850 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -65,7 +65,7 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } -/// returns record batch with 3 columns of i32 in memory +/// Returns record batch with 3 columns of i32 in memory pub fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -88,7 +88,7 @@ pub fn build_table_i32( .unwrap() } -/// returns memory table scan wrapped around record batch with 3 columns of i32 +/// Returns memory table scan wrapped around record batch with 3 columns of i32 pub fn build_table_scan_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -125,7 +125,7 @@ pub fn mem_exec(partitions: usize) -> MemoryExec { MemoryExec::try_new(&data, schema, projection).unwrap() } -// construct a stream partition for test purposes +// Construct a stream partition for test purposes #[derive(Debug)] pub struct TestPartitionStream { pub schema: SchemaRef, diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 2311541816f3..40ec3830ea0c 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -62,9 +62,9 @@ pub struct UnnestExec { input: Arc, /// The schema once the unnest is applied schema: SchemaRef, - /// indices of the list-typed columns in the input schema + /// Indices of the list-typed columns in the input schema list_column_indices: Vec, - /// indices of the struct-typed columns in the input schema + /// Indices of the struct-typed columns in the input schema struct_column_indices: Vec, /// Options options: UnnestOptions, @@ -115,12 +115,12 @@ impl UnnestExec { &self.input } - /// indices of the list-typed columns in the input schema + /// Indices of the list-typed columns in the input schema pub fn list_column_indices(&self) -> &[ListUnnest] { &self.list_column_indices } - /// indices of the struct-typed columns in the input schema + /// Indices of the struct-typed columns in the input schema pub fn struct_column_indices(&self) -> &[usize] { &self.struct_column_indices } @@ -203,7 +203,7 @@ impl ExecutionPlan for UnnestExec { #[derive(Clone, Debug)] struct UnnestMetrics { - /// total time for column unnesting + /// Total time for column unnesting elapsed_compute: metrics::Time, /// Number of batches consumed input_batches: metrics::Count, @@ -411,7 +411,7 @@ fn list_unnest_at_level( level_to_unnest: usize, options: &UnnestOptions, ) -> Result<(Vec, usize)> { - // extract unnestable columns at this level + // Extract unnestable columns at this level let (arrs_to_unnest, list_unnest_specs): (Vec>, Vec<_>) = list_type_unnests .iter() @@ -422,7 +422,7 @@ fn list_unnest_at_level( *unnesting, )); } - // this means the unnesting on this item has started at higher level + // This means the unnesting on this item has started at higher level // and need to continue until depth reaches 1 if level_to_unnest < unnesting.depth { return Some(( @@ -434,7 +434,7 @@ fn list_unnest_at_level( }) .unzip(); - // filter out so that list_arrays only contain column with the highest depth + // Filter out so that list_arrays only contain column with the highest depth // at the same time, during iteration remove this depth so next time we don't have to unnest them again let longest_length = find_longest_length(&arrs_to_unnest, options)?; let unnested_length = longest_length.as_primitive::(); @@ -456,7 +456,7 @@ fn list_unnest_at_level( // Create the take indices array for other columns let take_indices = create_take_indicies(unnested_length, total_length); - // dimension of arrays in batch is untouch, but the values are repeated + // Dimension of arrays in batch is untouched, but the values are repeated // as the side effect of unnesting let ret = repeat_arrs_from_indices(batch, &take_indices)?; unnested_temp_arrays @@ -548,8 +548,8 @@ fn build_batch( // This arr always has the same column count with the input batch let mut flatten_arrs = vec![]; - // original batch has the same columns - // all unnesting results are written to temp_batch + // Original batch has the same columns + // All unnesting results are written to temp_batch for depth in (1..=max_recursion).rev() { let input = match depth == max_recursion { true => batch.columns(), @@ -593,11 +593,11 @@ fn build_batch( .map(|(order, unnest_def)| (*unnest_def, order)) .collect(); - // one original column may be unnested multiple times into separate columns + // One original column may be unnested multiple times into separate columns let mut multi_unnested_per_original_index = unnested_array_map .into_iter() .map( - // each item in unnested_columns is the result of unnesting the same input column + // Each item in unnested_columns is the result of unnesting the same input column // we need to sort them to conform with the original expression order // e.g unnest(unnest(col)) must goes before unnest(col) |(original_index, mut unnested_columns)| { @@ -636,7 +636,7 @@ fn build_batch( .into_iter() .enumerate() .flat_map(|(col_idx, arr)| { - // convert original column into its unnested version(s) + // Convert original column into its unnested version(s) // Plural because one column can be unnested with different recursion level // and into separate output columns match multi_unnested_per_original_index.remove(&col_idx) { diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index ab5b45463b0c..991146d245a7 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -47,7 +47,7 @@ pub struct ValuesExec { } impl ValuesExec { - /// create a new values exec from data as expr + /// Create a new values exec from data as expr pub fn try_new( schema: SchemaRef, data: Vec>>, @@ -57,7 +57,7 @@ impl ValuesExec { } let n_row = data.len(); let n_col = schema.fields().len(); - // we have this single row batch as a placeholder to satisfy evaluation argument + // We have this single row batch as a placeholder to satisfy evaluation argument // and generate a single output row let batch = RecordBatch::try_new_with_options( Arc::new(Schema::empty()), @@ -126,7 +126,7 @@ impl ValuesExec { }) } - /// provides the data + /// Provides the data pub fn data(&self) -> Vec { self.data.clone() } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index ba95640a87c7..61d444171cc7 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -225,31 +225,31 @@ mod tests { #[test] fn test_work_table() { let work_table = WorkTable::new(); - // can't take from empty work_table + // Can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); - // update batch to work_table + // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap(); reservation.try_grow(100).unwrap(); work_table.update(ReservedBatches::new(vec![batch.clone()], reservation)); - // take from work_table + // Take from work_table let reserved_batches = work_table.take().unwrap(); assert_eq!(reserved_batches.batches, vec![batch.clone()]); - // consume the batch by the MemoryStream + // Consume the batch by the MemoryStream let memory_stream = MemoryStream::try_new(reserved_batches.batches, batch.schema(), None) .unwrap() .with_reservation(reserved_batches.reservation); - // should still be reserved + // Should still be reserved assert_eq!(pool.reserved(), 100); - // the reservation should be freed after drop the memory_stream + // The reservation should be freed after drop the memory_stream drop(memory_stream); assert_eq!(pool.reserved(), 0); } From cfe05b85cf75e5b0ddd7bd866ddca48b2de16738 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 23 Oct 2024 14:24:14 +0800 Subject: [PATCH 052/110] improve the condition checking for unparsing table_scan (#13062) * improve the condition check * fix tests * remove unnecessary clone --- datafusion/sql/src/unparser/plan.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 77f885c1de5f..037748035fbf 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -677,10 +677,10 @@ impl Unparser<'_> { // // Example: // select t1.c1 from t1 where t1.c1 > 1 -> select a.c1 from t1 as a where a.c1 > 1 - if alias.is_some() - && (table_scan.projection.is_some() || !table_scan.filters.is_empty()) - { - builder = builder.alias(alias.clone().unwrap())?; + if let Some(ref alias) = alias { + if table_scan.projection.is_some() || !table_scan.filters.is_empty() { + builder = builder.alias(alias.clone())?; + } } if let Some(project_vec) = &table_scan.projection { @@ -733,10 +733,10 @@ impl Unparser<'_> { // So we will append the alias to this subquery. // Example: // select * from t1 limit 10 -> (select * from t1 limit 10) as a - if alias.is_some() - && (table_scan.projection.is_none() && table_scan.filters.is_empty()) - { - builder = builder.alias(alias.clone().unwrap())?; + if let Some(alias) = alias { + if table_scan.projection.is_none() && table_scan.filters.is_empty() { + builder = builder.alias(alias)?; + } } Ok(Some(builder.build()?)) From d2a5e27fda0f9a613b6dc1e2b3289c1f981c71ac Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 23 Oct 2024 14:49:02 +0800 Subject: [PATCH 053/110] minor: simplify associated item bound of `hash_array_primitive` (#13070) --- datafusion/common/src/hash_utils.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 72cfeafd0bfe..8bd646626e06 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -102,8 +102,7 @@ fn hash_array_primitive( hashes_buffer: &mut [u64], rehash: bool, ) where - T: ArrowPrimitiveType, - ::Native: HashValue, + T: ArrowPrimitiveType, { assert_eq!( hashes_buffer.len(), From 3aa9714aac78c934c3b3f982f7a3ba38f52e40e8 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Wed, 23 Oct 2024 09:10:43 +0200 Subject: [PATCH 054/110] Run optimzer rules on subqueries by default (#13066) This patch makes it so that rules the configure an `apply_order` will also include subqueries in their traversel. This is a step twoards being able to run TPC-DS q41 (#4763) which has an expressions that needs simplification before we can decorrelate the subquery. This closes #3770 and maybe #2480 --- datafusion/optimizer/src/optimizer.rs | 14 +++------- .../sqllogictest/test_files/explain.slt | 2 -- .../sqllogictest/test_files/subquery.slt | 28 +++++++++---------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 373c87718789..90a790a0e841 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -28,7 +28,7 @@ use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -250,10 +250,6 @@ impl Optimizer { Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - // simplify expressions does not simplify expressions in subqueries, so we - // run it again after running the optimizations that potentially converted - // subqueries to joins - Arc::new(SimplifyExpressions::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), @@ -384,11 +380,9 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( - apply_order, - rule.as_ref(), - config, - )), + Some(apply_order) => new_plan.rewrite_with_subqueries( + &mut Rewriter::new(apply_order, rule.as_ref(), config), + ), // rule handles recursion itself None => optimize_plan_node(new_plan, rule.as_ref(), config), } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 54340604ad40..1340fd490e06 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -188,7 +188,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE @@ -214,7 +213,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 22857dd285c2..ab6dc3a9e588 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -391,7 +391,7 @@ logical_plan 01)Filter: EXISTS () 02)--Subquery: 03)----Projection: t1.t1_int -04)------Filter: t1.t1_id > t1.t1_int +04)------Filter: t1.t1_int < t1.t1_id 05)--------TableScan: t1 06)--TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -462,8 +462,8 @@ explain SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1 logical_plan 01)Projection: t1.t1_id, () AS t2_int 02)--Subquery: -03)----Limit: skip=0, fetch=1 -04)------Projection: t2.t2_int +03)----Projection: t2.t2_int +04)------Limit: skip=0, fetch=1 05)--------Filter: t2.t2_int = outer_ref(t1.t1_int) 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_int] @@ -475,8 +475,8 @@ logical_plan 01)Projection: t1.t1_id 02)--Filter: t1.t1_int = () 03)----Subquery: -04)------Limit: skip=0, fetch=1 -05)--------Projection: t2.t2_int +04)------Projection: t2.t2_int +05)--------Limit: skip=0, fetch=1 06)----------Filter: t2.t2_int = outer_ref(t1.t1_int) 07)------------TableScan: t2 08)----TableScan: t1 projection=[t1_id, t1_int] @@ -542,13 +542,13 @@ query TT explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name)) ---- logical_plan -01)Filter: EXISTS () -02)--Subquery: -03)----Projection: Int64(1) -04)------Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) -05)--------TableScan: t1 -06)--------TableScan: t2 -07)--TableScan: t0 projection=[t0_id, t0_name] +01)LeftSemi Join: t0.t0_name = __correlated_sq_2.t1_name +02)--TableScan: t0 projection=[t0_id, t0_name] +03)--SubqueryAlias: __correlated_sq_2 +04)----Projection: t1.t1_name +05)------Inner Join: t1.t1_id = t2.t2_id +06)--------TableScan: t1 projection=[t1_id, t1_name] +07)--------TableScan: t2 projection=[t2_id] #subquery_contains_join_contains_correlated_columns query TT @@ -656,8 +656,8 @@ explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where logical_plan 01)Filter: t1.t1_id IN () 02)--Subquery: -03)----Limit: skip=0, fetch=10 -04)------Projection: t2.t2_id +03)----Projection: t2.t2_id +04)------Limit: skip=0, fetch=10 05)--------Filter: outer_ref(t1.t1_name) = t2.t2_name 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_name] From 521966a249cd86c42b9bf52b7179a64973b7dbfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20=C5=9Een?= Date: Wed, 23 Oct 2024 11:05:45 +0300 Subject: [PATCH 055/110] extended log.rs tests for unary/binary and f32/f64 casting (#13034) * added tests * added for binary * added scalar tests --- datafusion/functions/src/math/log.rs | 181 +++++++++++++++++- datafusion/sqllogictest/test_files/scalar.slt | 31 +++ 2 files changed, 211 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index f82c0df34e27..9d2e1be3df9d 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -261,13 +261,192 @@ mod tests { use super::*; - use arrow::array::{Float32Array, Float64Array}; + use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; + #[test] + #[should_panic] + fn test_log_invalid_base_type() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ]; + + let _ = LogFunc::new().invoke(&args); + } + + #[test] + fn test_log_invalid_value() { + let args = [ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ]; + + let result = LogFunc::new().invoke(&args); + result.expect_err("expected error"); + } + + #[test] + fn test_log_scalar_f32_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f32() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 5.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f64_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f32_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + #[test] fn test_log_f64() { let args = [ diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index d510206b1930..145172f31fd7 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -536,6 +536,37 @@ select log(a, 64) a, log(b), log(10, b) from signed_integers; NaN 2 2 NaN 4 4 +# log overloaded base 10 float64 and float32 casting scalar +query RR rowsort +select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b; +---- +1 2 + +# log overloaded base 10 float64 and float32 casting with columns +query RR rowsort +select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers; +---- +0.301029995664 NaN +0.602059991328 NULL +NaN 2 +NaN 4 + +# log float64 and float32 casting scalar +query RR rowsort +select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b; +---- +3 4 + +# log float64 and float32 casting with columns +query RR rowsort +select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers; +---- +1 NaN +2 NULL +NaN 3.321928 +NaN 6.643856 + + ## log10 # log10 scalar function From a4e6b075b8fe448e0f1fde948626d7c43c1d9c4a Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 23 Oct 2024 07:00:20 -0400 Subject: [PATCH 056/110] feat: Convert CumeDist to UDWF (#13051) * Transferred cumedist * fixes * remove expr tests * small fix * small fix * check * clippy fix * roundtrip fix --- .../expr/src/built_in_window_function.rs | 6 - datafusion/expr/src/expr.rs | 23 +-- datafusion/expr/src/window_function.rs | 5 - datafusion/functions-window/src/cume_dist.rs | 170 ++++++++++++++++++ datafusion/functions-window/src/lib.rs | 3 + .../physical-expr/src/expressions/mod.rs | 1 - .../physical-expr/src/window/cume_dist.rs | 145 --------------- datafusion/physical-expr/src/window/mod.rs | 1 - datafusion/physical-plan/src/windows/mod.rs | 3 +- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 3 +- datafusion/sqllogictest/test_files/window.slt | 4 +- .../user-guide/sql/window_functions_new.md | 9 + 18 files changed, 195 insertions(+), 199 deletions(-) create mode 100644 datafusion/functions-window/src/cume_dist.rs delete mode 100644 datafusion/physical-expr/src/window/cume_dist.rs diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 2c70a07a4e15..36916a6b594f 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -40,8 +40,6 @@ impl fmt::Display for BuiltInWindowFunction { /// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, /// Integer ranging from 1 to the argument value, dividing the partition as equally as possible Ntile, /// returns value evaluated at the row that is the first row of the window frame @@ -56,7 +54,6 @@ impl BuiltInWindowFunction { pub fn name(&self) -> &str { use BuiltInWindowFunction::*; match self { - CumeDist => "CUME_DIST", Ntile => "NTILE", FirstValue => "first_value", LastValue => "last_value", @@ -69,7 +66,6 @@ impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name.to_uppercase().as_str() { - "CUME_DIST" => BuiltInWindowFunction::CumeDist, "NTILE" => BuiltInWindowFunction::Ntile, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, @@ -102,7 +98,6 @@ impl BuiltInWindowFunction { match self { BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::CumeDist => Ok(DataType::Float64), BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), @@ -113,7 +108,6 @@ impl BuiltInWindowFunction { pub fn signature(&self) -> Signature { // Note: The physical expression must accept the type returned by this function or the execution panics. match self { - BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 691b65d34443..7fadf6391bf3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2567,15 +2567,6 @@ mod test { Ok(()) } - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - #[test] fn test_ntile_return_type() -> Result<()> { let fun = find_df_window_func("ntile").unwrap(); @@ -2587,13 +2578,7 @@ mod test { #[test] fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "cume_dist", - "ntile", - "first_value", - "last_value", - "nth_value", - ]; + let names = vec!["ntile", "first_value", "last_value", "nth_value"]; for name in names { let fun = find_df_window_func(name).unwrap(); let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); @@ -2609,12 +2594,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::CumeDist - )) - ); assert_eq!( find_df_window_func("first_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 3e1870c59c15..c13a028e4a30 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -17,11 +17,6 @@ use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; -/// Create an expression to represent the `cume_dist` window function -pub fn cume_dist() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) -} - /// Create an expression to represent the `ntile` window function pub fn ntile(arg: Expr) -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs new file mode 100644 index 000000000000..9e30c672fee5 --- /dev/null +++ b/datafusion/functions-window/src/cume_dist.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `cume_dist` window function implementation + +use datafusion_common::arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::Result; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, OnceLock}; + +define_udwf_and_expr!( + CumeDist, + cume_dist, + "Calculates the cumulative distribution of a value in a group of values." +); + +/// CumeDist calculates the cume_dist in the window function with order by +#[derive(Debug)] +pub struct CumeDist { + signature: Signature, +} + +impl CumeDist { + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for CumeDist { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for CumeDist { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "cume_dist" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_cume_dist_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cume_dist_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", + ) + .with_syntax_example("cume_dist()") + .build() + .unwrap() + }) +} + +#[derive(Debug, Default)] +pub(crate) struct CumeDistEvaluator; + +impl PartitionEvaluator for CumeDistEvaluator { + /// Computes the cumulative distribution for all rows in the partition + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let scalar = num_rows as f64; + let result = Float64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(0_u64, |acc, range| { + let len = range.end - range.start; + *acc += len as u64; + let value: f64 = (*acc as f64) / scalar; + let result = iter::repeat(value).take(len); + Some(result) + }) + .flatten(), + ); + Ok(Arc::new(result)) + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_float64_array; + + fn test_f64_result( + num_rows: usize, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let evaluator = CumeDistEvaluator; + let result = evaluator.evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; + let result = result.values().to_vec(); + assert_eq!(expected, result); + Ok(()) + } + + #[test] + #[allow(clippy::single_range_in_vec_init)] + fn test_cume_dist() -> Result<()> { + test_f64_result(0, vec![], vec![])?; + + test_f64_result(1, vec![0..1], vec![1.0])?; + + test_f64_result(2, vec![0..2], vec![1.0, 1.0])?; + + test_f64_result(4, vec![0..2, 2..4], vec![0.5, 0.5, 1.0, 1.0])?; + + Ok(()) + } +} diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 5a2aafa2892e..13a77977d579 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -32,6 +32,7 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; +pub mod cume_dist; pub mod lead_lag; pub mod rank; @@ -40,6 +41,7 @@ mod utils; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::cume_dist::cume_dist; pub use super::lead_lag::lag; pub use super::lead_lag::lead; pub use super::rank::{dense_rank, percent_rank, rank}; @@ -49,6 +51,7 @@ pub mod expr_fn { /// Returns all default window functions pub fn all_default_window_functions() -> Vec> { vec![ + cume_dist::cume_dist_udwf(), row_number::row_number_udwf(), lead_lag::lead_udwf(), lead_lag::lag_udwf(), diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 54b8aafdb4da..63047f6929c1 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,7 +35,6 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; -pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::nth_value::NthValue; pub use crate::window::ntile::Ntile; pub use crate::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs deleted file mode 100644 index 9720187ea83d..000000000000 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `cume_dist` that can evaluated -//! at runtime during query execution - -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::array::Float64Array; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::iter; -use std::ops::Range; -use std::sync::Arc; - -/// CumeDist calculates the cume_dist in the window function with order by -#[derive(Debug)] -pub struct CumeDist { - name: String, - /// Output data type - data_type: DataType, -} - -/// Create a cume_dist window function -pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist { - CumeDist { - name, - data_type: data_type.clone(), - } -} - -impl BuiltInWindowFunctionExpr for CumeDist { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(CumeDistEvaluator {})) - } -} - -#[derive(Debug)] -pub(crate) struct CumeDistEvaluator; - -impl PartitionEvaluator for CumeDistEvaluator { - fn evaluate_all_with_rank( - &self, - num_rows: usize, - ranks_in_partition: &[Range], - ) -> Result { - let scalar = num_rows as f64; - let result = Float64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(0_u64, |acc, range| { - let len = range.end - range.start; - *acc += len as u64; - let value: f64 = (*acc as f64) / scalar; - let result = iter::repeat(value).take(len); - Some(result) - }) - .flatten(), - ); - Ok(Arc::new(result)) - } - - fn include_rank(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_common::cast::as_float64_array; - - fn test_i32_result( - expr: &CumeDist, - num_rows: usize, - ranks: Vec>, - expected: Vec, - ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - #[allow(clippy::single_range_in_vec_init)] - fn test_cume_dist() -> Result<()> { - let r = cume_dist("arr".into(), &DataType::Float64); - - let expected = vec![0.0; 0]; - test_i32_result(&r, 0, vec![], expected)?; - - let expected = vec![1.0; 1]; - test_i32_result(&r, 1, vec![0..1], expected)?; - - let expected = vec![1.0; 2]; - test_i32_result(&r, 2, vec![0..2], expected)?; - - let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; - - let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index c0fe3c2933a7..7bab4dbc5af6 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -18,7 +18,6 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; -pub(crate) mod cume_dist; pub(crate) mod nth_value; pub(crate) mod ntile; mod sliding_aggregate; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index f6902fcbe2e7..39ff71496e21 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,7 +21,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{cume_dist, Literal, NthValue, Ntile, PhysicalSortExpr}, + expressions::{Literal, NthValue, Ntile, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; @@ -219,7 +219,6 @@ fn create_built_in_window_expr( let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index a15fa2c5f9c6..c92328278e83 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -511,7 +511,7 @@ enum BuiltInWindowFunction { // RANK = 1; // DENSE_RANK = 2; // PERCENT_RANK = 3; - CUME_DIST = 4; + // CUME_DIST = 4; NTILE = 5; // LAG = 6; // LEAD = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d223e3646b51..ca331cdaa513 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1662,7 +1662,6 @@ impl serde::Serialize for BuiltInWindowFunction { { let variant = match self { Self::Unspecified => "UNSPECIFIED", - Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", @@ -1679,7 +1678,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { const FIELDS: &[&str] = &[ "UNSPECIFIED", - "CUME_DIST", "NTILE", "FIRST_VALUE", "LAST_VALUE", @@ -1725,7 +1723,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { match value { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), - "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), "NTILE" => Ok(BuiltInWindowFunction::Ntile), "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6b234be57a92..fb0b3bcb2c13 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1818,7 +1818,7 @@ pub enum BuiltInWindowFunction { /// RANK = 1; /// DENSE_RANK = 2; /// PERCENT_RANK = 3; - CumeDist = 4, + /// CUME_DIST = 4; Ntile = 5, /// LAG = 6; /// LEAD = 7; @@ -1834,7 +1834,6 @@ impl BuiltInWindowFunction { pub fn as_str_name(&self) -> &'static str { match self { Self::Unspecified => "UNSPECIFIED", - Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", @@ -1845,7 +1844,6 @@ impl BuiltInWindowFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "UNSPECIFIED" => Some(Self::Unspecified), - "CUME_DIST" => Some(Self::CumeDist), "NTILE" => Some(Self::Ntile), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 99b11939e95b..4587c090c96a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -152,7 +152,6 @@ impl From for BuiltInWindowFunction { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, - protobuf::BuiltInWindowFunction::CumeDist => Self::CumeDist, protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, protobuf::BuiltInWindowFunction::NthValue => Self::NthValue, protobuf::BuiltInWindowFunction::LastValue => Self::LastValue, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a34a220e490c..dce0cd741fd3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -128,7 +128,6 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::LastValue => Self::LastValue, BuiltInWindowFunction::NthValue => Self::NthValue, BuiltInWindowFunction::Ntile => Self::Ntile, - BuiltInWindowFunction::CumeDist => Self::CumeDist, } } } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 33eca0723103..37ea6a2b47be 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,8 +23,8 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, + BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, + Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -108,9 +108,9 @@ pub fn serialize_physical_window_expr( let expr = built_in_window_expr.get_built_in_func_expr(); let built_in_fn_expr = expr.as_any(); - let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::CumeDist - } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { + let builtin_fn = if let Some(ntile_expr) = + built_in_fn_expr.downcast_ref::() + { args.insert( 0, Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c017395d979f..a8c82ff80f23 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -48,7 +48,7 @@ use datafusion::functions_aggregate::expr_fn::{ use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; use datafusion::functions_window::expr_fn::{ - dense_rank, lag, lead, percent_rank, rank, row_number, + cume_dist, dense_rank, lag, lead, percent_rank, rank, row_number, }; use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; @@ -940,6 +940,7 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(1), lit(2), lit(3)], vec![lit(10), lit(20), lit(30)], ), + cume_dist(), row_number(), rank(), dense_rank(), diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index b3f2786d3dba..51e859275512 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1135,8 +1135,8 @@ SELECT query IRR SELECT c8, - CUME_DIST() OVER(ORDER BY c9) as cd1, - CUME_DIST() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 + cume_dist() OVER(ORDER BY c9) as cd1, + cume_dist() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 FROM aggregate_test_100 ORDER BY c8 LIMIT 5 diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md index 462fc900d139..89ce2284a70c 100644 --- a/docs/source/user-guide/sql/window_functions_new.md +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -157,11 +157,20 @@ All [aggregate functions](aggregate_functions.md) can be used as window function ## Ranking Functions +- [cume_dist](#cume_dist) - [dense_rank](#dense_rank) - [percent_rank](#percent_rank) - [rank](#rank) - [row_number](#row_number) +### `cume_dist` + +Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). + +``` +cume_dist() +``` + ### `dense_rank` Returns the rank of the current row without gaps. This function ranks rows in a dense manner, meaning consecutive ranks are assigned even for identical values. From 211e76ec0a2b9f0a62aa526dc059fb19c0fa0486 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 23 Oct 2024 13:01:12 +0200 Subject: [PATCH 057/110] Fix check_not_null_constraints null detection (#13033) * Fix function name typo * Fix check_not_null_constraints null detection `check_not_null_constraints` (aka `check_not_null_contraits`) checked for null using `Array::null_count` which does not return real null count. * Drop assertor dependency --- .../physical-plan/src/execution_plan.rs | 151 +++++++++++++++++- 1 file changed, 144 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index a89e265ad2f8..e6484452d43e 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use arrow_array::Array; use futures::stream::{StreamExt, TryStreamExt}; use tokio::task::JoinSet; @@ -852,7 +853,7 @@ pub fn execute_input_stream( Ok(Box::pin(RecordBatchStreamAdapter::new( sink_schema, input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), + .map(move |batch| check_not_null_constraints(batch?, &risky_columns)), ))) } } @@ -872,7 +873,7 @@ pub fn execute_input_stream( /// This function iterates over the specified column indices and ensures that none /// of the columns contain null values. If any column contains null values, an error /// is returned. -pub fn check_not_null_contraits( +pub fn check_not_null_constraints( batch: RecordBatch, column_indices: &Vec, ) -> Result { @@ -885,7 +886,13 @@ pub fn check_not_null_contraits( ); } - if batch.column(index).null_count() > 0 { + if batch + .column(index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default() + > 0 + { return exec_err!( "Invalid batch column at '{}' has null but schema specifies non-nullable", index @@ -920,11 +927,11 @@ pub enum CardinalityEffect { #[cfg(test)] mod tests { use super::*; + use arrow_array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; - use arrow_schema::{Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -1068,6 +1075,136 @@ mod tests { fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { let _ = plan.name(); } -} -// pub mod test; + #[test] + fn test_check_not_null_constraints_accept_non_null() -> Result<()> { + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_reject_null() -> Result<()> { + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_run_end_array() -> Result<()> { + // some null value inside REE array + let run_ends = Int32Array::from(vec![1, 2, 3, 4]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let run_end_array = RunArray::try_new(&run_ends, &values)?; + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + run_end_array.data_type().to_owned(), + true, + )])), + vec![Arc::new(run_end_array)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_array_with_null() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])); + let keys = Int32Array::from(vec![0, 1, 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_masking_null() -> Result<()> { + // some null value marked out by dictionary array + let values = Arc::new(Int32Array::from(vec![ + Some(1), + None, // this null value is masked by dictionary keys + Some(3), + Some(4), + ])); + let keys = Int32Array::from(vec![0, /*1,*/ 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_on_null_type() -> Result<()> { + // null value of Null type + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Null, true)])), + vec![Arc::new(NullArray::new(3))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{}' to start with '{}'", + actual, + expected_prefix + ); + } +} From 7a40344abb12966b1dc60ecfd9f23ed2484449b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 23 Oct 2024 19:14:56 +0200 Subject: [PATCH 058/110] Update list of TPC-DS queries (#13075) --- datafusion/core/benches/sql_planner.rs | 6 ++---- datafusion/core/tests/tpcds_planning.rs | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index e7c35c8d86d6..09f05c70fec6 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -203,10 +203,8 @@ fn criterion_benchmark(c: &mut Criterion) { let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); - // 10, 35: Physical plan does not support logical expression Exists() - // 45: Physical plan does not support logical expression () - // 41: Optimizing disjunctions not supported - let ignored = [10, 35, 41, 45]; + // 41: check_analyzed_plan: Correlated column is not allowed in predicate + let ignored = [41]; let raw_tpcds_sql_queries = (1..100) .filter(|q| !ignored.contains(q)) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 6beb29183483..0077a2d35b1f 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -230,8 +230,8 @@ async fn tpcds_logical_q40() -> Result<()> { #[tokio::test] #[ignore] -// Optimizer rule 'scalar_subquery_to_join' failed: Optimizing disjunctions not supported! -// issue: https://github.com/apache/datafusion/issues/5368 +// check_analyzed_plan: Correlated column is not allowed in predicate +// issue: https://github.com/apache/datafusion/issues/13074 async fn tpcds_logical_q41() -> Result<()> { create_logical_plan(41).await } From 3e940a91285562a4351074dd63c9b8706ff8c397 Mon Sep 17 00:00:00 2001 From: wiedld Date: Wed, 23 Oct 2024 16:08:07 -0700 Subject: [PATCH 059/110] Fix logical vs physical schema mismatch for UNION where some inputs are constants (#12954) * test(12733): reproducer of when metadata from the left side is not transferred to the right side * fix(12733): because either the left or right fields may be chosen, add metadata from both to each other * test(12733): update regression test to show that fix works * Add extra test to fix other issue with schema metadata * Fix union_schema to merge metadatas for both fields and schema * fmt --------- Co-authored-by: itsjunetime --- datafusion/physical-plan/src/union.rs | 48 +++++++++++-------- datafusion/sqllogictest/src/test_context.rs | 21 ++++++-- .../sqllogictest/test_files/metadata.slt | 27 +++++++++++ 3 files changed, 72 insertions(+), 24 deletions(-) diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 108e42e7be42..433dda870def 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -468,35 +468,41 @@ pub fn can_interleave>>( } fn union_schema(inputs: &[Arc]) -> SchemaRef { - let fields: Vec = (0..inputs[0].schema().fields().len()) + let first_schema = inputs[0].schema(); + + let fields = (0..first_schema.fields().len()) .map(|i| { inputs .iter() - .filter_map(|input| { - if input.schema().fields().len() > i { - let field = input.schema().field(i).clone(); - let right_hand_metdata = inputs - .get(1) - .map(|right_input| { - right_input.schema().field(i).metadata().clone() - }) - .unwrap_or_default(); - let mut metadata = field.metadata().clone(); - metadata.extend(right_hand_metdata); - Some(field.with_metadata(metadata)) - } else { - None - } + .enumerate() + .map(|(input_idx, input)| { + let field = input.schema().field(i).clone(); + let mut metadata = field.metadata().clone(); + + let other_metadatas = inputs + .iter() + .enumerate() + .filter(|(other_idx, _)| *other_idx != input_idx) + .flat_map(|(_, other_input)| { + other_input.schema().field(i).metadata().clone().into_iter() + }); + + metadata.extend(other_metadatas); + field.with_metadata(metadata) }) - .find_or_first(|f| f.is_nullable()) + .find_or_first(Field::is_nullable) + // We can unwrap this because if inputs was empty, this would've already panic'ed when we + // indexed into inputs[0]. .unwrap() }) + .collect::>(); + + let all_metadata_merged = inputs + .iter() + .flat_map(|i| i.schema().metadata().clone().into_iter()) .collect(); - Arc::new(Schema::new_with_metadata( - fields, - inputs[0].schema().metadata().clone(), - )) + Arc::new(Schema::new_with_metadata(fields, all_metadata_merged)) } /// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 2143b3089ee5..deeacb1b8819 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -319,17 +319,27 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { String::from("metadata_key"), String::from("the l_name field"), )])); + let ts = Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false) .with_metadata(HashMap::from([( String::from("metadata_key"), String::from("ts non-nullable field"), )])); - let schema = - Schema::new(vec![id, name, l_name, ts]).with_metadata(HashMap::from([( + let nonnull_name = + Field::new("nonnull_name", DataType::Utf8, false).with_metadata(HashMap::from([ + ( + String::from("metadata_key"), + String::from("the nonnull_name field"), + ), + ])); + + let schema = Schema::new(vec![id, name, l_name, ts, nonnull_name]).with_metadata( + HashMap::from([( String::from("metadata_key"), String::from("the entire schema"), - )])); + )]), + ); let batch = RecordBatch::try_new( Arc::new(schema), @@ -342,6 +352,11 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { 1599572549190855123, 1599572549190855123, ])) as _, + Arc::new(StringArray::from(vec![ + Some("no_foo"), + Some("no_bar"), + Some("no_baz"), + ])) as _, ], ) .unwrap(); diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 588a36e3d515..8f787254c096 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -123,7 +123,34 @@ ORDER BY id, name, l_name; NULL bar NULL NULL NULL l_bar +# Regression test: missing field metadata from left side of the union when right side is chosen +query T +select name from ( + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL +# Regression test: missing schema metadata from union when schema with metadata isn't the first one +# and also ensure it works fine with multiple unions +query T +select name from ( + SELECT NULL::string as name + UNION ALL + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL query P rowsort SELECT ts From de526a9e93f1ef38213eaa4d6a348ef7d7343f19 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 24 Oct 2024 02:09:03 +0200 Subject: [PATCH 060/110] Improve CSE stats (#13080) --- datafusion/common/src/cse.rs | 46 ++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 453ae26e7333..ab02915858cd 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -121,9 +121,17 @@ impl<'n, N: HashNode> Identifier<'n, N> { /// ``` type IdArray<'n, N> = Vec<(usize, Option>)>; -/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s -/// by their identifiers. -type NodeStats<'n, N> = HashMap, (usize, usize)>; +#[derive(PartialEq, Eq)] +/// How many times a node is evaluated. A node can be considered common if evaluated +/// surely at least 2 times or surely only once but also conditionally. +enum NodeEvaluation { + SurelyOnce, + ConditionallyAtLeastOnce, + Common, +} + +/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. +type NodeStats<'n, N> = HashMap, NodeEvaluation>; /// A map that contains the common [`TreeNode`]s and their alias by their identifiers, /// extracted during the second, rewriting traversal. @@ -331,16 +339,24 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito self.id_array[down_index].0 = self.up_index; if is_valid && !self.controller.is_ignored(node) { self.id_array[down_index].1 = Some(node_id); - let (count, conditional_count) = - self.node_stats.entry(node_id).or_insert((0, 0)); - if self.conditional { - *conditional_count += 1; - } else { - *count += 1; - } - if *count > 1 || (*count == 1 && *conditional_count > 0) { - self.found_common = true; - } + self.node_stats + .entry(node_id) + .and_modify(|evaluation| { + if *evaluation == NodeEvaluation::SurelyOnce + || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce + && !self.conditional + { + *evaluation = NodeEvaluation::Common; + self.found_common = true; + } + }) + .or_insert_with(|| { + if self.conditional { + NodeEvaluation::ConditionallyAtLeastOnce + } else { + NodeEvaluation::SurelyOnce + } + }); } self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); @@ -383,8 +399,8 @@ impl> TreeNodeRewriter // Handle nodes with identifiers only if let Some(node_id) = node_id { - let (count, conditional_count) = self.node_stats.get(&node_id).unwrap(); - if *count > 1 || *count == 1 && *conditional_count > 0 { + let evaluation = self.node_stats.get(&node_id).unwrap(); + if *evaluation == NodeEvaluation::Common { // step index to skip all sub-node (which has smaller series number). while self.down_index < self.id_array.len() && self.id_array[self.down_index].0 < up_index From 18b2aaa04d956475ddc74fbbde3725370e2d7bde Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 24 Oct 2024 08:21:48 +0800 Subject: [PATCH 061/110] Infer data type from schema for `Values` and add struct coercion to `coalesce` (#12864) * first draft Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * add values table without schema Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * rm unused import Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * use option instead of vec Signed-off-by: jayzhan211 * Fix clippy * add values back and rename Signed-off-by: jayzhan211 * invalid query Signed-off-by: jayzhan211 * use values if no schema Signed-off-by: jayzhan211 * add doc Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + datafusion/common/src/dfschema.rs | 1 - datafusion/expr-common/Cargo.toml | 1 + .../expr-common/src/type_coercion/binary.rs | 90 ++++++++++++++++- datafusion/expr/src/logical_plan/builder.rs | 97 +++++++++++++++++-- datafusion/functions-nested/src/make_array.rs | 64 +++--------- datafusion/functions/src/core/coalesce.rs | 7 +- datafusion/proto/src/logical_plan/mod.rs | 1 + datafusion/sql/src/planner.rs | 15 +++ datafusion/sql/src/statement.rs | 11 ++- datafusion/sql/src/values.rs | 15 ++- datafusion/sqllogictest/test_files/array.slt | 1 - .../test_files/create_external_table.slt | 1 - datafusion/sqllogictest/test_files/ddl.slt | 6 ++ .../sqllogictest/test_files/group_by.slt | 12 +-- datafusion/sqllogictest/test_files/joins.slt | 33 ++++--- datafusion/sqllogictest/test_files/struct.slt | 94 +++++++++++++++--- .../sqllogictest/test_files/subquery.slt | 40 +++++--- datafusion/sqllogictest/test_files/unnest.slt | 2 +- datafusion/sqllogictest/test_files/window.slt | 2 +- 20 files changed, 368 insertions(+), 126 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 401f203dd931..24649832b27e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1359,6 +1359,7 @@ version = "42.1.0" dependencies = [ "arrow", "datafusion-common", + "itertools", "paste", ] diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 9a1fe9bba267..aa2d93989da1 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -315,7 +315,6 @@ impl DFSchema { None => self_unqualified_names.contains(field.name().as_str()), }; if !duplicated_field { - // self.inner.fields.push(field.clone()); schema_builder.push(Arc::clone(field)); qualifiers.push(qualifier.cloned()); } diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 7e477efc4ebc..de11b19c3b06 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -40,4 +40,5 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true } +itertools = { workspace = true } paste = "^1.0" diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 887586f4f783..2f806bf76d16 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -28,7 +28,10 @@ use arrow::datatypes::{ DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result, +}; +use itertools::Itertools; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` @@ -372,6 +375,8 @@ impl From<&DataType> for TypeCategory { /// decimal precision and scale when coercing decimal types. /// /// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. +/// +/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; @@ -529,6 +534,89 @@ fn type_union_resolution_coercion( } } +/// Handle type union resolution including struct type and others. +pub fn try_type_union_resolution(data_types: &[DataType]) -> Result> { + let err = match try_type_union_resolution_with_struct(data_types) { + Ok(struct_types) => return Ok(struct_types), + Err(e) => Some(e), + }; + + if let Some(new_type) = type_union_resolution(data_types) { + Ok(vec![new_type; data_types.len()]) + } else { + exec_err!("Fail to find the coerced type, errors: {:?}", err) + } +} + +// Handle struct where we only change the data type but preserve the field name and nullability. +// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" +pub fn try_type_union_resolution_with_struct( + data_types: &[DataType], +) -> Result> { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + } + } else { + keys_string = Some(keys); + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut struct_types: Vec = if let DataType::Struct(fields) = &data_types[0] + { + fields.iter().map(|f| f.data_type().to_owned()).collect() + } else { + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + for data_type in data_types.iter().skip(1) { + if let DataType::Struct(fields) = data_type { + let incoming_struct_types: Vec = + fields.iter().map(|f| f.data_type().to_owned()).collect(); + // The order of field is verified above + for (lhs_type, rhs_type) in + struct_types.iter_mut().zip(incoming_struct_types.iter()) + { + if let Some(coerced_type) = + type_union_resolution_coercion(lhs_type, rhs_type) + { + *lhs_type = coerced_type; + } else { + return exec_err!( + "Fail to find the coerced type for {} and {}", + lhs_type, + rhs_type + ); + } + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut final_struct_types = vec![]; + for s in data_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(struct_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + + Ok(final_struct_types) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a /// comparison operation /// diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 21304068a8ab..d2ecd56cdc23 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -46,13 +46,15 @@ use crate::{ use super::dml::InsertOp; use super::plan::ColumnUnnestList; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies, - Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + exec_err, get_target_functional_dependencies, internal_err, not_impl_err, + plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, + FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -172,12 +174,45 @@ impl LogicalPlanBuilder { /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. /// + /// so it's usually better to override the default names with a table alias list. + /// + /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + pub fn values(values: Vec>) -> Result { + if values.is_empty() { + return plan_err!("Values list cannot be empty"); + } + let n_cols = values[0].len(); + if n_cols == 0 { + return plan_err!("Values list cannot be zero length"); + } + for (i, row) in values.iter().enumerate() { + if row.len() != n_cols { + return plan_err!( + "Inconsistent data length across values list: got {} values in row {} but expected {}", + row.len(), + i, + n_cols + ); + } + } + + // Infer from data itself + Self::infer_data(values) + } + + /// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming + /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) + /// documentation for more details. + /// /// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table. /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. /// /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. - pub fn values(mut values: Vec>) -> Result { + pub fn values_with_schema( + values: Vec>, + schema: &DFSchemaRef, + ) -> Result { if values.is_empty() { return plan_err!("Values list cannot be empty"); } @@ -196,16 +231,53 @@ impl LogicalPlanBuilder { } } - let empty_schema = DFSchema::empty(); + // Check the type of value against the schema + Self::infer_values_from_schema(values, schema) + } + + fn infer_values_from_schema( + values: Vec>, + schema: &DFSchema, + ) -> Result { + let n_cols = values[0].len(); + let mut field_types: Vec = Vec::with_capacity(n_cols); + for j in 0..n_cols { + let field_type = schema.field(j).data_type(); + for row in values.iter() { + let value = &row[j]; + let data_type = value.get_type(schema)?; + + if !data_type.equals_datatype(field_type) { + if can_cast_types(&data_type, field_type) { + } else { + return exec_err!( + "type mistmatch and can't cast to got {} and {}", + data_type, + field_type + ); + } + } + } + field_types.push(field_type.to_owned()); + } + + Self::infer_inner(values, &field_types, schema) + } + + fn infer_data(values: Vec>) -> Result { + let n_cols = values[0].len(); + let schema = DFSchema::empty(); + let mut field_types: Vec = Vec::with_capacity(n_cols); for j in 0..n_cols { let mut common_type: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; - let data_type = value.get_type(&empty_schema)?; + let data_type = value.get_type(&schema)?; if data_type == DataType::Null { continue; } + if let Some(prev_type) = common_type { // get common type of each column values. let data_types = vec![prev_type.clone(), data_type.clone()]; @@ -221,14 +293,22 @@ impl LogicalPlanBuilder { // since the code loop skips NULL field_types.push(common_type.unwrap_or(DataType::Null)); } + + Self::infer_inner(values, &field_types, &schema) + } + + fn infer_inner( + mut values: Vec>, + field_types: &[DataType], + schema: &DFSchema, + ) -> Result { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in field_types.iter().enumerate() { if let Expr::Literal(ScalarValue::Null) = row[j] { row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); } else { - row[j] = - std::mem::take(&mut row[j]).cast_to(field_type, &empty_schema)?; + row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } } } @@ -243,6 +323,7 @@ impl LogicalPlanBuilder { .collect::>(); let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); + Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) } diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index efc14cbbe519..abd7649e9ec7 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -28,15 +28,15 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::{exec_err, internal_err}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; -use datafusion_expr::binary::type_union_resolution; +use datafusion_expr::binary::{ + try_type_union_resolution_with_struct, type_union_resolution, +}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::TypeSignature; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; -use itertools::Itertools; use crate::utils::make_scalar_function; @@ -111,33 +111,16 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if let Some(new_type) = type_union_resolution(arg_types) { - // TODO: Move the logic to type_union_resolution if this applies to other functions as well - // Handle struct where we only change the data type but preserve the field name and nullability. - // Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" - let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?; - if is_struct_and_has_same_key { - let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] { - fields.iter().map(|f| f.data_type().to_owned()).collect() - } else { - return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); - }; - - let mut final_struct_types = vec![]; - for s in arg_types { - let mut new_fields = vec![]; - if let DataType::Struct(fields) = s { - for (i, f) in fields.iter().enumerate() { - let field = Arc::unwrap_or_clone(Arc::clone(f)) - .with_data_type(data_types[i].to_owned()); - new_fields.push(Arc::new(field)); - } - } - final_struct_types.push(DataType::Struct(new_fields.into())) - } - return Ok(final_struct_types); + let mut errors = vec![]; + match try_type_union_resolution_with_struct(arg_types) { + Ok(r) => return Ok(r), + Err(e) => { + errors.push(e); } + } + if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move FixedSizeList to List in type_union_resolution if let DataType::FixedSizeList(field, _) = new_type { Ok(vec![DataType::List(field); arg_types.len()]) } else if new_type.is_null() { @@ -147,9 +130,10 @@ impl ScalarUDFImpl for MakeArray { } } else { plan_err!( - "Fail to find the valid type between {:?} for {}", + "Fail to find the valid type between {:?} for {}, errors are {:?}", arg_types, - self.name() + self.name(), + errors ) } } @@ -188,26 +172,6 @@ fn get_make_array_doc() -> &'static Documentation { }) } -fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result { - let mut keys_string: Option = None; - for data_type in data_types { - if let DataType::Struct(fields) = data_type { - let keys = fields.iter().map(|f| f.name().to_owned()).join(","); - if let Some(ref k) = keys_string { - if *k != keys { - return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); - } - } else { - keys_string = Some(keys); - } - } else { - return Ok(false); - } - } - - Ok(true) -} - // Empty array is a special case that is useful for many other array functions pub(super) fn empty_array_type() -> DataType { DataType::List(Arc::new(Field::new("item", DataType::Int64, true))) diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 15cd733a8cd6..a05f3f08232c 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -20,8 +20,8 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, ExprSchema, Result}; +use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; -use datafusion_expr::type_coercion::binary::type_union_resolution; use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; @@ -137,9 +137,8 @@ impl ScalarUDFImpl for CoalesceFunc { if arg_types.is_empty() { return exec_err!("coalesce must have at least one argument"); } - let new_type = type_union_resolution(arg_types) - .unwrap_or(arg_types.first().unwrap().clone()); - Ok(vec![new_type; arg_types.len()]) + + try_type_union_resolution(arg_types) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index f57910b09ade..4adbb9318d51 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -281,6 +281,7 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, _>>() .map_err(|e| e.into()) }?; + LogicalPlanBuilder::values(values)?.build() } LogicalPlanType::Projection(projection) => { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 66e360a9ade9..072d2320fccf 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -138,6 +138,8 @@ pub struct PlannerContext { /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. outer_from_schema: Option, + /// The query schema defined by the table + create_table_schema: Option, } impl Default for PlannerContext { @@ -154,6 +156,7 @@ impl PlannerContext { ctes: HashMap::new(), outer_query_schema: None, outer_from_schema: None, + create_table_schema: None, } } @@ -181,6 +184,18 @@ impl PlannerContext { schema } + pub fn set_table_schema( + &mut self, + mut schema: Option, + ) -> Option { + std::mem::swap(&mut self.create_table_schema, &mut schema); + schema + } + + pub fn table_schema(&self) -> Option { + self.create_table_schema.clone() + } + // Return a clone of the outer FROM schema pub fn outer_from_schema(&self) -> Option> { self.outer_from_schema.clone() diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 60e3413b836f..29852be3bf77 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -394,13 +394,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Build column default values let column_defaults = self.build_column_defaults(&columns, planner_context)?; + + let has_columns = !columns.is_empty(); + let schema = self.build_schema(columns)?.to_dfschema_ref()?; + if has_columns { + planner_context.set_table_schema(Some(Arc::clone(&schema))); + } + match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; let input_schema = plan.schema(); - let plan = if !columns.is_empty() { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let plan = if has_columns { if schema.fields().len() != input_schema.fields().len() { return plan_err!( "Mismatch: {} columns specified, but result has {} columns", @@ -447,7 +453,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; let plan = EmptyRelation { produce_one_row: false, schema, diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index cd33ddb3cfe7..a4001bea7dea 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; @@ -31,16 +33,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { rows, } = values; - // Values should not be based on any other schema - let schema = DFSchema::empty(); + let empty_schema = Arc::new(DFSchema::empty()); let values = rows .into_iter() .map(|row| { row.into_iter() - .map(|v| self.sql_to_expr(v, &schema, planner_context)) + .map(|v| self.sql_to_expr(v, &empty_schema, planner_context)) .collect::>>() }) .collect::>>()?; - LogicalPlanBuilder::values(values)?.build() + + let schema = planner_context.table_schema().unwrap_or(empty_schema); + if schema.fields().is_empty() { + LogicalPlanBuilder::values(values)?.build() + } else { + LogicalPlanBuilder::values_with_schema(values, &schema)?.build() + } } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 69f62057c761..bfdbfb1bcc5e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -7288,4 +7288,3 @@ drop table values_all_empty; statement ok drop table fixed_size_col_table; - diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 7dba4d01d63b..ed001cf9f84c 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -283,4 +283,3 @@ CREATE EXTERNAL TABLE staging.foo STORED AS parquet LOCATION '../../parquet-test # Create external table with qualified name, but no schema should error statement error DataFusion error: Error during planning: failed to resolve schema: release CREATE EXTERNAL TABLE release.bar STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; - diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 813f7e95adf0..3205920d7110 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -799,3 +799,9 @@ CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data statement error DataFusion error: This feature is not implemented: Temporary views not supported CREATE TEMPORARY VIEW y AS VALUES (1,2,3); + +query error DataFusion error: Schema error: No field named a\. +EXPLAIN CREATE TABLE t(a int) AS VALUES (a + a); + +statement error DataFusion error: Schema error: No field named a\. +CREATE TABLE t(a int) AS SELECT x FROM (VALUES (a)) t(x) WHERE false; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 4f2778b5c0d1..61b3ad73cd0a 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -3360,7 +3360,8 @@ physical_plan 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] -08)--------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] query IRI SELECT s.sn, s.amount, 2*s.sn @@ -3430,9 +3431,9 @@ physical_plan 07)------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[sum(l.amount)] 08)--------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] 09)----------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1 -10)------------------CoalescePartitionsExec -11)--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] -12)------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +12)--------------------MemoryExec: partitions=1, partition_sizes=[1] query IRR SELECT r.sn, SUM(l.amount), r.amount @@ -3579,8 +3580,7 @@ physical_plan 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 09)----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] 10)------------------BoundedWindowAggExec: wdw=[sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -11)--------------------CoalescePartitionsExec -12)----------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +11)--------------------MemoryExec: partitions=1, partition_sizes=[1] query ITIPTRR diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 558a9170c7d3..af272e8f5022 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3901,8 +3901,8 @@ SELECT * FROM ( ) AS rhs ON lhs.b=rhs.b ---- 11 1 21 1 -14 2 22 2 12 3 23 3 +14 2 22 2 15 4 24 4 query TT @@ -3922,11 +3922,12 @@ logical_plan 05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 06)------TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] @@ -3979,10 +3980,11 @@ logical_plan 04)--SubqueryAlias: rhs 05)----TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------MemoryExec: partitions=1, partition_sizes=[1] # Null build indices: @@ -4038,11 +4040,12 @@ logical_plan 05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 06)------TableScan: right_table_no_nulls projection=[a, b] physical_plan -01)CoalesceBatchesExec: target_batch_size=3 -02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] -03)----MemoryExec: partitions=1, partition_sizes=[1] -04)----SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] # Test CROSS JOIN LATERAL syntax (planning) diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index b76c78396aed..7596b820c688 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -392,12 +392,12 @@ create table t(a struct, b struct) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) query ? select [a, b] from t; ---- -[{r: red, c: 1}, {r: blue, c: 2}] +[{r: red, c: 1.0}, {r: blue, c: 2.3}] statement ok drop table t; @@ -453,6 +453,27 @@ Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ statement ok drop table t; +statement ok +create table t as values({r: 'a', c: 1}), ({r: 'b', c: 2.3}); + +query ? +select * from t; +---- +{c0: a, c1: 1.0} +{c0: b, c1: 2.3} + +query T +select arrow_typeof(column1) from t; +---- +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Float64 type +create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); + ################################## ## Test Coalesce with Struct ################################## @@ -474,13 +495,12 @@ select coalesce(s1) from t; {a: 2, b: blue} {a: 3, b: green} -# TODO: a's type should be float query T -select arrow_typeof(coalesce(s1)) from t; +select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) statement ok drop table t; @@ -495,26 +515,32 @@ CREATE TABLE t ( (row(3, 'green'), row(33.2, 'string3')) ; -# TODO: second column should not be null query ? -select coalesce(s1) from t; +select coalesce(s1, s2) from t; ---- -{a: 1, b: red} -NULL -{a: 3, b: green} +{a: 1.0, b: red} +{a: 2.2, b: string2} +{a: 3.0, b: green} + +query T +select arrow_typeof(coalesce(s1, s2)) from t; +---- +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) statement ok drop table t; # row() with incorrect order -statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float64 type +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('red', 1), row(2.3, 'blue')), (row('purple', 1), row('green', 2.3)); # out of order struct literal # TODO: This query should not fail -statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'b' to value of Int32 type create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); ################################## @@ -529,3 +555,43 @@ select [{r: 'a', c: 1}, {r: 'b', c: 2}]; # Can't create a list of struct with different field types query error select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; + +statement ok +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +statement ok +drop table t; + +# create table with different struct type is fine +statement ok +create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); + +# create array with different struct type is not valid +query error +select arrow_typeof([a, b]) from t; + +statement ok +drop table t; + +statement ok +create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(a) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(b) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index ab6dc3a9e588..6b142302a543 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -208,10 +208,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -242,10 +244,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int * Float64(1))] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -276,10 +280,12 @@ physical_plan 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 @@ -313,10 +319,12 @@ physical_plan 08)--------------CoalesceBatchesExec: target_batch_size=2 09)----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 10)------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] -11)--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -12)------CoalesceBatchesExec: target_batch_size=2 -13)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -14)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------MemoryExec: partitions=1, partition_sizes=[1] +13)------CoalesceBatchesExec: target_batch_size=2 +14)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +15)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index b923e94fc819..947eb8630b52 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -643,7 +643,7 @@ NULL [4] [{c0: [2], c1: [[3], [4]]}] 4 [3] [{c0: [2], c1: [[3], [4]]}] NULL [4] [{c0: [2], c1: [[3], [4]]}] -## demonstrate where recursive unnest is impossible +## demonstrate where recursive unnest is impossible ## and need multiple unnesting logical plans ## e.g unnest -> field_access -> unnest query TT diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 51e859275512..95d850795772 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5021,4 +5021,4 @@ NULL statement ok DROP TABLE t1; -## end test handle NULL of lead \ No newline at end of file +## end test handle NULL of lead From 3f3a0cfd30d145121ec8f3f9de725ecbf5a335bb Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 24 Oct 2024 10:35:18 +0800 Subject: [PATCH 062/110] feat: support arbitrary expressions in `LIMIT` plan (#13028) * feat: support arbitrary expressions in `LIMIT` clause * restore test * Fix doc * Update datafusion/optimizer/src/eliminate_limit.rs Co-authored-by: Jax Liu * Update datafusion/expr/src/expr_rewriter/mod.rs Co-authored-by: Jax Liu * Fix clippy * Disallow non-integer types --------- Co-authored-by: Jax Liu --- datafusion/core/src/physical_planner.rs | 25 +++-- .../tests/user_defined/user_defined_plan.rs | 44 ++++---- datafusion/expr/src/expr_rewriter/mod.rs | 9 +- datafusion/expr/src/logical_plan/builder.rs | 19 +++- datafusion/expr/src/logical_plan/display.rs | 6 +- datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 105 +++++++++++++++--- datafusion/expr/src/logical_plan/tree_node.rs | 28 ++++- .../optimizer/src/analyzer/type_coercion.rs | 37 +++++- datafusion/optimizer/src/decorrelate.rs | 15 +-- datafusion/optimizer/src/eliminate_limit.rs | 20 ++-- datafusion/optimizer/src/push_down_limit.rs | 54 ++++----- datafusion/proto/src/logical_plan/mod.rs | 23 +++- datafusion/sql/src/query.rs | 92 +++------------ datafusion/sql/src/unparser/plan.rs | 15 +-- datafusion/sql/tests/cases/plan_to_sql.rs | 2 +- datafusion/sqllogictest/test_files/select.slt | 67 ++++++++++- .../substrait/src/logical_plan/consumer.rs | 4 +- .../substrait/src/logical_plan/producer.rs | 15 ++- 19 files changed, 376 insertions(+), 208 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 918ebccbeb70..4a5c156e28ac 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window, + Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, }; use crate::logical_expr::{ Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, UserDefinedLogicalNode, }; -use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -78,8 +77,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr, - StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery, + SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -796,8 +795,20 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(limit) => { let input = children.one()?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!( + "Unsupported OFFSET expression: {:?}", + limit.skip + ); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!( + "Unsupported LIMIT expression: {:?}", + limit.fetch + ); + }; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -806,13 +817,13 @@ impl DefaultPhysicalPlanner { // Apply a LocalLimitExec to each partition. The optimizer will also insert // a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec if let Some(fetch) = fetch { - Arc::new(LocalLimitExec::new(input, *fetch + skip)) + Arc::new(LocalLimitExec::new(input, fetch + skip)) } else { input } }; - Arc::new(GlobalLimitExec::new(input, *skip, *fetch)) + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } LogicalPlan::Unnest(Unnest { list_type_columns, diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 2b45d0ed600b..6c4e3c66e397 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -81,7 +81,7 @@ use datafusion::{ runtime_env::RuntimeEnv, }, logical_expr::{ - Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, + Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, optimizer::{OptimizerConfig, OptimizerRule}, @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::tree_node::replace_sort_expression; -use datafusion_expr::{Projection, SortExpr}; +use datafusion_expr::{FetchType, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; @@ -361,28 +361,28 @@ impl OptimizerRule for TopKOptimizerRule { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - if let LogicalPlan::Limit(Limit { - fetch: Some(fetch), - input, + let LogicalPlan::Limit(ref limit) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort(Sort { + ref expr, + ref input, .. - }) = &plan + }) = limit.input.as_ref() { - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = **input - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Transformed::yes(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: *fetch, - input: input.as_ref().clone(), - expr: expr[0].clone(), - }), - }))); - } + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 47cc947be3ca..d6d5c3e2931c 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -306,11 +306,14 @@ impl NamePreserver { /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan pub fn new(plan: &LogicalPlan) -> Self { Self { - // The schema of Filter, Join and TableScan nodes comes from their inputs rather than - // their expressions, so there is no need to use aliases to preserve expression names. + // The expressions of these plans do not contribute to their output schema, + // so there is no need to preserve expression names to prevent a schema change. use_alias: !matches!( plan, - LogicalPlan::Filter(_) | LogicalPlan::Join(_) | LogicalPlan::TableScan(_) + LogicalPlan::Filter(_) + | LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Limit(_) ), } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d2ecd56cdc23..cef05b6f8814 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -40,7 +40,7 @@ use crate::utils::{ find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, WriteOp, }; @@ -512,9 +512,22 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { + let skip_expr = if skip == 0 { + None + } else { + Some(lit(skip as i64)) + }; + let fetch_expr = fetch.map(|f| lit(f as i64)); + self.limit_by_expr(skip_expr, fetch_expr) + } + + /// Limit the number of rows returned + /// + /// Similar to `limit` but uses expressions for `skip` and `fetch` + pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { Ok(Self::new(LogicalPlan::Limit(Limit { - skip, - fetch, + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), input: self.plan, }))) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 26d54803d403..0287846862af 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -549,11 +549,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let mut object = serde_json::json!( { "Node Type": "Limit", - "Skip": skip, } ); + if let Some(s) = skip { + object["Skip"] = s.to_string().into() + }; if let Some(f) = fetch { - object["Fetch"] = serde_json::Value::Number((*f).into()); + object["Fetch"] = f.to_string().into() }; object } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index da44cfb010d7..18ac3f2ab9cb 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -36,9 +36,9 @@ pub use ddl::{ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable, - Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, + Distinct, DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, + Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d8dfe7b56e40..e0aae4cb7448 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -49,7 +49,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, + UnnestOptions, }; use indexmap::IndexSet; @@ -960,11 +961,21 @@ impl LogicalPlan { .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { - self.assert_no_expressions(expr)?; + let old_expr_len = skip.iter().chain(fetch.iter()).count(); + if old_expr_len != expr.len() { + return internal_err!( + "Invalid number of new Limit expressions: expected {}, got {}", + old_expr_len, + expr.len() + ); + } + // Pop order is same as the order returned by `LogicalPlan::expressions()` + let new_skip = skip.as_ref().and(expr.pop()); + let new_fetch = fetch.as_ref().and(expr.pop()); let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, + skip: new_skip.map(Box::new), + fetch: new_fetch.map(Box::new), input: Arc::new(input), })) } @@ -1339,7 +1350,10 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, + LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + Ok(FetchType::Literal(s)) => s, + _ => None, + }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -1909,16 +1923,20 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(limit) => { + // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. + let skip_str = match limit.get_skip_type() { + Ok(SkipType::Literal(n)) => n.to_string(), + _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + }; + let fetch_str = match limit.get_fetch_type() { + Ok(FetchType::Literal(Some(n))) => n.to_string(), + Ok(FetchType::Literal(None)) => "None".to_string(), + _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + }; write!( f, - "Limit: skip={}, fetch={}", - skip, - fetch.map_or_else(|| "None".to_string(), |x| x.to_string()) + "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2778,14 +2796,71 @@ impl PartialOrd for Extension { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { /// Number of rows to skip before fetch - pub skip: usize, + pub skip: Option>, /// Maximum number of rows to fetch, /// None means fetching all rows - pub fetch: Option, + pub fetch: Option>, /// The logical plan pub input: Arc, } +/// Different types of skip expression in Limit plan. +pub enum SkipType { + /// The skip expression is a literal value. + Literal(usize), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +/// Different types of fetch expression in Limit plan. +pub enum FetchType { + /// The fetch expression is a literal value. + /// `Literal(None)` means the fetch expression is not provided. + Literal(Option), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +impl Limit { + /// Get the skip type from the limit plan. + pub fn get_skip_type(&self) -> Result { + match self.skip.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(s)) => { + // `skip = NULL` is equivalent to `skip = 0` + let s = s.unwrap_or(0); + if s >= 0 { + Ok(SkipType::Literal(s as usize)) + } else { + plan_err!("OFFSET must be >=0, '{}' was provided", s) + } + } + _ => Ok(SkipType::UnsupportedExpr), + }, + // `skip = None` is equivalent to `skip = 0` + None => Ok(SkipType::Literal(0)), + } + } + + /// Get the fetch type from the limit plan. + pub fn get_fetch_type(&self) -> Result { + match self.fetch.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => { + if s >= 0 { + Ok(FetchType::Literal(Some(s as usize))) + } else { + plan_err!("LIMIT must be >= 0, '{}' was provided", s) + } + } + Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + _ => Ok(FetchType::UnsupportedExpr), + }, + None => Ok(FetchType::Literal(None)), + } + } +} + /// Removes duplicate rows from the input #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Distinct { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 606868e75abf..b8d7043d7746 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -43,6 +43,7 @@ use crate::{ Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Exists, InSubquery}; @@ -515,12 +516,16 @@ impl LogicalPlan { .chain(select_expr.iter()) .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) .apply_until_stop(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip + .iter() + .chain(fetch.iter()) + .map(|e| e.deref()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) @@ -726,13 +731,32 @@ impl LogicalPlan { schema, })) }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + let skip = skip.map(|e| *e); + let fetch = fetch.map(|e| *e); + map_until_stop_and_collect!( + skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }), + fetch, + fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), + input, + }) + }) + } // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e5d280289342..36b72233b5af 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, LogicalPlan, Operator, - Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -169,6 +170,7 @@ impl<'a> TypeCoercionRewriter<'a> { match plan { LogicalPlan::Join(join) => self.coerce_join(join), LogicalPlan::Union(union) => Self::coerce_union(union), + LogicalPlan::Limit(limit) => Self::coerce_limit(limit), _ => Ok(plan), } } @@ -230,6 +232,37 @@ impl<'a> TypeCoercionRewriter<'a> { })) } + /// Coerce the fetch and skip expression to Int64 type. + fn coerce_limit(limit: Limit) -> Result { + fn coerce_limit_expr( + expr: Expr, + schema: &DFSchema, + expr_name: &str, + ) -> Result { + let dt = expr.get_type(schema)?; + if dt.is_integer() || dt.is_null() { + expr.cast_to(&DataType::Int64, schema) + } else { + plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + } + } + + let empty_schema = DFSchema::empty(); + let new_fetch = limit + .fetch + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT")) + .transpose()?; + let new_skip = limit + .skip + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) + .transpose()?; + Ok(LogicalPlan::Limit(Limit { + input: limit.input, + fetch: new_fetch.map(Box::new), + skip: new_skip.map(Box::new), + })) + } + fn coerce_join_filter(&self, expr: Expr) -> Result { let expr_type = expr.get_type(self.schema)?; match expr_type { diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 7f918c03e3ac..baf449a045eb 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -31,7 +31,9 @@ use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; -use datafusion_expr::{expr, lit, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{ + expr, lit, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder, +}; use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated @@ -327,16 +329,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes( - if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::clone(limit.input.schema()), }) - } else { - LogicalPlanBuilder::from((*limit.input).clone()).build()? - }, - ), + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 25304d4ccafa..829d4c2d2217 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; +use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -63,8 +63,13 @@ impl OptimizerRule for EliminateLimit { > { match plan { LogicalPlan::Limit(limit) => { - if let Some(fetch) = limit.fetch { - if fetch == 0 { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + if let Some(v) = fetch { + if v == 0 { return Ok(Transformed::yes(LogicalPlan::EmptyRelation( EmptyRelation { produce_one_row: false, @@ -72,11 +77,10 @@ impl OptimizerRule for EliminateLimit { }, ))); } - } else if limit.skip == 0 { - // input also can be Limit, so we should apply again. - return Ok(self - .rewrite(Arc::unwrap_or_clone(limit.input), _config) - .unwrap()); + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } Ok(Transformed::no(LogicalPlan::Limit(limit))) } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6ed77387046e..bf5ce0531e06 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -27,6 +27,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; use datafusion_common::Result; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{lit, FetchType, SkipType}; /// Optimization rule that tries to push down `LIMIT`. /// @@ -56,16 +57,27 @@ impl OptimizerRule for PushDownLimit { return Ok(Transformed::no(plan)); }; - let Limit { skip, fetch, input } = limit; + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = input.as_ref() { - let (skip, fetch) = - combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); - + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); let plan = LogicalPlan::Limit(Limit { - skip, - fetch, + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), input: Arc::clone(&child.input), }); @@ -75,14 +87,10 @@ impl OptimizerRule for PushDownLimit { // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch, - input, - }))); + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - match Arc::unwrap_or_clone(input) { + match Arc::unwrap_or_clone(limit.input) { LogicalPlan::TableScan(mut scan) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan @@ -162,8 +170,8 @@ impl OptimizerRule for PushDownLimit { .into_iter() .map(|child| { LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), input: Arc::new(child.clone()), }) }) @@ -203,8 +211,8 @@ impl OptimizerRule for PushDownLimit { /// ``` fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), + skip: Some(Box::new(lit(skip as i64))), + fetch: Some(Box::new(lit(fetch as i64))), input, }) } @@ -224,11 +232,7 @@ fn original_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::no(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) } /// Returns the a transformed limit @@ -237,11 +241,7 @@ fn transformed_limit( fetch: usize, input: LogicalPlan, ) -> Result> { - Ok(Transformed::yes(LogicalPlan::Limit(Limit { - skip, - fetch: Some(fetch), - input: Arc::new(input), - }))) + Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } /// Adds a limit to the inputs of a join, if possible diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 4adbb9318d51..73df506397b1 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -62,13 +62,13 @@ use datafusion_expr::{ logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, - EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, - Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + EmptyRelation, Extension, Join, JoinConstraint, Prepare, Projection, Repartition, + Sort, SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, }; -use datafusion_expr::{AggregateUDF, ColumnUnnestList, Unnest}; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, FetchType, SkipType, Unnest}; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; @@ -1265,17 +1265,28 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(Limit { input, skip, fetch }) => { + LogicalPlan::Limit(limit) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), + limit.input.as_ref(), extension_codec, )?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal skip values", + )); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal fetch values", + )); + }; + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Limit(Box::new( protobuf::LimitNode { input: Some(Box::new(input)), - skip: *skip as i64, + skip: skip as i64, fetch: fetch.unwrap_or(i64::MAX as usize) as i64, }, ))), diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 54945ec43d10..842a1c0cbec1 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,15 +19,14 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; use datafusion_expr::expr::Sort; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, + CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, Query, SelectInto, - SetExpr, Value, + SetExpr, }; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -85,35 +84,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(input); } - let skip = match skip { - Some(skip_expr) => { - let expr = self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "OFFSET")?; - convert_usize_with_check(n, "OFFSET") - } - _ => Ok(0), - }?; - - let fetch = match fetch { - Some(limit_expr) - if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => - { - let expr = self.sql_to_expr( - limit_expr, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "LIMIT")?; - Some(convert_usize_with_check(n, "LIMIT")?) - } - _ => None, - }; - - LogicalPlanBuilder::from(input).limit(skip, fetch)?.build() + // skip and fetch expressions are not allowed to reference columns from the input plan + let empty_schema = DFSchema::empty(); + + let skip = skip + .map(|o| self.sql_to_expr(o.value, &empty_schema, &mut PlannerContext::new())) + .transpose()?; + let fetch = fetch + .map(|e| self.sql_to_expr(e, &empty_schema, &mut PlannerContext::new())) + .transpose()?; + LogicalPlanBuilder::from(input) + .limit_by_expr(skip, fetch)? + .build() } /// Wrap the logical in a sort @@ -159,50 +141,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -/// Retrieves the constant result of an expression, evaluating it if possible. -/// -/// This function takes an expression and an argument name as input and returns -/// a `Result` indicating either the constant result of the expression or an -/// error if the expression cannot be evaluated. -/// -/// # Arguments -/// -/// * `expr` - An `Expr` representing the expression to evaluate. -/// * `arg_name` - The name of the argument for error messages. -/// -/// # Returns -/// -/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, -/// or an `Err` variant containing an error message if evaluation fails. -/// -/// tracks a more general solution -fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { - match expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), - Expr::BinaryExpr(binary_expr) => { - let lhs = get_constant_result(&binary_expr.left, arg_name)?; - let rhs = get_constant_result(&binary_expr.right, arg_name)?; - let res = match binary_expr.op { - Operator::Plus => lhs + rhs, - Operator::Minus => lhs - rhs, - Operator::Multiply => lhs * rhs, - _ => return plan_err!("Unsupported operator for {arg_name} clause"), - }; - Ok(res) - } - _ => plan_err!("Unexpected expression in {arg_name} clause"), - } -} - -/// Converts an `i64` to `usize`, performing a boundary check. -fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { - if n < 0 { - plan_err!("{arg_name} must be >= 0, '{n}' was provided.") - } else { - Ok(n as usize) - } -} - /// Returns the order by expressions from the query. fn to_order_by_exprs(order_by: Option) -> Result> { let Some(OrderBy { exprs, interpolate }) = order_by else { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 037748035fbf..0147a607567b 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -343,20 +343,16 @@ impl Unparser<'_> { relation, ); } - - if let Some(fetch) = limit.fetch { + if let Some(fetch) = &limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( "Limit operator only valid in a statement context." ); }; - query.limit(Some(ast::Expr::Value(ast::Value::Number( - fetch.to_string(), - false, - )))); + query.limit(Some(self.expr_to_sql(fetch)?)); } - if limit.skip > 0 { + if let Some(skip) = &limit.skip { let Some(query) = query.as_mut() else { return internal_err!( "Offset operator only valid in a statement context." @@ -364,10 +360,7 @@ impl Unparser<'_> { }; query.offset(Some(ast::Offset { rows: ast::OffsetRows::None, - value: ast::Expr::Value(ast::Value::Number( - limit.skip.to_string(), - false, - )), + value: self.expr_to_sql(skip)?, })); } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index e7b96199511a..9ed084eec249 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1016,7 +1016,7 @@ fn test_without_offset() { #[test] fn test_with_offset0() { - sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1"); + sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1 OFFSET 0"); } #[test] diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 9910ca8da71f..f2ab4135aaa7 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -581,9 +581,32 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. +statement error Error during planning: LIMIT must be >= 0, '-1' was provided select * from (select 1 a union all select 2) b order by a limit -1; +statement error Error during planning: OFFSET must be >=0, '-1' was provided +select * from (select 1 a union all select 2) b order by a offset -1; + +statement error Unsupported LIMIT expression +select * from (values(1),(2)) limit (select 1); + +statement error Unsupported OFFSET expression +select * from (values(1),(2)) offset (select 1); + +# disallow non-integer limit/offset +statement error Expected LIMIT to be an integer or null, but got Float64 +select * from (values(1),(2)) limit 0.5; + +statement error Expected OFFSET to be an integer or null, but got Utf8 +select * from (values(1),(2)) offset '1'; + +# test with different integer types +query I +select * from (values (1), (2), (3), (4)) limit 2::int OFFSET 1::tinyint +---- +2 +3 + # select limit with basic arithmetic query I select * from (select 1 a union all select 2) b order by a limit 1+1; @@ -597,13 +620,38 @@ select * from (values (1)) LIMIT 10*100; ---- 1 -# More complex expressions in the limit is not supported yet. -# See issue: https://github.com/apache/datafusion/issues/9821 -statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +# select limit with complex arithmetic +query I select * from (values (1)) LIMIT 100/10; +---- +1 -# More complex expressions in the limit is not supported yet. -statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +# test constant-folding of LIMIT expr +query I +select * from (values (1), (2), (3), (4)) LIMIT abs(-4) + 4 / -2; -- LIMIT 2 +---- +1 +2 + +# test constant-folding of OFFSET expr +query I +select * from (values (1), (2), (3), (4)) OFFSET abs(-4) + 4 / -2; -- OFFSET 2 +---- +3 +4 + +# test constant-folding of LIMIT and OFFSET +query I +select * from (values (1), (2), (3), (4)) + -- LIMIT 2 + LIMIT abs(-4) + -1 * 2 + -- OFFSET 1 + OFFSET case when 1 < 2 then 1 else 0 end; +---- +2 +3 + +statement error Schema error: No field named column1. select * from (values (1)) LIMIT cast(column1 as tinyint); # select limit clause @@ -613,6 +661,13 @@ select * from (select 1 a union all select 2) b order by a limit null; 1 2 +# offset null takes no effect +query I +select * from (select 1 a union all select 2) b order by a offset null; +---- +1 +2 + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit 0; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a8d195507a2..3d5d7cce5673 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -623,8 +623,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` - let count = if fetch.count as usize == usize::MAX { + // -1 means that ALL records should be returned + let count = if fetch.count == -1 { None } else { Some(fetch.count as usize) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7504a287c055..bb50c4b9610f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -24,7 +24,7 @@ use substrait::proto::expression_reference::ExprType; use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + CrossJoin, Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -326,14 +326,19 @@ pub fn to_substrait_rel( } LogicalPlan::Limit(limit) => { let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; - // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` - let limit_fetch = limit.fetch.unwrap_or(usize::MAX); + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!("Non-literal limit fetch"); + }; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!("Non-literal limit skip"); + }; Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, input: Some(input), - offset: limit.skip as i64, - count: limit_fetch as i64, + offset: skip as i64, + // use -1 to signal that ALL records should be returned + count: fetch.map(|f| f as i64).unwrap_or(-1), advanced_extension: None, }))), })) From 8adbc2324afac66d8cb88b20cb1482913b190d4b Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Thu, 24 Oct 2024 00:10:09 -0700 Subject: [PATCH 063/110] [minor]: use arrow take_batch instead of get_record_batch_indices (#13084) * Initial commit * Fix linter errors * Minor changes * Fix error --- datafusion/common/src/utils/mod.rs | 24 ++++--------------- .../tests/fuzz_cases/equivalence/utils.rs | 17 ++++--------- .../src/windows/bounded_window_agg_exec.rs | 6 ++--- 3 files changed, 12 insertions(+), 35 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index def1def9853c..dacf90af9bbf 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -23,16 +23,14 @@ pub mod proxy; pub mod string_utils; use crate::error::{_internal_datafusion_err, _internal_err}; -use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use arrow::array::{ArrayRef, PrimitiveArray}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow::array::ArrayRef; use arrow::buffer::OffsetBuffer; -use arrow::compute::{partition, take_arrays, SortColumn, SortOptions}; -use arrow::datatypes::{Field, SchemaRef, UInt32Type}; -use arrow::record_batch::RecordBatch; +use arrow::compute::{partition, SortColumn, SortOptions}; +use arrow::datatypes::{Field, SchemaRef}; use arrow_array::cast::AsArray; use arrow_array::{ Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, - RecordBatchOptions, }; use arrow_schema::DataType; use sqlparser::ast::Ident; @@ -92,20 +90,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result, -) -> Result { - let new_columns = take_arrays(record_batch.columns(), indices, None)?; - RecordBatch::try_new_with_options( - record_batch.schema(), - new_columns, - &RecordBatchOptions::new().with_row_count(Some(indices.len())), - ) - .map_err(|e| arrow_datafusion_err!(e)) -} - /// This function compares two tuples depending on the given sort options. pub fn compare_rows( x: &[ScalarValue], diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index 61691311fe4e..acc45fe0e591 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -22,15 +22,11 @@ use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; -use arrow::compute::{lexsort_to_indices, SortColumn}; +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; -use arrow_array::{ - ArrayRef, Float32Array, Float64Array, PrimitiveArray, RecordBatch, UInt32Array, -}; +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::utils::{ - compare_rows, get_record_batch_at_indices, get_row_at_idx, -}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -465,7 +461,7 @@ pub fn generate_table_for_orderings( // Sort batch according to first ordering expression let sort_columns = get_sort_columns(&batch, &orderings[0])?; let sort_indices = lexsort_to_indices(&sort_columns, None)?; - let mut batch = get_record_batch_at_indices(&batch, &sort_indices)?; + let mut batch = take_record_batch(&batch, &sort_indices)?; // prune out rows that is invalid according to remaining orderings. for ordering in orderings.iter().skip(1) { @@ -490,10 +486,7 @@ pub fn generate_table_for_orderings( } } // Only keep valid rows, that satisfies given ordering relation. - batch = get_record_batch_at_indices( - &batch, - &PrimitiveArray::from_iter_values(keep_indices), - )?; + batch = take_record_batch(&batch, &UInt32Array::from_iter_values(keep_indices))?; } Ok(batch) diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 6254ae139a00..6495657339fa 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -40,6 +40,7 @@ use crate::{ SendableRecordBatchStream, Statistics, WindowExpr, }; use ahash::RandomState; +use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, compute::{concat, concat_batches, sort_to_indices, take_arrays}, @@ -49,8 +50,7 @@ use arrow::{ use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::utils::{ - evaluate_partition_ranges, get_at_indices, get_record_batch_at_indices, - get_row_at_idx, + evaluate_partition_ranges, get_at_indices, get_row_at_idx, }; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -558,7 +558,7 @@ impl PartitionSearcher for LinearSearch { let mut new_indices = UInt32Builder::with_capacity(indices.len()); new_indices.append_slice(&indices); let indices = new_indices.finish(); - Ok((row, get_record_batch_at_indices(record_batch, &indices)?)) + Ok((row, take_record_batch(record_batch, &indices)?)) }) .collect() } From f2da32b3bde851c34e9df0a2f4c174a5392f8897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20=C5=9Een?= Date: Thu, 24 Oct 2024 10:57:57 +0300 Subject: [PATCH 064/110] deprecated (#13076) --- datafusion-examples/examples/sql_analysis.rs | 6 +- datafusion/core/src/physical_planner.rs | 4 -- datafusion/expr/src/logical_plan/display.rs | 5 -- datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 67 +------------------ datafusion/expr/src/logical_plan/tree_node.rs | 28 ++------ datafusion/optimizer/src/analyzer/subquery.rs | 1 - .../optimizer/src/common_subexpr_eliminate.rs | 1 - .../optimizer/src/eliminate_cross_join.rs | 35 +++------- .../optimizer/src/optimize_projections/mod.rs | 11 --- .../optimizer/src/propagate_empty_relation.rs | 13 ---- datafusion/optimizer/src/push_down_filter.rs | 66 ++---------------- datafusion/optimizer/src/push_down_limit.rs | 7 -- datafusion/proto/src/logical_plan/mod.rs | 24 +------ datafusion/sql/src/unparser/plan.rs | 38 ----------- .../substrait/src/logical_plan/producer.rs | 21 +----- 16 files changed, 32 insertions(+), 299 deletions(-) diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 9a2aabaa79c2..2158b8e4b016 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -39,7 +39,7 @@ fn total_join_count(plan: &LogicalPlan) -> usize { // We can use the TreeNode API to walk over a LogicalPlan. plan.apply(|node| { // if we encounter a join we update the running count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; } Ok(TreeNodeRecursion::Continue) @@ -89,7 +89,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { while let Some(node) = to_visit.pop() { // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { let (group_count, inputs) = count_tree(node); total += group_count; groups.push(group_count); @@ -151,7 +151,7 @@ fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { } // any join we count - if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { + if matches!(node, LogicalPlan::Join(_)) { total += 1; Ok(TreeNodeRecursion::Continue) } else { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4a5c156e28ac..5a4ae868d04a 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1127,10 +1127,6 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::CrossJoin(_) => { - let [left, right] = children.two()?; - Arc::new(CrossJoinExec::new(left, right)) - } LogicalPlan::RecursiveQuery(RecursiveQuery { name, is_distinct, .. }) => { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 0287846862af..c0549451a776 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -504,11 +504,6 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::CrossJoin(_) => { - json!({ - "Node Type": "Cross Join" - }) - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 18ac3f2ab9cb..80a896212442 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -35,8 +35,8 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable, - Distinct, DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, + projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e0aae4cb7448..4b42702f24bf 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -221,10 +221,6 @@ pub enum LogicalPlan { /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` Join(Join), - /// Apply Cross Join to two logical plans. - /// This is used to implement SQL `CROSS JOIN` - /// Deprecated: use [LogicalPlan::Join] instead with empty `on` / no filter - CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems @@ -312,7 +308,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::CrossJoin(CrossJoin { schema, .. }) => schema, LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::Statement(statement) => statement.schema(), @@ -345,8 +340,7 @@ impl LogicalPlan { | LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => self + | LogicalPlan::Join(_) => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -436,7 +430,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => vec![left, right], LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -542,13 +535,6 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::CrossJoin(cross) => { - if cross.left.schema().fields().is_empty() { - cross.right.head_output_expr() - } else { - cross.left.head_output_expr() - } - } LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() } @@ -674,20 +660,6 @@ impl LogicalPlan { null_equals_null, })) } - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: _, - }) => { - let join_schema = - build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - - Ok(LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema: join_schema.into(), - })) - } LogicalPlan::Subquery(_) => Ok(self), LogicalPlan::SubqueryAlias(SubqueryAlias { input, @@ -938,11 +910,6 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::CrossJoin(_) => { - self.assert_no_expressions(expr)?; - let (left, right) = self.only_two_inputs(inputs)?; - LogicalPlanBuilder::from(left).cross_join(right)?.build() - } LogicalPlan::Subquery(Subquery { outer_ref_columns, .. }) => { @@ -1327,12 +1294,6 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(), JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - match (left.max_rows(), right.max_rows()) { - (Some(left_max), Some(right_max)) => Some(left_max * right_max), - _ => None, - } - } LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => inputs .iter() @@ -1893,9 +1854,6 @@ impl LogicalPlan { } } } - LogicalPlan::CrossJoin(_) => { - write!(f, "CrossJoin:") - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -2601,28 +2559,7 @@ impl TableScan { } } -/// Apply Cross Join to two logical plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct CrossJoin { - /// Left input - pub left: Arc, - /// Right input - pub right: Arc, - /// The output schema, containing fields from the left and right inputs - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for CrossJoin { - fn partial_cmp(&self, other: &Self) -> Option { - match self.left.partial_cmp(&other.left) { - Some(Ordering::Equal) => self.right.partial_cmp(&other.right), - cmp => cmp, - } - } -} - -/// Repartition the plan based on a partitioning scheme. +// Repartition the plan based on a partitioning scheme. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Repartition { /// The incoming logical plan diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index b8d7043d7746..0658f7029740 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,11 +37,11 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ - dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, - DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, - Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, - UserDefinedLogicalNode, Values, Window, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, + Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, + LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort, + Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, + Window, }; use std::ops::Deref; use std::sync::Arc; @@ -160,22 +160,6 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) - }), LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { @@ -527,7 +511,6 @@ impl LogicalPlan { | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -758,7 +741,6 @@ impl LogicalPlan { | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index aabc549de583..0a52685bd681 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -180,7 +180,6 @@ fn check_inner_plan( LogicalPlan::Projection(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) | LogicalPlan::EmptyRelation(_) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 921011d33fc4..ee9ae9fb15a7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -534,7 +534,6 @@ impl OptimizerRule for CommonSubexprEliminate { LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 8a365fb389be..65ebac2106ad 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -98,7 +98,7 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) | LogicalPlan::CrossJoin(_) + }) ); if !rewriteable { @@ -241,20 +241,6 @@ fn flatten_join_inputs( all_filters, )?; } - LogicalPlan::CrossJoin(join) => { - flatten_join_inputs( - Arc::unwrap_or_clone(join.left), - possible_join_keys, - all_inputs, - all_filters, - )?; - flatten_join_inputs( - Arc::unwrap_or_clone(join.right), - possible_join_keys, - all_inputs, - all_filters, - )?; - } _ => { all_inputs.push(plan); } @@ -270,23 +256,18 @@ fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} - LogicalPlan::CrossJoin(_) => {} _ => return false, }; for child in plan.inputs() { - match child { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !can_flatten_join_inputs(child) { - return false; - } + if let LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) = child + { + if !can_flatten_join_inputs(child) { + return false; } - // the child is not a join/cross join - _ => (), } } true diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b5d581f3919f..42eff7100fbe 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -367,17 +367,6 @@ fn optimize_projections( right_indices.with_projection_beneficial(), ] } - LogicalPlan::CrossJoin(cross_join) => { - let left_len = cross_join.left.schema().fields().len(); - let (left_indices, right_indices) = - split_join_requirements(left_len, indices, &JoinType::Inner); - // Joins benefit from "small" input tables (lower memory usage). - // Therefore, each child benefits from projection: - vec![ - left_indices.with_projection_beneficial(), - right_indices.with_projection_beneficial(), - ] - } // these nodes are explicitly rewritten in the match statement above LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index b5e1077ee5be..d26df073dc6f 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -72,19 +72,6 @@ impl OptimizerRule for PropagateEmptyRelation { } Ok(Transformed::no(plan)) } - LogicalPlan::CrossJoin(ref join) => { - let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; - if left_empty || right_empty { - return Ok(Transformed::yes(LogicalPlan::EmptyRelation( - EmptyRelation { - produce_one_row: false, - schema: Arc::clone(plan.schema()), - }, - ))); - } - Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone()))) - } - LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ac81f3efaa11..a6c0a7310610 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -24,19 +24,15 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef, - JoinConstraint, Result, + internal_err, plan_err, qualified_name, Column, DFSchema, Result, }; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{ - CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, -}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; use datafusion_expr::{ - and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - Projection, TableProviderFilterPushDown, + and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, }; use crate::optimizer::ApplyOrder; @@ -867,12 +863,6 @@ impl OptimizerRule for PushDownFilter { }) } LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::CrossJoin(cross_join) => { - let predicates = split_conjunction_owned(filter.predicate); - let join = convert_cross_join_to_inner_join(cross_join)?; - let plan = push_down_all_join(predicates, vec![], join, vec![])?; - convert_to_cross_join_if_beneficial(plan.data) - } LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); let results = scan @@ -1114,48 +1104,6 @@ impl PushDownFilter { } } -/// Converts the given cross join to an inner join with an empty equality -/// predicate and an empty filter condition. -fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { - let CrossJoin { left, right, .. } = cross_join; - let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - Ok(Join { - left, - right, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: vec![], - filter: None, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }) -} - -/// Converts the given inner join with an empty equality predicate and an -/// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial( - plan: LogicalPlan, -) -> Result> { - match plan { - // Can be converted back to cross join - LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => { - LogicalPlanBuilder::from(Arc::unwrap_or_clone(join.left)) - .cross_join(Arc::unwrap_or_clone(join.right))? - .build() - .map(Transformed::yes) - } - LogicalPlan::Filter(filter) => { - convert_to_cross_join_if_beneficial(Arc::unwrap_or_clone(filter.input))? - .transform_data(|child_plan| { - Filter::try_new(filter.predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) - .map(Transformed::yes) - }) - } - plan => Ok(Transformed::no(plan)), - } -} - /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, @@ -1203,13 +1151,13 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::ScalarValue; + use datafusion_common::{DFSchemaRef, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, ColumnarValue, Extension, ScalarUDF, - ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, - Volatility, + col, in_list, in_subquery, lit, ColumnarValue, Extension, LogicalPlanBuilder, + ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, + UserDefinedLogicalNodeCore, Volatility, }; use crate::optimizer::Optimizer; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index bf5ce0531e06..ec7a0a1364b6 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -118,13 +118,6 @@ impl OptimizerRule for PushDownLimit { transformed_limit(skip, fetch, LogicalPlan::Union(union)) } - LogicalPlan::CrossJoin(mut cross_join) => { - // push limit to both inputs - cross_join.left = make_arc_limit(0, fetch + skip, cross_join.left); - cross_join.right = make_arc_limit(0, fetch + skip, cross_join.right); - transformed_limit(skip, fetch, LogicalPlan::CrossJoin(cross_join)) - } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) .update_data(|join| { make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 73df506397b1..d80c6b716537 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -61,9 +61,9 @@ use datafusion_expr::{ dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, - EmptyRelation, Extension, Join, JoinConstraint, Prepare, Projection, Repartition, - Sort, SubqueryAlias, TableScan, Values, Window, + CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, + Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, + SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, WindowUDF, @@ -1503,24 +1503,6 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left = protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right = protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { - logical_plan_type: Some(LogicalPlanType::CrossJoin(Box::new( - protobuf::CrossJoinNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - }, - ))), - }) - } LogicalPlan::Extension(extension) => { let mut buf: Vec = vec![]; extension_codec.try_encode(extension, &mut buf)?; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 0147a607567b..695027374fa0 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -96,7 +96,6 @@ impl Unparser<'_> { | LogicalPlan::Aggregate(_) | LogicalPlan::Sort(_) | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -497,43 +496,6 @@ impl Unparser<'_> { Ok(()) } - LogicalPlan::CrossJoin(cross_join) => { - // Cross joins are the same as unconditional inner joins - let mut right_relation = RelationBuilder::default(); - - self.select_to_sql_recursively( - cross_join.left.as_ref(), - query, - select, - relation, - )?; - self.select_to_sql_recursively( - cross_join.right.as_ref(), - query, - select, - &mut right_relation, - )?; - - let Ok(Some(relation)) = right_relation.build() else { - return internal_err!("Failed to build right relation"); - }; - - let ast_join = ast::Join { - relation, - global: false, - join_operator: self.join_operator_to_sql( - JoinType::Inner, - ast::JoinConstraint::On(ast::Expr::Value(ast::Value::Boolean( - true, - ))), - ), - }; - let mut from = select.pop_from().unwrap(); - from.push_join(ast_join); - select.push_from(from); - - Ok(()) - } LogicalPlan::SubqueryAlias(plan_alias) => { let (plan, mut columns) = subquery_alias_inner_query_and_columns(plan_alias); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index bb50c4b9610f..4105dc1876db 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -24,7 +24,7 @@ use substrait::proto::expression_reference::ExprType; use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, + Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -67,7 +67,7 @@ use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ - rel_common, CrossRel, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, + rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, }; use substrait::{ proto::{ @@ -476,23 +476,6 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::CrossJoin(cross_join) => { - let CrossJoin { - left, - right, - schema: _, - } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extensions)?; - let right = to_substrait_rel(right.as_ref(), ctx, extensions)?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Cross(Box::new(CrossRel { - common: None, - left: Some(left), - right: Some(right), - advanced_extension: None, - }))), - })) - } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait From ac827abe1b66b1dfa02ce65ae857477f68667843 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Thu, 24 Oct 2024 07:26:37 -0400 Subject: [PATCH 065/110] feat: Migrate Map Functions (#13047) * add page * small fixes * delete md * Migrate map functions --- datafusion/functions-nested/src/map.rs | 69 ++++++++- .../functions-nested/src/map_extract.rs | 49 +++++- datafusion/functions-nested/src/map_keys.rs | 41 ++++- datafusion/functions-nested/src/map_values.rs | 41 ++++- dev/update_function_docs.sh | 1 - .../source/user-guide/sql/scalar_functions.md | 145 ------------------ .../user-guide/sql/scalar_functions_new.md | 144 +++++++++++++++++ 7 files changed, 334 insertions(+), 156 deletions(-) diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 29afe4a7f3be..d7dce3bacbe1 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::collections::{HashSet, VecDeque}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayData; use arrow_array::{Array, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; @@ -27,7 +27,10 @@ use arrow_schema::{DataType, Field, SchemaBuilder}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; use crate::make_array::make_array; @@ -238,7 +241,69 @@ impl ScalarUDFImpl for MapFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_map_batch(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_doc()) + } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns an Arrow map with the specified key-value pairs.\n\n\ + The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null." + ) + .with_syntax_example( + "map(key, value)\nmap(key: value)\nmake_map(['key1', 'key2'], ['value1', 'value2'])" + ) + .with_sql_example( + r#"```sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ```"# + ) + .with_argument( + "key", + "For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null." + ) + .with_argument( + "value", + "For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of values to be mapped to the corresponding keys." + ) + .build() + .unwrap() + }) +} + fn get_element_type(data_type: &DataType) -> Result<&DataType> { match data_type { DataType::List(element) => Ok(element.data_type()), diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 9f0c4ad29c60..d2bb6595fe76 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -26,9 +26,12 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::{cast::as_map_array, exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use std::vec; use crate::utils::{get_map_entry_field, make_scalar_function}; @@ -101,6 +104,48 @@ impl ScalarUDFImpl for MapExtract { field.first().unwrap().data_type().clone(), ]) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_extract_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_extract_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list containing the value for the given key or an empty list if the key is not present in the map.", + ) + .with_syntax_example("map_extract(map, key)") + .with_sql_example( + r#"```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators.", + ) + .with_argument( + "key", + "Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed.", + ) + .build() + .unwrap() + }) } fn general_map_extract_inner( diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 0b1cebb27c86..f28de1c3b2c7 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -21,12 +21,13 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( MapKeysFunc, @@ -81,6 +82,40 @@ impl ScalarUDFImpl for MapKeysFunc { fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { make_scalar_function(map_keys_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_keys_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_keys_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all keys in the map." + ) + .with_syntax_example("map_keys(map)") + .with_sql_example( + r#"```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) } fn map_keys_inner(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 58c0d74eed5f..2b19d9fbbc76 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -21,12 +21,13 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow_array::{Array, ArrayRef, ListArray}; use arrow_schema::{DataType, Field}; use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; make_udf_expr_and_func!( MapValuesFunc, @@ -81,6 +82,40 @@ impl ScalarUDFImpl for MapValuesFunc { fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { make_scalar_function(map_values_inner)(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_values_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_values_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all values in the map." + ) + .with_syntax_example("map_values(map)") + .with_sql_example( + r#"```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) } fn map_values_inner(args: &[ArrayRef]) -> Result { diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh index f1f26c8b2f58..13bc22afcc13 100755 --- a/dev/update_function_docs.sh +++ b/dev/update_function_docs.sh @@ -297,4 +297,3 @@ echo "Running prettier" npx prettier@2.3.2 --write "$TARGET_FILE" echo "'$TARGET_FILE' successfully updated!" - diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 547ea108080e..203411428777 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -199,151 +199,6 @@ Unwraps struct fields into columns. +-----------------------+-----------------------+ ``` -## Map Functions - -- [map](#map) -- [make_map](#make_map) -- [map_extract](#map_extract) -- [map_keys](#map_keys) -- [map_values](#map_values) - -### `map` - -Returns an Arrow map with the specified key-value pairs. - -``` -map(key, value) -map(key: value) -``` - -#### Arguments - -- **key**: Expression to be used for key. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. -- **value**: Expression to be used for value. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); ----- -{POST: 41, HEAD: 33, PATCH: } - -SELECT MAP([[1,2], [3,4]], ['a', 'b']); ----- -{[1, 2]: a, [3, 4]: b} - -SELECT MAP { 'a': 1, 'b': 2 }; ----- -{a: 1, b: 2} -``` - -### `make_map` - -Returns an Arrow map with the specified key-value pairs. - -``` -make_map(key_1, value_1, ..., key_n, value_n) -``` - -#### Arguments - -- **key_n**: Expression to be used for key. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. -- **value_n**: Expression to be used for value. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); ----- -{POST: 41, HEAD: 33, PATCH: } -``` - -### `map_extract` - -Return a list containing the value for a given key or an empty list if the key is not contained in the map. - -``` -map_extract(map, key) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. -- `key`: Key to extract from the map. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed. - -#### Example - -``` -SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); ----- -[1] -``` - -#### Aliases - -- element_at - -### `map_keys` - -Return a list of all keys in the map. - -``` -map_keys(map) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. - -#### Example - -``` -SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[a, b, c] - -select map_keys(map([100, 5], [42,43])); ----- -[100, 5] -``` - -### `map_values` - -Return a list of all values in the map. - -``` -map_values(map) -``` - -#### Arguments - -- `map`: Map expression. - Can be a constant, column, or function, and any combination of map operators. - -#### Example - -``` -SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[1, , 3] - -select map_values(map([100, 5], [42,43])); ----- -[42, 43] -``` - ## Other Functions See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 1f4ec1c27858..7d0280dbc28f 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -3898,6 +3898,150 @@ select struct(a as field_a, b) from t; - row +## Map Functions + +- [element_at](#element_at) +- [map](#map) +- [map_extract](#map_extract) +- [map_keys](#map_keys) +- [map_values](#map_values) + +### `element_at` + +_Alias of [map_extract](#map_extract)._ + +### `map` + +Returns an Arrow map with the specified key-value pairs. + +The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null. + +``` +map(key, value) +map(key: value) +make_map(['key1', 'key2'], ['value1', 'value2']) +``` + +#### Arguments + +- **key**: For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null. +- **value**: For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of values to be mapped to the corresponding keys. + +#### Example + +````sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ``` + + +### `map_extract` + +Returns a list containing the value for the given key or an empty list if the key is not present in the map. + +```` + +map_extract(map, key) + +```` +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. +- **key**: Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed. + +#### Example + +```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```` + +#### Aliases + +- element_at + +### `map_keys` + +Returns a list of all keys in the map. + +``` +map_keys(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +``` + +### `map_values` + +Returns a list of all values in the map. + +``` +map_values(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +``` + ## Hashing Functions - [digest](#digest) From 307c1ea2ef2323aa347be029039b9daf9b419645 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 24 Oct 2024 11:04:26 -0400 Subject: [PATCH 066/110] Minor: Add documentation for `cot` (#13069) * Add documentation for `cot` * fmt --- datafusion/functions/src/math/cot.rs | 26 ++++++++++++++++--- .../user-guide/sql/scalar_functions_new.md | 13 ++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index f039767536fa..eded50a20d8d 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -16,18 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use crate::utils::make_scalar_function; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct CotFunc { signature: Signature, @@ -39,6 +39,20 @@ impl Default for CotFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cot_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cotangent of a number.") + .with_syntax_example(r#"cot(numeric_expression)"#) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + impl CotFunc { pub fn new() -> Self { use DataType::*; @@ -77,6 +91,10 @@ impl ScalarUDFImpl for CotFunc { } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_cot_doc()) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(cot, vec![])(args) } diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 7d0280dbc28f..55e61984d7f8 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -47,6 +47,7 @@ the rest of the documentation. - [ceil](#ceil) - [cos](#cos) - [cosh](#cosh) +- [cot](#cot) - [degrees](#degrees) - [exp](#exp) - [factorial](#factorial) @@ -221,6 +222,18 @@ cosh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +### `cot` + +Returns the cotangent of a number. + +``` +cot(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + ### `degrees` Converts radians to degrees. From 631408baefa79d80f2f7e59e8a6e22714312cc3e Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 24 Oct 2024 09:25:03 -0700 Subject: [PATCH 067/110] Documentation: Add API deprecation policy (#13083) * Documentation: Add API deprecation policy --- README.md | 5 +++ docs/source/library-user-guide/api-health.md | 37 ++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 docs/source/library-user-guide/api-health.md diff --git a/README.md b/README.md index 30505d7ca132..bbbdf7133518 100644 --- a/README.md +++ b/README.md @@ -134,3 +134,8 @@ For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81 If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) + +## DataFusion API evolution policy + +Public methods in Apache DataFusion are subject to evolve as part of the API lifecycle. +Deprecated methods will be phased out in accordance with the [policy](docs/source/library-user-guide/api-health.md), ensuring the API is stable and healthy. diff --git a/docs/source/library-user-guide/api-health.md b/docs/source/library-user-guide/api-health.md new file mode 100644 index 000000000000..943a370e8172 --- /dev/null +++ b/docs/source/library-user-guide/api-health.md @@ -0,0 +1,37 @@ + + +# API health policy + +To maintain API health, developers must track and properly deprecate outdated methods. +When deprecating a method: + +- clearly mark the API as deprecated and specify the exact DataFusion version in which it was deprecated. +- concisely describe the preferred API, if relevant + +API deprecation example: + +```rust + #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self +``` + +Deprecated methods will remain in the codebase for a period of 6 major versions or 6 months, whichever is longer, to provide users ample time to transition away from them. + +Please refer to [DataFusion releases](https://crates.io/crates/datafusion/versions) to plan ahead API migration From 1b14655148377fac317153b5b5f14f4256a5d375 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Thu, 24 Oct 2024 14:41:39 -0400 Subject: [PATCH 068/110] changed doc instance (#13097) --- datafusion/functions-nested/src/range.rs | 4 ++- .../user-guide/sql/scalar_functions_new.md | 29 +++++++------------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 2346b4d5b43f..ddc56b1e4ee8 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -283,8 +283,10 @@ impl ScalarUDFImpl for GenSeries { } } +static GENERATE_SERIES_DOCUMENTATION: OnceLock = OnceLock::new(); + fn get_generate_series_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { + GENERATE_SERIES_DOCUMENTATION.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_ARRAY) .with_description( diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 55e61984d7f8..c15821ac89a3 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -3530,34 +3530,27 @@ flatten(array) ### `generate_series` -Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0. +Similar to the range function, but it includes the upper bound. ``` -range(start, stop, step) +generate_series(start, stop, step) ``` #### Arguments -- **start**: Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. -- **end**: End of the range (not included). Type must be the same as start. -- **step**: Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges. +- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: end of the series (included). Type must be the same as start. +- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. #### Example ```sql -> select range(2, 10, 3); -+-----------------------------------+ -| range(Int64(2),Int64(10),Int64(3))| -+-----------------------------------+ -| [2, 5, 8] | -+-----------------------------------+ - -> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); -+--------------------------------------------------------------+ -| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | -+--------------------------------------------------------------+ -| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | -+--------------------------------------------------------------+ +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ ``` ### `list_any_value` From 31701b8dc9c6486856c06a29a32107d9f4549cec Mon Sep 17 00:00:00 2001 From: Max Norfolk <66913041+mnorfolk03@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:01:08 -0400 Subject: [PATCH 069/110] chore: Added a number of physical planning join benchmarks (#13085) * chore: Added a number of physical planning join benchmarks * Ran cargo fmt --- datafusion/core/benches/sql_planner.rs | 69 +++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 09f05c70fec6..64d2760e9d97 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -60,7 +60,9 @@ fn create_schema(column_prefix: &str, num_columns: usize) -> Schema { fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc { let schema = Arc::new(create_schema(column_prefix, num_columns)); - MemTable::try_new(schema, vec![]).map(Arc::new).unwrap() + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() } fn create_context() -> SessionContext { @@ -158,6 +160,71 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); + // Benchmark for Physical Planning Joins + c.bench_function("physical_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 = b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_theta_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 < b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_many_self_joins", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT ta.a9, tb.a10, tc.a11, td.a12, te.a13, tf.a14 \ + FROM t1 AS ta, t1 AS tb, t1 AS tc, t1 AS td, t1 AS te, t1 AS tf \ + WHERE ta.a9 = tb.a10 AND tb.a10 = tc.a11 AND tc.a11 = td.a12 AND \ + td.a12 = te.a13 AND te.a13 = tf.a14", + ); + }); + }); + + c.bench_function("physical_unnest_to_join", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 \ + FROM t1 WHERE a7 = (SELECT b8 FROM t2)", + ); + }); + }); + + c.bench_function("physical_intersection", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 FROM t1 \ + INTERSECT SELECT t2.b8 FROM t2", + ); + }); + }); + // these two queries should be equivalent + c.bench_function("physical_join_distinct", |b| { + b.iter(|| { + logical_plan( + &ctx, + "SELECT DISTINCT t1.a7 \ + FROM t1, t2 WHERE t1.a7 = t2.b8", + ); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); From 9e636d8fe920340409e527da36cff3d2c25aef5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20=C5=9Een?= Date: Thu, 24 Oct 2024 23:28:13 +0300 Subject: [PATCH 070/110] [docs]: migrate lead/lag window function docs to new docs (#13095) * added lead-lag docs * deleted old --- datafusion/functions-window/src/lead_lag.rs | 58 ++++++++++++++++++- .../source/user-guide/sql/window_functions.md | 30 ---------- .../user-guide/sql/window_functions_new.md | 33 +++++++++++ 3 files changed, 88 insertions(+), 33 deletions(-) diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index f81521099751..bbe50cbbdc8a 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -22,9 +22,10 @@ use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; use datafusion_expr::{ - Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, Volatility, - WindowUDFImpl, + Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, + Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -34,7 +35,7 @@ use std::any::Any; use std::cmp::min; use std::collections::VecDeque; use std::ops::{Neg, Range}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; get_or_init_udwf!( Lag, @@ -147,6 +148,50 @@ impl WindowShift { } } +static LAG_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lag_doc() -> &'static Documentation { + LAG_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows before the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lag(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows back \ + the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) +} + +static LEAD_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lead_doc() -> &'static Documentation { + LEAD_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows after the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lead(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows \ + forward the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) +} + impl WindowUDFImpl for WindowShift { fn as_any(&self) -> &dyn Any { self @@ -212,6 +257,13 @@ impl WindowUDFImpl for WindowShift { WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), } } + + fn documentation(&self) -> Option<&Documentation> { + match self.kind { + WindowShiftKind::Lag => Some(get_lag_doc()), + WindowShiftKind::Lead => Some(get_lead_doc()), + } + } } /// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 6c0de711bc0c..0799859e4371 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -184,8 +184,6 @@ ntile(expression) - [cume_dist](#cume_dist) - [percent_rank](#percent_rank) -- [lag](#lag) -- [lead](#lead) - [first_value](#first_value) - [last_value](#last_value) - [nth_value](#nth_value) @@ -206,34 +204,6 @@ Relative rank of the current row: (rank - 1) / (total rows - 1). percent_rank() ``` -### `lag` - -Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lag(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows back the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - -### `lead` - -Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lead(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows forward the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - ### `first_value` Returns value evaluated at the row that is the first row of the window frame. diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md index 89ce2284a70c..267060abfdcc 100644 --- a/docs/source/user-guide/sql/window_functions_new.md +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -202,3 +202,36 @@ Number of the current row within its partition, counting from 1. ``` row_number() ``` + +## Analytical Functions + +- [lag](#lag) +- [lead](#lead) + +### `lag` + +Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lag(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. + +### `lead` + +Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lead(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. From 232293367d6f3ff10e291597e1dc45bcce7de7d7 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 24 Oct 2024 13:29:33 -0700 Subject: [PATCH 071/110] minor: Add deprecated policy to the contributor guide contents and fix the link from main README (#13100) --- README.md | 2 +- docs/source/index.rst | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bbbdf7133518..f89935d597c2 100644 --- a/README.md +++ b/README.md @@ -138,4 +138,4 @@ We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo% ## DataFusion API evolution policy Public methods in Apache DataFusion are subject to evolve as part of the API lifecycle. -Deprecated methods will be phased out in accordance with the [policy](docs/source/library-user-guide/api-health.md), ensuring the API is stable and healthy. +Deprecated methods will be phased out in accordance with the [policy](https://datafusion.apache.org/library-user-guide/api-health.html), ensuring the API is stable and healthy. diff --git a/docs/source/index.rst b/docs/source/index.rst index 27dd58cf50f4..9008950d3dd6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -130,6 +130,7 @@ To get started, see library-user-guide/extending-operators library-user-guide/profiling library-user-guide/query-optimizer + library-user-guide/api-health .. _toc.contributor-guide: .. toctree:: From 6a3c0b0bce67553a4431b941d13fc995f310bee8 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 25 Oct 2024 08:28:41 +0200 Subject: [PATCH 072/110] feat: improve type inference for WindowFrame (#13059) * feat: improve type inference for WindowFrame Closes #11432 * Support Interval for groups and rows * Remove case for SingleQuotedString --- datafusion/expr/src/window_frame.rs | 191 ++++++++++++++---- .../optimizer/src/analyzer/type_coercion.rs | 23 ++- datafusion/sql/tests/cases/plan_to_sql.rs | 2 +- datafusion/sqllogictest/test_files/window.slt | 32 ++- .../substrait/src/logical_plan/producer.rs | 104 ++-------- 5 files changed, 217 insertions(+), 135 deletions(-) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2e8268aa332..349968c3fa2f 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,11 +23,11 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::{expr::Sort, lit}; +use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{expr::Sort, lit}; - use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -119,9 +119,9 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.try_into()?; + let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; let end_bound = match value.end_bound { - Some(value) => value.try_into()?, + Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?, None => WindowFrameBound::CurrentRow, }; @@ -138,6 +138,7 @@ impl TryFrom for WindowFrame { )? } }; + let units = value.units.into(); Ok(Self::new_bounds(units, start_bound, end_bound)) } @@ -334,17 +335,18 @@ impl WindowFrameBound { } } -impl TryFrom for WindowFrameBound { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrameBound) -> Result { +impl WindowFrameBound { + fn try_parse( + value: ast::WindowFrameBound, + units: &ast::WindowFrameUnits, + ) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_frame_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, @@ -352,33 +354,65 @@ impl TryFrom for WindowFrameBound { } } -pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { - Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ast::Value::Number(value, false)) - | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, - ast::Expr::Interval(ast::Interval { - value, - leading_field, - .. - }) => { - let result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return sql_err!(ParserError(format!( - "INTERVAL expression cannot be {e:?}" - ))); +fn convert_frame_bound_to_scalar_value( + v: ast::Expr, + units: &ast::WindowFrameUnits, +) -> Result { + match units { + // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... + ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { + ast::Expr::Value(ast::Value::Number(value, false)) => { + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + }, + ast::Expr::Interval(ast::Interval { + value, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }) => { + let value = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + } + _ => plan_err!( + "Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ), + }, + // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, + // so we use a ScalarValue::Utf8. + ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) => value, + ast::Expr::Interval(ast::Interval { + value, + leading_field, + .. + }) => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + if let Some(leading_field) = leading_field { + format!("{result} {leading_field}") + } else { + result } - }; - if let Some(leading_field) = leading_field { - format!("{result} {leading_field}") - } else { - result } - } - _ => plan_err!( - "Invalid window frame: frame offsets must be non negative integers" - )?, - }))) + _ => plan_err!( + "Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval" + )?, + }))), + } } impl fmt::Display for WindowFrameBound { @@ -479,8 +513,91 @@ mod tests { ast::Expr::Value(ast::Value::Number("1".to_string(), false)), )))), }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); + + let window_frame = WindowFrame::try_from(window_frame)?; + assert_eq!(window_frame.units, WindowFrameUnits::Rows); + assert_eq!( + window_frame.start_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + ); + assert_eq!( + window_frame.end_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + ); + + Ok(()) + } + + macro_rules! test_bound { + ($unit:ident, $value:expr, $expected:expr) => { + let preceding = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(preceding, WindowFrameBound::Preceding($expected)); + let following = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(following, WindowFrameBound::Following($expected)); + }; + } + + macro_rules! test_bound_err { + ($unit:ident, $value:expr, $expected:expr) => { + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + }; + } + + #[test] + fn test_window_frame_bound_creation() -> Result<()> { + // Unbounded + test_bound!(Rows, None, ScalarValue::Null); + test_bound!(Groups, None, ScalarValue::Null); + test_bound!(Range, None, ScalarValue::Null); + + // Number + let number = Some(Box::new(ast::Expr::Value(ast::Value::Number( + "42".to_string(), + false, + )))); + test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("42".to_string())) + ); + + // Interval + let number = Some(Box::new(ast::Expr::Interval(ast::Interval { + value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + "1".to_string(), + ))), + leading_field: Some(ast::DateTimeField::Day), + fractional_seconds_precision: None, + last_field: None, + leading_precision: None, + }))); + test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("1 DAY".to_string())) + ); + Ok(()) } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 36b72233b5af..33eea1a661c6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -696,20 +696,20 @@ fn coerce_window_frame( expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|s| s.expr.get_type(schema)) - .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { + let current_types = expressions + .first() + .map(|s| s.expr.get_type(schema)) + .transpose()?; + if let Some(col_type) = current_types { if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) + || is_utf8_or_large_utf8(&col_type) || matches!(col_type, DataType::Null) { col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) + } else if is_datetime(&col_type) { + DataType::Interval(IntervalUnit::MonthDayNano) } else { return internal_err!( "Cannot run range queries on datatype: {col_type:?}" @@ -719,10 +719,11 @@ fn coerce_window_frame( return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64, }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; + window_frame.start_bound = + coerce_frame_bound(&target_type, window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?; Ok(window_frame) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 9ed084eec249..8e25c1c5b1cd 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1061,7 +1061,7 @@ fn test_aggregation_to_sql() { FROM person GROUP BY id, first_name;"#, r#"SELECT person.id, person.first_name, -sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN '5' PRECEDING AND '2' FOLLOWING) AS moving_sum, +sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 95d850795772..4a2d9e1d6864 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2208,7 +2208,7 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] 02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted] 05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] 06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 07)------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -2378,17 +2378,41 @@ SELECT c9, rn1 FROM (SELECT c9, # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. negative as following -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between current row and -1 following) from (select 1 a) x +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. negative as following +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between current row and -1 following) from (select 1 a) x + +# interval for rows +query I +select row_number() over (rows between '1' preceding and current row) from (select 1 a) x +---- +1 + +# interval for groups +query I +select row_number() over (order by a groups between '1' preceding and current row) from (select 1 a) x +---- +1 + # This test shows that ordering satisfy considers ordering equivalences, # and can simplify (reduce expression size) multi expression requirements during normalization # For the example below, requirement rn1 ASC, c9 DESC should be simplified to the rn1 ASC. diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4105dc1876db..4855af683b7d 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1718,98 +1718,38 @@ fn make_substrait_like_expr( } } +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} + fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { match bound { WindowFrameBound::CurrentRow => Bound { kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), }, - WindowFrameBound::Preceding(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, - WindowFrameBound::Following(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v, - })), + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, From 13a422579d3d0d68c90ee31fdeb5e9bb4bd2df7f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Oct 2024 21:02:54 +0800 Subject: [PATCH 073/110] Introduce `binary_as_string` parquet option, upgrade to arrow/parquet `53.2.0` (#12816) * Update to arrow-rs 53.2.0 * introduce binary_as_string parquet option * Fix test --------- Co-authored-by: Andrew Lamb --- Cargo.toml | 18 +- benchmarks/src/clickbench.rs | 15 +- datafusion-cli/Cargo.lock | 122 +++++------ datafusion/common/src/config.rs | 8 + .../common/src/file_options/parquet_writer.rs | 3 + .../core/src/datasource/file_format/mod.rs | 112 ++++++++-- .../src/datasource/file_format/parquet.rs | 37 +++- .../physical_plan/parquet/opener.rs | 23 +- .../proto/datafusion_common.proto | 1 + datafusion/proto-common/src/from_proto/mod.rs | 11 +- .../proto-common/src/generated/pbjson.rs | 75 ++++++- .../proto-common/src/generated/prost.rs | 190 +++++------------ datafusion/proto-common/src/to_proto/mod.rs | 1 + .../src/generated/datafusion_proto_common.rs | 3 + .../proto/src/logical_plan/file_formats.rs | 2 + .../test_files/information_schema.slt | 2 + .../sqllogictest/test_files/parquet.slt | 201 ++++++++++++++++++ docs/source/user-guide/configs.md | 1 + 18 files changed, 581 insertions(+), 244 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 63bfb7fce413..e1e3aca77153 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,22 +70,22 @@ version = "42.1.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.1.0", features = [ +arrow = { version = "53.2.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.1.0", default-features = false, features = [ +arrow-array = { version = "53.2.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.1.0", default-features = false } -arrow-flight = { version = "53.1.0", features = [ +arrow-buffer = { version = "53.2.0", default-features = false } +arrow-flight = { version = "53.2.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.1.0", default-features = false, features = [ +arrow-ipc = { version = "53.2.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.1.0", default-features = false } -arrow-schema = { version = "53.1.0", default-features = false } -arrow-string = { version = "53.1.0", default-features = false } +arrow-ord = { version = "53.2.0", default-features = false } +arrow-schema = { version = "53.2.0", default-features = false } +arrow-string = { version = "53.2.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -126,7 +126,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.1.0", default-features = false, features = [ +parquet = { version = "53.2.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 6ebefa985b51..3564ae82585a 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -115,12 +115,15 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; + // configure parquet options let mut config = self.common.config(); - config - .options_mut() - .execution - .parquet - .schema_force_view_types = self.common.force_view_types; + { + let parquet_options = &mut config.options_mut().execution.parquet; + parquet_options.schema_force_view_types = self.common.force_view_types; + // The hits_partitioned dataset specifies string columns + // as binary due to how it was written. Force it to strings + parquet_options.binary_as_string = true; + } let ctx = SessionContext::new_with_config(config); self.register_hits(&ctx).await?; @@ -148,7 +151,7 @@ impl RunOpt { Ok(()) } - /// Registrs the `hits.parquet` as a table named `hits` + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 24649832b27e..ca67e3e4f531 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -84,9 +84,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "5f581a3637024bb8f62027f3ab6151f502090388c1dad05b01c70fb733b33c20" dependencies = [ "anstyle", "anstyle-parse", @@ -123,12 +123,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "abbf7eaf69f3b46121caf74645dd5d3078b4b205a2513930da0033156682cd28" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -173,9 +173,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9ba0d7248932f4e2a12fb37f0a2e3ec82b3bdedbac2a1dce186e036843b8f8c" +checksum = "4caf25cdc4a985f91df42ed9e9308e1adbcd341a31a72605c697033fcef163e3" dependencies = [ "arrow-arith", "arrow-array", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d60afcdc004841a5c8d8da4f4fa22d64eb19c0c01ef4bcedd77f175a7cf6e38f" +checksum = "91f2dfd1a7ec0aca967dfaa616096aec49779adc8eccec005e2f5e4111b1192a" dependencies = [ "arrow-array", "arrow-buffer", @@ -209,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f16835e8599dbbb1659fd869d865254c4cf32c6c2bb60b6942ac9fc36bfa5da" +checksum = "d39387ca628be747394890a6e47f138ceac1aa912eab64f02519fed24b637af8" dependencies = [ "ahash", "arrow-buffer", @@ -226,9 +226,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a1f34f0faae77da6b142db61deba2cb6d60167592b178be317b341440acba80" +checksum = "9e51e05228852ffe3eb391ce7178a0f97d2cf80cc6ef91d3c4a6b3cb688049ec" dependencies = [ "bytes", "half", @@ -237,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "450e4abb5775bca0740bec0bcf1b1a5ae07eff43bd625661c4436d8e8e4540c4" +checksum = "d09aea56ec9fa267f3f3f6cdab67d8a9974cbba90b3aa38c8fe9d0bb071bd8c1" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3a4e4d63830a341713e35d9a42452fbc6241d5f42fa5cf6a4681b8ad91370c4" +checksum = "c07b5232be87d115fde73e32f2ca7f1b353bff1b44ac422d3c6fc6ae38f11f0d" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b1e618bbf714c7a9e8d97203c806734f012ff71ae3adc8ad1b075689f540634" +checksum = "b98ae0af50890b494cebd7d6b04b35e896205c1d1df7b29a6272c5d0d0249ef5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98e983549259a2b97049af7edfb8f28b8911682040e99a94e4ceb1196bd65c2" +checksum = "0ed91bdeaff5a1c00d28d8f73466bcb64d32bbd7093b5a30156b4b9f4dba3eee" dependencies = [ "arrow-array", "arrow-buffer", @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b198b9c6fcf086501730efbbcb483317b39330a116125af7bb06467d04b352a3" +checksum = "0471f51260a5309307e5d409c9dc70aede1cd9cf1d4ff0f0a1e8e1a2dd0e0d3c" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2427f37b4459a4b9e533045abe87a5183a5e0995a3fc2c2fd45027ae2cc4ef3f" +checksum = "2883d7035e0b600fb4c30ce1e50e66e53d8656aa729f2bfa4b51d359cf3ded52" dependencies = [ "arrow-array", "arrow-buffer", @@ -339,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15959657d92e2261a7a323517640af87f5afd9fd8a6492e424ebee2203c567f6" +checksum = "552907e8e587a6fde4f8843fd7a27a576a260f65dab6c065741ea79f633fc5be" dependencies = [ "ahash", "arrow-array", @@ -353,15 +353,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf0388a18fd7f7f3fe3de01852d30f54ed5182f9004db700fbe3ba843ed2794" +checksum = "539ada65246b949bd99ffa0881a9a15a4a529448af1a07a9838dd78617dafab1" [[package]] name = "arrow-select" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b83e5723d307a38bf00ecd2972cd078d1339c7fd3eb044f609958a9a24463f3a" +checksum = "6259e566b752da6dceab91766ed8b2e67bf6270eb9ad8a6e07a33c1bede2b125" dependencies = [ "ahash", "arrow-array", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab3db7c09dd826e74079661d84ed01ed06547cf75d52c2818ef776d0d852305" +checksum = "f3179ccbd18ebf04277a095ba7321b93fd1f774f18816bd5f6b3ce2f594edb6c" dependencies = [ "arrow-array", "arrow-buffer", @@ -663,9 +663,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -707,9 +707,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" dependencies = [ "base64-simd", "bytes", @@ -836,9 +836,9 @@ dependencies = [ [[package]] name = "brotli" -version = "6.0.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -880,9 +880,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "bytes-utils" @@ -2704,9 +2704,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.1.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "310c46a70a3ba90d98fec39fa2da6d9d731e544191da6fb56c9d199484d0dd3e" +checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" dependencies = [ "ahash", "arrow-array", @@ -2809,9 +2809,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2881,9 +2881,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -3023,9 +3023,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -3393,18 +3393,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", @@ -3607,9 +3607,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.82" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -3646,18 +3646,18 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", @@ -3731,9 +3731,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 47ffe0b1c66b..33e5184d2cac 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -390,6 +390,14 @@ config_namespace! { /// and `Binary/BinaryLarge` with `BinaryView`. pub schema_force_view_types: bool, default = false + /// (reading) If true, parquet reader will read columns of + /// `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. + /// + /// Parquet files generated by some legacy writers do not correctly set + /// the UTF8 flag for strings, causing string columns to be loaded as + /// BLOB instead. + pub binary_as_string: bool, default = false + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 5d553d59da4e..dd9d67d6bb47 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -176,6 +176,7 @@ impl ParquetOptions { maximum_buffered_record_batches_per_stream: _, bloom_filter_on_read: _, // reads not used for writer props schema_force_view_types: _, + binary_as_string: _, // not used for writer props } = self; let mut builder = WriterProperties::builder() @@ -442,6 +443,7 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: defaults.bloom_filter_on_read, schema_force_view_types: defaults.schema_force_view_types, + binary_as_string: defaults.binary_as_string, } } @@ -543,6 +545,7 @@ mod tests { .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, schema_force_view_types: global_options_defaults.schema_force_view_types, + binary_as_string: global_options_defaults.binary_as_string, }, column_specific_options, key_value_metadata, diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index e16986c660ad..a313a7a9bcb1 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -42,7 +42,7 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_expr::Expr; @@ -235,20 +235,26 @@ pub fn file_type_to_format( } } +/// Create a new field with the specified data type, copying the other +/// properties from the input field +fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { + Arc::new(field.as_ref().clone().with_data_type(new_type)) +} + /// Transform a schema to use view types for Utf8 and Binary +/// +/// See [parquet::ParquetFormat::force_view_types] for details pub fn transform_schema_to_view(schema: &Schema) -> Schema { let transformed_fields: Vec> = schema .fields .iter() .map(|field| match field.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Arc::new( - Field::new(field.name(), DataType::Utf8View, field.is_nullable()) - .with_metadata(field.metadata().to_owned()), - ), - DataType::Binary | DataType::LargeBinary => Arc::new( - Field::new(field.name(), DataType::BinaryView, field.is_nullable()) - .with_metadata(field.metadata().to_owned()), - ), + DataType::Utf8 | DataType::LargeUtf8 => { + field_with_new_type(field, DataType::Utf8View) + } + DataType::Binary | DataType::LargeBinary => { + field_with_new_type(field, DataType::BinaryView) + } _ => field.clone(), }) .collect(); @@ -274,6 +280,7 @@ pub(crate) fn coerce_file_schema_to_view_type( (f.name(), dt) }) .collect(); + if !transform { return None; } @@ -283,14 +290,13 @@ pub(crate) fn coerce_file_schema_to_view_type( .iter() .map( |field| match (table_fields.get(field.name()), field.data_type()) { - (Some(DataType::Utf8View), DataType::Utf8) - | (Some(DataType::Utf8View), DataType::LargeUtf8) => Arc::new( - Field::new(field.name(), DataType::Utf8View, field.is_nullable()), - ), - (Some(DataType::BinaryView), DataType::Binary) - | (Some(DataType::BinaryView), DataType::LargeBinary) => Arc::new( - Field::new(field.name(), DataType::BinaryView, field.is_nullable()), - ), + (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { + field_with_new_type(field, DataType::Utf8View) + } + ( + Some(DataType::BinaryView), + DataType::Binary | DataType::LargeBinary, + ) => field_with_new_type(field, DataType::BinaryView), _ => field.clone(), }, ) @@ -302,6 +308,78 @@ pub(crate) fn coerce_file_schema_to_view_type( )) } +/// Transform a schema so that any binary types are strings +pub fn transform_binary_to_string(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Binary => field_with_new_type(field, DataType::Utf8), + DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), + DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), + _ => field.clone(), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// If the table schema uses a string type, coerce the file schema to use a string type. +/// +/// See [parquet::ParquetFormat::binary_as_string] for details +pub(crate) fn coerce_file_schema_to_string_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| (f.name(), f.data_type())) + .collect(); + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + // table schema uses string type, coerce the file schema to use string type + ( + Some(DataType::Utf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8) + } + // table schema uses large string type, coerce the file schema to use large string type + ( + Some(DataType::LargeUtf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::LargeUtf8) + } + // table schema uses string view type, coerce the file schema to use view type + ( + Some(DataType::Utf8View), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8View) + } + _ => field.clone(), + }, + ) + .collect(); + + if !transform { + None + } else { + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) + } +} + #[cfg(test)] pub(crate) mod test_util { use std::ops::Range; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 8647b5df90be..756c17fd67c6 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -26,8 +26,9 @@ use std::sync::Arc; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; use super::{ - coerce_file_schema_to_view_type, transform_schema_to_view, FileFormat, - FileFormatFactory, FilePushdownSupport, FileScanConfig, + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, + transform_binary_to_string, transform_schema_to_view, FileFormat, FileFormatFactory, + FilePushdownSupport, FileScanConfig, }; use crate::arrow::array::RecordBatch; use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; @@ -253,13 +254,29 @@ impl ParquetFormat { self.options.global.schema_force_view_types } - /// If true, will use view types (StringView and BinaryView). - /// - /// Refer to [`Self::force_view_types`]. + /// If true, will use view types. See [`Self::force_view_types`] for details pub fn with_force_view_types(mut self, use_views: bool) -> Self { self.options.global.schema_force_view_types = use_views; self } + + /// Return `true` if binary types will be read as strings. + /// + /// If this returns true, DataFusion will instruct the parquet reader + /// to read binary columns such as `Binary` or `BinaryView` as the + /// corresponding string type such as `Utf8` or `LargeUtf8`. + /// The parquet reader has special optimizations for `Utf8` and `LargeUtf8` + /// validation, and such queries are significantly faster than reading + /// binary columns and then casting to string columns. + pub fn binary_as_string(&self) -> bool { + self.options.global.binary_as_string + } + + /// If true, will read binary types as strings. See [`Self::binary_as_string`] for details + pub fn with_binary_as_string(mut self, binary_as_string: bool) -> Self { + self.options.global.binary_as_string = binary_as_string; + self + } } /// Clears all metadata (Schema level and field level) on an iterator @@ -350,6 +367,12 @@ impl FileFormat for ParquetFormat { Schema::try_merge(schemas) }?; + let schema = if self.binary_as_string() { + transform_binary_to_string(&schema) + } else { + schema + }; + let schema = if self.force_view_types() { transform_schema_to_view(&schema) } else { @@ -552,6 +575,10 @@ pub fn statistics_from_parquet_meta_calc( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; + if let Some(merged) = coerce_file_schema_to_string_type(&table_schema, &file_schema) { + file_schema = merged; + } + if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &file_schema) { file_schema = merged; } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index a818a8850284..4990cb4dd735 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -17,7 +17,9 @@ //! [`ParquetOpener`] for opening Parquet files -use crate::datasource::file_format::coerce_file_schema_to_view_type; +use crate::datasource::file_format::{ + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, +}; use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; use crate::datasource::physical_plan::parquet::{ @@ -80,7 +82,7 @@ pub(super) struct ParquetOpener { } impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> datafusion_common::Result { + fn open(&self, file_meta: FileMeta) -> Result { let file_range = file_meta.range.clone(); let extensions = file_meta.extensions.clone(); let file_name = file_meta.location().to_string(); @@ -121,7 +123,14 @@ impl FileOpener for ParquetOpener { let mut metadata_timer = file_metrics.metadata_load_time.timer(); let metadata = ArrowReaderMetadata::load_async(&mut reader, options.clone()).await?; - let mut schema = metadata.schema().clone(); + let mut schema = Arc::clone(metadata.schema()); + + if let Some(merged) = + coerce_file_schema_to_string_type(&table_schema, &schema) + { + schema = Arc::new(merged); + } + // read with view types if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &schema) { @@ -130,16 +139,16 @@ impl FileOpener for ParquetOpener { let options = ArrowReaderOptions::new() .with_page_index(enable_page_index) - .with_schema(schema.clone()); + .with_schema(Arc::clone(&schema)); let metadata = - ArrowReaderMetadata::try_new(metadata.metadata().clone(), options)?; + ArrowReaderMetadata::try_new(Arc::clone(metadata.metadata()), options)?; metadata_timer.stop(); let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata); - let file_schema = builder.schema().clone(); + let file_schema = Arc::clone(builder.schema()); let (schema_mapping, adapted_projections) = schema_adapter.map_schema(&file_schema)?; @@ -177,7 +186,7 @@ impl FileOpener for ParquetOpener { // Determine which row groups to actually read. The idea is to skip // as many row groups as possible based on the metadata and query - let file_metadata = builder.metadata().clone(); + let file_metadata = Arc::clone(builder.metadata()); let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); let rg_metadata = file_metadata.row_groups(); // track which row groups to actually read diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index d1506fcd64f0..7f8bce6b206e 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -494,6 +494,7 @@ message ParquetOptions { bool bloom_filter_on_read = 26; // default = true bool bloom_filter_on_write = 27; // default = false bool schema_force_view_types = 28; // default = false + bool binary_as_string = 29; // default = false oneof metadata_size_hint_opt { uint64 metadata_size_hint = 4; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index d1b4374fc0e7..d848f795c684 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -897,7 +897,7 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { pruning: value.pruning, skip_metadata: value.skip_metadata, metadata_size_hint: value - .metadata_size_hint_opt.clone() + .metadata_size_hint_opt .map(|opt| match opt { protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => Some(v as usize), }) @@ -958,6 +958,7 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as usize, maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as usize, schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, }) } } @@ -979,7 +980,7 @@ impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { }) .unwrap_or(None), max_statistics_size: value - .max_statistics_size_opt.clone() + .max_statistics_size_opt .map(|opt| match opt { protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(v as usize), }) @@ -990,18 +991,18 @@ impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { protobuf::parquet_column_options::EncodingOpt::Encoding(v) => Some(v), }) .unwrap_or(None), - bloom_filter_enabled: value.bloom_filter_enabled_opt.clone().map(|opt| match opt { + bloom_filter_enabled: value.bloom_filter_enabled_opt.map(|opt| match opt { protobuf::parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => Some(v), }) .unwrap_or(None), bloom_filter_fpp: value - .bloom_filter_fpp_opt.clone() + .bloom_filter_fpp_opt .map(|opt| match opt { protobuf::parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), }) .unwrap_or(None), bloom_filter_ndv: value - .bloom_filter_ndv_opt.clone() + .bloom_filter_ndv_opt .map(|opt| match opt { protobuf::parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), }) diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index fa5d1f442754..e8b46fbf7012 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1548,18 +1548,22 @@ impl serde::Serialize for CsvOptions { let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("hasHeader", pbjson::private::base64::encode(&self.has_header).as_str())?; } if !self.delimiter.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("delimiter", pbjson::private::base64::encode(&self.delimiter).as_str())?; } if !self.quote.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("quote", pbjson::private::base64::encode(&self.quote).as_str())?; } if !self.escape.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("escape", pbjson::private::base64::encode(&self.escape).as_str())?; } if self.compression != 0 { @@ -1569,6 +1573,7 @@ impl serde::Serialize for CsvOptions { } if self.schema_infer_max_rec != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; } if !self.date_format.is_empty() { @@ -1591,18 +1596,22 @@ impl serde::Serialize for CsvOptions { } if !self.comment.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?; } if !self.double_quote.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; } if !self.newlines_in_values.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("newlinesInValues", pbjson::private::base64::encode(&self.newlines_in_values).as_str())?; } if !self.terminator.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("terminator", pbjson::private::base64::encode(&self.terminator).as_str())?; } struct_ser.end() @@ -2276,14 +2285,17 @@ impl serde::Serialize for Decimal128 { let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } if self.p != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } if self.s != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() @@ -2410,14 +2422,17 @@ impl serde::Serialize for Decimal256 { let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; } if self.p != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; } if self.s != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; } struct_ser.end() @@ -3080,6 +3095,7 @@ impl serde::Serialize for Field { } if self.dict_id != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; } if self.dict_ordered { @@ -3484,6 +3500,7 @@ impl serde::Serialize for IntervalMonthDayNanoValue { } if self.nanos != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; } struct_ser.end() @@ -3917,6 +3934,7 @@ impl serde::Serialize for JsonOptions { } if self.schema_infer_max_rec != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; } struct_ser.end() @@ -4474,6 +4492,7 @@ impl serde::Serialize for ParquetColumnOptions { match v { parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; } } @@ -4894,6 +4913,9 @@ impl serde::Serialize for ParquetOptions { if self.schema_force_view_types { len += 1; } + if self.binary_as_string { + len += 1; + } if self.dictionary_page_size_limit != 0 { len += 1; } @@ -4951,10 +4973,12 @@ impl serde::Serialize for ParquetOptions { } if self.data_pagesize_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dataPagesizeLimit", ToString::to_string(&self.data_pagesize_limit).as_str())?; } if self.write_batch_size != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; } if !self.writer_version.is_empty() { @@ -4965,10 +4989,12 @@ impl serde::Serialize for ParquetOptions { } if self.maximum_parallel_row_group_writers != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maximumParallelRowGroupWriters", ToString::to_string(&self.maximum_parallel_row_group_writers).as_str())?; } if self.maximum_buffered_record_batches_per_stream != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maximumBufferedRecordBatchesPerStream", ToString::to_string(&self.maximum_buffered_record_batches_per_stream).as_str())?; } if self.bloom_filter_on_read { @@ -4980,16 +5006,22 @@ impl serde::Serialize for ParquetOptions { if self.schema_force_view_types { struct_ser.serialize_field("schemaForceViewTypes", &self.schema_force_view_types)?; } + if self.binary_as_string { + struct_ser.serialize_field("binaryAsString", &self.binary_as_string)?; + } if self.dictionary_page_size_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; } if self.data_page_row_count_limit != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; } if self.max_row_group_size != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; } if !self.created_by.is_empty() { @@ -4999,6 +5031,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("metadataSizeHint", ToString::to_string(&v).as_str())?; } } @@ -5028,6 +5061,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("maxStatisticsSize", ToString::to_string(&v).as_str())?; } } @@ -5036,6 +5070,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("columnIndexTruncateLength", ToString::to_string(&v).as_str())?; } } @@ -5058,6 +5093,7 @@ impl serde::Serialize for ParquetOptions { match v { parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; } } @@ -5099,6 +5135,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "bloomFilterOnWrite", "schema_force_view_types", "schemaForceViewTypes", + "binary_as_string", + "binaryAsString", "dictionary_page_size_limit", "dictionaryPageSizeLimit", "data_page_row_count_limit", @@ -5140,7 +5178,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { MaximumBufferedRecordBatchesPerStream, BloomFilterOnRead, BloomFilterOnWrite, - schemaForceViewTypes, + SchemaForceViewTypes, + BinaryAsString, DictionaryPageSizeLimit, DataPageRowCountLimit, MaxRowGroupSize, @@ -5188,7 +5227,8 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "maximumBufferedRecordBatchesPerStream" | "maximum_buffered_record_batches_per_stream" => Ok(GeneratedField::MaximumBufferedRecordBatchesPerStream), "bloomFilterOnRead" | "bloom_filter_on_read" => Ok(GeneratedField::BloomFilterOnRead), "bloomFilterOnWrite" | "bloom_filter_on_write" => Ok(GeneratedField::BloomFilterOnWrite), - "schemaForceViewTypes" | "schema_force_view_types" => Ok(GeneratedField::schemaForceViewTypes), + "schemaForceViewTypes" | "schema_force_view_types" => Ok(GeneratedField::SchemaForceViewTypes), + "binaryAsString" | "binary_as_string" => Ok(GeneratedField::BinaryAsString), "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), @@ -5235,6 +5275,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { let mut bloom_filter_on_read__ = None; let mut bloom_filter_on_write__ = None; let mut schema_force_view_types__ = None; + let mut binary_as_string__ = None; let mut dictionary_page_size_limit__ = None; let mut data_page_row_count_limit__ = None; let mut max_row_group_size__ = None; @@ -5336,12 +5377,18 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { } bloom_filter_on_write__ = Some(map_.next_value()?); } - GeneratedField::schemaForceViewTypes => { + GeneratedField::SchemaForceViewTypes => { if schema_force_view_types__.is_some() { return Err(serde::de::Error::duplicate_field("schemaForceViewTypes")); } schema_force_view_types__ = Some(map_.next_value()?); } + GeneratedField::BinaryAsString => { + if binary_as_string__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryAsString")); + } + binary_as_string__ = Some(map_.next_value()?); + } GeneratedField::DictionaryPageSizeLimit => { if dictionary_page_size_limit__.is_some() { return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); @@ -5443,6 +5490,7 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { bloom_filter_on_read: bloom_filter_on_read__.unwrap_or_default(), bloom_filter_on_write: bloom_filter_on_write__.unwrap_or_default(), schema_force_view_types: schema_force_view_types__.unwrap_or_default(), + binary_as_string: binary_as_string__.unwrap_or_default(), dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), max_row_group_size: max_row_group_size__.unwrap_or_default(), @@ -5867,6 +5915,7 @@ impl serde::Serialize for ScalarFixedSizeBinary { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarFixedSizeBinary", len)?; if !self.values.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("values", pbjson::private::base64::encode(&self.values).as_str())?; } if self.length != 0 { @@ -5986,10 +6035,12 @@ impl serde::Serialize for ScalarNestedValue { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } if !self.arrow_data.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } if let Some(v) = self.schema.as_ref() { @@ -6130,10 +6181,12 @@ impl serde::Serialize for scalar_nested_value::Dictionary { let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue.Dictionary", len)?; if !self.ipc_message.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } if !self.arrow_data.is_empty() { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } struct_ser.end() @@ -6354,10 +6407,12 @@ impl serde::Serialize for ScalarTime64Value { match v { scalar_time64_value::Value::Time64MicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("time64MicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_time64_value::Value::Time64NanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("time64NanosecondValue", ToString::to_string(&v).as_str())?; } } @@ -6471,18 +6526,22 @@ impl serde::Serialize for ScalarTimestampValue { match v { scalar_timestamp_value::Value::TimeMicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeMicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeNanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeNanosecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeSecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeSecondValue", ToString::to_string(&v).as_str())?; } scalar_timestamp_value::Value::TimeMillisecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("timeMillisecondValue", ToString::to_string(&v).as_str())?; } } @@ -6645,6 +6704,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Int64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("int64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Uint8Value(v) => { @@ -6658,6 +6718,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Uint64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("uint64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::Float32Value(v) => { @@ -6695,6 +6756,7 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::Date64Value(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; } scalar_value::Value::IntervalYearmonthValue(v) => { @@ -6702,18 +6764,22 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::DurationSecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationMillisecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationMicrosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::DurationNanosecondValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; } scalar_value::Value::TimestampValue(v) => { @@ -6724,14 +6790,17 @@ impl serde::Serialize for ScalarValue { } scalar_value::Value::BinaryValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("binaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::LargeBinaryValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::BinaryViewValue(v) => { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("binaryViewValue", pbjson::private::base64::encode(&v).as_str())?; } scalar_value::Value::Time64Value(v) => { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index d6f982278d67..939a4b3c2cd2 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -1,11 +1,9 @@ // This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnRelation { #[prost(string, tag = "1")] pub relation: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Column { #[prost(string, tag = "1")] @@ -13,7 +11,6 @@ pub struct Column { #[prost(message, optional, tag = "2")] pub relation: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DfField { #[prost(message, optional, tag = "1")] @@ -21,7 +18,6 @@ pub struct DfField { #[prost(message, optional, tag = "2")] pub qualifier: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DfSchema { #[prost(message, repeated, tag = "1")] @@ -32,40 +28,33 @@ pub struct DfSchema { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvFormat { #[prost(message, optional, tag = "5")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetFormat { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct AvroFormat {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct NdJsonFormat { #[prost(message, optional, tag = "1")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UniqueConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Constraint { #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] @@ -73,7 +62,6 @@ pub struct Constraint { } /// Nested message and enum types in `Constraint`. pub mod constraint { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ConstraintMode { #[prost(message, tag = "1")] @@ -82,19 +70,15 @@ pub mod constraint { Unique(super::UniqueConstraint), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Constraints { #[prost(message, repeated, tag = "1")] pub constraints: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct AvroOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ArrowOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Schema { #[prost(message, repeated, tag = "1")] @@ -105,7 +89,6 @@ pub struct Schema { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Field { /// name of the field @@ -128,7 +111,6 @@ pub struct Field { #[prost(bool, tag = "7")] pub dict_ordered: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { #[prost(enumeration = "TimeUnit", tag = "1")] @@ -136,29 +118,25 @@ pub struct Timestamp { #[prost(string, tag = "2")] pub timezone: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Decimal { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Decimal256Type { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct List { #[prost(message, optional, boxed, tag = "1")] pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FixedSizeList { #[prost(message, optional, boxed, tag = "1")] @@ -166,7 +144,6 @@ pub struct FixedSizeList { #[prost(int32, tag = "2")] pub list_size: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Dictionary { #[prost(message, optional, boxed, tag = "1")] @@ -174,13 +151,11 @@ pub struct Dictionary { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Struct { #[prost(message, repeated, tag = "1")] pub sub_field_types: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Map { #[prost(message, optional, boxed, tag = "1")] @@ -188,7 +163,6 @@ pub struct Map { #[prost(bool, tag = "2")] pub keys_sorted: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Union { #[prost(message, repeated, tag = "1")] @@ -199,7 +173,6 @@ pub struct Union { pub type_ids: ::prost::alloc::vec::Vec, } /// Used for List/FixedSizeList/LargeList/Struct/Map -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { #[prost(bytes = "vec", tag = "1")] @@ -213,7 +186,6 @@ pub struct ScalarNestedValue { } /// Nested message and enum types in `ScalarNestedValue`. pub mod scalar_nested_value { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Dictionary { #[prost(bytes = "vec", tag = "1")] @@ -222,16 +194,14 @@ pub mod scalar_nested_value { pub arrow_data: ::prost::alloc::vec::Vec, } } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ScalarTime32Value { #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime32Value`. pub mod scalar_time32_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int32, tag = "1")] Time32SecondValue(i32), @@ -239,16 +209,14 @@ pub mod scalar_time32_value { Time32MillisecondValue(i32), } } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ScalarTime64Value { #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime64Value`. pub mod scalar_time64_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] Time64MicrosecondValue(i64), @@ -256,7 +224,6 @@ pub mod scalar_time64_value { Time64NanosecondValue(i64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarTimestampValue { #[prost(string, tag = "5")] @@ -266,8 +233,7 @@ pub struct ScalarTimestampValue { } /// Nested message and enum types in `ScalarTimestampValue`. pub mod scalar_timestamp_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] TimeMicrosecondValue(i64), @@ -279,7 +245,6 @@ pub mod scalar_timestamp_value { TimeMillisecondValue(i64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarDictionaryValue { #[prost(message, optional, tag = "1")] @@ -287,16 +252,14 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] pub days: i32, #[prost(int32, tag = "2")] pub milliseconds: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct IntervalMonthDayNanoValue { #[prost(int32, tag = "1")] pub months: i32, @@ -305,7 +268,6 @@ pub struct IntervalMonthDayNanoValue { #[prost(int64, tag = "3")] pub nanos: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionField { #[prost(int32, tag = "1")] @@ -313,7 +275,6 @@ pub struct UnionField { #[prost(message, optional, tag = "2")] pub field: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionValue { /// Note that a null union value must have one or more fields, so we @@ -327,7 +288,6 @@ pub struct UnionValue { #[prost(enumeration = "UnionMode", tag = "4")] pub mode: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] @@ -335,7 +295,6 @@ pub struct ScalarFixedSizeBinary { #[prost(int32, tag = "2")] pub length: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarValue { #[prost( @@ -346,7 +305,6 @@ pub struct ScalarValue { } /// Nested message and enum types in `ScalarValue`. pub mod scalar_value { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Value { /// was PrimitiveScalarType null_value = 19; @@ -434,7 +392,6 @@ pub mod scalar_value { UnionValue(::prost::alloc::boxed::Box), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Decimal128 { #[prost(bytes = "vec", tag = "1")] @@ -444,7 +401,6 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Decimal256 { #[prost(bytes = "vec", tag = "1")] @@ -455,7 +411,6 @@ pub struct Decimal256 { pub s: i64, } /// Serialized data type -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ArrowType { #[prost( @@ -466,7 +421,6 @@ pub struct ArrowType { } /// Nested message and enum types in `ArrowType`. pub mod arrow_type { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ArrowTypeEnum { /// arrow::Type::NA @@ -557,16 +511,13 @@ pub mod arrow_type { /// i32 Two = 2; /// } /// } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EmptyMessage {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct JsonWriterOptions { #[prost(enumeration = "CompressionTypeVariant", tag = "1")] pub compression: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvWriterOptions { /// Compression type @@ -604,7 +555,6 @@ pub struct CsvWriterOptions { pub double_quote: bool, } /// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvOptions { /// Indicates if the CSV has a header row @@ -657,8 +607,7 @@ pub struct CsvOptions { pub terminator: ::prost::alloc::vec::Vec, } /// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct JsonOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -667,7 +616,6 @@ pub struct JsonOptions { #[prost(uint64, tag = "2")] pub schema_infer_max_rec: u64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableParquetOptions { #[prost(message, optional, tag = "1")] @@ -680,7 +628,6 @@ pub struct TableParquetOptions { ::prost::alloc::string::String, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetColumnSpecificOptions { #[prost(string, tag = "1")] @@ -688,7 +635,6 @@ pub struct ParquetColumnSpecificOptions { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetColumnOptions { #[prost(oneof = "parquet_column_options::BloomFilterEnabledOpt", tags = "1")] @@ -722,56 +668,47 @@ pub struct ParquetColumnOptions { } /// Nested message and enum types in `ParquetColumnOptions`. pub mod parquet_column_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterEnabledOpt { #[prost(bool, tag = "1")] BloomFilterEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "2")] Encoding(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "3")] DictionaryEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "4")] Compression(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "5")] StatisticsEnabled(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterFppOpt { #[prost(double, tag = "6")] BloomFilterFpp(f64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "7")] BloomFilterNdv(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MaxStatisticsSizeOpt { #[prost(uint32, tag = "8")] MaxStatisticsSize(u32), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetOptions { /// Regular fields @@ -820,6 +757,9 @@ pub struct ParquetOptions { /// default = false #[prost(bool, tag = "28")] pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, #[prost(uint64, tag = "12")] pub dictionary_page_size_limit: u64, #[prost(uint64, tag = "18")] @@ -859,62 +799,52 @@ pub struct ParquetOptions { } /// Nested message and enum types in `ParquetOptions`. pub mod parquet_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MetadataSizeHintOpt { #[prost(uint64, tag = "4")] MetadataSizeHint(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "10")] Compression(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "11")] DictionaryEnabled(bool), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "13")] StatisticsEnabled(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum MaxStatisticsSizeOpt { #[prost(uint64, tag = "14")] MaxStatisticsSize(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum ColumnIndexTruncateLengthOpt { #[prost(uint64, tag = "17")] ColumnIndexTruncateLength(u64), } - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "19")] Encoding(::prost::alloc::string::String), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterFppOpt { #[prost(double, tag = "21")] BloomFilterFpp(f64), } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "22")] BloomFilterNdv(u64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Precision { #[prost(enumeration = "PrecisionInfo", tag = "1")] @@ -922,7 +852,6 @@ pub struct Precision { #[prost(message, optional, tag = "2")] pub val: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Statistics { #[prost(message, optional, tag = "1")] @@ -932,7 +861,6 @@ pub struct Statistics { #[prost(message, repeated, tag = "3")] pub column_stats: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ColumnStats { #[prost(message, optional, tag = "1")] @@ -963,14 +891,14 @@ impl JoinType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinType::Inner => "INNER", - JoinType::Left => "LEFT", - JoinType::Right => "RIGHT", - JoinType::Full => "FULL", - JoinType::Leftsemi => "LEFTSEMI", - JoinType::Leftanti => "LEFTANTI", - JoinType::Rightsemi => "RIGHTSEMI", - JoinType::Rightanti => "RIGHTANTI", + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1001,8 +929,8 @@ impl JoinConstraint { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinConstraint::On => "ON", - JoinConstraint::Using => "USING", + Self::On => "ON", + Self::Using => "USING", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1029,10 +957,10 @@ impl TimeUnit { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - TimeUnit::Second => "Second", - TimeUnit::Millisecond => "Millisecond", - TimeUnit::Microsecond => "Microsecond", - TimeUnit::Nanosecond => "Nanosecond", + Self::Second => "Second", + Self::Millisecond => "Millisecond", + Self::Microsecond => "Microsecond", + Self::Nanosecond => "Nanosecond", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1060,9 +988,9 @@ impl IntervalUnit { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - IntervalUnit::YearMonth => "YearMonth", - IntervalUnit::DayTime => "DayTime", - IntervalUnit::MonthDayNano => "MonthDayNano", + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1088,8 +1016,8 @@ impl UnionMode { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - UnionMode::Sparse => "sparse", - UnionMode::Dense => "dense", + Self::Sparse => "sparse", + Self::Dense => "dense", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1117,11 +1045,11 @@ impl CompressionTypeVariant { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - CompressionTypeVariant::Gzip => "GZIP", - CompressionTypeVariant::Bzip2 => "BZIP2", - CompressionTypeVariant::Xz => "XZ", - CompressionTypeVariant::Zstd => "ZSTD", - CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1149,8 +1077,8 @@ impl JoinSide { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - JoinSide::LeftSide => "LEFT_SIDE", - JoinSide::RightSide => "RIGHT_SIDE", + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1176,9 +1104,9 @@ impl PrecisionInfo { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - PrecisionInfo::Exact => "EXACT", - PrecisionInfo::Inexact => "INEXACT", - PrecisionInfo::Absent => "ABSENT", + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", } } /// Creates an enum from field names used in the ProtoBuf definition. diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index ebb53ae7577c..f9b8973e2d41 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -831,6 +831,7 @@ impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as u64, maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as u64, schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, }) } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 16de2c777241..939a4b3c2cd2 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -757,6 +757,9 @@ pub struct ParquetOptions { /// default = false #[prost(bool, tag = "28")] pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, #[prost(uint64, tag = "12")] pub dictionary_page_size_limit: u64, #[prost(uint64, tag = "18")] diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 98034e3082af..d0f82ecac62c 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -403,6 +403,7 @@ impl TableParquetOptionsProto { maximum_parallel_row_group_writers: global_options.global.maximum_parallel_row_group_writers as u64, maximum_buffered_record_batches_per_stream: global_options.global.maximum_buffered_record_batches_per_stream as u64, schema_force_view_types: global_options.global.schema_force_view_types, + binary_as_string: global_options.global.binary_as_string, }), column_specific_options: column_specific_options.into_iter().map(|(column_name, options)| { ParquetColumnSpecificOptions { @@ -493,6 +494,7 @@ impl From<&ParquetOptionsProto> for ParquetOptions { maximum_parallel_row_group_writers: proto.maximum_parallel_row_group_writers as usize, maximum_buffered_record_batches_per_stream: proto.maximum_buffered_record_batches_per_stream as usize, schema_force_view_types: proto.schema_force_view_types, + binary_as_string: proto.binary_as_string, } } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 57bf029a63c1..3630f6c36595 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -180,6 +180,7 @@ datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 datafusion.execution.parquet.allow_single_file_parallelism true +datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL datafusion.execution.parquet.bloom_filter_on_read true @@ -271,6 +272,7 @@ datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index f8b163adc796..bf68a1851137 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -348,3 +348,204 @@ DROP TABLE list_columns; # Clean up statement ok DROP TABLE listing_table; + +### Tests for binary_ar_string + +# This scenario models the case where a column has been stored in parquet +# "binary" column (without a String logical type annotation) +# this is the case with the `hits_partitioned` ClickBench datasets +# see https://github.com/apache/datafusion/issues/12788 + +## Create a table with a binary column + +query I +COPY ( + SELECT + arrow_cast(string_col, 'Binary') as binary_col, + arrow_cast(string_col, 'LargeBinary') as largebinary_col, + arrow_cast(string_col, 'BinaryView') as binaryview_col + FROM src_table + ) +TO 'test_files/scratch/parquet/binary_as_string.parquet' +STORED AS PARQUET; +---- +9 + +# Test 1: Read table with default options +statement ok +CREATE EXTERNAL TABLE binary_as_string_default +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' + +# NB the data is read and displayed as binary +query T?T?T? +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_default; +---- +Binary 616161 Binary 616161 Binary 616161 +Binary 626262 Binary 626262 Binary 626262 +Binary 636363 Binary 636363 Binary 636363 +Binary 646464 Binary 646464 Binary 646464 +Binary 656565 Binary 656565 Binary 656565 +Binary 666666 Binary 666666 Binary 666666 +Binary 676767 Binary 676767 Binary 676767 +Binary 686868 Binary 686868 Binary 686868 +Binary 696969 Binary 696969 Binary 696969 + +# Run an explain plan to show the cast happens in the plan (a CAST is needed for the predicates) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_default + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%") +02)--TableScan: binary_as_string_default projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% + + +statement ok +DROP TABLE binary_as_string_default; + +## Test 2: Read table using the binary_as_string option + +statement ok +CREATE EXTERNAL TABLE binary_as_string_option +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ('binary_as_string' 'true'); + +# NB the data is read and displayed as string +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_option; +---- +Utf8 aaa Utf8 aaa Utf8 aaa +Utf8 bbb Utf8 bbb Utf8 bbb +Utf8 ccc Utf8 ccc Utf8 ccc +Utf8 ddd Utf8 ddd Utf8 ddd +Utf8 eee Utf8 eee Utf8 eee +Utf8 fff Utf8 fff Utf8 fff +Utf8 ggg Utf8 ggg Utf8 ggg +Utf8 hhh Utf8 hhh Utf8 hhh +Utf8 iii Utf8 iii Utf8 iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_option + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_option.binary_col LIKE Utf8("%a%") AND binary_as_string_option.largebinary_col LIKE Utf8("%a%") AND binary_as_string_option.binaryview_col LIKE Utf8("%a%") +02)--TableScan: binary_as_string_option projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_option.binary_col LIKE Utf8("%a%"), binary_as_string_option.largebinary_col LIKE Utf8("%a%"), binary_as_string_option.binaryview_col LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +DROP TABLE binary_as_string_option; + +## Test 3: Read table with binary_as_string option AND schema_force_view_types + +statement ok +CREATE EXTERNAL TABLE binary_as_string_both +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ( + 'binary_as_string' 'true', + 'schema_force_view_types' 'true' +); + +# NB the data is read and displayed a StringView +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_both; +---- +Utf8View aaa Utf8View aaa Utf8View aaa +Utf8View bbb Utf8View bbb Utf8View bbb +Utf8View ccc Utf8View ccc Utf8View ccc +Utf8View ddd Utf8View ddd Utf8View ddd +Utf8View eee Utf8View eee Utf8View eee +Utf8View fff Utf8View fff Utf8View fff +Utf8View ggg Utf8View ggg Utf8View ggg +Utf8View hhh Utf8View hhh Utf8View hhh +Utf8View iii Utf8View iii Utf8View iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_both + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_both.binary_col LIKE Utf8View("%a%") AND binary_as_string_both.largebinary_col LIKE Utf8View("%a%") AND binary_as_string_both.binaryview_col LIKE Utf8View("%a%") +02)--TableScan: binary_as_string_both projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_both.binary_col LIKE Utf8View("%a%"), binary_as_string_both.largebinary_col LIKE Utf8View("%a%"), binary_as_string_both.binaryview_col LIKE Utf8View("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +drop table binary_as_string_both; + +# Read a parquet file with binary data in a FixedSizeBinary column + +# by default, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet'; + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3; +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; + + +# even with the binary_as_string option set, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet' +OPTIONS ('binary_as_string' 'true'); + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3 +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 10917932482c..91a2e8b4389a 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -57,6 +57,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | | datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | | datafusion.execution.parquet.schema_force_view_types | false | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | | datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | | datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | | datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | From 02b969381db1f3765676029aa47ddf5e54ccdf4f Mon Sep 17 00:00:00 2001 From: Jagdish Parihar Date: Fri, 25 Oct 2024 18:42:58 +0530 Subject: [PATCH 074/110] Convert `ntile` builtIn function to UDWF (#13040) * converting to ntile udwf * updated the window functions documentation file * wip: update the ntile udwf function * fix the roundtrip_logical_plan.rs * removed builtIn ntile function * fixed field name issue * fixing the return type of ntile udwf * error if UInt64 conversion fails * handling if null is found * handling if value is zero or less than zero * removed unused import * updated prost.rs file * removed dead code * fixed clippy error * added inner doc comment * minor fixes and added roundtrip logical plan test * removed parse_expr in ntile --- .../expr/src/built_in_window_function.rs | 19 -- datafusion/expr/src/expr.rs | 11 +- datafusion/expr/src/window_function.rs | 5 - datafusion/functions-window/src/lib.rs | 4 +- datafusion/functions-window/src/ntile.rs | 168 ++++++++++++++++++ datafusion/functions-window/src/utils.rs | 12 ++ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/physical-expr/src/window/mod.rs | 1 - datafusion/physical-expr/src/window/ntile.rs | 111 ------------ datafusion/physical-plan/src/windows/mod.rs | 59 +----- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 45 ++--- .../tests/cases/roundtrip_logical_plan.rs | 3 +- .../source/user-guide/sql/window_functions.md | 34 ---- .../user-guide/sql/window_functions_new.md | 13 ++ 19 files changed, 221 insertions(+), 276 deletions(-) create mode 100644 datafusion/functions-window/src/ntile.rs delete mode 100644 datafusion/physical-expr/src/window/ntile.rs diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 36916a6b594f..ab41395ad371 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -40,8 +40,6 @@ impl fmt::Display for BuiltInWindowFunction { /// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// Integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, /// returns value evaluated at the row that is the first row of the window frame FirstValue, /// Returns value evaluated at the row that is the last row of the window frame @@ -54,7 +52,6 @@ impl BuiltInWindowFunction { pub fn name(&self) -> &str { use BuiltInWindowFunction::*; match self { - Ntile => "NTILE", FirstValue => "first_value", LastValue => "last_value", NthValue => "NTH_VALUE", @@ -66,7 +63,6 @@ impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name.to_uppercase().as_str() { - "NTILE" => BuiltInWindowFunction::Ntile, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, @@ -97,7 +93,6 @@ impl BuiltInWindowFunction { })?; match self { - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), @@ -111,20 +106,6 @@ impl BuiltInWindowFunction { BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7fadf6391bf3..4d73c2a04486 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2567,18 +2567,9 @@ mod test { Ok(()) } - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16], &[true], "")?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - #[test] fn test_window_function_case_insensitive() -> Result<()> { - let names = vec!["ntile", "first_value", "last_value", "nth_value"]; + let names = vec!["first_value", "last_value", "nth_value"]; for name in names { let fun = find_df_window_func(name).unwrap(); let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index c13a028e4a30..be2b6575e2e9 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -17,11 +17,6 @@ use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; -/// Create an expression to represent the `ntile` window function -pub fn ntile(arg: Expr) -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) -} - /// Create an expression to represent the `nth_value` window function pub fn nth_value(arg: Expr, n: i64) -> Expr { Expr::WindowFunction(WindowFunction::new( diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 13a77977d579..ff8542838df9 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -34,7 +34,7 @@ pub mod macros; pub mod cume_dist; pub mod lead_lag; - +pub mod ntile; pub mod rank; pub mod row_number; mod utils; @@ -44,6 +44,7 @@ pub mod expr_fn { pub use super::cume_dist::cume_dist; pub use super::lead_lag::lag; pub use super::lead_lag::lead; + pub use super::ntile::ntile; pub use super::rank::{dense_rank, percent_rank, rank}; pub use super::row_number::row_number; } @@ -58,6 +59,7 @@ pub fn all_default_window_functions() -> Vec> { rank::rank_udwf(), rank::dense_rank_udwf(), rank::percent_rank_udwf(), + ntile::ntile_udwf(), ] } /// Registers all enabled packages with a [`FunctionRegistry`] diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs new file mode 100644 index 000000000000..b0a7241f24cd --- /dev/null +++ b/datafusion/functions-window/src/ntile.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `ntile` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::sync::{Arc, OnceLock}; + +use crate::utils::{ + get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, +}; +use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; +use datafusion_common::arrow::datatypes::{DataType, Field}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +get_or_init_udwf!( + Ntile, + ntile, + "integer ranging from 1 to the argument value, dividing the partition as equally as possible" +); + +pub fn ntile(arg: Expr) -> Expr { + ntile_udwf().call(vec![arg]) +} + +#[derive(Debug)] +pub struct Ntile { + signature: Signature, +} + +impl Ntile { + /// Create a new `ntile` function + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for Ntile { + fn default() -> Self { + Self::new() + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ntile_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", + ) + .with_syntax_example("ntile(expression)") + .with_argument("expression","An integer describing the number groups the partition should be split into") + .build() + .unwrap() + }) +} + +impl WindowUDFImpl for Ntile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ntile" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if scalar_n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if scalar_n.is_unsigned() { + let n = get_unsigned_integer(scalar_n)?; + Ok(Box::new(NtileEvaluator { n })) + } else { + let n: i64 = get_signed_integer(scalar_n)?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Ok(Box::new(NtileEvaluator { n: n as u64 })) + } + } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let nullable = false; + + Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ntile_doc()) + } +} + +#[derive(Debug)] +struct NtileEvaluator { + n: u64, +} + +impl PartitionEvaluator for NtileEvaluator { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { + let num_rows = num_rows as u64; + let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); + for i in 0..num_rows { + let res = i * n / num_rows; + vec.push(res + 1) + } + Ok(Arc::new(UInt64Array::from(vec))) + } +} diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs index 69f68aa78f2c..3f8061dbea3e 100644 --- a/datafusion/functions-window/src/utils.rs +++ b/datafusion/functions-window/src/utils.rs @@ -51,3 +51,15 @@ pub(crate) fn get_scalar_value_from_args( None }) } + +pub(crate) fn get_unsigned_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::UInt64)?.try_into() +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 63047f6929c1..7d71bd9ff17b 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -36,7 +36,6 @@ mod unknown_column; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; pub use crate::window::nth_value::NthValue; -pub use crate::window::ntile::Ntile; pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 7bab4dbc5af6..3c37fff7a1ba 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -19,7 +19,6 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; pub(crate) mod nth_value; -pub(crate) mod ntile; mod sliding_aggregate; mod window_expr; diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs deleted file mode 100644 index fb7a7ad84fb7..000000000000 --- a/datafusion/physical-expr/src/window/ntile.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `ntile` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::{ArrayRef, UInt64Array}; -use arrow::datatypes::Field; -use arrow_schema::{DataType, SchemaRef, SortOptions}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::sync::Arc; - -#[derive(Debug)] -pub struct Ntile { - name: String, - n: u64, - /// Output data type - data_type: DataType, -} - -impl Ntile { - pub fn new(name: String, n: u64, data_type: &DataType) -> Self { - Self { - name, - n, - data_type: data_type.clone(), - } - } - - pub fn get_n(&self) -> u64 { - self.n - } -} - -impl BuiltInWindowFunctionExpr for Ntile { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(NtileEvaluator { n: self.n })) - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in NTILE window function introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } -} - -#[derive(Debug)] -pub(crate) struct NtileEvaluator { - n: u64, -} - -impl PartitionEvaluator for NtileEvaluator { - fn evaluate_all( - &mut self, - _values: &[ArrayRef], - num_rows: usize, - ) -> Result { - let num_rows = num_rows as u64; - let mut vec: Vec = Vec::new(); - let n = u64::min(self.n, num_rows); - for i in 0..num_rows { - let res = i * n / num_rows; - vec.push(res + 1) - } - Ok(Arc::new(UInt64Array::from(vec))) - } -} diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 39ff71496e21..7ebb7e71ec57 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,15 +21,13 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{Literal, NthValue, Ntile, PhysicalSortExpr}, + expressions::{Literal, NthValue, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::{ BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, WindowFunctionDefinition, WindowUDF, @@ -165,25 +163,6 @@ fn window_expr_from_aggregate_expr( } } -fn get_scalar_value_from_args( - args: &[Arc], - index: usize, -) -> Result> { - Ok(if let Some(field) = args.get(index) { - let tmp = field - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::NotImplemented( - format!("There is only support Literal types for field at idx: {index} in Window Function"), - ))? - .value() - .clone(); - Some(tmp) - } else { - None - }) -} - fn get_signed_integer(value: ScalarValue) -> Result { if value.is_null() { return Ok(0); @@ -196,18 +175,6 @@ fn get_signed_integer(value: ScalarValue) -> Result { value.cast_to(&DataType::Int64)?.try_into() } -fn get_unsigned_integer(value: ScalarValue) -> Result { - if value.is_null() { - return Ok(0); - } - - if !value.data_type().is_integer() { - return exec_err!("Expected an integer value"); - } - - value.cast_to(&DataType::UInt64)?.try_into() -} - fn create_built_in_window_expr( fun: &BuiltInWindowFunction, args: &[Arc], @@ -219,28 +186,6 @@ fn create_built_in_window_expr( let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::Ntile => { - let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires a positive integer".to_string(), - ) - })?; - - if n.is_null() { - return exec_err!("NTILE requires a positive integer, but finds NULL"); - } - - if n.is_unsigned() { - let n = get_unsigned_integer(n)?; - Arc::new(Ntile::new(name, n, out_data_type)) - } else { - let n: i64 = get_signed_integer(n)?; - if n <= 0 { - return exec_err!("NTILE requires a positive integer"); - } - Arc::new(Ntile::new(name, n as u64, out_data_type)) - } - } BuiltInWindowFunction::NthValue => { let arg = Arc::clone(&args[0]); let n = get_signed_integer( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c92328278e83..b68c47c57eb9 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -512,7 +512,7 @@ enum BuiltInWindowFunction { // DENSE_RANK = 2; // PERCENT_RANK = 3; // CUME_DIST = 4; - NTILE = 5; + // NTILE = 5; // LAG = 6; // LEAD = 7; FIRST_VALUE = 8; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index ca331cdaa513..e54edb718808 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1662,7 +1662,6 @@ impl serde::Serialize for BuiltInWindowFunction { { let variant = match self { Self::Unspecified => "UNSPECIFIED", - Self::Ntile => "NTILE", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1678,7 +1677,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { const FIELDS: &[&str] = &[ "UNSPECIFIED", - "NTILE", "FIRST_VALUE", "LAST_VALUE", "NTH_VALUE", @@ -1723,7 +1721,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { match value { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), - "NTILE" => Ok(BuiltInWindowFunction::Ntile), "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fb0b3bcb2c13..dfc30e809108 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1819,7 +1819,7 @@ pub enum BuiltInWindowFunction { /// DENSE_RANK = 2; /// PERCENT_RANK = 3; /// CUME_DIST = 4; - Ntile = 5, + /// NTILE = 5; /// LAG = 6; /// LEAD = 7; FirstValue = 8, @@ -1834,7 +1834,6 @@ impl BuiltInWindowFunction { pub fn as_str_name(&self) -> &'static str { match self { Self::Unspecified => "UNSPECIFIED", - Self::Ntile => "NTILE", Self::FirstValue => "FIRST_VALUE", Self::LastValue => "LAST_VALUE", Self::NthValue => "NTH_VALUE", @@ -1844,7 +1843,6 @@ impl BuiltInWindowFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "UNSPECIFIED" => Some(Self::Unspecified), - "NTILE" => Some(Self::Ntile), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), "NTH_VALUE" => Some(Self::NthValue), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4587c090c96a..27bda7dd5ace 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -152,7 +152,6 @@ impl From for BuiltInWindowFunction { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, - protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, protobuf::BuiltInWindowFunction::NthValue => Self::NthValue, protobuf::BuiltInWindowFunction::LastValue => Self::LastValue, } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index dce0cd741fd3..5a6f3a32c668 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -127,7 +127,6 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::FirstValue => Self::FirstValue, BuiltInWindowFunction::LastValue => Self::LastValue, BuiltInWindowFunction::NthValue => Self::NthValue, - BuiltInWindowFunction::Ntile => Self::Ntile, } } } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 37ea6a2b47be..89a2403922e9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, + Literal, NegativeExpr, NotExpr, NthValue, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -108,33 +108,24 @@ pub fn serialize_physical_window_expr( let expr = built_in_window_expr.get_built_in_func_expr(); let built_in_fn_expr = expr.as_any(); - let builtin_fn = if let Some(ntile_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 0, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - ntile_expr.get_n() as i64, - )))), - ); - protobuf::BuiltInWindowFunction::Ntile - } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { - match nth_value_expr.get_kind() { - NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, - NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, - NthValueKind::Nth(n) => { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64( - Some(n), - ))), - ); - protobuf::BuiltInWindowFunction::NthValue + let builtin_fn = + if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new( + datafusion_common::ScalarValue::Int64(Some(n)), + )), + ); + protobuf::BuiltInWindowFunction::NthValue + } } - } - } else { - return not_impl_err!("BuiltIn function not supported: {expr:?}"); - }; + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; ( physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index a8c82ff80f23..3fec7d1c6ea0 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -48,7 +48,7 @@ use datafusion::functions_aggregate::expr_fn::{ use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; use datafusion::functions_window::expr_fn::{ - cume_dist, dense_rank, lag, lead, percent_rank, rank, row_number, + cume_dist, dense_rank, lag, lead, ntile, percent_rank, rank, row_number, }; use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; @@ -951,6 +951,7 @@ async fn roundtrip_expr_api() -> Result<()> { lag(col("b"), None, None), lag(col("b"), Some(2), None), lag(col("b"), Some(2), Some(ScalarValue::from(100))), + ntile(lit(3)), nth_value(col("b"), 1, vec![]), nth_value( col("b"), diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 0799859e4371..6bf2005dabf9 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -146,40 +146,6 @@ RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must All [aggregate functions](aggregate_functions.md) can be used as window functions. -## Ranking functions - -- [rank](#rank) -- [dense_rank](#dense_rank) -- [ntile](#ntile) - -### `rank` - -Rank of the current row with gaps; same as row_number of its first peer. - -```sql -rank() -``` - -### `dense_rank` - -Rank of the current row without gaps; this function counts peer groups. - -```sql -dense_rank() -``` - -### `ntile` - -Integer ranging from 1 to the argument value, dividing the partition as equally as possible. - -```sql -ntile(expression) -``` - -#### Arguments - -- **expression**: An integer describing the number groups the partition should be split into - ## Analytical functions - [cume_dist](#cume_dist) diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md index 267060abfdcc..ae3edb832fcb 100644 --- a/docs/source/user-guide/sql/window_functions_new.md +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -159,6 +159,7 @@ All [aggregate functions](aggregate_functions.md) can be used as window function - [cume_dist](#cume_dist) - [dense_rank](#dense_rank) +- [ntile](#ntile) - [percent_rank](#percent_rank) - [rank](#rank) - [row_number](#row_number) @@ -179,6 +180,18 @@ Returns the rank of the current row without gaps. This function ranks rows in a dense_rank() ``` +### `ntile` + +Integer ranging from 1 to the argument value, dividing the partition as equally as possible + +``` +ntile(expression) +``` + +#### Arguments + +- **expression**: An integer describing the number groups the partition should be split into + ### `percent_rank` Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. From 7f32dcef3349059c7ee9ae0f24a3373ee2473982 Mon Sep 17 00:00:00 2001 From: June <61218022+itsjunetime@users.noreply.github.com> Date: Fri, 25 Oct 2024 07:18:55 -0600 Subject: [PATCH 075/110] Fix more instances of schema missing metadata (#13068) --- .../core/src/datasource/file_format/parquet.rs | 3 ++- datafusion/core/src/datasource/listing/table.rs | 8 +++++++- .../datasource/physical_plan/file_scan_config.rs | 14 ++++++++++---- datafusion/core/src/datasource/schema_adapter.rs | 3 ++- datafusion/core/src/physical_planner.rs | 11 +++++++++-- datafusion/expr/src/logical_plan/builder.rs | 8 ++++++-- datafusion/expr/src/utils.rs | 5 ++++- datafusion/physical-plan/src/joins/utils.rs | 8 +++++++- datafusion/physical-plan/src/windows/utils.rs | 4 +++- 9 files changed, 50 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 756c17fd67c6..2d45c76ce918 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -738,13 +738,14 @@ impl ParquetSink { .iter() .map(|(s, _)| s) .collect(); - Arc::new(Schema::new( + Arc::new(Schema::new_with_metadata( schema .fields() .iter() .filter(|f| !partition_names.contains(&f.name())) .map(|f| (**f).clone()) .collect::>(), + schema.metadata().clone(), )) } else { self.config.output_schema().clone() diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 1e9f06c20b47..ea2e098ef14e 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -719,10 +719,16 @@ impl ListingTable { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + let table = Self { table_paths: config.table_paths, file_schema, - table_schema: Arc::new(builder.finish()), + table_schema, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 2c438e8b0e78..415ea62b3bb3 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -248,9 +248,10 @@ impl FileScanConfig { column_statistics: table_cols_stats, }; - let projected_schema = Arc::new( - Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), - ); + let projected_schema = Arc::new(Schema::new_with_metadata( + table_fields, + self.file_schema.metadata().clone(), + )); let projected_output_ordering = get_projected_output_ordering(self, &projected_schema); @@ -281,7 +282,12 @@ impl FileScanConfig { fields.map_or_else( || Arc::clone(&self.file_schema), - |f| Arc::new(Schema::new(f).with_metadata(self.file_schema.metadata.clone())), + |f| { + Arc::new(Schema::new_with_metadata( + f, + self.file_schema.metadata.clone(), + )) + }, ) } diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index fdf3381758a4..131b8c354ce7 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -304,7 +304,8 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = Arc::new(Schema::new(fields)); + let schema = + Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 5a4ae868d04a..ffedc2d6b6ef 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1025,14 +1025,21 @@ impl DefaultPhysicalPlanner { }) .collect(); + let metadata: HashMap<_, _> = left_df_schema + .metadata() + .clone() + .into_iter() + .chain(right_df_schema.metadata().clone()) + .collect(); + // Construct intermediate schemas used for filtering data and // convert logical expression to physical according to filter schema let filter_df_schema = DFSchema::new_with_metadata( filter_df_fields, - HashMap::new(), + metadata.clone(), )?; let filter_schema = - Schema::new_with_metadata(filter_fields, HashMap::new()); + Schema::new_with_metadata(filter_fields, metadata); let filter_expr = create_physical_expr( expr, &filter_df_schema, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cef05b6f8814..aef531a9dbf7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1402,8 +1402,12 @@ pub fn build_join_schema( join_type, left.fields().len(), ); - let mut metadata = left.metadata().clone(); - metadata.extend(right.metadata().clone()); + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 86562daf6909..bb5496c0f799 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -437,7 +437,10 @@ pub fn expand_qualified_wildcard( return plan_err!("Invalid qualifier {qualifier}"); } - let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); + let qualified_schema = Arc::new(Schema::new_with_metadata( + fields_with_qualified, + schema.metadata().clone(), + )); let qualified_dfschema = DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? .with_functional_dependencies(projected_func_dependencies)?; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c520e4271416..17a32a67c743 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -701,7 +701,13 @@ pub fn build_join_schema( .unzip(), }; - (fields.finish(), column_indices) + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); + (fields.finish().with_metadata(metadata), column_indices) } /// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls diff --git a/datafusion/physical-plan/src/windows/utils.rs b/datafusion/physical-plan/src/windows/utils.rs index 3cf92daae0fb..13332ea82fa1 100644 --- a/datafusion/physical-plan/src/windows/utils.rs +++ b/datafusion/physical-plan/src/windows/utils.rs @@ -31,5 +31,7 @@ pub(crate) fn create_schema( for expr in window_expr { builder.push(expr.field()?); } - Ok(builder.finish()) + Ok(builder + .finish() + .with_metadata(input_schema.metadata().clone())) } From 06594c79f11d57b88467e9b87e5cb1ac6cf85d88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:50:38 +0300 Subject: [PATCH 076/110] Bug-fix / Limit with_new_exprs() (#13109) * Update plan.rs * Update plan.rs * Update plan.rs --- datafusion/expr/src/logical_plan/plan.rs | 26 +++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4b42702f24bf..572285defba0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -936,9 +936,8 @@ impl LogicalPlan { expr.len() ); } - // Pop order is same as the order returned by `LogicalPlan::expressions()` - let new_skip = skip.as_ref().and(expr.pop()); - let new_fetch = fetch.as_ref().and(expr.pop()); + let new_skip = skip.as_ref().and_then(|_| expr.pop()); + let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { skip: new_skip.map(Box::new), @@ -4101,4 +4100,25 @@ digraph { ); assert_eq!(describe_table.partial_cmp(&describe_table_clone), None); } + + #[test] + fn test_limit_with_new_children() { + let limit = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(Expr::Literal( + ScalarValue::new_ten(&DataType::UInt32).unwrap(), + ))), + input: Arc::new(LogicalPlan::Values(Values { + schema: Arc::new(DFSchema::empty()), + values: vec![vec![]], + })), + }); + let new_limit = limit + .with_new_exprs( + limit.expressions(), + limit.inputs().into_iter().cloned().collect(), + ) + .unwrap(); + assert_eq!(limit, new_limit); + } } From 813220d54f08c5203ad79bfb066ca638abe208ed Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 25 Oct 2024 16:21:05 +0200 Subject: [PATCH 077/110] Move subquery check from analyzer to PullUpCorrelatedExpr (#13091) This patch moves subquery check `can_pull_over_aggregation` from analyzer into the PullUpCorrelatedExpr. Instead of failing the query we will no instead not decorrelate such queries and then fail during physical plan creation. The goal here is to support TPC-DS q41 which has an expression that can not be pull up until it has been simplified by SimplifyExpressions. This means that currently we reject the query already in the analyzer. But after this change we are able to plan that query. --- datafusion/core/benches/sql_planner.rs | 4 - datafusion/core/tests/tpcds_planning.rs | 5 - datafusion/optimizer/src/analyzer/subquery.rs | 96 ++++--------------- datafusion/optimizer/src/decorrelate.rs | 45 ++++++++- .../optimizer/src/scalar_subquery_to_join.rs | 54 ++++++++--- .../sqllogictest/test_files/subquery.slt | 14 ++- 6 files changed, 116 insertions(+), 102 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 64d2760e9d97..6f9cf02873d1 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -270,11 +270,7 @@ fn criterion_benchmark(c: &mut Criterion) { let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); - // 41: check_analyzed_plan: Correlated column is not allowed in predicate - let ignored = [41]; - let raw_tpcds_sql_queries = (1..100) - .filter(|q| !ignored.contains(q)) .map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap()) .collect::>(); diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 0077a2d35b1f..252d76d0f9d9 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -229,9 +229,6 @@ async fn tpcds_logical_q40() -> Result<()> { } #[tokio::test] -#[ignore] -// check_analyzed_plan: Correlated column is not allowed in predicate -// issue: https://github.com/apache/datafusion/issues/13074 async fn tpcds_logical_q41() -> Result<()> { create_logical_plan(41).await } @@ -726,8 +723,6 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] -// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 0a52685bd681..e01ae625ed9c 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; @@ -24,10 +22,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; -use datafusion_expr::{ - Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, - Window, -}; +use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window}; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, @@ -98,7 +93,7 @@ pub fn check_subquery_expr( ) }?; } - check_correlations_in_subquery(inner_plan, true) + check_correlations_in_subquery(inner_plan) } else { if let Expr::InSubquery(subquery) = expr { // InSubquery should only return one column @@ -121,25 +116,17 @@ pub fn check_subquery_expr( Projection, Filter, Window functions, Aggregate and Join plan nodes" ), }?; - check_correlations_in_subquery(inner_plan, false) + check_correlations_in_subquery(inner_plan) } } // Recursively check the unsupported outer references in the sub query plan. -fn check_correlations_in_subquery( - inner_plan: &LogicalPlan, - is_scalar: bool, -) -> Result<()> { - check_inner_plan(inner_plan, is_scalar, false, true) +fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { + check_inner_plan(inner_plan, true) } // Recursively check the unsupported outer references in the sub query plan. -fn check_inner_plan( - inner_plan: &LogicalPlan, - is_scalar: bool, - is_aggregate: bool, - can_contain_outer_ref: bool, -) -> Result<()> { +fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); } @@ -147,32 +134,18 @@ fn check_inner_plan( match inner_plan { LogicalPlan::Aggregate(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) - .into_iter() - .partition(|e| e.contains_outer()); - let maybe_unsupported = correlated - .into_iter() - .filter(|expr| !can_pullup_over_aggregation(expr)) - .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupported.is_empty() { - return plan_err!( - "Correlated column is not allowed in predicate: {predicate}" - ); - } - check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) + LogicalPlan::Filter(Filter { input, .. }) => { + check_inner_plan(input, can_contain_outer_ref) } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -188,7 +161,7 @@ fn check_inner_plan( | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -201,27 +174,22 @@ fn check_inner_plan( }) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { - check_inner_plan( - plan, - is_scalar, - is_aggregate, - can_contain_outer_ref, - )?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - check_inner_plan(left, is_scalar, is_aggregate, can_contain_outer_ref)?; - check_inner_plan(right, is_scalar, is_aggregate, false) + check_inner_plan(left, can_contain_outer_ref)?; + check_inner_plan(right, false) } JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - check_inner_plan(left, is_scalar, is_aggregate, false)?; - check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) + check_inner_plan(left, false)?; + check_inner_plan(right, can_contain_outer_ref) } JoinType::Full => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, false)?; + check_inner_plan(plan, false)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -290,34 +258,6 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { Ok(exprs) } -/// Check whether the expression can pull up over the aggregation without change the result of the query -fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr - { - match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => - { - !right.any_column_refs() - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => - { - !left.any_column_refs() - } - (_, _) => false, - } - } else { - false - } -} - /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { let mixed = window @@ -398,6 +338,6 @@ mod test { }), }); - check_inner_plan(&plan, false, false, true).unwrap(); + check_inner_plan(&plan, true).unwrap(); } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index baf449a045eb..6aa59b77f7f9 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -32,7 +32,8 @@ use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{ - expr, lit, EmptyRelation, Expr, FetchType, LogicalPlan, LogicalPlanBuilder, + expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, }; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -51,6 +52,9 @@ pub struct PullUpCorrelatedExpr { pub exists_sub_query: bool, /// Can the correlated expressions be pulled up. Defaults to **TRUE** pub can_pull_up: bool, + /// Indicates if we encounter any correlated expression that can not be pulled up + /// above a aggregation without changing the meaning of the query. + can_pull_over_aggregation: bool, /// Do we need to handle [the Count bug] during the pull up process /// /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 @@ -75,6 +79,7 @@ impl PullUpCorrelatedExpr { in_predicate_opt: None, exists_sub_query: false, can_pull_up: true, + can_pull_over_aggregation: true, need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, @@ -154,6 +159,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { match &plan { LogicalPlan::Filter(plan_filter) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + self.can_pull_over_aggregation = self.can_pull_over_aggregation + && subquery_filter_exprs + .iter() + .filter(|e| e.contains_outer()) + .all(|&e| can_pullup_over_aggregation(e)); let (mut join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; if let Some(in_predicate) = &self.in_predicate_opt { @@ -259,6 +269,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { + // If the aggregation is from a distinct it will not change the result for + // exists/in subqueries so we can still pull up all predicates. + let is_distinct = aggregate.aggr_expr.is_empty(); + if !is_distinct { + self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; + } let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -385,6 +401,33 @@ impl PullUpCorrelatedExpr { } } +fn can_pullup_over_aggregation(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + fn collect_local_correlated_cols( plan: &LogicalPlan, all_cols_map: &HashMap>, diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 6409bb9e03f7..7b931e73abf9 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -625,11 +625,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) != orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -652,11 +662,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) < orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -680,11 +700,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1)"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 6b142302a543..26b5d8b952f6 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -509,8 +509,18 @@ SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from 44 NULL #non_equal_correlated_scalar_subquery -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated column is not allowed in predicate: t2\.t2_id < outer_ref\(t1\.t1_id\) -SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +# Currently not supported and should not be decorrelated +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS t2_sum +02)--Subquery: +03)----Projection: sum(t2.t2_int) +04)------Aggregate: groupBy=[[]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] +05)--------Filter: t2.t2_id < outer_ref(t1.t1_id) +06)----------TableScan: t2 +07)--TableScan: t1 projection=[t1_id] #aggregated_correlated_scalar_subquery_with_extra_group_by_columns statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns From bdcf8225933c852e9f3a1b44a51d262627506f98 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 26 Oct 2024 04:59:41 +0800 Subject: [PATCH 078/110] Include IMDB in benchmark README (#13107) --- benchmarks/README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/benchmarks/README.md b/benchmarks/README.md index afaf28bb7576..a12662ccb846 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -330,6 +330,16 @@ steps. The tests sort the entire dataset using several different sort orders. +## IMDB + +Run Join Order Benchmark (JOB) on IMDB dataset. + +The Internet Movie Database (IMDB) dataset contains real-world movie data. Unlike synthetic datasets like TPCH, which assume uniform data distribution and uncorrelated columns, the IMDB dataset includes skewed data and correlated columns (which are common for real dataset), making it more suitable for testing query optimizers, particularly for cardinality estimation. + +This benchmark is derived from [Join Order Benchmark](https://github.com/gregrahn/join-order-benchmark). + +See paper [How Good Are Query Optimizers, Really](http://www.vldb.org/pvldb/vol9/p204-leis.pdf) for more details. + ## TPCH Run the tpch benchmark. From 96236908877d9cebbf115cff001f35f0729fab9e Mon Sep 17 00:00:00 2001 From: neyama Date: Sat, 26 Oct 2024 06:02:27 +0900 Subject: [PATCH 079/110] removed --prefer_hash_join option that causes an error when running the benchmark (#13106) --- benchmarks/bench.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 70faa9ef2b73..fc10cc5afc53 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -357,7 +357,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -365,7 +365,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } From 21cfd6ccdbd6996be24586568bc9c260f50505db Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 25 Oct 2024 17:14:06 -0400 Subject: [PATCH 080/110] Make CI test error if a function is not documented (#12938) --- .../core/src/bin/print_functions_docs.rs | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index d87c3cefe666..598574c0703d 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::execution::SessionStateDefaults; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::{ aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, @@ -30,7 +31,7 @@ use std::fmt::Write as _; /// Usage: `cargo run --bin print_functions_docs -- ` /// /// Called from `dev/update_function_docs.sh` -fn main() { +fn main() -> Result<()> { let args: Vec = args().collect(); if args.len() != 2 { @@ -48,12 +49,13 @@ fn main() { _ => { panic!("Unknown function type: {}", function_type) } - }; + }?; println!("{docs}"); + Ok(()) } -fn print_aggregate_docs() -> String { +fn print_aggregate_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_aggregate_functions() { @@ -63,7 +65,7 @@ fn print_aggregate_docs() -> String { print_docs(providers, aggregate_doc_sections::doc_sections()) } -fn print_scalar_docs() -> String { +fn print_scalar_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_scalar_functions() { @@ -73,7 +75,7 @@ fn print_scalar_docs() -> String { print_docs(providers, scalar_doc_sections::doc_sections()) } -fn print_window_docs() -> String { +fn print_window_docs() -> Result { let mut providers: Vec> = vec![]; for f in SessionStateDefaults::default_window_functions() { @@ -86,7 +88,7 @@ fn print_window_docs() -> String { fn print_docs( providers: Vec>, doc_sections: Vec, -) -> String { +) -> Result { let mut docs = "".to_string(); // Ensure that all providers have documentation @@ -217,12 +219,13 @@ fn print_docs( // eventually make this an error: https://github.com/apache/datafusion/issues/12872 if !providers_with_no_docs.is_empty() { eprintln!("INFO: The following functions do not have documentation:"); - for f in providers_with_no_docs { + for f in &providers_with_no_docs { eprintln!(" - {f}"); } + not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + } else { + Ok(docs) } - - docs } /// Trait for accessing name / aliases / documentation for differnet functions From 7b2284c8a0b49234e9607bfef10d73ef788d9458 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 25 Oct 2024 17:27:23 -0400 Subject: [PATCH 081/110] Allow using `cargo nextest` for running tests (#13045) * Allow using `cargo nextest` for running tests * Update datafusion/sqllogictest/bin/sqllogictests.rs Co-authored-by: Piotr Findeisen * Clarify rationale for returning OK * Apply suggestions from code review Co-authored-by: Piotr Findeisen --------- Co-authored-by: Piotr Findeisen --- datafusion/sqllogictest/bin/sqllogictests.rs | 29 +++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index baa49057e1b9..501fd3517a17 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -62,6 +62,15 @@ async fn run_tests() -> Result<()> { env_logger::init(); let options: Options = clap::Parser::parse(); + if options.list { + // nextest parses stdout, so print messages to stderr + eprintln!("NOTICE: --list option unsupported, quitting"); + // return Ok, not error so that tools like nextest which are listing all + // workspace tests (by running `cargo test ... --list --format terse`) + // do not fail when they encounter this binary. Instead, print nothing + // to stdout and return OK so they can continue listing other tests. + return Ok(()); + } options.warn_on_ignored(); // Run all tests in parallel, reporting failures at the end @@ -276,7 +285,7 @@ fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { /// Parsed command line options /// -/// This structure attempts to mimic the command line options +/// This structure attempts to mimic the command line options of the built in rust test runner /// accepted by IDEs such as CLion that pass arguments /// /// See for more details @@ -320,6 +329,18 @@ struct Options { help = "IGNORED (for compatibility with built in rust test runner)" )] show_output: bool, + + #[clap( + long, + help = "Quits immediately, not listing anything (for compatibility with built-in rust test runner)" + )] + list: bool, + + #[clap( + long, + help = "IGNORED (for compatibility with built-in rust test runner)" + )] + ignored: bool, } impl Options { @@ -354,15 +375,15 @@ impl Options { /// Logs warning messages to stdout if any ignored options are passed fn warn_on_ignored(&self) { if self.format.is_some() { - println!("WARNING: Ignoring `--format` compatibility option"); + eprintln!("WARNING: Ignoring `--format` compatibility option"); } if self.z_options.is_some() { - println!("WARNING: Ignoring `-Z` compatibility option"); + eprintln!("WARNING: Ignoring `-Z` compatibility option"); } if self.show_output { - println!("WARNING: Ignoring `--show-output` compatibility option"); + eprintln!("WARNING: Ignoring `--show-output` compatibility option"); } } } From 73cfa6c266763b3db15942e3f331f3d5274169c1 Mon Sep 17 00:00:00 2001 From: Leslie Su <3530611790@qq.com> Date: Sat, 26 Oct 2024 18:29:37 +0800 Subject: [PATCH 082/110] feat: Add `Date32`/`Date64` in aggregate fuzz testing (#13041) * refactor PrimitiveArrayGenerator. * support Date32/Date64 type in data generator. * fix format. * remove unnecessary type para in PrimitiveArrayGenerator. * introduce FromNative trait and replace the unsafe. --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 + .../aggregation_fuzzer/data_generator.rs | 46 ++++-- test-utils/src/array_gen/primitive.rs | 134 ++++++++++++------ 3 files changed, 126 insertions(+), 56 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 1035fa31da08..28901b14b5b7 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -164,6 +164,8 @@ fn baseline_config() -> DatasetGeneratorConfig { ColumnDescr::new("u16", DataType::UInt16), ColumnDescr::new("u32", DataType::UInt32), ColumnDescr::new("u64", DataType::UInt64), + ColumnDescr::new("date32", DataType::Date32), + ColumnDescr::new("date64", DataType::Date64), // TODO: date/time columns // todo decimal columns // begin string columns diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 44f96d5a1a07..ef9b5a7f355a 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -17,6 +17,10 @@ use std::sync::Arc; +use arrow::datatypes::{ + Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; @@ -222,7 +226,7 @@ macro_rules! generate_string_array { } macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $DATA_TYPE:ident) => { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => { paste::paste! {{ let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; @@ -239,7 +243,7 @@ macro_rules! generate_primitive_array { rng: $ARRAY_GEN_RNG, }; - generator.[< gen_data_ $DATA_TYPE >]() + generator.gen_data::<$ARROW_TYPE>() }}} } @@ -297,7 +301,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - i8 + Int8Type ) } DataType::Int16 => { @@ -306,7 +310,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - i16 + Int16Type ) } DataType::Int32 => { @@ -315,7 +319,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - i32 + Int32Type ) } DataType::Int64 => { @@ -324,7 +328,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - i64 + Int64Type ) } DataType::UInt8 => { @@ -333,7 +337,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - u8 + UInt8Type ) } DataType::UInt16 => { @@ -342,7 +346,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - u16 + UInt16Type ) } DataType::UInt32 => { @@ -351,7 +355,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - u32 + UInt32Type ) } DataType::UInt64 => { @@ -360,7 +364,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - u64 + UInt64Type ) } DataType::Float32 => { @@ -369,7 +373,7 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - f32 + Float32Type ) } DataType::Float64 => { @@ -378,7 +382,25 @@ impl RecordBatchGenerator { num_rows, batch_gen_rng, array_gen_rng, - f64 + Float64Type + ) + } + DataType::Date32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Date32Type + ) + } + DataType::Date64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + Date64Type ) } DataType::Utf8 => { diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs index f70ebf6686d0..0581862d63bd 100644 --- a/test-utils/src/array_gen/primitive.rs +++ b/test-utils/src/array_gen/primitive.rs @@ -15,14 +15,45 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array}; -use arrow::datatypes::{ - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, -}; +use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; +use arrow::datatypes::DataType; +use rand::distributions::Standard; +use rand::prelude::Distribution; use rand::rngs::StdRng; use rand::Rng; +/// Trait for converting type safely from a native type T impl this trait. +pub trait FromNative: std::fmt::Debug + Send + Sync + Copy + Default { + /// Convert native type from i64. + fn from_i64(_: i64) -> Option { + None + } +} + +macro_rules! native_type { + ($t: ty $(, $from:ident)*) => { + impl FromNative for $t { + $( + #[inline] + fn $from(v: $t) -> Option { + Some(v) + } + )* + } + }; +} + +native_type!(i8); +native_type!(i16); +native_type!(i32); +native_type!(i64, from_i64); +native_type!(u8); +native_type!(u16); +native_type!(u32); +native_type!(u64); +native_type!(f32); +native_type!(f64); + /// Randomly generate primitive array pub struct PrimitiveArrayGenerator { /// the total number of strings in the output @@ -35,46 +66,61 @@ pub struct PrimitiveArrayGenerator { pub rng: StdRng, } -macro_rules! impl_gen_data { - ($NATIVE_TYPE:ty, $ARROW_TYPE:ident) => { - paste::paste! { - pub fn [< gen_data_ $NATIVE_TYPE >](&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_primitives: PrimitiveArray<$ARROW_TYPE> = (0..self.num_distinct_primitives) - .map(|_| Some(self.rng.gen::<$NATIVE_TYPE>())) - .collect(); +// TODO: support generating more primitive arrays +impl PrimitiveArrayGenerator { + pub fn gen_data(&mut self) -> ArrayRef + where + A: ArrowPrimitiveType, + A::Native: FromNative, + Standard: Distribution<::Native>, + { + // table of primitives from which to draw + let distinct_primitives: PrimitiveArray = (0..self.num_distinct_primitives) + .map(|_| { + Some(match A::DATA_TYPE { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 => self.rng.gen::(), - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_primitives) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_primitives > 1 { - let range = 1..(self.num_distinct_primitives as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); + DataType::Date64 => { + // TODO: constrain this range to valid dates if necessary + let date_value = self.rng.gen_range(i64::MIN..=i64::MAX); + let millis_per_day = 86_400_000; + let adjusted_value = date_value - (date_value % millis_per_day); + A::Native::from_i64(adjusted_value).unwrap() + } - let options = None; - arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() - } - } - }; -} + _ => { + let arrow_type = A::DATA_TYPE; + panic!("Unsupported arrow data type: {arrow_type}") + } + }) + }) + .collect(); -// TODO: support generating more primitive arrays -impl PrimitiveArrayGenerator { - impl_gen_data!(i8, Int8Type); - impl_gen_data!(i16, Int16Type); - impl_gen_data!(i32, Int32Type); - impl_gen_data!(i64, Int64Type); - impl_gen_data!(u8, UInt8Type); - impl_gen_data!(u16, UInt16Type); - impl_gen_data!(u32, UInt32Type); - impl_gen_data!(u64, UInt64Type); - impl_gen_data!(f32, Float32Type); - impl_gen_data!(f64, Float64Type); + // pick num_primitves randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_primitives) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_primitives > 1 { + let range = 1..(self.num_distinct_primitives as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() + } } From 22a242c15deafc78f5e6b42eb98408181979cf00 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sat, 26 Oct 2024 06:30:34 -0400 Subject: [PATCH 083/110] docs: Added Special Functions Page (#13102) * Added special function page * Add index entry, tweak wording * Improve example * Update docs/source/user-guide/sql/special_functions.md --------- Co-authored-by: Andrew Lamb Co-authored-by: Oleks V --- docs/source/user-guide/sql/index.rst | 1 + .../source/user-guide/sql/scalar_functions.md | 77 -------------- .../user-guide/sql/special_functions.md | 100 ++++++++++++++++++ 3 files changed, 101 insertions(+), 77 deletions(-) create mode 100644 docs/source/user-guide/sql/special_functions.md diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 6eb451c83b96..8b8afc7b048a 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -35,5 +35,6 @@ SQL Reference window_functions_new scalar_functions scalar_functions_new + special_functions sql_status write_options diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 203411428777..a8e25930bef7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -72,51 +72,8 @@ See [date_part](#date_part). ## Array Functions -- [unnest](#unnest) - [range](#range) -### `unnest` - -Transforms an array into rows. - -#### Arguments - -- **array**: Array expression to unnest. - Can be a constant, column, or function, and any combination of array operators. - -#### Examples - -``` -> select unnest(make_array(1, 2, 3, 4, 5)); -+------------------------------------------------------------------+ -| unnest(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5))) | -+------------------------------------------------------------------+ -| 1 | -| 2 | -| 3 | -| 4 | -| 5 | -+------------------------------------------------------------------+ -``` - -``` -> select unnest(range(0, 10)); -+-----------------------------------+ -| unnest(range(Int64(0),Int64(10))) | -+-----------------------------------+ -| 0 | -| 1 | -| 2 | -| 3 | -| 4 | -| 5 | -| 6 | -| 7 | -| 8 | -| 9 | -+-----------------------------------+ -``` - ### `range` Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or @@ -165,40 +122,6 @@ are not allowed - generate_series -## Struct Functions - -- [unnest](#unnest-struct) - -For more struct functions see the new documentation [ -`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) - -### `unnest (struct)` - -Unwraps struct fields into columns. - -#### Arguments - -- **struct**: Object expression to unnest. - Can be a constant, column, or function, and any combination of object operators. - -#### Examples - -``` -> select * from foo; -+---------------------+ -| column1 | -+---------------------+ -| {a: 5, b: a string} | -+---------------------+ - -> select unnest(column1) from foo; -+-----------------------+-----------------------+ -| unnest(foo.column1).a | unnest(foo.column1).b | -+-----------------------+-----------------------+ -| 5 | a string | -+-----------------------+-----------------------+ -``` - ## Other Functions See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) diff --git a/docs/source/user-guide/sql/special_functions.md b/docs/source/user-guide/sql/special_functions.md new file mode 100644 index 000000000000..7c9efbb66218 --- /dev/null +++ b/docs/source/user-guide/sql/special_functions.md @@ -0,0 +1,100 @@ + + +# Special Functions + +## Expansion Functions + +- [unnest](#unnest) +- [unnest(struct)](#unnest-struct) + +### `unnest` + +Expands an array or map into rows. + +#### Arguments + +- **array**: Array expression to unnest. + Can be a constant, column, or function, and any combination of array operators. + +#### Examples + +```sql +> select unnest(make_array(1, 2, 3, 4, 5)) as unnested; ++----------+ +| unnested | ++----------+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | ++----------+ +``` + +```sql +> select unnest(range(0, 10)) as unnested_range; ++----------------+ +| unnested_range | ++----------------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----------------+ +``` + +### `unnest (struct)` + +Expand a struct fields into individual columns. + +#### Arguments + +- **struct**: Object expression to unnest. + Can be a constant, column, or function, and any combination of object operators. + +#### Examples + +```sql +> create table foo as values ({a: 5, b: 'a string'}), ({a:6, b: 'another string'}); + +> create view foov as select column1 as struct_column from foo; + +> select * from foov; ++---------------------------+ +| struct_column | ++---------------------------+ +| {a: 5, b: a string} | +| {a: 6, b: another string} | ++---------------------------+ + +> select unnest(struct_column) from foov; ++------------------------------------------+------------------------------------------+ +| unnest_placeholder(foov.struct_column).a | unnest_placeholder(foov.struct_column).b | ++------------------------------------------+------------------------------------------+ +| 5 | a string | +| 6 | another string | ++------------------------------------------+------------------------------------------+ +``` From d2511b258e1d3e286ca50a21e671633e2281b105 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sat, 26 Oct 2024 18:41:26 +0800 Subject: [PATCH 084/110] fix: planning of prepare statement with limit clause (#13088) * fix: planning of prepare statement with limit clause * Improve test --- datafusion/sql/src/query.rs | 10 ++++++---- datafusion/sql/tests/sql_integration.rs | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 842a1c0cbec1..1ef009132f9e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -53,7 +53,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // so we need to process `SELECT` and `ORDER BY` together. let oby_exprs = to_order_by_exprs(query.order_by)?; let plan = self.select_to_plan(*select, oby_exprs, planner_context)?; - let plan = self.limit(plan, query.offset, query.limit)?; + let plan = + self.limit(plan, query.offset, query.limit, planner_context)?; // Process the `SELECT INTO` after `LIMIT`. self.select_into(plan, select_into) } @@ -68,7 +69,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None, )?; let plan = self.order_by(plan, order_by_rex)?; - self.limit(plan, query.offset, query.limit) + self.limit(plan, query.offset, query.limit, planner_context) } } } @@ -79,6 +80,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, skip: Option, fetch: Option, + planner_context: &mut PlannerContext, ) -> Result { if skip.is_none() && fetch.is_none() { return Ok(input); @@ -88,10 +90,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_schema = DFSchema::empty(); let skip = skip - .map(|o| self.sql_to_expr(o.value, &empty_schema, &mut PlannerContext::new())) + .map(|o| self.sql_to_expr(o.value, &empty_schema, planner_context)) .transpose()?; let fetch = fetch - .map(|e| self.sql_to_expr(e, &empty_schema, &mut PlannerContext::new())) + .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) .transpose()?; LogicalPlanBuilder::from(input) .limit_by_expr(skip, fetch)? diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index edb614493b38..698c408e538f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -4209,6 +4209,29 @@ fn test_prepare_statement_to_plan_having() { prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + + let expected_plan = "Prepare: \"my_plan\" [Int64, Int64] \ + \n Limit: skip=$1, fetch=$2\ + \n Projection: person.id\ + \n TableScan: person"; + + let expected_dt = "[Int64, Int64]"; + + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let expected_plan = "Limit: skip=10, fetch=200\ + \n Projection: person.id\ + \n TableScan: person"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_to_plan_value_list() { let sql = "PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter);"; From 7df3e5cd11f63226b90783564ae7268ee2512ec1 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 26 Oct 2024 18:59:09 +0800 Subject: [PATCH 085/110] Add benchmark for memory-limited aggregation (#13090) * Adding benchmark for external aggregation * comments --- benchmarks/README.md | 28 ++ benchmarks/bench.sh | 25 +- benchmarks/src/bin/external_aggr.rs | 390 ++++++++++++++++++++ datafusion/execution/src/memory_pool/mod.rs | 12 +- 4 files changed, 450 insertions(+), 5 deletions(-) create mode 100644 benchmarks/src/bin/external_aggr.rs diff --git a/benchmarks/README.md b/benchmarks/README.md index a12662ccb846..a9aa1afb97a1 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -352,6 +352,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## External Aggregation + +Run the benchmark for aggregations with limited memory. + +When the memory limit is exceeded, the aggregation intermediate results will be spilled to disk, and finally read back with sort-merge. + +External aggregation benchmarks run several aggregation queries with different memory limits, on TPCH `lineitem` table. Queries can be found in [`external_aggr.rs`](src/bin/external_aggr.rs). + +This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. + +### External Aggregation Example Runs +1. Run all queries with predefined memory limits: +```bash +# Under 'benchmarks/' directory +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' +``` + +2. Run a query with specific memory limit: +```bash +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M +``` + +3. Run all queries with `bench.sh` script: +```bash +./bench.sh data external_aggr +./bench.sh run external_aggr +``` + # Older Benchmarks diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index fc10cc5afc53..47c5d1261605 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -78,6 +78,7 @@ sort: Benchmark of sorting speed clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +external_aggr: External aggregation benchmark ********** * Supported Configuration (Environment Variables) @@ -170,6 +171,10 @@ main() { imdb) data_imdb ;; + external_aggr) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -212,6 +217,7 @@ main() { run_clickbench_partitioned run_clickbench_extended run_imdb + run_external_aggr ;; tpch) run_tpch "1" @@ -243,6 +249,9 @@ main() { imdb) run_imdb ;; + external_aggr) + run_external_aggr + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -524,7 +533,21 @@ run_imdb() { $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" } - +# Runs the external aggregation benchmark +run_external_aggr() { + # Use TPC-H SF1 dataset + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/external_aggr.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running external aggregation benchmark..." + + # Only parquet is supported. + # Since per-operator memory limit is calculated as (total-memory-limit / + # number-of-partitions), and by default `--partitions` is set to number of + # CPU cores, we set a constant number of partitions to prevent this + # benchmark to fail on some machines. + $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" +} compare_benchmarks() { diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs new file mode 100644 index 000000000000..1bc74e22ccfa --- /dev/null +++ b/benchmarks/src/bin/external_aggr.rs @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! external_aggr binary entrypoint + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::OnceLock; +use structopt::StructOpt; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::FairSpillPool; +use datafusion::execution::memory_pool::{human_readable_size, units}; +use datafusion::execution::runtime_env::RuntimeConfig; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "datafusion-external-aggregation", + about = "DataFusion external aggregation benchmark" +)] +enum ExternalAggrOpt { + Benchmark(ExternalAggrConfig), +} + +#[derive(Debug, StructOpt)] +struct ExternalAggrConfig { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. + #[structopt(long)] + memory_limit: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +/// Query Memory Limits +/// Map query id to predefined memory limits +/// +/// Q1 requires 36MiB for aggregation +/// Memory limits to run: 64MiB, 32MiB, 16MiB +/// Q2 requires 250MiB for aggregation +/// Memory limits to run: 512MiB, 256MiB, 128MiB, 64MiB, 32MiB +static QUERY_MEMORY_LIMITS: OnceLock>> = OnceLock::new(); + +impl ExternalAggrConfig { + const AGGR_TABLES: [&'static str; 1] = ["lineitem"]; + const AGGR_QUERIES: [&'static str; 2] = [ + // Q1: Output size is ~25% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey + FROM lineitem + ) + "#, + // Q2: Output size is ~99% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey, l_suppkey + FROM lineitem + ) + "#, + ]; + + fn init_query_memory_limits() -> &'static HashMap> { + use units::*; + QUERY_MEMORY_LIMITS.get_or_init(|| { + let mut map = HashMap::new(); + map.insert(1, vec![64 * MB, 32 * MB, 16 * MB]); + map.insert(2, vec![512 * MB, 256 * MB, 128 * MB, 64 * MB, 32 * MB]); + map + }) + } + + /// If `--query` and `--memory-limit` is not speicified, run all queries + /// with pre-configured memory limits + /// If only `--query` is specified, run the query with all memory limits + /// for this query + /// If both `--query` and `--memory-limit` are specified, run the query + /// with the specified memory limit + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let memory_limit = match &self.memory_limit { + Some(limit) => Some(Self::parse_memory_limit(limit)?), + None => None, + }; + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::AGGR_QUERIES.len(), + }; + + // Each element is (query_id, memory_limit) + // e.g. [(1, 64_000), (1, 32_000)...] means first run Q1 with 64KiB + // memory limit, next run Q1 with 32KiB memory limit, etc. + let mut query_executions = vec![]; + // Setup `query_executions` + for query_id in query_range { + if query_id > Self::AGGR_QUERIES.len() { + return exec_err!( + "Invalid '--query'(query number) {} for external aggregation benchmark.", + query_id + ); + } + + match memory_limit { + Some(limit) => { + query_executions.push((query_id, limit)); + } + None => { + let memory_limits_table = Self::init_query_memory_limits(); + let memory_limits = memory_limits_table.get(&query_id).unwrap(); + for limit in memory_limits { + query_executions.push((query_id, *limit)); + } + } + } + } + + for (query_id, mem_limit) in query_executions { + benchmark_run.start_new_case(&format!( + "{query_id}({})", + human_readable_size(mem_limit as usize) + )); + + let query_results = self.benchmark_query(query_id, mem_limit).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `AGGR_QUERIES` + async fn benchmark_query( + &self, + query_id: usize, + mem_limit: u64, + ) -> Result> { + let query_name = + format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); + let mut config = self.common.config(); + config + .options_mut() + .execution + .parquet + .schema_force_view_types = self.common.force_view_types; + let runtime_config = RuntimeConfig::new() + .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .build_arc()?; + let ctx = SessionContext::new_with_config_rt(config, runtime_config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::AGGR_QUERIES[query_idx]; + + let result = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "{query_name} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("{query_name} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::AGGR_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } + + /// Parse memory limit from string to number of bytes + /// e.g. '1.5G', '100M' -> 1572864 + fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) + })?; + + match unit { + "K" => Ok((number * 1024.0) as u64), + "M" => Ok((number * 1024.0 * 1024.0) as u64), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), + _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), + } + } +} + +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match ExternalAggrOpt::from_args() { + ExternalAggrOpt::Benchmark(opt) => opt.run().await?, + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!( + ExternalAggrConfig::parse_memory_limit("100K").unwrap(), + 102400 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), + 1572864 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("2G").unwrap(), + 2147483648 + ); + + // Test invalid unit + assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); + } +} diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index d87ce1ebfed7..5bf30b724d0b 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -334,13 +334,17 @@ impl Drop for MemoryReservation { } } -const TB: u64 = 1 << 40; -const GB: u64 = 1 << 30; -const MB: u64 = 1 << 20; -const KB: u64 = 1 << 10; +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} /// Present size in human readable form pub fn human_readable_size(size: usize) -> String { + use units::*; + let size = size as u64; let (value, unit) = { if size >= 2 * TB { From 412ca4e1ffe0a005cb772393f1e920ab395a6200 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 26 Oct 2024 07:38:32 -0400 Subject: [PATCH 086/110] Add clickbench parquet based queries to sql_planner benchmark (#13103) * Add clickbench parquet based queries to sql_planner benchmark. * Cargo fmt. * Commented out most logical_plan tests & updated code to allow for running from either cargo or via target/release/deps/sql_planner-xyz --- datafusion/core/benches/sql_planner.rs | 121 ++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 14 deletions(-) diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 6f9cf02873d1..140e266a0272 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,22 +15,31 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -extern crate arrow; extern crate datafusion; mod data_utils; + use crate::criterion::Criterion; use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; +use itertools::Itertools; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; use test_utils::TableDef; use tokio::runtime::Runtime; +const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; +const BENCHMARKS_PATH_2: &str = "./benchmarks/"; +const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; + /// Create a logical plan from the specified sql fn logical_plan(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); @@ -91,7 +100,37 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { ctx } +fn register_clickbench_hits_table() -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + + // use an external table for clickbench benchmarks + let path = + if PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() { + format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}") + } else { + format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}") + }; + + let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + + rt.block_on(ctx.sql(&sql)).unwrap(); + + let count = + rt.block_on(async { ctx.table("hits").await.unwrap().count().await.unwrap() }); + assert!(count > 0); + ctx +} + fn criterion_benchmark(c: &mut Criterion) { + // verify that we can load the clickbench data prior to running the benchmark + if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() + && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() + { + panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + } + let ctx = create_context(); // Test simplest @@ -235,9 +274,15 @@ fn criterion_benchmark(c: &mut Criterion) { "q16", "q17", "q18", "q19", "q20", "q21", "q22", ]; + let benchmarks_path = if PathBuf::from(BENCHMARKS_PATH_1).exists() { + BENCHMARKS_PATH_1 + } else { + BENCHMARKS_PATH_2 + }; + for q in tpch_queries { let sql = - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap(); + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { b.iter(|| physical_plan(&tpch_ctx, &sql)) }); @@ -246,7 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { let all_tpch_sql_queries = tpch_queries .iter() .map(|q| { - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap() + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap() }) .collect::>(); @@ -258,20 +303,25 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpch_all", |b| { - b.iter(|| { - for sql in &all_tpch_sql_queries { - logical_plan(&tpch_ctx, sql) - } - }) - }); + // c.bench_function("logical_plan_tpch_all", |b| { + // b.iter(|| { + // for sql in &all_tpch_sql_queries { + // logical_plan(&tpch_ctx, sql) + // } + // }) + // }); // --- TPC-DS --- let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); + let tests_path = if PathBuf::from("./tests/").exists() { + "./tests/" + } else { + "datafusion/core/tests/" + }; let raw_tpcds_sql_queries = (1..100) - .map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap()) + .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); // some queries have multiple statements @@ -288,10 +338,53 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpcds_all", |b| { + // c.bench_function("logical_plan_tpcds_all", |b| { + // b.iter(|| { + // for sql in &all_tpcds_sql_queries { + // logical_plan(&tpcds_ctx, sql) + // } + // }) + // }); + + // -- clickbench -- + + let queries_file = + File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); + let extended_file = + File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); + + let clickbench_queries: Vec = BufReader::new(queries_file) + .lines() + .chain(BufReader::new(extended_file).lines()) + .map(|l| l.expect("Could not parse line")) + .collect_vec(); + + let clickbench_ctx = register_clickbench_hits_table(); + + // for (i, sql) in clickbench_queries.iter().enumerate() { + // c.bench_function(&format!("logical_plan_clickbench_q{}", i + 1), |b| { + // b.iter(|| logical_plan(&clickbench_ctx, sql)) + // }); + // } + + for (i, sql) in clickbench_queries.iter().enumerate() { + c.bench_function(&format!("physical_plan_clickbench_q{}", i + 1), |b| { + b.iter(|| physical_plan(&clickbench_ctx, sql)) + }); + } + + // c.bench_function("logical_plan_clickbench_all", |b| { + // b.iter(|| { + // for sql in &clickbench_queries { + // logical_plan(&clickbench_ctx, sql) + // } + // }) + // }); + + c.bench_function("physical_plan_clickbench_all", |b| { b.iter(|| { - for sql in &all_tpcds_sql_queries { - logical_plan(&tpcds_ctx, sql) + for sql in &clickbench_queries { + physical_plan(&clickbench_ctx, sql) } }) }); From 62b063cd36653b92a9f0cd53a358231be8c3e848 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 26 Oct 2024 07:41:41 -0400 Subject: [PATCH 087/110] Improve documentation and examples for `SchemaAdapterFactory`, make `record_batch` "hygenic" (#13063) * Improve documentation and examples for SchemaAdapterFactory and related classes * fix macro * Add macro hygene test * Fix example, add convenience function, update docs * Add tests and docs showing what happens when adapting a nullable column * review feedback * fix clippy --- datafusion/common/src/test_util.rs | 2 +- .../core/src/datasource/schema_adapter.rs | 286 ++++++++++++++---- datafusion/core/tests/macro_hygiene/mod.rs | 10 + 3 files changed, 239 insertions(+), 59 deletions(-) diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 422fcb5eb3e0..d3b8c8451258 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -347,7 +347,7 @@ macro_rules! record_batch { let batch = arrow_array::RecordBatch::try_new( schema, vec![$( - create_array!($type, $values), + $crate::create_array!($type, $values), )*] ); diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index 131b8c354ce7..80d2bf987473 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -32,11 +32,19 @@ use std::sync::Arc; /// /// This interface provides a way to implement custom schema adaptation logic /// for ParquetExec (for example, to fill missing columns with default value -/// other than null) +/// other than null). +/// +/// Most users should use [`DefaultSchemaAdapterFactory`]. See that struct for +/// more details and examples. pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { - /// Provides `SchemaAdapter`. - // The design of this function is mostly modeled for the needs of DefaultSchemaAdapterFactory, - // read its implementation docs for the reasoning + /// Create a [`SchemaAdapter`] + /// + /// Arguments: + /// + /// * `projected_table_schema`: The schema for the table, projected to + /// include only the fields being output (projected) by the this mapping. + /// + /// * `table_schema`: The entire table schema for the table fn create( &self, projected_table_schema: SchemaRef, @@ -44,53 +52,57 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { ) -> Box; } -/// Adapt file-level [`RecordBatch`]es to a table schema, which may have a schema -/// obtained from merging multiple file-level schemas. -/// -/// This is useful for enabling schema evolution in partitioned datasets. -/// -/// This has to be done in two stages. +/// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table +/// schema, which may have a schema obtained from merging multiple file-level +/// schemas. /// -/// 1. Before reading the file, we have to map projected column indexes from the -/// table schema to the file schema. +/// This is useful for implementing schema evolution in partitioned datasets. /// -/// 2. After reading a record batch map the read columns back to the expected -/// columns indexes and insert null-valued columns wherever the file schema was -/// missing a column present in the table schema. +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaAdapter: Send + Sync { /// Map a column index in the table schema to a column index in a particular /// file schema /// + /// This is used while reading a file to push down projections by mapping + /// projected column indexes from the table schema to the file schema + /// /// Panics if index is not in range for the table schema fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option; - /// Creates a `SchemaMapping` that can be used to cast or map the columns - /// from the file schema to the table schema. + /// Creates a mapping for casting columns from the file schema to the table + /// schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// This is used after reading a record batch. The returned [`SchemaMapper`]: /// - /// Returns a [`SchemaMapper`] that can be applied to the output batch - /// along with an ordered list of columns to project from the file + /// 1. Maps columns to the expected columns indexes + /// 2. Handles missing values (e.g. fills nulls or a default value) for + /// columns in the in the table schema not in the file schema + /// 2. Handles different types: if the column in the file schema has a + /// different type than `table_schema`, the mapper will resolve this + /// difference (e.g. by casting to the appropriate type) + /// + /// Returns: + /// * a [`SchemaMapper`] + /// * an ordered list of columns to project from the file fn map_schema( &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)>; } -/// Maps, by casting or reordering columns from the file schema to the table -/// schema. +/// Maps, columns from a specific file schema to the table schema. +/// +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. pub trait SchemaMapper: Debug + Send + Sync { - /// Adapts a `RecordBatch` to match the `table_schema` using the stored - /// mapping and conversions. + /// Adapts a `RecordBatch` to match the `table_schema` fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; /// Adapts a [`RecordBatch`] that does not have all the columns from the /// file schema. /// - /// This method is used when applying a filter to a subset of the columns as - /// part of `DataFusionArrowPredicate` when `filter_pushdown` is enabled. + /// This method is used, for example, when applying a filter to a subset of + /// the columns as part of `DataFusionArrowPredicate` when `filter_pushdown` + /// is enabled. /// /// This method is slower than `map_batch` as it looks up columns by name. fn map_partial_batch( @@ -99,11 +111,106 @@ pub trait SchemaMapper: Debug + Send + Sync { ) -> datafusion_common::Result; } -/// Implementation of [`SchemaAdapterFactory`] that maps columns by name -/// and casts columns to the expected type. +/// Default [`SchemaAdapterFactory`] for mapping schemas. +/// +/// This can be used to adapt file-level record batches to a table schema and +/// implement schema evolution. +/// +/// Given an input file schema and a table schema, this factory returns +/// [`SchemaAdapter`] that return [`SchemaMapper`]s that: +/// +/// 1. Reorder columns +/// 2. Cast columns to the correct type +/// 3. Fill missing columns with nulls +/// +/// # Errors: +/// +/// * If a column in the table schema is non-nullable but is not present in the +/// file schema (i.e. it is missing), the returned mapper tries to fill it with +/// nulls resulting in a schema error. +/// +/// # Illustration of Schema Mapping +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌───────┐ ┌───────┐ │ ┌───────┐ ┌───────┐ ┌───────┐ │ +/// ││ 1.0 │ │ "foo" │ ││ NULL │ │ "foo" │ │ "1.0" │ +/// ├───────┤ ├───────┤ │ Schema mapping ├───────┤ ├───────┤ ├───────┤ │ +/// ││ 2.0 │ │ "bar" │ ││ NULL │ │ "bar" │ │ "2.0" │ +/// └───────┘ └───────┘ │────────────────▶ └───────┘ └───────┘ └───────┘ │ +/// │ │ +/// column "c" column "b"│ column "a" column "b" column "c"│ +/// │ Float64 Utf8 │ Int32 Utf8 Utf8 +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Input Record Batch Output Record Batch +/// +/// Schema { Schema { +/// "c": Float64, "a": Int32, +/// "b": Utf8, "b": Utf8, +/// } "c": Utf8, +/// } +/// ``` +/// +/// # Example of using the `DefaultSchemaAdapterFactory` to map [`RecordBatch`]s +/// +/// Note `SchemaMapping` also supports mapping partial batches, which is used as +/// part of predicate pushdown. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapterFactory}; +/// # use datafusion_common::record_batch; +/// // Table has fields "a", "b" and "c" +/// let table_schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// Field::new("c", DataType::Utf8, true), +/// ]); +/// +/// // create an adapter to map the table schema to the file schema +/// let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); +/// +/// // The file schema has fields "c" and "b" but "b" is stored as an 'Float64' +/// // instead of 'Utf8' +/// let file_schema = Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Float64, true), +/// ]); +/// +/// // Get a mapping from the file schema to the table schema +/// let (mapper, _indices) = adapter.map_schema(&file_schema).unwrap(); +/// +/// let file_batch = record_batch!( +/// ("c", Utf8, vec!["foo", "bar"]), +/// ("b", Float64, vec![1.0, 2.0]) +/// ).unwrap(); +/// +/// let mapped_batch = mapper.map_batch(file_batch).unwrap(); +/// +/// // the mapped batch has the correct schema and the "b" column has been cast to Utf8 +/// let expected_batch = record_batch!( +/// ("a", Int32, vec![None, None]), // missing column filled with nulls +/// ("b", Utf8, vec!["1.0", "2.0"]), // b was cast to string and order was changed +/// ("c", Utf8, vec!["foo", "bar"]) +/// ).unwrap(); +/// assert_eq!(mapped_batch, expected_batch); +/// ``` #[derive(Clone, Debug, Default)] pub struct DefaultSchemaAdapterFactory; +impl DefaultSchemaAdapterFactory { + /// Create a new factory for mapping batches from a file schema to a table + /// schema. + /// + /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with + /// the same schema for both the projected table schema and the table + /// schema. + pub fn from_schema(table_schema: SchemaRef) -> Box { + Self.create(Arc::clone(&table_schema), table_schema) + } +} + impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { fn create( &self, @@ -117,8 +224,8 @@ impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { } } -/// This SchemaAdapter requires both the table schema and the projected table schema because of the -/// needs of the [`SchemaMapping`] it creates. Read its documentation for more details +/// This SchemaAdapter requires both the table schema and the projected table +/// schema. See [`SchemaMapping`] for more details #[derive(Clone, Debug)] pub(crate) struct DefaultSchemaAdapter { /// The schema for the table, projected to include only the fields being output (projected) by the @@ -142,11 +249,12 @@ impl SchemaAdapter for DefaultSchemaAdapter { Some(file_schema.fields.find(field.name())?.0) } - /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema. + /// Creates a `SchemaMapping` for casting or mapping the columns from the + /// file schema to the table schema. /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. + /// If the provided `file_schema` contains columns of a different type to + /// the expected `table_schema`, the method will attempt to cast the array + /// data from the file schema to the table schema where possible. /// /// Returns a [`SchemaMapping`] that can be applied to the output batch /// along with an ordered list of columns to project from the file @@ -189,36 +297,45 @@ impl SchemaAdapter for DefaultSchemaAdapter { } } -/// The SchemaMapping struct holds a mapping from the file schema to the table schema -/// and any necessary type conversions that need to be applied. +/// The SchemaMapping struct holds a mapping from the file schema to the table +/// schema and any necessary type conversions. +/// +/// Note, because `map_batch` and `map_partial_batch` functions have different +/// needs, this struct holds two schemas: +/// +/// 1. The projected **table** schema +/// 2. The full table schema /// -/// This needs both the projected table schema and full table schema because its different -/// functions have different needs. The [`map_batch`] function is only used by the ParquetOpener to -/// produce a RecordBatch which has the projected schema, since that's the schema which is supposed -/// to come out of the execution of this query. [`map_partial_batch`], however, is used to create a -/// RecordBatch with a schema that can be used for Parquet pushdown, meaning that it may contain -/// fields which are not in the projected schema (as the fields that parquet pushdown filters -/// operate can be completely distinct from the fields that are projected (output) out of the -/// ParquetExec). +/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which +/// has the projected schema, since that's the schema which is supposed to come +/// out of the execution of this query. Thus `map_batch` uses +/// `projected_table_schema` as it can only operate on the projected fields. /// -/// [`map_partial_batch`] uses `table_schema` to create the resulting RecordBatch (as it could be -/// operating on any fields in the schema), while [`map_batch`] uses `projected_table_schema` (as -/// it can only operate on the projected fields). +/// [`map_partial_batch`] is used to create a RecordBatch with a schema that +/// can be used for Parquet predicate pushdown, meaning that it may contain +/// fields which are not in the projected schema (as the fields that parquet +/// pushdown filters operate can be completely distinct from the fields that are +/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses +/// `table_schema` to create the resulting RecordBatch (as it could be operating +/// on any fields in the schema). /// /// [`map_batch`]: Self::map_batch /// [`map_partial_batch`]: Self::map_partial_batch #[derive(Debug)] pub struct SchemaMapping { - /// The schema of the table. This is the expected schema after conversion and it should match - /// the schema of the query result. + /// The schema of the table. This is the expected schema after conversion + /// and it should match the schema of the query result. projected_table_schema: SchemaRef, - /// Mapping from field index in `projected_table_schema` to index in projected file_schema. - /// They are Options instead of just plain `usize`s because the table could have fields that - /// don't exist in the file. + /// Mapping from field index in `projected_table_schema` to index in + /// projected file_schema. + /// + /// They are Options instead of just plain `usize`s because the table could + /// have fields that don't exist in the file. field_mappings: Vec>, - /// The entire table schema, as opposed to the projected_table_schema (which only contains the - /// columns that we are projecting out of this query). This contains all fields in the table, - /// regardless of if they will be projected out or not. + /// The entire table schema, as opposed to the projected_table_schema (which + /// only contains the columns that we are projecting out of this query). + /// This contains all fields in the table, regardless of if they will be + /// projected out or not. table_schema: SchemaRef, } @@ -331,8 +448,9 @@ mod tests { use crate::datasource::listing::PartitionedFile; use crate::datasource::schema_adapter::{ - SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, }; + use datafusion_common::record_batch; #[cfg(feature = "parquet")] use parquet::arrow::ArrowWriter; use tempfile::TempDir; @@ -405,6 +523,58 @@ mod tests { assert_batches_sorted_eq!(expected, &read); } + #[test] + fn default_schema_adapter() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + // file has a subset of the table schema fields and different type + let file_schema = Schema::new(vec![ + Field::new("c", DataType::Float64, true), // not in table schema + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![1]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // the mapped batch has the correct schema and the "b" column has been cast to Utf8 + let expected_batch = record_batch!( + ("a", Int32, vec![None, None]), // missing column filled with nulls + ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed + ) + .unwrap(); + assert_eq!(mapped_batch, expected_batch); + } + + #[test] + fn default_schema_adapter_non_nullable_columns() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), // "a"" is declared non nullable + Field::new("b", DataType::Utf8, true), + ]); + let file_schema = Schema::new(vec![ + // since file doesn't have "a" it will be filled with nulls + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![0]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + // Mapping fails because it tries to fill in a non-nullable column with nulls + let err = mapper.map_batch(file_batch).unwrap_err().to_string(); + assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + } + #[derive(Debug)] struct TestSchemaAdapterFactory; diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 72ac6e64fb0c..c35e46c0c558 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -37,3 +37,13 @@ mod plan_datafusion_err { plan_datafusion_err!("foo"); } } + +mod record_batch { + // NO other imports! + use datafusion_common::record_batch; + + #[test] + fn test_macro() { + record_batch!(("column_name", Int32, vec![1, 2, 3])).unwrap(); + } +} From 146f16a0c14f0fe65e2bd8b7226508f27ced3f13 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Sat, 26 Oct 2024 09:16:01 -0700 Subject: [PATCH 088/110] Move filtered SMJ Left Anti filtered join out of `join_partial` phase (#13111) * Move filtered SMJ Left Anti filtered join out of `join_partial` phase --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 6 +- .../src/joins/sort_merge_join.rs | 245 ++++++++++- .../test_files/sort_merge_join.slt | 383 +++++++++--------- 3 files changed, 414 insertions(+), 220 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index ca2c2bf4e438..44d34b674bbb 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{ }; use datafusion::physical_plan::memory::MemoryExec; +use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; @@ -223,9 +224,6 @@ async fn test_anti_join_1k() { } #[tokio::test] -// flaky for HjSmj case, giving 1 rows difference sometimes -// https://github.com/apache/datafusion/issues/11555 -#[ignore] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -233,7 +231,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[JoinTestType::NljHj], false) + .run_test(&[JoinTestType::HjSmj, NljHj], false) .await } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d5134855440a..7b7b7462f7e4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -802,6 +802,32 @@ fn get_corrected_filter_mask( Some(corrected_mask.finish()) } + JoinType::LeftAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(true); null_matched]); + Some(corrected_mask.finish()) + } // Only outer joins needs to keep track of processed rows and apply corrected filter mask _ => None, } @@ -835,15 +861,18 @@ impl Stream for SMJStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::LeftAnti ) { self.freeze_all()?; if !self.output_record_batches.batches.is_empty() - && self.buffered_data.scanning_finished() { - let out_batch = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok(out_batch))); + let out_filtered_batch = + self.filter_joined_batch()?; + return Poll::Ready(Some(Ok( + out_filtered_batch, + ))); } } @@ -907,15 +936,17 @@ impl Stream for SMJStream { // because target output batch size can be hit in the middle of // filtering causing the filtering to be incomplete and causing // correctness issues - let record_batch = if !(self.filter.is_some() + if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right - )) { - record_batch - } else { + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + ) + { continue; - }; + } return Poll::Ready(Some(Ok(record_batch))); } @@ -929,7 +960,10 @@ impl Stream for SMJStream { if self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti ) { let out = self.filter_joined_batch()?; @@ -1273,11 +1307,7 @@ impl SMJStream { }; if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; + join_streamed = !self.streamed_joined; join_buffered = join_streamed; } } @@ -1519,7 +1549,10 @@ impl SMJStream { // Push the filtered batch which contains rows passing join filter to the output if matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti ) { self.output_record_batches .batches @@ -1654,7 +1687,10 @@ impl SMJStream { if !(self.filter.is_some() && matches!( self.join_type, - JoinType::Left | JoinType::LeftSemi | JoinType::Right + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti )) { self.output_record_batches.batches.clear(); @@ -1727,7 +1763,7 @@ impl SMJStream { &self.schema, &[filtered_record_batch, null_joined_streamed_batch], )?; - } else if matches!(self.join_type, JoinType::LeftSemi) { + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { let output_column_indices = (0..streamed_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; @@ -3349,6 +3385,7 @@ mod tests { batch_ids: vec![], }; + // Insert already prejoined non-filtered rows batches.batches.push(RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -3835,6 +3872,178 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_left_anti_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftAnti, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 051cc6dce3d4..f4cc888d6b8e 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -407,214 +407,201 @@ select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != statement ok set datafusion.execution.batch_size = 10; -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c where false -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 11 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 11 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 14 c union all -# select 11 a, 11 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 11 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- # Test LEFT ANTI with cross batch data distribution statement ok set datafusion.execution.batch_size = 1; -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c union all -# select 11 a, 14 b, 4 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query III -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b, 1 c union all -# select 11 a, 13 b, 2 c), -#t2 as ( -# select 11 a, 12 b, 3 c where false -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) -#) order by 1, 2; -#---- -#11 12 1 -#11 13 2 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 13 c union all -# select 11 a, 14 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- -#11 12 - -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 11 c union all -# select 11 a, 15 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 -# Uncomment when filtered LEFTANTI moved -#query II -#select * from ( -#with -#t1 as ( -# select 11 a, 12 b), -#t2 as ( -# select 11 a, 12 c union all -# select 11 a, 14 c union all -# select 11 a, 11 c -# ) -#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) -#) order by 1, 2 -#---- +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- query IIII select * from ( From 5db274004bc4a7d493aba6764a8521694a67cd11 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sun, 27 Oct 2024 08:33:50 -0700 Subject: [PATCH 089/110] Improve TableScan with filters pushdown unparsing (multiple filters support) (#13131) --- datafusion/sql/src/unparser/ast.rs | 23 ++++++++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 71ff712985cd..2de1ce9125a7 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -182,7 +182,28 @@ impl SelectBuilder { self } pub fn selection(&mut self, value: Option) -> &mut Self { - self.selection = value; + // With filter pushdown optimization, the LogicalPlan can have filters defined as part of `TableScan` and `Filter` nodes. + // To avoid overwriting one of the filters, we combine the existing filter with the additional filter. + // Example: | + // | Projection: customer.c_phone AS cntrycode, customer.c_acctbal | + // | Filter: CAST(customer.c_acctbal AS Decimal128(38, 6)) > () | + // | Subquery: + // | .. | + // | TableScan: customer, full_filters=[customer.c_mktsegment = Utf8("BUILDING")] + match (&self.selection, value) { + (Some(existing_selection), Some(new_selection)) => { + self.selection = Some(ast::Expr::BinaryOp { + left: Box::new(existing_selection.clone()), + op: ast::BinaryOperator::And, + right: Box::new(new_selection), + }); + } + (None, Some(new_selection)) => { + self.selection = Some(new_selection); + } + (_, None) => (), + } + self } pub fn group_by(&mut self, value: ast::GroupByExpr) -> &mut Self { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 8e25c1c5b1cd..a58bdf4a31c4 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -968,6 +968,21 @@ fn test_table_scan_pushdown() -> Result<()> { table_scan_with_all.to_string(), "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10" ); + + let table_scan_with_additional_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(col("age"))], + )? + .filter(col("id").eq(lit(5)))? + .build()?; + let table_scan_with_filter = plan_to_sql(&table_scan_with_additional_filter)?; + assert_eq!( + table_scan_with_filter.to_string(), + "SELECT * FROM t1 WHERE (t1.id = 5) AND (t1.id > t1.age)" + ); + Ok(()) } From e22d23113f549a90483bbf161add84eba08510fd Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 27 Oct 2024 22:03:39 -0400 Subject: [PATCH 090/110] Raise a plan error on union if column count is not the same between plans. (#13117) --- datafusion/expr/src/logical_plan/builder.rs | 9 +++++++++ datafusion/sqllogictest/test_files/type_coercion.slt | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index aef531a9dbf7..1f671626873f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1482,6 +1482,15 @@ pub fn validate_unique_names<'a>( /// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union /// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { + if left_plan.schema().fields().len() != right_plan.schema().fields().len() { + return plan_err!( + "UNION queries have different number of columns: \ + left has {} columns whereas right has {} columns", + left_plan.schema().fields().len(), + right_plan.schema().fields().len() + ); + } + // Temporarily use the schema from the left input and later rely on the analyzer to // coerce the two schemas into a common one. diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index 0f9399cede2e..43e7c2f7bc25 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -103,11 +103,11 @@ CREATE TABLE orders( ); # union_different_num_columns_error() / UNION -query error Error during planning: Union schemas have different number of fields: query 1 has 1 fields whereas query 2 has 2 fields +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns SELECT order_id FROM orders UNION SELECT customer_id, o_item_id FROM orders # union_different_num_columns_error() / UNION ALL -query error Error during planning: Union schemas have different number of fields: query 1 has 1 fields whereas query 2 has 2 fields +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns SELECT order_id FROM orders UNION ALL SELECT customer_id, o_item_id FROM orders # union_with_different_column_names() From a0588cc806e310b01151f91a028ab79929194647 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 28 Oct 2024 07:35:19 -0400 Subject: [PATCH 091/110] [docs]: added `alternative_syntax` function for docs (#13140) * Add alternative syntax function. * fmt check --- datafusion/core/src/bin/print_functions_docs.rs | 7 +++++++ datafusion/expr/src/udf_docs.rs | 13 +++++++++++++ datafusion/functions/src/unicode/strpos.rs | 1 + docs/source/user-guide/sql/scalar_functions_new.md | 6 ++++++ 4 files changed, 27 insertions(+) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 598574c0703d..3aedcbc2aa63 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -195,6 +195,13 @@ fn print_docs( ); } + if let Some(alt_syntax) = &documentation.alternative_syntax { + let _ = writeln!(docs, "#### Alternative Syntax\n"); + for syntax in alt_syntax { + let _ = writeln!(docs, "```sql\n{}\n```", syntax); + } + } + // next, aliases if !f.get_aliases().is_empty() { let _ = writeln!(docs, "#### Aliases"); diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs index 63d1a964345d..a124361e42a3 100644 --- a/datafusion/expr/src/udf_docs.rs +++ b/datafusion/expr/src/udf_docs.rs @@ -47,6 +47,8 @@ pub struct Documentation { /// Left member of a pair is the argument name, right is a /// description for the argument pub arguments: Option>, + /// A list of alternative syntax examples for a function + pub alternative_syntax: Option>, /// Related functions if any. Values should match the related /// udf's name exactly. Related udf's must be of the same /// UDF type (scalar, aggregate or window) for proper linking to @@ -96,6 +98,7 @@ pub struct DocumentationBuilder { pub syntax_example: Option, pub sql_example: Option, pub arguments: Option>, + pub alternative_syntax: Option>, pub related_udfs: Option>, } @@ -107,6 +110,7 @@ impl DocumentationBuilder { syntax_example: None, sql_example: None, arguments: None, + alternative_syntax: None, related_udfs: None, } } @@ -172,6 +176,13 @@ impl DocumentationBuilder { self.with_argument(arg_name, description) } + pub fn with_alternative_syntax(mut self, syntax_name: impl Into) -> Self { + let mut alternative_syntax_array = self.alternative_syntax.unwrap_or_default(); + alternative_syntax_array.push(syntax_name.into()); + self.alternative_syntax = Some(alternative_syntax_array); + self + } + pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { let mut related = self.related_udfs.unwrap_or_default(); related.push(related_udf.into()); @@ -186,6 +197,7 @@ impl DocumentationBuilder { syntax_example, sql_example, arguments, + alternative_syntax, related_udfs, } = self; @@ -205,6 +217,7 @@ impl DocumentationBuilder { syntax_example: syntax_example.unwrap(), sql_example, arguments, + alternative_syntax, related_udfs, }) } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 152623b0e5dc..9c84590f7f94 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -97,6 +97,7 @@ fn get_strpos_doc() -> &'static Documentation { ```"#) .with_standard_argument("str", Some("String")) .with_argument("substr", "Substring expression to search for.") + .with_alternative_syntax("position(substr in origstr)") .build() .unwrap() }) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index c15821ac89a3..6031a68d40e4 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1465,6 +1465,12 @@ strpos(str, substr) +----------------------------------------+ ``` +#### Alternative Syntax + +```sql +position(substr in origstr) +``` + #### Aliases - instr From 132b232b8861888348f05cc99d11bf1d2d4f1c63 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 28 Oct 2024 07:36:18 -0400 Subject: [PATCH 092/110] Minor: Delete old cume_dist and percent_rank docs (#13137) * Delete cume_dist and percent_rank docs * fix docs --- docs/source/user-guide/sql/window_functions.md | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 6bf2005dabf9..8216a3b258b8 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -148,28 +148,10 @@ All [aggregate functions](aggregate_functions.md) can be used as window function ## Analytical functions -- [cume_dist](#cume_dist) -- [percent_rank](#percent_rank) - [first_value](#first_value) - [last_value](#last_value) - [nth_value](#nth_value) -### `cume_dist` - -Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). - -```sql -cume_dist() -``` - -### `percent_rank` - -Relative rank of the current row: (rank - 1) / (total rows - 1). - -```sql -percent_rank() -``` - ### `first_value` Returns value evaluated at the row that is the first row of the window frame. From 1fd6116dd9e1898540b4fbdbba735c4ebacc4227 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 28 Oct 2024 07:02:03 -0700 Subject: [PATCH 093/110] Add basic support for `unnest` unparsing (#13129) * Add basic support for `unnest` unparsing (#45) * Fix taplo cargo check --- datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/unparser/expr.rs | 35 ++++++++- datafusion/sql/src/unparser/plan.rs | 53 +++++++++++--- datafusion/sql/src/unparser/utils.rs | 87 ++++++++++++++++++----- datafusion/sql/tests/cases/plan_to_sql.rs | 19 ++++- 5 files changed, 163 insertions(+), 32 deletions(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 90be576a884e..1eef1b718ba6 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ strum = { version = "0.26.1", features = ["derive"] } ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } paste = "^1.0" diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8864c97bb1ff..1d0327fadbe4 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::expr::Unnest; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, @@ -466,7 +467,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), - Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::Unnest(unnest) => self.unnest_to_sql(unnest), } } @@ -1340,6 +1341,29 @@ impl Unparser<'_> { } } + /// Converts an UNNEST operation to an AST expression by wrapping it as a function call, + /// since there is no direct representation for UNNEST in the AST. + fn unnest_to_sql(&self, unnest: &Unnest) -> Result { + let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?; + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: "UNNEST".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { @@ -1855,6 +1879,15 @@ mod tests { }), r#"CAST(a AS DECIMAL(12,0))"#, ), + ( + Expr::Unnest(Unnest { + expr: Box::new(Expr::Column(Column { + relation: Some(TableReference::partial("schema", "table")), + name: "array_col".to_string(), + })), + }), + r#"UNNEST("schema"."table".array_col)"#, + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 695027374fa0..7c9054656b94 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -26,8 +26,9 @@ use super::{ subquery_alias_inner_query_and_columns, TableAliasRewriter, }, utils::{ - find_agg_node_within_select, find_window_nodes_within_select, - unproject_sort_expr, unproject_window_exprs, + find_agg_node_within_select, find_unnest_node_within_select, + find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr, + unproject_window_exprs, }, Unparser, }; @@ -173,15 +174,24 @@ impl Unparser<'_> { p: &Projection, select: &mut SelectBuilder, ) -> Result<()> { + let mut exprs = p.expr.clone(); + + // If an Unnest node is found within the select, find and unproject the unnest column + if let Some(unnest) = find_unnest_node_within_select(plan) { + exprs = exprs + .into_iter() + .map(|e| unproject_unnest_expr(e, unnest)) + .collect::>>()?; + }; + match ( find_agg_node_within_select(plan, true), find_window_nodes_within_select(plan, None, true), ) { (Some(agg), window) => { let window_option = window.as_deref(); - let items = p - .expr - .iter() + let items = exprs + .into_iter() .map(|proj_expr| { let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; self.select_item_to_sql(&unproj) @@ -198,9 +208,8 @@ impl Unparser<'_> { )); } (None, Some(window)) => { - let items = p - .expr - .iter() + let items = exprs + .into_iter() .map(|proj_expr| { let unproj = unproject_window_exprs(proj_expr, &window)?; self.select_item_to_sql(&unproj) @@ -210,8 +219,7 @@ impl Unparser<'_> { select.projection(items); } _ => { - let items = p - .expr + let items = exprs .iter() .map(|e| self.select_item_to_sql(e)) .collect::>>()?; @@ -318,7 +326,8 @@ impl Unparser<'_> { if let Some(agg) = find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?; + let unprojected = + unproject_agg_exprs(filter.predicate.clone(), agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { @@ -596,6 +605,28 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), + LogicalPlan::Unnest(unnest) => { + if !unnest.struct_type_columns.is_empty() { + return internal_err!( + "Struct type columns are not currently supported in UNNEST: {:?}", + unnest.struct_type_columns + ); + } + + // In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip. + // Otherwise, there will be a duplicate SELECT clause. + // | Projection: table.col1, UNNEST(table.col2) + // | Unnest: UNNEST(table.col2) + // | Projection: table.col1, table.col2 AS UNNEST(table.col2) + // | Filter: table.col3 = Int64(3) + // | TableScan: table projection=None + if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + // continue with projection input + self.select_to_sql_recursively(&p.input, query, select, relation) + } else { + internal_err!("Unnest input is not a Projection: {unnest:?}") + } + } _ => not_impl_err!("Unsupported operator: {plan:?}"), } } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 5e3a3aa600b6..d3d1bf351384 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -23,8 +23,8 @@ use datafusion_common::{ Column, Result, ScalarValue, }; use datafusion_expr::{ - utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr, - Window, + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, + SortExpr, Unnest, Window, }; use sqlparser::ast; @@ -62,6 +62,28 @@ pub(crate) fn find_agg_node_within_select( } } +/// Recursively searches children of [LogicalPlan] to find Unnest node if exist +pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + + if let LogicalPlan::Unnest(unnest) = input { + Some(unnest) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + None + } else { + find_unnest_node_within_select(input) + } +} + /// Recursively searches children of [LogicalPlan] to find Window nodes if exist /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). /// If Window node is not found prior to this or at all before reaching the end @@ -104,18 +126,46 @@ pub(crate) fn find_window_nodes_within_select<'a>( } } +/// Recursively identify Column expressions and transform them into the appropriate unnest expression +/// +/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) +pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(col_ref) = &sub_expr { + // Check if the column is among the columns to run unnest on. + // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if let Ok(idx) = unnest.schema.index_of_column(col_ref) { + if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let Some(unprojected_expr) = expr.get(idx) { + let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + return Ok(Transformed::yes(unnest_expr)); + } + } + } + return internal_err!( + "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name + ); + } + } + + Ok(Transformed::no(sub_expr)) + + }).map(|e| e.data) +} + /// Recursively identify all Column expressions and transform them into the appropriate /// aggregate expression contained in agg. /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. pub(crate) fn unproject_agg_exprs( - expr: &Expr, + expr: Expr, agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.clone() - .transform(|sub_expr| { + expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) @@ -123,7 +173,7 @@ pub(crate) fn unproject_agg_exprs( windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?)); + return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)); } else { internal_err!( "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name @@ -141,20 +191,19 @@ pub(crate) fn unproject_agg_exprs( /// /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. -pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result { - expr.clone() - .transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unproj) = find_window_expr(windows, &c.name) { - Ok(Transformed::yes(unproj.clone())) - } else { - Ok(Transformed::no(Expr::Column(c))) - } +pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unproj) = find_window_expr(windows, &c.name) { + Ok(Transformed::yes(unproj.clone())) } else { - Ok(Transformed::no(sub_expr)) + Ok(Transformed::no(Expr::Column(c))) } - }) - .map(|e| e.data) + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) } fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { @@ -218,7 +267,7 @@ pub(crate) fn unproject_sort_expr( // In case of aggregation there could be columns containing aggregation functions we need to unproject if let Some(agg) = agg { if agg.schema.is_column_from_schema(col_ref) { - let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?; + let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; sort_expr.expr = new_expr; return Ok(sort_expr); } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index a58bdf4a31c4..16941c5d9164 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -24,6 +24,7 @@ use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_u use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; +use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -711,7 +712,8 @@ where .with_aggregate_function(max_udaf()) .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) - .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())), + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()), }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -1084,3 +1086,18 @@ FROM person GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), ); } + +#[test] +fn test_unnest_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#, + r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#, + ); + + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#, + ); +} From 0b45b9a2dd84e30a68d8701b627293e6ac803643 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Tue, 29 Oct 2024 02:20:05 -0700 Subject: [PATCH 094/110] Improve TableScan with filters pushdown unparsing (joins) (#13132) * Improve TableScan with filters pushdown unparsing (joins) * Fix formatting * Add test with filters before and after join --- datafusion/sql/src/unparser/plan.rs | 77 ++++++++++++++++--- datafusion/sql/src/unparser/utils.rs | 93 +++++++++++++++++++++-- datafusion/sql/tests/cases/plan_to_sql.rs | 87 +++++++++++++++++++++ 3 files changed, 240 insertions(+), 17 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 7c9054656b94..2c38a1d36c1e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -27,8 +27,8 @@ use super::{ }, utils::{ find_agg_node_within_select, find_unnest_node_within_select, - find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr, - unproject_window_exprs, + find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters, + unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs, }, Unparser, }; @@ -39,8 +39,8 @@ use datafusion_common::{ Column, DataFusionError, Result, TableReference, }; use datafusion_expr::{ - expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, TableScan, + expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, }; use sqlparser::ast::{self, Ident, SetExpr}; use std::sync::Arc; @@ -468,22 +468,77 @@ impl Unparser<'_> { self.select_to_sql_recursively(input, query, select, relation) } LogicalPlan::Join(join) => { - let join_constraint = self.join_constraint_to_sql( - join.join_constraint, - &join.on, - join.filter.as_ref(), + let mut table_scan_filters = vec![]; + + let left_plan = + match try_transform_to_simple_table_scan_with_filters(&join.left)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.left), + }; + + self.select_to_sql_recursively( + left_plan.as_ref(), + query, + select, + relation, )?; + let right_plan = + match try_transform_to_simple_table_scan_with_filters(&join.right)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.right), + }; + let mut right_relation = RelationBuilder::default(); self.select_to_sql_recursively( - join.left.as_ref(), + right_plan.as_ref(), query, select, - relation, + &mut right_relation, )?; + + let join_filters = if table_scan_filters.is_empty() { + join.filter.clone() + } else { + // Combine `table_scan_filters` into a single filter using `AND` + let Some(combined_filters) = + table_scan_filters.into_iter().reduce(|acc, filter| { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(acc), + op: Operator::And, + right: Box::new(filter), + }) + }) + else { + return internal_err!("Failed to combine TableScan filters"); + }; + + // Combine `join.filter` with `combined_filters` using `AND` + match &join.filter { + Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + left: Box::new(filter.clone()), + op: Operator::And, + right: Box::new(combined_filters), + })), + None => Some(combined_filters), + } + }; + + let join_constraint = self.join_constraint_to_sql( + join.join_constraint, + &join.on, + join_filters.as_ref(), + )?; + self.select_to_sql_recursively( - join.right.as_ref(), + right_plan.as_ref(), query, select, &mut right_relation, diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index d3d1bf351384..284956cef195 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; +use std::{cmp::Ordering, sync::Arc, vec}; use datafusion_common::{ internal_err, - tree_node::{Transformed, TreeNode}, - Column, Result, ScalarValue, + tree_node::{Transformed, TransformedResult, TreeNode}, + Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, - SortExpr, Unnest, Window, + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, + LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, }; use sqlparser::ast; -use super::{dialect::DateFieldExtractStyle, Unparser}; +use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -288,6 +288,87 @@ pub(crate) fn unproject_sort_expr( Ok(sort_expr) } +/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering +/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan. +/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately. +/// If the plan contains a Projection, returns None. +/// +/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias. +/// +/// LogicalPlan example: +/// Filter: ta.j1_id < 5 +/// Alias: ta +/// TableScan: j1, j1_id > 10 +/// +/// Will return LogicalPlan below: +/// Alias: ta +/// TableScan: j1 +/// And filters: [ta.j1_id < 5, ta.j1_id > 10] +pub(crate) fn try_transform_to_simple_table_scan_with_filters( + plan: &LogicalPlan, +) -> Result)>> { + let mut filters: Vec = vec![]; + let mut plan_stack = vec![plan]; + let mut table_alias = None; + + while let Some(current_plan) = plan_stack.pop() { + match current_plan { + LogicalPlan::SubqueryAlias(alias) => { + table_alias = Some(alias.alias.clone()); + plan_stack.push(alias.input.as_ref()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); + plan_stack.push(filter.input.as_ref()); + } + LogicalPlan::TableScan(table_scan) => { + let table_schema = table_scan.source.schema(); + // optional rewriter if table has an alias + let mut filter_alias_rewriter = + table_alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: &table_schema, + alias_name: alias_name.clone(), + }); + + // rewrite filters to use table alias if present + let table_scan_filters = table_scan + .filters + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = filter_alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::, DataFusionError>>()?; + + filters.extend(table_scan_filters); + + let mut builder = LogicalPlanBuilder::scan( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + None, + )?; + + if let Some(alias) = table_alias.take() { + builder = builder.alias(alias)?; + } + + let plan = builder.build()?; + + return Ok(Some((plan, filters))); + } + _ => { + return Ok(None); + } + } + } + + Ok(None) +} + /// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. pub(crate) fn date_part_to_sql( unparser: &Unparser, diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 16941c5d9164..ea0ccb8e4b43 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1008,6 +1008,93 @@ fn test_sort_with_push_down_fetch() -> Result<()> { Ok(()) } +#[test] +fn test_join_with_table_scan_filters() -> Result<()> { + let schema_left = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ]); + + let schema_right = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let left_plan = table_scan_with_filters( + Some("left_table"), + &schema_left, + None, + vec![col("name").like(lit("some_name"))], + )? + .alias("left")? + .build()?; + + let right_plan = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .build()?; + + let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan.clone(), + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .build()?; + + let sql = plan_to_sql(&join_plan_with_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + None, + )? + .build()?; + + let sql = plan_to_sql(&join_plan_no_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let right_plan_with_filter = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .filter(col("right_table.name").eq(lit("before_join_filter_val")))? + .build()?; + + let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan_with_filter, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .filter(col("left.name").eq(lit("after_join_filter_val")))? + .build()?; + + let sql = plan_to_sql(&join_plan_multiple_filters)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#; + + assert_eq!(sql.to_string(), expected_sql); + + Ok(()) +} + #[test] fn test_interval_lhs_eq() { sql_round_trip( From 467a80481ec0d537a0bcc1feee43590babe67b66 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 29 Oct 2024 12:21:50 +0100 Subject: [PATCH 095/110] Report offending plan node when In/Exist subquery misused (#13155) --- datafusion/optimizer/src/analyzer/subquery.rs | 4 +++- datafusion/sqllogictest/test_files/subquery.slt | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index e01ae625ed9c..7c0bddf1153f 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -113,7 +113,9 @@ pub fn check_subquery_expr( | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, Window functions, Aggregate and Join plan nodes" + Projection, Filter, Window functions, Aggregate and Join plan nodes, \ + but was used in [{}]", + outer_plan.display() ), }?; check_correlations_in_subquery(inner_plan) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 26b5d8b952f6..36de19f1c3aa 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -438,7 +438,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery From 80ad713693133532b53c5a8e1fc4202084203986 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 29 Oct 2024 12:22:26 +0100 Subject: [PATCH 096/110] Remove unused assert_analyzed_plan_ne test helper (#13121) Plan textual representation is rich. Testing it's not a particular string is difficult to make robust, that's probably why the helper is unused. --- datafusion/optimizer/src/test/mod.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index cabeafd8e7de..94d07a0791b3 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -133,20 +133,6 @@ pub fn assert_analyzed_plan_with_config_eq( Ok(()) } -pub fn assert_analyzed_plan_ne( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_ne!(formatted_plan, expected); - - Ok(()) -} - pub fn assert_analyzed_plan_eq_display_indent( rule: Arc, plan: LogicalPlan, From feeb32ab090fda25b9c0aca8b86cf10a5b91acaa Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 29 Oct 2024 07:22:57 -0400 Subject: [PATCH 097/110] Add alternative syntax for extract, trim and substring. (#13143) --- .../functions/src/datetime/date_part.rs | 1 + datafusion/functions/src/string/btrim.rs | 2 ++ datafusion/functions/src/string/ltrim.rs | 1 + datafusion/functions/src/string/rtrim.rs | 1 + datafusion/functions/src/unicode/substr.rs | 1 + .../user-guide/sql/scalar_functions_new.md | 34 +++++++++++++++++++ 6 files changed, 40 insertions(+) diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 3fefa5051376..01e094bc4e0b 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -255,6 +255,7 @@ fn get_date_part_doc() -> &'static Documentation { "expression", "Time expression to operate on. Can be a constant, column, or function.", ) + .with_alternative_syntax("extract(field FROM source)") .build() .unwrap() }) diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index f689f27d9d24..e215b18d9c3c 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -124,6 +124,8 @@ fn get_btrim_doc() -> &'static Documentation { ```"#) .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(BOTH trim_str FROM str)") + .with_alternative_syntax("trim(trim_str FROM str)") .with_related_udf("ltrim") .with_related_udf("rtrim") .build() diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 91809d691647..0b4c197646b6 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -124,6 +124,7 @@ fn get_ltrim_doc() -> &'static Documentation { ```"#) .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(LEADING trim_str FROM str)") .with_related_udf("btrim") .with_related_udf("rtrim") .build() diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 06c8a85c38dd..e934147efbbe 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -124,6 +124,7 @@ fn get_rtrim_doc() -> &'static Documentation { ```"#) .with_standard_argument("str", Some("String")) .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(TRAILING trim_str FROM str)") .with_related_udf("btrim") .with_related_udf("ltrim") .build() diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 5a8c2500900b..edfe57210b71 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -173,6 +173,7 @@ fn get_substr_doc() -> &'static Documentation { .with_standard_argument("str", Some("String")) .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") + .with_alternative_syntax("substring(str from start_pos for length)") .build() .unwrap() }) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 6031a68d40e4..56173b97b405 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -808,6 +808,16 @@ btrim(str[, trim_str]) +-------------------------------------------+ ``` +#### Alternative Syntax + +```sql +trim(BOTH trim_str FROM str) +``` + +```sql +trim(trim_str FROM str) +``` + #### Aliases - trim @@ -1191,6 +1201,12 @@ ltrim(str[, trim_str]) +-------------------------------------------+ ``` +#### Alternative Syntax + +```sql +trim(LEADING trim_str FROM str) +``` + **Related functions**: - [btrim](#btrim) @@ -1387,6 +1403,12 @@ rtrim(str[, trim_str]) +-------------------------------------------+ ``` +#### Alternative Syntax + +```sql +trim(TRAILING trim_str FROM str) +``` + **Related functions**: - [btrim](#btrim) @@ -1501,6 +1523,12 @@ substr(str, start_pos[, length]) +----------------------------------------------+ ``` +#### Alternative Syntax + +```sql +substring(str from start_pos for length) +``` + #### Aliases - substring @@ -1965,6 +1993,12 @@ date_part(part, expression) - **expression**: Time expression to operate on. Can be a constant, column, or function. +#### Alternative Syntax + +```sql +extract(field FROM source) +``` + #### Aliases - datepart From 67b0f2573e3cedbfe0868d77210cbbacfb56e49b Mon Sep 17 00:00:00 2001 From: Dmitry Bugakov Date: Tue, 29 Oct 2024 12:23:18 +0100 Subject: [PATCH 098/110] Fix Utf8View as Join Key (#13115) --- datafusion/expr/src/utils.rs | 1 + datafusion/sqllogictest/test_files/joins.slt | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index bb5496c0f799..9207ad00993c 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -981,6 +981,7 @@ pub fn can_hash(data_type: &DataType) -> bool { }, DataType::Utf8 => true, DataType::LargeUtf8 => true, + DataType::Utf8View => true, DataType::Decimal128(_, _) => true, DataType::Date32 => true, DataType::Date64 => true, diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index af272e8f5022..bc40f845cc8a 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4277,3 +4277,16 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 03)----MemoryExec: partitions=1, partition_sizes=[1] 04)----MemoryExec: partitions=1, partition_sizes=[1] + +# Test Utf8View as Join Key +# Issue: https://github.com/apache/datafusion/issues/12468 +statement ok +CREATE TABLE table1(v1 STRING) AS VALUES ('foo'), (NULL); + +statement ok +CREATE TABLE table1_stringview AS SELECT arrow_cast(v1, 'Utf8View') AS v1 FROM table1; + +query T +select * from table1 as t1 natural join table1_stringview as t2; +---- +foo From d00a089971b341d8f10cb5ecb446c27a4d824ac4 Mon Sep 17 00:00:00 2001 From: Yasser Latreche Date: Tue, 29 Oct 2024 07:24:12 -0400 Subject: [PATCH 099/110] Add Support for `modulus` operation in substrait (#13108) * Add modulus operation * change the producer to output `modulus` instead of `mod` for `modulo` operation * Add a roundtrip test case for modulus --- datafusion/substrait/src/logical_plan/consumer.rs | 1 + datafusion/substrait/src/logical_plan/producer.rs | 2 +- datafusion/substrait/tests/cases/roundtrip_logical_plan.rs | 5 +++++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 3d5d7cce5673..54b93cb7e345 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -118,6 +118,7 @@ pub fn name_to_op(name: &str) -> Option { "multiply" => Some(Operator::Multiply), "divide" => Some(Operator::Divide), "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), "and" => Some(Operator::And), "or" => Some(Operator::Or), "is_distinct_from" => Some(Operator::IsDistinctFrom), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4855af683b7d..da8a4c994fb4 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -744,7 +744,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Minus => "subtract", Operator::Multiply => "multiply", Operator::Divide => "divide", - Operator::Modulo => "mod", + Operator::Modulo => "modulus", Operator::And => "and", Operator::Or => "or", Operator::IsDistinctFrom => "is_distinct_from", diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 06a047b108bd..9739afa99244 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -593,6 +593,11 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_modulus() -> Result<()> { + roundtrip("SELECT a%3 from data").await +} + #[tokio::test] async fn roundtrip_not() -> Result<()> { roundtrip("SELECT * FROM data WHERE NOT d").await From 4e38abd71e61e5b6da9b6a486c0a40c2107dfa0a Mon Sep 17 00:00:00 2001 From: JasonLi Date: Tue, 29 Oct 2024 19:24:47 +0800 Subject: [PATCH 100/110] unify cast_to function of ScalarValue (#13122) --- datafusion/common/src/scalar/mod.rs | 33 +++++++++++++------- datafusion/expr-common/src/columnar_value.rs | 27 +++------------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 43f22265f5f6..f609e9f9ef6c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -58,6 +58,7 @@ use arrow::{ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::{UnionFields, UnionMode}; +use crate::format::DEFAULT_CAST_OPTIONS; use half::f16; pub use struct_builder::ScalarStructBuilder; @@ -2809,22 +2810,30 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::from(value); - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), - }; - let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; - ScalarValue::try_from_array(&cast_arr, 0) + ScalarValue::from(value).cast_to(target_type) } /// Try to cast this value to a ScalarValue of type `data_type` - pub fn cast_to(&self, data_type: &DataType) -> Result { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), + pub fn cast_to(&self, target_type: &DataType) -> Result { + self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS) + } + + /// Try to cast this value to a ScalarValue of type `data_type` with [`CastOptions`] + pub fn cast_to_with_options( + &self, + target_type: &DataType, + cast_options: &CastOptions<'static>, + ) -> Result { + let scalar_array = match (self, target_type) { + ( + ScalarValue::Float64(Some(float_ts)), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) + .to_array()?, + _ => self.to_array()?, }; - let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?; + + let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 57056d0806a7..4b9454ed739d 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -19,7 +19,7 @@ use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::{kernels, CastOptions}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; @@ -193,28 +193,9 @@ impl ColumnarValue { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), - ColumnarValue::Scalar(scalar) => { - let scalar_array = - if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar { - ScalarValue::Int64(Some( - (float_ts * 1_000_000_000_f64).trunc() as i64, - )) - .to_array()? - } else { - scalar.to_array()? - } - } else { - scalar.to_array()? - }; - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - cast_type, - &cast_options, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( + scalar.cast_to_with_options(cast_type, &cast_options)?, + )), } } } From ac79ef3442e65f6197c7234da9fad964895b9101 Mon Sep 17 00:00:00 2001 From: Daniel Hegberg Date: Tue, 29 Oct 2024 04:31:24 -0700 Subject: [PATCH 101/110] Add unused_qualifications with deny level to linter. Fix unused_qualifications violations." (#13086) --- Cargo.toml | 1 + datafusion-examples/examples/advanced_udaf.rs | 6 +- .../examples/custom_datasource.rs | 4 +- .../examples/custom_file_format.rs | 5 +- .../examples/flight/flight_server.rs | 2 +- .../examples/function_factory.rs | 2 +- datafusion-examples/examples/simple_udaf.rs | 2 +- .../examples/simplify_udaf_expression.rs | 5 +- .../examples/simplify_udwf_expression.rs | 3 +- datafusion/common/src/config.rs | 2 +- datafusion/common/src/join_type.rs | 2 +- datafusion/common/src/parsers.rs | 3 +- datafusion/common/src/pyarrow.rs | 2 +- datafusion/common/src/scalar/mod.rs | 55 +++-- datafusion/common/src/stats.rs | 6 +- datafusion/common/src/utils/memory.rs | 7 +- datafusion/common/src/utils/proxy.rs | 7 +- datafusion/core/benches/parquet_query_sql.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 23 +- .../avro_to_arrow/arrow_array_reader.rs | 32 ++- .../core/src/datasource/avro_to_arrow/mod.rs | 2 +- .../core/src/datasource/file_format/csv.rs | 4 +- .../core/src/datasource/file_format/json.rs | 2 +- .../core/src/datasource/file_format/mod.rs | 8 +- .../src/datasource/file_format/parquet.rs | 8 +- .../src/datasource/file_format/write/demux.rs | 5 +- .../src/datasource/listing_table_factory.rs | 2 +- .../core/src/datasource/physical_plan/csv.rs | 4 +- .../physical_plan/file_scan_config.rs | 5 +- .../core/src/datasource/physical_plan/mod.rs | 2 +- .../datasource/physical_plan/parquet/mod.rs | 2 +- .../physical_plan/parquet/row_group_filter.rs | 26 +-- .../datasource/physical_plan/statistics.rs | 8 +- .../core/src/datasource/schema_adapter.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 4 +- .../core/src/execution/session_state.rs | 4 +- .../enforce_distribution.rs | 7 +- datafusion/core/src/test/mod.rs | 2 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/expr_api/simplification.rs | 31 ++- .../fuzz_cases/aggregation_fuzzer/fuzzer.rs | 4 +- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 13 ++ .../core/tests/fuzz_cases/limit_fuzz.rs | 2 +- .../sort_preserving_repartition_fuzz.rs | 2 +- .../core/tests/fuzz_cases/window_fuzz.rs | 4 +- .../core/tests/parquet/file_statistics.rs | 5 +- .../limited_distinct_aggregation.rs | 4 +- datafusion/core/tests/sql/joins.rs | 4 +- datafusion/core/tests/sql/mod.rs | 6 +- .../user_defined/user_defined_aggregates.rs | 6 +- .../tests/user_defined/user_defined_plan.rs | 6 +- .../user_defined_scalar_functions.rs | 25 +- .../user_defined_window_functions.rs | 8 +- datafusion/execution/src/disk_manager.rs | 2 +- .../expr-common/src/type_coercion/binary.rs | 8 +- datafusion/expr/src/expr.rs | 34 ++- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/ddl.rs | 2 +- datafusion/expr/src/logical_plan/dml.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/logical_plan/statement.rs | 2 +- datafusion/expr/src/test/function_stub.rs | 15 +- datafusion/expr/src/utils.rs | 24 +- datafusion/expr/src/window_frame.rs | 6 +- .../src/aggregate/count_distinct/bytes.rs | 5 +- .../src/aggregate/count_distinct/native.rs | 7 +- .../src/aggregate/groups_accumulator.rs | 8 +- .../aggregate/groups_accumulator/prim_op.rs | 3 +- .../functions-aggregate-common/src/tdigest.rs | 4 +- .../src/approx_percentile_cont.rs | 6 +- .../src/approx_percentile_cont_with_weight.rs | 4 +- .../functions-aggregate/src/array_agg.rs | 29 ++- datafusion/functions-aggregate/src/average.rs | 16 +- .../functions-aggregate/src/bit_and_or_xor.rs | 10 +- .../functions-aggregate/src/bool_and_or.rs | 5 +- .../functions-aggregate/src/correlation.rs | 8 +- datafusion/functions-aggregate/src/count.rs | 23 +- .../functions-aggregate/src/covariance.rs | 3 +- .../functions-aggregate/src/first_last.rs | 13 +- .../functions-aggregate/src/grouping.rs | 2 +- datafusion/functions-aggregate/src/median.rs | 13 +- datafusion/functions-aggregate/src/min_max.rs | 9 +- .../src/min_max/min_max_bytes.rs | 4 +- .../functions-aggregate/src/nth_value.rs | 17 +- datafusion/functions-aggregate/src/regr.rs | 3 +- datafusion/functions-aggregate/src/stddev.rs | 4 +- .../functions-aggregate/src/string_agg.rs | 3 +- datafusion/functions-aggregate/src/sum.rs | 8 +- .../functions-aggregate/src/variance.rs | 13 +- datafusion/functions-nested/src/distance.rs | 2 +- datafusion/functions-nested/src/make_array.rs | 4 +- datafusion/functions-nested/src/map_keys.rs | 4 +- datafusion/functions-nested/src/map_values.rs | 4 +- datafusion/functions/src/core/named_struct.rs | 2 +- datafusion/functions/src/core/planner.rs | 2 +- .../functions/src/datetime/make_date.rs | 6 +- datafusion/functions/src/datetime/to_char.rs | 5 +- .../functions/src/datetime/to_local_time.rs | 4 +- .../functions/src/datetime/to_timestamp.rs | 15 +- datafusion/functions/src/math/factorial.rs | 2 +- datafusion/functions/src/math/round.rs | 4 +- datafusion/functions/src/strings.rs | 16 +- datafusion/functions/src/utils.rs | 2 +- .../src/analyzer/count_wildcard_rule.rs | 4 +- datafusion/optimizer/src/analyzer/subquery.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 38 +-- .../src/decorrelate_predicate_subquery.rs | 6 +- datafusion/optimizer/src/eliminate_limit.rs | 5 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../optimizer/src/scalar_subquery_to_join.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 28 ++- .../src/single_distinct_to_groupby.rs | 8 +- .../physical-expr-common/src/binary_map.rs | 8 +- .../physical-expr-common/src/sort_expr.rs | 6 +- .../src/equivalence/properties.rs | 2 +- .../physical-expr/src/expressions/case.rs | 34 ++- .../physical-expr/src/expressions/cast.rs | 2 +- .../physical-expr/src/expressions/column.rs | 2 +- .../physical-expr/src/expressions/in_list.rs | 4 +- .../physical-expr/src/expressions/negative.rs | 2 +- .../src/expressions/unknown_column.rs | 2 +- .../physical-expr/src/intervals/cp_solver.rs | 10 +- datafusion/physical-expr/src/partitioning.rs | 4 +- .../src/aggregates/group_values/bytes.rs | 3 +- .../src/aggregates/group_values/bytes_view.rs | 3 +- .../src/aggregates/group_values/column.rs | 4 +- .../aggregates/group_values/group_column.rs | 18 +- .../src/aggregates/group_values/primitive.rs | 3 +- .../src/aggregates/group_values/row.rs | 15 +- .../physical-plan/src/aggregates/mod.rs | 13 +- .../src/aggregates/order/full.rs | 3 +- .../physical-plan/src/aggregates/order/mod.rs | 3 +- .../src/aggregates/order/partial.rs | 3 +- .../src/aggregates/topk/hash_table.rs | 4 +- datafusion/physical-plan/src/display.rs | 24 +- datafusion/physical-plan/src/insert.rs | 8 +- .../physical-plan/src/joins/cross_join.rs | 4 +- .../physical-plan/src/joins/hash_join.rs | 10 +- .../src/joins/sort_merge_join.rs | 90 +++---- .../src/joins/stream_join_utils.rs | 4 +- .../src/joins/symmetric_hash_join.rs | 35 +-- datafusion/physical-plan/src/joins/utils.rs | 46 ++-- datafusion/physical-plan/src/limit.rs | 6 +- datafusion/physical-plan/src/memory.rs | 6 +- datafusion/physical-plan/src/metrics/value.rs | 2 +- datafusion/physical-plan/src/projection.rs | 9 +- .../src/repartition/distributor_channels.rs | 2 +- .../physical-plan/src/repartition/mod.rs | 4 +- datafusion/physical-plan/src/sorts/sort.rs | 8 +- .../src/sorts/sort_preserving_merge.rs | 2 +- datafusion/physical-plan/src/stream.rs | 4 +- datafusion/physical-plan/src/streaming.rs | 2 +- datafusion/physical-plan/src/topk/mod.rs | 12 +- datafusion/physical-plan/src/unnest.rs | 12 +- datafusion/proto/Cargo.toml | 3 - .../proto/src/logical_plan/file_formats.rs | 20 +- datafusion/proto/src/logical_plan/mod.rs | 221 ++++++++---------- datafusion/proto/src/physical_plan/mod.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- datafusion/sql/src/expr/mod.rs | 2 +- datafusion/sql/src/statement.rs | 14 +- datafusion/sql/src/unparser/dialect.rs | 41 ++-- datafusion/sql/src/unparser/expr.rs | 196 ++++++++-------- datafusion/sql/tests/common/mod.rs | 7 +- datafusion/sqllogictest/bin/sqllogictests.rs | 4 +- datafusion/sqllogictest/src/test_context.rs | 2 +- .../substrait/src/logical_plan/consumer.rs | 10 +- .../substrait/src/logical_plan/producer.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- datafusion/substrait/tests/cases/serialize.rs | 5 +- 170 files changed, 865 insertions(+), 1003 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e1e3aca77153..0a7184ad2d99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -169,3 +169,4 @@ large_futures = "warn" [workspace.lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unused_qualifications = "deny" diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 1259f90d6449..414596bdc678 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -193,7 +193,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -394,8 +394,8 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.prods.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + + self.prods.capacity() * size_of::() } } diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 0f7748b13365..7440e592962b 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -110,7 +110,7 @@ struct CustomDataSourceInner { } impl Debug for CustomDataSource { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("custom_db") } } @@ -220,7 +220,7 @@ impl CustomExec { } impl DisplayAs for CustomExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { write!(f, "CustomExec") } } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index b85127d42f71..95168597ebaa 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -74,10 +74,7 @@ impl FileFormat for TSVFileFormat { "tsv".to_string() } - fn get_ext_with_compression( - &self, - c: &FileCompressionType, - ) -> datafusion::error::Result { + fn get_ext_with_compression(&self, c: &FileCompressionType) -> Result { if c == &FileCompressionType::UNCOMPRESSED { Ok("tsv".to_string()) } else { diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index f9d1b8029f04..cc5f43746ddf 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -105,7 +105,7 @@ impl FlightService for FlightServiceImpl { } // add an initial FlightData message that sends schema - let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default(); + let options = arrow::ipc::writer::IpcWriteOptions::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index f57b3bf60404..b42f25437d77 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -121,7 +121,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 140fc0d3572d..ef97bf9763b0 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -131,7 +131,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index aedc511c62fe..52a27317e3c3 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -70,7 +70,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } @@ -90,8 +90,7 @@ impl AggregateUDFImpl for BetterAvgUdaf { fn simplify(&self) -> Option { // as an example for this functionality we replace UDF function // with build-in aggregate function to illustrate the use - let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, - _: &dyn SimplifyInfo| { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( avg_udaf(), // yes it is the same Avg, `BetterAvgUdaf` was just a diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index d95f1147bc37..117063df4e0d 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -70,8 +70,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. fn simplify(&self) -> Option { - let simplify = |window_function: datafusion_expr::expr::WindowFunction, - _: &dyn SimplifyInfo| { + let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { Ok(Expr::WindowFunction(WindowFunction { fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 33e5184d2cac..15290204fbac 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -876,7 +876,7 @@ pub trait ConfigExtension: ExtensionOptions { } /// An object-safe API for storing arbitrary configuration -pub trait ExtensionOptions: Send + Sync + std::fmt::Debug + 'static { +pub trait ExtensionOptions: Send + Sync + fmt::Debug + 'static { /// Return `self` as [`Any`] /// /// This is needed until trait upcasting is stabilised diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index fbdae1c50a83..d502e7836da3 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -97,7 +97,7 @@ pub enum JoinConstraint { } impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { JoinSide::Left => write!(f, "left"), JoinSide::Right => write!(f, "right"), diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index e23edb4e2adb..c73c8a55f18c 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -18,7 +18,6 @@ //! Interval parsing logic use std::fmt::Display; -use std::result; use std::str::FromStr; use sqlparser::parser::ParserError; @@ -41,7 +40,7 @@ pub enum CompressionTypeVariant { impl FromStr for CompressionTypeVariant { type Err = ParserError; - fn from_str(s: &str) -> result::Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index 87254a499fb1..bdcf831c7884 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -34,7 +34,7 @@ impl From for PyErr { } impl FromPyArrow for ScalarValue { - fn from_pyarrow_bound(value: &pyo3::Bound<'_, pyo3::PyAny>) -> PyResult { + fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); let typ = value.getattr("type")?; let val = value.call_method0("as_py")?; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index f609e9f9ef6c..7a1eaa2ad65b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -28,6 +28,7 @@ use std::fmt; use std::hash::Hash; use std::hash::Hasher; use std::iter::repeat; +use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; @@ -691,8 +692,8 @@ hash_float_value!((f64, u64), (f32, u32)); // # Panics // // Panics if there is an error when creating hash values for rows -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { use ScalarValue::*; match self { Decimal128(v, p, s) => { @@ -768,7 +769,7 @@ impl std::hash::Hash for ScalarValue { } } -fn hash_nested_array(arr: ArrayRef, state: &mut H) { +fn hash_nested_array(arr: ArrayRef, state: &mut H) { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -802,7 +803,7 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(if value.is_null() { + let key_array: PrimitiveArray = repeat(if value.is_null() { None } else { Some(K::default_value()) @@ -2043,7 +2044,7 @@ impl ScalarValue { scale: i8, size: usize, ) -> Result { - Ok(std::iter::repeat(value) + Ok(repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale)?) @@ -2512,7 +2513,7 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr).take(size).collect::>(); + let arrays = repeat(arr).take(size).collect::>(); let ret = match !arrays.is_empty() { true => arrow::compute::concat(arrays.as_slice())?, false => arr.slice(0, 0), @@ -3083,7 +3084,7 @@ impl ScalarValue { /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + match self { ScalarValue::Null | ScalarValue::Boolean(_) @@ -3137,12 +3138,12 @@ impl ScalarValue { ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() - .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .map(|(_id, sv)| sv.size() - size_of_val(sv)) .unwrap_or_default() // `fields` is boxed, so it is NOT already included in `self` - + std::mem::size_of_val(fields) - + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + + size_of_val(fields) + + (size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - size_of_val(field)).sum::() } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` @@ -3155,11 +3156,11 @@ impl ScalarValue { /// /// Includes the size of the [`Vec`] container itself. pub fn size_of_vec(vec: &Vec) -> usize { - std::mem::size_of_val(vec) - + (std::mem::size_of::() * vec.capacity()) + size_of_val(vec) + + (size_of::() * vec.capacity()) + vec .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3167,11 +3168,11 @@ impl ScalarValue { /// /// Includes the size of the [`VecDeque`] container itself. pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { - std::mem::size_of_val(vec_deque) - + (std::mem::size_of::() * vec_deque.capacity()) + size_of_val(vec_deque) + + (size_of::() * vec_deque.capacity()) + vec_deque .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -3179,11 +3180,11 @@ impl ScalarValue { /// /// Includes the size of the [`HashSet`] container itself. pub fn size_of_hashset(set: &HashSet) -> usize { - std::mem::size_of_val(set) - + (std::mem::size_of::() * set.capacity()) + size_of_val(set) + + (size_of::() * set.capacity()) + set .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } } @@ -4445,7 +4446,7 @@ mod tests { let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); - let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); + let arrow_result = add(arrow_left_array, arrow_right_array); assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } @@ -5060,13 +5061,13 @@ mod tests { // thus the size of the enum appears to as well // The value may also change depending on rust version - assert_eq!(std::mem::size_of::(), 64); + assert_eq!(size_of::(), 64); } #[test] fn memory_size() { let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); - assert_eq!(sv.size(), std::mem::size_of::() + 10,); + assert_eq!(sv.size(), size_of::() + 10,); let sv_size = sv.size(); let mut v = Vec::with_capacity(10); @@ -5075,9 +5076,7 @@ mod tests { assert_eq!(v.capacity(), 10); assert_eq!( ScalarValue::size_of_vec(&v), - std::mem::size_of::>() - + (9 * std::mem::size_of::()) - + sv_size, + size_of::>() + (9 * size_of::()) + sv_size, ); let mut s = HashSet::with_capacity(0); @@ -5087,8 +5086,8 @@ mod tests { let s_capacity = s.capacity(); assert_eq!( ScalarValue::size_of_hashset(&s), - std::mem::size_of::>() - + ((s_capacity - 1) * std::mem::size_of::()) + size_of::>() + + ((s_capacity - 1) * size_of::()) + sv_size, ); } diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index d8e62b3045f9..e669c674f78a 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -190,7 +190,7 @@ impl Precision { } } -impl Debug for Precision { +impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -200,7 +200,7 @@ impl Debug for Precision } } -impl Display for Precision { +impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -341,7 +341,7 @@ fn check_num_rows(value: Option, is_exact: bool) -> Precision { } impl Display for Statistics { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // string of column statistics let column_stats = self .column_statistics diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index 2c34b61bd093..d5ce59e3421b 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -18,6 +18,7 @@ //! This module provides a function to estimate the memory size of a HashTable prior to alloaction use crate::{DataFusionError, Result}; +use std::mem::size_of; /// Estimates the memory size required for a hash table prior to allocation. /// @@ -87,7 +88,7 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result // + size of entry * number of buckets // + 1 byte for each bucket // + fixed size of collection (HashSet/HashTable) - std::mem::size_of::() + size_of::() .checked_mul(estimated_buckets)? .checked_add(estimated_buckets)? .checked_add(fixed_size) @@ -108,7 +109,7 @@ mod tests { #[test] fn test_estimate_memory() { // size (bytes): 48 - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two() let num_elements = 8; @@ -126,7 +127,7 @@ mod tests { #[test] fn test_estimate_memory_overflow() { let num_elements = usize::MAX; - let fixed_size = std::mem::size_of::>(); + let fixed_size = size_of::>(); let estimated = estimate_memory_size::(num_elements, fixed_size); assert!(estimated.is_err()); diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs index d68b5e354384..5d14a1517129 100644 --- a/datafusion/common/src/utils/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -18,6 +18,7 @@ //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations use hashbrown::raw::{Bucket, RawTable}; +use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. pub trait VecAllocExt { @@ -93,7 +94,7 @@ impl VecAllocExt for Vec { let new_capacity = self.capacity(); if new_capacity > prev_capacty { // capacity changed, so we allocated more - let bump_size = (new_capacity - prev_capacty) * std::mem::size_of::(); + let bump_size = (new_capacity - prev_capacty) * size_of::(); // Note multiplication should never overflow because `push` would // have panic'd first, but the checked_add could potentially // overflow since accounting could be tracking additional values, and @@ -102,7 +103,7 @@ impl VecAllocExt for Vec { } } fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } @@ -157,7 +158,7 @@ impl RawTableAllocExt for RawTable { // need to request more memory let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * std::mem::size_of::(); + let bump_size = bump_elements * size_of::(); *accounting = (*accounting).checked_add(bump_size).expect("overflow"); self.reserve(bump_elements, hasher); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index bc4298786002..f82a126c5652 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -249,7 +249,7 @@ fn criterion_benchmark(c: &mut Criterion) { } // Temporary file must outlive the benchmarks, it is deleted when dropped - std::mem::drop(temp_file); + drop(temp_file); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d1d49bfaa693..e5d352a63c7a 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1941,12 +1941,12 @@ mod tests { use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use arrow::array::{self, Int32Array}; + use arrow::array::Int32Array; use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -1979,8 +1979,8 @@ mod tests { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), - Arc::new(array::StringArray::from(vec!["a"])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), ], ) .unwrap(); @@ -2176,7 +2176,7 @@ mod tests { async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(expr::WindowFunction::new( + let first_row = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), @@ -3570,11 +3570,10 @@ mod tests { #[tokio::test] async fn with_column_renamed_case_sensitive() -> Result<()> { - let config = - SessionConfig::from_string_hash_map(&std::collections::HashMap::from([( - "datafusion.sql_parser.enable_ident_normalization".to_owned(), - "false".to_owned(), - )]))?; + let config = SessionConfig::from_string_hash_map(&HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; let ctx = SessionContext::new_with_config(config); let name = "aggregate_test_100"; register_aggregate_csv(&ctx, name).await?; @@ -3646,7 +3645,7 @@ mod tests { #[tokio::test] async fn row_writer_resize_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "column_1", DataType::Utf8, false, @@ -3655,7 +3654,7 @@ mod tests { let data = RecordBatch::try_new( schema, vec![ - Arc::new(arrow::array::StringArray::from(vec![ + Arc::new(StringArray::from(vec![ Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), ])) diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index 98b6702bc383..9f089c7c0cea 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -206,7 +206,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn build_primitive_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef where T: ArrowNumericType + Resolver, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { Arc::new( rows.iter() @@ -354,7 +354,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let builder = builder .as_any_mut() .downcast_mut::>() - .ok_or_else(||ArrowError::SchemaError( + .ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -369,7 +369,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { builder.append(true); } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -402,7 +402,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { col_name: &str, ) -> ArrowResult where - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, T: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: StringDictionaryBuilder = @@ -453,12 +453,10 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt64 => { self.build_dictionary_array::(rows, col_name) } - _ => Err(ArrowError::SchemaError( - "unsupported dictionary key type".to_string(), - )), + _ => Err(SchemaError("unsupported dictionary key type".to_string())), } } else { - Err(ArrowError::SchemaError( + Err(SchemaError( "dictionary types other than UTF-8 not yet supported".to_string(), )) } @@ -532,7 +530,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt32 => self.read_primitive_list_values::(rows), DataType::UInt64 => self.read_primitive_list_values::(rows), DataType::Float16 => { - return Err(ArrowError::SchemaError("Float16 not supported".to_string())) + return Err(SchemaError("Float16 not supported".to_string())) } DataType::Float32 => self.read_primitive_list_values::(rows), DataType::Float64 => self.read_primitive_list_values::(rows), @@ -541,7 +539,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { - return Err(ArrowError::SchemaError( + return Err(SchemaError( "Temporal types are not yet supported, see ARROW-4803".to_string(), )) } @@ -623,7 +621,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .unwrap() } datatype => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "Nested list of {datatype:?} not supported" ))); } @@ -737,7 +735,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time64" ))) } @@ -751,7 +749,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time32" ))) } @@ -854,7 +852,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { make_array(data) } _ => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "type {:?} not supported", field.data_type() ))) @@ -870,7 +868,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn read_primitive_list_values(&self, rows: &[&Value]) -> ArrayData where T: ArrowPrimitiveType + ArrowNumericType, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { let values = rows .iter() @@ -970,7 +968,7 @@ fn resolve_u8(v: &Value) -> AvroResult { other => Err(AvroError::GetU8(other.into())), }?; if let Value::Int(n) = int { - if n >= 0 && n <= std::convert::From::from(u8::MAX) { + if n >= 0 && n <= From::from(u8::MAX) { return Ok(n as u8); } } @@ -1048,7 +1046,7 @@ fn maybe_resolve_union(value: &Value) -> &Value { impl Resolver for N where N: ArrowNumericType, - N::Native: num_traits::cast::NumCast, + N::Native: NumCast, { fn resolve(value: &Value) -> Option { let value = maybe_resolve_union(value); diff --git a/datafusion/core/src/datasource/avro_to_arrow/mod.rs b/datafusion/core/src/datasource/avro_to_arrow/mod.rs index c59078c89dd0..71184a78c96f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/mod.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/mod.rs @@ -39,7 +39,7 @@ use std::io::Read; pub fn read_avro_schema_from_reader(reader: &mut R) -> Result { let avro_reader = apache_avro::Reader::new(reader)?; let schema = avro_reader.writer_schema(); - schema::to_arrow_schema(schema) + to_arrow_schema(schema) } #[cfg(not(feature = "avro"))] diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index f235c3b628a0..3cb5ae4f85ca 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -78,7 +78,7 @@ impl CsvFormatFactory { } } -impl fmt::Debug for CsvFormatFactory { +impl Debug for CsvFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CsvFormatFactory") .field("options", &self.options) @@ -968,7 +968,7 @@ mod tests { limit: Option, has_header: bool, ) -> Result> { - let root = format!("{}/csv", crate::test_util::arrow_test_data()); + let root = format!("{}/csv", arrow_test_data()); let format = CsvFormat::default().with_has_header(has_header); scan_format(state, &format, &root, file_name, projection, limit).await } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index c9ed0c0d2805..fd97da52165b 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -118,7 +118,7 @@ impl GetExt for JsonFormatFactory { } } -impl fmt::Debug for JsonFormatFactory { +impl Debug for JsonFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("JsonFormatFactory") .field("options", &self.options) diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index a313a7a9bcb1..24f1111517d2 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -79,7 +79,7 @@ pub trait FileFormatFactory: Sync + Send + GetExt + Debug { /// /// [`TableProvider`]: crate::catalog::TableProvider #[async_trait] -pub trait FileFormat: Send + Sync + fmt::Debug { +pub trait FileFormat: Send + Sync + Debug { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -224,7 +224,7 @@ pub fn format_as_file_type( /// downcasted to a [DefaultFileType]. pub fn file_type_to_format( file_type: &Arc, -) -> datafusion_common::Result> { +) -> Result> { match file_type .as_ref() .as_any() @@ -447,8 +447,8 @@ pub(crate) mod test_util { iterations_detected: Arc>, } - impl std::fmt::Display for VariableStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl Display for VariableStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "VariableStream") } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 2d45c76ce918..9153e71a5c26 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -165,7 +165,7 @@ impl GetExt for ParquetFormatFactory { } } -impl fmt::Debug for ParquetFormatFactory { +impl Debug for ParquetFormatFactory { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ParquetFormatFactory") .field("ParquetFormatFactory", &self.options) @@ -1439,7 +1439,7 @@ mod tests { } impl Display for RequestCountingObjectStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "RequestCounting({})", self.inner) } } @@ -1707,7 +1707,7 @@ mod tests { let null_utf8 = if force_views { ScalarValue::Utf8View(None) } else { - ScalarValue::Utf8(None) + Utf8(None) }; // Fetch statistics for first file @@ -1720,7 +1720,7 @@ mod tests { let expected_type = if force_views { ScalarValue::Utf8View } else { - ScalarValue::Utf8 + Utf8 }; assert_eq!( c1_stats.max_value, diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index 427b28db4030..1746ffef8282 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -280,9 +280,8 @@ async fn hive_style_partitions_demuxer( Some(part_tx) => part_tx, None => { // Create channel for previously unseen distinct partition key and notify consumer of new file - let (part_tx, part_rx) = tokio::sync::mpsc::channel::( - max_buffered_recordbatches, - ); + let (part_tx, part_rx) = + mpsc::channel::(max_buffered_recordbatches); let file_path = compute_hive_style_file_path( &part_key, &partition_by, diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 701a13477b5b..581d88d25884 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -91,7 +91,7 @@ impl TableProviderFactory for ListingTableFactory { .field_with_name(col) .map_err(|e| arrow_datafusion_err!(e)) }) - .collect::>>()? + .collect::>>()? .into_iter() .map(|f| (f.name().to_owned(), f.data_type().to_owned())) .collect(); diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 6cd1864deb1d..5beffc3b0581 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -1216,7 +1216,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\n1,2\n3,4"); + let data = Bytes::from("a,b\n1,2\n3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); @@ -1247,7 +1247,7 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\r1,2\r3,4"); + let data = Bytes::from("a,b\r1,2\r3,4"); let path = object_store::path::Path::from("a.csv"); store.put(&path, data.into()).await.unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 415ea62b3bb3..96c0e452e29e 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,7 +19,8 @@ //! file sources. use std::{ - borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, + sync::Arc, vec, }; use super::{get_projected_output_ordering, statistics::MinMaxStatistics}; @@ -497,7 +498,7 @@ impl ZeroBufferGenerator where T: ArrowNativeType, { - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = size_of::(); fn get_buffer(&mut self, n_vals: usize) -> Buffer { match &mut self.cache { diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 6e8752ccfbf4..407a3b74f79f 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -763,7 +763,7 @@ mod tests { /// create a PartitionedFile for testing fn partitioned_file(path: &str) -> PartitionedFile { let object_meta = ObjectMeta { - location: object_store::path::Path::parse(path).unwrap(), + location: Path::parse(path).unwrap(), last_modified: Utc::now(), size: 42, e_tag: None, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 743dd5896986..059f86ce110f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -2227,7 +2227,7 @@ mod tests { // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; - std::fs::create_dir(&out_dir).unwrap(); + fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; let schema: Schema = df.schema().into(); // Register a listing table - this will use all files in the directory as data sources diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index a1d74cb54355..7406676652f6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -779,11 +779,8 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 2), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 2), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -849,11 +846,8 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 0), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 0), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -863,7 +857,7 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); @@ -947,7 +941,7 @@ mod tests { // INT64: c1 < 5, the c1 is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) @@ -1005,7 +999,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) @@ -1018,7 +1012,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1083,7 +1077,7 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) @@ -1096,7 +1090,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); diff --git a/datafusion/core/src/datasource/physical_plan/statistics.rs b/datafusion/core/src/datasource/physical_plan/statistics.rs index e1c61ec1a712..3ca3ba89f4d9 100644 --- a/datafusion/core/src/datasource/physical_plan/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/statistics.rs @@ -278,13 +278,9 @@ impl MinMaxStatistics { fn sort_columns_from_physical_sort_exprs( sort_order: &[PhysicalSortExpr], -) -> Option> { +) -> Option> { sort_order .iter() - .map(|expr| { - expr.expr - .as_any() - .downcast_ref::() - }) + .map(|expr| expr.expr.as_any().downcast_ref::()) .collect::>>() } diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs index 80d2bf987473..5ba597e4b542 100644 --- a/datafusion/core/src/datasource/schema_adapter.rs +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -478,7 +478,7 @@ mod tests { writer.close().unwrap(); let location = Path::parse(path.to_str().unwrap()).unwrap(); - let metadata = std::fs::metadata(path.as_path()).expect("Local file metadata"); + let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); let meta = ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 606759aae5ee..333f83c673cc 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2139,9 +2139,9 @@ mod tests { fn create_physical_expr( &self, _expr: &Expr, - _input_dfschema: &crate::common::DFSchema, + _input_dfschema: &DFSchema, _session_state: &SessionState, - ) -> Result> { + ) -> Result> { unimplemented!() } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 4953eecd66e3..d50c912dd2fd 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -512,7 +512,7 @@ impl SessionState { /// [`catalog::resolve_table_references`]: crate::catalog_common::resolve_table_references pub fn resolve_table_references( &self, - statement: &datafusion_sql::parser::Statement, + statement: &Statement, ) -> datafusion_common::Result> { let enable_ident_normalization = self.config.options().sql_parser.enable_ident_normalization; @@ -526,7 +526,7 @@ impl SessionState { /// Convert an AST Statement into a LogicalPlan pub async fn statement_to_plan( &self, - statement: datafusion_sql::parser::Statement, + statement: Statement, ) -> datafusion_common::Result { let references = self.resolve_table_references(&statement)?; diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index c971e6150633..aa4bcb683749 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1416,8 +1416,8 @@ pub(crate) mod tests { use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{ - expressions, expressions::binary, expressions::lit, LexOrdering, - PhysicalSortExpr, PhysicalSortRequirement, + expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, + PhysicalSortRequirement, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::PlanProperties; @@ -1646,8 +1646,7 @@ pub(crate) mod tests { .enumerate() .map(|(index, (_col, name))| { ( - Arc::new(expressions::Column::new(name, index)) - as Arc, + Arc::new(Column::new(name, index)) as Arc, name.clone(), ) }) diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 08740daa0c8e..9ac75c8f3efb 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -69,7 +69,7 @@ pub fn create_table_dual() -> Arc { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), Arc::new(array::StringArray::from(vec!["a"])), ], ) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3520ab8fed2b..0c3c2a99517e 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1434,9 +1434,7 @@ async fn unnest_analyze_metrics() -> Result<()> { .explain(false, true)? .collect() .await?; - let formatted = arrow::util::pretty::pretty_format_batches(&results) - .unwrap() - .to_string(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 800a087587da..68785b7a5a45 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -29,10 +29,10 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr, table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarUDF, Volatility, + table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarUDF, Volatility, }; -use datafusion_functions::{math, string}; +use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; @@ -368,13 +368,13 @@ fn test_const_evaluator() { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]); + let expr = concat(vec![lit("foo"), lit("bar")]); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]); - let expr = string::expr_fn::concat(vec![lit("foo"), concat1]); + let concat1 = concat(vec![lit("bar"), lit("baz")]); + let expr = concat(vec![lit("foo"), concat1]); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments @@ -407,7 +407,7 @@ fn test_const_evaluator_scalar_functions() { #[test] fn test_const_evaluator_now() { let ts_nanos = 1599566400000000000i64; - let time = chrono::Utc.timestamp_nanos(ts_nanos); + let time = Utc.timestamp_nanos(ts_nanos); let ts_string = "2020-09-08T12:05:00+00:00"; // now() --> ts test_evaluate_with_start_time(now(), lit_timestamp_nano(ts_nanos), &time); @@ -429,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + let expr = Expr::ScalarFunction(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -438,21 +438,16 @@ fn test_evaluator_udfs() { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - args.clone(), - )); + let expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = - Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); - let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - folded_args, - )); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 898d1081ff13..0704bafa0318 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -104,7 +104,7 @@ impl AggregationFuzzerBuilder { } } -impl std::default::Default for AggregationFuzzerBuilder { +impl Default for AggregationFuzzerBuilder { fn default() -> Self { Self::new() } @@ -375,7 +375,7 @@ pub struct QueryBuilder { } impl QueryBuilder { pub fn new() -> Self { - std::default::Default::default() + Default::default() } /// return the table name if any diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 44d34b674bbb..c8478db22bd4 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -90,6 +90,7 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_inner_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -102,6 +103,7 @@ async fn test_inner_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_inner_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -114,6 +116,7 @@ async fn test_inner_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_left_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -126,6 +129,7 @@ async fn test_left_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -138,6 +142,7 @@ async fn test_left_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_right_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -150,6 +155,7 @@ async fn test_right_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -162,6 +168,7 @@ async fn test_right_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_full_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -174,6 +181,7 @@ async fn test_full_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] // flaky for HjSmj case // https://github.com/apache/datafusion/issues/12359 async fn test_full_join_1k_filtered() { @@ -188,6 +196,7 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_semi_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -200,6 +209,7 @@ async fn test_semi_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -212,6 +222,7 @@ async fn test_semi_join_1k_filtered() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_anti_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -224,6 +235,7 @@ async fn test_anti_join_1k() { } #[tokio::test] +#[allow(unused_qualifications)] async fn test_anti_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), @@ -449,6 +461,7 @@ impl JoinFuzzTestCase { /// `join_tests` - identifies what join types to test /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders, /// so it is easy to debug a test on top of the failed data + #[allow(unused_qualifications)] async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) { for batch_size in self.batch_sizes { let session_config = SessionConfig::new().with_batch_size(*batch_size); diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 95d97709f319..c52acdd82764 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -341,7 +341,7 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - rand::thread_rng() + thread_rng() .sample_iter(rand::distributions::Alphanumeric) .take(len) .map(char::from) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index a72affc2b079..353db8668363 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -174,7 +174,7 @@ mod sp_repartition_fuzz_tests { }) .unzip(); - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + let sort_arrs = lexsort(&sort_columns, None)?; for (idx, arr) in izip!(indices, sort_arrs) { schema_vec[idx] = Some(arr); } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index d649919f1b6a..61b4e32ad6c9 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -293,7 +293,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { vec![window_expr], memory_exec.clone(), vec![], - InputOrderMode::Linear, + Linear, )?); let task_ctx = ctx.task_ctx(); let mut collected_results = @@ -592,7 +592,7 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, InputOrderMode::Sorted); + let is_linear = !matches!(search_mode, Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 18d8300fb254..4b5d22bfa71f 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,7 +28,6 @@ use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, }; @@ -211,8 +210,8 @@ fn get_cache_runtime_state() -> ( SessionState, ) { let cache_config = CacheManagerConfig::default(); - let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); - let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + let file_static_cache = Arc::new(DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(DefaultListFilesCache::default()); let cache_config = cache_config .with_files_statistics_cache(Some(file_static_cache.clone())) diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index d6991711f581..6859e2f1468c 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -375,7 +375,7 @@ fn test_has_filter() -> Result<()> { // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec let filter_expr = Some(expressions::binary( - expressions::col("a", &schema)?, + col("a", &schema)?, Operator::Gt, cast(expressions::lit(1u32), &schema, DataType::Int32)?, &schema, @@ -408,7 +408,7 @@ fn test_has_filter() -> Result<()> { #[test] fn test_has_order_by() -> Result<()> { let sort_key = vec![PhysicalSortExpr { - expr: expressions::col("a", &schema()).unwrap(), + expr: col("a", &schema()).unwrap(), options: SortOptions::default(), }]; let source = parquet_exec_with_sort(vec![sort_key]); diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index addabc8a3612..fab92c0f9c2b 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -33,7 +33,7 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; @@ -101,7 +101,7 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index dc9d04786021..177427b47d21 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -65,7 +65,7 @@ pub mod select; mod sql_api; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let df = ctx .sql(&format!( @@ -103,7 +103,7 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -227,7 +227,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { } async fn register_alltypes_parquet(ctx: &SessionContext) { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 1e0d3d9d514e..497addd23094 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -747,7 +747,7 @@ impl Accumulator for FirstSelector { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -816,7 +816,7 @@ impl Accumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } fn state(&mut self) -> Result> { @@ -864,6 +864,6 @@ impl GroupsAccumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 6c4e3c66e397..c96256784402 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -513,11 +513,7 @@ impl Debug for TopKExec { } impl DisplayAs for TopKExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "TopKExec: k={}", self.k) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 0887645b8cbf..f1b172862399 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -936,11 +936,11 @@ struct ScalarFunctionWrapper { name: String, expr: Expr, signature: Signature, - return_type: arrow_schema::DataType, + return_type: DataType, } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -948,21 +948,15 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } - fn return_type( - &self, - _arg_types: &[arrow_schema::DataType], - ) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - fn invoke( - &self, - _args: &[datafusion_expr::ColumnarValue], - ) -> Result { + fn invoke(&self, _args: &[ColumnarValue]) -> Result { internal_err!("This function should not get invoked!") } @@ -1042,10 +1036,7 @@ impl TryFrom for ScalarFunctionWrapper { .into_iter() .map(|a| a.data_type) .collect(), - definition - .params - .behavior - .unwrap_or(datafusion_expr::Volatility::Volatile), + definition.params.behavior.unwrap_or(Volatility::Volatile), ), }) } @@ -1350,7 +1341,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -1362,7 +1353,7 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { } async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3760328934bc..8fe028eedd44 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -593,11 +593,7 @@ impl PartitionEvaluator for OddCounter { Ok(scalar) } - fn evaluate_all( - &mut self, - values: &[arrow_array::ArrayRef], - num_rows: usize, - ) -> Result { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { println!("evaluate_all, values: {values:#?}, num_rows: {num_rows}"); self.test_state.inc_evaluate_all_called(); @@ -641,7 +637,7 @@ fn odd_count(arr: &Int64Array) -> i64 { } /// returns an array of num_rows that has the number of odd values in `arr` -fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index c98d7e5579f0..38c259fcbdc8 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -173,7 +173,7 @@ fn create_local_dirs(local_dirs: Vec) -> Result>> { local_dirs .iter() .map(|root| { - if !std::path::Path::new(root).exists() { + if !Path::new(root).exists() { std::fs::create_dir(root)?; } Builder::new() diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 2f806bf76d16..31fe6a59baee 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -89,7 +89,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { // Logical binary boolean operators can only be evaluated for // boolean or null arguments. - Ok(Signature::uniform(DataType::Boolean)) + Ok(Signature::uniform(Boolean)) } else { plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" @@ -1225,9 +1225,9 @@ pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (DataType::Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), - (Utf8View | Utf8 | LargeUtf8, DataType::Null) => Some(lhs_type.clone()), - (DataType::Null, DataType::Null) => Some(Utf8), + (Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), + (Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()), + (Null, Null) => Some(Utf8), _ => None, } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4d73c2a04486..bda4d7ae3d7f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,8 +29,8 @@ use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::Volatility; use crate::{ - built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, - Signature, WindowFrame, WindowUDF, + udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, + WindowUDF, }; use arrow::datatypes::{DataType, FieldRef}; @@ -695,11 +695,11 @@ impl AggregateFunction { pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function /// A a built-in window function - BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + BuiltInWindowFunction(BuiltInWindowFunction), /// A user defined aggregate function AggregateUDF(Arc), /// A user defined aggregate function - WindowUDF(Arc), + WindowUDF(Arc), } impl WindowFunctionDefinition { @@ -742,14 +742,12 @@ impl WindowFunctionDefinition { } } -impl fmt::Display for WindowFunctionDefinition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { - std::fmt::Display::fmt(fun, f) - } - WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Display::fmt(fun, f), - WindowFunctionDefinition::WindowUDF(fun) => std::fmt::Display::fmt(fun, f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f), } } } @@ -833,9 +831,7 @@ pub fn find_df_window_func(name: &str) -> Option { // may have different implementations for these cases. If the sought // function is not found among built-in window functions, we search for // it among aggregate functions. - if let Ok(built_in_function) = - built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) - { + if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { Some(WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, )) @@ -2141,8 +2137,8 @@ pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Display for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for Expr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), @@ -2346,7 +2342,7 @@ impl fmt::Display for Expr { } fn fmt_function( - f: &mut fmt::Formatter, + f: &mut Formatter, fun: &str, distinct: bool, args: &[Expr], @@ -2588,13 +2584,13 @@ mod test { assert_eq!( find_df_window_func("first_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue )) ); assert_eq!( find_df_window_func("LAST_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::LastValue + BuiltInWindowFunction::LastValue )) ); assert_eq!(find_df_window_func("not_exist"), None) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1f671626873f..2547aa23d3cd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1678,7 +1678,7 @@ impl TableSource for LogicalTableSource { fn supports_filters_pushdown( &self, filters: &[&Expr], - ) -> Result> { + ) -> Result> { Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index c4fa9f4c3fed..93e8b5fd045e 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -120,7 +120,7 @@ impl DdlStatement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a DdlStatement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 68b3ac41fa08..669bc8e8a7d3 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -165,7 +165,7 @@ impl WriteOp { } impl Display for WriteOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } @@ -196,7 +196,7 @@ impl InsertOp { } impl Display for InsertOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 572285defba0..a301c48659d7 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3382,8 +3382,8 @@ pub struct ColumnUnnestList { pub depth: usize, } -impl fmt::Display for ColumnUnnestList { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Display for ColumnUnnestList { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}|depth={}", self.output_column, self.depth) } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index ed06375157c9..7ad18ce7bbf7 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -61,7 +61,7 @@ impl Statement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a Statement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b4f768085fcc..262aa99e5007 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -34,7 +34,6 @@ use crate::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::AggregateOrderSensitivity, Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, - Volatility, }; macro_rules! create_func { @@ -106,7 +105,7 @@ pub struct Sum { impl Sum { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::user_defined(Immutable), } } } @@ -236,13 +235,13 @@ impl Count { pub fn new() -> Self { Self { aliases: vec!["count".to_string()], - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Count { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -318,13 +317,13 @@ impl Default for Min { impl Min { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Min { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -403,13 +402,13 @@ impl Default for Max { impl Max { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::variadic_any(Immutable), } } } impl AggregateUDFImpl for Max { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 9207ad00993c..29c62440abb1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1399,7 +1399,7 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, + col, cube, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; @@ -1414,19 +1414,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1441,28 +1441,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = expr::Sort::new(col("age"), true, true); - let name_desc = expr::Sort::new(col("name"), false, true); - let created_at_desc = expr::Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let age_asc = Sort::new(col("age"), true, true); + let name_desc = Sort::new(col("name"), false, true); + let created_at_desc = Sort::new(col("created_at"), false, true); + let max1 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( + let min3 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( + let sum4 = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 349968c3fa2f..222914315d70 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -94,7 +94,7 @@ pub struct WindowFrame { } impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, "{} BETWEEN {} AND {}", @@ -416,7 +416,7 @@ fn convert_frame_bound_to_scalar_value( } impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { WindowFrameBound::Preceding(n) => { if n.is_null() { @@ -457,7 +457,7 @@ pub enum WindowFrameUnits { } impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.write_str(match self { WindowFrameUnits::Rows => "ROWS", WindowFrameUnits::Range => "RANGE", diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs index ee61128979e1..07fa4efc990e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs @@ -25,6 +25,7 @@ use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::Arc; /// Specialized implementation of @@ -86,7 +87,7 @@ impl Accumulator for BytesDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } @@ -146,6 +147,6 @@ impl Accumulator for BytesViewDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index d128a8af58ee..405b2c2db7bd 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -23,6 +23,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; +use std::mem::size_of_val; use std::sync::Arc; use ahash::RandomState; @@ -117,8 +118,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } @@ -206,8 +206,7 @@ where fn size(&self) -> usize { let num_elements = self.values.len(); - let fixed_size = - std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); + let fixed_size = size_of_val(self) + size_of_val(&self.values); estimate_memory_size::(num_elements, fixed_size).unwrap() } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index c936c80cbed7..03e4ef557269 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -23,6 +23,8 @@ pub mod bool_op; pub mod nulls; pub mod prim_op; +use std::mem::{size_of, size_of_val}; + use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, @@ -122,9 +124,7 @@ impl AccumulatorState { /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { - self.accumulator.size() - + std::mem::size_of_val(self) - + self.indices.allocated_size() + self.accumulator.size() + size_of_val(self) + self.indices.allocated_size() } } @@ -464,7 +464,7 @@ pub trait VecAllocExt { impl VecAllocExt for Vec { type T = T; fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 8bbcf756c37c..078982c983fc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; @@ -195,6 +196,6 @@ where } fn size(&self) -> usize { - self.values.capacity() * std::mem::size_of::() + self.null_state.size() + self.values.capacity() * size_of::() + self.null_state.size() } } diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index e6723b54b372..786d7ea3e361 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -33,6 +33,7 @@ use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; +use std::mem::{size_of, size_of_val}; pub const DEFAULT_MAX_SIZE: usize = 100; @@ -203,8 +204,7 @@ impl TDigest { /// Size in bytes including `Self`. pub fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.centroids.capacity()) + size_of_val(self) + (size_of::() * self.centroids.capacity()) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 83b9f714fa89..53fcfd641ddf 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; use arrow::array::{Array, RecordBatch}; @@ -486,10 +487,9 @@ impl Accumulator for ApproxPercentileAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.digest.size() - - std::mem::size_of_val(&self.digest) + size_of_val(self) + self.digest.size() - size_of_val(&self.digest) + self.return_type.size() - - std::mem::size_of_val(&self.return_type) + - size_of_val(&self.return_type) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index b86fec1e037e..5458d0f792b9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; use arrow::{ @@ -239,8 +240,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - - std::mem::size_of_val(&self.approx_percentile_cont_accumulator) + size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator) + self.approx_percentile_cont_accumulator.size() } } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 6f523756832e..b3e04c5584ef 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -34,6 +34,7 @@ use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val}; use std::sync::{Arc, OnceLock}; make_udaf_expr_and_func!( @@ -245,15 +246,15 @@ impl Accumulator for ArrayAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .map(|arr| arr.get_array_memory_size()) .sum::() + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -318,10 +319,10 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) + size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - size_of_val(&self.values) + self.datatype.size() - - std::mem::size_of_val(&self.datatype) + - size_of_val(&self.datatype) } } @@ -486,25 +487,23 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 67b824c2ea79..710b7e69ac5c 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -18,8 +18,8 @@ //! Defines `Avg` & `Mean` aggregate & accumulators use arrow::array::{ - self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, + BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; use arrow::compute::sum; @@ -47,6 +47,7 @@ use datafusion_functions_aggregate_common::utils::DecimalAverager; use log::debug; use std::any::Any; use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; use std::sync::{Arc, OnceLock}; make_udaf_expr_and_func!( @@ -294,7 +295,7 @@ impl Accumulator for AvgAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -372,7 +373,7 @@ impl Accumulator for DecimalAvgAccumu } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -471,7 +472,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -554,7 +555,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -614,7 +615,6 @@ where } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 0a281ad81467..249ff02e7222 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; @@ -347,7 +348,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -392,7 +393,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -446,7 +447,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -509,8 +510,7 @@ where } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index b410bfa139e9..87293ccfa21f 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; +use std::mem::size_of_val; use std::sync::OnceLock; use arrow::array::ArrayRef; @@ -229,7 +230,7 @@ impl Accumulator for BoolAndAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { @@ -378,7 +379,7 @@ impl Accumulator for BoolOrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn state(&mut self) -> Result> { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 40429289d768..187a43ecbea3 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; use arrow::compute::{and, filter, is_not_null}; @@ -204,11 +205,10 @@ impl Accumulator for CorrelationAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) - + self.covar.size() - - std::mem::size_of_val(&self.stddev1) + size_of_val(self) - size_of_val(&self.covar) + self.covar.size() + - size_of_val(&self.stddev1) + self.stddev1.size() - - std::mem::size_of_val(&self.stddev2) + - size_of_val(&self.stddev2) + self.stddev2.size() } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index b4eeb937d4fb..bade589a908a 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -21,6 +21,7 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewD use datafusion_physical_expr::expressions; use std::collections::HashSet; use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; use std::sync::{Arc, OnceLock}; @@ -394,7 +395,7 @@ impl Accumulator for CountAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = downcast_value!(states[0], Int64Array); - let delta = &arrow::compute::sum(counts); + let delta = &compute::sum(counts); if let Some(d) = delta { self.count += *d; } @@ -410,7 +411,7 @@ impl Accumulator for CountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -583,7 +584,7 @@ impl GroupsAccumulator for CountGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() } } @@ -627,28 +628,28 @@ impl DistinctCountAccumulator { // number of batches This method is faster than .full_size(), however it is // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .unwrap_or(0) - + std::mem::size_of::() + + size_of::() } // calculates the size as accurately as possible. Note that calling this // method is expensive fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) + size_of_val(self) + + (size_of::() * self.values.capacity()) + self .values .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) .sum::() - + std::mem::size_of::() + + size_of::() } } diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 4b2b21059d16..063aaa92059d 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -18,6 +18,7 @@ //! [`CovarianceSample`]: covariance sample aggregations. use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::OnceLock; use arrow::{ @@ -448,6 +449,6 @@ impl Accumulator for CovarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index c708d23ae6c5..da3fc62f8c8c 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, AsArray, BooleanArray}; @@ -365,10 +366,10 @@ impl Accumulator for FirstValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + size_of_val(self) - size_of_val(&self.first) + self.first.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -698,10 +699,10 @@ impl Accumulator for LastValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + size_of_val(self) - size_of_val(&self.last) + self.last.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -795,7 +796,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); @@ -825,7 +826,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 558d3055f1bf..27949aa3df27 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -45,7 +45,7 @@ pub struct Grouping { } impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut std::fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Grouping") .field("name", &self.name()) .field("signature", &self.signature) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index e0011e2e0f69..ff0a930d490b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -17,6 +17,7 @@ use std::collections::HashSet; use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, size_of_val}; use std::sync::{Arc, OnceLock}; use arrow::array::{downcast_integer, ArrowNumericType}; @@ -62,7 +63,7 @@ pub struct Median { } impl Debug for Median { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { f.debug_struct("Median") .field("name", &self.name()) .field("signature", &self.signature) @@ -195,7 +196,7 @@ struct MedianAccumulator { all_values: Vec, } -impl std::fmt::Debug for MedianAccumulator { +impl Debug for MedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MedianAccumulator({})", self.data_type) } @@ -235,8 +236,7 @@ impl Accumulator for MedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.all_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.all_values.capacity() * size_of::() } } @@ -252,7 +252,7 @@ struct DistinctMedianAccumulator { distinct_values: HashSet>, } -impl std::fmt::Debug for DistinctMedianAccumulator { +impl Debug for DistinctMedianAccumulator { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DistinctMedianAccumulator({})", self.data_type) } @@ -307,8 +307,7 @@ impl Accumulator for DistinctMedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.distinct_values.capacity() * std::mem::size_of::() + size_of_val(self) + self.distinct_values.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 8102d0e4794b..b4256508e351 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -60,6 +60,7 @@ use datafusion_expr::{ }; use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use half::f16; +use std::mem::size_of_val; use std::ops::Deref; use std::sync::OnceLock; @@ -923,7 +924,7 @@ impl Accumulator for MaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } @@ -982,7 +983,7 @@ impl Accumulator for SlidingMaxAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + size_of_val(self) - size_of_val(&self.max) + self.max.size() } } @@ -1231,7 +1232,7 @@ impl Accumulator for MinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } @@ -1294,7 +1295,7 @@ impl Accumulator for SlidingMinAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + size_of_val(self) - size_of_val(&self.min) + self.min.size() } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index e3f01b91bf3e..501454edf77c 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -22,6 +22,7 @@ use arrow_schema::DataType; use datafusion_common::{internal_err, Result}; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use std::mem::size_of; use std::sync::Arc; /// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], @@ -509,7 +510,6 @@ impl MinMaxBytesState { } fn size(&self) -> usize { - self.total_data_bytes - + self.min_max.len() * std::mem::size_of::>>() + self.total_data_bytes + self.min_max.len() * size_of::>>() } } diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 3e7f51af5265..2a1778d8b232 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::collections::VecDeque; +use std::mem::{size_of, size_of_val}; use std::sync::{Arc, OnceLock}; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; @@ -378,25 +379,23 @@ impl Accumulator for NthValueAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec_deque(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index a1fc5b094276..bf1e81949d23 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -36,6 +36,7 @@ use datafusion_expr::{ use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::OnceLock; macro_rules! make_regr_udaf_expr_and_func { @@ -614,6 +615,6 @@ impl Accumulator for RegrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 0d1821687524..355d1d5ad2db 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::mem::align_of_val; use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; @@ -343,8 +344,7 @@ impl Accumulator for StddevAccumulator { } fn size(&self) -> usize { - std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) - + self.variance.size() + align_of_val(self) - align_of_val(&self.variance) + self.variance.size() } fn supports_retract_batch(&self) -> bool { diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 66fc19910696..68267b9f72c7 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -29,6 +29,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::Literal; use std::any::Any; +use std::mem::size_of_val; use std::sync::OnceLock; make_udaf_expr_and_func!( @@ -179,7 +180,7 @@ impl Accumulator for StringAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + self.delimiter.capacity() } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 943f66a92c00..6ad376db4fb9 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -21,6 +21,7 @@ use ahash::RandomState; use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; +use std::mem::{size_of, size_of_val}; use std::sync::OnceLock; use arrow::array::Array; @@ -310,7 +311,7 @@ impl Accumulator for SumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -370,7 +371,7 @@ impl Accumulator for SlidingSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -464,7 +465,6 @@ impl Accumulator for DistinctSumAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() + size_of_val(self) + self.values.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 8453c9d3010b..810247a2884a 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -24,6 +24,7 @@ use arrow::{ compute::kernels::cast, datatypes::{DataType, Field}, }; +use std::mem::{size_of, size_of_val}; use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; @@ -424,7 +425,7 @@ impl Accumulator for VarianceAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } fn supports_retract_batch(&self) -> bool { @@ -529,7 +530,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -555,7 +556,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow::array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); @@ -606,8 +607,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } fn size(&self) -> usize { - self.m2s.capacity() * std::mem::size_of::() - + self.means.capacity() * std::mem::size_of::() - + self.counts.capacity() * std::mem::size_of::() + self.m2s.capacity() * size_of::() + + self.means.capacity() * size_of::() + + self.counts.capacity() * size_of::() } } diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index 19a22690980b..4f890e4166e9 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -247,7 +247,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - DataType::Float64 => Ok(as_float64_array(array)?.clone()), + Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index abd7649e9ec7..c2c6f24948b8 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -122,7 +122,7 @@ impl ScalarUDFImpl for MakeArray { if let Some(new_type) = type_union_resolution(arg_types) { // TODO: Move FixedSizeList to List in type_union_resolution if let DataType::FixedSizeList(field, _) = new_type { - Ok(vec![DataType::List(field); arg_types.len()]) + Ok(vec![List(field); arg_types.len()]) } else if new_type.is_null() { Ok(vec![DataType::Int64; arg_types.len()]) } else { @@ -174,7 +174,7 @@ fn get_make_array_doc() -> &'static Documentation { // Empty array is a special case that is useful for many other array functions pub(super) fn empty_array_type() -> DataType { - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))) + List(Arc::new(Field::new("item", DataType::Int64, true))) } /// `make_array_inner` is the implementation of the `make_array` function. diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index f28de1c3b2c7..03e381e372f6 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -66,7 +66,7 @@ impl ScalarUDFImpl for MapKeysFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_keys expects single argument"); } @@ -79,7 +79,7 @@ impl ScalarUDFImpl for MapKeysFunc { )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(map_keys_inner)(args) } diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 2b19d9fbbc76..dc7d9c9db8ee 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -66,7 +66,7 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if arg_types.len() != 1 { return exec_err!("map_values expects single argument"); } @@ -79,7 +79,7 @@ impl ScalarUDFImpl for MapValuesFunc { )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(map_values_inner)(args) } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 342f99274aca..b2c7f06d5868 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -124,7 +124,7 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type_from_exprs( &self, - args: &[datafusion_expr::Expr], + args: &[Expr], schema: &dyn datafusion_common::ExprSchema, _arg_types: &[DataType], ) -> Result { diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 5873b4e1af41..717a74797c0b 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -49,7 +49,7 @@ impl ExprPlanner for CoreFunctionPlanner { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( if is_named_struct { - crate::core::named_struct() + named_struct() } else { crate::core::r#struct() }, diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 78bd7c63a412..c8ef349dfbeb 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -89,9 +89,9 @@ impl ScalarUDFImpl for MakeDateFunc { ColumnarValue::Array(a) => Some(a.len()), }); - let years = args[0].cast_to(&DataType::Int32, None)?; - let months = args[1].cast_to(&DataType::Int32, None)?; - let days = args[2].cast_to(&DataType::Int32, None)?; + let years = args[0].cast_to(&Int32, None)?; + let months = args[1].cast_to(&Int32, None)?; + let days = args[2].cast_to(&Int32, None)?; let scalar_value_fn = |col: &ColumnarValue| -> Result { let ColumnarValue::Scalar(s) = col else { diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 430dcedd92cf..2fbfb2261180 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -222,10 +222,7 @@ fn _to_char_scalar( if is_scalar_expression { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } else { - return Ok(ColumnarValue::Array(new_null_array( - &DataType::Utf8, - array.len(), - ))); + return Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))); } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 7646137ce656..376cb6f5f2f8 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -68,7 +68,7 @@ impl ToLocalTimeFunc { let time_value = &args[0]; let arg_type = time_value.data_type(); match arg_type { - DataType::Timestamp(_, None) => { + Timestamp(_, None) => { // if no timezone specified, just return the input Ok(time_value.clone()) } @@ -78,7 +78,7 @@ impl ToLocalTimeFunc { // for more details. // // Then remove the timezone in return type, i.e. return None - DataType::Timestamp(_, Some(timezone)) => { + Timestamp(_, Some(timezone)) => { let tz: Tz = timezone.parse()?; match time_value { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 9479e25fe61f..60482ee3c74a 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -374,7 +374,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { static TO_TIMESTAMP_MILLIS_DOC: OnceLock = OnceLock::new(); fn get_to_timestamp_millis_doc() -> &'static Documentation { - crate::datetime::to_timestamp::TO_TIMESTAMP_MILLIS_DOC.get_or_init(|| { + TO_TIMESTAMP_MILLIS_DOC.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_DATETIME) .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") @@ -1008,7 +1008,7 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); + assert!(matches!(rt, Timestamp(_, Some(_)))); let res = udf .invoke(&[array.clone()]) @@ -1018,7 +1018,7 @@ mod tests { _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, Some(_)))); + assert!(matches!(ty, Timestamp(_, Some(_)))); } } @@ -1051,7 +1051,7 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); - assert!(matches!(rt, DataType::Timestamp(_, None))); + assert!(matches!(rt, Timestamp(_, None))); let res = udf .invoke(&[array.clone()]) @@ -1061,7 +1061,7 @@ mod tests { _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, DataType::Timestamp(_, None))); + assert!(matches!(ty, Timestamp(_, None))); } } } @@ -1137,10 +1137,7 @@ mod tests { .expect("that to_timestamp with format args parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 1); - assert!(matches!( - parsed_array.data_type(), - DataType::Timestamp(_, None) - )); + assert!(matches!(parsed_array.data_type(), Timestamp(_, None))); match time_unit { Nanosecond => { diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 4b87284744d3..bacdf47524f4 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -94,7 +94,7 @@ fn get_factorial_doc() -> &'static Documentation { /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => { + Int64 => { let arg = downcast_arg!((&args[0]), "value", Int64Array); Ok(arg .iter() diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index cf0f53a80a43..6000e5d765de 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -138,7 +138,7 @@ pub fn round(args: &[ArrayRef]) -> Result { } match args[0].data_type() { - DataType::Float64 => match decimal_places { + Float64 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { exec_datafusion_err!( @@ -181,7 +181,7 @@ pub fn round(args: &[ArrayRef]) -> Result { } }, - DataType::Float32 => match decimal_places { + Float32 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { let decimal_places: i32 = decimal_places.try_into().map_err(|e| { exec_datafusion_err!( diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index 2e0e2c48390f..e0cec3cb5756 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; + use arrow::array::{ make_view, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ByteView, GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, @@ -122,9 +124,8 @@ pub struct StringArrayBuilder { impl StringArrayBuilder { pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); // SAFETY: the first offset value is definitely not going to exceed the bounds. unsafe { offsets_buffer.push_unchecked(0_i32) }; Self { @@ -186,7 +187,7 @@ impl StringArrayBuilder { pub fn finish(self, null_buffer: Option) -> StringArray { let array_builder = ArrayDataBuilder::new(DataType::Utf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .len(self.offsets_buffer.len() / size_of::() - 1) .add_buffer(self.offsets_buffer.into()) .add_buffer(self.value_buffer.into()) .nulls(null_buffer); @@ -273,9 +274,8 @@ pub struct LargeStringArrayBuilder { impl LargeStringArrayBuilder { pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); // SAFETY: the first offset value is definitely not going to exceed the bounds. unsafe { offsets_buffer.push_unchecked(0_i64) }; Self { @@ -337,7 +337,7 @@ impl LargeStringArrayBuilder { pub fn finish(self, null_buffer: Option) -> LargeStringArray { let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .len(self.offsets_buffer.len() / size_of::() - 1) .add_buffer(self.offsets_buffer.into()) .add_buffer(self.value_buffer.into()) .nulls(null_buffer); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 818b4c64bd20..4d6574d2bd6c 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -107,7 +107,7 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>>()?; + .collect::>>()?; let result = (inner)(&args); if is_scalar { diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b3b24724552a..454afa24b628 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,7 +101,7 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + col, exists, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; @@ -219,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(expr::WindowFunction::new( + .window(vec![Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c0bddf1153f..0ffc954388f5 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -305,7 +305,7 @@ mod test { vec![] } - fn schema(&self) -> &datafusion_common::DFSchemaRef { + fn schema(&self) -> &DFSchemaRef { &self.empty_schema } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 33eea1a661c6..5d33b58a0241 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1243,7 +1243,7 @@ mod test { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(Utf8) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { @@ -1446,7 +1446,7 @@ mod test { cast(lit("2002-05-08"), DataType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ @@ -1462,7 +1462,7 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = @@ -1517,7 +1517,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1525,7 +1525,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1545,7 +1545,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1553,7 +1553,7 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; @@ -1581,7 +1581,7 @@ mod test { let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); @@ -1599,7 +1599,7 @@ mod test { #[test] fn concat_for_type_coercion() -> Result<()> { - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature @@ -1734,7 +1734,7 @@ mod test { true, ), Field::new("binary", DataType::Binary, true), - Field::new("string", DataType::Utf8, true), + Field::new("string", Utf8, true), Field::new("decimal", DataType::Decimal128(10, 10), true), ] .into(), @@ -1751,7 +1751,7 @@ mod test { else_expr: None, }; let case_when_common_type = DataType::Boolean; - let then_else_common_type = DataType::Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1770,8 +1770,8 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = DataType::Utf8; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), &case_when_common_type, @@ -1861,7 +1861,7 @@ mod test { Some("list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1869,7 +1869,7 @@ mod test { Some("large_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1877,7 +1877,7 @@ mod test { Some("list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1885,7 +1885,7 @@ mod test { Some("fixed_list"), vec![(Box::new(col("list")), Box::new(lit("1")))], DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1893,7 +1893,7 @@ mod test { Some("fixed_list"), vec![(Box::new(col("large_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); @@ -1901,7 +1901,7 @@ mod test { Some("large_list"), vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), - DataType::Utf8, + Utf8, schema ); Ok(()) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index cdffa8c645ea..cc1687cffe92 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -357,9 +357,9 @@ fn build_join( .for_each(|cols| all_correlated_cols.extend(cols.clone())); // alias the join filter - let join_filter_opt = - conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &alias).map(Option::Some) + let join_filter_opt = conjunction(pull_up.join_filters) + .map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 829d4c2d2217..267615c3e0d9 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -57,10 +57,7 @@ impl OptimizerRule for EliminateLimit { &self, plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result< - datafusion_common::tree_node::Transformed, - datafusion_common::DataFusionError, - > { + ) -> Result, datafusion_common::DataFusionError> { match plan { LogicalPlan::Limit(limit) => { // Only supports rewriting for literal fetch diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a6c0a7310610..f8e614a0aa84 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2387,7 +2387,7 @@ mod tests { .collect()) } - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 7b931e73abf9..2e2c8fb1d6f8 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -318,8 +318,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; // join our sub query into the main plan diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index f9dfadc70826..ce6734616b80 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1537,7 +1537,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); @@ -1577,7 +1577,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1597,7 +1597,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1617,7 +1617,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1637,7 +1637,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1657,7 +1657,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -3818,7 +3818,7 @@ mod tests { fn test_simplify_udaf() { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3832,7 +3832,7 @@ mod tests { let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); let aggregate_function_expr = - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(expr::AggregateFunction::new_udf( udaf.into(), vec![], false, @@ -3882,7 +3882,7 @@ mod tests { fn accumulator( &self, - _acc_args: function::AccumulatorArgs, + _acc_args: AccumulatorArgs, ) -> Result> { unimplemented!("not needed for tests") } @@ -3912,9 +3912,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3922,9 +3921,8 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = Expr::WindowFunction( - datafusion_expr::expr::WindowFunction::new(udwf, vec![]), - ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 74251e5caad2..01875349c922 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -279,7 +279,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr::{self, GroupingSet}; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; @@ -288,7 +288,7 @@ mod tests { use datafusion_functions_aggregate::sum::sum_udaf; fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( max_udaf(), vec![expr], true, @@ -569,7 +569,7 @@ mod tests { let table_scan = test_table_scan()?; // sum(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, @@ -612,7 +612,7 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![col("a")], false, diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 03ac4769d9d9..80c4963ae035 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -31,7 +31,7 @@ use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; -use std::mem; +use std::mem::{size_of, swap}; use std::ops::Range; use std::sync::Arc; @@ -260,7 +260,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - mem::swap(self, &mut new_self); + swap(self, &mut new_self); new_self } @@ -545,7 +545,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * mem::size_of::() + + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } @@ -575,7 +575,7 @@ where } /// Maximum size of a value that can be inlined in the hash table -const SHORT_VALUE_LEN: usize = mem::size_of::(); +const SHORT_VALUE_LEN: usize = size_of::(); /// Entry in the hash table -- see [`ArrowBytesMap`] for more details #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 6c4bf156ce56..d825bfe7e264 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -143,7 +143,7 @@ impl Hash for PhysicalSortExpr { } impl Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "{} {}", self.expr, to_str(&self.options)) } } @@ -188,7 +188,7 @@ impl PhysicalSortExpr { pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { struct DisplayableList<'a>(&'a [PhysicalSortExpr]); impl<'a> Display for DisplayableList<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let mut first = true; for sort_expr in self.0 { if first { @@ -260,7 +260,7 @@ impl PartialEq for PhysicalSortRequirement { } impl Display for PhysicalSortRequirement { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let opts_string = self.options.as_ref().map_or("NA", to_str); write!(f, "{} {}", self.expr, opts_string) } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a0cc29685f77..9a16b205ae25 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1113,7 +1113,7 @@ impl EquivalenceProperties { /// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] /// ``` impl Display for EquivalenceProperties { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.eq_group.is_empty() && self.oeq_class.is_empty() && self.constants.is_empty() diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index ffb431b200f2..981e49d73750 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1096,16 +1096,15 @@ mod tests { let expr2 = Arc::clone(&expr) .transform(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -1117,16 +1116,15 @@ mod tests { let expr3 = Arc::clone(&expr) .transform_down(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5621473c4fdb..457c47097a19 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -693,7 +693,7 @@ mod tests { let result = cast( col("a", &schema).unwrap(), &schema, - DataType::Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::MonthDayNano), ); result.expect_err("expected Invalid CAST"); } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 4aad959584ac..3e2d49e9fa69 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -107,7 +107,7 @@ impl std::fmt::Display for Column { impl PhysicalExpr for Column { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 0a3e5fcefcf6..cf57ce3e0e21 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1102,7 +1102,7 @@ mod tests { let mut phy_exprs = vec![ lit(1i64), expressions::cast(lit(2i32), &schema, DataType::Int64)?, - expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?, + try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); @@ -1130,7 +1130,7 @@ mod tests { try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); // column - phy_exprs.push(expressions::col("a", &schema)?); + phy_exprs.push(col("a", &schema)?); assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); Ok(()) diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index b5ebc250cb89..399ebde9f726 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -257,7 +257,7 @@ mod tests { #[test] fn test_negation_valid_types() -> Result<()> { let negatable_types = [ - DataType::Int8, + Int8, DataType::Timestamp(TimeUnit::Second, None), DataType::Interval(IntervalUnit::YearMonth), ]; diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index cb7221e7fa15..590efd577963 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -57,7 +57,7 @@ impl std::fmt::Display for UnKnownColumn { impl PhysicalExpr for UnKnownColumn { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index f05ac3624b8e..8084a52c78d8 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use super::utils::{ @@ -128,12 +129,11 @@ impl ExprIntervalGraph { /// Estimate size of bytes including `Self`. pub fn size(&self) -> usize { let node_memory_usage = self.graph.node_count() - * (std::mem::size_of::() - + std::mem::size_of::()); - let edge_memory_usage = self.graph.edge_count() - * (std::mem::size_of::() + std::mem::size_of::() * 2); + * (size_of::() + size_of::()); + let edge_memory_usage = + self.graph.edge_count() * (size_of::() + size_of::() * 2); - std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage + size_of_val(self) + node_memory_usage + edge_memory_usage } } diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 01f72a8efd9a..98c0c864b9f7 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -121,8 +121,8 @@ pub enum Partitioning { UnknownPartitioning(usize), } -impl fmt::Display for Partitioning { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for Partitioning { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Partitioning::RoundRobinBatch(size) => write!(f, "RoundRobinBatch({size})"), Partitioning::Hash(phy_exprs, size) => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index f789af8b8a02..013c027e7306 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -19,6 +19,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// @@ -73,7 +74,7 @@ impl GroupValues for GroupValuesByes { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs index 1a0cb90a16d4..7379b7a538b4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -20,6 +20,7 @@ use arrow_array::{Array, ArrayRef, RecordBatch}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8View/BinaryView values /// @@ -74,7 +75,7 @@ impl GroupValues for GroupValuesBytesView { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 4ad75844f7b7..958a4b58d800 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -35,8 +35,8 @@ use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; - use hashbrown::raw::RawTable; +use std::mem::size_of; /// A [`GroupValues`] that stores multiple columns of group values. /// @@ -351,7 +351,7 @@ impl GroupValues for GroupValuesColumn { self.group_values.clear(); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 41534958602e..bba59b6d0caa 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -37,7 +37,7 @@ use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow_array::types::GenericStringType; use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use std::marker::PhantomData; -use std::mem; +use std::mem::{replace, size_of}; use std::sync::Arc; use std::vec; @@ -292,7 +292,7 @@ where } fn size(&self) -> usize { - self.buffer.capacity() * std::mem::size_of::() + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.nulls.allocated_size() } @@ -488,7 +488,7 @@ impl ByteViewGroupValueBuilder { // If current block isn't big enough, flush it and create a new in progress block if require_cap > self.max_block_size { - let flushed_block = mem::replace( + let flushed_block = replace( &mut self.in_progress, Vec::with_capacity(self.max_block_size), ); @@ -611,7 +611,7 @@ impl ByteViewGroupValueBuilder { // The `n == len` case, we need to take all if self.len() == n { let new_builder = Self::new().with_max_block_size(self.max_block_size); - let cur_builder = std::mem::replace(self, new_builder); + let cur_builder = replace(self, new_builder); return cur_builder.build_inner(); } @@ -759,7 +759,7 @@ impl ByteViewGroupValueBuilder { } fn flush_in_progress(&mut self) { - let flushed_block = mem::replace( + let flushed_block = replace( &mut self.in_progress, Vec::with_capacity(self.max_block_size), ); @@ -785,14 +785,14 @@ impl GroupColumn for ByteViewGroupValueBuilder { let buffers_size = self .completed .iter() - .map(|buf| buf.capacity() * std::mem::size_of::()) + .map(|buf| buf.capacity() * size_of::()) .sum::(); self.nulls.allocated_size() - + self.views.capacity() * std::mem::size_of::() - + self.in_progress.capacity() * std::mem::size_of::() + + self.views.capacity() * size_of::() + + self.in_progress.capacity() * size_of::() + buffers_size - + std::mem::size_of::() + + size_of::() } fn build(self: Box) -> ArrayRef { diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index d5b7f1b11ac5..05214ec10d68 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -30,6 +30,7 @@ use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::raw::RawTable; +use std::mem::size_of; use std::sync::Arc; /// A trait to allow hashing of floating point numbers @@ -151,7 +152,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::() + self.values.allocated_size() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 8ca88257bf1a..de0ae2e07dd2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -27,6 +27,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use std::mem::size_of; use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] @@ -231,10 +232,8 @@ impl GroupValues for GroupValuesRows { // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); - *array = dictionary_encode_if_necessary( - Arc::::clone(array), - expected, - )?; + *array = + dictionary_encode_if_necessary(Arc::::clone(array), expected)?; } self.group_values = Some(group_values); @@ -249,7 +248,7 @@ impl GroupValues for GroupValuesRows { }); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } @@ -267,7 +266,7 @@ fn dictionary_encode_if_necessary( .zip(struct_array.columns()) .map(|(expected_field, column)| { dictionary_encode_if_necessary( - Arc::::clone(column), + Arc::::clone(column), expected_field.data_type(), ) }) @@ -286,13 +285,13 @@ fn dictionary_encode_if_necessary( Arc::::clone(expected_field), list.offsets().clone(), dictionary_encode_if_necessary( - Arc::::clone(list.values()), + Arc::::clone(list.values()), expected_field.data_type(), )?, list.nulls().cloned(), )?)) } (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), - (_, _) => Ok(Arc::::clone(&array)), + (_, _) => Ok(Arc::::clone(&array)), } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f36bd920e83c..48a03af19dbd 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1485,7 +1485,7 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { // In spill mode, we test with the limited memory, if the mem usage exceeds, @@ -1557,8 +1557,7 @@ mod tests { input_schema, )?); - let result = - common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); @@ -1625,7 +1624,7 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1671,7 +1670,7 @@ mod tests { } else { Arc::clone(&task_ctx) }; - let result = common::collect(merged_aggregate.execute(0, task_ctx)?).await?; + let result = collect(merged_aggregate.execute(0, task_ctx)?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); @@ -1971,7 +1970,7 @@ mod tests { } let stream: SendableRecordBatchStream = stream.into(); - let err = common::collect(stream).await.unwrap_err(); + let err = collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. let err = err.find_root(); @@ -2522,7 +2521,7 @@ mod tests { let input = Arc::new(MemoryExec::try_new( &[vec![batch.clone()]], - Arc::::clone(&batch.schema()), + Arc::::clone(&batch.schema()), None, )?); let aggregate_exec = Arc::new(AggregateExec::try_new( diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index d64c99ba1bee..218855459b1e 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -16,6 +16,7 @@ // under the License. use datafusion_expr::EmitTo; +use std::mem::size_of; /// Tracks grouping state when the data is ordered entirely by its /// group keys @@ -139,7 +140,7 @@ impl GroupOrderingFull { } pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 483150ee61af..accb2fda1131 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -20,6 +20,7 @@ use arrow_schema::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; +use std::mem::size_of; mod full; mod partial; @@ -118,7 +119,7 @@ impl GroupOrdering { /// Return the size of memory used by the ordering state, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + match self { GroupOrdering::None => 0, GroupOrdering::Partial(partial) => partial.size(), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index 2cbe3bbb784e..2dd1ea8a5449 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -22,6 +22,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; +use std::mem::size_of; use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of @@ -244,7 +245,7 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.order_indices.allocated_size() + self.row_converter.size() } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 232b87de3231..34df643b6cf0 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -109,7 +109,7 @@ impl StringHashTable { Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } @@ -181,7 +181,7 @@ where Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 4e936fb37a12..e79b3c817bd1 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -125,7 +125,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { t: self.format_type, f, @@ -164,7 +164,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let t = DisplayFormatType::Default; let mut visitor = GraphvizVisitor { @@ -203,7 +203,7 @@ impl<'a> DisplayableExecutionPlan<'a> { } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { f, t: DisplayFormatType::Default, @@ -257,7 +257,7 @@ struct IndentVisitor<'a, 'b> { /// How to format each node t: DisplayFormatType, /// Write to this formatter - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// Indent size indent: usize, /// How to show metrics @@ -318,7 +318,7 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } struct GraphvizVisitor<'a, 'b> { - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// How to format each node t: DisplayFormatType, /// How to show metrics @@ -349,8 +349,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { struct Wrapper<'a>(&'a dyn ExecutionPlan, DisplayFormatType); - impl<'a> std::fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(self.1, f) } } @@ -422,14 +422,14 @@ pub trait DisplayAs { /// different from the default one /// /// Should not include a newline - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result; + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; } /// A newtype wrapper to display `T` implementing`DisplayAs` using the `Default` mode pub struct DefaultDisplay(pub T); impl fmt::Display for DefaultDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Default, f) } } @@ -438,7 +438,7 @@ impl fmt::Display for DefaultDisplay { pub struct VerboseDisplay(pub T); impl fmt::Display for VerboseDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Verbose, f) } } @@ -448,7 +448,7 @@ impl fmt::Display for VerboseDisplay { pub struct ProjectSchemaDisplay<'a>(pub &'a SchemaRef); impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let parts: Vec<_> = self .0 .fields() @@ -464,7 +464,7 @@ impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { pub struct OutputOrderingDisplay<'a>(pub &'a [PhysicalSortExpr]); impl<'a> fmt::Display for OutputOrderingDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "[")?; for (i, e) in self.0.iter().enumerate() { if i > 0 { diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index dda45ebebb0c..8b3ef5ae01e4 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -93,7 +93,7 @@ pub struct DataSinkExec { cache: PlanProperties, } -impl fmt::Debug for DataSinkExec { +impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DataSinkExec schema: {:?}", self.count_schema) } @@ -148,11 +148,7 @@ impl DataSinkExec { } impl DisplayAs for DataSinkExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "DataSinkExec: sink=")?; diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 8f2bef56da76..8f49885068fd 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -418,7 +418,7 @@ impl Stream for CrossJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } @@ -429,7 +429,7 @@ impl CrossJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { + ) -> Poll>> { loop { return match self.state { CrossJoinStreamState::WaitBuildSide => { diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 3b730c01291c..2d11e03814a3 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -18,6 +18,7 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator use std::fmt; +use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; @@ -849,7 +850,7 @@ async fn collect_left_input( // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - let fixed_size = std::mem::size_of::(); + let fixed_size = size_of::(); let estimated_hashtable_size = estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; @@ -1524,7 +1525,7 @@ impl Stream for HashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } @@ -3594,10 +3595,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::common::collect(stream) - .await - .unwrap_err() - .to_string(); + let result_string = common::collect(stream).await.unwrap_err().to_string(); assert!( result_string.contains("bad data error"), "actual: {result_string}" diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 7b7b7462f7e4..b299b495c504 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -26,7 +26,7 @@ use std::collections::{HashMap, VecDeque}; use std::fmt::Formatter; use std::fs::File; use std::io::BufReader; -use std::mem; +use std::mem::size_of; use std::ops::Range; use std::pin::Pin; use std::sync::atomic::AtomicUsize; @@ -411,13 +411,13 @@ struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: metrics::Time, /// Number of batches consumed by this operator - input_batches: metrics::Count, + input_batches: Count, /// Number of rows consumed by this operator - input_rows: metrics::Count, + input_rows: Count, /// Number of batches produced by this operator - output_batches: metrics::Count, + output_batches: Count, /// Number of rows produced by this operator - output_rows: metrics::Count, + output_rows: Count, /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, @@ -630,9 +630,9 @@ impl BufferedBatch { .iter() .map(|arr| arr.get_array_memory_size()) .sum::() - + batch.num_rows().next_power_of_two() * mem::size_of::() - + mem::size_of::>() - + mem::size_of::(); + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); let num_rows = batch.num_rows(); BufferedBatch { @@ -2332,7 +2332,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", @@ -2371,7 +2371,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2409,7 +2409,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2448,7 +2448,7 @@ mod tests { ), ]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2489,7 +2489,7 @@ mod tests { left, right, on, - JoinType::Inner, + Inner, vec![ SortOptions { descending: true, @@ -2539,7 +2539,7 @@ mod tests { ]; let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + join_collect_batch_size_equals_two(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2574,7 +2574,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2606,7 +2606,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2638,7 +2638,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2670,7 +2670,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; + let (_, batches) = join_collect(left, right, on, LeftAnti).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2701,7 +2701,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; + let (_, batches) = join_collect(left, right, on, LeftSemi).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2734,7 +2734,7 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", @@ -2766,7 +2766,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+------------+------------+------------+------------+------------+------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2798,7 +2798,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2829,7 +2829,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2865,7 +2865,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2909,7 +2909,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2958,7 +2958,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -3007,7 +3007,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -3047,14 +3047,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3132,14 +3125,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() @@ -3195,14 +3181,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() @@ -3303,14 +3282,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index bddd152341da..02c71dab3df2 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -19,6 +19,7 @@ //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; +use std::mem::size_of; use std::sync::Arc; use crate::joins::utils::{JoinFilter, JoinHashMapType}; @@ -153,8 +154,7 @@ impl PruningJoinHashMap { /// # Returns /// The size of the hash map in bytes. pub(crate) fn size(&self) -> usize { - self.map.allocation_info().1.size() - + self.next.capacity() * std::mem::size_of::() + self.map.allocation_info().1.size() + self.next.capacity() * size_of::() } /// Removes hash values from the map and the list based on the given pruning diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 70ada3892aea..eb6a30d17e92 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -27,6 +27,7 @@ use std::any::Any; use std::fmt::{self, Debug}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; @@ -604,7 +605,7 @@ impl Stream for SymmetricHashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { self.poll_next_impl(cx) } @@ -1004,15 +1005,15 @@ pub struct OneSideHashJoiner { impl OneSideHashJoiner { pub fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(self); - size += std::mem::size_of_val(&self.build_side); + size += size_of_val(self); + size += size_of_val(&self.build_side); size += self.input_buffer.get_array_memory_size(); - size += std::mem::size_of_val(&self.on); + size += size_of_val(&self.on); size += self.hashmap.size(); - size += self.hashes_buffer.capacity() * std::mem::size_of::(); - size += self.visited_rows.capacity() * std::mem::size_of::(); - size += std::mem::size_of_val(&self.offset); - size += std::mem::size_of_val(&self.deleted_offset); + size += self.hashes_buffer.capacity() * size_of::(); + size += self.visited_rows.capacity() * size_of::(); + size += size_of_val(&self.offset); + size += size_of_val(&self.deleted_offset); size } pub fn new( @@ -1463,18 +1464,18 @@ impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.schema); - size += std::mem::size_of_val(&self.filter); - size += std::mem::size_of_val(&self.join_type); + size += size_of_val(&self.schema); + size += size_of_val(&self.filter); + size += size_of_val(&self.join_type); size += self.left.size(); size += self.right.size(); - size += std::mem::size_of_val(&self.column_indices); + size += size_of_val(&self.column_indices); size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); - size += std::mem::size_of_val(&self.left_sorted_filter_expr); - size += std::mem::size_of_val(&self.right_sorted_filter_expr); - size += std::mem::size_of_val(&self.random_state); - size += std::mem::size_of_val(&self.null_equals_null); - size += std::mem::size_of_val(&self.metrics); + size += size_of_val(&self.left_sorted_filter_expr); + size += size_of_val(&self.right_sorted_filter_expr); + size += size_of_val(&self.random_state); + size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.metrics); size } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 17a32a67c743..090cf9aa628a 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -369,7 +369,7 @@ impl JoinHashMapType for JoinHashMap { } } -impl fmt::Debug for JoinHashMap { +impl Debug for JoinHashMap { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } @@ -727,8 +727,8 @@ impl Default for OnceAsync { } } -impl std::fmt::Debug for OnceAsync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for OnceAsync { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "OnceAsync") } } @@ -1952,13 +1952,13 @@ mod tests { ) -> Statistics { Statistics { num_rows: if is_exact { - num_rows.map(Precision::Exact) + num_rows.map(Exact) } else { - num_rows.map(Precision::Inexact) + num_rows.map(Inexact) } - .unwrap_or(Precision::Absent), + .unwrap_or(Absent), column_statistics: column_stats, - total_byte_size: Precision::Absent, + total_byte_size: Absent, } } @@ -2204,17 +2204,17 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact((400 * 400) / 200)) + Some(Inexact((400 * 400) / 200)) ); Ok(()) } @@ -2222,33 +2222,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact(100)) + Some(Inexact(100)) ); Ok(()) } diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index eda75b37fe66..1fe550a93056 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -473,7 +473,7 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::collect; - use crate::{common, test}; + use crate::test; use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use arrow_array::RecordBatchOptions; @@ -497,7 +497,7 @@ mod tests { // The result should contain 4 batches (one per input partition) let iter = limit.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; // There should be a total of 100 rows let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); @@ -613,7 +613,7 @@ mod tests { // The result should contain 4 batches (one per input partition) let iter = offset.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 52a8631d5a63..dd4868d1bfcc 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -69,11 +69,7 @@ impl fmt::Debug for MemoryExec { } impl DisplayAs for MemoryExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let partition_sizes: Vec<_> = diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index 5a335d9f99cd..2eb01914ee0a 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -168,7 +168,7 @@ impl PartialEq for Time { impl Display for Time { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let duration = std::time::Duration::from_nanos(self.value() as u64); + let duration = Duration::from_nanos(self.value() as u64); write!(f, "{duration:?}") } } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 936cf742a792..c1d3f368366f 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -356,7 +356,6 @@ impl RecordBatchStream for ProjectionStream { mod tests { use super::*; use crate::common::collect; - use crate::expressions; use crate::test; use arrow_schema::DataType; @@ -418,8 +417,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col1", 1)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col1", 1)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); @@ -452,8 +451,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col2", 2)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col2", 2)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs index 675d26bbfb9f..2e5ef24beac3 100644 --- a/datafusion/physical-plan/src/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -829,7 +829,7 @@ mod tests { { let test_waker = Arc::new(TestWaker::default()); let waker = futures::task::waker(Arc::clone(&test_waker)); - let mut cx = std::task::Context::from_waker(&waker); + let mut cx = Context::from_waker(&waker); let res = fut.poll_unpin(&mut cx); (res, test_waker) } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 90e62d6f11f8..601c1e873152 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1326,7 +1326,7 @@ mod tests { // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); // Now, start sending input let mut background_task = JoinSet::new(); @@ -1401,7 +1401,7 @@ mod tests { let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); let mut background_task = JoinSet::new(); background_task.spawn(async move { input.wait().await; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 8e13a2e07e49..921678a4ad92 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -815,11 +815,7 @@ impl SortExec { } impl DisplayAs for SortExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr = PhysicalSortExpr::format_list(&self.expr); @@ -1018,7 +1014,7 @@ mod tests { } impl DisplayAs for SortedUnboundedExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "UnboundableExec",).unwrap() diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 3d3f9dcb98ee..31a4ed61cf9e 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -946,7 +946,7 @@ mod tests { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); // This causes the MergeStream to wait for more input - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + tokio::time::sleep(Duration::from_millis(10)).await; } Ok(()) diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 9220646653e6..ec4c9dd502a6 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -437,12 +437,12 @@ impl ObservedStream { } impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.inner.schema() } } -impl futures::Stream for ObservedStream { +impl Stream for ObservedStream { type Item = Result; fn poll_next( diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 0f7c75c2c90b..cdb94af1fe8a 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -163,7 +163,7 @@ impl StreamingTableExec { } } -impl std::fmt::Debug for StreamingTableExec { +impl Debug for StreamingTableExec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LazyMemTableExec").finish_non_exhaustive() } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index d3f1a4fd96ca..9b46ad2ec7b1 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -21,6 +21,7 @@ use arrow::{ compute::interleave, row::{RowConverter, Rows, SortField}, }; +use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use arrow_array::{Array, ArrayRef, RecordBatch}; @@ -225,7 +226,7 @@ impl TopK { /// return the size of memory used by this operator, in bytes fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.row_converter.size() + self.scratch_rows.size() + self.heap.size() @@ -444,8 +445,8 @@ impl TopKHeap { /// return the size of memory used by this heap, in bytes fn size(&self) -> usize { - std::mem::size_of::() - + (self.inner.capacity() * std::mem::size_of::()) + size_of::() + + (self.inner.capacity() * size_of::()) + self.store.size() + self.owned_bytes } @@ -636,9 +637,8 @@ impl RecordBatchStore { /// returns the size of memory used by this store, including all /// referenced `RecordBatch`es, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() - + self.batches.capacity() - * (std::mem::size_of::() + std::mem::size_of::()) + size_of::() + + self.batches.capacity() * (size_of::() + size_of::()) + self.batches_size } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 40ec3830ea0c..3e312b7451be 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -984,7 +984,7 @@ mod tests { list_array: &dyn ListArrayType, lengths: Vec, expected: Vec>, - ) -> datafusion_common::Result<()> { + ) -> Result<()> { let length_array = Int64Array::from(lengths); let unnested_array = unnest_list_array(list_array, &length_array, 3 * 6)?; let strs = unnested_array.as_string::().iter().collect::>(); @@ -993,7 +993,7 @@ mod tests { } #[test] - fn test_build_batch_list_arr_recursive() -> datafusion_common::Result<()> { + fn test_build_batch_list_arr_recursive() -> Result<()> { // col1 | col2 // [[1,2,3],null,[4,5]] | ['a','b'] // [[7,8,9,10], null, [11,12,13]] | ['c','d'] @@ -1125,7 +1125,7 @@ mod tests { } #[test] - fn test_unnest_list_array() -> datafusion_common::Result<()> { + fn test_unnest_list_array() -> Result<()> { // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = make_generic_array::(); verify_unnest_list_array( @@ -1173,7 +1173,7 @@ mod tests { list_arrays: &[ArrayRef], preserve_nulls: bool, expected: Vec, - ) -> datafusion_common::Result<()> { + ) -> Result<()> { let options = UnnestOptions { preserve_nulls, recursions: vec![], @@ -1191,7 +1191,7 @@ mod tests { } #[test] - fn test_longest_list_length() -> datafusion_common::Result<()> { + fn test_longest_list_length() -> Result<()> { // Test with single ListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; @@ -1223,7 +1223,7 @@ mod tests { } #[test] - fn test_create_take_indicies() -> datafusion_common::Result<()> { + fn test_create_take_indicies() -> Result<()> { let length_array = Int64Array::from(vec![2, 3, 1]); let take_indicies = create_take_indicies(&length_array, 6); let expected = Int64Array::from(vec![0, 0, 1, 1, 1, 2]); diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 3ffe5e3e76e7..9e4b331a01bf 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -32,9 +32,6 @@ rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] -[lints] -workspace = true - [lib] name = "datafusion_proto" path = "src/lib.rs" diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index d0f82ecac62c..02be3e11c1cb 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -161,7 +161,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -179,7 +179,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } @@ -271,7 +271,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -289,7 +289,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } @@ -570,7 +570,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -588,7 +588,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } @@ -658,7 +658,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -676,7 +676,7 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } @@ -716,7 +716,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { &self, _buf: &[u8], _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &datafusion::prelude::SessionContext, + _ctx: &SessionContext, ) -> datafusion_common::Result { not_impl_err!("Method not implemented") } @@ -734,7 +734,7 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { _buf: &[u8], _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, - _cts: &datafusion::prelude::SessionContext, + _cts: &SessionContext, ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d80c6b716537..b90ae88aa74a 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -450,7 +450,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - LogicalPlanType::CustomScan(scan) => { + CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -844,13 +844,13 @@ impl AsLogicalPlan for LogicalPlanNode { .prepare(prepare.name.clone(), data_types)? .build() } - LogicalPlanType::DropView(dropview) => Ok(datafusion_expr::LogicalPlan::Ddl( - datafusion_expr::DdlStatement::DropView(DropView { + LogicalPlanType::DropView(dropview) => { + Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { name: from_table_reference(dropview.name.as_ref(), "DropView")?, if_exists: dropview.if_exists, schema: Arc::new(convert_required!(dropview.schema)?), - }), - )), + }))) + } LogicalPlanType::CopyTo(copy) => { let input: LogicalPlan = into_logical_plan!(copy.input, ctx, extension_codec)?; @@ -859,20 +859,18 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec.try_decode_file_format(©.file_type, ctx)?, ); - Ok(datafusion_expr::LogicalPlan::Copy( - datafusion_expr::dml::CopyTo { - input: Arc::new(input), - output_url: copy.output_url.clone(), - partition_by: copy.partition_by.clone(), - file_type, - options: Default::default(), - }, - )) + Ok(LogicalPlan::Copy(dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + partition_by: copy.partition_by.clone(), + file_type, + options: Default::default(), + })) } LogicalPlanType::Unnest(unnest) => { let input: LogicalPlan = into_logical_plan!(unnest.input, ctx, extension_codec)?; - Ok(datafusion_expr::LogicalPlan::Unnest(Unnest { + Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), list_type_columns: unnest @@ -926,7 +924,7 @@ impl AsLogicalPlan for LogicalPlanNode { } as u64; let values_list = serialize_exprs(values.iter().flatten(), extension_codec)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( protobuf::ValuesNode { n_cols, @@ -1018,7 +1016,7 @@ impl AsLogicalPlan for LogicalPlanNode { exprs_vec.push(expr_vec); } - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { file_format_type: Some(file_format_type), @@ -1044,12 +1042,12 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { table_name: Some(table_name.clone().into()), input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( view_table.logical_plan(), extension_codec, )?, @@ -1082,11 +1080,11 @@ impl AsLogicalPlan for LogicalPlanNode { } } LogicalPlan::Projection(Projection { expr, input, .. }) => { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Projection(Box::new( protobuf::ProjectionNode { input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, )?, @@ -1098,12 +1096,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Filter(filter) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - filter.input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + filter.input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), @@ -1116,12 +1113,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Distinct(Distinct::All(input)) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Distinct(Box::new( protobuf::DistinctNode { input: Some(Box::new(input)), @@ -1136,16 +1132,15 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. })) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; let sort_expr = match sort_expr { None => vec![], Some(sort_expr) => serialize_sorts(sort_expr, extension_codec)?, }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( protobuf::DistinctOnNode { on_expr: serialize_exprs(on_expr, extension_codec)?, @@ -1159,12 +1154,11 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Window(Window { input, window_expr, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Window(Box::new( protobuf::WindowNode { input: Some(Box::new(input)), @@ -1179,12 +1173,11 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Aggregate(Box::new( protobuf::AggregateNode { input: Some(Box::new(input)), @@ -1204,16 +1197,14 @@ impl AsLogicalPlan for LogicalPlanNode { null_equals_null, .. }) => { - let left: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; + let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + left.as_ref(), + extension_codec, + )?; + let right: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + right.as_ref(), + extension_codec, + )?; let (left_join_key, right_join_key) = on .iter() .map(|(l, r)| { @@ -1232,7 +1223,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), @@ -1251,12 +1242,11 @@ impl AsLogicalPlan for LogicalPlanNode { not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::SubqueryAlias(Box::new( protobuf::SubqueryAliasNode { input: Some(Box::new(input)), @@ -1266,11 +1256,10 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Limit(limit) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - limit.input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + limit.input.as_ref(), + extension_codec, + )?; let SkipType::Literal(skip) = limit.get_skip_type()? else { return Err(proto_error( "LogicalPlan::Limit only supports literal skip values", @@ -1282,7 +1271,7 @@ impl AsLogicalPlan for LogicalPlanNode { )); }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Limit(Box::new( protobuf::LimitNode { input: Some(Box::new(input)), @@ -1293,14 +1282,13 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Sort(Sort { input, expr, fetch }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; let sort_expr: Vec = serialize_sorts(expr, extension_codec)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( protobuf::SortNode { input: Some(Box::new(input)), @@ -1315,11 +1303,10 @@ impl AsLogicalPlan for LogicalPlanNode { partitioning_scheme, }) => { use datafusion::logical_expr::Partitioning; - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; // Assumed common usize field was batch size // Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways @@ -1340,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode { } }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Repartition(Box::new( protobuf::RepartitionNode { input: Some(Box::new(input)), @@ -1351,7 +1338,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, .. - }) => Ok(protobuf::LogicalPlanNode { + }) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::EmptyRelation( protobuf::EmptyRelationNode { produce_one_row: *produce_one_row, @@ -1390,7 +1377,7 @@ impl AsLogicalPlan for LogicalPlanNode { .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { name: Some(name.clone().into()), @@ -1416,7 +1403,7 @@ impl AsLogicalPlan for LogicalPlanNode { or_replace, definition, temporary, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { name: Some(name.clone().into()), @@ -1436,7 +1423,7 @@ impl AsLogicalPlan for LogicalPlanNode { if_not_exists, schema: df_schema, }, - )) => Ok(protobuf::LogicalPlanNode { + )) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( protobuf::CreateCatalogSchemaNode { schema_name: schema_name.clone(), @@ -1449,7 +1436,7 @@ impl AsLogicalPlan for LogicalPlanNode { catalog_name, if_not_exists, schema: df_schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalog( protobuf::CreateCatalogNode { catalog_name: catalog_name.clone(), @@ -1459,11 +1446,11 @@ impl AsLogicalPlan for LogicalPlanNode { )), }), LogicalPlan::Analyze(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Analyze(Box::new( protobuf::AnalyzeNode { input: Some(Box::new(input)), @@ -1473,11 +1460,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Explain(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Explain(Box::new( protobuf::ExplainNode { input: Some(Box::new(input)), @@ -1490,14 +1477,9 @@ impl AsLogicalPlan for LogicalPlanNode { let inputs: Vec = union .inputs .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Union( protobuf::UnionNode { inputs }, )), @@ -1511,15 +1493,10 @@ impl AsLogicalPlan for LogicalPlanNode { .node .inputs() .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Extension( LogicalExtensionNode { node: buf, inputs }, )), @@ -1530,11 +1507,9 @@ impl AsLogicalPlan for LogicalPlanNode { data_types, input, }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Prepare(Box::new( protobuf::PrepareNode { name: name.clone(), @@ -1556,10 +1531,8 @@ impl AsLogicalPlan for LogicalPlanNode { schema, options, }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let proto_unnest_list_items = list_type_columns .iter() .map(|(index, ul)| ColumnUnnestListItem { @@ -1570,7 +1543,7 @@ impl AsLogicalPlan for LogicalPlanNode { }), }) .collect(); - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Unnest(Box::new( protobuf::UnnestNode { input: Some(Box::new(input)), @@ -1606,7 +1579,7 @@ impl AsLogicalPlan for LogicalPlanNode { name, if_exists, schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { name: Some(name.clone().into()), @@ -1637,15 +1610,13 @@ impl AsLogicalPlan for LogicalPlanNode { partition_by, .. }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; let mut buf = Vec::new(); extension_codec .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( protobuf::CopyToNode { input: Some(Box::new(input)), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 634ae284c955..326c7acab392 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -851,7 +851,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -898,7 +898,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -1713,9 +1713,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; @@ -1782,9 +1780,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3fec7d1c6ea0..14d91913e7cd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2171,7 +2171,7 @@ fn roundtrip_aggregate_udf() { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -2395,7 +2395,7 @@ fn roundtrip_window() { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 34e119c45fdf..432e8668c52e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -57,7 +57,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(sqlparser::ast::BinaryOperator), + Operator(BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 29852be3bf77..abb9912b712a 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -99,7 +99,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::Unique { + } => constraints.push(TableConstraint::Unique { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -111,7 +111,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::PrimaryKey { + } => constraints.push(TableConstraint::PrimaryKey { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -125,7 +125,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::ForeignKey { + } => constraints.push(TableConstraint::ForeignKey { name: name.clone(), columns: vec![], foreign_table: foreign_table.clone(), @@ -135,7 +135,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - constraints.push(ast::TableConstraint::Check { + constraints.push(TableConstraint::Check { name: name.clone(), expr: Box::new(expr.clone()), }) @@ -776,7 +776,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let isolation_level: ast::TransactionIsolationLevel = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(_) => None, TransactionMode::IsolationLevel(level) => Some(level), }) @@ -785,7 +785,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(ast::TransactionIsolationLevel::Serializable); let access_mode: ast::TransactionAccessMode = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(mode) => Some(mode), TransactionMode::IsolationLevel(_) => None, }) @@ -1650,7 +1650,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => { // If the target table has an alias, use it to qualify the column name if let Some(alias) = &table_alias { - datafusion_expr::Expr::Column(Column::new( + Expr::Column(Column::new( Some(self.ident_normalizer.normalize(alias.name.clone())), field.name(), )) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 02934a004d6f..88159ab6df15 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -59,8 +59,8 @@ pub trait Dialect: Send + Sync { /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Double + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::Double } /// The SQL type to use for Arrow Utf8 unparsing @@ -110,8 +110,8 @@ pub trait Dialect: Send + Sync { /// The SQL type to use for Arrow Date32 unparsing /// Most dialects use Date, but some, like SQLite require TEXT - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Date + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Date } /// Does the dialect support specifying column aliases as part of alias table definition? @@ -197,8 +197,8 @@ impl Dialect for PostgreSqlDialect { IntervalStyle::PostgresVerbose } - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::DoublePrecision + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::DoublePrecision } fn scalar_function_to_sql_overrides( @@ -245,7 +245,7 @@ impl PostgreSqlDialect { } Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -335,8 +335,8 @@ impl Dialect for SqliteDialect { DateFieldExtractStyle::Strftime } - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { - sqlparser::ast::DataType::Text + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text } fn supports_column_alias_in_table_alias(&self) -> bool { @@ -362,7 +362,7 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - float64_ast_dtype: sqlparser::ast::DataType, + float64_ast_dtype: ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, @@ -370,7 +370,7 @@ pub struct CustomDialect { int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, timestamp_tz_cast_dtype: ast::DataType, - date32_cast_dtype: sqlparser::ast::DataType, + date32_cast_dtype: ast::DataType, supports_column_alias_in_table_alias: bool, requires_derived_table_alias: bool, } @@ -382,7 +382,7 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, - float64_ast_dtype: sqlparser::ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, @@ -393,7 +393,7 @@ impl Default for CustomDialect { None, TimezoneInfo::WithTimeZone, ), - date32_cast_dtype: sqlparser::ast::DataType::Date, + date32_cast_dtype: ast::DataType::Date, supports_column_alias_in_table_alias: true, requires_derived_table_alias: false, } @@ -428,7 +428,7 @@ impl Dialect for CustomDialect { self.interval_style } - fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + fn float64_ast_dtype(&self) -> ast::DataType { self.float64_ast_dtype.clone() } @@ -464,7 +464,7 @@ impl Dialect for CustomDialect { } } - fn date32_cast_dtype(&self) -> sqlparser::ast::DataType { + fn date32_cast_dtype(&self) -> ast::DataType { self.date32_cast_dtype.clone() } @@ -509,7 +509,7 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, - float64_ast_dtype: sqlparser::ast::DataType, + float64_ast_dtype: ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, @@ -535,7 +535,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, - float64_ast_dtype: sqlparser::ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, @@ -546,7 +546,7 @@ impl CustomDialectBuilder { None, TimezoneInfo::WithTimeZone, ), - date32_cast_dtype: sqlparser::ast::DataType::Date, + date32_cast_dtype: ast::DataType::Date, supports_column_alias_in_table_alias: true, requires_derived_table_alias: false, } @@ -604,10 +604,7 @@ impl CustomDialectBuilder { } /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. - pub fn with_float64_ast_dtype( - mut self, - float64_ast_dtype: sqlparser::ast::DataType, - ) -> Self { + pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 1d0327fadbe4..6da0a32282c6 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -248,7 +248,7 @@ impl Unparser<'_> { })); Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -292,7 +292,7 @@ impl Unparser<'_> { None => None, }; Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -478,7 +478,7 @@ impl Unparser<'_> { ) -> Result { let args = self.function_args_to_sql(args)?; Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), @@ -519,7 +519,7 @@ impl Unparser<'_> { fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { - ast::DataType::Timestamp(None, ast::TimezoneInfo::None) + ast::DataType::Timestamp(None, TimezoneInfo::None) } else { ast::DataType::Datetime(None) } @@ -594,16 +594,16 @@ impl Unparser<'_> { } /// This function can create an identifier with or without quotes based on the dialect rules - pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> ast::Ident { + pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> Ident { let quote_style = self.dialect.identifier_quote_style(&ident); - ast::Ident { + Ident { value: ident, quote_style, } } - pub(super) fn new_ident_without_quote_style(&self, str: String) -> ast::Ident { - ast::Ident { + pub(super) fn new_ident_without_quote_style(&self, str: String) -> Ident { + Ident { value: str, quote_style: None, } @@ -613,7 +613,7 @@ impl Unparser<'_> { &self, lhs: ast::Expr, rhs: ast::Expr, - op: ast::BinaryOperator, + op: BinaryOperator, ) -> ast::Expr { ast::Expr::BinaryOp { left: Box::new(lhs), @@ -698,7 +698,7 @@ impl Unparser<'_> { // Closest precedence we currently have to Between is PGLikeMatch // (https://www.postgresql.org/docs/7.2/sql-precedence.html) ast::Expr::Between { .. } => { - self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) + self.sql_op_precedence(&BinaryOperator::PGLikeMatch) } _ => 0, } @@ -728,70 +728,70 @@ impl Unparser<'_> { fn sql_to_op(&self, op: &BinaryOperator) -> Result { match op { - ast::BinaryOperator::Eq => Ok(Operator::Eq), - ast::BinaryOperator::NotEq => Ok(Operator::NotEq), - ast::BinaryOperator::Lt => Ok(Operator::Lt), - ast::BinaryOperator::LtEq => Ok(Operator::LtEq), - ast::BinaryOperator::Gt => Ok(Operator::Gt), - ast::BinaryOperator::GtEq => Ok(Operator::GtEq), - ast::BinaryOperator::Plus => Ok(Operator::Plus), - ast::BinaryOperator::Minus => Ok(Operator::Minus), - ast::BinaryOperator::Multiply => Ok(Operator::Multiply), - ast::BinaryOperator::Divide => Ok(Operator::Divide), - ast::BinaryOperator::Modulo => Ok(Operator::Modulo), - ast::BinaryOperator::And => Ok(Operator::And), - ast::BinaryOperator::Or => Ok(Operator::Or), - ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), - ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), - ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), - ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), - ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), - ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), - ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), - ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), - ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), - ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), - ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), - ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), - ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), - ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), - ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), - ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + BinaryOperator::Eq => Ok(Operator::Eq), + BinaryOperator::NotEq => Ok(Operator::NotEq), + BinaryOperator::Lt => Ok(Operator::Lt), + BinaryOperator::LtEq => Ok(Operator::LtEq), + BinaryOperator::Gt => Ok(Operator::Gt), + BinaryOperator::GtEq => Ok(Operator::GtEq), + BinaryOperator::Plus => Ok(Operator::Plus), + BinaryOperator::Minus => Ok(Operator::Minus), + BinaryOperator::Multiply => Ok(Operator::Multiply), + BinaryOperator::Divide => Ok(Operator::Divide), + BinaryOperator::Modulo => Ok(Operator::Modulo), + BinaryOperator::And => Ok(Operator::And), + BinaryOperator::Or => Ok(Operator::Or), + BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + BinaryOperator::StringConcat => Ok(Operator::StringConcat), + BinaryOperator::AtArrow => Ok(Operator::AtArrow), + BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), _ => not_impl_err!("unsupported operation: {op:?}"), } } - fn op_to_sql(&self, op: &Operator) -> Result { + fn op_to_sql(&self, op: &Operator) -> Result { match op { - Operator::Eq => Ok(ast::BinaryOperator::Eq), - Operator::NotEq => Ok(ast::BinaryOperator::NotEq), - Operator::Lt => Ok(ast::BinaryOperator::Lt), - Operator::LtEq => Ok(ast::BinaryOperator::LtEq), - Operator::Gt => Ok(ast::BinaryOperator::Gt), - Operator::GtEq => Ok(ast::BinaryOperator::GtEq), - Operator::Plus => Ok(ast::BinaryOperator::Plus), - Operator::Minus => Ok(ast::BinaryOperator::Minus), - Operator::Multiply => Ok(ast::BinaryOperator::Multiply), - Operator::Divide => Ok(ast::BinaryOperator::Divide), - Operator::Modulo => Ok(ast::BinaryOperator::Modulo), - Operator::And => Ok(ast::BinaryOperator::And), - Operator::Or => Ok(ast::BinaryOperator::Or), + Operator::Eq => Ok(BinaryOperator::Eq), + Operator::NotEq => Ok(BinaryOperator::NotEq), + Operator::Lt => Ok(BinaryOperator::Lt), + Operator::LtEq => Ok(BinaryOperator::LtEq), + Operator::Gt => Ok(BinaryOperator::Gt), + Operator::GtEq => Ok(BinaryOperator::GtEq), + Operator::Plus => Ok(BinaryOperator::Plus), + Operator::Minus => Ok(BinaryOperator::Minus), + Operator::Multiply => Ok(BinaryOperator::Multiply), + Operator::Divide => Ok(BinaryOperator::Divide), + Operator::Modulo => Ok(BinaryOperator::Modulo), + Operator::And => Ok(BinaryOperator::And), + Operator::Or => Ok(BinaryOperator::Or), Operator::IsDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), Operator::IsNotDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), - Operator::RegexMatch => Ok(ast::BinaryOperator::PGRegexMatch), - Operator::RegexIMatch => Ok(ast::BinaryOperator::PGRegexIMatch), - Operator::RegexNotMatch => Ok(ast::BinaryOperator::PGRegexNotMatch), - Operator::RegexNotIMatch => Ok(ast::BinaryOperator::PGRegexNotIMatch), - Operator::ILikeMatch => Ok(ast::BinaryOperator::PGILikeMatch), - Operator::NotLikeMatch => Ok(ast::BinaryOperator::PGNotLikeMatch), - Operator::LikeMatch => Ok(ast::BinaryOperator::PGLikeMatch), - Operator::NotILikeMatch => Ok(ast::BinaryOperator::PGNotILikeMatch), - Operator::BitwiseAnd => Ok(ast::BinaryOperator::BitwiseAnd), - Operator::BitwiseOr => Ok(ast::BinaryOperator::BitwiseOr), - Operator::BitwiseXor => Ok(ast::BinaryOperator::BitwiseXor), - Operator::BitwiseShiftRight => Ok(ast::BinaryOperator::PGBitwiseShiftRight), - Operator::BitwiseShiftLeft => Ok(ast::BinaryOperator::PGBitwiseShiftLeft), - Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat), + Operator::RegexMatch => Ok(BinaryOperator::PGRegexMatch), + Operator::RegexIMatch => Ok(BinaryOperator::PGRegexIMatch), + Operator::RegexNotMatch => Ok(BinaryOperator::PGRegexNotMatch), + Operator::RegexNotIMatch => Ok(BinaryOperator::PGRegexNotIMatch), + Operator::ILikeMatch => Ok(BinaryOperator::PGILikeMatch), + Operator::NotLikeMatch => Ok(BinaryOperator::PGNotLikeMatch), + Operator::LikeMatch => Ok(BinaryOperator::PGLikeMatch), + Operator::NotILikeMatch => Ok(BinaryOperator::PGNotILikeMatch), + Operator::BitwiseAnd => Ok(BinaryOperator::BitwiseAnd), + Operator::BitwiseOr => Ok(BinaryOperator::BitwiseOr), + Operator::BitwiseXor => Ok(BinaryOperator::BitwiseXor), + Operator::BitwiseShiftRight => Ok(BinaryOperator::PGBitwiseShiftRight), + Operator::BitwiseShiftLeft => Ok(BinaryOperator::PGBitwiseShiftLeft), + Operator::StringConcat => Ok(BinaryOperator::StringConcat), Operator::AtArrow => not_impl_err!("unsupported operation: {op:?}"), Operator::ArrowAt => not_impl_err!("unsupported operation: {op:?}"), } @@ -935,17 +935,17 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) } ScalarValue::UInt64(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8View(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8View(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::LargeUtf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), @@ -978,7 +978,7 @@ impl Unparser<'_> { Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + expr: Box::new(ast::Expr::Value(SingleQuotedString( date.to_string(), ))), data_type: ast::DataType::Date, @@ -1001,7 +1001,7 @@ impl Unparser<'_> { Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + expr: Box::new(ast::Expr::Value(SingleQuotedString( datetime.to_string(), ))), data_type: self.ast_type_for_date64_in_cast(), @@ -1243,9 +1243,9 @@ impl Unparser<'_> { IntervalStyle::SQLStandard => match v { ScalarValue::IntervalYearMonth(Some(v)) => { let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(v.to_string()), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.to_string(), + ))), leading_field: Some(ast::DateTimeField::Month), leading_precision: None, last_field: None, @@ -1264,11 +1264,9 @@ impl Unparser<'_> { let millis = v.milliseconds % 1_000; let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(format!( - "{days} {hours}:{mins}:{secs}.{millis:3}" - )), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString(format!( + "{days} {hours}:{mins}:{secs}.{millis:3}" + )))), leading_field: Some(ast::DateTimeField::Day), leading_precision: None, last_field: Some(ast::DateTimeField::Second), @@ -1279,9 +1277,9 @@ impl Unparser<'_> { ScalarValue::IntervalMonthDayNano(Some(v)) => { if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(v.months.to_string()), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.months.to_string(), + ))), leading_field: Some(ast::DateTimeField::Month), leading_precision: None, last_field: None, @@ -1300,11 +1298,9 @@ impl Unparser<'_> { let millis = (v.nanoseconds % 1_000_000_000) / 1_000_000; let interval = Interval { - value: Box::new(ast::Expr::Value( - ast::Value::SingleQuotedString(format!( - "{days} {hours}:{mins}:{secs}.{millis:03}" - )), - )), + value: Box::new(ast::Expr::Value(SingleQuotedString( + format!("{days} {hours}:{mins}:{secs}.{millis:03}"), + ))), leading_field: Some(ast::DateTimeField::Day), leading_precision: None, last_field: Some(ast::DateTimeField::Second), @@ -1962,11 +1958,8 @@ mod tests { #[test] fn custom_dialect_float64_ast_dtype() -> Result<()> { for (float64_ast_dtype, identifier) in [ - (sqlparser::ast::DataType::Double, "DOUBLE"), - ( - sqlparser::ast::DataType::DoublePrecision, - "DOUBLE PRECISION", - ), + (ast::DataType::Double, "DOUBLE"), + (ast::DataType::DoublePrecision, "DOUBLE PRECISION"), ] { let dialect = CustomDialectBuilder::new() .with_float64_ast_dtype(float64_ast_dtype) @@ -2383,10 +2376,7 @@ mod tests { expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( "variation".to_string(), )))), - data_type: DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), + data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", )]; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 47caeec78dc7..b0fa17031849 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -232,10 +232,7 @@ impl ContextProvider for MockContextProvider { &self.state.config_options } - fn get_file_type( - &self, - _ext: &str, - ) -> Result> { + fn get_file_type(&self, _ext: &str) -> Result> { Ok(Arc::new(MockCsvType {})) } @@ -275,7 +272,7 @@ impl EmptyTable { } impl TableSource for EmptyTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 501fd3517a17..2479252a7b5b 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -61,7 +61,7 @@ async fn run_tests() -> Result<()> { // Enable logging (e.g. set RUST_LOG=debug to see debug logs) env_logger::init(); - let options: Options = clap::Parser::parse(); + let options: Options = Parser::parse(); if options.list { // nextest parses stdout, so print messages to stderr eprintln!("NOTICE: --list option unsupported, quitting"); @@ -264,7 +264,7 @@ fn read_dir_recursive>(path: P) -> Result> { /// Append all paths recursively to dst fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { - let entries = std::fs::read_dir(path) + let entries = fs::read_dir(path) .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; for entry in entries { let path = entry diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index deeacb1b8819..477f225443e2 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -139,7 +139,7 @@ impl TestContext { } #[cfg(feature = "avro")] -pub async fn register_avro_tables(ctx: &mut crate::TestContext) { +pub async fn register_avro_tables(ctx: &mut TestContext) { use datafusion::prelude::AvroReadOptions; ctx.enable_testdir(); diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 54b93cb7e345..99e7990df623 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -77,7 +77,7 @@ use substrait::proto::expression::literal::{ IntervalDayToSecond, IntervalYearToMonth, UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -389,7 +389,7 @@ pub async fn from_substrait_extended_expr( pub fn apply_masking( schema: DFSchema, - mask_expression: &::core::option::Option, + mask_expression: &::core::option::Option, ) -> Result { match mask_expression { Some(MaskExpression { select, .. }) => match &select.as_ref() { @@ -2117,11 +2117,7 @@ fn from_substrait_literal( let s = d.scale.try_into().map_err(|e| { substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; - ScalarValue::Decimal128( - Some(std::primitive::i128::from_le_bytes(value)), - p, - s, - ) + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) } Some(LiteralType::List(l)) => { // Each element should start the name index from the same value, then we increase it diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index da8a4c994fb4..7b5165067225 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -491,7 +491,7 @@ pub fn to_substrait_rel( .map(|ptr| *ptr) .collect(); Ok(Box::new(Rel { - rel_type: Some(substrait::proto::rel::RelType::Set(SetRel { + rel_type: Some(RelType::Set(SetRel { common: None, inputs: input_rels, op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 9739afa99244..1f654f1d3c95 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -69,7 +69,7 @@ impl SerializerRegistry for MockSerializerRegistry { &self, name: &str, bytes: &[u8], - ) -> Result> { + ) -> Result> { if name == "MockUserDefinedLogicalPlan" { MockUserDefinedLogicalPlan::deserialize(bytes) } else { @@ -1005,7 +1005,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index 72d685817d7d..54d55d1b6f10 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -20,13 +20,12 @@ mod tests { use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use datafusion_substrait::logical_plan::producer; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; use datafusion_substrait::serializer; use datafusion::error::Result; use datafusion::prelude::*; - use datafusion_substrait::logical_plan::producer::to_substrait_plan; use std::fs; use substrait::proto::plan_rel::RelType; use substrait::proto::rel_common::{Emit, EmitKind}; @@ -61,7 +60,7 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = producer::to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx); assert!(convert_result.is_ok()); Ok(()) From 89e71ef5fd058278f1a0cc659dccf1757def3a62 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Tue, 29 Oct 2024 19:31:44 +0800 Subject: [PATCH 102/110] [Optimization] Infer predicate under all JoinTypes (#13081) * optimize infer join predicate * pass clippy * chores: remove unnecessary debug code --- datafusion/optimizer/src/push_down_filter.rs | 393 ++++++++++++++++--- datafusion/optimizer/src/utils.rs | 171 +++++++- 2 files changed, 508 insertions(+), 56 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f8e614a0aa84..a0262d7d95df 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -36,7 +36,7 @@ use datafusion_expr::{ }; use crate::optimizer::ApplyOrder; -use crate::utils::has_all_column_refs; +use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so @@ -558,10 +558,6 @@ fn infer_join_predicates( predicates: &[Expr], on_filters: &[Expr], ) -> Result> { - if join.join_type != JoinType::Inner { - return Ok(vec![]); - } - // Only allow both side key is column. let join_col_keys = join .on @@ -573,55 +569,176 @@ fn infer_join_predicates( }) .collect::>(); - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - - let columns = predicate.column_refs(); - - for &col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == *l { - join_cols_to_replace.insert(col, *r); - break; - } else if col == *r { - join_cols_to_replace.insert(col, *l); - break; - } - } - } + let join_type = join.join_type; - if join_cols_to_replace.is_empty() { - return None; - } + let mut inferred_predicates = InferredPredicates::new(join_type); - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; + infer_join_predicates_from_predicates( + &join_col_keys, + predicates, + &mut inferred_predicates, + )?; - Some(Ok(join_side_predicate)) - }) - .collect::>>() + infer_join_predicates_from_on_filters( + &join_col_keys, + join_type, + on_filters, + &mut inferred_predicates, + )?; + + Ok(inferred_predicates.predicates) +} + +/// Inferred predicates collector. +/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly +/// filter out NULL, otherwise ignore it. e.g. +/// ```text +/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL; +/// ``` +/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to +/// the left side, resulting in the wrong result. +struct InferredPredicates { + predicates: Vec, + is_inner_join: bool, +} + +impl InferredPredicates { + fn new(join_type: JoinType) -> Self { + Self { + predicates: vec![], + is_inner_join: matches!(join_type, JoinType::Inner), + } + } + + fn try_build_predicate( + &mut self, + predicate: Expr, + replace_map: &HashMap<&Column, &Column>, + ) -> Result<()> { + if self.is_inner_join + || matches!( + is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned() + ), + Ok(true) + ) + { + self.predicates.push(replace_col(predicate, replace_map)?); + } + + Ok(()) + } +} + +/// Infer predicates from the pushed down predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `predicates` the pushed down predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_predicates( + join_col_keys: &[(&Column, &Column)], + predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + infer_join_predicates_impl::( + join_col_keys, + predicates, + inferred_predicates, + ) +} + +/// Infer predicates from the join filter. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `join_type` the JoinType of Join +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_on_filters( + join_col_keys: &[(&Column, &Column)], + join_type: JoinType, + on_filters: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + match join_type { + JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()), + JoinType::Inner => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Left | JoinType::LeftSemi => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Right | JoinType::RightSemi => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + } +} + +/// Infer predicates from the given predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `input_predicates` the given predicates. It can be the pushed down predicates, +/// or it can be the filters of the Join +/// +/// * `inferred_predicates` the inferred results +/// +/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can +/// be inferred from the left table related predicate +/// +/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can +/// be inferred from the right table related predicate +/// +fn infer_join_predicates_impl< + const ENABLE_LEFT_TO_RIGHT: bool, + const ENABLE_RIGHT_TO_LEFT: bool, +>( + join_col_keys: &[(&Column, &Column)], + input_predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + for predicate in input_predicates { + let mut join_cols_to_replace = HashMap::new(); + + for &col in &predicate.column_refs() { + for (l, r) in join_col_keys.iter() { + if ENABLE_LEFT_TO_RIGHT && col == *l { + join_cols_to_replace.insert(col, *r); + break; + } + if ENABLE_RIGHT_TO_LEFT && col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } + if join_cols_to_replace.is_empty() { + continue; + } + + inferred_predicates + .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; + } + Ok(()) } impl OptimizerRule for PushDownFilter { @@ -1992,7 +2109,7 @@ mod tests { let expected = "\ Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) @@ -2032,7 +2149,7 @@ mod tests { \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(plan, expected) } @@ -2814,6 +2931,46 @@ Projection: a, b assert_optimized_plan_eq(optimized_plan, expected) } + #[test] + fn left_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2" + ); + + // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. + let expected = "\ + Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2855,6 +3012,46 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. + let expected = "\ + Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2896,6 +3093,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn left_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For left anti, filter of the right side filter can be pushed down. + let expected = "\ + Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -2942,6 +3184,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For right anti, filter of the left side can be pushed down. + let expected = "\ + Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6972c16c0ddf..9f325bc01b1d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,11 +21,18 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, Result}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use arrow::array::{new_null_array, Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::{logical_plan::LogicalPlan, Expr}; - +use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr}; +use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; +use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -117,3 +124,161 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); trace!("{description}::\n{}\n", plan.display_indent_schema()); } + +/// Determine whether a predicate can restrict NULLs. e.g. +/// `c0 > 8` return true; +/// `c0 IS NULL` return false. +pub fn is_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + if matches!(predicate, Expr::Column(_)) { + return Ok(true); + } + + static DUMMY_COL_NAME: &str = "?"; + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); + let input_schema = DFSchema::try_from(schema.clone())?; + let column = new_null_array(&DataType::Null, 1); + let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?; + let execution_props = ExecutionProps::default(); + let null_column = Column::from_name(DUMMY_COL_NAME); + + let join_cols_to_replace = join_cols_of_predicate + .into_iter() + .map(|column| (column, &null_column)) + .collect::>(); + + let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; + let coerced_predicate = coerce(replaced_predicate, &input_schema)?; + let phys_expr = + create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?; + + let result_type = phys_expr.data_type(&schema)?; + if !matches!(&result_type, DataType::Boolean) { + return Ok(false); + } + + // If result is single `true`, return false; + // If result is single `NULL` or `false`, return true; + Ok(match phys_expr.evaluate(&input_batch)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }) +} + +fn coerce(expr: Expr, schema: &DFSchema) -> Result { + let mut expr_rewrite = TypeCoercionRewriter { schema }; + expr.rewrite(&mut expr_rewrite).data() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; + + #[test] + fn expr_is_restrict_null_predicate() -> Result<()> { + let test_cases = vec![ + // a + (col("a"), true), + // a IS NULL + (is_null(col("a")), false), + // a IS NOT NULL + (Expr::IsNotNull(Box::new(col("a"))), true), + // a = NULL + ( + binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + true, + ), + // a > 8 + (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), + // a <= 8 + (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), + // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .when(lit(0i64), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + true, + ), + // CASE a WHEN 1 THEN true ELSE false END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + true, + ), + // CASE a WHEN 0 THEN false ELSE true END + ( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + false, + ), + // (CASE a WHEN 0 THEN false ELSE true END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + Operator::Or, + lit(false), + ), + false, + ), + // (CASE a WHEN 0 THEN true ELSE false END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + true, + ), + // a IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false), + true, + ), + // a NOT IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true), + true, + ), + // a IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + true, + ), + // a NOT IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + true, + ), + ]; + + let column_a = Column::from_name("a"); + for (predicate, expected) in test_cases { + let join_cols_of_predicate = std::iter::once(&column_a); + let actual = + is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + assert_eq!(actual, expected, "{}", predicate); + } + + Ok(()) + } +} From d62f2621f6335899dd095b70c2d969320386edaa Mon Sep 17 00:00:00 2001 From: Bruno Volpato Date: Tue, 29 Oct 2024 07:33:11 -0400 Subject: [PATCH 103/110] feat(substrait): support order_by in aggregate functions (#13114) --- .../substrait/src/logical_plan/consumer.rs | 17 ++- .../tests/cases/roundtrip_logical_plan.rs | 16 ++- ...aggregate_sorted_no_project.substrait.json | 113 ++++++++++++++++++ 3 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 99e7990df623..e0bb3b4e4f33 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -714,14 +714,27 @@ pub async fn from_substrait_rel( } _ => false, }; + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts( + ctx, + &f.sorts, + input.schema(), + extensions, + ) + .await?, + ) + } else { + None + }; + from_substrait_agg_func( ctx, f, input.schema(), extensions, filter, - // TODO: Add parsing of order_by also - None, + order_by, distinct, ) .await diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1f654f1d3c95..8108b9ad6767 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -685,6 +685,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> { .await } +#[tokio::test] +async fn aggregate_wo_projection_sorted_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await +} + #[tokio::test] async fn simple_intersect_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); @@ -1025,8 +1038,9 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); + roundtrip_with_ctx("select dummy_agg(a) from data", ctx.clone()).await?; + roundtrip_with_ctx("select dummy_agg(a order by a) from data", ctx.clone()).await?; - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await?; Ok(()) } diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json new file mode 100644 index 000000000000..d5170223cd65 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} \ No newline at end of file From c03e2604f03c23fdfd777ce5043c3d17621a6349 Mon Sep 17 00:00:00 2001 From: Yasser Latreche Date: Tue, 29 Oct 2024 07:36:59 -0400 Subject: [PATCH 104/110] Support `negate` expression in substrait (#13112) --- datafusion/substrait/src/logical_plan/consumer.rs | 9 +++++---- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e0bb3b4e4f33..43263196bb84 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -2613,7 +2613,7 @@ impl BuiltinExprBuilder { match name { "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" => Some(Self { + | "is_not_unknown" | "negative" | "negate" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -2634,8 +2634,9 @@ impl BuiltinExprBuilder { "ilike" => { Self::build_like_expr(ctx, true, f, input_schema, extensions).await } - "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" - | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => { Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) .await } @@ -2664,7 +2665,7 @@ impl BuiltinExprBuilder { let expr = match fn_name { "not" => Expr::Not(arg), - "negative" => Expr::Negative(arg), + "negative" | "negate" => Expr::Negative(arg), "is_null" => Expr::IsNull(arg), "is_not_null" => Expr::IsNotNull(arg), "is_true" => Expr::IsTrue(arg), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7b5165067225..17ed41f016bd 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1344,7 +1344,7 @@ pub fn to_substrait_rex( ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, - "negative", + "negate", arg, schema, col_ref_offset, From 1c2a2fdcfe4ebe49a1d8619a2429a3e691afa80c Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 29 Oct 2024 07:38:10 -0400 Subject: [PATCH 105/110] Fix an issue with to_char signature not working correctly with timezones or other types because the ordering is not in most exact -> least exact order. (#13126) --- datafusion/functions/src/datetime/to_char.rs | 24 +++++++++---------- .../sqllogictest/test_files/timestamps.slt | 5 ++++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 2fbfb2261180..f0c4a02c1523 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -54,34 +54,34 @@ impl ToCharFunc { vec![ Exact(vec![Date32, Utf8]), Exact(vec![Date64, Utf8]), + Exact(vec![Time64(Nanosecond), Utf8]), + Exact(vec![Time64(Microsecond), Utf8]), Exact(vec![Time32(Millisecond), Utf8]), Exact(vec![Time32(Second), Utf8]), - Exact(vec![Time64(Microsecond), Utf8]), - Exact(vec![Time64(Nanosecond), Utf8]), - Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![ - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Millisecond, None), Utf8]), + Exact(vec![Timestamp(Nanosecond, None), Utf8]), Exact(vec![ - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), Exact(vec![Timestamp(Microsecond, None), Utf8]), Exact(vec![ - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Nanosecond, None), Utf8]), + Exact(vec![Timestamp(Millisecond, None), Utf8]), Exact(vec![ - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Duration(Second), Utf8]), - Exact(vec![Duration(Millisecond), Utf8]), - Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![Duration(Nanosecond), Utf8]), + Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Duration(Millisecond), Utf8]), + Exact(vec![Duration(Second), Utf8]), ], Volatility::Immutable, ), diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index d866ec8c94dd..38c2a6647273 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2780,6 +2780,11 @@ FROM NULL 01:01:2025 23-59-58 +query T +select to_char('2020-01-01 00:10:20.123'::timestamp at time zone 'America/New_York', '%Y-%m-%d %H:%M:%S.%3f'); +---- +2020-01-01 00:10:20.123 + statement ok drop table formats; From b30d12a73fb9867180c2fdf8ddc818b45f957bac Mon Sep 17 00:00:00 2001 From: Michael J Ward Date: Tue, 29 Oct 2024 06:38:48 -0500 Subject: [PATCH 106/110] chore: re-export functions_window_common::ExpressionArgs (#13149) * chore: re-export functions_window_common::ExpressionArgs This struct is needed to implement the WindowUDFImpl trait. * cargo fmt --- datafusion/expr/src/function.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 199a91bf5ace..23ffc83e3549 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -27,6 +27,7 @@ pub use datafusion_functions_aggregate_common::accumulator::{ AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, }; +pub use datafusion_functions_window_common::expr::ExpressionArgs; pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; From d764c4af78f15188bf0a018be2bc00cd5ee3113f Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 29 Oct 2024 14:23:34 +0100 Subject: [PATCH 107/110] minor: Fix build on main (#13159) --- datafusion/sql/src/unparser/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 6da0a32282c6..b41b0a54b86f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1343,7 +1343,7 @@ impl Unparser<'_> { let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?; Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: "UNNEST".to_string(), quote_style: None, }]), From 444a673682a4823b98695568ecdab1087b5b6a60 Mon Sep 17 00:00:00 2001 From: Arttu Date: Tue, 29 Oct 2024 10:22:41 -0400 Subject: [PATCH 108/110] feat: Support Substrait's IntervalCompound type/literal instead of interval-month-day-nano UDT (#12112) * feat(substrait): use IntervalCompound instead of interval-month-day-nano UDT * clippy * more clippy * even more clippy * fix precision exponent * add a test * update deprecation version * update deprecation comments --- .../substrait/src/logical_plan/consumer.rs | 116 +++++++++-- .../substrait/src/logical_plan/producer.rs | 180 +++++------------- datafusion/substrait/src/variation_const.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 22 +-- 4 files changed, 153 insertions(+), 171 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 43263196bb84..2aaf8ec0aa06 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -42,17 +42,18 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, }; #[allow(deprecated)] use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, - TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, - TIMESTAMP_SECOND_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{new_empty_array, AsArray}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::dataframe::DataFrame; use datafusion::logical_expr::expr::InList; @@ -71,10 +72,10 @@ use datafusion::{ use std::collections::HashSet; use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ - IntervalDayToSecond, IntervalYearToMonth, UserDefined, + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; @@ -1845,9 +1846,14 @@ fn from_substrait_type( Ok(DataType::Interval(IntervalUnit::YearMonth)) } r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } r#type::Kind::UserDefined(u) => { if let Some(name) = extensions.types.get(&u.type_reference) { + #[allow(deprecated)] match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", @@ -1856,18 +1862,17 @@ fn from_substrait_type( ), } } else { - // Kept for backwards compatibility, new plans should include the extension instead #[allow(deprecated)] match u.type_reference { - // Kept for backwards compatibility, use IntervalYear instead + // Kept for backwards compatibility, producers should use IntervalYear instead INTERVAL_YEAR_MONTH_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } - // Kept for backwards compatibility, use IntervalDay instead + // Kept for backwards compatibility, producers should use IntervalDay instead INTERVAL_DAY_TIME_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::DayTime)) } - // Not supported yet by Substrait + // Kept for backwards compatibility, producers should use IntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_REF => { Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } @@ -2285,6 +2290,7 @@ fn from_substrait_literal( subseconds, precision_mode, })) => { + use interval_day_to_second::PrecisionMode; // DF only supports millisecond precision, so for any more granular type we lose precision let milliseconds = match precision_mode { Some(PrecisionMode::Microseconds(ms)) => ms / 1000, @@ -2309,6 +2315,35 @@ fn from_substrait_literal( Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { ScalarValue::new_interval_ym(*years, *months) } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed @@ -2339,6 +2374,8 @@ fn from_substrait_literal( if let Some(name) = extensions.types.get(&user_defined.type_reference) { match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { interval_month_day_nano(user_defined)? } @@ -2351,10 +2388,9 @@ fn from_substrait_literal( } } } else { - // Kept for backwards compatibility - new plans should include extension instead #[allow(deprecated)] match user_defined.type_reference { - // Kept for backwards compatibility, use IntervalYearToMonth instead + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead INTERVAL_YEAR_MONTH_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval year month value is empty"); @@ -2369,7 +2405,7 @@ fn from_substrait_literal( value_slice, ))) } - // Kept for backwards compatibility, use IntervalDayToSecond instead + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead INTERVAL_DAY_TIME_TYPE_REF => { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval day time value is empty"); @@ -2389,6 +2425,7 @@ fn from_substrait_literal( milliseconds, })) } + // Kept for backwards compatibility, producers should useIntervalCompound instead INTERVAL_MONTH_DAY_NANO_TYPE_REF => { interval_month_day_nano(user_defined)? } @@ -2738,3 +2775,52 @@ impl BuiltinExprBuilder { })) } } + +#[cfg(test)] +mod test { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use arrow_buffer::IntervalMonthDayNano; + use datafusion::error::Result; + use datafusion::scalar::ScalarValue; + use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, + IntervalYearToMonth, LiteralType, + }; + use substrait::proto::expression::Literal; + + #[test] + fn interval_compound_different_precision() -> Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + assert_eq!( + from_substrait_literal_without_names(&substrait, &Extensions::default())?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 17ed41f016bd..408885f70687 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -21,7 +21,6 @@ use datafusion::optimizer::AnalyzerRule; use std::sync::Arc; use substrait::proto::expression_reference::ExprType; -use arrow_buffer::ToByteSlice; use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, @@ -39,10 +38,11 @@ use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME, LARGE_CONTAINER_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, }; use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::{ exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, substrait_err, DFSchemaRef, ToDFSchema, @@ -58,8 +58,8 @@ use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; use substrait::proto::expression::literal::map::KeyValue; use substrait::proto::expression::literal::{ - user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, UserDefined, + IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; @@ -114,7 +114,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result>>()?; - let substrait_schema = to_substrait_named_struct(schema, &mut extensions)?; + let substrait_schema = to_substrait_named_struct(schema)?; Ok(Box::new(ExtendedExpression { advanced_extensions: None, @@ -203,7 +203,7 @@ pub fn to_substrait_rel( }); let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema, extensions)?; + let base_schema = to_substrait_named_struct(&table_schema)?; Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { @@ -229,7 +229,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&e.schema, extensions)?), + base_schema: Some(to_substrait_named_struct(&e.schema)?), filter: None, best_effort_filter: None, projection: None, @@ -268,7 +268,7 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(to_substrait_named_struct(&v.schema, extensions)?), + base_schema: Some(to_substrait_named_struct(&v.schema)?), filter: None, best_effort_filter: None, projection: None, @@ -664,10 +664,7 @@ fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Res Ok(()) } -fn to_substrait_named_struct( - schema: &DFSchemaRef, - extensions: &mut Extensions, -) -> Result { +fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { let mut names = Vec::with_capacity(schema.fields().len()); for field in schema.fields() { flatten_names(field, false, &mut names)?; @@ -677,7 +674,7 @@ fn to_substrait_named_struct( types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable(), extensions)) + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Unspecified as i32, @@ -1150,7 +1147,7 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true, extensions)?), + r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( ctx, expr, @@ -1356,11 +1353,7 @@ pub fn to_substrait_rex( } } -fn to_substrait_type( - dt: &DataType, - nullable: bool, - extensions: &mut Extensions, -) -> Result { +fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { let nullability = if nullable { r#type::Nullability::Nullable as i32 } else { @@ -1489,16 +1482,14 @@ fn to_substrait_type( })), }), IntervalUnit::MonthDayNano => { - // Substrait doesn't currently support this type, so we represent it as a UDT Ok(substrait::proto::Type { - kind: Some(r#type::Kind::UserDefined(r#type::UserDefined { - type_reference: extensions.register_type( - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - ), - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - type_parameters: vec![], - })), + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), }) } } @@ -1547,8 +1538,7 @@ fn to_substrait_type( })), }), DataType::List(inner) => { - let inner_type = - to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1558,8 +1548,7 @@ fn to_substrait_type( }) } DataType::LargeList(inner) => { - let inner_type = - to_substrait_type(inner.data_type(), inner.is_nullable(), extensions)?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -1573,12 +1562,10 @@ fn to_substrait_type( let key_type = to_substrait_type( key_and_value[0].data_type(), key_and_value[0].is_nullable(), - extensions, )?; let value_type = to_substrait_type( key_and_value[1].data_type(), key_and_value[1].is_nullable(), - extensions, )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { @@ -1594,9 +1581,7 @@ fn to_substrait_type( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| { - to_substrait_type(field.data_type(), field.is_nullable(), extensions) - }) + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -1783,7 +1768,6 @@ fn to_substrait_literal( literal_type: Some(LiteralType::Null(to_substrait_type( &value.data_type(), true, - extensions, )?)), }); } @@ -1892,23 +1876,21 @@ fn to_substrait_literal( }), DEFAULT_TYPE_VARIATION_REF, ), - ScalarValue::IntervalMonthDayNano(Some(i)) => { - // IntervalMonthDayNano is internally represented as a 128-bit integer, containing - // months (32bit), days (32bit), and nanoseconds (64bit) - let bytes = i.to_byte_slice(); - ( - LiteralType::UserDefined(UserDefined { - type_reference: extensions - .register_type(INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string()), - type_parameters: vec![], - val: Some(user_defined::Val::Value(ProtoAny { - type_url: INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string(), - value: bytes.to_vec().into(), - })), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, }), - DEFAULT_TYPE_VARIATION_REF, - ) - } + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), ScalarValue::IntervalDayTime(Some(i)) => ( LiteralType::IntervalDayToSecond(IntervalDayToSecond { days: i.days, @@ -1964,7 +1946,7 @@ fn to_substrait_literal( ), ScalarValue::Map(m) => { let map = if m.is_empty() || m.value(0).is_empty() { - let mt = to_substrait_type(m.data_type(), m.is_nullable(), extensions)?; + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; let mt = match mt { substrait::proto::Type { kind: Some(r#type::Kind::Map(mt)), @@ -2049,11 +2031,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let lt = match to_substrait_type( - array.data_type(), - array.is_nullable(), - extensions, - )? { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -2179,7 +2157,6 @@ mod test { use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::DFSchema; - use std::collections::HashMap; #[test] fn round_trip_literals() -> Result<()> { @@ -2310,39 +2287,6 @@ mod test { Ok(()) } - #[test] - fn custom_type_literal_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let scalar = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::new( - 17, 25, 1234567890, - ))); - let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&substrait_literal, &extensions)?; - assert_eq!(scalar, roundtrip_scalar); - - assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } - ); - - // Check we fail if we don't propagate extensions - assert!(from_substrait_literal_without_names( - &substrait_literal, - &Extensions::default() - ) - .is_err()); - Ok(()) - } - #[test] fn round_trip_types() -> Result<()> { round_trip_type(DataType::Boolean)?; @@ -2414,50 +2358,17 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); - let mut extensions = Extensions::default(); - // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true, &mut extensions)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; + let substrait = to_substrait_type(&dt, true)?; + let roundtrip_dt = + from_substrait_type_without_names(&substrait, &Extensions::default())?; assert_eq!(dt, roundtrip_dt); Ok(()) } - #[test] - fn custom_type_extensions() -> Result<()> { - let mut extensions = Extensions::default(); - // IntervalMonthDayNano is represented as a custom type in Substrait - let dt = DataType::Interval(IntervalUnit::MonthDayNano); - - let substrait = to_substrait_type(&dt, true, &mut extensions)?; - let roundtrip_dt = from_substrait_type_without_names(&substrait, &extensions)?; - assert_eq!(dt, roundtrip_dt); - - assert_eq!( - extensions, - Extensions { - functions: HashMap::new(), - types: HashMap::from([( - 0, - INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() - )]), - type_variations: HashMap::new(), - } - ); - - // Check we fail if we don't propagate extensions - assert!( - from_substrait_type_without_names(&substrait, &Extensions::default()) - .is_err() - ); - - Ok(()) - } - #[test] fn named_struct_names() -> Result<()> { - let mut extensions = Extensions::default(); let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ Field::new("int", DataType::Int32, true), Field::new( @@ -2472,7 +2383,7 @@ mod test { Field::new("trailer", DataType::Float64, true), ]))?); - let named_struct = to_substrait_named_struct(&schema, &mut extensions)?; + let named_struct = to_substrait_named_struct(&schema)?; // Struct field names should be flattened DFS style // List field names should be omitted @@ -2481,7 +2392,8 @@ mod test { vec!["int", "struct", "inner", "trailer"] ); - let roundtrip_schema = from_substrait_named_struct(&named_struct, &extensions)?; + let roundtrip_schema = + from_substrait_named_struct(&named_struct, &Extensions::default())?; assert_eq!(schema.as_ref(), &roundtrip_schema); Ok(()) } diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index a3e76389d510..58774db424da 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -96,7 +96,7 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano #[deprecated( since = "41.0.0", - note = "Use Substrait `UserDefinedType` with name `INTERVAL_MONTH_DAY_NANO_TYPE_NAME` instead" + note = "Use Substrait `IntervalCompund` type instead" )] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; @@ -104,4 +104,8 @@ pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; /// /// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano +#[deprecated( + since = "43.0.0", + note = "Use Substrait `IntervalCompund` type instead" +)] pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 8108b9ad6767..04530dd34d4b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -39,10 +39,7 @@ use std::hash::Hash; use std::sync::Arc; use datafusion::execution::session_state::SessionStateBuilder; -use substrait::proto::extensions::simple_extension_declaration::{ - ExtensionType, MappingType, -}; -use substrait::proto::extensions::SimpleExtensionDeclaration; +use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; @@ -230,23 +227,6 @@ async fn select_with_reused_functions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn roundtrip_udt_extensions() -> Result<()> { - let ctx = create_context().await?; - let proto = - roundtrip_with_ctx("SELECT INTERVAL '1 YEAR 1 DAY 1 SECOND' FROM data", ctx) - .await?; - let expected_type = SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionType(ExtensionType { - extension_uri_reference: u32::MAX, - type_anchor: 0, - name: "interval-month-day-nano".to_string(), - })), - }; - assert_eq!(proto.extensions, vec![expected_type]); - Ok(()) -} - #[tokio::test] async fn select_with_filter_date() -> Result<()> { roundtrip("SELECT * FROM data WHERE c > CAST('2020-01-01' AS DATE)").await From 223bb02fce886b47dc1ac81e2eda2bd3c6d60c3e Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Tue, 29 Oct 2024 14:39:48 -0400 Subject: [PATCH 109/110] docs: switch completely to generated docs for scalar and aggregate functions (#13161) * Remove _new docs, update index, update docs build script to point to main .md files for aggregate & scalar function pages. * update documentation --- dev/update_function_docs.sh | 20 +- .../user-guide/sql/aggregate_functions.md | 839 +++- .../user-guide/sql/aggregate_functions_new.md | 865 ---- docs/source/user-guide/sql/index.rst | 2 - .../source/user-guide/sql/scalar_functions.md | 4334 +++++++++++++++- .../user-guide/sql/scalar_functions_new.md | 4365 ----------------- 6 files changed, 5125 insertions(+), 5300 deletions(-) delete mode 100644 docs/source/user-guide/sql/aggregate_functions_new.md delete mode 100644 docs/source/user-guide/sql/scalar_functions_new.md diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh index 13bc22afcc13..ad3bc9c7f69c 100755 --- a/dev/update_function_docs.sh +++ b/dev/update_function_docs.sh @@ -24,7 +24,7 @@ SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "${SOURCE_DIR}/../" && pwd -TARGET_FILE="docs/source/user-guide/sql/aggregate_functions_new.md" +TARGET_FILE="docs/source/user-guide/sql/aggregate_functions.md" PRINT_AGGREGATE_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- aggregate" echo "Inserting header" @@ -56,13 +56,7 @@ update documentation for an individual UDF or the dev/update_function_docs.sh file for updating surrounding text. --> -# Aggregate Functions (NEW) - -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Aggregate Functions (old)](aggregate_functions.md) page for -the rest of the documentation. - -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 +# Aggregate Functions Aggregate functions operate on a set of values to compute a single result. EOF @@ -75,7 +69,7 @@ npx prettier@2.3.2 --write "$TARGET_FILE" echo "'$TARGET_FILE' successfully updated!" -TARGET_FILE="docs/source/user-guide/sql/scalar_functions_new.md" +TARGET_FILE="docs/source/user-guide/sql/scalar_functions.md" PRINT_SCALAR_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- scalar" echo "Inserting header" @@ -107,13 +101,7 @@ update documentation for an individual UDF or the dev/update_function_docs.sh file for updating surrounding text. --> -# Scalar Functions (NEW) - -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Scalar Functions (old)](aggregate_functions.md) page for -the rest of the documentation. - -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 +# Scalar Functions EOF diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 77f527c92cda..d9fc28a81772 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -17,6 +17,843 @@ under the License. --> + + # Aggregate Functions -Note: this documentation has been migrated to [Aggregate Functions (new)](aggregate_functions_new.md) +Aggregate functions operate on a set of values to compute a single result. + +## General Functions + +- [array_agg](#array_agg) +- [avg](#avg) +- [bit_and](#bit_and) +- [bit_or](#bit_or) +- [bit_xor](#bit_xor) +- [bool_and](#bool_and) +- [bool_or](#bool_or) +- [count](#count) +- [first_value](#first_value) +- [grouping](#grouping) +- [last_value](#last_value) +- [max](#max) +- [mean](#mean) +- [median](#median) +- [min](#min) +- [string_agg](#string_agg) +- [sum](#sum) +- [var](#var) +- [var_pop](#var_pop) +- [var_population](#var_population) +- [var_samp](#var_samp) +- [var_sample](#var_sample) + +### `array_agg` + +Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order. + +``` +array_agg(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +``` + +### `avg` + +Returns the average of numeric values in the specified column. + +``` +avg(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +``` + +#### Aliases + +- mean + +### `bit_and` + +Computes the bitwise AND of all non-null input values. + +``` +bit_and(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bit_or` + +Computes the bitwise OR of all non-null input values. + +``` +bit_or(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bit_xor` + +Computes the bitwise exclusive OR of all non-null input values. + +``` +bit_xor(expression) +``` + +#### Arguments + +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `bool_and` + +Returns true if all non-null input values are true, otherwise false. + +``` +bool_and(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` + +### `bool_or` + +Returns true if all non-null input values are true, otherwise false. + +``` +bool_and(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` + +### `count` + +Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`. + +``` +count(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +``` + +### `first_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` + +### `grouping` + +Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. + +``` +grouping(expression) +``` + +#### Arguments + +- **expression**: Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +``` + +### `last_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` + +### `max` + +Returns the maximum value in the specified column. + +``` +max(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` + +### `mean` + +_Alias of [avg](#avg)._ + +### `median` + +Returns the median value in the specified column. + +``` +median(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +``` + +### `min` + +Returns the maximum value in the specified column. + +``` +max(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` + +### `string_agg` + +Concatenates the values of string expressions and places separator values between them. + +``` +string_agg(expression, delimiter) +``` + +#### Arguments + +- **expression**: The string expression to concatenate. Can be a column or any valid string expression. +- **delimiter**: A literal string used as a separator between the concatenated values. + +#### Example + +```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +``` + +### `sum` + +Returns the sum of all values in the specified column. + +``` +sum(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +``` + +### `var` + +Returns the statistical sample variance of a set of numbers. + +``` +var(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_sample +- var_samp + +### `var_pop` + +Returns the statistical population variance of a set of numbers. + +``` +var_pop(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_population + +### `var_population` + +_Alias of [var_pop](#var_pop)._ + +### `var_samp` + +_Alias of [var](#var)._ + +### `var_sample` + +_Alias of [var](#var)._ + +## Statistical Functions + +- [corr](#corr) +- [covar](#covar) +- [covar_pop](#covar_pop) +- [covar_samp](#covar_samp) +- [nth_value](#nth_value) +- [regr_avgx](#regr_avgx) +- [regr_avgy](#regr_avgy) +- [regr_count](#regr_count) +- [regr_intercept](#regr_intercept) +- [regr_r2](#regr_r2) +- [regr_slope](#regr_slope) +- [regr_sxx](#regr_sxx) +- [regr_sxy](#regr_sxy) +- [regr_syy](#regr_syy) +- [stddev](#stddev) +- [stddev_pop](#stddev_pop) +- [stddev_samp](#stddev_samp) + +### `corr` + +Returns the coefficient of correlation between two numeric values. + +``` +corr(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ +``` + +### `covar` + +_Alias of [covar_samp](#covar_samp)._ + +### `covar_pop` + +Returns the sample covariance of a set of number pairs. + +``` +covar_samp(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +``` + +### `covar_samp` + +Returns the sample covariance of a set of number pairs. + +``` +covar_samp(expression1, expression2) +``` + +#### Arguments + +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +``` + +#### Aliases + +- covar + +### `nth_value` + +Returns the nth value in a group of values. + +``` +nth_value(expression, n ORDER BY expression) +``` + +#### Arguments + +- **expression**: The column or expression to retrieve the nth value from. +- **n**: The position (nth) of the value to retrieve, based on the ordering. + +#### Example + +```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +``` + +### `regr_avgx` + +Computes the average of the independent variable (input) expression_x for the non-null paired data points. + +``` +regr_avgx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_avgy` + +Computes the average of the dependent variable (output) expression_y for the non-null paired data points. + +``` +regr_avgy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_count` + +Counts the number of non-null paired data points. + +``` +regr_count(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_intercept` + +Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. + +``` +regr_intercept(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_r2` + +Computes the square of the correlation coefficient between the independent and dependent variables. + +``` +regr_r2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_slope` + +Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. + +``` +regr_slope(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxx` + +Computes the sum of squares of the independent variable. + +``` +regr_sxx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_sxy` + +Computes the sum of products of paired data points. + +``` +regr_sxy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `regr_syy` + +Computes the sum of squares of the dependent variable. + +``` +regr_syy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `stddev` + +Returns the standard deviation of a set of numbers. + +``` +stddev(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +``` + +#### Aliases + +- stddev_samp + +### `stddev_pop` + +Returns the standard deviation of a set of numbers. + +``` +stddev(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +``` + +### `stddev_samp` + +_Alias of [stddev](#stddev)._ + +## Approximate Functions + +- [approx_distinct](#approx_distinct) +- [approx_median](#approx_median) +- [approx_percentile_cont](#approx_percentile_cont) +- [approx_percentile_cont_with_weight](#approx_percentile_cont_with_weight) + +### `approx_distinct` + +Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm. + +``` +approx_distinct(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +``` + +### `approx_median` + +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. + +``` +approx_median(expression) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +``` + +### `approx_percentile_cont` + +Returns the approximate percentile of input values using the t-digest algorithm. + +``` +approx_percentile_cont(expression, percentile, centroids) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. + +#### Example + +```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +``` + +### `approx_percentile_cont_with_weight` + +Returns the weighted approximate percentile of input values using the t-digest algorithm. + +``` +approx_percentile_cont_with_weight(expression, weight, percentile) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. +- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). + +#### Example + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/aggregate_functions_new.md b/docs/source/user-guide/sql/aggregate_functions_new.md deleted file mode 100644 index ad6d15b94ee5..000000000000 --- a/docs/source/user-guide/sql/aggregate_functions_new.md +++ /dev/null @@ -1,865 +0,0 @@ - - - - -# Aggregate Functions (NEW) - -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Aggregate Functions (old)](aggregate_functions.md) page for -the rest of the documentation. - -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 - -Aggregate functions operate on a set of values to compute a single result. - -## General Functions - -- [array_agg](#array_agg) -- [avg](#avg) -- [bit_and](#bit_and) -- [bit_or](#bit_or) -- [bit_xor](#bit_xor) -- [bool_and](#bool_and) -- [bool_or](#bool_or) -- [count](#count) -- [first_value](#first_value) -- [grouping](#grouping) -- [last_value](#last_value) -- [max](#max) -- [mean](#mean) -- [median](#median) -- [min](#min) -- [string_agg](#string_agg) -- [sum](#sum) -- [var](#var) -- [var_pop](#var_pop) -- [var_population](#var_population) -- [var_samp](#var_samp) -- [var_sample](#var_sample) - -### `array_agg` - -Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order. - -``` -array_agg(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; -+-----------------------------------------------+ -| array_agg(column_name ORDER BY other_column) | -+-----------------------------------------------+ -| [element1, element2, element3] | -+-----------------------------------------------+ -``` - -### `avg` - -Returns the average of numeric values in the specified column. - -``` -avg(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT avg(column_name) FROM table_name; -+---------------------------+ -| avg(column_name) | -+---------------------------+ -| 42.75 | -+---------------------------+ -``` - -#### Aliases - -- mean - -### `bit_and` - -Computes the bitwise AND of all non-null input values. - -``` -bit_and(expression) -``` - -#### Arguments - -- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `bit_or` - -Computes the bitwise OR of all non-null input values. - -``` -bit_or(expression) -``` - -#### Arguments - -- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `bit_xor` - -Computes the bitwise exclusive OR of all non-null input values. - -``` -bit_xor(expression) -``` - -#### Arguments - -- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `bool_and` - -Returns true if all non-null input values are true, otherwise false. - -``` -bool_and(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT bool_and(column_name) FROM table_name; -+----------------------------+ -| bool_and(column_name) | -+----------------------------+ -| true | -+----------------------------+ -``` - -### `bool_or` - -Returns true if all non-null input values are true, otherwise false. - -``` -bool_and(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT bool_and(column_name) FROM table_name; -+----------------------------+ -| bool_and(column_name) | -+----------------------------+ -| true | -+----------------------------+ -``` - -### `count` - -Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`. - -``` -count(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT count(column_name) FROM table_name; -+-----------------------+ -| count(column_name) | -+-----------------------+ -| 100 | -+-----------------------+ - -> SELECT count(*) FROM table_name; -+------------------+ -| count(*) | -+------------------+ -| 120 | -+------------------+ -``` - -### `first_value` - -Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. - -``` -first_value(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT first_value(column_name ORDER BY other_column) FROM table_name; -+-----------------------------------------------+ -| first_value(column_name ORDER BY other_column)| -+-----------------------------------------------+ -| first_element | -+-----------------------------------------------+ -``` - -### `grouping` - -Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. - -``` -grouping(expression) -``` - -#### Arguments - -- **expression**: Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function. - -#### Example - -```sql -> SELECT column_name, GROUPING(column_name) AS group_column - FROM table_name - GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ -| column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ -``` - -### `last_value` - -Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. - -``` -first_value(expression [ORDER BY expression]) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT first_value(column_name ORDER BY other_column) FROM table_name; -+-----------------------------------------------+ -| first_value(column_name ORDER BY other_column)| -+-----------------------------------------------+ -| first_element | -+-----------------------------------------------+ -``` - -### `max` - -Returns the maximum value in the specified column. - -``` -max(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT max(column_name) FROM table_name; -+----------------------+ -| max(column_name) | -+----------------------+ -| 150 | -+----------------------+ -``` - -### `mean` - -_Alias of [avg](#avg)._ - -### `median` - -Returns the median value in the specified column. - -``` -median(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT median(column_name) FROM table_name; -+----------------------+ -| median(column_name) | -+----------------------+ -| 45.5 | -+----------------------+ -``` - -### `min` - -Returns the maximum value in the specified column. - -``` -max(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT max(column_name) FROM table_name; -+----------------------+ -| max(column_name) | -+----------------------+ -| 150 | -+----------------------+ -``` - -### `string_agg` - -Concatenates the values of string expressions and places separator values between them. - -``` -string_agg(expression, delimiter) -``` - -#### Arguments - -- **expression**: The string expression to concatenate. Can be a column or any valid string expression. -- **delimiter**: A literal string used as a separator between the concatenated values. - -#### Example - -```sql -> SELECT string_agg(name, ', ') AS names_list - FROM employee; -+--------------------------+ -| names_list | -+--------------------------+ -| Alice, Bob, Charlie | -+--------------------------+ -``` - -### `sum` - -Returns the sum of all values in the specified column. - -``` -sum(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT sum(column_name) FROM table_name; -+-----------------------+ -| sum(column_name) | -+-----------------------+ -| 12345 | -+-----------------------+ -``` - -### `var` - -Returns the statistical sample variance of a set of numbers. - -``` -var(expression) -``` - -#### Arguments - -- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Aliases - -- var_sample -- var_samp - -### `var_pop` - -Returns the statistical population variance of a set of numbers. - -``` -var_pop(expression) -``` - -#### Arguments - -- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Aliases - -- var_population - -### `var_population` - -_Alias of [var_pop](#var_pop)._ - -### `var_samp` - -_Alias of [var](#var)._ - -### `var_sample` - -_Alias of [var](#var)._ - -## Statistical Functions - -- [corr](#corr) -- [covar](#covar) -- [covar_pop](#covar_pop) -- [covar_samp](#covar_samp) -- [nth_value](#nth_value) -- [regr_avgx](#regr_avgx) -- [regr_avgy](#regr_avgy) -- [regr_count](#regr_count) -- [regr_intercept](#regr_intercept) -- [regr_r2](#regr_r2) -- [regr_slope](#regr_slope) -- [regr_sxx](#regr_sxx) -- [regr_sxy](#regr_sxy) -- [regr_syy](#regr_syy) -- [stddev](#stddev) -- [stddev_pop](#stddev_pop) -- [stddev_samp](#stddev_samp) - -### `corr` - -Returns the coefficient of correlation between two numeric values. - -``` -corr(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT corr(column1, column2) FROM table_name; -+--------------------------------+ -| corr(column1, column2) | -+--------------------------------+ -| 0.85 | -+--------------------------------+ -``` - -### `covar` - -_Alias of [covar_samp](#covar_samp)._ - -### `covar_pop` - -Returns the sample covariance of a set of number pairs. - -``` -covar_samp(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT covar_samp(column1, column2) FROM table_name; -+-----------------------------------+ -| covar_samp(column1, column2) | -+-----------------------------------+ -| 8.25 | -+-----------------------------------+ -``` - -### `covar_samp` - -Returns the sample covariance of a set of number pairs. - -``` -covar_samp(expression1, expression2) -``` - -#### Arguments - -- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT covar_samp(column1, column2) FROM table_name; -+-----------------------------------+ -| covar_samp(column1, column2) | -+-----------------------------------+ -| 8.25 | -+-----------------------------------+ -``` - -#### Aliases - -- covar - -### `nth_value` - -Returns the nth value in a group of values. - -``` -nth_value(expression, n ORDER BY expression) -``` - -#### Arguments - -- **expression**: The column or expression to retrieve the nth value from. -- **n**: The position (nth) of the value to retrieve, based on the ordering. - -#### Example - -```sql -> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept - FROM employee; -+---------+--------+-------------------------+ -| dept_id | salary | second_salary_by_dept | -+---------+--------+-------------------------+ -| 1 | 30000 | NULL | -| 1 | 40000 | 40000 | -| 1 | 50000 | 40000 | -| 2 | 35000 | NULL | -| 2 | 45000 | 45000 | -+---------+--------+-------------------------+ -``` - -### `regr_avgx` - -Computes the average of the independent variable (input) expression_x for the non-null paired data points. - -``` -regr_avgx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_avgy` - -Computes the average of the dependent variable (output) expression_y for the non-null paired data points. - -``` -regr_avgy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_count` - -Counts the number of non-null paired data points. - -``` -regr_count(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_intercept` - -Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. - -``` -regr_intercept(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_r2` - -Computes the square of the correlation coefficient between the independent and dependent variables. - -``` -regr_r2(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_slope` - -Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. - -``` -regr_slope(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_sxx` - -Computes the sum of squares of the independent variable. - -``` -regr_sxx(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_sxy` - -Computes the sum of products of paired data points. - -``` -regr_sxy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `regr_syy` - -Computes the sum of squares of the dependent variable. - -``` -regr_syy(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `stddev` - -Returns the standard deviation of a set of numbers. - -``` -stddev(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT stddev(column_name) FROM table_name; -+----------------------+ -| stddev(column_name) | -+----------------------+ -| 12.34 | -+----------------------+ -``` - -#### Aliases - -- stddev_samp - -### `stddev_pop` - -Returns the standard deviation of a set of numbers. - -``` -stddev(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT stddev(column_name) FROM table_name; -+----------------------+ -| stddev(column_name) | -+----------------------+ -| 12.34 | -+----------------------+ -``` - -### `stddev_samp` - -_Alias of [stddev](#stddev)._ - -## Approximate Functions - -- [approx_distinct](#approx_distinct) -- [approx_median](#approx_median) -- [approx_percentile_cont](#approx_percentile_cont) -- [approx_percentile_cont_with_weight](#approx_percentile_cont_with_weight) - -### `approx_distinct` - -Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm. - -``` -approx_distinct(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT approx_distinct(column_name) FROM table_name; -+-----------------------------------+ -| approx_distinct(column_name) | -+-----------------------------------+ -| 42 | -+-----------------------------------+ -``` - -### `approx_median` - -Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. - -``` -approx_median(expression) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> SELECT approx_median(column_name) FROM table_name; -+-----------------------------------+ -| approx_median(column_name) | -+-----------------------------------+ -| 23.5 | -+-----------------------------------+ -``` - -### `approx_percentile_cont` - -Returns the approximate percentile of input values using the t-digest algorithm. - -``` -approx_percentile_cont(expression, percentile, centroids) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). -- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. - -#### Example - -```sql -> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ -``` - -### `approx_percentile_cont_with_weight` - -Returns the weighted approximate percentile of input values using the t-digest algorithm. - -``` -approx_percentile_cont_with_weight(expression, weight, percentile) -``` - -#### Arguments - -- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. -- **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). - -#### Example - -```sql -> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ -| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ -``` diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 8b8afc7b048a..4499aac53611 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -30,11 +30,9 @@ SQL Reference information_schema operators aggregate_functions - aggregate_functions_new window_functions window_functions_new scalar_functions - scalar_functions_new special_functions sql_status write_options diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index a8e25930bef7..98c44cbd981d 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -17,111 +17,4343 @@ under the License. --> + + # Scalar Functions -Scalar functions operate on a single row at a time and return a single value. +## Math Functions + +- [abs](#abs) +- [acos](#acos) +- [acosh](#acosh) +- [asin](#asin) +- [asinh](#asinh) +- [atan](#atan) +- [atan2](#atan2) +- [atanh](#atanh) +- [cbrt](#cbrt) +- [ceil](#ceil) +- [cos](#cos) +- [cosh](#cosh) +- [cot](#cot) +- [degrees](#degrees) +- [exp](#exp) +- [factorial](#factorial) +- [floor](#floor) +- [gcd](#gcd) +- [isnan](#isnan) +- [iszero](#iszero) +- [lcm](#lcm) +- [ln](#ln) +- [log](#log) +- [log10](#log10) +- [log2](#log2) +- [nanvl](#nanvl) +- [pi](#pi) +- [pow](#pow) +- [power](#power) +- [radians](#radians) +- [random](#random) +- [round](#round) +- [signum](#signum) +- [sin](#sin) +- [sinh](#sinh) +- [sqrt](#sqrt) +- [tan](#tan) +- [tanh](#tanh) +- [trunc](#trunc) + +### `abs` + +Returns the absolute value of a number. + +``` +abs(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `acos` + +Returns the arc cosine or inverse cosine of a number. + +``` +acos(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `acosh` + +Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number. + +``` +acosh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `asin` + +Returns the arc sine or inverse sine of a number. + +``` +asin(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `asinh` + +Returns the area hyperbolic sine or inverse hyperbolic sine of a number. + +``` +asinh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `atan` + +Returns the arc tangent or inverse tangent of a number. + +``` +atan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `atan2` + +Returns the arc tangent or inverse tangent of `expression_y / expression_x`. + +``` +atan2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Second numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `atanh` + +Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. + +``` +atanh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cbrt` + +Returns the cube root of a number. + +``` +cbrt(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `ceil` + +Returns the nearest integer greater than or equal to a number. + +``` +ceil(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cos` + +Returns the cosine of a number. + +``` +cos(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cosh` + +Returns the hyperbolic cosine of a number. + +``` +cosh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cot` + +Returns the cotangent of a number. + +``` +cot(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `degrees` + +Converts radians to degrees. + +``` +degrees(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `exp` + +Returns the base-e exponential of a number. + +``` +exp(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `factorial` + +Factorial. Returns 1 if value is less than 2. + +``` +factorial(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `floor` + +Returns the nearest integer less than or equal to a number. + +``` +floor(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `gcd` + +Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero. + +``` +gcd(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `isnan` + +Returns true if a given number is +NaN or -NaN otherwise returns false. + +``` +isnan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `iszero` + +Returns true if a given number is +0.0 or -0.0 otherwise returns false. + +``` +iszero(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `lcm` + +Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. + +``` +lcm(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `ln` + +Returns the natural logarithm of a number. + +``` +ln(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log` + +Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. + +``` +log(base, numeric_expression) +log(numeric_expression) +``` + +#### Arguments + +- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log10` + +Returns the base-10 logarithm of a number. + +``` +log10(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `log2` + +Returns the base-2 logarithm of a number. + +``` +log2(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `nanvl` + +Returns the first argument if it's not _NaN_. +Returns the second argument otherwise. + +``` +nanvl(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. + +### `pi` + +Returns an approximate value of π. + +``` +pi() +``` + +### `pow` + +_Alias of [power](#power)._ + +### `power` + +Returns a base expression raised to the power of an exponent. + +``` +power(base, exponent) +``` + +#### Arguments + +- **base**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **exponent**: Exponent numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- pow + +### `radians` + +Converts degrees to radians. + +``` +radians(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `random` + +Returns a random float value in the range [0, 1). +The random seed is unique to each row. + +``` +random() +``` + +### `round` + +Rounds a number to the nearest integer. + +``` +round(numeric_expression[, decimal_places]) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **decimal_places**: Optional. The number of decimal places to round to. Defaults to 0. + +### `signum` + +Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`. + +``` +signum(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sin` + +Returns the sine of a number. + +``` +sin(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sinh` + +Returns the hyperbolic sine of a number. + +``` +sinh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `sqrt` + +Returns the square root of a number. + +``` +sqrt(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `tan` + +Returns the tangent of a number. + +``` +tan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `tanh` + +Returns the hyperbolic tangent of a number. + +``` +tanh(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `trunc` + +Truncates a number to a whole number or truncated to the specified decimal places. + +``` +trunc(numeric_expression[, decimal_places]) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **decimal_places**: Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`. + +## Conditional Functions + +- [coalesce](#coalesce) +- [ifnull](#ifnull) +- [nullif](#nullif) +- [nvl](#nvl) +- [nvl2](#nvl2) + +### `coalesce` + +Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. + +``` +coalesce(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. + +#### Example + +```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +``` + +### `ifnull` + +_Alias of [nvl](#nvl)._ + +### `nullif` + +Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce). + +``` +nullif(expression1, expression2) +``` + +#### Arguments + +- **expression1**: Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +``` + +### `nvl` + +Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_. + +``` +nvl(expression1, expression2) +``` + +#### Arguments + +- **expression1**: Expression to return if not null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` + +#### Aliases + +- ifnull + +### `nvl2` + +Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. + +``` +nvl2(expression1, expression2, expression3) +``` + +#### Arguments + +- **expression1**: Expression to test for null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators. +- **expression3**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` + +## String Functions + +- [ascii](#ascii) +- [bit_length](#bit_length) +- [btrim](#btrim) +- [char_length](#char_length) +- [character_length](#character_length) +- [chr](#chr) +- [concat](#concat) +- [concat_ws](#concat_ws) +- [contains](#contains) +- [ends_with](#ends_with) +- [find_in_set](#find_in_set) +- [initcap](#initcap) +- [instr](#instr) +- [left](#left) +- [length](#length) +- [levenshtein](#levenshtein) +- [lower](#lower) +- [lpad](#lpad) +- [ltrim](#ltrim) +- [octet_length](#octet_length) +- [position](#position) +- [repeat](#repeat) +- [replace](#replace) +- [reverse](#reverse) +- [right](#right) +- [rpad](#rpad) +- [rtrim](#rtrim) +- [split_part](#split_part) +- [starts_with](#starts_with) +- [strpos](#strpos) +- [substr](#substr) +- [substr_index](#substr_index) +- [substring](#substring) +- [substring_index](#substring_index) +- [to_hex](#to_hex) +- [translate](#translate) +- [trim](#trim) +- [upper](#upper) +- [uuid](#uuid) + +### `ascii` + +Returns the Unicode character code of the first character in a string. + +``` +ascii(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +``` + +**Related functions**: + +- [chr](#chr) + +### `bit_length` + +Returns the bit length of a string. + +``` +bit_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +``` + +**Related functions**: + +- [length](#length) +- [octet_length](#octet_length) + +### `btrim` + +Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. + +``` +btrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ + +#### Example + +```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(BOTH trim_str FROM str) +``` + +```sql +trim(trim_str FROM str) +``` + +#### Aliases + +- trim + +**Related functions**: + +- [ltrim](#ltrim) +- [rtrim](#rtrim) + +### `char_length` + +_Alias of [character_length](#character_length)._ + +### `character_length` + +Returns the number of characters in a string. + +``` +character_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +``` + +#### Aliases + +- length +- char_length + +**Related functions**: + +- [bit_length](#bit_length) +- [octet_length](#octet_length) + +### `chr` + +Returns the character with the specified ASCII or Unicode code value. + +``` +chr(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +``` + +**Related functions**: + +- [ascii](#ascii) + +### `concat` + +Concatenates multiple strings together. + +``` +concat(str[, ..., str_n]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +``` + +**Related functions**: + +- [concat_ws](#concat_ws) + +### `concat_ws` + +Concatenates multiple strings together with a specified separator. + +``` +concat_ws(separator, str[, ..., str_n]) +``` + +#### Arguments + +- **separator**: Separator to insert between concatenated strings. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +``` + +**Related functions**: + +- [concat](#concat) + +### `contains` + +Return true if search_str is found within string (case-sensitive). + +``` +contains(str, search_str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **search_str**: The string to search for in str. + +#### Example + +```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +``` + +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +``` + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + +#### Example + +```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +``` + +### `initcap` + +Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. + +``` +initcap(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +``` + +**Related functions**: + +- [lower](#lower) +- [upper](#upper) + +### `instr` + +_Alias of [strpos](#strpos)._ + +### `left` + +Returns a specified number of characters from the left side of a string. + +``` +left(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return. + +#### Example + +```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +``` + +**Related functions**: + +- [right](#right) + +### `length` + +_Alias of [character_length](#character_length)._ + +### `levenshtein` + +Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings. + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +#### Example + +```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +``` + +### `lower` + +Converts a string to lower-case. + +``` +lower(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [upper](#upper) + +### `lpad` + +Pads the left side of a string with another string to a specified string length. + +``` +lpad(str, n[, padding_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: String length to pad to. +- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +``` + +**Related functions**: + +- [rpad](#rpad) + +### `ltrim` + +Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. + +``` +ltrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(LEADING trim_str FROM str) +``` + +**Related functions**: + +- [btrim](#btrim) +- [rtrim](#rtrim) + +### `octet_length` + +Returns the length of a string in bytes. + +``` +octet_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +``` + +**Related functions**: + +- [bit_length](#bit_length) +- [length](#length) + +### `position` + +_Alias of [strpos](#strpos)._ + +### `repeat` + +Returns a string with an input string repeated a specified number. + +``` +repeat(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of times to repeat the input string. + +#### Example + +```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +``` + +### `replace` + +Replaces all occurrences of a specified substring in a string with a new substring. + +``` +replace(str, substr, replacement) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to replace in the input string. Substring expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +``` + +### `reverse` + +Reverses the character order of a string. + +``` +reverse(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +``` + +### `right` + +Returns a specified number of characters from the right side of a string. + +``` +right(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return + +#### Example + +```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +``` + +**Related functions**: + +- [left](#left) + +### `rpad` + +Pads the right side of a string with another string to a specified string length. + +``` +rpad(str, n[, padding_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: String length to pad to. +- **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +``` + +**Related functions**: + +- [lpad](#lpad) + +### `rtrim` + +Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. + +``` +rtrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(TRAILING trim_str FROM str) +``` + +**Related functions**: + +- [btrim](#btrim) +- [ltrim](#ltrim) + +### `split_part` + +Splits a string based on a specified delimiter and returns the substring in the specified position. + +``` +split_part(str, delimiter, pos) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delimiter**: String or character to split on. +- **pos**: Position of the part to return. + +#### Example + +```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +``` + +### `starts_with` + +Tests if a string starts with a substring. + +``` +starts_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +``` + +### `strpos` + +Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0. + +``` +strpos(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to search for. + +#### Example + +```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +``` + +#### Alternative Syntax + +```sql +position(substr in origstr) +``` + +#### Aliases + +- instr +- position + +### `substr` + +Extracts a substring of a specified number of characters from a specific starting position in a string. + +``` +substr(str, start_pos[, length]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start_pos**: Character position to start the substring at. The first character in the string has a position of 1. +- **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. + +#### Example + +```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +substring(str from start_pos for length) +``` + +#### Aliases + +- substring + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delim**: The string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be either a positive or negative number. + +#### Example + +```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +``` + +#### Aliases + +- substring_index + +### `substring` + +_Alias of [substr](#substr)._ + +### `substring_index` + +_Alias of [substr_index](#substr_index)._ + +### `to_hex` + +Converts an integer to a hexadecimal string. + +``` +to_hex(int) +``` + +#### Arguments + +- **int**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +``` + +### `translate` + +Translates characters in a string to specified translation characters. + +``` +translate(str, chars, translation) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **chars**: Characters to translate. +- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. + +#### Example + +```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +``` + +### `trim` + +_Alias of [btrim](#btrim)._ + +### `upper` + +Converts a string to upper-case. + +``` +upper(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [lower](#lower) + +### `uuid` + +Returns [`UUID v4`]() string value which is unique per row. + +``` +uuid() +``` + +#### Example + +```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +``` + +## Binary String Functions + +- [decode](#decode) +- [encode](#encode) + +### `decode` + +Decode binary data from textual representation in string. + +``` +decode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing encoded string data +- **format**: Same arguments as [encode](#encode) + +**Related functions**: + +- [encode](#encode) + +### `encode` + +Encode binary data into a textual representation. + +``` +encode(expression, format) +``` + +#### Arguments + +- **expression**: Expression containing string or binary data +- **format**: Supported formats are: `base64`, `hex` + +**Related functions**: + +- [decode](#decode) + +## Regular Expression Functions + +Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) +(minus support for several features including look-around and backreferences). +The following regular expression functions are supported: + +- [regexp_count](#regexp_count) +- [regexp_like](#regexp_like) +- [regexp_match](#regexp_match) +- [regexp_replace](#regexp_replace) + +### `regexp_count` + +Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +``` + +### `regexp_like` + +Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. + +``` +regexp_like(str, regexp[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql +select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); ++--------------------------------------------------------+ +| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | ++--------------------------------------------------------+ +| true | ++--------------------------------------------------------+ +SELECT regexp_like('aBc', '(b|d)', 'i'); ++--------------------------------------------------+ +| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +### `regexp_match` + +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. + +``` +regexp_match(str, regexp[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +### `regexp_replace` + +Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). + +``` +regexp_replace(str, regexp, replacement[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **replacement**: Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +## Time and Date Functions + +- [current_date](#current_date) +- [current_time](#current_time) +- [current_timestamp](#current_timestamp) +- [date_bin](#date_bin) +- [date_format](#date_format) +- [date_part](#date_part) +- [date_trunc](#date_trunc) +- [datepart](#datepart) +- [datetrunc](#datetrunc) +- [from_unixtime](#from_unixtime) +- [make_date](#make_date) +- [now](#now) +- [to_char](#to_char) +- [to_date](#to_date) +- [to_local_time](#to_local_time) +- [to_timestamp](#to_timestamp) +- [to_timestamp_micros](#to_timestamp_micros) +- [to_timestamp_millis](#to_timestamp_millis) +- [to_timestamp_nanos](#to_timestamp_nanos) +- [to_timestamp_seconds](#to_timestamp_seconds) +- [to_unixtime](#to_unixtime) +- [today](#today) + +### `current_date` + +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. + +``` +current_date() +``` + +#### Aliases + +- today + +### `current_time` + +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. + +``` +current_time() +``` + +### `current_timestamp` + +_Alias of [now](#now)._ + +### `date_bin` + +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. + +``` +date_bin(interval, expression, origin-timestamp) +``` + +#### Arguments + +- **interval**: Bin interval. +- **expression**: Time expression to operate on. Can be a constant, column, or function. +- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century + +### `date_format` + +_Alias of [to_char](#to_char)._ + +### `date_part` + +Returns the specified part of the date as an integer. + +``` +date_part(part, expression) +``` + +#### Arguments + +- **part**: Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Alternative Syntax + +```sql +extract(field FROM source) +``` + +#### Aliases + +- datepart + +### `date_trunc` + +Truncates a timestamp value to a specified precision. + +``` +date_trunc(precision, expression) +``` + +#### Arguments + +- **precision**: Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Aliases + +- datetrunc + +### `datepart` + +_Alias of [date_part](#date_part)._ + +### `datetrunc` + +_Alias of [date_trunc](#date_trunc)._ + +### `from_unixtime` + +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. + +``` +from_unixtime(expression) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. + +### `make_date` + +Make a date from year/month/day component parts. + +``` +make_date(year, month, day) +``` + +#### Arguments + +- **year**: Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **month**: Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) + +### `now` + +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. + +``` +now() +``` + +#### Aliases + +- current_timestamp + +### `to_char` + +Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. + +``` +to_char(expression, format) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration. +- **format**: A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. + +#### Example + +```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) + +#### Aliases + +- date_format + +### `to_date` + +Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. + +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. + +``` +to_date('2017-05-31', '%Y-%m-%d') +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned. + +#### Example + +```sql +> select to_date('2023-01-31'); ++-----------------------------+ +| to_date(Utf8("2023-01-31")) | ++-----------------------------+ +| 2023-01-31 | ++-----------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) + +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + +### `to_timestamp` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + +``` +to_timestamp(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_micros` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp. + +``` +to_timestamp_micros(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_millis` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_millis(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_nanos` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_nanos(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_timestamp_seconds` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +``` +to_timestamp_seconds(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) + +### `to_unixtime` + +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. + +``` +to_unixtime(expression[, ..., format_n]) +``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` + +### `today` + +_Alias of [current_date](#current_date)._ + +## Array Functions + +- [array_any_value](#array_any_value) +- [array_append](#array_append) +- [array_cat](#array_cat) +- [array_concat](#array_concat) +- [array_contains](#array_contains) +- [array_dims](#array_dims) +- [array_distance](#array_distance) +- [array_distinct](#array_distinct) +- [array_element](#array_element) +- [array_empty](#array_empty) +- [array_except](#array_except) +- [array_extract](#array_extract) +- [array_has](#array_has) +- [array_has_all](#array_has_all) +- [array_has_any](#array_has_any) +- [array_indexof](#array_indexof) +- [array_intersect](#array_intersect) +- [array_join](#array_join) +- [array_length](#array_length) +- [array_ndims](#array_ndims) +- [array_pop_back](#array_pop_back) +- [array_pop_front](#array_pop_front) +- [array_position](#array_position) +- [array_positions](#array_positions) +- [array_prepend](#array_prepend) +- [array_push_back](#array_push_back) +- [array_push_front](#array_push_front) +- [array_remove](#array_remove) +- [array_remove_all](#array_remove_all) +- [array_remove_n](#array_remove_n) +- [array_repeat](#array_repeat) +- [array_replace](#array_replace) +- [array_replace_all](#array_replace_all) +- [array_replace_n](#array_replace_n) +- [array_resize](#array_resize) +- [array_reverse](#array_reverse) +- [array_slice](#array_slice) +- [array_sort](#array_sort) +- [array_to_string](#array_to_string) +- [array_union](#array_union) +- [cardinality](#cardinality) +- [empty](#empty) +- [flatten](#flatten) +- [generate_series](#generate_series) +- [list_any_value](#list_any_value) +- [list_append](#list_append) +- [list_cat](#list_cat) +- [list_concat](#list_concat) +- [list_contains](#list_contains) +- [list_dims](#list_dims) +- [list_distance](#list_distance) +- [list_distinct](#list_distinct) +- [list_element](#list_element) +- [list_empty](#list_empty) +- [list_except](#list_except) +- [list_extract](#list_extract) +- [list_has](#list_has) +- [list_has_all](#list_has_all) +- [list_has_any](#list_has_any) +- [list_indexof](#list_indexof) +- [list_intersect](#list_intersect) +- [list_join](#list_join) +- [list_length](#list_length) +- [list_ndims](#list_ndims) +- [list_pop_back](#list_pop_back) +- [list_pop_front](#list_pop_front) +- [list_position](#list_position) +- [list_positions](#list_positions) +- [list_prepend](#list_prepend) +- [list_push_back](#list_push_back) +- [list_push_front](#list_push_front) +- [list_remove](#list_remove) +- [list_remove_all](#list_remove_all) +- [list_remove_n](#list_remove_n) +- [list_repeat](#list_repeat) +- [list_replace](#list_replace) +- [list_replace_all](#list_replace_all) +- [list_replace_n](#list_replace_n) +- [list_resize](#list_resize) +- [list_reverse](#list_reverse) +- [list_slice](#list_slice) +- [list_sort](#list_sort) +- [list_to_string](#list_to_string) +- [list_union](#list_union) +- [make_array](#make_array) +- [make_list](#make_list) +- [range](#range) +- [string_to_array](#string_to_array) +- [string_to_list](#string_to_list) + +### `array_any_value` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_any_value + +### `array_append` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- list_append +- array_push_back +- list_push_back + +### `array_cat` + +_Alias of [array_concat](#array_concat)._ + +### `array_concat` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- array_cat +- list_concat +- list_cat + +### `array_contains` + +_Alias of [array_has](#array_has)._ + +### `array_dims` + +Returns an array of the array's dimensions. + +``` +array_dims(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +``` + +#### Aliases + +- list_dims + +### `array_distance` + +Returns the Euclidean distance between two input arrays of equal length. + +``` +array_distance(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +``` + +#### Aliases + +- list_distance + +### `array_distinct` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_distinct + +### `array_element` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- array_extract +- list_element +- list_extract + +### `array_empty` + +_Alias of [empty](#empty)._ + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_except + +### `array_extract` + +_Alias of [array_element](#array_element)._ + +### `array_has` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has +- array_contains +- list_contains + +### `array_has_all` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has_all + +### `array_has_any` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has_any + +### `array_indexof` + +_Alias of [array_position](#array_position)._ + +### `array_intersect` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_intersect + +### `array_join` + +_Alias of [array_to_string](#array_to_string)._ + +### `array_length` + +Returns the length of the array dimension. + +``` +array_length(array, dimension) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **dimension**: Array dimension. + +#### Example + +```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ +``` + +#### Aliases + +- list_length + +### `array_ndims` + +Returns an array of the array's dimensions. + +``` +array_dims(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +``` + +#### Aliases + +- list_ndims + +### `array_pop_back` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_pop_back + +### `array_pop_front` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_pop_front + +### `array_position` + +Returns the position of the first occurrence of the specified element in the array. + +``` +array_position(array, element) +array_position(array, element, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. + +#### Example + +```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_position +- array_indexof +- list_indexof + +### `array_positions` + +Returns the position of the first occurrence of the specified element in the array. + +``` +array_position(array, element) +array_position(array, element, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. + +#### Example + +```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_positions + +### `array_prepend` + +Appends an element to the end of an array. + +``` +array_append(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. + +#### Example + +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +``` + +#### Aliases + +- list_prepend +- array_push_front +- list_push_front + +### `array_push_back` + +_Alias of [array_append](#array_append)._ + +### `array_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `array_remove` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove + +### `array_remove_all` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove_all + +### `array_remove_n` + +Removes the first element from the array equal to the given value. + +``` +array_remove(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. + +#### Example + +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +``` + +#### Aliases + +- list_remove_n + +### `array_repeat` + +Returns an array containing element `count` times. + +``` +array_repeat(element, count) +``` + +#### Arguments + +- **element**: Element expression. Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. + +#### Example + +```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +``` + +#### Aliases + +- list_repeat + +### `array_replace` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace + +### `array_replace_all` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_all + +### `array_replace_n` + +Replaces the first `max` occurrences of the specified element with another specified element. + +``` +array_replace_n(array, from, to, max) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **from**: Initial element. +- **to**: Final element. +- **max**: Number of first occurrences to replace. + +#### Example + +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_n + +### `array_resize` + +Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. + +``` +array_resize(array, size, value) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **size**: New size of given array. +- **value**: Defines new elements' value or empty if value is not set. + +#### Example + +```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ +``` + +#### Aliases + +- list_resize + +### `array_reverse` + +Returns the array with the order of the elements reversed. + +``` +array_reverse(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_reverse([1, 2, 3, 4]); ++------------------------------------------------------------+ +| array_reverse(List([1, 2, 3, 4])) | ++------------------------------------------------------------+ +| [4, 3, 2, 1] | ++------------------------------------------------------------+ +``` + +#### Aliases + +- list_reverse + +### `array_slice` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- list_slice + +### `array_sort` + +Sort array. + +``` +array_sort(array, desc, nulls_first) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). + +#### Example + +```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +``` + +#### Aliases + +- list_sort + +### `array_to_string` + +Converts each element to its text representation. + +``` +array_to_string(array, delimiter) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. + +#### Example + +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_to_string +- array_join +- list_join + +### `array_union` + +Returns distinct values from the array after removing duplicates. + +``` +array_distinct(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +``` + +#### Aliases + +- list_union + +### `cardinality` + +Returns the total number of elements in the array. + +``` +cardinality(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +``` + +### `empty` + +Returns 1 for an empty array or 0 for a non-empty array. + +``` +empty(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +``` + +#### Aliases + +- array_empty +- list_empty + +### `flatten` + +Converts an array of arrays to a flat array. + +- Applies to any depth of nested arrays +- Does not change arrays that are already flat + +The flattened array contains all the elements from all source arrays. + +``` +flatten(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +``` + +### `generate_series` + +Similar to the range function, but it includes the upper bound. + +``` +generate_series(start, stop, step) +``` + +#### Arguments + +- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: end of the series (included). Type must be the same as start. +- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. + +#### Example + +```sql +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ +``` + +### `list_any_value` + +_Alias of [array_any_value](#array_any_value)._ + +### `list_append` + +_Alias of [array_append](#array_append)._ + +### `list_cat` + +_Alias of [array_concat](#array_concat)._ + +### `list_concat` + +_Alias of [array_concat](#array_concat)._ + +### `list_contains` + +_Alias of [array_has](#array_has)._ + +### `list_dims` + +_Alias of [array_dims](#array_dims)._ + +### `list_distance` + +_Alias of [array_distance](#array_distance)._ + +### `list_distinct` + +_Alias of [array_distinct](#array_distinct)._ + +### `list_element` + +_Alias of [array_element](#array_element)._ + +### `list_empty` + +_Alias of [empty](#empty)._ + +### `list_except` + +_Alias of [array_except](#array_except)._ + +### `list_extract` + +_Alias of [array_element](#array_element)._ + +### `list_has` + +_Alias of [array_has](#array_has)._ + +### `list_has_all` + +_Alias of [array_has_all](#array_has_all)._ + +### `list_has_any` + +_Alias of [array_has_any](#array_has_any)._ + +### `list_indexof` + +_Alias of [array_position](#array_position)._ + +### `list_intersect` + +_Alias of [array_intersect](#array_intersect)._ + +### `list_join` + +_Alias of [array_to_string](#array_to_string)._ + +### `list_length` + +_Alias of [array_length](#array_length)._ + +### `list_ndims` + +_Alias of [array_ndims](#array_ndims)._ + +### `list_pop_back` + +_Alias of [array_pop_back](#array_pop_back)._ + +### `list_pop_front` + +_Alias of [array_pop_front](#array_pop_front)._ + +### `list_position` + +_Alias of [array_position](#array_position)._ -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Scalar Functions (new)](scalar_functions_new.md) page for -the rest of the documentation. +### `list_positions` -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 +_Alias of [array_positions](#array_positions)._ -## Conditional Functions +### `list_prepend` -See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) +_Alias of [array_prepend](#array_prepend)._ -## String Functions +### `list_push_back` -See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) +_Alias of [array_append](#array_append)._ -### `position` +### `list_push_front` + +_Alias of [array_prepend](#array_prepend)._ + +### `list_remove` + +_Alias of [array_remove](#array_remove)._ + +### `list_remove_all` + +_Alias of [array_remove_all](#array_remove_all)._ + +### `list_remove_n` + +_Alias of [array_remove_n](#array_remove_n)._ + +### `list_repeat` + +_Alias of [array_repeat](#array_repeat)._ + +### `list_replace` + +_Alias of [array_replace](#array_replace)._ + +### `list_replace_all` + +_Alias of [array_replace_all](#array_replace_all)._ -Returns the position of `substr` in `origstr` (counting from 1). If `substr` does -not appear in `origstr`, return 0. +### `list_replace_n` + +_Alias of [array_replace_n](#array_replace_n)._ + +### `list_resize` + +_Alias of [array_resize](#array_resize)._ + +### `list_reverse` + +_Alias of [array_reverse](#array_reverse)._ + +### `list_slice` + +_Alias of [array_slice](#array_slice)._ + +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + +### `list_to_string` + +_Alias of [array_to_string](#array_to_string)._ + +### `list_union` + +_Alias of [array_union](#array_union)._ + +### `make_array` + +Returns an array using the specified input expressions. ``` -position(substr in origstr) +make_array(expression1[, ..., expression_n]) ``` #### Arguments -- **substr**: The pattern string. -- **origstr**: The model string. +- **expression_n**: Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators. -## Time and Date Functions +#### Example + +```sql +> select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +``` + +#### Aliases + +- make_list + +### `make_list` -- [extract](#extract) +_Alias of [make_array](#make_array)._ -### `extract` +### `range` -Returns a sub-field from a time value as an integer. +Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0. ``` -extract(field FROM source) +range(start, stop, step) ``` -Equivalent to calling `date_part('field', source)`. For example, these are equivalent: +#### Arguments + +- **start**: Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: End of the range (not included). Type must be the same as start. +- **step**: Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges. + +#### Example ```sql -extract(day FROM '2024-04-13'::date) -date_part('day', '2024-04-13'::date) +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ + +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ ``` -See [date_part](#date_part). +### `string_to_array` -## Array Functions +Converts each element to its text representation. -- [range](#range) +``` +array_to_string(array, delimiter) +``` -### `range` +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. + +#### Example + +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` + +#### Aliases + +- string_to_list -Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or -`SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` +### `string_to_list` -The range start..end contains all values with start <= x < end. It is empty if start >= end. +_Alias of [string_to_array](#string_to_array)._ -Step can not be 0 (then the range will be nonsense.). +## Struct Functions -Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, -but when the required range is a date or timestamp, it must be 3 non-NULL parameters. -For example, +- [named_struct](#named_struct) +- [row](#row) +- [struct](#struct) + +### `named_struct` + +Returns an Arrow struct using the specified name and input expressions pairs. ``` -SELECT range(3); -SELECT range(1,5); -SELECT range(1,5,1); +named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) +``` + +#### Arguments + +- **expression_n_name**: Name of the column field. Must be a constant string. +- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators. + +#### Example + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: + +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ ``` -are allowed in number ranges +### `row` + +_Alias of [struct](#struct)._ + +### `struct` -but in date and timestamp ranges, only +Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc. ``` -SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); -SELECT range(TIMESTAMP '1992-09-01', TIMESTAMP '1993-03-01', INTERVAL '1' MONTH); +struct(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators. + +#### Example + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `c1`: + +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +-- use default names `c0`, `c1` +> select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ + +-- name the first field `field_a` +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ ``` -is allowed, and +#### Aliases + +- row + +## Map Functions + +- [element_at](#element_at) +- [map](#map) +- [map_extract](#map_extract) +- [map_keys](#map_keys) +- [map_values](#map_values) + +### `element_at` + +_Alias of [map_extract](#map_extract)._ + +### `map` + +Returns an Arrow map with the specified key-value pairs. + +The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null. ``` -SELECT range(DATE '1992-09-01', DATE '1993-03-01', NULL); -SELECT range(NULL, DATE '1993-03-01', INTERVAL '1' MONTH); -SELECT range(DATE '1992-09-01', NULL, INTERVAL '1' MONTH); +map(key, value) +map(key: value) +make_map(['key1', 'key2'], ['value1', 'value2']) ``` -are not allowed +#### Arguments + +- **key**: For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null. +- **value**: For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of values to be mapped to the corresponding keys. + +#### Example + +````sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ``` + + +### `map_extract` + +Returns a list containing the value for the given key or an empty list if the key is not present in the map. + +```` + +map_extract(map, key) + +```` #### Arguments -- **start**: start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. -- **end**: end of the range (not included). Type must be the same as start. -- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. +- **key**: Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed. + +#### Example + +```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```` #### Aliases -- generate_series +- element_at + +### `map_keys` + +Returns a list of all keys in the map. + +``` +map_keys(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +``` + +### `map_values` + +Returns a list of all values in the map. + +``` +map_values(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +``` + +## Hashing Functions + +- [digest](#digest) +- [md5](#md5) +- [sha224](#sha224) +- [sha256](#sha256) +- [sha384](#sha384) +- [sha512](#sha512) + +### `digest` + +Computes the binary hash of an expression using the specified algorithm. + +``` +digest(expression, algorithm) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **algorithm**: String expression specifying algorithm to use. Must be one of: +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3 + +#### Example + +```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `md5` + +Computes an MD5 128-bit checksum for a string expression. + +``` +md5(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +``` + +### `sha224` + +Computes the SHA-224 hash of a binary string. + +``` +sha224(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `sha256` + +Computes the SHA-256 hash of a binary string. + +``` +sha256(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +``` + +### `sha384` + +Computes the SHA-384 hash of a binary string. + +``` +sha384(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +``` + +### `sha512` + +Computes the SHA-512 hash of a binary string. + +``` +sha512(expression) +``` + +#### Arguments + +- **expression**: String + +#### Example + +```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +``` ## Other Functions -See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) +- [arrow_cast](#arrow_cast) +- [arrow_typeof](#arrow_typeof) +- [get_field](#get_field) +- [version](#version) + +### `arrow_cast` + +Casts a value to a specific Arrow data type. + +``` +arrow_cast(expression, datatype) +``` + +#### Arguments + +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] + +#### Example + +```sql +> select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +``` + +### `arrow_typeof` + +Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression. + +``` +arrow_typeof(expression) +``` + +#### Arguments + +- **expression**: Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +``` + +### `get_field` + +Returns a field within a map or a struct with the given key. +Note: most users invoke `get_field` indirectly via field access +syntax such as `my_struct_col['field_name']` which results in a call to +`get_field(my_struct_col, 'field_name')`. + +``` +get_field(expression1, expression2) +``` + +#### Arguments + +- **expression1**: The map or struct to retrieve a field for. +- **expression2**: The field name in the map or struct to retrieve data for. Must evaluate to a string. + +#### Example + +```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +``` + +### `version` + +Returns the version of DataFusion. + +``` +version() +``` + +#### Example + +```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md deleted file mode 100644 index 56173b97b405..000000000000 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ /dev/null @@ -1,4365 +0,0 @@ - - - - -# Scalar Functions (NEW) - -Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. -Please see the [Scalar Functions (old)](aggregate_functions.md) page for -the rest of the documentation. - -[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 - -## Math Functions - -- [abs](#abs) -- [acos](#acos) -- [acosh](#acosh) -- [asin](#asin) -- [asinh](#asinh) -- [atan](#atan) -- [atan2](#atan2) -- [atanh](#atanh) -- [cbrt](#cbrt) -- [ceil](#ceil) -- [cos](#cos) -- [cosh](#cosh) -- [cot](#cot) -- [degrees](#degrees) -- [exp](#exp) -- [factorial](#factorial) -- [floor](#floor) -- [gcd](#gcd) -- [isnan](#isnan) -- [iszero](#iszero) -- [lcm](#lcm) -- [ln](#ln) -- [log](#log) -- [log10](#log10) -- [log2](#log2) -- [nanvl](#nanvl) -- [pi](#pi) -- [pow](#pow) -- [power](#power) -- [radians](#radians) -- [random](#random) -- [round](#round) -- [signum](#signum) -- [sin](#sin) -- [sinh](#sinh) -- [sqrt](#sqrt) -- [tan](#tan) -- [tanh](#tanh) -- [trunc](#trunc) - -### `abs` - -Returns the absolute value of a number. - -``` -abs(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `acos` - -Returns the arc cosine or inverse cosine of a number. - -``` -acos(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `acosh` - -Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number. - -``` -acosh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `asin` - -Returns the arc sine or inverse sine of a number. - -``` -asin(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `asinh` - -Returns the area hyperbolic sine or inverse hyperbolic sine of a number. - -``` -asinh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `atan` - -Returns the arc tangent or inverse tangent of a number. - -``` -atan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `atan2` - -Returns the arc tangent or inverse tangent of `expression_y / expression_x`. - -``` -atan2(expression_y, expression_x) -``` - -#### Arguments - -- **expression_y**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `atanh` - -Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. - -``` -atanh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `cbrt` - -Returns the cube root of a number. - -``` -cbrt(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `ceil` - -Returns the nearest integer greater than or equal to a number. - -``` -ceil(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `cos` - -Returns the cosine of a number. - -``` -cos(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `cosh` - -Returns the hyperbolic cosine of a number. - -``` -cosh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `cot` - -Returns the cotangent of a number. - -``` -cot(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `degrees` - -Converts radians to degrees. - -``` -degrees(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `exp` - -Returns the base-e exponential of a number. - -``` -exp(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `factorial` - -Factorial. Returns 1 if value is less than 2. - -``` -factorial(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `floor` - -Returns the nearest integer less than or equal to a number. - -``` -floor(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `gcd` - -Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero. - -``` -gcd(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `isnan` - -Returns true if a given number is +NaN or -NaN otherwise returns false. - -``` -isnan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `iszero` - -Returns true if a given number is +0.0 or -0.0 otherwise returns false. - -``` -iszero(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `lcm` - -Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. - -``` -lcm(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `ln` - -Returns the natural logarithm of a number. - -``` -ln(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `log` - -Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. - -``` -log(base, numeric_expression) -log(numeric_expression) -``` - -#### Arguments - -- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `log10` - -Returns the base-10 logarithm of a number. - -``` -log10(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `log2` - -Returns the base-2 logarithm of a number. - -``` -log2(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `nanvl` - -Returns the first argument if it's not _NaN_. -Returns the second argument otherwise. - -``` -nanvl(expression_x, expression_y) -``` - -#### Arguments - -- **expression_x**: Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. - -### `pi` - -Returns an approximate value of π. - -``` -pi() -``` - -### `pow` - -_Alias of [power](#power)._ - -### `power` - -Returns a base expression raised to the power of an exponent. - -``` -power(base, exponent) -``` - -#### Arguments - -- **base**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **exponent**: Exponent numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Aliases - -- pow - -### `radians` - -Converts degrees to radians. - -``` -radians(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `random` - -Returns a random float value in the range [0, 1). -The random seed is unique to each row. - -``` -random() -``` - -### `round` - -Rounds a number to the nearest integer. - -``` -round(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **decimal_places**: Optional. The number of decimal places to round to. Defaults to 0. - -### `signum` - -Returns the sign of a number. -Negative numbers return `-1`. -Zero and positive numbers return `1`. - -``` -signum(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `sin` - -Returns the sine of a number. - -``` -sin(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `sinh` - -Returns the hyperbolic sine of a number. - -``` -sinh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `sqrt` - -Returns the square root of a number. - -``` -sqrt(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `tan` - -Returns the tangent of a number. - -``` -tan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `tanh` - -Returns the hyperbolic tangent of a number. - -``` -tanh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - -### `trunc` - -Truncates a number to a whole number or truncated to the specified decimal places. - -``` -trunc(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **decimal_places**: Optional. The number of decimal places to - truncate to. Defaults to 0 (truncate to a whole number). If - `decimal_places` is a positive integer, truncates digits to the - right of the decimal point. If `decimal_places` is a negative - integer, replaces digits to the left of the decimal point with `0`. - -## Conditional Functions - -- [coalesce](#coalesce) -- [ifnull](#ifnull) -- [nullif](#nullif) -- [nvl](#nvl) -- [nvl2](#nvl2) - -### `coalesce` - -Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. - -``` -coalesce(expression1[, ..., expression_n]) -``` - -#### Arguments - -- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. - -#### Example - -```sql -> select coalesce(null, null, 'datafusion'); -+----------------------------------------+ -| coalesce(NULL,NULL,Utf8("datafusion")) | -+----------------------------------------+ -| datafusion | -+----------------------------------------+ -``` - -### `ifnull` - -_Alias of [nvl](#nvl)._ - -### `nullif` - -Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. -This can be used to perform the inverse operation of [`coalesce`](#coalesce). - -``` -nullif(expression1, expression2) -``` - -#### Arguments - -- **expression1**: Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select nullif('datafusion', 'data'); -+-----------------------------------------+ -| nullif(Utf8("datafusion"),Utf8("data")) | -+-----------------------------------------+ -| datafusion | -+-----------------------------------------+ -> select nullif('datafusion', 'datafusion'); -+-----------------------------------------------+ -| nullif(Utf8("datafusion"),Utf8("datafusion")) | -+-----------------------------------------------+ -| | -+-----------------------------------------------+ -``` - -### `nvl` - -Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_. - -``` -nvl(expression1, expression2) -``` - -#### Arguments - -- **expression1**: Expression to return if not null. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select nvl(null, 'a'); -+---------------------+ -| nvl(NULL,Utf8("a")) | -+---------------------+ -| a | -+---------------------+\ -> select nvl('b', 'a'); -+--------------------------+ -| nvl(Utf8("b"),Utf8("a")) | -+--------------------------+ -| b | -+--------------------------+ -``` - -#### Aliases - -- ifnull - -### `nvl2` - -Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. - -``` -nvl2(expression1, expression2, expression3) -``` - -#### Arguments - -- **expression1**: Expression to test for null. Can be a constant, column, or function, and any combination of operators. -- **expression2**: Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators. -- **expression3**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select nvl2(null, 'a', 'b'); -+--------------------------------+ -| nvl2(NULL,Utf8("a"),Utf8("b")) | -+--------------------------------+ -| b | -+--------------------------------+ -> select nvl2('data', 'a', 'b'); -+----------------------------------------+ -| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | -+----------------------------------------+ -| a | -+----------------------------------------+ -``` - -## String Functions - -- [ascii](#ascii) -- [bit_length](#bit_length) -- [btrim](#btrim) -- [char_length](#char_length) -- [character_length](#character_length) -- [chr](#chr) -- [concat](#concat) -- [concat_ws](#concat_ws) -- [contains](#contains) -- [ends_with](#ends_with) -- [find_in_set](#find_in_set) -- [initcap](#initcap) -- [instr](#instr) -- [left](#left) -- [length](#length) -- [levenshtein](#levenshtein) -- [lower](#lower) -- [lpad](#lpad) -- [ltrim](#ltrim) -- [octet_length](#octet_length) -- [position](#position) -- [repeat](#repeat) -- [replace](#replace) -- [reverse](#reverse) -- [right](#right) -- [rpad](#rpad) -- [rtrim](#rtrim) -- [split_part](#split_part) -- [starts_with](#starts_with) -- [strpos](#strpos) -- [substr](#substr) -- [substr_index](#substr_index) -- [substring](#substring) -- [substring_index](#substring_index) -- [to_hex](#to_hex) -- [translate](#translate) -- [trim](#trim) -- [upper](#upper) -- [uuid](#uuid) - -### `ascii` - -Returns the Unicode character code of the first character in a string. - -``` -ascii(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select ascii('abc'); -+--------------------+ -| ascii(Utf8("abc")) | -+--------------------+ -| 97 | -+--------------------+ -> select ascii('🚀'); -+-------------------+ -| ascii(Utf8("🚀")) | -+-------------------+ -| 128640 | -+-------------------+ -``` - -**Related functions**: - -- [chr](#chr) - -### `bit_length` - -Returns the bit length of a string. - -``` -bit_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select bit_length('datafusion'); -+--------------------------------+ -| bit_length(Utf8("datafusion")) | -+--------------------------------+ -| 80 | -+--------------------------------+ -``` - -**Related functions**: - -- [length](#length) -- [octet_length](#octet_length) - -### `btrim` - -Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. - -``` -btrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ - -#### Example - -```sql -> select btrim('__datafusion____', '_'); -+-------------------------------------------+ -| btrim(Utf8("__datafusion____"),Utf8("_")) | -+-------------------------------------------+ -| datafusion | -+-------------------------------------------+ -``` - -#### Alternative Syntax - -```sql -trim(BOTH trim_str FROM str) -``` - -```sql -trim(trim_str FROM str) -``` - -#### Aliases - -- trim - -**Related functions**: - -- [ltrim](#ltrim) -- [rtrim](#rtrim) - -### `char_length` - -_Alias of [character_length](#character_length)._ - -### `character_length` - -Returns the number of characters in a string. - -``` -character_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select character_length('Ångström'); -+------------------------------------+ -| character_length(Utf8("Ångström")) | -+------------------------------------+ -| 8 | -+------------------------------------+ -``` - -#### Aliases - -- length -- char_length - -**Related functions**: - -- [bit_length](#bit_length) -- [octet_length](#octet_length) - -### `chr` - -Returns the character with the specified ASCII or Unicode code value. - -``` -chr(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select chr(128640); -+--------------------+ -| chr(Int64(128640)) | -+--------------------+ -| 🚀 | -+--------------------+ -``` - -**Related functions**: - -- [ascii](#ascii) - -### `concat` - -Concatenates multiple strings together. - -``` -concat(str[, ..., str_n]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **str_n**: Subsequent string expressions to concatenate. - -#### Example - -```sql -> select concat('data', 'f', 'us', 'ion'); -+-------------------------------------------------------+ -| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | -+-------------------------------------------------------+ -| datafusion | -+-------------------------------------------------------+ -``` - -**Related functions**: - -- [concat_ws](#concat_ws) - -### `concat_ws` - -Concatenates multiple strings together with a specified separator. - -``` -concat_ws(separator, str[, ..., str_n]) -``` - -#### Arguments - -- **separator**: Separator to insert between concatenated strings. -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **str_n**: Subsequent string expressions to concatenate. - -#### Example - -```sql -> select concat_ws('_', 'data', 'fusion'); -+--------------------------------------------------+ -| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | -+--------------------------------------------------+ -| data_fusion | -+--------------------------------------------------+ -``` - -**Related functions**: - -- [concat](#concat) - -### `contains` - -Return true if search_str is found within string (case-sensitive). - -``` -contains(str, search_str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **search_str**: The string to search for in str. - -#### Example - -```sql -> select contains('the quick brown fox', 'row'); -+---------------------------------------------------+ -| contains(Utf8("the quick brown fox"),Utf8("row")) | -+---------------------------------------------------+ -| true | -+---------------------------------------------------+ -``` - -### `ends_with` - -Tests if a string ends with a substring. - -``` -ends_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **substr**: Substring to test for. - -#### Example - -```sql -> select ends_with('datafusion', 'soin'); -+--------------------------------------------+ -| ends_with(Utf8("datafusion"),Utf8("soin")) | -+--------------------------------------------+ -| false | -+--------------------------------------------+ -> select ends_with('datafusion', 'sion'); -+--------------------------------------------+ -| ends_with(Utf8("datafusion"),Utf8("sion")) | -+--------------------------------------------+ -| true | -+--------------------------------------------+ -``` - -### `find_in_set` - -Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. - -``` -find_in_set(str, strlist) -``` - -#### Arguments - -- **str**: String expression to find in strlist. -- **strlist**: A string list is a string composed of substrings separated by , characters. - -#### Example - -```sql -> select find_in_set('b', 'a,b,c,d'); -+----------------------------------------+ -| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | -+----------------------------------------+ -| 2 | -+----------------------------------------+ -``` - -### `initcap` - -Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. - -``` -initcap(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select initcap('apache datafusion'); -+------------------------------------+ -| initcap(Utf8("apache datafusion")) | -+------------------------------------+ -| Apache Datafusion | -+------------------------------------+ -``` - -**Related functions**: - -- [lower](#lower) -- [upper](#upper) - -### `instr` - -_Alias of [strpos](#strpos)._ - -### `left` - -Returns a specified number of characters from the left side of a string. - -``` -left(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: Number of characters to return. - -#### Example - -```sql -> select left('datafusion', 4); -+-----------------------------------+ -| left(Utf8("datafusion"),Int64(4)) | -+-----------------------------------+ -| data | -+-----------------------------------+ -``` - -**Related functions**: - -- [right](#right) - -### `length` - -_Alias of [character_length](#character_length)._ - -### `levenshtein` - -Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings. - -``` -levenshtein(str1, str2) -``` - -#### Arguments - -- **str1**: String expression to compute Levenshtein distance with str2. -- **str2**: String expression to compute Levenshtein distance with str1. - -#### Example - -```sql -> select levenshtein('kitten', 'sitting'); -+---------------------------------------------+ -| levenshtein(Utf8("kitten"),Utf8("sitting")) | -+---------------------------------------------+ -| 3 | -+---------------------------------------------+ -``` - -### `lower` - -Converts a string to lower-case. - -``` -lower(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select lower('Ångström'); -+-------------------------+ -| lower(Utf8("Ångström")) | -+-------------------------+ -| ångström | -+-------------------------+ -``` - -**Related functions**: - -- [initcap](#initcap) -- [upper](#upper) - -### `lpad` - -Pads the left side of a string with another string to a specified string length. - -``` -lpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: String length to pad to. -- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ - -#### Example - -```sql -> select lpad('Dolly', 10, 'hello'); -+---------------------------------------------+ -| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | -+---------------------------------------------+ -| helloDolly | -+---------------------------------------------+ -``` - -**Related functions**: - -- [rpad](#rpad) - -### `ltrim` - -Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. - -``` -ltrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ - -#### Example - -```sql -> select ltrim(' datafusion '); -+-------------------------------+ -| ltrim(Utf8(" datafusion ")) | -+-------------------------------+ -| datafusion | -+-------------------------------+ -> select ltrim('___datafusion___', '_'); -+-------------------------------------------+ -| ltrim(Utf8("___datafusion___"),Utf8("_")) | -+-------------------------------------------+ -| datafusion___ | -+-------------------------------------------+ -``` - -#### Alternative Syntax - -```sql -trim(LEADING trim_str FROM str) -``` - -**Related functions**: - -- [btrim](#btrim) -- [rtrim](#rtrim) - -### `octet_length` - -Returns the length of a string in bytes. - -``` -octet_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select octet_length('Ångström'); -+--------------------------------+ -| octet_length(Utf8("Ångström")) | -+--------------------------------+ -| 10 | -+--------------------------------+ -``` - -**Related functions**: - -- [bit_length](#bit_length) -- [length](#length) - -### `position` - -_Alias of [strpos](#strpos)._ - -### `repeat` - -Returns a string with an input string repeated a specified number. - -``` -repeat(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: Number of times to repeat the input string. - -#### Example - -```sql -> select repeat('data', 3); -+-------------------------------+ -| repeat(Utf8("data"),Int64(3)) | -+-------------------------------+ -| datadatadata | -+-------------------------------+ -``` - -### `replace` - -Replaces all occurrences of a specified substring in a string with a new substring. - -``` -replace(str, substr, replacement) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **substr**: Substring expression to replace in the input string. Substring expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select replace('ABabbaBA', 'ab', 'cd'); -+-------------------------------------------------+ -| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | -+-------------------------------------------------+ -| ABcdbaBA | -+-------------------------------------------------+ -``` - -### `reverse` - -Reverses the character order of a string. - -``` -reverse(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select reverse('datafusion'); -+-----------------------------+ -| reverse(Utf8("datafusion")) | -+-----------------------------+ -| noisufatad | -+-----------------------------+ -``` - -### `right` - -Returns a specified number of characters from the right side of a string. - -``` -right(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: Number of characters to return - -#### Example - -```sql -> select right('datafusion', 6); -+------------------------------------+ -| right(Utf8("datafusion"),Int64(6)) | -+------------------------------------+ -| fusion | -+------------------------------------+ -``` - -**Related functions**: - -- [left](#left) - -### `rpad` - -Pads the right side of a string with another string to a specified string length. - -``` -rpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **n**: String length to pad to. -- **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ - -#### Example - -```sql -> select rpad('datafusion', 20, '_-'); -+-----------------------------------------------+ -| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | -+-----------------------------------------------+ -| datafusion_-_-_-_-_- | -+-----------------------------------------------+ -``` - -**Related functions**: - -- [lpad](#lpad) - -### `rtrim` - -Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. - -``` -rtrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ - -#### Example - -```sql -> select rtrim(' datafusion '); -+-------------------------------+ -| rtrim(Utf8(" datafusion ")) | -+-------------------------------+ -| datafusion | -+-------------------------------+ -> select rtrim('___datafusion___', '_'); -+-------------------------------------------+ -| rtrim(Utf8("___datafusion___"),Utf8("_")) | -+-------------------------------------------+ -| ___datafusion | -+-------------------------------------------+ -``` - -#### Alternative Syntax - -```sql -trim(TRAILING trim_str FROM str) -``` - -**Related functions**: - -- [btrim](#btrim) -- [ltrim](#ltrim) - -### `split_part` - -Splits a string based on a specified delimiter and returns the substring in the specified position. - -``` -split_part(str, delimiter, pos) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **delimiter**: String or character to split on. -- **pos**: Position of the part to return. - -#### Example - -```sql -> select split_part('1.2.3.4.5', '.', 3); -+--------------------------------------------------+ -| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | -+--------------------------------------------------+ -| 3 | -+--------------------------------------------------+ -``` - -### `starts_with` - -Tests if a string starts with a substring. - -``` -starts_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **substr**: Substring to test for. - -#### Example - -```sql -> select starts_with('datafusion','data'); -+----------------------------------------------+ -| starts_with(Utf8("datafusion"),Utf8("data")) | -+----------------------------------------------+ -| true | -+----------------------------------------------+ -``` - -### `strpos` - -Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0. - -``` -strpos(str, substr) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **substr**: Substring expression to search for. - -#### Example - -```sql -> select strpos('datafusion', 'fus'); -+----------------------------------------+ -| strpos(Utf8("datafusion"),Utf8("fus")) | -+----------------------------------------+ -| 5 | -+----------------------------------------+ -``` - -#### Alternative Syntax - -```sql -position(substr in origstr) -``` - -#### Aliases - -- instr -- position - -### `substr` - -Extracts a substring of a specified number of characters from a specific starting position in a string. - -``` -substr(str, start_pos[, length]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **start_pos**: Character position to start the substring at. The first character in the string has a position of 1. -- **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. - -#### Example - -```sql -> select substr('datafusion', 5, 3); -+----------------------------------------------+ -| substr(Utf8("datafusion"),Int64(5),Int64(3)) | -+----------------------------------------------+ -| fus | -+----------------------------------------------+ -``` - -#### Alternative Syntax - -```sql -substring(str from start_pos for length) -``` - -#### Aliases - -- substring - -### `substr_index` - -Returns the substring from str before count occurrences of the delimiter delim. -If count is positive, everything to the left of the final delimiter (counting from the left) is returned. -If count is negative, everything to the right of the final delimiter (counting from the right) is returned. - -``` -substr_index(str, delim, count) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **delim**: The string to find in str to split str. -- **count**: The number of times to search for the delimiter. Can be either a positive or negative number. - -#### Example - -```sql -> select substr_index('www.apache.org', '.', 1); -+---------------------------------------------------------+ -| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | -+---------------------------------------------------------+ -| www | -+---------------------------------------------------------+ -> select substr_index('www.apache.org', '.', -1); -+----------------------------------------------------------+ -| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | -+----------------------------------------------------------+ -| org | -+----------------------------------------------------------+ -``` - -#### Aliases - -- substring_index - -### `substring` - -_Alias of [substr](#substr)._ - -### `substring_index` - -_Alias of [substr_index](#substr_index)._ - -### `to_hex` - -Converts an integer to a hexadecimal string. - -``` -to_hex(int) -``` - -#### Arguments - -- **int**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select to_hex(12345689); -+-------------------------+ -| to_hex(Int64(12345689)) | -+-------------------------+ -| bc6159 | -+-------------------------+ -``` - -### `translate` - -Translates characters in a string to specified translation characters. - -``` -translate(str, chars, translation) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. - -#### Example - -```sql -> select translate('twice', 'wic', 'her'); -+--------------------------------------------------+ -| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | -+--------------------------------------------------+ -| there | -+--------------------------------------------------+ -``` - -### `trim` - -_Alias of [btrim](#btrim)._ - -### `upper` - -Converts a string to upper-case. - -``` -upper(str) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select upper('dataFusion'); -+---------------------------+ -| upper(Utf8("dataFusion")) | -+---------------------------+ -| DATAFUSION | -+---------------------------+ -``` - -**Related functions**: - -- [initcap](#initcap) -- [lower](#lower) - -### `uuid` - -Returns [`UUID v4`]() string value which is unique per row. - -``` -uuid() -``` - -#### Example - -```sql -> select uuid(); -+--------------------------------------+ -| uuid() | -+--------------------------------------+ -| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | -+--------------------------------------+ -``` - -## Binary String Functions - -- [decode](#decode) -- [encode](#encode) - -### `decode` - -Decode binary data from textual representation in string. - -``` -decode(expression, format) -``` - -#### Arguments - -- **expression**: Expression containing encoded string data -- **format**: Same arguments as [encode](#encode) - -**Related functions**: - -- [encode](#encode) - -### `encode` - -Encode binary data into a textual representation. - -``` -encode(expression, format) -``` - -#### Arguments - -- **expression**: Expression containing string or binary data -- **format**: Supported formats are: `base64`, `hex` - -**Related functions**: - -- [decode](#decode) - -## Regular Expression Functions - -Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) -regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) -(minus support for several features including look-around and backreferences). -The following regular expression functions are supported: - -- [regexp_count](#regexp_count) -- [regexp_like](#regexp_like) -- [regexp_match](#regexp_match) -- [regexp_replace](#regexp_replace) - -### `regexp_count` - -Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. - -``` -regexp_count(str, regexp[, start, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -> select regexp_count('abcAbAbc', 'abc', 2, 'i'); -+---------------------------------------------------------------+ -| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | -+---------------------------------------------------------------+ -| 1 | -+---------------------------------------------------------------+ -``` - -### `regexp_like` - -Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. - -``` -regexp_like(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+--------------------------------------------------------+ -| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+--------------------------------------------------------+ -| true | -+--------------------------------------------------------+ -SELECT regexp_like('aBc', '(b|d)', 'i'); -+--------------------------------------------------+ -| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+--------------------------------------------------+ -| true | -+--------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_match` - -Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. - -``` -regexp_match(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql - > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); - +---------------------------------------------------------+ - | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | - +---------------------------------------------------------+ - | [Köln] | - +---------------------------------------------------------+ - SELECT regexp_match('aBc', '(b|d)', 'i'); - +---------------------------------------------------+ - | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | - +---------------------------------------------------+ - | [B] | - +---------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_replace` - -Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). - -``` -regexp_replace(str, regexp, replacement[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **replacement**: Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: -- **g**: (global) Search globally and don't return after the first match -- **i**: case-insensitive: letters match both upper and lower case -- **m**: multi-line mode: ^ and $ match begin/end of line -- **s**: allow . to match \n -- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used -- **U**: swap the meaning of x* and x*? - -#### Example - -```sql -> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); -+------------------------------------------------------------------------+ -| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | -+------------------------------------------------------------------------+ -| fooXarYXazY | -+------------------------------------------------------------------------+ -SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); -+-------------------------------------------------------------------+ -| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | -+-------------------------------------------------------------------+ -| aAbBac | -+-------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -## Time and Date Functions - -- [current_date](#current_date) -- [current_time](#current_time) -- [current_timestamp](#current_timestamp) -- [date_bin](#date_bin) -- [date_format](#date_format) -- [date_part](#date_part) -- [date_trunc](#date_trunc) -- [datepart](#datepart) -- [datetrunc](#datetrunc) -- [from_unixtime](#from_unixtime) -- [make_date](#make_date) -- [now](#now) -- [to_char](#to_char) -- [to_date](#to_date) -- [to_local_time](#to_local_time) -- [to_timestamp](#to_timestamp) -- [to_timestamp_micros](#to_timestamp_micros) -- [to_timestamp_millis](#to_timestamp_millis) -- [to_timestamp_nanos](#to_timestamp_nanos) -- [to_timestamp_seconds](#to_timestamp_seconds) -- [to_unixtime](#to_unixtime) -- [today](#today) - -### `current_date` - -Returns the current UTC date. - -The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. - -``` -current_date() -``` - -#### Aliases - -- today - -### `current_time` - -Returns the current UTC time. - -The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. - -``` -current_time() -``` - -### `current_timestamp` - -_Alias of [now](#now)._ - -### `date_bin` - -Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. - -For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. - -``` -date_bin(interval, expression, origin-timestamp) -``` - -#### Arguments - -- **interval**: Bin interval. -- **expression**: Time expression to operate on. Can be a constant, column, or function. -- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). - -The following intervals are supported: - -- nanoseconds -- microseconds -- milliseconds -- seconds -- minutes -- hours -- days -- weeks -- months -- years -- century - -### `date_format` - -_Alias of [to_char](#to_char)._ - -### `date_part` - -Returns the specified part of the date as an integer. - -``` -date_part(part, expression) -``` - -#### Arguments - -- **part**: Part of the date to return. The following date parts are supported: - - - year - - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - - month - - week (week of the year) - - day (day of the month) - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow (day of the week) - - doy (day of the year) - - epoch (seconds since Unix epoch) - -- **expression**: Time expression to operate on. Can be a constant, column, or function. - -#### Alternative Syntax - -```sql -extract(field FROM source) -``` - -#### Aliases - -- datepart - -### `date_trunc` - -Truncates a timestamp value to a specified precision. - -``` -date_trunc(precision, expression) -``` - -#### Arguments - -- **precision**: Time precision to truncate to. The following precisions are supported: - - - year / YEAR - - quarter / QUARTER - - month / MONTH - - week / WEEK - - day / DAY - - hour / HOUR - - minute / MINUTE - - second / SECOND - -- **expression**: Time expression to operate on. Can be a constant, column, or function. - -#### Aliases - -- datetrunc - -### `datepart` - -_Alias of [date_part](#date_part)._ - -### `datetrunc` - -_Alias of [date_trunc](#date_trunc)._ - -### `from_unixtime` - -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. - -``` -from_unixtime(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. - -### `make_date` - -Make a date from year/month/day component parts. - -``` -make_date(year, month, day) -``` - -#### Arguments - -- **year**: Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. -- **month**: Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. -- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. - -#### Example - -```sql -> select make_date(2023, 1, 31); -+-------------------------------------------+ -| make_date(Int64(2023),Int64(1),Int64(31)) | -+-------------------------------------------+ -| 2023-01-31 | -+-------------------------------------------+ -> select make_date('2023', '01', '31'); -+-----------------------------------------------+ -| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | -+-----------------------------------------------+ -| 2023-01-31 | -+-----------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) - -### `now` - -Returns the current UTC timestamp. - -The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. - -``` -now() -``` - -#### Aliases - -- current_timestamp - -### `to_char` - -Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. - -``` -to_char(expression, format) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration. -- **format**: A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression. -- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. - -#### Example - -```sql -> select to_char('2023-03-01'::date, '%d-%m-%Y'); -+----------------------------------------------+ -| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | -+----------------------------------------------+ -| 01-03-2023 | -+----------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) - -#### Aliases - -- date_format - -### `to_date` - -Converts a value to a date (`YYYY-MM-DD`). -Supports strings, integer and double types as input. -Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. -Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding date. - -Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. - -``` -to_date('2017-05-31', '%Y-%m-%d') -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -#### Example - -```sql -> select to_date('2023-01-31'); -+-----------------------------+ -| to_date(Utf8("2023-01-31")) | -+-----------------------------+ -| 2023-01-31 | -+-----------------------------+ -> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); -+---------------------------------------------------------------+ -| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | -+---------------------------------------------------------------+ -| 2023-01-31 | -+---------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) - -### `to_local_time` - -Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes. - -``` -to_local_time(expression) -``` - -#### Arguments - -- **expression**: Time expression to operate on. Can be a constant, column, or function. - -#### Example - -```sql -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); -+---------------------------------------------+ -| to_local_time(Utf8("2024-04-01T00:00:20Z")) | -+---------------------------------------------+ -| 2024-04-01T00:00:20 | -+---------------------------------------------+ - -> SELECT - time, - arrow_typeof(time) as type, - to_local_time(time) as to_local_time, - arrow_typeof(to_local_time(time)) as to_local_time_type -FROM ( - SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time -); -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| time | type | to_local_time | to_local_time_type | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ -| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | -+---------------------------+------------------------------------------------+---------------------+-----------------------------+ - -# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather -# than UTC boundaries - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; -+---------------------+ -| date_bin | -+---------------------+ -| 2024-04-01T00:00:00 | -+---------------------+ - -> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; -+---------------------------+ -| date_bin_with_timezone | -+---------------------------+ -| 2024-04-01T00:00:00+02:00 | -+---------------------------+ -``` - -### `to_timestamp` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. - -Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. - -``` -to_timestamp(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------+ -| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------+ -> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------+ -| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+--------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_micros` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp. - -``` -to_timestamp_micros(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456 | -+------------------------------------------------------------------+ -> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_millis` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. - -``` -to_timestamp_millis(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); -+------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123 | -+------------------------------------------------------------------+ -> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_nanos` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. - -``` -to_timestamp_nanos(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); -+-----------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-----------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456789 | -+-----------------------------------------------------------------+ -> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+--------------------------------------------------------------------------------------------------------------+ -| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+--------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456789 | -+---------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_timestamp_seconds` - -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. - -``` -to_timestamp_seconds(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); -+-------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | -+-------------------------------------------------------------------+ -| 2023-01-31T14:26:56 | -+-------------------------------------------------------------------+ -> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); -+----------------------------------------------------------------------------------------------------------------+ -| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | -+----------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00 | -+----------------------------------------------------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) - -### `to_unixtime` - -Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. - -``` -to_unixtime(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. - -#### Example - -```sql -> select to_unixtime('2020-09-08T12:00:00+00:00'); -+------------------------------------------------+ -| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | -+------------------------------------------------+ -| 1599566400 | -+------------------------------------------------+ -> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); -+-----------------------------------------------------------------------------------------------------------------------------+ -| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | -+-----------------------------------------------------------------------------------------------------------------------------+ -| 1673638290 | -+-----------------------------------------------------------------------------------------------------------------------------+ -``` - -### `today` - -_Alias of [current_date](#current_date)._ - -## Array Functions - -- [array_any_value](#array_any_value) -- [array_append](#array_append) -- [array_cat](#array_cat) -- [array_concat](#array_concat) -- [array_contains](#array_contains) -- [array_dims](#array_dims) -- [array_distance](#array_distance) -- [array_distinct](#array_distinct) -- [array_element](#array_element) -- [array_empty](#array_empty) -- [array_except](#array_except) -- [array_extract](#array_extract) -- [array_has](#array_has) -- [array_has_all](#array_has_all) -- [array_has_any](#array_has_any) -- [array_indexof](#array_indexof) -- [array_intersect](#array_intersect) -- [array_join](#array_join) -- [array_length](#array_length) -- [array_ndims](#array_ndims) -- [array_pop_back](#array_pop_back) -- [array_pop_front](#array_pop_front) -- [array_position](#array_position) -- [array_positions](#array_positions) -- [array_prepend](#array_prepend) -- [array_push_back](#array_push_back) -- [array_push_front](#array_push_front) -- [array_remove](#array_remove) -- [array_remove_all](#array_remove_all) -- [array_remove_n](#array_remove_n) -- [array_repeat](#array_repeat) -- [array_replace](#array_replace) -- [array_replace_all](#array_replace_all) -- [array_replace_n](#array_replace_n) -- [array_resize](#array_resize) -- [array_reverse](#array_reverse) -- [array_slice](#array_slice) -- [array_sort](#array_sort) -- [array_to_string](#array_to_string) -- [array_union](#array_union) -- [cardinality](#cardinality) -- [empty](#empty) -- [flatten](#flatten) -- [generate_series](#generate_series) -- [list_any_value](#list_any_value) -- [list_append](#list_append) -- [list_cat](#list_cat) -- [list_concat](#list_concat) -- [list_contains](#list_contains) -- [list_dims](#list_dims) -- [list_distance](#list_distance) -- [list_distinct](#list_distinct) -- [list_element](#list_element) -- [list_empty](#list_empty) -- [list_except](#list_except) -- [list_extract](#list_extract) -- [list_has](#list_has) -- [list_has_all](#list_has_all) -- [list_has_any](#list_has_any) -- [list_indexof](#list_indexof) -- [list_intersect](#list_intersect) -- [list_join](#list_join) -- [list_length](#list_length) -- [list_ndims](#list_ndims) -- [list_pop_back](#list_pop_back) -- [list_pop_front](#list_pop_front) -- [list_position](#list_position) -- [list_positions](#list_positions) -- [list_prepend](#list_prepend) -- [list_push_back](#list_push_back) -- [list_push_front](#list_push_front) -- [list_remove](#list_remove) -- [list_remove_all](#list_remove_all) -- [list_remove_n](#list_remove_n) -- [list_repeat](#list_repeat) -- [list_replace](#list_replace) -- [list_replace_all](#list_replace_all) -- [list_replace_n](#list_replace_n) -- [list_resize](#list_resize) -- [list_reverse](#list_reverse) -- [list_slice](#list_slice) -- [list_sort](#list_sort) -- [list_to_string](#list_to_string) -- [list_union](#list_union) -- [make_array](#make_array) -- [make_list](#make_list) -- [range](#range) -- [string_to_array](#string_to_array) -- [string_to_list](#string_to_list) - -### `array_any_value` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -```sql -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- list_any_value - -### `array_append` - -Appends an element to the end of an array. - -``` -array_append(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to append to the array. - -#### Example - -```sql -> select array_append([1, 2, 3], 4); -+--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ -``` - -#### Aliases - -- list_append -- array_push_back -- list_push_back - -### `array_cat` - -_Alias of [array_concat](#array_concat)._ - -### `array_concat` - -Appends an element to the end of an array. - -``` -array_append(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to append to the array. - -#### Example - -```sql -> select array_append([1, 2, 3], 4); -+--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ -``` - -#### Aliases - -- array_cat -- list_concat -- list_cat - -### `array_contains` - -_Alias of [array_has](#array_has)._ - -### `array_dims` - -Returns an array of the array's dimensions. - -``` -array_dims(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_dims([[1, 2, 3], [4, 5, 6]]); -+---------------------------------+ -| array_dims(List([1,2,3,4,5,6])) | -+---------------------------------+ -| [2, 3] | -+---------------------------------+ -``` - -#### Aliases - -- list_dims - -### `array_distance` - -Returns the Euclidean distance between two input arrays of equal length. - -``` -array_distance(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_distance([1, 2], [1, 4]); -+------------------------------------+ -| array_distance(List([1,2], [1,4])) | -+------------------------------------+ -| 2.0 | -+------------------------------------+ -``` - -#### Aliases - -- list_distance - -### `array_distinct` - -Returns distinct values from the array after removing duplicates. - -``` -array_distinct(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_distinct([1, 3, 2, 3, 1, 2, 4]); -+---------------------------------+ -| array_distinct(List([1,2,3,4])) | -+---------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------+ -``` - -#### Aliases - -- list_distinct - -### `array_element` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -```sql -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- array_extract -- list_element -- list_extract - -### `array_empty` - -_Alias of [empty](#empty)._ - -### `array_except` - -Returns an array of the elements that appear in the first array but not in the second. - -``` -array_except(array1, array2) -``` - -#### Arguments - -- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ -> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ -``` - -#### Aliases - -- list_except - -### `array_extract` - -_Alias of [array_element](#array_element)._ - -### `array_has` - -Returns true if the array contains the element. - -``` -array_has(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_has([1, 2, 3], 2); -+-----------------------------+ -| array_has(List([1,2,3]), 2) | -+-----------------------------+ -| true | -+-----------------------------+ -``` - -#### Aliases - -- list_has -- array_contains -- list_contains - -### `array_has_all` - -Returns true if the array contains the element. - -``` -array_has(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_has([1, 2, 3], 2); -+-----------------------------+ -| array_has(List([1,2,3]), 2) | -+-----------------------------+ -| true | -+-----------------------------+ -``` - -#### Aliases - -- list_has_all - -### `array_has_any` - -Returns true if the array contains the element. - -``` -array_has(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_has([1, 2, 3], 2); -+-----------------------------+ -| array_has(List([1,2,3]), 2) | -+-----------------------------+ -| true | -+-----------------------------+ -``` - -#### Aliases - -- list_has_any - -### `array_indexof` - -_Alias of [array_position](#array_position)._ - -### `array_intersect` - -Returns distinct values from the array after removing duplicates. - -``` -array_distinct(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_distinct([1, 3, 2, 3, 1, 2, 4]); -+---------------------------------+ -| array_distinct(List([1,2,3,4])) | -+---------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------+ -``` - -#### Aliases - -- list_intersect - -### `array_join` - -_Alias of [array_to_string](#array_to_string)._ - -### `array_length` - -Returns the length of the array dimension. - -``` -array_length(array, dimension) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **dimension**: Array dimension. - -#### Example - -```sql -> select array_length([1, 2, 3, 4, 5], 1); -+-------------------------------------------+ -| array_length(List([1,2,3,4,5]), 1) | -+-------------------------------------------+ -| 5 | -+-------------------------------------------+ -``` - -#### Aliases - -- list_length - -### `array_ndims` - -Returns an array of the array's dimensions. - -``` -array_dims(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_dims([[1, 2, 3], [4, 5, 6]]); -+---------------------------------+ -| array_dims(List([1,2,3,4,5,6])) | -+---------------------------------+ -| [2, 3] | -+---------------------------------+ -``` - -#### Aliases - -- list_ndims - -### `array_pop_back` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -```sql -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- list_pop_back - -### `array_pop_front` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -```sql -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- list_pop_front - -### `array_position` - -Returns the position of the first occurrence of the specified element in the array. - -``` -array_position(array, element) -array_position(array, element, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. -- **index**: Index at which to start searching. - -#### Example - -```sql -> select array_position([1, 2, 2, 3, 1, 4], 2); -+----------------------------------------------+ -| array_position(List([1,2,2,3,1,4]),Int64(2)) | -+----------------------------------------------+ -| 2 | -+----------------------------------------------+ -> select array_position([1, 2, 2, 3, 1, 4], 2, 3); -+----------------------------------------------------+ -| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | -+----------------------------------------------------+ -| 3 | -+----------------------------------------------------+ -``` - -#### Aliases - -- list_position -- array_indexof -- list_indexof - -### `array_positions` - -Returns the position of the first occurrence of the specified element in the array. - -``` -array_position(array, element) -array_position(array, element, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for position in the array. -- **index**: Index at which to start searching. - -#### Example - -```sql -> select array_position([1, 2, 2, 3, 1, 4], 2); -+----------------------------------------------+ -| array_position(List([1,2,2,3,1,4]),Int64(2)) | -+----------------------------------------------+ -| 2 | -+----------------------------------------------+ -> select array_position([1, 2, 2, 3, 1, 4], 2, 3); -+----------------------------------------------------+ -| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | -+----------------------------------------------------+ -| 3 | -+----------------------------------------------------+ -``` - -#### Aliases - -- list_positions - -### `array_prepend` - -Appends an element to the end of an array. - -``` -array_append(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to append to the array. - -#### Example - -```sql -> select array_append([1, 2, 3], 4); -+--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ -``` - -#### Aliases - -- list_prepend -- array_push_front -- list_push_front - -### `array_push_back` - -_Alias of [array_append](#array_append)._ - -### `array_push_front` - -_Alias of [array_prepend](#array_prepend)._ - -### `array_remove` - -Removes the first element from the array equal to the given value. - -``` -array_remove(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -#### Example - -```sql -> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); -+----------------------------------------------+ -| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | -+----------------------------------------------+ -| [1, 2, 3, 2, 1, 4] | -+----------------------------------------------+ -``` - -#### Aliases - -- list_remove - -### `array_remove_all` - -Removes the first element from the array equal to the given value. - -``` -array_remove(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -#### Example - -```sql -> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); -+----------------------------------------------+ -| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | -+----------------------------------------------+ -| [1, 2, 3, 2, 1, 4] | -+----------------------------------------------+ -``` - -#### Aliases - -- list_remove_all - -### `array_remove_n` - -Removes the first element from the array equal to the given value. - -``` -array_remove(array, element) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. - -#### Example - -```sql -> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); -+----------------------------------------------+ -| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | -+----------------------------------------------+ -| [1, 2, 3, 2, 1, 4] | -+----------------------------------------------+ -``` - -#### Aliases - -- list_remove_n - -### `array_repeat` - -Returns an array containing element `count` times. - -``` -array_repeat(element, count) -``` - -#### Arguments - -- **element**: Element expression. Can be a constant, column, or function, and any combination of array operators. -- **count**: Value of how many times to repeat the element. - -#### Example - -```sql -> select array_repeat(1, 3); -+---------------------------------+ -| array_repeat(Int64(1),Int64(3)) | -+---------------------------------+ -| [1, 1, 1] | -+---------------------------------+ -> select array_repeat([1, 2], 2); -+------------------------------------+ -| array_repeat(List([1,2]),Int64(2)) | -+------------------------------------+ -| [[1, 2], [1, 2]] | -+------------------------------------+ -``` - -#### Aliases - -- list_repeat - -### `array_replace` - -Replaces the first `max` occurrences of the specified element with another specified element. - -``` -array_replace_n(array, from, to, max) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. -- **max**: Number of first occurrences to replace. - -#### Example - -```sql -> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); -+-------------------------------------------------------------------+ -| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | -+-------------------------------------------------------------------+ -| [1, 5, 5, 3, 2, 1, 4] | -+-------------------------------------------------------------------+ -``` - -#### Aliases - -- list_replace - -### `array_replace_all` - -Replaces the first `max` occurrences of the specified element with another specified element. - -``` -array_replace_n(array, from, to, max) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. -- **max**: Number of first occurrences to replace. - -#### Example - -```sql -> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); -+-------------------------------------------------------------------+ -| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | -+-------------------------------------------------------------------+ -| [1, 5, 5, 3, 2, 1, 4] | -+-------------------------------------------------------------------+ -``` - -#### Aliases - -- list_replace_all - -### `array_replace_n` - -Replaces the first `max` occurrences of the specified element with another specified element. - -``` -array_replace_n(array, from, to, max) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **from**: Initial element. -- **to**: Final element. -- **max**: Number of first occurrences to replace. - -#### Example - -```sql -> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); -+-------------------------------------------------------------------+ -| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | -+-------------------------------------------------------------------+ -| [1, 5, 5, 3, 2, 1, 4] | -+-------------------------------------------------------------------+ -``` - -#### Aliases - -- list_replace_n - -### `array_resize` - -Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. - -``` -array_resize(array, size, value) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **size**: New size of given array. -- **value**: Defines new elements' value or empty if value is not set. - -#### Example - -```sql -> select array_resize([1, 2, 3], 5, 0); -+-------------------------------------+ -| array_resize(List([1,2,3],5,0)) | -+-------------------------------------+ -| [1, 2, 3, 0, 0] | -+-------------------------------------+ -``` - -#### Aliases - -- list_resize - -### `array_reverse` - -Returns the array with the order of the elements reversed. - -``` -array_reverse(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_reverse([1, 2, 3, 4]); -+------------------------------------------------------------+ -| array_reverse(List([1, 2, 3, 4])) | -+------------------------------------------------------------+ -| [4, 3, 2, 1] | -+------------------------------------------------------------+ -``` - -#### Aliases - -- list_reverse - -### `array_slice` - -Extracts the element with the index n from the array. - -``` -array_element(array, index) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **index**: Index to extract the element from the array. - -#### Example - -```sql -> select array_element([1, 2, 3, 4], 3); -+-----------------------------------------+ -| array_element(List([1,2,3,4]),Int64(3)) | -+-----------------------------------------+ -| 3 | -+-----------------------------------------+ -``` - -#### Aliases - -- list_slice - -### `array_sort` - -Sort array. - -``` -array_sort(array, desc, nulls_first) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **desc**: Whether to sort in descending order(`ASC` or `DESC`). -- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). - -#### Example - -```sql -> select array_sort([3, 1, 2]); -+-----------------------------+ -| array_sort(List([3,1,2])) | -+-----------------------------+ -| [1, 2, 3] | -+-----------------------------+ -``` - -#### Aliases - -- list_sort - -### `array_to_string` - -Converts each element to its text representation. - -``` -array_to_string(array, delimiter) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **delimiter**: Array element separator. - -#### Example - -```sql -> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); -+----------------------------------------------------+ -| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | -+----------------------------------------------------+ -| 1,2,3,4,5,6,7,8 | -+----------------------------------------------------+ -``` - -#### Aliases - -- list_to_string -- array_join -- list_join - -### `array_union` - -Returns distinct values from the array after removing duplicates. - -``` -array_distinct(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select array_distinct([1, 3, 2, 3, 1, 2, 4]); -+---------------------------------+ -| array_distinct(List([1,2,3,4])) | -+---------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------+ -``` - -#### Aliases - -- list_union - -### `cardinality` - -Returns the total number of elements in the array. - -``` -cardinality(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); -+--------------------------------------+ -| cardinality(List([1,2,3,4,5,6,7,8])) | -+--------------------------------------+ -| 8 | -+--------------------------------------+ -``` - -### `empty` - -Returns 1 for an empty array or 0 for a non-empty array. - -``` -empty(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select empty([1]); -+------------------+ -| empty(List([1])) | -+------------------+ -| 0 | -+------------------+ -``` - -#### Aliases - -- array_empty -- list_empty - -### `flatten` - -Converts an array of arrays to a flat array. - -- Applies to any depth of nested arrays -- Does not change arrays that are already flat - -The flattened array contains all the elements from all source arrays. - -``` -flatten(array) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - -#### Example - -```sql -> select flatten([[1, 2], [3, 4]]); -+------------------------------+ -| flatten(List([1,2], [3,4])) | -+------------------------------+ -| [1, 2, 3, 4] | -+------------------------------+ -``` - -### `generate_series` - -Similar to the range function, but it includes the upper bound. - -``` -generate_series(start, stop, step) -``` - -#### Arguments - -- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. -- **end**: end of the series (included). Type must be the same as start. -- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. - -#### Example - -```sql -> select generate_series(1,3); -+------------------------------------+ -| generate_series(Int64(1),Int64(3)) | -+------------------------------------+ -| [1, 2, 3] | -+------------------------------------+ -``` - -### `list_any_value` - -_Alias of [array_any_value](#array_any_value)._ - -### `list_append` - -_Alias of [array_append](#array_append)._ - -### `list_cat` - -_Alias of [array_concat](#array_concat)._ - -### `list_concat` - -_Alias of [array_concat](#array_concat)._ - -### `list_contains` - -_Alias of [array_has](#array_has)._ - -### `list_dims` - -_Alias of [array_dims](#array_dims)._ - -### `list_distance` - -_Alias of [array_distance](#array_distance)._ - -### `list_distinct` - -_Alias of [array_distinct](#array_distinct)._ - -### `list_element` - -_Alias of [array_element](#array_element)._ - -### `list_empty` - -_Alias of [empty](#empty)._ - -### `list_except` - -_Alias of [array_except](#array_except)._ - -### `list_extract` - -_Alias of [array_element](#array_element)._ - -### `list_has` - -_Alias of [array_has](#array_has)._ - -### `list_has_all` - -_Alias of [array_has_all](#array_has_all)._ - -### `list_has_any` - -_Alias of [array_has_any](#array_has_any)._ - -### `list_indexof` - -_Alias of [array_position](#array_position)._ - -### `list_intersect` - -_Alias of [array_intersect](#array_intersect)._ - -### `list_join` - -_Alias of [array_to_string](#array_to_string)._ - -### `list_length` - -_Alias of [array_length](#array_length)._ - -### `list_ndims` - -_Alias of [array_ndims](#array_ndims)._ - -### `list_pop_back` - -_Alias of [array_pop_back](#array_pop_back)._ - -### `list_pop_front` - -_Alias of [array_pop_front](#array_pop_front)._ - -### `list_position` - -_Alias of [array_position](#array_position)._ - -### `list_positions` - -_Alias of [array_positions](#array_positions)._ - -### `list_prepend` - -_Alias of [array_prepend](#array_prepend)._ - -### `list_push_back` - -_Alias of [array_append](#array_append)._ - -### `list_push_front` - -_Alias of [array_prepend](#array_prepend)._ - -### `list_remove` - -_Alias of [array_remove](#array_remove)._ - -### `list_remove_all` - -_Alias of [array_remove_all](#array_remove_all)._ - -### `list_remove_n` - -_Alias of [array_remove_n](#array_remove_n)._ - -### `list_repeat` - -_Alias of [array_repeat](#array_repeat)._ - -### `list_replace` - -_Alias of [array_replace](#array_replace)._ - -### `list_replace_all` - -_Alias of [array_replace_all](#array_replace_all)._ - -### `list_replace_n` - -_Alias of [array_replace_n](#array_replace_n)._ - -### `list_resize` - -_Alias of [array_resize](#array_resize)._ - -### `list_reverse` - -_Alias of [array_reverse](#array_reverse)._ - -### `list_slice` - -_Alias of [array_slice](#array_slice)._ - -### `list_sort` - -_Alias of [array_sort](#array_sort)._ - -### `list_to_string` - -_Alias of [array_to_string](#array_to_string)._ - -### `list_union` - -_Alias of [array_union](#array_union)._ - -### `make_array` - -Returns an array using the specified input expressions. - -``` -make_array(expression1[, ..., expression_n]) -``` - -#### Arguments - -- **expression_n**: Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators. - -#### Example - -```sql -> select make_array(1, 2, 3, 4, 5); -+----------------------------------------------------------+ -| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | -+----------------------------------------------------------+ -| [1, 2, 3, 4, 5] | -+----------------------------------------------------------+ -``` - -#### Aliases - -- make_list - -### `make_list` - -_Alias of [make_array](#make_array)._ - -### `range` - -Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0. - -``` -range(start, stop, step) -``` - -#### Arguments - -- **start**: Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. -- **end**: End of the range (not included). Type must be the same as start. -- **step**: Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges. - -#### Example - -```sql -> select range(2, 10, 3); -+-----------------------------------+ -| range(Int64(2),Int64(10),Int64(3))| -+-----------------------------------+ -| [2, 5, 8] | -+-----------------------------------+ - -> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); -+--------------------------------------------------------------+ -| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | -+--------------------------------------------------------------+ -| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | -+--------------------------------------------------------------+ -``` - -### `string_to_array` - -Converts each element to its text representation. - -``` -array_to_string(array, delimiter) -``` - -#### Arguments - -- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **delimiter**: Array element separator. - -#### Example - -```sql -> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); -+----------------------------------------------------+ -| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | -+----------------------------------------------------+ -| 1,2,3,4,5,6,7,8 | -+----------------------------------------------------+ -``` - -#### Aliases - -- string_to_list - -### `string_to_list` - -_Alias of [string_to_array](#string_to_array)._ - -## Struct Functions - -- [named_struct](#named_struct) -- [row](#row) -- [struct](#struct) - -### `named_struct` - -Returns an Arrow struct using the specified name and input expressions pairs. - -``` -named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) -``` - -#### Arguments - -- **expression_n_name**: Name of the column field. Must be a constant string. -- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators. - -#### Example - -For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `field_a` and `field_b`: - -```sql -> select * from t; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -| 3 | 4 | -+---+---+ -> select named_struct('field_a', a, 'field_b', b) from t; -+-------------------------------------------------------+ -| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | -+-------------------------------------------------------+ -| {field_a: 1, field_b: 2} | -| {field_a: 3, field_b: 4} | -+-------------------------------------------------------+ -``` - -### `row` - -_Alias of [struct](#struct)._ - -### `struct` - -Returns an Arrow struct using the specified input expressions optionally named. -Fields in the returned struct use the optional name or the `cN` naming convention. -For example: `c0`, `c1`, `c2`, etc. - -``` -struct(expression1[, ..., expression_n]) -``` - -#### Arguments - -- **expression1, expression_n**: Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators. - -#### Example - -For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `field_a` and `c1`: - -```sql -> select * from t; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -| 3 | 4 | -+---+---+ - --- use default names `c0`, `c1` -> select struct(a, b) from t; -+-----------------+ -| struct(t.a,t.b) | -+-----------------+ -| {c0: 1, c1: 2} | -| {c0: 3, c1: 4} | -+-----------------+ - --- name the first field `field_a` -select struct(a as field_a, b) from t; -+--------------------------------------------------+ -| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | -+--------------------------------------------------+ -| {field_a: 1, c1: 2} | -| {field_a: 3, c1: 4} | -+--------------------------------------------------+ -``` - -#### Aliases - -- row - -## Map Functions - -- [element_at](#element_at) -- [map](#map) -- [map_extract](#map_extract) -- [map_keys](#map_keys) -- [map_values](#map_values) - -### `element_at` - -_Alias of [map_extract](#map_extract)._ - -### `map` - -Returns an Arrow map with the specified key-value pairs. - -The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null. - -``` -map(key, value) -map(key: value) -make_map(['key1', 'key2'], ['value1', 'value2']) -``` - -#### Arguments - -- **key**: For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators. - For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null. -- **value**: For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators. - For `make_map`: The list of values to be mapped to the corresponding keys. - -#### Example - -````sql - -- Using map function - SELECT MAP('type', 'test'); - ---- - {type: test} - - SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); - ---- - {POST: 41, HEAD: 33, PATCH: } - - SELECT MAP([[1,2], [3,4]], ['a', 'b']); - ---- - {[1, 2]: a, [3, 4]: b} - - SELECT MAP { 'a': 1, 'b': 2 }; - ---- - {a: 1, b: 2} - - -- Using make_map function - SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); - ---- - {POST: 41, HEAD: 33} - - SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); - ---- - {key1: value1, key2: } - ``` - - -### `map_extract` - -Returns a list containing the value for the given key or an empty list if the key is not present in the map. - -```` - -map_extract(map, key) - -```` -#### Arguments - -- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. -- **key**: Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed. - -#### Example - -```sql -SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); ----- -[1] - -SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); ----- -['two'] - -SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); ----- -[] -```` - -#### Aliases - -- element_at - -### `map_keys` - -Returns a list of all keys in the map. - -``` -map_keys(map) -``` - -#### Arguments - -- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. - -#### Example - -```sql -SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[a, b, c] - -SELECT map_keys(map([100, 5], [42, 43])); ----- -[100, 5] -``` - -### `map_values` - -Returns a list of all values in the map. - -``` -map_values(map) -``` - -#### Arguments - -- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. - -#### Example - -```sql -SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); ----- -[1, , 3] - -SELECT map_values(map([100, 5], [42, 43])); ----- -[42, 43] -``` - -## Hashing Functions - -- [digest](#digest) -- [md5](#md5) -- [sha224](#sha224) -- [sha256](#sha256) -- [sha384](#sha384) -- [sha512](#sha512) - -### `digest` - -Computes the binary hash of an expression using the specified algorithm. - -``` -digest(expression, algorithm) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **algorithm**: String expression specifying algorithm to use. Must be one of: -- md5 -- sha224 -- sha256 -- sha384 -- sha512 -- blake2s -- blake2b -- blake3 - -#### Example - -```sql -> select digest('foo', 'sha256'); -+------------------------------------------+ -| digest(Utf8("foo"), Utf8("sha256")) | -+------------------------------------------+ -| | -+------------------------------------------+ -``` - -### `md5` - -Computes an MD5 128-bit checksum for a string expression. - -``` -md5(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select md5('foo'); -+-------------------------------------+ -| md5(Utf8("foo")) | -+-------------------------------------+ -| | -+-------------------------------------+ -``` - -### `sha224` - -Computes the SHA-224 hash of a binary string. - -``` -sha224(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select sha224('foo'); -+------------------------------------------+ -| sha224(Utf8("foo")) | -+------------------------------------------+ -| | -+------------------------------------------+ -``` - -### `sha256` - -Computes the SHA-256 hash of a binary string. - -``` -sha256(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select sha256('foo'); -+--------------------------------------+ -| sha256(Utf8("foo")) | -+--------------------------------------+ -| | -+--------------------------------------+ -``` - -### `sha384` - -Computes the SHA-384 hash of a binary string. - -``` -sha384(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select sha384('foo'); -+-----------------------------------------+ -| sha384(Utf8("foo")) | -+-----------------------------------------+ -| | -+-----------------------------------------+ -``` - -### `sha512` - -Computes the SHA-512 hash of a binary string. - -``` -sha512(expression) -``` - -#### Arguments - -- **expression**: String - -#### Example - -```sql -> select sha512('foo'); -+-------------------------------------------+ -| sha512(Utf8("foo")) | -+-------------------------------------------+ -| | -+-------------------------------------------+ -``` - -## Other Functions - -- [arrow_cast](#arrow_cast) -- [arrow_typeof](#arrow_typeof) -- [get_field](#get_field) -- [version](#version) - -### `arrow_cast` - -Casts a value to a specific Arrow data type. - -``` -arrow_cast(expression, datatype) -``` - -#### Arguments - -- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. -- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] - -#### Example - -```sql -> select arrow_cast(-5, 'Int8') as a, - arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, - arrow_cast('bar', 'LargeUtf8') as c, - arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d - ; -+----+-----+-----+---------------------------+ -| a | b | c | d | -+----+-----+-----+---------------------------+ -| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | -+----+-----+-----+---------------------------+ -``` - -### `arrow_typeof` - -Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression. - -``` -arrow_typeof(expression) -``` - -#### Arguments - -- **expression**: Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators. - -#### Example - -```sql -> select arrow_typeof('foo'), arrow_typeof(1); -+---------------------------+------------------------+ -| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | -+---------------------------+------------------------+ -| Utf8 | Int64 | -+---------------------------+------------------------+ -``` - -### `get_field` - -Returns a field within a map or a struct with the given key. -Note: most users invoke `get_field` indirectly via field access -syntax such as `my_struct_col['field_name']` which results in a call to -`get_field(my_struct_col, 'field_name')`. - -``` -get_field(expression1, expression2) -``` - -#### Arguments - -- **expression1**: The map or struct to retrieve a field for. -- **expression2**: The field name in the map or struct to retrieve data for. Must evaluate to a string. - -#### Example - -```sql -> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); -> select struct(idx, v) from t as c; -+-------------------------+ -| struct(c.idx,c.v) | -+-------------------------+ -| {c0: data, c1: fusion} | -| {c0: apache, c1: arrow} | -+-------------------------+ -> select get_field((select struct(idx, v) from t), 'c0'); -+-----------------------+ -| struct(t.idx,t.v)[c0] | -+-----------------------+ -| data | -| apache | -+-----------------------+ -> select get_field((select struct(idx, v) from t), 'c1'); -+-----------------------+ -| struct(t.idx,t.v)[c1] | -+-----------------------+ -| fusion | -| arrow | -+-----------------------+ -``` - -### `version` - -Returns the version of DataFusion. - -``` -version() -``` - -#### Example - -```sql -> select version(); -+--------------------------------------------+ -| version() | -+--------------------------------------------+ -| Apache DataFusion 42.0.0, aarch64 on macos | -+--------------------------------------------+ -``` From 9df766f090cf4ecf4011f38afb82f4d70d7d02eb Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 30 Oct 2024 08:02:01 +0100 Subject: [PATCH 110/110] fix: add missing `NotExpr::evaluate_bounds` (#13082) * fix: add missing `NotExpr::evaluate_bounds` * Add a test --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/expressions/not.rs | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index b69954e00bba..6d91e9dfdd36 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -27,6 +27,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; /// Not expression @@ -100,6 +101,10 @@ impl PhysicalExpr for NotExpr { Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) } + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + children[0].not() + } + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); @@ -125,10 +130,11 @@ mod tests { use super::*; use crate::expressions::col; use arrow::{array::BooleanArray, datatypes::*}; + use std::sync::OnceLock; #[test] fn neg_op() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let schema = schema(); let expr = not(col("a", &schema)?)?; assert_eq!(expr.data_type(&schema)?, DataType::Boolean); @@ -137,8 +143,7 @@ mod tests { let input = BooleanArray::from(vec![Some(true), None, Some(false)]); let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?; let result = expr .evaluate(&batch)? @@ -150,4 +155,48 @@ mod tests { Ok(()) } + + #[test] + fn test_evaluate_bounds() -> Result<()> { + // Note that `None` for boolean intervals is converted to `Some(false)` + // / `Some(true)` by `Interval::make`, so it is not explicitly tested + // here + + // if the bounds are all booleans (false, true) so is the negation + assert_evaluate_bounds( + Interval::make(Some(false), Some(true))?, + Interval::make(Some(false), Some(true))?, + )?; + // (true, false) is not tested because it is not a valid interval (lower + // bound is greater than upper bound) + assert_evaluate_bounds( + Interval::make(Some(true), Some(true))?, + Interval::make(Some(false), Some(false))?, + )?; + assert_evaluate_bounds( + Interval::make(Some(false), Some(false))?, + Interval::make(Some(true), Some(true))?, + )?; + Ok(()) + } + + fn assert_evaluate_bounds( + interval: Interval, + expected_interval: Interval, + ) -> Result<()> { + let not_expr = not(col("a", &schema())?)?; + assert_eq!( + not_expr.evaluate_bounds(&[&interval]).unwrap(), + expected_interval + ); + Ok(()) + } + + fn schema() -> SchemaRef { + Arc::clone(SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])) + })) + } + + static SCHEMA: OnceLock = OnceLock::new(); }