Skip to content

Commit

Permalink
Merge remote-tracking branch 'spark/master' into SPARK-46092-row-grou…
Browse files Browse the repository at this point in the history
…p-skipping-overflow
  • Loading branch information
johanl-db committed Nov 28, 2023
2 parents 1f88c4f + a6cda23 commit 8744cc2
Show file tree
Hide file tree
Showing 225 changed files with 2,334 additions and 860 deletions.
12 changes: 2 additions & 10 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -692,11 +692,7 @@ jobs:
- name: Install Python linter dependencies
if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' && inputs.branch != 'branch-3.5'
run: |
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
# Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.982' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==23.9.1'
python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.982' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc jinja2 'black==23.9.1'
python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.59.3' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0'
- name: Python linter
run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python
Expand Down Expand Up @@ -745,13 +741,9 @@ jobs:
Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')"
- name: Install dependencies for documentation generation
run: |
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
# Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
# Pin the MarkupSafe to 2.0.1 to resolve the CI error.
# See also https://issues.apache.org/jira/browse/SPARK-38279.
python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme sphinx-copybutton nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
python3.9 -m pip install 'sphinx==4.2.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 'markupsafe==2.0.1' 'pyzmq<24.0.0'
python3.9 -m pip install ipython_genutils # See SPARK-38517
python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8'
python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class UTF8StringPropertyCheckSuite extends AnyFunSuite with ScalaCheckDrivenProp

test("compare") {
forAll { (s1: String, s2: String) =>
assert(Math.signum(toUTF8(s1).compareTo(toUTF8(s2))) === Math.signum(s1.compareTo(s2)))
assert(Math.signum {
toUTF8(s1).compareTo(toUTF8(s2)).toFloat
} === Math.signum(s1.compareTo(s2).toFloat))
}
}

Expand Down
2 changes: 1 addition & 1 deletion common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@
},
"FAILED_EXECUTE_UDF" : {
"message" : [
"Failed to execute user defined function (<functionName>: (<signature>) => <result>)."
"User defined function (<functionName>: (<signature>) => <result>) failed due to: <reason>."
],
"sqlState" : "39000"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,41 @@ class Dataset[T] private[sql] (
proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
}

/**
* Create multi-dimensional aggregation for the current Dataset using the specified grouping
* sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the
* available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns group by specific grouping sets.
* ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg()
*
* // Compute the max age and average salary, group by specific grouping sets.
* ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 4.0.0
*/
@scala.annotation.varargs
def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = {
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
groupingSetMsg.addGroupingSet(groupCol.expr)
}
groupingSetMsg.build()
}
new RelationalGroupedDataset(
toDF(),
cols,
proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS,
groupingSets = Some(groupingSetMsgs))
}

/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class RelationalGroupedDataset private[sql] (
private[sql] val df: DataFrame,
private[sql] val groupingExprs: Seq[Column],
groupType: proto.Aggregate.GroupType,
pivot: Option[proto.Aggregate.Pivot] = None) {
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) {

private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame { builder =>
Expand All @@ -60,6 +61,11 @@ class RelationalGroupedDataset private[sql] (
builder.getAggregateBuilder
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
.setPivot(pivot.get)
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
assert(groupingSets.isDefined)
val aggBuilder = builder.getAggregateBuilder
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
groupingSets.get.foreach(aggBuilder.addGroupingSets)
case g => throw new UnsupportedOperationException(g.toString)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3017,6 +3017,12 @@ class PlanGenerationTestSuite
simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
}

test("groupingSets") {
simple
.groupingSets(Seq(Seq(fn.col("a")), Seq.empty[Column]), fn.col("a"))
.agg("a" -> "max", "a" -> "count")
}

test("width_bucket") {
simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"), fn.col("a")))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object RetryPolicy {
def defaultPolicy(): RetryPolicy = RetryPolicy(
name = "DefaultPolicy",
// Please synchronize changes here with Python side:
// pyspark/sql/connect/client/core.py
// pyspark/sql/connect/client/retries.py
//
// Note: these constants are selected so that the maximum tolerated wait is guaranteed
// to be at least 10 minutes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ private[arrow] class SmallIntVectorReader(v: SmallIntVector)
private[arrow] class IntVectorReader(v: IntVector) extends TypedArrowVectorReader[IntVector](v) {
override def getInt(i: Int): Int = vector.get(i)
override def getLong(i: Int): Long = getInt(i)
override def getFloat(i: Int): Float = getInt(i)
override def getFloat(i: Int): Float = getInt(i).toFloat
override def getDouble(i: Int): Double = getInt(i)
override def getString(i: Int): String = String.valueOf(getInt(i))
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getInt(i))
Expand All @@ -143,8 +143,8 @@ private[arrow] class IntVectorReader(v: IntVector) extends TypedArrowVectorReade
private[arrow] class BigIntVectorReader(v: BigIntVector)
extends TypedArrowVectorReader[BigIntVector](v) {
override def getLong(i: Int): Long = vector.get(i)
override def getFloat(i: Int): Float = getLong(i)
override def getDouble(i: Int): Double = getLong(i)
override def getFloat(i: Int): Float = getLong(i).toFloat
override def getDouble(i: Int): Double = getLong(i).toDouble
override def getString(i: Int): String = String.valueOf(getLong(i))
override def getJavaDecimal(i: Int): JBigDecimal = JBigDecimal.valueOf(getLong(i))
override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(getLong(i) * MICROS_PER_SECOND)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8) AS encode(g, UTF-8)#0]
Project [encode(g#0, UTF-8, false) AS encode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8) AS to_binary(g, utf-8)#0]
Project [encode(g#0, UTF-8, false) AS to_binary(g, utf-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Aggregate [a#0, spark_grouping_id#0L], [a#0, max(a#0) AS max(a)#0, count(a#0) AS count(a)#0L]
+- Expand [[id#0L, a#0, b#0, a#0, 0], [id#0L, a#0, b#0, null, 1]], [id#0L, a#0, b#0, a#0, spark_grouping_id#0L]
+- Project [id#0L, a#0, b#0, a#0 AS a#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"common": {
"planId": "1"
},
"aggregate": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double\u003e"
}
},
"groupType": "GROUP_TYPE_GROUPING_SETS",
"groupingExpressions": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}],
"aggregateExpressions": [{
"unresolvedFunction": {
"functionName": "max",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a",
"planId": "0"
}
}]
}
}, {
"unresolvedFunction": {
"functionName": "count",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a",
"planId": "0"
}
}]
}
}],
"groupingSets": [{
"groupingSet": [{
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}, {
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -2235,7 +2235,7 @@ class SparkConnectPlanner(

JoinWith.typedJoinWith(
joined,
session.sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
session.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
session.sessionState.analyzer.resolver,
rel.getJoinDataType.getIsLeftStruct,
rel.getJoinDataType.getIsRightStruct)
Expand Down Expand Up @@ -2563,6 +2563,8 @@ class SparkConnectPlanner(
// To avoid explicit handling of the result on the client, we build the expected input
// of the relation on the server. The client has to simply forward the result.
val result = SqlCommandResult.newBuilder()
// Only filled when isCommand
val metrics = ExecutePlanResponse.Metrics.newBuilder()
if (isCommand) {
// Convert the results to Arrow.
val schema = df.schema
Expand Down Expand Up @@ -2596,10 +2598,10 @@ class SparkConnectPlanner(
proto.LocalRelation
.newBuilder()
.setData(ByteString.copyFrom(bytes))))
metrics.addAllMetrics(MetricGenerator.transformPlan(df).asJava)
} else {
// Trigger assertExecutedPlanPrepared to ensure post ReadyForExecution before finished
// executedPlan is currently called by createMetricsResponse below
df.queryExecution.assertExecutedPlanPrepared()
// No execution triggered for relations. Manually set ready
tracker.setReadyForExecution()
result.setRelation(
proto.Relation
.newBuilder()
Expand All @@ -2622,8 +2624,17 @@ class SparkConnectPlanner(
.setSqlCommandResult(result)
.build())

// Send Metrics
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, df))
// Send Metrics when isCommand (i.e. show tables) which is eagerly executed & has metrics
// Skip metrics when !isCommand (i.e. select 1) which is not executed & doesn't have metrics
if (isCommand) {
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionHolder.sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
.setMetrics(metrics.build)
.build)
}
}

private def handleRegisterUserDefinedFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper {
allChildren(p).flatMap(c => transformPlan(c, p.id))
}

private[connect] def transformPlan(
rows: DataFrame): Seq[ExecutePlanResponse.Metrics.MetricObject] = {
val executedPlan = rows.queryExecution.executedPlan
transformPlan(executedPlan, executedPlan.id)
}

private def transformPlan(
p: SparkPlan,
parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ private[spark] class DirectKafkaInputDStream[K, V](
val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
val backpressureRate = lag / totalLag.toDouble * rate
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
Math.min(backpressureRate, maxRateLimitPerPartition.toDouble)} else backpressureRate)
}
case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private[spark] class KafkaRDD[K, V](
if (compacted) {
super.countApprox(timeout, confidence)
} else {
val c = count()
val c = count().toDouble
new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long)
time: Long,
elements: Long,
processingDelay: Long,
schedulingDelay: Long): Option[Double] = Some(rate)
schedulingDelay: Long): Option[Double] = Some(rate.toDouble)
}

private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
resources.foreach { case (k, v) =>
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v.name, dataOut)
dataOut.writeInt(v.addresses.size)
dataOut.writeInt(v.addresses.length)
v.addresses.foreach { case addr =>
PythonRDD.writeUTF(addr, dataOut)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
.flatMap(_.iterator)
.groupBy(_._1) // group by resource name
.map { case (rName, rInfoArr) =>
rName -> rInfoArr.map(_._2.addresses.size).sum
rName -> rInfoArr.map(_._2.addresses.length).sum
}
val usedInfo = aliveWorkers.map(_.resourcesInfoUsed)
.flatMap(_.iterator)
.groupBy(_._1) // group by resource name
.map { case (rName, rInfoArr) =>
rName -> rInfoArr.map(_._2.addresses.size).sum
rName -> rInfoArr.map(_._2.addresses.length).sum
}
formatResourcesUsed(totalInfo, usedInfo)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ExecutorMetrics private[spark] extends Serializable {

private[spark] def this(metrics: Array[Long]) = {
this()
Array.copy(metrics, 0, this.metrics, 0, Math.min(metrics.size, this.metrics.size))
Array.copy(metrics, 0, this.metrics, 0, Math.min(metrics.length, this.metrics.length))
}

private[spark] def this(metrics: AtomicLongArray) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private[spark] class FixedLengthBinaryInputFormat
if (defaultSize < recordLength) {
recordLength.toLong
} else {
(Math.floor(defaultSize / recordLength) * recordLength).toLong
defaultSize / recordLength * recordLength
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ private[spark] class StatsdReporter(

private def reportTimer(name: String, timer: Timer)(implicit socket: DatagramSocket): Unit = {
val snapshot = timer.getSnapshot
send(fullName(name, "max"), format(convertDuration(snapshot.getMax)), TIMER)
send(fullName(name, "max"), format(convertDuration(snapshot.getMax.toDouble)), TIMER)
send(fullName(name, "mean"), format(convertDuration(snapshot.getMean)), TIMER)
send(fullName(name, "min"), format(convertDuration(snapshot.getMin)), TIMER)
send(fullName(name, "min"), format(convertDuration(snapshot.getMin.toDouble)), TIMER)
send(fullName(name, "stddev"), format(convertDuration(snapshot.getStdDev)), TIMER)
send(fullName(name, "p50"), format(convertDuration(snapshot.getMedian)), TIMER)
send(fullName(name, "p75"), format(convertDuration(snapshot.get75thPercentile)), TIMER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)

override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(sum, 1.0, sum, sum)
new BoundedDouble(sum.toDouble, 1.0, sum.toDouble, sum.toDouble)
} else if (outputsMerged == 0 || sum == 0) {
new BoundedDouble(0, 0.0, 0.0, Double.PositiveInfinity)
} else {
Expand All @@ -57,7 +57,8 @@ private[partial] object CountEvaluator {
val low = dist.inverseCumulativeProbability((1 - confidence) / 2)
val high = dist.inverseCumulativeProbability((1 + confidence) / 2)
// Add 'sum' to each because distribution is just of remaining count, not observed
new BoundedDouble(sum + dist.getNumericalMean, confidence, sum + low, sum + high)
new BoundedDouble(
sum + dist.getNumericalMean, confidence, (sum + low).toDouble, (sum + high).toDouble)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf

override def currentResult(): Map[T, BoundedDouble] = {
if (outputsMerged == totalOutputs) {
sums.map { case (key, sum) => (key, new BoundedDouble(sum, 1.0, sum, sum)) }.toMap
sums.map { case (key, sum) =>
(key, new BoundedDouble(sum.toDouble, 1.0, sum.toDouble, sum.toDouble))
}.toMap
} else if (outputsMerged == 0) {
new HashMap[T, BoundedDouble]
} else {
Expand Down
Loading

0 comments on commit 8744cc2

Please sign in to comment.