diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 9e1e6b81b61df..4621a18864373 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -17,32 +17,14 @@ use crate::utils::make_scalar_function; use arrow::array::Int32Array; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType; -use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -/// Returns the numeric code of the first character of the argument. -/// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut chars = string.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - #[derive(Debug)] pub struct AsciiFunc { signature: Signature, @@ -60,7 +42,7 @@ impl AsciiFunc { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8], + vec![Utf8, LargeUtf8, Utf8View], Volatility::Immutable, ), } @@ -87,12 +69,92 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(ascii::, vec![])(args), - DataType::LargeUtf8 => { - return make_scalar_function(ascii::, vec![])(args); - } - _ => internal_err!("Unsupported data type"), + make_scalar_function(ascii, vec![])(args) + } +} + +fn calculate_ascii<'a, I>(string_array: I) -> Result +where + I: IntoIterator>, +{ + let result = string_array + .into_iter() + .map(|string| { + string.map(|s| { + let mut chars = s.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the numeric code of the first character of the argument. +pub fn ascii(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + calculate_ascii(string_array.iter()) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + calculate_ascii(string_array.iter()) } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + calculate_ascii(string_array.iter()) + } + _ => unreachable!(), + } +} + +#[cfg(test)] +mod tests { + use crate::string::ascii::AsciiFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ascii { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_functions() -> Result<()> { + test_ascii!(Some(String::from("x")), Ok(Some(120))); + test_ascii!(Some(String::from("a")), Ok(Some(97))); + test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(None, Ok(None)); + Ok(()) } }