Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
WweiL committed Sep 19, 2024
1 parent 0445579 commit 13a83f7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,23 @@ class SparkConnectPlanner(
}
}

private def groupColsFromDropDuplicates(
colNames: Seq[String], allColumns: Seq[Attribute]): Seq[Attribute] = {
val resolver = session.sessionState.analyzer.resolver
// [SPARK-31990][SPARK-49722]: We must keep `toSet.toSeq` here because of the backward
// compatibility issue (the Streaming's state store depends on the `groupCols` order).
// If you modify this function, please also modify the same function in `DataSet.scala` in SQL.
colNames.toSet.toSeq.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
if (cols.isEmpty) {
throw InvalidPlanInput(s"Invalid deduplicate column ${colName}")
}
cols
}
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand All @@ -1209,19 +1226,15 @@ class SparkConnectPlanner(
val resolver = session.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
if (rel.getAllColumnsAsKeys) {
if (rel.getWithinWatermark) DeduplicateWithinWatermark(allColumns, queryExecution.analyzed)
else Deduplicate(allColumns, queryExecution.analyzed)
val groupCals = groupColsFromDropDuplicates(allColumns.map(_.name), allColumns)
if (rel.getWithinWatermark) {
DeduplicateWithinWatermark(groupCals, queryExecution.analyzed)
} else {
Deduplicate(groupCals, queryExecution.analyzed)
}
} else {
val toGroupColumnNames = rel.getColumnNamesList.asScala.toSeq
val groupCols = toGroupColumnNames.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
if (cols.isEmpty) {
throw InvalidPlanInput(s"Invalid deduplicate column ${colName}")
}
cols
}
val groupCols = groupColsFromDropDuplicates(toGroupColumnNames, allColumns)
if (rel.getWithinWatermark) DeduplicateWithinWatermark(groupCols, queryExecution.analyzed)
else Deduplicate(groupCols, queryExecution.analyzed)
}
Expand Down
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,8 @@ class Dataset[T] private[sql](
val allColumns = queryExecution.analyzed.output
// SPARK-31990: We must keep `toSet.toSeq` here because of the backward compatibility issue
// (the Streaming's state store depends on the `groupCols` order).
// SPARK-49722: If you modify this function, please please also modify the same function in
// `SparkConnectPlanner.scala` in connect.
colNames.toSet.toSeq.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
Expand Down

0 comments on commit 13a83f7

Please sign in to comment.