diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2010a5c66412..9fcd0d90d08f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -707,10 +707,6 @@ impl DefaultPhysicalPlanner { physical_input_schema.clone(), )?); - // update group column indices based on partial aggregate plan evaluation - let final_group: Vec> = - initial_aggr.output_group_expr(); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); @@ -731,13 +727,7 @@ impl DefaultPhysicalPlanner { AggregateMode::Final }; - let final_grouping_set = PhysicalGroupBy::new_single( - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) - .collect(), - ); + let final_grouping_set = initial_aggr.group_expr().as_final(); Arc::new(AggregateExec::try_new( next_partition_mode, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c3bc7b042e65..063a6a8030b3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -108,6 +108,8 @@ impl AggregateMode { } } +const INTERNAL_GROUPING_ID: &str = "grouping_id"; + /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. @@ -137,6 +139,10 @@ pub struct PhysicalGroupBy { /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, + /// The number of expressions that are output by this `PhysicalGroupBy`. + /// Internal expressions like one used to implement grouping sets are not + /// part of the output. + num_output_exprs: usize, } impl PhysicalGroupBy { @@ -146,10 +152,12 @@ impl PhysicalGroupBy { null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { + let num_output_exprs = expr.len() + if !null_expr.is_empty() { 1 } else { 0 }; Self { expr, null_expr, groups, + num_output_exprs, } } @@ -161,6 +169,7 @@ impl PhysicalGroupBy { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], + num_output_exprs: num_exprs, } } @@ -212,11 +221,78 @@ impl PhysicalGroupBy { /// Return grouping expressions as they occur in the output schema. pub fn output_exprs(&self) -> Vec> { - self.expr - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + let mut output_exprs = Vec::with_capacity(self.num_output_exprs); + output_exprs.extend( + self.expr + .iter() + .take(self.num_output_exprs) + .enumerate() + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), + ); + if !self.is_single() { + output_exprs + .push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _); + } + output_exprs + } + + pub fn expr_count(&self) -> usize { + self.expr.len() + self.internal_expr_count() + } + + pub fn group_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = Vec::with_capacity(self.expr_count()); + for ((expr, name), group_expr_nullable) in + self.expr.iter().zip(self.exprs_nullable().into_iter()) + { + fields.push(Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + )) + } + if !self.is_single() { + fields.push(Field::new( + INTERNAL_GROUPING_ID, + arrow::datatypes::DataType::UInt32, + false, + )); + } + Ok(fields) + } + + fn output_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = self.group_fields(input_schema)?; + fields.truncate(self.num_output_exprs); + Ok(fields) + } + + fn internal_expr_count(&self) -> usize { + if self.is_single() { + 0 + } else { + 1 + } + } + + pub fn as_final(&self) -> PhysicalGroupBy { + let expr: Vec<_> = self + .output_exprs() + .into_iter() + .zip( + self.expr + .iter() + .map(|t| t.1.clone()) + .chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())), + ) + .collect(); + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + num_output_exprs: num_exprs - self.internal_expr_count(), + } } } @@ -320,13 +396,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( - &input.schema(), - &group_by.expr, - &aggr_expr, - group_by.exprs_nullable(), - mode, - )?; + let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); AggregateExec::try_new_with_schema( @@ -786,22 +856,12 @@ impl ExecutionPlan for AggregateExec { fn create_schema( input_schema: &Schema, - group_expr: &[(Arc, String)], + group_by: &PhysicalGroupBy, aggr_expr: &[AggregateFunctionExpr], - group_expr_nullable: Vec, mode: AggregateMode, ) -> Result { - let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - )) - } + let mut fields = Vec::with_capacity(group_by.num_output_exprs + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -824,9 +884,8 @@ fn create_schema( Ok(Schema::new(fields)) } -fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { - let group_fields = schema.fields()[0..group_count].to_vec(); - Arc::new(Schema::new(group_fields)) +fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { + Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) } /// Determines the lexical ordering requirement for an aggregate expression. @@ -1133,6 +1192,19 @@ fn evaluate_optional( .collect() } +fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { + if group.len() > 32 { + return not_impl_err!("Grouping sets with more than 32 columns are not supported"); + } + let group_id = group.iter().fold(0u32, |acc, &is_null| { + (acc << 1) | if is_null { 1 } else { 0 } + }); + Ok(Arc::new(arrow::array::UInt32Array::from(vec![ + group_id; + batch.num_rows() + ]))) +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: @@ -1165,23 +1237,24 @@ pub(crate) fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by + group_by .groups .iter() .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - Arc::clone(&null_exprs[idx]) - } else { - Arc::clone(&exprs[idx]) - } - }) - .collect() + let mut group_values = Vec::with_capacity(group_by.expr_count()); + group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { + if *is_null { + Arc::clone(&null_exprs[idx]) + } else { + Arc::clone(&exprs[idx]) + } + })); + if !group_by.is_single() { + group_values.push(group_id_array(group, batch)?); + } + Ok(group_values) }) - .collect()) + .collect() } #[cfg(test)] @@ -2502,25 +2575,24 @@ mod tests { .build()?, ]; - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::Float32(None)), "a".to_string()), (lit(ScalarValue::Float32(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![false, false], // (a,b) ], - }; + ); let aggr_schema = create_schema( &input_schema, - &grouping_set.expr, + &grouping_set, &aggr_expr, - grouping_set.exprs_nullable(), AggregateMode::Final, )?; let expected_schema = Schema::new(vec![ diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 60efc7711216..5d0f93458971 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -456,13 +456,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.expr_count(), )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.expr_count(), )?; let filter_expressions = match agg.mode { @@ -480,7 +480,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; let spill_expr = group_schema .fields .into_iter() @@ -851,6 +851,9 @@ impl GroupedHashAggregateStream { } let mut output = self.group_values.emit(emit_to)?; + if !spilling { + output.truncate(self.group_by.num_output_exprs); + } if let EmitTo::First(n) = emit_to { self.group_ordering.remove_groups(n); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 576abe5c6f5a..c146d730806f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3512,6 +3512,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls ---- 1 5 +# grouping_sets with null values +query II rowsort +SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value) +---- +1 1 +3 3 +4 4 +5 5 +NULL 1 +NULL NULL + + statement ok DROP TABLE integers_with_nulls; @@ -4876,7 +4888,7 @@ logical_plan 03)----TableScan: aggregate_test_100 projection=[c2, c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] +02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, grouping_id@2 as grouping_id], aggr=[], lim=[3] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1