Skip to content

Commit

Permalink
Add field trait method to WindowUDFImpl, remove return_type/`nu…
Browse files Browse the repository at this point in the history
…llable` (#12374)

* Adds new library `functions-window-common`

* Adds `FieldArgs` struct for field of final result

* Adds `field` method to `WindowUDFImpl` trait

* Minor: fixes formatting

* Fixes: udwf doc test

* Fixes: implements missing trait items

* Updates `datafusion-cli` dependencies

* Fixes: formatting of `Cargo.toml` files

* Fixes: implementation of `field` in udwf example

* Pass `FieldArgs` argument to `field`

* Use `field` in place of `return_type` for udwf

* Update `field` in udwf implementations

* Fixes: implementation of `field` in udwf example

* Revert unrelated change

* Mark `return_type` for udwf as unreachable

* Delete code

* Uses schema name of udwf to construct `FieldArgs`

* Adds deprecated notice to `return_type` trait method

* Add doc comments to `field` trait method

* Reify `input_types` when creating the udwf window expression

* Rename name field to `schema_name` in `FieldArgs`

* Make `FieldArgs` opaque

* Minor refactor

* Removes `nullable` trait method from `WindowUDFImpl`

* Add doc comments

* Rename to `WindowUDFResultArgs`

* Minor: fixes formatting

* Copy edits for doc comments

* Renames field to `function_name`

* Rename struct to `WindowUDFFieldArgs`

* Add comments for unreachable code

* Copy edit for `WindowUDFImpl::field` trait method

* Renames module

* Fix warning: unused doc comment

* Minor: rename bindings

* Minor refactor

* Minor: copy edit

* Fixes: use `Expr::qualified_name` for window function name

* Fixes: apply previous fix to `Expr::nullable`

* Refactor: reuse type coercion for window functions

* Fixes: clippy errors

* Adds name parameter to `WindowFunctionDefinition::return_type`

* Removes `return_type` field from `SimpleWindowUDF`

* Add doc comment for helper method

* Rewrite doc comments

* Minor: remove empty comment

* Remove `WindowUDFImpl::return_type`

* Fixes doc test
  • Loading branch information
jcsherin committed Sep 21, 2024
1 parent d9cb6e6 commit e1b992a
Show file tree
Hide file tree
Showing 24 changed files with 357 additions and 168 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ members = [
"datafusion/functions-aggregate-common",
"datafusion/functions-nested",
"datafusion/functions-window",
"datafusion/functions-window-common",
"datafusion/optimizer",
"datafusion/physical-expr",
"datafusion/physical-expr-common",
Expand Down Expand Up @@ -103,6 +104,7 @@ datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", vers
datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" }
datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" }
datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" }
datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" }
datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false }
datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false }
datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false }
Expand Down
10 changes: 10 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use arrow::{
array::{ArrayRef, AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::Field;
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::ScalarValue;
use datafusion_expr::function::WindowUDFFieldArgs;
use datafusion_expr::{
PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl,
};
Expand Down Expand Up @@ -70,16 +72,15 @@ impl WindowUDFImpl for SmoothItUdf {
&self.signature
}

/// What is the type of value that will be returned by this function.
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

/// Create a `PartitionEvaluator` to evaluate this function on a new
/// partition.
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Float64, true))
}
}

/// This implements the lowest level evaluation for a window function
Expand Down
12 changes: 6 additions & 6 deletions datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

use std::any::Any;

use arrow_schema::DataType;
use arrow_schema::{DataType, Field};

use datafusion::execution::context::SessionContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{error::Result, execution::options::CsvReadOptions};
use datafusion_expr::function::WindowFunctionSimplification;
use datafusion_expr::function::{WindowFunctionSimplification, WindowUDFFieldArgs};
use datafusion_expr::{
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
Volatility, WindowUDF, WindowUDFImpl,
Expand Down Expand Up @@ -60,10 +60,6 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
}
Expand All @@ -84,6 +80,10 @@ impl WindowUDFImpl for SimplifySmoothItUdf {

Some(Box::new(simplify))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Float64, true))
}
}

// create local execution context with `cars.csv` registered as a table named `cars`
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ bigdecimal = { workspace = true }
criterion = { version = "0.5", features = ["async_tokio"] }
csv = "1.1.6"
ctor = { workspace = true }
datafusion-functions-window-common = { workspace = true }
doc-comment = { workspace = true }
env_logger = { workspace = true }
half = { workspace = true, default-features = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ use std::{

use arrow::array::AsArray;
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use arrow_schema::DataType;
use arrow_schema::{DataType, Field};
use datafusion::{assert_batches_eq, prelude::SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;

/// A query with a window function evaluated over the entire partition
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
Expand Down Expand Up @@ -522,7 +523,6 @@ impl OddCounter {
#[derive(Debug, Clone)]
struct SimpleWindowUDF {
signature: Signature,
return_type: DataType,
test_state: Arc<TestState>,
aliases: Vec<String>,
}
Expand All @@ -531,10 +531,8 @@ impl OddCounter {
fn new(test_state: Arc<TestState>) -> Self {
let signature =
Signature::exact(vec![DataType::Float64], Volatility::Immutable);
let return_type = DataType::Int64;
Self {
signature,
return_type,
test_state,
aliases: vec!["odd_counter_alias".to_string()],
}
Expand All @@ -554,17 +552,17 @@ impl OddCounter {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Int64, true))
}
}

ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-expr-common = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-functions-window-common = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
paste = "^1.0"
serde_json = { workspace = true }
Expand Down
32 changes: 18 additions & 14 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
plan_err, Column, DFSchema, Result, ScalarValue, TableReference,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use sqlparser::ast::{
display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem,
NullTreatment, RenameSelectItem, ReplaceSelectElement,
Expand Down Expand Up @@ -706,6 +707,7 @@ impl WindowFunctionDefinition {
&self,
input_expr_types: &[DataType],
_input_expr_nullable: &[bool],
display_name: &str,
) -> Result<DataType> {
match self {
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
Expand All @@ -714,7 +716,9 @@ impl WindowFunctionDefinition {
WindowFunctionDefinition::AggregateUDF(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types),
WindowFunctionDefinition::WindowUDF(fun) => fun
.field(WindowUDFFieldArgs::new(input_expr_types, display_name))
.map(|field| field.data_type().clone()),
}
}

Expand Down Expand Up @@ -2536,10 +2540,10 @@ mod test {
#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::UInt64], &[true])?;
let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
assert_eq!(DataType::UInt64, observed);

Ok(())
Expand All @@ -2548,10 +2552,10 @@ mod test {
#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true])?;
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2560,10 +2564,10 @@ mod test {
#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true])?;
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2572,10 +2576,10 @@ mod test {
#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true])?;
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2585,11 +2589,11 @@ mod test {
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed =
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?;
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed =
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?;
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2598,7 +2602,7 @@ mod test {
#[test]
fn test_percent_rank_return_type() -> Result<()> {
let fun = find_df_window_func("percent_rank").unwrap();
let observed = fun.return_type(&[], &[])?;
let observed = fun.return_type(&[], &[], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2607,7 +2611,7 @@ mod test {
#[test]
fn test_cume_dist_return_type() -> Result<()> {
let fun = find_df_window_func("cume_dist").unwrap();
let observed = fun.return_type(&[], &[])?;
let observed = fun.return_type(&[], &[], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2616,7 +2620,7 @@ mod test {
#[test]
fn test_ntile_return_type() -> Result<()> {
let fun = find_df_window_func("ntile").unwrap();
let observed = fun.return_type(&[DataType::Int16], &[true])?;
let observed = fun.return_type(&[DataType::Int16], &[true], "")?;
assert_eq!(DataType::UInt64, observed);

Ok(())
Expand Down
13 changes: 9 additions & 4 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use arrow::compute::kernels::cast_utils::{
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
Expand Down Expand Up @@ -657,13 +658,17 @@ impl WindowUDFImpl for SimpleWindowUDF {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(
field_args.name(),
self.return_type.clone(),
true,
))
}
}

pub fn interval_year_month_lit(value: &str) -> Expr {
Expand Down
Loading

0 comments on commit e1b992a

Please sign in to comment.