Skip to content

Commit

Permalink
Implement native support StringView for REPEAT
Browse files Browse the repository at this point in the history
Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 committed Aug 13, 2024
1 parent 18193e6 commit 2cdf18b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
82 changes: 70 additions & 12 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
use arrow::datatypes::DataType;

use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_string_view_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Volatility};
Expand All @@ -45,7 +45,14 @@ impl RepeatFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
vec![
// Planner attempts coercion to the target type starting with the most preferred candidate.
// For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
// If that fails, it proceeds to `(Utf8, Int64)`.
Exact(vec![Utf8View, Int64]),
Exact(vec![Utf8, Int64]),
Exact(vec![LargeUtf8, Int64]),
],
Volatility::Immutable,
),
}
Expand All @@ -71,9 +78,10 @@ impl ScalarUDFImpl for RepeatFunc {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8View => make_scalar_function(repeat_utf8view, vec![])(args),
DataType::Utf8 => make_scalar_function(repeat::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(repeat::<i64>, vec![])(args),
other => exec_err!("Unsupported data type {other:?} for function repeat"),
other => exec_err!("Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8"),
}
}
}
Expand All @@ -87,18 +95,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let result = string_array
.iter()
.zip(number_array.iter())
.map(|(string, number)| match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
Some(string.repeat(number as usize))
}
(Some(_), Some(_)) => Some("".to_string()),
_ => None,
})
.map(|(string, number)| repeat_common(string, number))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}

fn repeat_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_view_array = as_string_view_array(&args[0])?;
let number_array = as_int64_array(&args[1])?;

let result = string_view_array
.iter()
.zip(number_array.iter())
.map(|(string, number)| repeat_common(string, number))
.collect::<StringArray>();

Ok(Arc::new(result) as ArrayRef)
}

fn repeat_common(string: Option<&str>, number: Option<i64>) -> Option<String> {
match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
Some(string.repeat(number as usize))
}
(Some(_), Some(_)) => Some("".to_string()),
_ => None,
}
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
Expand All @@ -124,7 +149,6 @@ mod tests {
Utf8,
StringArray
);

test_function!(
RepeatFunc::new(),
&[
Expand All @@ -148,6 +172,40 @@ mod tests {
StringArray
);

test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(Some("PgPgPgPg")),
&str,
Utf8,
StringArray
);
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
],
Ok(None),
&str,
Utf8,
StringArray
);
test_function!(
RepeatFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
Utf8,
StringArray
);

Ok(())
}
}
3 changes: 1 addition & 2 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -761,14 +761,13 @@ logical_plan


## Ensure no casts for REPEAT
## TODO file ticket
query TT
EXPLAIN SELECT
REPEAT(column1_utf8view, 2) as c1
FROM test;
----
logical_plan
01)Projection: repeat(CAST(test.column1_utf8view AS Utf8), Int64(2)) AS c1
01)Projection: repeat(test.column1_utf8view, Int64(2)) AS c1
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for REPLACE
Expand Down

0 comments on commit 2cdf18b

Please sign in to comment.