From 13a83f7ae882f473070e05e3d11c9d648ff75bfb Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 19 Sep 2024 14:53:20 -0700 Subject: [PATCH] patch --- .../connect/planner/SparkConnectPlanner.scala | 35 +++++++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 2 ++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 33c9edb1cd21a..f96aa8420837b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1194,6 +1194,23 @@ class SparkConnectPlanner( } } + private def groupColsFromDropDuplicates( + colNames: Seq[String], allColumns: Seq[Attribute]): Seq[Attribute] = { + val resolver = session.sessionState.analyzer.resolver + // [SPARK-31990][SPARK-49722]: We must keep `toSet.toSeq` here because of the backward + // compatibility issue (the Streaming's state store depends on the `groupCols` order). + // If you modify this function, please also modify the same function in `DataSet.scala` in SQL. + colNames.toSet.toSeq.flatMap { (colName: String) => + // It is possibly there are more than one columns with the same name, + // so we call filter instead of find. + val cols = allColumns.filter(col => resolver(col.name, colName)) + if (cols.isEmpty) { + throw InvalidPlanInput(s"Invalid deduplicate column ${colName}") + } + cols + } + } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Deduplicate needs a plan input") @@ -1209,19 +1226,15 @@ class SparkConnectPlanner( val resolver = session.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output if (rel.getAllColumnsAsKeys) { - if (rel.getWithinWatermark) DeduplicateWithinWatermark(allColumns, queryExecution.analyzed) - else Deduplicate(allColumns, queryExecution.analyzed) + val groupCals = groupColsFromDropDuplicates(allColumns.map(_.name), allColumns) + if (rel.getWithinWatermark) { + DeduplicateWithinWatermark(groupCals, queryExecution.analyzed) + } else { + Deduplicate(groupCals, queryExecution.analyzed) + } } else { val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq - val groupCols = toGroupColumnNames.flatMap { (colName: String) => - // It is possibly there are more than one columns with the same name, - // so we call filter instead of find. - val cols = allColumns.filter(col => resolver(col.name, colName)) - if (cols.isEmpty) { - throw InvalidPlanInput(s"Invalid deduplicate column ${colName}") - } - cols - } + val groupCols = groupColsFromDropDuplicates(toGroupColumnNames, allColumns) if (rel.getWithinWatermark) DeduplicateWithinWatermark(groupCols, queryExecution.analyzed) else Deduplicate(groupCols, queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef628ca612b49..0eaa79a9d5ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1304,6 +1304,8 @@ class Dataset[T] private[sql]( val allColumns = queryExecution.analyzed.output // SPARK-31990: We must keep `toSet.toSeq` here because of the backward compatibility issue // (the Streaming's state store depends on the `groupCols` order). + // SPARK-49722: If you modify this function, please please also modify the same function in + // `SparkConnectPlanner.scala` in connect. colNames.toSet.toSeq.flatMap { (colName: String) => // It is possibly there are more than one columns with the same name, // so we call filter instead of find.