Skip to content

Commit

Permalink
fix: serialize user-defined window functions to proto (apache#13421)
Browse files Browse the repository at this point in the history
* Adds roundtrip physical plan test

* Adds enum for udwf to `WindowFunction`

* initial fix for serializing udwf

* Revives deleted test

* Adds codec methods for physical plan

* Rewrite error message

* Minor: rename binding + formatting fixes

* Extends `PhysicalExtensionCodec` for udwf

* Minor: formatting

* Restricts visibility to tests
  • Loading branch information
jcsherin authored and mwylde committed Nov 22, 2024
1 parent ea2113a commit 319f59b
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 15 deletions.
8 changes: 7 additions & 1 deletion datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ pub fn create_udwf_window_expr(

/// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`]
#[derive(Clone, Debug)]
struct WindowUDFExpr {
pub struct WindowUDFExpr {
fun: Arc<WindowUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
/// Display name
Expand All @@ -213,6 +213,12 @@ struct WindowUDFExpr {
ignore_nulls: bool,
}

impl WindowUDFExpr {
pub fn fun(&self) -> &Arc<WindowUDF> {
&self.fun
}
}

impl BuiltInWindowFunctionExpr for WindowUDFExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ message PhysicalWindowExprNode {
oneof window_function {
BuiltInWindowFunction built_in_function = 2;
string user_defined_aggr_function = 3;
string user_defined_window_function = 10;
}
repeated PhysicalExprNode args = 4;
repeated PhysicalExprNode partition_by = 5;
Expand Down
13 changes: 13 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ pub fn parse_physical_window_expr(
None => registry.udaf(udaf_name)?
})
}
protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => {
WindowFunctionDefinition::WindowUDF(match &proto.fun_definition {
Some(buf) => codec.try_decode_udwf(udwf_name, buf)?,
None => registry.udwf(udwf_name)?
})
}
}
} else {
return Err(proto_error("Missing required field in protobuf"));
Expand Down
10 changes: 9 additions & 1 deletion datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use datafusion::physical_plan::{
ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};

use crate::common::{byte_to_string, str_to_byte};
use crate::physical_plan::from_proto::{
Expand Down Expand Up @@ -2126,6 +2126,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync {
fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec<u8>) -> Result<()> {
Ok(())
}

fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result<Arc<WindowUDF>> {
not_impl_err!("PhysicalExtensionCodec is not provided for window function {name}")
}

fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec<u8>) -> Result<()> {
Ok(())
}
}

#[derive(Debug)]
Expand Down
25 changes: 22 additions & 3 deletions datafusion/proto/src/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ use std::sync::Arc;

#[cfg(feature = "parquet")]
use datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr};
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr,
Literal, NegativeExpr, NotExpr, TryCastExpr,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::PlainAggregateWindowExpr;
use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr};
use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr};
use datafusion::{
datasource::{
Expand Down Expand Up @@ -69,7 +69,7 @@ pub fn serialize_physical_aggr_expr(
ordering_req,
distinct: aggr_expr.is_distinct(),
ignore_nulls: aggr_expr.ignore_nulls(),
fun_definition: (!buf.is_empty()).then_some(buf)
fun_definition: (!buf.is_empty()).then_some(buf),
},
)),
})
Expand Down Expand Up @@ -121,6 +121,25 @@ pub fn serialize_physical_window_expr(
window_frame,
codec,
)?
} else if let Some(built_in_window_expr) = expr.downcast_ref::<BuiltInWindowExpr>() {
if let Some(expr) = built_in_window_expr
.get_built_in_func_expr()
.as_any()
.downcast_ref::<WindowUDFExpr>()
{
let mut buf = Vec::new();
codec.try_encode_udwf(expr.fun(), &mut buf)?;
(
physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(
expr.fun().name().to_string(),
),
(!buf.is_empty()).then_some(buf),
)
} else {
return not_impl_err!(
"User-defined window function not supported: {window_expr:?}"
);
}
} else {
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
};
Expand Down
60 changes: 57 additions & 3 deletions datafusion/proto/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field};
use std::any::Any;

use arrow::datatypes::DataType;
use std::fmt::Debug;

use datafusion_common::plan_err;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility,
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
Signature, Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
Expand Down Expand Up @@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode {
#[prost(string, tag = "1")]
pub result: String,
}

#[derive(Debug)]
pub(in crate::cases) struct CustomUDWF {
signature: Signature,
payload: String,
}

impl CustomUDWF {
pub fn new(payload: String) -> Self {
Self {
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
payload,
}
}
}

impl WindowUDFImpl for CustomUDWF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"custom_udwf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> datafusion_common::Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(CustomUDWFEvaluator {}))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result<Field> {
Ok(Field::new(field_args.name(), DataType::UInt64, false))
}
}

#[derive(Debug)]
struct CustomUDWFEvaluator;

impl PartitionEvaluator for CustomUDWFEvaluator {}

#[derive(Clone, PartialEq, ::prost::Message)]
pub(in crate::cases) struct CustomUDWFNode {
#[prost(string, tag = "1")]
pub payload: String,
}
Loading

0 comments on commit 319f59b

Please sign in to comment.