diff --git a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala index 84ae93f178823..facb03365e813 100644 --- a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala +++ b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala @@ -30,8 +30,10 @@ import org.apache.spark.annotation.DeveloperApi * Thus, we should use [[ContextAwareIterator]] to stop consuming after the task ends. * * @since 3.1.0 + * @deprecated since 4.0.0 as its only usage for Python evaluation is now extinct */ @DeveloperApi +@deprecated("Only usage for Python evaluation is now extinct", "3.5.0") class ContextAwareIterator[+T](val context: TaskContext, val delegate: Iterator[T]) extends Iterator[T] { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index eef99c26e77c3..e404c9ee8b4cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,7 +18,6 @@ package org.apache.spark import java.io.File -import java.net.Socket import java.util.Locale import scala.collection.JavaConverters._ @@ -30,7 +29,7 @@ import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.PythonWorkerFactory +import org.apache.spark.api.python.{PythonWorker, PythonWorkerFactory} import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorBackend import org.apache.spark.internal.{config, Logging} @@ -129,7 +128,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, daemonModule: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Int]) = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.getOrElseUpdate(key, @@ -140,7 +139,7 @@ class SparkEnv ( private[spark] def createPythonWorker( pythonExec: String, workerModule: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Int]) = { createPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars) } @@ -150,7 +149,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.get(key).foreach(_.stopWorker(worker)) @@ -161,7 +160,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { destroyPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker) } @@ -171,7 +170,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.get(key).foreach(_.releaseWorker(worker)) @@ -182,7 +181,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { releasePythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 91fd92d4422c8..a2f2d566db5a3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -137,7 +137,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker - private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private val workerBroadcasts = new mutable.WeakHashMap[PythonWorker, mutable.Set[Long]]() // Authentication helper used when serving iterator data. private lazy val authHelper = { @@ -145,7 +145,7 @@ private[spark] object PythonRDD extends Logging { new SocketAuthHelper(conf) } - def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { + def getWorkerBroadcasts(worker: PythonWorker): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } @@ -300,7 +300,11 @@ private[spark] object PythonRDD extends Logging { new PythonBroadcast(path) } - def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = { + /** + * Writes the next element of the iterator `iter` to `dataOut`. Returns true if any data was + * written to the stream. Returns false if no data was written as the iterator has been exhausted. + */ + def writeNextElementToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Boolean = { def write(obj: Any): Unit = obj match { case null => @@ -318,8 +322,18 @@ private[spark] object PythonRDD extends Logging { case other => throw new SparkException("Unexpected element type " + other.getClass) } + if (iter.hasNext) { + write(iter.next()) + true + } else { + false + } + } - iter.foreach(write) + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = { + while (writeNextElementToStream(iter, dataOut)) { + // Nothing. + } } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 0173de75ff23e..d7801d2e83b97 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -19,6 +19,8 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files => JavaFiles, Path} @@ -32,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} import org.apache.spark.internal.config.Python._ +import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -103,6 +106,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val timelyFlushEnabled: Boolean = false + protected val timelyFlushTimeoutNanos: Long = 0 protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) @@ -143,7 +148,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Python accumulator is always set in production except in tests. See SPARK-27893 private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) - // Expose a ServerSocket to support method calls via socket from Python side. + // Expose a ServerSocket to support method calls via socket from Python side. Only relevant for + // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details. private[spark] var serverSocket: Option[ServerSocket] = None // Authentication helper used when serving method calls via socket from Python side. @@ -194,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( + val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker( pythonExec, workerModule, daemonModule, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make @@ -202,22 +208,19 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val releasedOrClosed = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + val writer = newWriter(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener[Unit] { _ => - writerThread.shutdownOnTaskCompletion() if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { try { - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to stop worker") } } } - writerThread.start() - new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start() if (reuseWorker) { val key = (worker, context.taskAttemptId) // SPARK-35009: avoid creating multiple monitor threads for the same python worker @@ -230,68 +233,49 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - + val dataIn = new DataInputStream( + new BufferedInputStream(new ReaderInputStream(worker, writer), bufferSize)) val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) + dataIn, writer, startTime, env, worker, pid, releasedOrClosed, context) new InterruptibleIterator(context, stdoutIterator) } - protected def newWriterThread( + protected def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext): WriterThread + context: TaskContext): Writer protected def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - abstract class WriterThread( + abstract class Writer( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { + context: TaskContext) { - @volatile private var _exception: Throwable = null + @volatile private var _exception: Throwable = _ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - setDaemon(true) - /** Contains the throwable thrown while writing the parent iterator to the Python process. */ def exception: Option[Throwable] = Option(_exception) - /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur - * due to cleanup. - */ - def shutdownOnTaskCompletion(): Unit = { - assert(context.isCompleted) - this.interrupt() - // Task completion listeners that run after this method returns may invalidate - // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized - // reader, a task completion listener will free the underlying off-heap buffers. If the writer - // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free - // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer - // thread to exit before returning. - this.join() - } - /** * Writes a command section to the stream connected to the Python worker. */ @@ -299,14 +283,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( /** * Writes input data to the stream connected to the Python worker. + * Returns true if any data was written to the stream, false if the input is exhausted. */ - protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + def writeNextInputToStream(dataOut: DataOutputStream): Boolean - override def run(): Unit = Utils.logUncaughtExceptions { + def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions { try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(partitionIndex) @@ -367,21 +349,25 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } else { "" } - // Close ServerSocket on task completion. - serverSocket.foreach { server => - context.addTaskCompletionListener[Unit](_ => server.close()) - } - val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) - if (boundPort == -1) { - val message = "ServerSocket failed to bind to Java side." - logError(message) - throw new SparkException(message) - } else if (isBarrier) { + if (isBarrier) { + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } logDebug(s"Started ServerSocket on port $boundPort.") + dataOut.writeBoolean(/* isBarrier = */true) + dataOut.writeInt(boundPort) + } else { + dataOut.writeBoolean(/* isBarrier = */false) + dataOut.writeInt(0) } // Write out the TaskContextInfo - dataOut.writeBoolean(isBarrier) - dataOut.writeInt(boundPort) val secretBytes = secret.getBytes(UTF_8) dataOut.writeInt(secretBytes.length) dataOut.write(secretBytes, 0, secretBytes.length) @@ -412,30 +398,33 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(evalType) writeCommand(dataOut) - writeIteratorToStream(dataOut) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { - case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] => if (context.isCompleted || context.isInterrupted) { logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + "cleanup)", t) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + if (worker.channel.isConnected) { + Utils.tryLog(worker.channel.shutdownOutput()) } } else { // We must avoid throwing exceptions/NonFatals here, because the thread uncaught // exception handler will kill the whole executor (see // org.apache.spark.executor.Executor). _exception = t - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + if (worker.channel.isConnected) { + Utils.tryLog(worker.channel.shutdownOutput()) } } } } + def close(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } + /** * Gateway to call BarrierTaskContext methods. */ @@ -470,10 +459,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( abstract class ReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) @@ -531,7 +520,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val obj = new Array[Byte](exLength) stream.readFully(obj) new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.orNull) + writer.exception.orNull) } protected def handleEndOfDataSection(): Unit = { @@ -554,10 +543,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( logDebug("Exception thrown after task interruption", e) throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - case e: Exception if writerThread.exception.isDefined => + case e: Exception if writer.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get + logError("This may have been caused by a prior exception:", writer.exception.get) + throw writer.exception.get case eof: EOFException if faultHandlerEnabled && pid.isDefined && JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) => @@ -576,7 +565,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the * threads can block indefinitely. */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + class MonitorThread(env: SparkEnv, worker: PythonWorker, context: TaskContext) extends Thread(s"Worker Monitor for $pythonExec") { /** How long to wait before killing the python worker if a task cannot be interrupted. */ @@ -620,60 +609,185 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } - /** - * This thread monitors the WriterThread and kills it in case of deadlock. - * - * A deadlock can arise if the task completes while the writer thread is sending input to the - * Python process (e.g. due to the use of `take()`), and the Python process is still producing - * output. When the inputs are sufficiently large, this can result in a deadlock due to the use of - * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. - */ - class WriterMonitorThread( - env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext) - extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { - + class ReaderInputStream(worker: PythonWorker, writer: Writer) extends InputStream { + private[this] var writerIfbhThreadLocalValue: Object = null + private[this] val temp = new Array[Byte](1) + private[this] val bufferStream = new DirectByteBufferOutputStream() /** - * How long to wait before closing the socket if the writer thread has not exited after the task - * ends. + * Buffers data to be written to the Python worker until the socket is + * available for write. + * A best-effort attempt is made to not grow the buffer beyond "spark.buffer.size". See + * `writeAdditionalInputToPythonWorker()` for details. */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) + private[this] var buffer: ByteBuffer = _ + private[this] var hasInput = true - setDaemon(true) + writer.open(new DataOutputStream(bufferStream)) + buffer = bufferStream.toByteBuffer - override def run(): Unit = { - // Wait until the task is completed (or the writer thread exits, in which case this thread has - // nothing to do). - while (!context.isCompleted && writerThread.isAlive) { - Thread.sleep(2000) + override def read(): Int = { + val n = read(temp) + if (n <= 0) { + -1 + } else { + // Signed byte to unsigned integer + temp(0) & 0xff } - if (writerThread.isAlive) { - Thread.sleep(taskKillTimeout) - // If the writer thread continues running, this indicates a deadlock. Kill the worker to - // resolve the deadlock. - if (writerThread.isAlive) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning( - s"Detected deadlock while completing task $taskName: " + - "Attempting to kill Python Worker") - env.destroyPythonWorker( - pythonExec, workerModule, daemonModule, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + // The code below manipulates the InputFileBlockHolder thread local in order + // to prevent behavior changes in the input_file_name() expression due to the switch from + // multi-threaded to single-threaded Python execution (SPARK-44705). + // + // Prior to that change, scan operations feeding into PythonRunner would be evaluated in + // "writer" threads that were child threads of the main task thread. As a result, when + // a scan operation hit end-of-input and called InputFileBlockHolder.unset(), the effects + // of unset() would only occur in the writer thread and not the main task thread: this + // meant that code "downstream" of a PythonRunner would continue to observe the writer's + // last pre-unset() value (i.e. the last read filename). + // + // Switching to a single-threaded Python runner changed this behavior: now, unset() would + // impact operators both upstream and downstream of the PythonRunner and this would cause + // unset()'s effects to be immediately visible to downstream operators, in turn causing the + // input_file_name() expression to return empty filenames in situations where it previously + // would have returned the last non-empty filename. + // + // To avoid this behavior change, the code below simulates the behavior of the + // InputFileBlockHolder's inheritable thread local: + // + // - Detect whether code that previously would have run in the writer thread has changed + // the thread local value itself. Note that the thread local holds a mutable + // AtomicReference, so the thread local's value only changes objects when unset() is + // called. + // - If an object change was detected, then henceforth we will swap between the "main" + // and "writer" thread local values when context switching between upstream and + // downstream operator execution. + // + // This issue is subtle and several other alternative approaches were considered + val buf = ByteBuffer.wrap(b, off, len) + var n = 0 + while (n == 0) { + worker.selector.select() + if (worker.selectionKey.isReadable) { + n = worker.channel.read(buf) + } + if (worker.selectionKey.isWritable) { + val mainIfbhThreadLocalValue = InputFileBlockHolder.getThreadLocalValue() + // Check whether the writer's thread local value has diverged from its parent's value: + if (writerIfbhThreadLocalValue eq null) { + // Default case (which is why it appears first): the writer's thread local value + // is the same object as the main code, so no need to swap before executing the + // writer code. + try { + // Execute the writer code: + writeAdditionalInputToPythonWorker() + } finally { + // Check whether the writer code changed the thread local value: + val maybeNewIfbh = InputFileBlockHolder.getThreadLocalValue() + if (maybeNewIfbh ne mainIfbhThreadLocalValue) { + // The writer thread change the thread local, so henceforth we need to + // swap. Store the writer thread's value and restore the old main thread + // value: + writerIfbhThreadLocalValue = maybeNewIfbh + InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue) + } + } + } else { + // The writer thread and parent thread have different values, so we must swap + // them when switching between writer and parent code: + try { + // Swap in the writer value: + InputFileBlockHolder.setThreadLocalValue(writerIfbhThreadLocalValue) + try { + // Execute the writer code: + writeAdditionalInputToPythonWorker() + } finally { + // Store an updated writer thread value: + writerIfbhThreadLocalValue = InputFileBlockHolder.getThreadLocalValue() + } + } finally { + // Restore the main thread's value: + InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue) + } } } } + n + } + + private var lastFlushTime = System.nanoTime() + + /** + * Returns false if `timelyFlushEnabled` is disabled. + * + * Otherwise, returns true if `buffer` should be flushed before any additional data is + * written to it. + * For small input rows the data might stay in the buffer for long before it is sent to the + * Python worker. We should flush the buffer periodically so that the downstream can make + * continued progress. + */ + private def shouldFlush(): Boolean = { + if (!timelyFlushEnabled) { + false + } else { + val currentTime = System.nanoTime() + if (currentTime - lastFlushTime > timelyFlushTimeoutNanos) { + lastFlushTime = currentTime + bufferStream.size() > 0 + } else { + false + } + } + } + + /** + * Reads input data from `writer.inputIterator` into `buffer` and writes the buffer to the + * Python worker if the socket is available for writing. + */ + private def writeAdditionalInputToPythonWorker(): Unit = { + var acceptsInput = true + while (acceptsInput && (hasInput || buffer.hasRemaining)) { + if (!buffer.hasRemaining && hasInput) { + // No buffered data is available. Try to read input into the buffer. + bufferStream.reset() + // Set the `buffer` to null to make it eligible for GC + buffer = null + + val dataOut = new DataOutputStream(bufferStream) + // Try not to grow the buffer much beyond `bufferSize`. This is inevitable for large + // input rows. + while (bufferStream.size() < bufferSize && hasInput && !shouldFlush()) { + hasInput = writer.writeNextInputToStream(dataOut) + } + if (!hasInput) { + // Reached the end of the input. + writer.close(dataOut) + } + buffer = bufferStream.toByteBuffer + } + + // Try to write as much buffered data as possible to the socket. + while (buffer.hasRemaining && acceptsInput) { + val n = worker.channel.write(buffer) + acceptsInput = n > 0 + } + } + + if (!hasInput && !buffer.hasRemaining) { + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + bufferStream.close() + } } } + } private[spark] object PythonRunner { // already running worker monitor threads for worker and task attempts ID pairs - val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]() + val runningMonitorThreads = ConcurrentHashMap.newKeySet[(PythonWorker, Long)]() private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true) @@ -693,13 +807,13 @@ private[spark] class PythonRunner( extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command @@ -707,28 +821,32 @@ private[spark] class PythonRunner( dataOut.write(command.toArray) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + if (PythonRDD.writeNextElementToStream(inputIterator, dataOut)) { + true + } else { + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + false + } } } } protected override def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { stream.readInt() match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4ba6dd949b14a..1db8748c327ab 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -18,7 +18,8 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream} -import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.net.{InetAddress, InetSocketAddress, SocketException} +import java.nio.channels._ import java.util.Arrays import java.util.concurrent.TimeUnit import javax.annotation.concurrent.GuardedBy @@ -33,6 +34,14 @@ import org.apache.spark.internal.config.Python._ import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} +case class PythonWorker(channel: SocketChannel, selector: Selector, selectionKey: SelectionKey) { + def stop(): Unit = { + selectionKey.cancel() + selector.close() + channel.close() + } +} + private[spark] class PythonWorkerFactory( pythonExec: String, workerModule: String, @@ -67,32 +76,33 @@ private[spark] class PythonWorkerFactory( @GuardedBy("self") private var daemonPort: Int = 0 @GuardedBy("self") - private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]() @GuardedBy("self") - private val idleWorkers = new mutable.Queue[Socket]() + private val idleWorkers = new mutable.Queue[PythonWorker]() @GuardedBy("self") private var lastActivityNs = 0L new MonitorThread().start() @GuardedBy("self") - private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() + private val simpleWorkers = new mutable.WeakHashMap[PythonWorker, Process]() private val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): (Socket, Option[Int]) = { + def create(): (PythonWorker, Option[Int]) = { if (useDaemon) { self.synchronized { if (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() + worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE) return (worker, daemonWorkers.get(worker)) } } createThroughDaemon() } else { - createSimpleWorker() + createSimpleWorker(blockingMode = false) } } @@ -101,18 +111,25 @@ private[spark] class PythonWorkerFactory( * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): (Socket, Option[Int]) = { + private def createThroughDaemon(): (PythonWorker, Option[Int]) = { - def createSocket(): (Socket, Option[Int]) = { - val socket = new Socket(daemonHost, daemonPort) - val pid = new DataInputStream(socket.getInputStream).readInt() + def createWorker(): (PythonWorker, Option[Int]) = { + val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) + // These calls are blocking. + val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } - authHelper.authToServer(socket) - daemonWorkers.put(socket, pid) - (socket, Some(pid)) + authHelper.authToServer(socketChannel.socket()) + socketChannel.configureBlocking(false) + val selector = Selector.open() + val selectionKey = socketChannel.register(selector, + SelectionKey.OP_READ | SelectionKey.OP_WRITE) + val worker = PythonWorker(socketChannel, selector, selectionKey) + + daemonWorkers.put(worker, pid) + (worker, Some(pid)) } self.synchronized { @@ -121,14 +138,14 @@ private[spark] class PythonWorkerFactory( // Attempt to connect, restart and retry once if it fails try { - createSocket() + createWorker() } catch { case exc: SocketException => logWarning("Failed to open socket to Python daemon:", exc) logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - createSocket() + createWorker() } } } @@ -136,10 +153,11 @@ private[spark] class PythonWorkerFactory( /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private[spark] def createSimpleWorker(): (Socket, Option[Int]) = { - var serverSocket: ServerSocket = null + private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = { + var serverSocketChannel: ServerSocketChannel = null try { - serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) + serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1) // Create and start the worker val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) @@ -154,38 +172,49 @@ private[spark] class PythonWorkerFactory( workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") - workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocketChannel.socket().getLocalPort + .toString) workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) if (Utils.preferIPv6) { workerEnv.put("SPARK_PREFER_IPV6", "True") } - val worker = pb.start() + val workerProcess = pb.start() // Redirect worker stdout and stderr - redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) + redirectStreamsToStderr(workerProcess.getInputStream, workerProcess.getErrorStream) // Wait for it to connect to our socket, and validate the auth secret. - serverSocket.setSoTimeout(10000) + serverSocketChannel.socket().setSoTimeout(10000) try { - val socket = serverSocket.accept() - authHelper.authClient(socket) - // TODO: When we drop JDK 8, we can just use worker.pid() - val pid = new DataInputStream(socket.getInputStream).readInt() + val socketChannel = serverSocketChannel.accept() + authHelper.authClient(socketChannel.socket()) + // TODO: When we drop JDK 8, we can just use workerProcess.pid() + val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } + if (!blockingMode) { + socketChannel.configureBlocking(false) + } + val selector = Selector.open() + val selectionKey = if (blockingMode) { + null + } else { + socketChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE) + } + val worker = PythonWorker(socketChannel, selector, selectionKey) self.synchronized { - simpleWorkers.put(socket, worker) + simpleWorkers.put(worker, workerProcess) } - return (socket, Some(pid)) + return (worker, Some(pid)) } catch { case e: Exception => throw new SparkException("Python worker failed to connect back.", e) } } finally { - if (serverSocket != null) { - serverSocket.close() + if (serverSocketChannel != null) { + serverSocketChannel.close() } } null @@ -320,11 +349,10 @@ private[spark] class PythonWorkerFactory( while (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() try { - // the worker will exit after closing the socket - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to stop worker socket", e) } } } @@ -351,7 +379,7 @@ private[spark] class PythonWorkerFactory( stopDaemon() } - def stopWorker(worker: Socket): Unit = { + def stopWorker(worker: PythonWorker): Unit = { self.synchronized { if (useDaemon) { if (daemon != null) { @@ -367,22 +395,21 @@ private[spark] class PythonWorkerFactory( simpleWorkers.get(worker).foreach(_.destroy()) } } - worker.close() + worker.stop() } - def releaseWorker(worker: Socket): Unit = { + def releaseWorker(worker: PythonWorker): Unit = { if (useDaemon) { self.synchronized { lastActivityNs = System.nanoTime() idleWorkers.enqueue(worker) } } else { - // Cleanup the worker socket. This will also cause the Python worker to exit. try { - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to close worker", e) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala index b6ab031d388b2..3f7b11a40ada1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, File} -import java.net.Socket import java.nio.charset.StandardCharsets import org.apache.spark.{SparkEnv, SparkFiles} @@ -76,7 +75,7 @@ private[spark] object PythonWorkerUtils extends Logging { */ def writeBroadcasts( broadcastVars: Seq[Broadcast[PythonBroadcast]], - worker: Socket, + worker: PythonWorker, env: SparkEnv, dataOut: DataOutputStream): Unit = { // Broadcast variables @@ -117,9 +116,6 @@ private[spark] object PythonWorkerUtils extends Logging { dataOut.writeLong(id) } dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") } else { sendBidsToRemove() for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index fdfe388db2d40..e82052e41be1a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -18,7 +18,6 @@ package org.apache.spark.api.python import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} -import java.net.Socket import scala.collection.JavaConverters._ @@ -50,7 +49,8 @@ private[spark] class StreamingPythonRunner( private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec - private var pythonWorker: Option[Socket] = None + private var pythonWorker: Option[PythonWorker] = None + private var pythonWorkerFactory: Option[PythonWorkerFactory] = None protected val pythonVer: String = func.pythonVer /** @@ -71,14 +71,17 @@ private[spark] class StreamingPythonRunner( val prevConf = conf.get(PYTHON_USE_DAEMON) conf.set(PYTHON_USE_DAEMON, false) try { - val (worker, _) = env.createPythonWorker( - pythonExec, workerModule, envVars.asScala.toMap) + val workerFactory = + new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) + val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true) pythonWorker = Some(worker) + pythonWorkerFactory = Some(workerFactory) } finally { conf.set(PYTHON_USE_DAEMON, prevConf) } - val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, bufferSize) + val stream = new BufferedOutputStream( + pythonWorker.get.channel.socket().getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) @@ -93,7 +96,7 @@ private[spark] class StreamingPythonRunner( dataOut.flush() val dataIn = new DataInputStream( - new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) + new BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, bufferSize)) val resFromPython = dataIn.readInt() logInfo(s"Runner initialization returned $resFromPython") diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index 8230144025feb..5f2a9dd2743c6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -55,6 +55,14 @@ private[spark] object InputFileBlockHolder { new AtomicReference(new FileBlock) } + private[spark] def setThreadLocalValue(ref: Object): Unit = { + inputBlock.set(ref.asInstanceOf[AtomicReference[FileBlock]]) + } + + private[spark] def getThreadLocalValue(): Object = { + inputBlock.get() + } + /** * Returns the holding file name or empty string if it is unknown. */ @@ -72,6 +80,9 @@ private[spark] object InputFileBlockHolder { /** * Sets the thread-local input block. + * + * Callers of this method must ensure a task completion listener has been registered to unset() + * the thread local in the task thread. */ def set(filePath: String, startOffset: Long, length: Long): Unit = { require(filePath != null, "filePath cannot be null") diff --git a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala new file mode 100644 index 0000000000000..a4145bb36acc9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.OutputStream +import java.nio.ByteBuffer + +import org.apache.spark.storage.StorageUtils +import org.apache.spark.unsafe.Platform + +/** + * An output stream that dumps data into a direct byte buffer. The byte buffer grows in size + * as more data is written to the stream. + * @param capacity The initial capacity of the direct byte buffer + */ +private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputStream { + private var buffer = Platform.allocateDirectBuffer(capacity) + + def this() = this(32) + + override def write(b: Int): Unit = { + ensureCapacity(buffer.position() + 1) + buffer.put(b.toByte) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + ensureCapacity(buffer.position() + len) + buffer.put(b, off, len) + } + + private def ensureCapacity(minCapacity: Int): Unit = { + if (minCapacity > buffer.capacity()) grow(minCapacity) + } + + /** + * Grows the current buffer to at least `minCapacity` capacity. + * As a side effect, all references to the old buffer will be invalidated. + */ + private def grow(minCapacity: Int): Unit = { + val oldCapacity = buffer.capacity() + var newCapacity = oldCapacity << 1 + if (newCapacity < minCapacity) newCapacity = minCapacity + val oldBuffer = buffer + oldBuffer.flip() + val newBuffer = ByteBuffer.allocateDirect(newCapacity) + newBuffer.put(oldBuffer) + StorageUtils.dispose(oldBuffer) + buffer = newBuffer + } + + def reset(): Unit = buffer.clear() + + def size(): Int = buffer.position() + + /** + * Any subsequent call to [[close()]], [[write()]], [[reset()]] will invalidate the buffer + * returned by this method. + */ + def toByteBuffer: ByteBuffer = { + val outputBuffer = buffer.duplicate() + outputBuffer.flip() + outputBuffer + } + + override def close(): Unit = { + // Eagerly free the direct byte buffer without waiting for GC to reduce memory pressure. + StorageUtils.dispose(buffer) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index d4c535fe76a3e..a60d0beeeed54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -55,7 +55,7 @@ class ApplyInPandasWithStatePythonRunner( evalType: Int, argOffsets: Array[Array[Int]], inputSchema: StructType, - override protected val timeZoneId: String, + _timeZoneId: String, initialWorkerConf: Map[String, String], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, @@ -73,8 +73,10 @@ class ApplyInPandasWithStatePythonRunner( private val sqlConf = SQLConf.get - override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) - + // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s + // constructor. + override protected lazy val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) + override protected lazy val timeZoneId: String = _timeZoneId override val errorOnDuplicatedFieldNames: Boolean = true override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback @@ -113,37 +115,41 @@ class ApplyInPandasWithStatePythonRunner( // Also write the schema for state value PythonRDD.writeUTF(stateValueSchema.json, stream) } - + private var pandasWriter: ApplyInPandasWithStateWriter = _ /** * Read the (key, state, values) from input iterator and construct Arrow RecordBatches, and * write constructed RecordBatches to the writer. * * See [[ApplyInPandasWithStateWriter]] for more details. */ - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[InType]): Unit = { - val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) - - while (inputIterator.hasNext) { + inputIterator: Iterator[InType]): Boolean = { + if (pandasWriter == null) { + pandasWriter = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) + } + if (inputIterator.hasNext) { val startData = dataOut.size() val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") - w.startNewGroup(keyRow, groupState) + pandasWriter.startNewGroup(keyRow, groupState) while (dataIter.hasNext) { val dataRow = dataIter.next() - w.writeRow(dataRow) + pandasWriter.writeRow(dataRow) } - w.finalizeGroup() + pandasWriter.finalizeGroup() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + pandasWriter.finalizeData() + super[PythonArrowInput].close() + false } - - w.finalizeData() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index d9bce96c47768..0f26d8f21f8d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -31,8 +31,8 @@ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, + _schema: StructType, + _timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], @@ -50,6 +50,10 @@ class ArrowPythonRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s + // constructor. + override protected lazy val timeZoneId: String = _timeZoneId + override protected lazy val schema: StructType = _schema override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( bufferSize >= 4, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 9dae874e3ed96..6c8412f8b3770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import scala.collection.JavaConverters._ import net.razorvine.pickle.Unpickler import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorkerUtils} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorker, PythonWorkerUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -101,13 +100,13 @@ class PythonUDTFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index eef8be7c940b0..bd901545bb03c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric @@ -60,14 +59,14 @@ class CoGroupedArrowPythonRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - protected def newWriterThread( + protected def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], partitionIndex: Int, - context: TaskContext): WriterThread = { + context: TaskContext): Writer = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { @@ -81,10 +80,10 @@ class CoGroupedArrowPythonRunner( PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { // For each we first send the number of dataframes in each group then send // first df, then send second df. End of data is marked by sending 0. - while (inputIterator.hasNext) { + if (inputIterator.hasNext) { val startData = dataOut.size() dataOut.writeInt(2) val (nextLeft, nextRight) = inputIterator.next() @@ -93,8 +92,11 @@ class CoGroupedArrowPythonRunner( val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + dataOut.writeInt(0) + false } - dataOut.writeInt(0) } private def writeGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala index 10bb3a45be94a..373e17c0aa383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala @@ -21,7 +21,7 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -62,7 +62,6 @@ abstract class EvalPythonEvaluatorFactory( iters: Iterator[InternalRow]*): Iterator[InternalRow] = { val iter = iters.head val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, iter) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -97,7 +96,7 @@ abstract class EvalPythonEvaluatorFactory( }.toArray) // Add rows to queue to join later with the result. - val projectedRowIter = contextAwareIterator.map { inputRow => + val projectedRowIter = iter.map { inputRow => queue.add(inputRow.asInstanceOf[UnsafeRow]) projection(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8d2f788e05cc7..6664acf957263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.{ContextAwareIterator, TaskContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -302,7 +301,7 @@ object EvaluatePython { def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { rdd.mapPartitions { iter => registerPicklers() // let it called in executor - new SerDeUtil.AutoBatchedPickler(new ContextAwareIterator(TaskContext.get, iter)) + new SerDeUtil.AutoBatchedPickler(iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 1e15aa7f777bf..6f501e1411ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ -import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -52,11 +52,10 @@ class MapInBatchEvaluatorFactory( // Single function with one struct. val argOffsets = Array(Array(0)) val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, inputIter) // Here we wrap it via another row so that Python sides understand it // as a DataFrame. - val wrappedIter = contextAwareIterator.map(InternalRow(_)) + val wrappedIter = inputIter.map(InternalRow(_)) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 5c99a3f9808bb..00ee3a175631a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType @@ -48,11 +48,11 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected def pythonMetrics: Map[String, SQLMetric] - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[IN]): Unit + inputIterator: Iterator[IN]): Boolean protected def writeUDF( dataOut: DataOutputStream, @@ -68,51 +68,46 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => PythonRDD.writeUTF(v, stream) } } + private val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue) + protected val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected var writer: ArrowStreamWriter = _ + +protected def close(): Unit = { + Utils.tryWithSafeFinally { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + root.close() + allocator.close() + } +} - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { - + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema( - schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - Utils.tryWithSafeFinally { - val writer = new ArrowStreamWriter(root, null, dataOut) + if (writer == null) { + writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - - writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() - } { - // If we close root and allocator in TaskCompletionListener, there could be a race - // condition where the writer thread keeps writing to the VectorSchemaRoot while - // it's being closed by the TaskCompletion listener. - // Closing root and allocator here is cleaner because root and allocator is owned - // by the writer thread and is only visible to the writer thread. - // - // If the writer thread is interrupted by TaskCompletionListener, it should either - // (1) in the try block, in which case it will get an InterruptedException when - // performing io, and goes into the finally block or (2) in the finally block, - // in which case it will ignore the interruption and close the resources. - root.close() - allocator.close() } + + assert(writer != null) + writeNextInputToArrowStream(root, writer, dataOut, inputIterator) } } } @@ -120,15 +115,15 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] { self: BasePythonRunner[Iterator[InternalRow], _] => + private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root) - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { - val arrowWriter = ArrowWriter.create(root) + inputIterator: Iterator[Iterator[InternalRow]]): Boolean = { - while (inputIterator.hasNext) { + if (inputIterator.hasNext) { val startData = dataOut.size() val nextBatch = inputIterator.next() @@ -141,6 +136,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In arrowWriter.reset() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + super[PythonArrowInput].close() + false } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index c12c690f776a3..8f99325e4e08c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.python import java.io.DataInputStream -import java.net.Socket import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ @@ -26,7 +25,7 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.api.python.{BasePythonRunner, PythonWorker, SpecialLengths} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -46,16 +45,16 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ protected def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdin reader for $pythonExec", 0, Long.MaxValue) @@ -80,8 +79,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ } protected override def read(): OUT = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { if (reader != null && batchLoaded) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 3857f084bcb0b..a229931cec89e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -31,6 +31,44 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, Utils} +/** + * Writes the rows buffered in [[UnsafeRowBuffer]] to the Python worker. + * Any exceptions encountered will be cached to be read later by the parent thread. + */ +class WriterThread(outputIterator: Iterator[Array[Byte]]) + extends Thread(s"Thread streaming data to the Python worker") { + + @volatile var _exception: Throwable = _ + + override def run(): Unit = { + try { + // [[PythonForEachWriter]] is a sink and thus the Python worker does not generate any output. + // The `hasNext()` and `next()` call are an indirect way to ship the input data to the + // Python worker. Consuming the Python worker's output iterator, as a side-effect, drives the + // write of the input data to the Python worker through [[org.apache.spark.api.python. + // BasePythonRunner.ReaderInputStream .writeAdditionalInputToPythonWorker]]. + if (outputIterator.hasNext) { + outputIterator.next() + } + } catch { + // Cache exceptions seen while evaluating the Python function on the streamed input. The + // parent thread will throw this crashed exception eventually. + case t: Throwable => + _exception = t + } + } +} + +/** + * The class proceeds as follows: + * - Rows streamed through a `process()` call on the + * [[org.apache.spark.sql.execution.streaming.QueryExecutionThread]] are buffered in the + * `UnsafeRowBuffer`. + * - The [[WriterThread]] streams the buffered data to the Python worker. + * - Once the streaming query ends, [[close()]] is called which signals the buffer to mark the + * end of streaming input. The streaming query execution thread waits for the [[WriterThread]] to + * complete and throws any exceptions seen by the [[WriterThread]]. + */ class PythonForeachWriter(func: PythonFunction, schema: StructType) extends ForeachWriter[UnsafeRow] { @@ -58,8 +96,11 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) private lazy val outputIterator = pythonRunner.compute(inputByteIterator, context.partitionId(), context) + private lazy val writerThread = new WriterThread(outputIterator) + override def open(partitionId: Long, version: Long): Boolean = { outputIterator // initialize everything + writerThread.start() TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } true } @@ -68,9 +109,15 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) buffer.add(value) } + /** + * Waits for the writer thread to finish evaluating the Python function. Throws any exceptions + * seen by the writer thread. + */ override def close(errorOrNull: Throwable): Unit = { buffer.allRowsAdded() - if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + writerThread.join() + // Throw Python exception if there was one. + if (writerThread._exception != null) throw writerThread._exception } } @@ -78,18 +125,20 @@ object PythonForeachWriter { /** * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. - * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader - * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python - * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator - * are blocking, that is, it blocks until new data is available or all data has been added. + * It is designed to be used with only two threads: the QueryExecutionThread which writes data + * to the buffer and [[WriterThread]] thread that reads from the buffer and writes to the + * Python worker stdin. Adds to the buffer are non-blocking, and reads through the buffer's + * iterator are blocking, that is, it blocks until new data is available or all data has been + * added. * * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue * across memory and local disk. However, HybridRowQueue is designed to be used only with - * EvalPythonExec where the reader is always behind the writer, that is, the reader does not - * try to read n+1 rows if the writer has only written n rows at any point of time. This - * assumption is not true for PythonForeachWriter where rows may be added at a different rate as - * they are consumed by the python worker. Hence, to maintain the invariant of the reader being - * behind the writer while using HybridRowQueue, the buffer does the following + * EvalPythonExec where the buffer's consumer is always behind the buffer's populator, that is, + * the [[WriterThread]] does not try to read n + 1 rows if the streaming thread has only + * written n rows at any point of time. This assumption is not true for PythonForeachWriter + * where rows may be added at a different rate as they are consumed by the Python worker. + * Hence, to maintain the invariant of the reader being behind the writer while using + * HybridRowQueue, the buffer does the following: * - Keeps a count of the rows in the HybridRowQueue * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not * try to read more rows than what has been written. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 22083e0473b7d..bc27ee6919dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.python import java.io._ -import java.net._ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ @@ -44,40 +43,42 @@ abstract class BasePythonUDFRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - abstract class PythonUDFWriterThread( + abstract class PythonUDFWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { + extends Writer(env, worker, inputIterator, partitionIndex, context) { - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { val startData = dataOut.size() - - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - + val wroteData = PythonRDD.writeNextElementToStream(inputIterator, dataOut) + if (!wroteData) { + // Reached the end of input. + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + wroteData } } protected override def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { stream.readInt() match { @@ -110,13 +111,13 @@ class PythonUDFRunner( jobArtifactUUID: Option[String]) extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 36cb2e17835a4..5fa9c89b3d15b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.python -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} -import java.net.Socket +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream} +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey import java.nio.charset.StandardCharsets import java.util.HashMap @@ -27,7 +28,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException} -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths} import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} @@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRo import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.DirectByteBufferOutputStream /** * A user-defined Python function. This is used by the Python API. @@ -205,13 +207,14 @@ object UserDefinedPythonTableFunction { val pickler = new Pickler(/* useMemo = */ true, /* valueCompare = */ false) - val (worker: Socket, _) = + val (worker: PythonWorker, _) = env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap) var releasedOrClosed = false + val bufferStream = new DirectByteBufferOutputStream() try { - val dataOut = - new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize)) - val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) + val dataIn = new DataInputStream(new BufferedInputStream( + new WorkerInputStream(worker, bufferStream), bufferSize)) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) @@ -276,4 +279,50 @@ object UserDefinedPythonTableFunction { } } } + + /** + * A wrapper of the non-blocking IO to write to/read from the worker. + * + * Since we use non-blocking IO to communicate with workers; see SPARK-44705, + * a wrapper is needed to do IO with the worker. + * This is a port and simplified version of `PythonRunner.ReaderInputStream`, + * and only supports to write all at once and then read all. + */ + private class WorkerInputStream( + worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) extends InputStream { + + private[this] val temp = new Array[Byte](1) + + override def read(): Int = { + val n = read(temp) + if (n <= 0) { + -1 + } else { + // Signed byte to unsigned integer + temp(0) & 0xff + } + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + val buf = ByteBuffer.wrap(b, off, len) + var n = 0 + while (n == 0) { + worker.selector.select() + if (worker.selectionKey.isReadable) { + n = worker.channel.read(buf) + } + if (worker.selectionKey.isWritable) { + val buffer = bufferStream.toByteBuffer + var acceptsInput = true + while (acceptsInput && buffer.hasRemaining) { + val n = worker.channel.write(buffer) + acceptsInput = n > 0 + } + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + } + } + n + } + } }