Skip to content

Commit

Permalink
Update concat_ws scalar function to support Utf8View (#12309)
Browse files Browse the repository at this point in the history
* Update concat_ws scalar function to support Utf8View

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* adds one_of for type sig

* revert type sig, the problem was outside the actual signature xD

* fmt and clippy

* add utf8view

* fix match

* fix match

* 🤔 why nulls why

* small fix

* log ms

* pushing up -- wip

* make it so the return type is just Utf8

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* fmt

Signed-off-by: Devan <[email protected]>

* order matters

Signed-off-by: Devan <[email protected]>

* sum inner buffer for stringviewarray data_size

* need turbo fish

---------

Signed-off-by: Devan <[email protected]>
  • Loading branch information
devanbenz authored Sep 12, 2024
1 parent b25aa33 commit 199d028
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 35 deletions.
10 changes: 3 additions & 7 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,9 @@ impl ScalarUDFImpl for ConcatFunc {

for arg in args {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
Expand Down
102 changes: 75 additions & 27 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::StringArray;
use arrow::array::{as_largestring_array, Array, StringArray};
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Utf8;

use datafusion_common::cast::as_string_array;
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use crate::string::common::*;
use crate::string::concat::simplify_concat;
use crate::string::concat_ws;
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

use crate::string::common::*;
use crate::string::concat::simplify_concat;
use crate::string::concat_ws;

#[derive(Debug)]
pub struct ConcatWsFunc {
signature: Signature,
Expand All @@ -48,7 +46,10 @@ impl ConcatWsFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
signature: Signature::variadic(
vec![Utf8View, Utf8, LargeUtf8],
Volatility::Immutable,
),
}
}
}
Expand All @@ -67,13 +68,14 @@ impl ScalarUDFImpl for ConcatWsFunc {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
Ok(Utf8)
}

/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.
/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22'
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
// do not accept 0 or 1 arguments.
// do not accept 0 arguments.
if args.len() < 2 {
return exec_err!(
"concat_ws was called with {} arguments. It requires at least 2.",
Expand All @@ -92,8 +94,12 @@ impl ScalarUDFImpl for ConcatWsFunc {
// Scalar
if array_len.is_none() {
let sep = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s,
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s,
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}
_ => unreachable!(),
Expand All @@ -104,22 +110,30 @@ impl ScalarUDFImpl for ConcatWsFunc {

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
result.push_str(s);
break;
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
}
}

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
result.push_str(sep);
result.push_str(s);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -155,21 +169,53 @@ impl ScalarUDFImpl for ConcatWsFunc {
let mut columns = Vec::with_capacity(args.len() - 1);
for arg in &args[1..] {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => {
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
| ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
if let Some(s) = maybe_value {
data_size += s.len() * len;
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Array(array) => {
let string_array = as_string_array(array)?;
data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
match array.data_type() {
DataType::Utf8 => {
let string_array = as_string_array(array)?;

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
};
columns.push(column);
},
DataType::LargeUtf8 => {
let string_array = as_largestring_array(array);

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableLargeStringArray(string_array)
} else {
ColumnarValueRef::NonNullableLargeStringArray(string_array)
};
columns.push(column);
},
DataType::Utf8View => {
let string_array = as_string_view_array(array)?;

data_size += string_array.data_buffers().iter().map(|buf| buf.len()).sum::<usize>();
let column = if array.is_nullable() {
ColumnarValueRef::NullableStringViewArray(string_array)
} else {
ColumnarValueRef::NonNullableStringViewArray(string_array)
};
columns.push(column);
},
other => {
return plan_err!("Input was {other} which is not a supported datatype for concat_ws function.")
}
};
columns.push(column);
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -223,7 +269,9 @@ impl ScalarUDFImpl for ConcatWsFunc {
fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyResult> {
match delimiter {
Expr::Literal(
ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter),
ScalarValue::Utf8(delimiter)
| ScalarValue::LargeUtf8(delimiter)
| ScalarValue::Utf8View(delimiter),
) => {
match delimiter {
// when the delimiter is an empty string,
Expand All @@ -236,8 +284,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<ExprSimplifyRes
for arg in args {
match arg {
// filter out null args
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => {
Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => {
match contiguous_scalar {
None => contiguous_scalar = Some(v.to_string()),
Some(mut pre) => {
Expand Down
79 changes: 78 additions & 1 deletion datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: concat_ws(Utf8(", "), CAST(test.column1_utf8view AS Utf8), CAST(test.column2_utf8view AS Utf8)) AS c
01)Projection: concat_ws(Utf8(", "), test.column1_utf8view, test.column2_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for CONTAINS
Expand Down Expand Up @@ -1039,6 +1039,83 @@ XiangpengXiangpeng
RaphaelR
R

## Should run CONCAT successfully with utf8view
query T
SELECT
concat(column1_utf8view, column2_utf8view) as c
FROM test;
----
AndrewX
XiangpengXiangpeng
RaphaelR
R

## Should run CONCAT_WS successfully with utf8
query T
SELECT
concat_ws(',', column1_utf8, column2_utf8) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Should run CONCAT_WS successfully with utf8view
query T
SELECT
concat_ws(',', column1_utf8view, column2_utf8view) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Should run CONCAT_WS successfully with largeutf8
query T
SELECT
concat_ws(',', column1_large_utf8, column2_large_utf8) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Should run CONCAT_WS successfully with utf8 and largeutf8
query T
SELECT
concat_ws(',', column1_utf8, column2_large_utf8) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Should run CONCAT_WS successfully with utf8 and utf8view
query T
SELECT
concat_ws(',', column1_utf8view, column2_utf8) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Should run CONCAT_WS successfully with largeutf8 and utf8view
query T
SELECT
concat_ws(',', column1_utf8view, column2_large_utf8) as c
FROM test;
----
Andrew,X
Xiangpeng,Xiangpeng
Raphael,R
R

## Ensure no casts for LPAD
query TT
EXPLAIN SELECT
Expand Down

0 comments on commit 199d028

Please sign in to comment.