diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index edbdda8a0bcb6..34ee3a48f8e74 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -45,8 +45,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getJobIdsForGroup(jobGroup: String): Array[Int] = { jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + jobProgressListener.jobGroupToJobIds.getOrElse(jobGroup, Seq.empty).toArray } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 0997507d016f5..9db6fd1ac4dbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -101,6 +101,8 @@ private[deploy] object DeployMessages { case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage + case class UnregisterApplication(appId: String) + case class MasterChangeAcknowledged(appId: String) // Master to AppClient diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 3b729725257ef..4f06d7f96c46e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -157,6 +157,7 @@ private[spark] class AppClient( case StopAppClient => markDead("Application has been stopped.") + master ! UnregisterApplication(appId) sender ! true context.stop(self) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 536aedb6f9fe9..bc5b293379f2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -91,7 +91,7 @@ private[deploy] class ApplicationInfo( } } - private[master] val requestedCores = desc.maxCores.getOrElse(defaultCores) + private val requestedCores = desc.maxCores.getOrElse(defaultCores) private[master] def coresLeft: Int = requestedCores - coresGranted @@ -111,6 +111,10 @@ private[deploy] class ApplicationInfo( endTime = System.currentTimeMillis() } + private[master] def isFinished: Boolean = { + state != ApplicationState.WAITING && state != ApplicationState.RUNNING + } + def duration: Long = { if (endTime != -1) { endTime - startTime diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 80506621f4d24..9a5d5877da86d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -339,7 +339,11 @@ private[master] class Master( if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") - appInfo.removeExecutor(exec) + // If an application has already finished, preserve its + // state to display its information properly on the UI + if (!appInfo.isFinished) { + appInfo.removeExecutor(exec) + } exec.worker.removeExecutor(exec) val normalExit = exitStatus == Some(0) @@ -428,6 +432,10 @@ private[master] class Master( if (canCompleteRecovery) { completeRecovery() } } + case UnregisterApplication(applicationId) => + logInfo(s"Received unregister request from application $applicationId") + idToApp.get(applicationId).foreach(finishApplication) + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was logInfo(s"$address got disassociated, removing it.") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 46509e39c0f23..45412a35e9a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -75,16 +75,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use", - "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration") + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps) - - val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node", - "Submitted Time", "User", "State", "Duration") + val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow, - completedApps) + val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") @@ -191,7 +187,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = { + private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { val killLinkUri = s"app/kill?id=${app.id}&terminate=true" @@ -201,7 +197,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (kill) } - {app.id} @@ -210,15 +205,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.desc.name} - { - if (active) { - - {app.coresGranted} - - } - } - {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores} + {app.coresGranted} {Utils.megabytesToString(app.desc.memoryPerSlave)} @@ -230,14 +218,6 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def activeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = true) - } - - private def completeAppRow(app: ApplicationInfo): Seq[Node] = { - appRow(app, active = false) - } - private def driverRow(driver: DriverInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (driver.state == DriverState.RUNNING || diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 03afc289736bb..71e6e300fec5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -191,25 +191,23 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced - def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + def fastBucketFunction(min: Double, max: Double, count: Int)(e: Double): Option[Int] = { // If our input is not a number unless the increment is also NaN then we fail fast - if (e.isNaN()) { - return None - } - val bucketNumber = (e - min)/(increment) - // We do this rather than buckets.lengthCompare(bucketNumber) - // because Array[Double] fails to override it (for now). - if (bucketNumber > count || bucketNumber < 0) { + if (e.isNaN || e < min || e > max) { None } else { - Some(bucketNumber.toInt.min(count - 1)) + // Compute ratio of e's distance along range to total range first, for better precision + val bucketNumber = (((e - min) / (max - min)) * count).toInt + // should be less than count, but will equal count if e == max, in which case + // it's part of the last end-range-inclusive bucket, so return count-1 + Some(math.min(bucketNumber, count - 1)) } } // Decide which bucket function to pass to histogramPartition. We decide here - // rather than having a general function so that the decission need only be made + // rather than having a general function so that the decision need only be made // once rather than once per shard val bucketFunction = if (evenBuckets) { - fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + fastBucketFunction(buckets.head, buckets.last, buckets.length - 1) _ } else { basicBucketFunction _ } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6fa1f2c880f7a..132a9ced77700 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -81,9 +81,11 @@ class TaskInfo( def status: String = { if (running) { - "RUNNING" - } else if (gettingResult) { - "GET RESULT" + if (gettingResult) { + "GET RESULT" + } else { + "RUNNING" + } } else if (failed) { "FAILED" } else if (successful) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 660df00bc32f5..d0178dfde6935 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -112,6 +112,7 @@ class FileShuffleBlockManager(conf: SparkConf) private val shuffleState = shuffleStates(shuffleId) private var fileGroup: ShuffleFileGroup = null + val openStartTime = System.nanoTime val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -135,6 +136,9 @@ class FileShuffleBlockManager(conf: SparkConf) blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, so should be included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index fa2e617762f55..55ea0f17b156a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -63,6 +63,9 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter.insertAll(records) } + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 80d66e59132da..1dff09a75d038 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -535,9 +535,14 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) + memoryStore.putBytes(blockId, bytes.limit, () => { + // https://issues.apache.org/jira/browse/SPARK-6076 + // If the file size is bigger than the free memory, OOM will happen. So if we cannot + // put it into MemoryStore, copyForMemory should not be created. That's why this + // action is put into a `() => ByteBuffer` and created lazily. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + }) bytes.rewind() } if (!asBlockResult) { @@ -991,15 +996,23 @@ private[spark] class BlockManager( putIterator(blockId, Iterator(value), level, tellMaster) } + def dropFromMemory( + blockId: BlockId, + data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + dropFromMemory(blockId, () => data) + } + /** * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. * + * If `data` is not put on disk, it won't be created. + * * Return the block status if the given block has been updated, else None. */ def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1023,7 +1036,7 @@ private[spark] class BlockManager( // Drop to disk, if storage level requires if (level.useDisk && !diskStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") - data match { + data() match { case Left(elements) => diskStore.putArray(blockId, elements, level, returnValues = false) case Right(bytes) => diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 12cd8ea3bdf1f..2883137872600 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -47,6 +47,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content + // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private val shutdownHook = addShutdownHook() @@ -61,20 +63,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon val subDirId = (hash / localDirs.length) % subDirsPerLocalDir // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - if (!newDir.exists() && !newDir.mkdir()) { - throw new IOException(s"Failed to create local dir in $newDir.") - } - subDirs(dirId)(subDirId) = newDir - newDir + val subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + if (!newDir.exists() && !newDir.mkdir()) { + throw new IOException(s"Failed to create local dir in $newDir.") } + subDirs(dirId)(subDirId) = newDir + newDir } } @@ -91,7 +90,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories - subDirs.flatten.filter(_ != null).flatMap { dir => + subDirs.flatMap { dir => + dir.synchronized { + // Copy the content of dir because it may be modified in other threads + dir.clone() + } + }.filter(_ != null).flatMap { dir => val files = dir.listFiles() if (files != null) files else Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 1be860aea63d0..ed609772e6979 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -98,6 +98,26 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Use `size` to test if there is enough space in MemoryStore. If so, create the ByteBuffer and + * put it into MemoryStore. Otherwise, the ByteBuffer won't be created. + * + * The caller should guarantee that `size` is correct. + */ + def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + // Work on a duplicate - since the original input might be used elsewhere. + lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val data = + if (putAttempt.success) { + assert(bytes.limit == size) + Right(bytes.duplicate()) + } else { + null + } + PutResult(size, data, putAttempt.droppedBlocks) + } + override def putArray( blockId: BlockId, values: Array[Any], @@ -312,11 +332,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId.asRDDId.map(_.rddId) } + private def tryToPut( + blockId: BlockId, + value: Any, + size: Long, + deserialized: Boolean): ResultWithDroppedBlocks = { + tryToPut(blockId, () => value, size, deserialized) + } + /** * Try to put in a set of values, if we can free up enough space. The value should either be * an Array if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) size * must also be passed by the caller. * + * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be + * created to avoid OOM since it may be a big ByteBuffer. + * * Synchronize on `accountingLock` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for @@ -326,7 +357,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ private def tryToPut( blockId: BlockId, - value: Any, + value: () => Any, size: Long, deserialized: Boolean): ResultWithDroppedBlocks = { @@ -345,7 +376,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlocks ++= freeSpaceResult.droppedBlocks if (enoughFreeSpace) { - val entry = new MemoryEntry(value, size, deserialized) + val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) currentMemory += size @@ -357,12 +388,12 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[Array[Any]]) + lazy val data = if (deserialized) { + Left(value().asInstanceOf[Array[Any]]) } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[ByteBuffer].duplicate()) } - val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 19ac7a826e306..5fbcd6bb8ad94 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.util.concurrent.Semaphore + import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} @@ -88,6 +90,8 @@ private[spark] object UIWorkloadGenerator { ("Job with delays", baseData.map(x => Thread.sleep(100)).count) ) + val barrier = new Semaphore(-nJobSet * jobs.size + 1) + (1 to nJobSet).foreach { _ => for ((desc, job) <- jobs) { new Thread { @@ -99,12 +103,17 @@ private[spark] object UIWorkloadGenerator { } catch { case e: Exception => println("Job Failed: " + desc) + } finally { + barrier.release() } } }.start Thread.sleep(INTER_JOB_WAIT_MS) } } + + // Waiting for threads. + barrier.acquire() sc.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 949e80d30f5eb..625596885faa1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -44,6 +44,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // These type aliases are public because they're used in the types of public fields: type JobId = Int + type JobGroupId = String type StageId = Int type StageAttemptId = Int type PoolName = String @@ -54,6 +55,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val completedJobs = ListBuffer[JobUIData]() val failedJobs = ListBuffer[JobUIData]() val jobIdToData = new HashMap[JobId, JobUIData] + val jobGroupToJobIds = new HashMap[JobGroupId, HashSet[JobId]] // Stages: val pendingStages = new HashMap[StageId, StageInfo] @@ -119,7 +121,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { Map( "jobIdToData" -> jobIdToData.size, "stageIdToData" -> stageIdToData.size, - "stageIdToStageInfo" -> stageIdToInfo.size + "stageIdToStageInfo" -> stageIdToInfo.size, + "jobGroupToJobIds" -> jobGroupToJobIds.values.map(_.size).sum, + // Since jobGroupToJobIds is map of sets, check that we don't leak keys with empty values: + "jobGroupToJobIds keySet" -> jobGroupToJobIds.keys.size ) } @@ -140,7 +145,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (jobs.size > retainedJobs) { val toRemove = math.max(retainedJobs / 10, 1) jobs.take(toRemove).foreach { job => - jobIdToData.remove(job.jobId) + // Remove the job's UI data, if it exists + jobIdToData.remove(job.jobId).foreach { removedJob => + // A null jobGroupId is used for jobs that are run without a job group + val jobGroupId = removedJob.jobGroup.orNull + // Remove the job group -> job mapping entry, if it exists + jobGroupToJobIds.get(jobGroupId).foreach { jobsInGroup => + jobsInGroup.remove(job.jobId) + // If this was the last job in this job group, remove the map entry for the job group + if (jobsInGroup.isEmpty) { + jobGroupToJobIds.remove(jobGroupId) + } + } + } } jobs.trimStart(toRemove) } @@ -158,6 +175,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) + // A null jobGroupId is used for jobs that are run without a job group + jobGroupToJobIds.getOrElseUpdate(jobGroup.orNull, new HashSet[JobId]).add(jobStart.jobId) jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index e03442894c5cc..797c9404bc449 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -269,11 +269,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - if (info.gettingResultTime > 0) { - (info.finishTime - info.gettingResultTime).toDouble - } else { - 0.0 - } + getGettingResultTime(info).toDouble } val gettingResultQuantiles = @@ -464,7 +460,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = info.gettingResultTime + val gettingResultTime = getGettingResultTime(info) val maybeAccumulators = info.accumulables val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} @@ -627,6 +623,19 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {errorSummary}{details} } + private def getGettingResultTime(info: TaskInfo): Long = { + if (info.gettingResultTime > 0) { + if (info.finishTime > 0) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + System.currentTimeMillis - info.gettingResultTime + } + } else { + 0L + } + } + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { val totalExecutionTime = if (info.gettingResult) { @@ -638,6 +647,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val executorOverhead = (metrics.executorDeserializeTime + metrics.resultSerializationTime) - math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - getGettingResultTime(info)) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3262e670c2030..b962c101c91da 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -352,6 +352,7 @@ private[spark] class ExternalSorter[K, V, C]( // Create our file writers if we haven't done so yet if (partitionWriters == null) { curWriteMetrics = new ShuffleWriteMetrics() + val openStartTime = System.nanoTime partitionWriters = Array.fill(numPartitions) { // Because these files may be read during shuffle, their compression must be controlled by // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use @@ -359,6 +360,10 @@ private[spark] class ExternalSorter[K, V, C]( val (blockId, file) = diskBlockManager.createTempShuffleBlock() blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) } // No need to sort stuff, just write each element out diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index c52591b352340..efc2482c74ddf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + if (k == null) { + haveNullValue + } else { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + } + /** Get the value for a given key */ def apply(k: K): V = { if (k == null) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index c80057f95e0b2..1501111a06655 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -122,7 +122,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def addWithoutResize(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 + var delta = 1 while (true) { if (!_bitset.get(pos)) { // This is a new key. @@ -134,14 +134,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( // Found an existing key. return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** @@ -163,21 +161,19 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ def getPos(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask - var i = 1 - val maxProbe = _data.size - while (i < maxProbe) { + var delta = 1 + while (true) { if (!_bitset.get(pos)) { return INVALID_POS } else if (k == _data(pos)) { return pos } else { - val delta = i + // quadratic probing with values increase by 1, 2, 3, ... pos = (pos + delta) & _mask - i += 1 + delta += 1 } } - // Never reached here - INVALID_POS + throw new RuntimeException("Should never reach here.") } /** Return the value at the specified position. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index 61e22642761f0..b4ec4ea521253 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -48,6 +48,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, override def size: Int = _keySet.size + /** Tests whether this map contains a binding for a key. */ + def contains(k: K): Boolean = { + _keySet.getPos(k) != OpenHashSet.INVALID_POS + } + /** Get the value for a given key */ def apply(k: K): V = { val pos = _keySet.getPos(k) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 4cd0f97368ca3..97079382c716f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -235,6 +235,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithDoubleValuesAtMinMax") { + val rdd = sc.parallelize(Seq(1, 1, 1, 2, 3, 3)) + assert(Array(3, 0, 1, 2) === rdd.map(_.toDouble).histogram(4)._2) + assert(Array(3, 1, 2) === rdd.map(_.toDouble).histogram(3)._2) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) @@ -248,7 +254,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { } test("WorksWithoutBucketsForLargerDatasets") { - // Verify the case of slighly larger datasets + // Verify the case of slightly larger datasets val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(8) val expectedHistogramResults = @@ -259,17 +265,27 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } - test("WorksWithoutBucketsWithIrrationalBucketEdges") { - // Verify the case of buckets with irrational edges. See #SPARK-2862. + test("WorksWithoutBucketsWithNonIntegralBucketEdges") { + // Verify the case of buckets with nonintegral edges. See #SPARK-2862. val rdd = sc.parallelize(6 to 99) val (histogramBuckets, histogramResults) = rdd.histogram(9) + // Buckets are 6.0, 16.333333333333336, 26.666666666666668, 37.0, 47.333333333333336 ... val expectedHistogramResults = - Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + Array(11, 10, 10, 11, 10, 10, 11, 10, 11) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets(0) === 6.0) assert(histogramBuckets(9) === 99.0) } + test("WorksWithHugeRange") { + val rdd = sc.parallelize(Array(0, 1.0e24, 1.0e30)) + val histogramResults = rdd.histogram(1000000)._2 + assert(histogramResults(0) === 1) + assert(histogramResults(1) === 1) + assert(histogramResults.last === 1) + assert((2 to histogramResults.length - 2).forall(i => histogramResults(i) == 0)) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 3fdbe99b5d02b..ecd1cba5b5abe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -170,8 +170,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -413,8 +413,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach t2.join() t3.join() - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) + store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) store.waitForAsyncReregister() } } @@ -1223,4 +1223,30 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) } + + test("lazily create a big ByteBuffer to avoid OOM if it cannot be put into MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + val result = memoryStore.putBytes(blockId, 13000, () => { + fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") + }) + assert(result.size === 13000) + assert(result.data === null) + assert(result.droppedBlocks === Nil) + } + + test("put a small ByteBuffer to MemoryStore") { + store = makeBlockManager(12000) + val memoryStore = store.memoryStore + val blockId = BlockId("rdd_3_10") + var bytes: ByteBuffer = null + val result = memoryStore.putBytes(blockId, 10000, () => { + bytes = ByteBuffer.allocate(10000) + bytes + }) + assert(result.size === 10000) + assert(result.data === Right(bytes)) + assert(result.droppedBlocks === Nil) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 730a4b54f5aa1..c0c28cb60e21d 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import java.util.Properties + import org.scalatest.FunSuite import org.scalatest.Matchers @@ -44,11 +46,19 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc SparkListenerStageCompleted(stageInfo) } - private def createJobStartEvent(jobId: Int, stageIds: Seq[Int]) = { + private def createJobStartEvent( + jobId: Int, + stageIds: Seq[Int], + jobGroup: Option[String] = None): SparkListenerJobStart = { val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) + val properties: Option[Properties] = jobGroup.map { groupId => + val props = new Properties() + props.setProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + props + } + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos, properties.orNull) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { @@ -110,6 +120,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.stageIdToActiveJobIds.size should be (0) } + test("test clearing of jobGroupToJobIds") { + val conf = new SparkConf() + conf.set("spark.ui.retainedJobs", 5.toString) + val listener = new JobProgressListener(conf) + + // Run 50 jobs, each with one stage + for (jobId <- 0 to 50) { + listener.onJobStart(createJobStartEvent(jobId, Seq(0), jobGroup = Some(jobId.toString))) + listener.onStageSubmitted(createStageStartEvent(0)) + listener.onStageCompleted(createStageEndEvent(0, failed = false)) + listener.onJobEnd(createJobEndEvent(jobId, false)) + } + assertActiveJobsStateIsEmpty(listener) + // This collection won't become empty, but it should be bounded by spark.ui.retainedJobs + listener.jobGroupToJobIds.size should be (5) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 6a70877356409..ef890d2ba60f3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toString) === i.toString) } } + + test("contains") { + val map = new OpenHashMap[String, Int](2) + map("a") = 1 + assert(map.contains("a")) + assert(!map.contains("b")) + assert(!map.contains(null)) + map(null) = 0 + assert(map.contains(null)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 8c7df7d73dcd3..caf378fec8b3e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers { assert(map(i.toLong) === i.toString) } } + + test("contains") { + val map = new PrimitiveKeyOpenHashMap[Int, Int](1) + map(0) = 0 + assert(map.contains(0)) + assert(!map.contains(1)) + } } diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 0b6db4fcb7b1f..f5aa15b7d9b79 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model. {% highlight scala %} import org.apache.spark.mllib.clustering.GaussianMixture +import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))) // Cluster the data into two classes using GaussianMixture val gmm = new GaussianMixture().setK(2).run(parsedData) +// Save and load model +gmm.save(sc, "myGMMModel") +val sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + // output parameters of max-likelihood model for (i <- 0 until gmm.k) { println("weight=%f\nmu=%s\nsigma=\n%s\n" format @@ -231,6 +236,9 @@ public class GaussianMixtureExample { // Cluster the data into two classes using GaussianMixture GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); + // Save and load GaussianMixtureModel + gmm.save(sc, "myGMMModel") + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") // Output the parameters of the mixture model for(int j=0; j Seq[String] = { str => + val re = paramMap(pattern).r + val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = paramMap(minTokenLength) + tokens.filter(_.length >= minLength) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 69c6751df37f8..642250f30b90e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -110,9 +110,11 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): JList[Object] = { + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) + .setValidateData(validateData) lrAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -134,8 +136,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val lassoAlg = new LassoWithSGD() + lassoAlg.setIntercept(intercept) + .setValidateData(validateData) lassoAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) @@ -156,8 +162,12 @@ private[python] class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Vector): JList[Object] = { + initialWeights: Vector, + intercept: Boolean, + validateData: Boolean): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.setIntercept(intercept) + .setValidateData(validateData) ridgeAlg.optimizer .setNumIterations(numIterations) .setRegParam(regParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index af6f83c74bb40..ec65a3da689de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD @Experimental class GaussianMixtureModel( val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable { + val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") - + + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) + } + /** Number of gaussians in mixture */ def k: Int = weights.length @@ -83,5 +95,79 @@ class GaussianMixtureModel( p(i) /= pSum } p - } + } +} + +@Experimental +object GaussianMixtureModel extends Loader[GaussianMixtureModel] { + + private object SaveLoadV1_0 { + + case class Data(weight: Double, mu: Vector, sigma: Matrix) + + val formatVersionV1_0 = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel" + + def save( + sc: SparkContext, + path: String, + weights: Array[Double], + gaussians: Array[MultivariateGaussian]): Unit = { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Create JSON metadata. + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(weights.length) { i => + Data(weights(i), gaussians(i).mu, gaussians(i).sigma) + } + sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): GaussianMixtureModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = new SQLContext(sc) + val dataFrame = sqlContext.parquetFile(dataPath) + val dataArray = dataFrame.select("weight", "mu", "sigma").collect() + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val (weights, gaussians) = dataArray.map { + case Row(weight: Double, mu: Vector, sigma: Matrix) => + (weight, new MultivariateGaussian(mu, sigma)) + }.unzip + + return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + } + } + + override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val k = (metadata \ "k").extract[Int] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + (loadedClassName, version) match { + case (classNameV1_0, "1.0") => { + val model = SaveLoadV1_0.load(sc, path) + require(model.weights.length == k, + s"GaussianMixtureModel requires weights of length $k " + + s"got weights of length ${model.weights.length}") + require(model.gaussians.length == k, + s"GaussianMixtureModel requires gaussians of length $k" + + s"got gaussians of length ${model.gaussians.length}") + model + } + case _ => throw new Exception( + s"GaussianMixtureModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5e17c8da61134..9d63a08e211bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, normalize} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 45b9ebb4cc0d6..9fd60ff7a0c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -211,6 +211,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java new file mode 100644 index 0000000000000..3806f650025b2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaTokenizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void regexTokenizer() { + RegexTokenizer myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(3); + + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + + Row[] pairs = myRegExTokenizer.transform(dataset) + .select("tokens", "wantedTokens") + .collect(); + + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 0000000000000..bf862b912d326 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { + /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ + def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) +} + +class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("RegexTokenizer") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer, dataset0) + + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Seq("punct")) + )) + + tokenizer.setMinTokenLength(3) + testRegexTokenizer(tokenizer, dataset1) + + tokenizer + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(0) + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + )) + testRegexTokenizer(tokenizer, dataset2) + } +} + +object RegexTokenizerSuite extends FunSuite { + + def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("tokens", "wantedTokens") + .collect() + .foreach { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index aaa81da9e273c..a26c52852c4d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -425,6 +425,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M val model = lr.run(testRDD) + val numFeatures = testRDD.map(_.features.size).first() + val initialWeights = Vectors.dense(new Array[Double]((numFeatures + 1) * 2)) + val model2 = lr.run(testRDD, initialWeights) + + LogisticRegressionSuite.checkModelsEqual(model, model2) + /** * The following is the instruction to reproduce the model using R's glmnet package. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 1b46a4012d731..f356ffa3e3a26 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { test("single cluster") { @@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) + val data = sc.parallelize(GaussianTestData.data) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } test("two clusters with sparse data") { - val data = sc.parallelize(Array( - Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), - Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), - Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), - Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), - Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) - )) - + val data = sc.parallelize(GaussianTestData.data) val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray)) // we set an initial gaussian to induce expected results val initialGmm = new GaussianMixtureModel( @@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3) assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + + test("model save / load") { + val data = sc.parallelize(GaussianTestData.data) + + val gmm = new GaussianMixture().setK(2).setSeed(0).run(data) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + gmm.save(sc, path) + + // TODO: GaussianMixtureModel should implement equals/hashcode directly. + val sameModel = GaussianMixtureModel.load(sc, path) + assert(sameModel.k === gmm.k) + (0 until sameModel.k).foreach { i => + assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu) + assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + object GaussianTestData { + + val data = Array( + Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220), + Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118), + Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322), + Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026), + Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) + ) + + } } diff --git a/pom.xml b/pom.xml index 23bb16130b504..b3cecd1893a06 100644 --- a/pom.xml +++ b/pom.xml @@ -1452,7 +1452,8 @@ ${basedir}/src/test/scala scalastyle-config.xml scalastyle-output.xml - UTF-8 + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 414a0ada80787..209f1ee473b5b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -140,6 +140,13 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", + ... intercept=True, validateData=True) + >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( @@ -173,7 +180,8 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=0.0, regType=None, intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False, + validateData=True): """ Train a linear regression model on the given data. @@ -195,15 +203,18 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, (default: None) - @param intercept: Boolean parameter which indicates the use + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). (default: False) + :param validateData: Boolean parameter which indicates if the + algorithm should validate data before training. + (default: True) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept)) + regType, bool(intercept), bool(validateData)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -253,6 +264,13 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( @@ -273,11 +291,13 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a Lasso regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -327,6 +347,13 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, + ... validateData=True) + >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True """ def save(self, sc, path): java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( @@ -347,11 +374,13 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, - miniBatchFraction=1.0, initialWeights=None): + miniBatchFraction=1.0, initialWeights=None, intercept=False, + validateData=True): """Train a ridge regression model on the given data.""" def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), - float(regParam), float(miniBatchFraction), i) + float(regParam), float(miniBatchFraction), i, bool(intercept), + bool(validateData)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ae32c2870c97b..890b6d2b05acb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -16,7 +16,6 @@ # import sys -import itertools import warnings import random @@ -534,6 +533,25 @@ def sort(self, *cols): orderBy = sort + def describe(self, *cols): + """Computes statistics for numeric columns. + + This include count, mean, stddev, min, and max. If no columns are + given, this function computes statistics for all numerical columns. + + >>> df.describe().show() + summary age + count 2 + mean 3.5 + stddev 1.5 + min 2 + max 5 + """ + cols = ListConverter().convert(cols, + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def head(self, n=None): """ Return the first `n` rows or the first row if n is None. @@ -1011,6 +1029,23 @@ def substr(self, startPos, length): __getslice__ = substr + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + # order asc = _unary_op("asc", "Returns a sort expression based on the" " ascending order of the given column name.") diff --git a/repl/pom.xml b/repl/pom.xml index edfa1c7f2c29c..03053b4c3b287 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -84,6 +84,11 @@ scalacheck_${scala.binary.version} test + + org.mockito + mockito-all + test + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 9805609120005..004941d5f50ae 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,9 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException} -import java.net.{URI, URL, URLEncoder} -import java.util.concurrent.{Executors, ExecutorService} +import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} + +import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader val parentLoader = new ParentClassLoader(parent) + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { if (Set("http", "https", "ftp").contains(uri.getScheme)) { @@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), + SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + if (fileSystem.exists(path)) { + fileSystem.open(path) + } else { + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null try { - val pathInDirectory = name.replace('.', '/') + ".class" - val inputStream = { + inputStream = { if (fileSystem != null) { - fileSystem.open(new Path(directory, pathInDirectory)) + getClassFileInputStreamFromFileSystem(pathInDirectory) } else { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - - Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager) - .getInputStream + getClassFileInputStreamFromHttpServer(pathInDirectory) } } val bytes = readAndTransformClass(name, inputStream) - inputStream.close() Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: FileNotFoundException => + case e: ClassNotFoundException => // We did not find the class logDebug(s"Did not load class $name from REPL class server at $uri", e) None @@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader // Something bad happened while checking if the class exists logError(s"Failed to check existence of class $name on REPL class server at $uri", e) None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } } } diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 6a79e76a34db8..c709cde740748 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -20,13 +20,25 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.scalatest.concurrent.Interruptor +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito._ -import org.apache.spark.{SparkConf, TestUtils} +import org.apache.spark._ import org.apache.spark.util.Utils -class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { +class ExecutorClassLoaderSuite + extends FunSuite + with BeforeAndAfterAll + with MockitoSugar + with Logging { val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") @@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { var tempDir2: File = _ var url1: String = _ var urls2: Array[URL] = _ + var classServer: HttpServer = _ override def beforeAll() { super.beforeAll() @@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { super.afterAll() + if (classServer != null) { + classServer.stop() + } Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) } test("child first") { @@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { } } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { + // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class + // from the driver's class server would leak a HTTP connection, causing the class server's + // thread / connection pool to be exhausted. + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + classServer = new HttpServer(conf, tempDir1, securityManager) + classServer.start() + // ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this + val mockEnv = mock[SparkEnv] + when(mockEnv.securityManager).thenReturn(securityManager) + SparkEnv.set(mockEnv) + // Create an ExecutorClassLoader that's configured to load classes from the HTTP server + val parentLoader = new URLClassLoader(Array.empty, null) + val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) + classLoader.httpUrlConnectionTimeoutMillis = 500 + // Check that this class loader can actually load classes that exist + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + // Try to perform a full GC now, since GC during the test might mask resource leaks + System.gc() + // When the original bug occurs, the test thread becomes blocked in a classloading call + // and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to + // shut down the HTTP server when the test times out + val interruptor: Interruptor = new Interruptor { + override def apply(thread: Thread): Unit = { + classServer.stop() + classServer = null + thread.interrupt() + } + } + def tryAndFailToLoadABunchOfClasses(): Unit = { + // The number of trials here should be much larger than Jetty's thread / connection limit + // in order to expose thread or connection leaks + for (i <- 1 to 1000) { + if (Thread.currentThread().isInterrupted) { + throw new InterruptedException() + } + // Incorporate the iteration number into the class name in order to avoid any response + // caching that might be added in the future + intercept[ClassNotFoundException] { + classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance() + } + } + } + failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 366be00473d1c..3823584287741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,7 +26,7 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] object KeywordNormalizer { - def apply(str: String) = str.toLowerCase() + def apply(str: String): String = str.toLowerCase() } private[sql] abstract class AbstractSparkSQLParser @@ -42,7 +42,7 @@ private[sql] abstract class AbstractSparkSQLParser } protected case class Keyword(str: String) { - def normalize = KeywordNormalizer(str) + def normalize: String = KeywordNormalizer(str) def parser: Parser[String] = normalize } @@ -81,7 +81,7 @@ private[sql] abstract class AbstractSparkSQLParser class SqlLexical extends StdLexical { case class FloatLit(chars: String) extends Token { - override def toString = chars + override def toString: String = chars } /* This is a work around to support the lazy setting */ @@ -120,7 +120,7 @@ class SqlLexical extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') + override def identChar: Parser[Elem] = letter | elem('_') override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c93af79795bc7..44eceb0b372e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -64,9 +64,7 @@ class Analyzer(catalog: Catalog, UnresolvedHavingClauseAttributes :: TrimGroupingAliases :: typeCoercionRules ++ - extendedResolutionRules : _*), - Batch("Remove SubQueries", fixedPoint, - EliminateSubQueries) + extendedResolutionRules : _*) ) /** @@ -170,7 +168,7 @@ class Analyzer(catalog: Catalog, * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { - def getTable(u: UnresolvedRelation) = { + def getTable(u: UnresolvedRelation): LogicalPlan = { try { catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 9e6e2912e0622..5eb7dff0cede8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -86,12 +86,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables += ((getDbTableName(tableIdent), plan)) } - override def unregisterTable(tableIdentifier: Seq[String]) = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { val tableIdent = processTableIdentifier(tableIdentifier) tables -= getDbTableName(tableIdent) } - override def unregisterAllTables() = { + override def unregisterAllTables(): Unit = { tables.clear() } @@ -147,8 +147,8 @@ trait OverrideCatalog extends Catalog { } abstract override def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None): LogicalPlan = { + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) val overriddenTable = overrides.get(getDBTable(tableIdent)) val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) @@ -205,15 +205,15 @@ trait OverrideCatalog extends Catalog { */ object EmptyCatalog extends Catalog { - val caseSensitive: Boolean = true + override val caseSensitive: Boolean = true - def tableExists(tableIdentifier: Seq[String]): Boolean = { + override def tableExists(tableIdentifier: Seq[String]): Boolean = { throw new UnsupportedOperationException } - def lookupRelation( - tableIdentifier: Seq[String], - alias: Option[String] = None) = { + override def lookupRelation( + tableIdentifier: Seq[String], + alias: Option[String] = None): LogicalPlan = { throw new UnsupportedOperationException } @@ -221,11 +221,11 @@ object EmptyCatalog extends Catalog { throw new UnsupportedOperationException } - def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } - def unregisterTable(tableIdentifier: Seq[String]): Unit = { + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 425e1e41cbf21..40472a1cbb3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -33,7 +33,7 @@ class CheckAnalysis { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9f334f6d42ad1..c43ea55899695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -35,7 +35,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -47,7 +47,7 @@ trait OverrideFunctionRegistry extends FunctionRegistry { class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) - def registerFunction(name: String, builder: FunctionBuilder) = { + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) } @@ -61,13 +61,15 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder) = ??? + override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + throw new UnsupportedOperationException + } - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } - def caseSensitive: Boolean = ??? + override def caseSensitive: Boolean = throw new UnsupportedOperationException } /** @@ -76,7 +78,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { * TODO move this into util folder? */ object StringKeyHashMap { - def apply[T](caseSensitive: Boolean) = caseSensitive match { + def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = caseSensitive match { case false => new StringKeyHashMap[T](_.toLowerCase) case true => new StringKeyHashMap[T](identity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index a7d3a8ee7deb3..c61c395cb4bb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -38,7 +38,7 @@ package object analysis { implicit class AnalysisErrorAt(t: TreeNode[_]) { /** Fails the analysis at the point where a specific tree node was parsed. */ - def failAnalysis(msg: String) = { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index ad5172c0349eb..300e9ba187bc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.types.DataType /** * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully @@ -38,9 +39,10 @@ case class UnresolvedRelation( alias: Option[String] = None) extends LeafNode { /** Returns a `.` separated name for this relation. */ - def tableName = tableIdentifier.mkString(".") + def tableName: String = tableIdentifier.mkString(".") + + override def output: Seq[Attribute] = Nil - override def output = Nil override lazy val resolved = false } @@ -48,16 +50,16 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = UnresolvedAttribute(name) + override def newInstance(): UnresolvedAttribute = this + override def withNullability(newNullability: Boolean): UnresolvedAttribute = this + override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -67,16 +69,16 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"'$name(${children.mkString(",")})" + override def toString: String = s"'$name(${children.mkString(",")})" } /** @@ -86,17 +88,17 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E trait Star extends Attribute with trees.LeafNode[Expression] { self: Product => - override def name = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name: String = throw new UnresolvedException(this, "name") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this - override def withNullability(newNullability: Boolean) = this - override def withQualifiers(newQualifiers: Seq[String]) = this - override def withName(newName: String) = this + override def newInstance(): Star = this + override def withNullability(newNullability: Boolean): Star = this + override def withQualifiers(newQualifiers: Seq[String]): Star = this + override def withName(newName: String): Star = this // Star gets expanded at runtime so we never evaluate a Star. override def eval(input: Row = null): EvaluatedType = @@ -129,7 +131,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { } } - override def toString = table.map(_ + ".").getOrElse("") + "*" + override def toString: String = table.map(_ + ".").getOrElse("") + "*" } /** @@ -144,25 +146,25 @@ case class UnresolvedStar(table: Option[String]) extends Star { case class MultiAlias(child: Expression, names: Seq[String]) extends Attribute with trees.UnaryNode[Expression] { - override def name = throw new UnresolvedException(this, "name") + override def name: String = throw new UnresolvedException(this, "name") - override def exprId = throw new UnresolvedException(this, "exprId") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") - override def dataType = throw new UnresolvedException(this, "dataType") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def nullable = throw new UnresolvedException(this, "nullable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance() = this + override def newInstance(): MultiAlias = this - override def withNullability(newNullability: Boolean) = this + override def withNullability(newNullability: Boolean): MultiAlias = this - override def withQualifiers(newQualifiers: Seq[String]) = this + override def withQualifiers(newQualifiers: Seq[String]): MultiAlias = this - override def withName(newName: String) = this + override def withName(newName: String): MultiAlias = this override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @@ -179,17 +181,17 @@ case class MultiAlias(child: Expression, names: Seq[String]) */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions - override def toString = expressions.mkString("ResolvedStar(", ", ", ")") + override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression { - override def dataType = throw new UnresolvedException(this, "dataType") - override def foldable = throw new UnresolvedException(this, "foldable") - override def nullable = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child.$fieldName" + override def toString: String = s"$child.$fieldName" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 51a09ac0e1249..145f062dd6817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -61,60 +61,60 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- = UnaryMinus(expr) - def unary_! = Not(expr) - def unary_~ = BitwiseNot(expr) - - def + (other: Expression) = Add(expr, other) - def - (other: Expression) = Subtract(expr, other) - def * (other: Expression) = Multiply(expr, other) - def / (other: Expression) = Divide(expr, other) - def % (other: Expression) = Remainder(expr, other) - def & (other: Expression) = BitwiseAnd(expr, other) - def | (other: Expression) = BitwiseOr(expr, other) - def ^ (other: Expression) = BitwiseXor(expr, other) - - def && (other: Expression) = And(expr, other) - def || (other: Expression) = Or(expr, other) - - def < (other: Expression) = LessThan(expr, other) - def <= (other: Expression) = LessThanOrEqual(expr, other) - def > (other: Expression) = GreaterThan(expr, other) - def >= (other: Expression) = GreaterThanOrEqual(expr, other) - def === (other: Expression) = EqualTo(expr, other) - def <=> (other: Expression) = EqualNullSafe(expr, other) - def !== (other: Expression) = Not(EqualTo(expr, other)) - - def in(list: Expression*) = In(expr, list) - - def like(other: Expression) = Like(expr, other) - def rlike(other: Expression) = RLike(expr, other) - def contains(other: Expression) = Contains(expr, other) - def startsWith(other: Expression) = StartsWith(expr, other) - def endsWith(other: Expression) = EndsWith(expr, other) - def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def unary_- : Expression= UnaryMinus(expr) + def unary_! : Predicate = Not(expr) + def unary_~ : Expression = BitwiseNot(expr) + + def + (other: Expression): Expression = Add(expr, other) + def - (other: Expression): Expression = Subtract(expr, other) + def * (other: Expression): Expression = Multiply(expr, other) + def / (other: Expression): Expression = Divide(expr, other) + def % (other: Expression): Expression = Remainder(expr, other) + def & (other: Expression): Expression = BitwiseAnd(expr, other) + def | (other: Expression): Expression = BitwiseOr(expr, other) + def ^ (other: Expression): Expression = BitwiseXor(expr, other) + + def && (other: Expression): Predicate = And(expr, other) + def || (other: Expression): Predicate = Or(expr, other) + + def < (other: Expression): Predicate = LessThan(expr, other) + def <= (other: Expression): Predicate = LessThanOrEqual(expr, other) + def > (other: Expression): Predicate = GreaterThan(expr, other) + def >= (other: Expression): Predicate = GreaterThanOrEqual(expr, other) + def === (other: Expression): Predicate = EqualTo(expr, other) + def <=> (other: Expression): Predicate = EqualNullSafe(expr, other) + def !== (other: Expression): Predicate = Not(EqualTo(expr, other)) + + def in(list: Expression*): Expression = In(expr, list) + + def like(other: Expression): Expression = Like(expr, other) + def rlike(other: Expression): Expression = RLike(expr, other) + def contains(other: Expression): Expression = Contains(expr, other) + def startsWith(other: Expression): Expression = StartsWith(expr, other) + def endsWith(other: Expression): Expression = EndsWith(expr, other) + def substr(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)) = + def substring(pos: Expression, len: Expression = Literal(Int.MaxValue)): Expression = Substring(expr, pos, len) - def isNull = IsNull(expr) - def isNotNull = IsNotNull(expr) + def isNull: Predicate = IsNull(expr) + def isNotNull: Predicate = IsNotNull(expr) - def getItem(ordinal: Expression) = GetItem(expr, ordinal) - def getField(fieldName: String) = UnresolvedGetField(expr, fieldName) + def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal) + def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName) - def cast(to: DataType) = Cast(expr, to) + def cast(to: DataType): Expression = Cast(expr, to) - def asc = SortOrder(expr, Ascending) - def desc = SortOrder(expr, Descending) + def asc: SortOrder = SortOrder(expr, Ascending) + def desc: SortOrder = SortOrder(expr, Descending) - def as(alias: String) = Alias(expr, alias)() - def as(alias: Symbol) = Alias(expr, alias.name)() + def as(alias: String): NamedExpression = Alias(expr, alias)() + def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { implicit class DslExpression(e: Expression) extends ImplicitOperators { - def expr = e + def expr: Expression = e } implicit def booleanToLiteral(b: Boolean): Literal = Literal(b) @@ -144,94 +144,100 @@ package object dsl { } } - def sum(e: Expression) = Sum(e) - def sumDistinct(e: Expression) = SumDistinct(e) - def count(e: Expression) = Count(e) - def countDistinct(e: Expression*) = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05) = ApproxCountDistinct(e, rsd) - def avg(e: Expression) = Average(e) - def first(e: Expression) = First(e) - def last(e: Expression) = Last(e) - def min(e: Expression) = Min(e) - def max(e: Expression) = Max(e) - def upper(e: Expression) = Upper(e) - def lower(e: Expression) = Lower(e) - def sqrt(e: Expression) = Sqrt(e) - def abs(e: Expression) = Abs(e) - - implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + def sum(e: Expression): Expression = Sum(e) + def sumDistinct(e: Expression): Expression = SumDistinct(e) + def count(e: Expression): Expression = Count(e) + def countDistinct(e: Expression*): Expression = CountDistinct(e) + def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = + ApproxCountDistinct(e, rsd) + def avg(e: Expression): Expression = Average(e) + def first(e: Expression): Expression = First(e) + def last(e: Expression): Expression = Last(e) + def min(e: Expression): Expression = Min(e) + def max(e: Expression): Expression = Max(e) + def upper(e: Expression): Expression = Upper(e) + def lower(e: Expression): Expression = Lower(e) + def sqrt(e: Expression): Expression = Sqrt(e) + def abs(e: Expression): Expression = Abs(e) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) - def attr = analysis.UnresolvedAttribute(s) + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) } abstract class ImplicitAttribute extends ImplicitOperators { def s: String - def expr = attr - def attr = analysis.UnresolvedAttribute(s) + def expr: UnresolvedAttribute = attr + def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s) /** Creates a new AttributeReference of type boolean */ - def boolean = AttributeReference(s, BooleanType, nullable = true)() + def boolean: AttributeReference = AttributeReference(s, BooleanType, nullable = true)() /** Creates a new AttributeReference of type byte */ - def byte = AttributeReference(s, ByteType, nullable = true)() + def byte: AttributeReference = AttributeReference(s, ByteType, nullable = true)() /** Creates a new AttributeReference of type short */ - def short = AttributeReference(s, ShortType, nullable = true)() + def short: AttributeReference = AttributeReference(s, ShortType, nullable = true)() /** Creates a new AttributeReference of type int */ - def int = AttributeReference(s, IntegerType, nullable = true)() + def int: AttributeReference = AttributeReference(s, IntegerType, nullable = true)() /** Creates a new AttributeReference of type long */ - def long = AttributeReference(s, LongType, nullable = true)() + def long: AttributeReference = AttributeReference(s, LongType, nullable = true)() /** Creates a new AttributeReference of type float */ - def float = AttributeReference(s, FloatType, nullable = true)() + def float: AttributeReference = AttributeReference(s, FloatType, nullable = true)() /** Creates a new AttributeReference of type double */ - def double = AttributeReference(s, DoubleType, nullable = true)() + def double: AttributeReference = AttributeReference(s, DoubleType, nullable = true)() /** Creates a new AttributeReference of type string */ - def string = AttributeReference(s, StringType, nullable = true)() + def string: AttributeReference = AttributeReference(s, StringType, nullable = true)() /** Creates a new AttributeReference of type date */ - def date = AttributeReference(s, DateType, nullable = true)() + def date: AttributeReference = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + def decimal: AttributeReference = + AttributeReference(s, DecimalType.Unlimited, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal(precision: Int, scale: Int) = + def decimal(precision: Int, scale: Int): AttributeReference = AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ - def timestamp = AttributeReference(s, TimestampType, nullable = true)() + def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() /** Creates a new AttributeReference of type binary */ - def binary = AttributeReference(s, BinaryType, nullable = true)() + def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() /** Creates a new AttributeReference of type array */ - def array(dataType: DataType) = AttributeReference(s, ArrayType(dataType), nullable = true)() + def array(dataType: DataType): AttributeReference = + AttributeReference(s, ArrayType(dataType), nullable = true)() /** Creates a new AttributeReference of type map */ def map(keyType: DataType, valueType: DataType): AttributeReference = map(MapType(keyType, valueType)) - def map(mapType: MapType) = AttributeReference(s, mapType, nullable = true)() + + def map(mapType: MapType): AttributeReference = + AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) - def struct(structType: StructType) = AttributeReference(s, structType, nullable = true)() + def struct(structType: StructType): AttributeReference = + AttributeReference(s, structType, nullable = true)() } implicit class DslAttribute(a: AttributeReference) { - def notNull = a.withNullability(false) - def nullable = a.withNullability(true) + def notNull: AttributeReference = a.withNullability(false) + def nullable: AttributeReference = a.withNullability(true) // Protobuf terminology - def required = a.withNullability(false) + def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) + def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } @@ -241,23 +247,23 @@ package object dsl { abstract class LogicalPlanFunctions { def logicalPlan: LogicalPlan - def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression) = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, - condition: Option[Expression] = None) = + condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { val aliasedExprs = aggregateExprs.map { case ne: NamedExpression => ne case e => Alias(e, e.toString)() @@ -265,41 +271,43 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) - def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) def sample( fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None) = + alias: Option[String] = None): LogicalPlan = Generate(generator, join, outer, None, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false) = + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite) - def analyze = analysis.SimpleAnalyzer(logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) } object plans { // scalastyle:ignore implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String) = WriteToFile(path, logicalPlan) + def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) } } case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + def call(args: Expression*): ScalaUdf = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } } // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 82e760b6c6916..96a11e352ec50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,7 +23,9 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { + new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) + } } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index adaeab0b5c027..11b4eb5c888be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -19,27 +19,27 @@ package org.apache.spark.sql.catalyst.expressions protected class AttributeEquals(val a: Attribute) { - override def hashCode() = a match { + override def hashCode(): Int = a match { case ar: AttributeReference => ar.exprId.hashCode() case a => a.hashCode() } - override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + override def equals(other: Any): Boolean = (a, other.asInstanceOf[AttributeEquals].a) match { case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId case (a1, a2) => a1 == a2 } } object AttributeSet { - def apply(a: Attribute) = - new AttributeSet(Set(new AttributeEquals(a))) + def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ - def apply(baseSet: Seq[Expression]) = + def apply(baseSet: Seq[Expression]): AttributeSet = { new AttributeSet( baseSet .flatMap(_.references) .map(new AttributeEquals(_)).toSet) + } } /** @@ -57,8 +57,9 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { /** Returns true if the members of this AttributeSet and other are the same. */ - override def equals(other: Any) = other match { - case otherSet: AttributeSet => baseSet.map(_.a).forall(otherSet.contains) + override def equals(other: Any): Boolean = other match { + case otherSet: AttributeSet => + otherSet.size == baseSet.size && baseSet.map(_.a).forall(otherSet.contains) case _ => false } @@ -81,32 +82,34 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) * Returns true if the [[Attribute Attributes]] in this set are a subset of the Attributes in * `other`. */ - def subsetOf(other: AttributeSet) = baseSet.subsetOf(other.baseSet) + def subsetOf(other: AttributeSet): Boolean = baseSet.subsetOf(other.baseSet) /** * Returns a new [[AttributeSet]] that does not contain any of the [[Attribute Attributes]] found * in `other`. */ - def --(other: Traversable[NamedExpression]) = + def --(other: Traversable[NamedExpression]): AttributeSet = new AttributeSet(baseSet -- other.map(a => new AttributeEquals(a.toAttribute))) /** * Returns a new [[AttributeSet]] that contains all of the [[Attribute Attributes]] found * in `other`. */ - def ++(other: AttributeSet) = new AttributeSet(baseSet ++ other.baseSet) + def ++(other: AttributeSet): AttributeSet = new AttributeSet(baseSet ++ other.baseSet) /** * Returns a new [[AttributeSet]] contain only the [[Attribute Attributes]] where `f` evaluates to * true. */ - override def filter(f: Attribute => Boolean) = new AttributeSet(baseSet.filter(ae => f(ae.a))) + override def filter(f: Attribute => Boolean): AttributeSet = + new AttributeSet(baseSet.filter(ae => f(ae.a))) /** * Returns a new [[AttributeSet]] that only contains [[Attribute Attributes]] that are found in * `this` and `other`. */ - def intersect(other: AttributeSet) = new AttributeSet(baseSet.intersect(other.baseSet)) + def intersect(other: AttributeSet): AttributeSet = + new AttributeSet(baseSet.intersect(other.baseSet)) override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 76a9f08dea85f..2225621dbaabd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) type EvaluatedType = Any - override def toString = s"input[$ordinal]" + override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1bc858478ee1..31f1a5fdc7e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -29,9 +29,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) - override def foldable = child.foldable + override def foldable: Boolean = child.foldable - override def nullable = forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true @@ -103,7 +103,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def toString = s"CAST($child, $dataType)" + override def toString: String = s"CAST($child, $dataType)" type EvaluatedType = Any @@ -394,10 +394,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val casts = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } - // TODO: This is very slow! - buildCast[Row](_, row => Row(row.toSeq.zip(casts).map { - case (v, cast) => if (v == null) null else cast(v) - }: _*)) + // TODO: Could be faster? + val newRow = new GenericMutableRow(from.fields.size) + buildCast[Row](_, row => { + var i = 0 + while (i < row.length) { + val v = row(i) + newRow.update(i, if (v == null) null else casts(i)(v)) + i += 1 + } + newRow.copy() + }) } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { @@ -430,14 +437,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } } // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue() = { + override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6ad39b8372cfb..4e3bbc06a5b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -65,7 +65,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved = !children.exists(!_.resolved) + def childrenResolved: Boolean = !children.exists(!_.resolved) /** * Returns a string representation of this expression that does not have developer centric @@ -84,9 +84,9 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express def symbol: String - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def toString = s"($left $symbol $right)" + override def toString: String = s"($left $symbol $right)" } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -104,8 +104,8 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = ??? - override def nullable = false - override def foldable = false - override def dataType = ??? + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = false + override def foldable: Boolean = false + override def dataType: DataType = throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index db5d897ee569f..c2866cd955409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -40,7 +40,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericRow(outputArray) } - override def toString = s"Row => [${exprArray.mkString(",")}]" + override def toString: String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -107,12 +107,12 @@ class JoinedRow extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -142,7 +142,7 @@ class JoinedRow extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -153,7 +153,7 @@ class JoinedRow extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -207,12 +207,12 @@ class JoinedRow2 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -242,7 +242,7 @@ class JoinedRow2 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -253,7 +253,7 @@ class JoinedRow2 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -301,12 +301,12 @@ class JoinedRow3 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -336,7 +336,7 @@ class JoinedRow3 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -347,7 +347,7 @@ class JoinedRow3 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -395,12 +395,12 @@ class JoinedRow4 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -430,7 +430,7 @@ class JoinedRow4 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -441,7 +441,7 @@ class JoinedRow4 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" @@ -489,12 +489,12 @@ class JoinedRow5 extends Row { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length = row1.length + row2.length + override def length: Int = row1.length + row2.length - override def apply(i: Int) = + override def apply(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) - override def isNullAt(i: Int) = + override def isNullAt(i: Int): Boolean = if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) override def getInt(i: Int): Int = @@ -524,7 +524,7 @@ class JoinedRow5 extends Row { override def getAs[T](i: Int): T = if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy() = { + override def copy(): Row = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) var i = 0 @@ -535,7 +535,7 @@ class JoinedRow5 extends Row { new GenericRow(copiedValues) } - override def toString() = { + override def toString: String = { // Make sure toString never throws NullPointerException. if ((row1 eq null) && (row2 eq null)) { "[ empty row ]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index b2c6d3029031d..f5fea3f015dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Random -import org.apache.spark.sql.types.DoubleType + +import org.apache.spark.sql.types.{DataType, DoubleType} case object Rand extends LeafExpression { - override def dataType = DoubleType - override def nullable = false + override def dataType: DataType = DoubleType + override def nullable: Boolean = false private[this] lazy val rand = new Random - override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def eval(input: Row = null): EvaluatedType = { + rand.nextDouble().asInstanceOf[EvaluatedType] + } - override def toString = "RAND()" + override def toString: String = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790d..389dc4f745723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,9 +29,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi type EvaluatedType = Any - def nullable = true + override def nullable: Boolean = true - override def toString = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off @@ -39,363 +39,669 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) - - s""" - case $x => - function.asInstanceOf[($anys) => Any]( - $evals) - """ + val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) + val evals = (0 to x - 1).map(x => s"ScalaReflection.convertToScala(child$x.eval(input), child$x.dataType)").reduce(_ + ",\n " + _) + + s""" case $x => + val func = function.asInstanceOf[($anys) => Any] + $childs + (input: Row) => { + func( + $evals) + } + """ }.foreach(println) */ - - override def eval(input: Row): Any = { - val result = children.size match { - case 0 => function.asInstanceOf[() => Any]() - case 1 => - function.asInstanceOf[(Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) - - - case 2 => - function.asInstanceOf[(Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) - - - case 3 => - function.asInstanceOf[(Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) - - - case 4 => - function.asInstanceOf[(Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) - - - case 5 => - function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) - - - case 6 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) - - - case 7 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) - - - case 8 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) - - - case 9 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) - - - case 10 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) - - - case 11 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) - - - case 12 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) - - - case 13 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) - - - case 14 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) - - - case 15 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) - - - case 16 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) - - - case 17 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) - - - case 18 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) - - - case 19 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) - - - case 20 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) - - - case 21 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) - - - case 22 => - function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), - ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), - ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), - ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), - ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), - ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), - ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), - ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), - ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), - ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), - ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), - ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), - ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), - ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), - ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), - ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), - ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), - ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), - ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), - ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), - ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), - ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) - - } - // scalastyle:on - - ScalaReflection.convertToCatalyst(result, dataType) + + val f = children.size match { + case 0 => + val func = function.asInstanceOf[() => Any] + (input: Row) => { + func() + } + + case 1 => + val func = function.asInstanceOf[(Any) => Any] + val child0 = children(0) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType)) + } + + case 2 => + val func = function.asInstanceOf[(Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType)) + } + + case 3 => + val func = function.asInstanceOf[(Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType)) + } + + case 4 => + val func = function.asInstanceOf[(Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType)) + } + + case 5 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType)) + } + + case 6 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType)) + } + + case 7 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType)) + } + + case 8 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType)) + } + + case 9 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType)) + } + + case 10 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType)) + } + + case 11 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType)) + } + + case 12 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType)) + } + + case 13 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType)) + } + + case 14 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType)) + } + + case 15 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType)) + } + + case 16 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType)) + } + + case 17 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType)) + } + + case 18 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType)) + } + + case 19 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType)) + } + + case 20 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType)) + } + + case 21 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType)) + } + + case 22 => + val func = function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any] + val child0 = children(0) + val child1 = children(1) + val child2 = children(2) + val child3 = children(3) + val child4 = children(4) + val child5 = children(5) + val child6 = children(6) + val child7 = children(7) + val child8 = children(8) + val child9 = children(9) + val child10 = children(10) + val child11 = children(11) + val child12 = children(12) + val child13 = children(13) + val child14 = children(14) + val child15 = children(15) + val child16 = children(16) + val child17 = children(17) + val child18 = children(18) + val child19 = children(19) + val child20 = children(20) + val child21 = children(21) + (input: Row) => { + func( + ScalaReflection.convertToScala(child0.eval(input), child0.dataType), + ScalaReflection.convertToScala(child1.eval(input), child1.dataType), + ScalaReflection.convertToScala(child2.eval(input), child2.dataType), + ScalaReflection.convertToScala(child3.eval(input), child3.dataType), + ScalaReflection.convertToScala(child4.eval(input), child4.dataType), + ScalaReflection.convertToScala(child5.eval(input), child5.dataType), + ScalaReflection.convertToScala(child6.eval(input), child6.dataType), + ScalaReflection.convertToScala(child7.eval(input), child7.dataType), + ScalaReflection.convertToScala(child8.eval(input), child8.dataType), + ScalaReflection.convertToScala(child9.eval(input), child9.dataType), + ScalaReflection.convertToScala(child10.eval(input), child10.dataType), + ScalaReflection.convertToScala(child11.eval(input), child11.dataType), + ScalaReflection.convertToScala(child12.eval(input), child12.dataType), + ScalaReflection.convertToScala(child13.eval(input), child13.dataType), + ScalaReflection.convertToScala(child14.eval(input), child14.dataType), + ScalaReflection.convertToScala(child15.eval(input), child15.dataType), + ScalaReflection.convertToScala(child16.eval(input), child16.dataType), + ScalaReflection.convertToScala(child17.eval(input), child17.dataType), + ScalaReflection.convertToScala(child18.eval(input), child18.dataType), + ScalaReflection.convertToScala(child19.eval(input), child19.dataType), + ScalaReflection.convertToScala(child20.eval(input), child20.dataType), + ScalaReflection.convertToScala(child21.eval(input), child21.dataType)) + } } + + // scalastyle:on + + override def eval(input: Row): Any = ScalaReflection.convertToCatalyst(f(input), dataType) + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d00b2ac09745c..83074eb1e6310 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.DataType abstract sealed class SortDirection case object Ascending extends SortDirection @@ -31,12 +32,12 @@ case object Descending extends SortDirection case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 21d714c9a8c3b..47b6f358ed1b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,126 +62,126 @@ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) - def copy(): this.type + def copy(): MutableValue } final class MutableInt extends MutableValue { var value: Int = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Int] + value = v.asInstanceOf[Int] } - def copy() = { + override def copy(): MutableInt = { val newCopy = new MutableInt newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableInt] } } final class MutableFloat extends MutableValue { var value: Float = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Float] + value = v.asInstanceOf[Float] } - def copy() = { + override def copy(): MutableFloat = { val newCopy = new MutableFloat newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableFloat] } } final class MutableBoolean extends MutableValue { var value: Boolean = false - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Boolean] + value = v.asInstanceOf[Boolean] } - def copy() = { + override def copy(): MutableBoolean = { val newCopy = new MutableBoolean newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableBoolean] } } final class MutableDouble extends MutableValue { var value: Double = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Double] + value = v.asInstanceOf[Double] } - def copy() = { + override def copy(): MutableDouble = { val newCopy = new MutableDouble newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableDouble] } } final class MutableShort extends MutableValue { var value: Short = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Short] } - def copy() = { + override def copy(): MutableShort = { val newCopy = new MutableShort newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableShort] } } final class MutableLong extends MutableValue { var value: Long = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Long] } - def copy() = { + override def copy(): MutableLong = { val newCopy = new MutableLong newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableLong] } } final class MutableByte extends MutableValue { var value: Byte = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = value = { isNull = false v.asInstanceOf[Byte] } - def copy() = { + override def copy(): MutableByte = { val newCopy = new MutableByte newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableByte] } } final class MutableAny extends MutableValue { var value: Any = _ - def boxed = if (isNull) null else value - def update(v: Any) = value = { + override def boxed: Any = if (isNull) null else value + override def update(v: Any): Unit = { isNull = false - v.asInstanceOf[Any] + value = v.asInstanceOf[Any] } - def copy() = { + override def copy(): MutableAny = { val newCopy = new MutableAny newCopy.isNull = isNull newCopy.value = value - newCopy.asInstanceOf[this.type] + newCopy.asInstanceOf[MutableAny] } } @@ -234,9 +234,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR if (value == null) setNullAt(ordinal) else values(ordinal).update(value) } - override def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) - override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5297d1e31246c..30da4faa3f1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -79,27 +79,29 @@ abstract class AggregateFunction /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - override def nullable = base.nullable - override def dataType = base.dataType + override def nullable: Boolean = base.nullable + override def dataType: DataType = base.dataType def update(input: Row): Unit // Do we really need this? - override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance(): AggregateFunction = { + makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + } } case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MIN($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) } - override def newInstance() = new MinFunction(child, this) + override def newInstance(): MinFunction = new MinFunction(child, this) } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -121,16 +123,16 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"MAX($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) } - override def newInstance() = new MaxFunction(child, this) + override def newInstance(): MaxFunction = new MaxFunction(child, this) } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -152,29 +154,29 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT($child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } - override def newInstance() = new CountFunction(child, this) + override def newInstance(): CountFunction = new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) - override def children = expressions + override def children: Seq[Expression] = expressions - override def nullable = false - override def dataType = LongType - override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance() = new CountDistinctFunction(expressions, this) + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(expressions), "partialSets")() SplitEvaluation( CombineSetsAndCount(partialSet.toAttribute), @@ -185,11 +187,11 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) - override def children = expressions - override def nullable = false - override def dataType = ArrayType(expressions.head.dataType) - override def toString = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance() = new CollectHashSetFunction(expressions, this) + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: ArrayType = ArrayType(expressions.head.dataType) + override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) } case class CollectHashSetFunction( @@ -219,11 +221,13 @@ case class CollectHashSetFunction( case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { def this() = this(null) - override def children = inputSet :: Nil - override def nullable = false - override def dataType = LongType - override def toString = s"CombineAndCount($inputSet)" - override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = false + override def dataType: DataType = LongType + override def toString: String = s"CombineAndCount($inputSet)" + override def newInstance(): CombineSetsAndCountFunction = { + new CombineSetsAndCountFunction(inputSet, this) + } } case class CombineSetsAndCountFunction( @@ -249,27 +253,31 @@ case class CombineSetsAndCountFunction( case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: DataType = child.dataType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctPartitionFunction = { + new ApproxCountDistinctPartitionFunction(child, this, relativeSD) + } } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD) + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" + override def newInstance(): ApproxCountDistinctMergeFunction = { + new ApproxCountDistinctMergeFunction(child, this, relativeSD) + } } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = LongType - override def toString = s"APPROXIMATE COUNT(DISTINCT $child)" + override def nullable: Boolean = false + override def dataType: LongType.type = LongType + override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" override def asPartial: SplitEvaluation = { val partialCount = @@ -280,14 +288,14 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) partialCount :: Nil) } - override def newInstance() = new CountDistinctFunction(child :: Nil, this) + override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -296,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString = s"AVG($child)" + override def toString: String = s"AVG($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -323,14 +331,14 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } } - override def newInstance() = new AverageFunction(child, this) + override def newInstance(): AverageFunction = new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true + override def nullable: Boolean = true - override def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -339,7 +347,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString = s"SUM($child)" + override def toString: String = s"SUM($child)" override def asPartial: SplitEvaluation = { child.dataType match { @@ -357,7 +365,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } } - override def newInstance() = new SumFunction(child, this) + override def newInstance(): SumFunction = new SumFunction(child, this) } /** @@ -377,19 +385,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - override def children = child :: Nil - override def nullable = true - override def dataType = child.dataType - override def toString = s"CombineSum($child)" - override def newInstance() = new CombineSumFunction(child, this) + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } case class SumDistinct(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { def this() = this(null) - override def nullable = true - override def dataType = child.dataType match { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive case DecimalType.Unlimited => @@ -397,10 +405,10 @@ case class SumDistinct(child: Expression) case _ => child.dataType } - override def toString = s"SUM(DISTINCT ${child})" - override def newInstance() = new SumDistinctFunction(child, this) + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - override def asPartial = { + override def asPartial: SplitEvaluation = { val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() SplitEvaluation( CombineSetsAndSum(partialSet.toAttribute, this), @@ -411,11 +419,13 @@ case class SumDistinct(child: Expression) case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { def this() = this(null, null) - override def children = inputSet :: Nil - override def nullable = true - override def dataType = base.dataType - override def toString = s"CombineAndSum($inputSet)" - override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } } case class CombineSetsAndSumFunction( @@ -449,9 +459,9 @@ case class CombineSetsAndSumFunction( } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = true - override def dataType = child.dataType - override def toString = s"FIRST($child)" + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" override def asPartial: SplitEvaluation = { val partialFirst = Alias(First(child), "PartialFirst")() @@ -459,14 +469,14 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance() = new FirstFunction(child, this) + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references = child.references - override def nullable = true - override def dataType = child.dataType - override def toString = s"LAST($child)" + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" override def asPartial: SplitEvaluation = { val partialLast = Alias(Last(child), "PartialLast")() @@ -474,7 +484,7 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode Last(partialLast.toAttribute), partialLast :: Nil) } - override def newInstance() = new LastFunction(child, this) + override def newInstance(): LastFunction = new LastFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -713,6 +723,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression) extends Agg result = input } - override def eval(input: Row): Any = if (result != null) expr.eval(result.asInstanceOf[Row]) - else null + override def eval(input: Row): Any = { + if (result != null) expr.eval(result.asInstanceOf[Row]) else null + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 00b0d3c683fe2..1f6526ef66c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.types._ case class UnaryMinus(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"-$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"-$child" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -47,10 +47,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = DoubleType - override def foldable = child.foldable - def nullable = true - override def toString = s"SQRT($child)" + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"SQRT($child)" lazy val numeric = child.dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -74,14 +74,14 @@ abstract class BinaryArithmetic extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + def nullable: Boolean = left.nullable || right.nullable override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType && !DecimalType.isFixed(left.dataType) - def dataType = { + def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -108,7 +108,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "+" + override def symbol: String = "+" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -131,7 +131,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "-" + override def symbol: String = "-" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -154,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "*" + override def symbol: String = "*" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] @@ -177,9 +177,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "/" + override def symbol: String = "/" - override def nullable = true + override def nullable: Boolean = true lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div @@ -203,9 +203,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "%" + override def symbol: String = "%" - override def nullable = true + override def nullable: Boolean = true lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] @@ -232,7 +232,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise and(&) of two numbers. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" + override def symbol: String = "&" lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -253,7 +253,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "|" + override def symbol: String = "|" lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -274,7 +274,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * A function that calculates bitwise xor(^) of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "^" + override def symbol: String = "^" lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -297,10 +297,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme case class BitwiseNot(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"~$child" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"~$child" lazy val not: (Any) => Any = dataType match { case ByteType => @@ -327,17 +327,17 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any - override def foldable = left.foldable && right.foldable + override def foldable: Boolean = left.foldable && right.foldable - override def nullable = left.nullable && right.nullable + override def nullable: Boolean = left.nullable && right.nullable - override def children = left :: right :: Nil + override def children: Seq[Expression] = left :: right :: Nil override lazy val resolved = left.resolved && right.resolved && left.dataType == right.dataType - override def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") @@ -366,7 +366,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } } - override def toString = s"MaxOf($left, $right)" + override def toString: String = s"MaxOf($left, $right)" } /** @@ -375,10 +375,10 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { case class Abs(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"Abs($child)" + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"Abs($child)" lazy val numeric = dataType match { case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e48b8cde20eda..d1abf3c0b64a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -91,7 +91,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val startTime = System.nanoTime() val result = create(in) val endTime = System.nanoTime() - def timeMs = (endTime - startTime).toDouble / 1000000 + def timeMs: Double = (endTime - startTime).toDouble / 1000000 logInfo(s"Code generated expression $in in $timeMs ms") result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 68051a2a2007e..3fd78db297462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -27,12 +27,12 @@ import org.apache.spark.sql.types._ case class GetItem(child: Expression, ordinal: Expression) extends Expression { type EvaluatedType = Any - val children = child :: ordinal :: Nil + val children: Seq[Expression] = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ - override def nullable = true - override def foldable = child.foldable && ordinal.foldable + override def nullable: Boolean = true + override def foldable: Boolean = child.foldable && ordinal.foldable - def dataType = child.dataType match { + override def dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case MapType(_, vt, _) => vt } @@ -40,7 +40,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { childrenResolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - override def toString = s"$child[$ordinal]" + override def toString: String = s"$child[$ordinal]" override def eval(input: Row): Any = { val value = child.eval(input) @@ -75,8 +75,8 @@ trait GetField extends UnaryExpression { self: Product => type EvaluatedType = Any - override def foldable = child.foldable - override def toString = s"$child.${field.name}" + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" def field: StructField } @@ -86,8 +86,8 @@ trait GetField extends UnaryExpression { */ case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { - def dataType = field.dataType - override def nullable = child.nullable || field.nullable + override def dataType: DataType = field.dataType + override def nullable: Boolean = child.nullable || field.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] @@ -101,8 +101,8 @@ case class StructGetField(child: Expression, field: StructField, ordinal: Int) e case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) extends GetField { - def dataType = ArrayType(field.dataType, containsNull) - override def nullable = child.nullable + override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[Row]] @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) lazy val childTypes = children.map(_.dataType).distinct @@ -140,5 +140,5 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString = s"Array(${children.mkString(",")})" + override def toString: String = s"Array(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 83d8c1d42bca4..adb94df7d1c7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -24,9 +24,9 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override type EvaluatedType = Any override def dataType: DataType = LongType - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"UnscaledValue($child)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"UnscaledValue($child)" override def eval(input: Row): Any = { val childResult = child.eval(input) @@ -43,9 +43,9 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"MakeDecimal($child,$precision,$scale)" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: Row): Decimal = { val childResult = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 0983d274def3f..860b72fad38b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -45,7 +45,7 @@ abstract class Generator extends Expression { override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) - override def nullable = false + override def nullable: Boolean = false /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -89,7 +89,7 @@ case class UserDefinedGenerator( function(inputRow(input)) } - override def toString = s"UserDefinedGenerator(${children.mkString(",")})" + override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" } /** @@ -130,5 +130,5 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def toString() = s"explode($child)" + override def toString: String = s"explode($child)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 9ff66563c8164..19f3fc9c2291a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -64,14 +64,13 @@ object IntegerLiteral { case class Literal(value: Any, dataType: DataType) extends LeafExpression { - override def foldable = true - def nullable = value == null + override def foldable: Boolean = true + override def nullable: Boolean = value == null - - override def toString = if (value != null) value.toString else "null" + override def toString: String = if (value != null) value.toString else "null" type EvaluatedType = Any - override def eval(input: Row):Any = value + override def eval(input: Row): Any = value } // TODO: Specialize @@ -79,9 +78,9 @@ case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean extends LeafExpression { type EvaluatedType = Any - def update(expression: Expression, input: Row) = { + def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) } - override def eval(input: Row) = value + override def eval(input: Row): Any = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 08361d043b6ed..bcbcbeb31c7b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees.LeafNode import org.apache.spark.sql.types._ object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() - def newExprId = ExprId(curId.getAndIncrement()) + def newExprId: ExprId = ExprId(curId.getAndIncrement()) def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType) } @@ -79,13 +80,13 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => - override def references = AttributeSet(this) + override def references: AttributeSet = AttributeSet(this) def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute - def toAttribute = this + def toAttribute: Attribute = this def newInstance(): Attribute } @@ -112,10 +113,10 @@ case class Alias(child: Expression, name: String)( override type EvaluatedType = Any - override def eval(input: Row) = child.eval(input) + override def eval(input: Row): Any = child.eval(input) - override def dataType = child.dataType - override def nullable = child.nullable + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable override def metadata: Metadata = { explicitMetadata.getOrElse { child match { @@ -125,7 +126,7 @@ case class Alias(child: Expression, name: String)( } } - override def toAttribute = { + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { @@ -135,7 +136,9 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" - override protected final def otherCopyArgs = exprId :: qualifiers :: explicitMetadata :: Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: explicitMetadata :: Nil + } override def equals(other: Any): Boolean = other match { case a: Alias => @@ -166,7 +169,7 @@ case class AttributeReference( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def equals(other: Any) = other match { + override def equals(other: Any): Boolean = other match { case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -180,7 +183,7 @@ case class AttributeReference( h } - override def newInstance() = + override def newInstance(): AttributeReference = AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** @@ -205,7 +208,7 @@ case class AttributeReference( /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - override def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = { if (newQualifiers.toSet == qualifiers.toSet) { this } else { @@ -227,20 +230,22 @@ case class AttributeReference( case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { type EvaluatedType = Any - override def toString = name - - override def withNullability(newNullability: Boolean): Attribute = ??? - override def newInstance(): Attribute = ??? - override def withQualifiers(newQualifiers: Seq[String]): Attribute = ??? - override def withName(newName: String): Attribute = ??? - override def qualifiers: Seq[String] = ??? - override def exprId: ExprId = ??? - override def eval(input: Row): EvaluatedType = ??? - override def nullable: Boolean = ??? + override def toString: String = name + + override def withNullability(newNullability: Boolean): Attribute = + throw new UnsupportedOperationException + override def newInstance(): Attribute = throw new UnsupportedOperationException + override def withQualifiers(newQualifiers: Seq[String]): Attribute = + throw new UnsupportedOperationException + override def withName(newName: String): Attribute = throw new UnsupportedOperationException + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + override def exprId: ExprId = throw new UnsupportedOperationException + override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } object VirtualColumn { - val groupingIdName = "grouping__id" - def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)() + val groupingIdName: String = "grouping__id" + def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 08b982bc671e7..d1f3d4f4ee9ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -19,22 +19,23 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - def nullable = !children.exists(!_.nullable) + override def nullable: Boolean = !children.exists(!_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable = !children.exists(!_.foldable) + override def foldable: Boolean = !children.exists(!_.foldable) // Only resolved if all the children are of the same type. override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) - override def toString = s"Coalesce(${children.mkString(",")})" + override def toString: String = s"Coalesce(${children.mkString(",")})" - def dataType = if (resolved) { + def dataType: DataType = if (resolved) { children.head.dataType } else { val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") @@ -54,20 +55,20 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false override def eval(input: Row): Any = { child.eval(input) == null } - override def toString = s"IS NULL $child" + override def toString: String = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { - override def foldable = child.foldable - def nullable = false - override def toString = s"IS NOT NULL $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = false + override def toString: String = s"IS NOT NULL $child" override def eval(input: Row): Any = { child.eval(input) != null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 0024ef92c0452..7e47cb3fffe12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -34,7 +34,7 @@ object InterpretedPredicate { trait Predicate extends Expression { self: Product => - def dataType = BooleanType + override def dataType: DataType = BooleanType type EvaluatedType = Any } @@ -72,13 +72,13 @@ trait PredicateHelper { abstract class BinaryPredicate extends BinaryExpression with Predicate { self: Product => - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable } case class Not(child: Expression) extends UnaryExpression with Predicate { - override def foldable = child.foldable - def nullable = child.nullable - override def toString = s"NOT $child" + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + override def toString: String = s"NOT $child" override def eval(input: Row): Any = { child.eval(input) match { @@ -92,10 +92,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate { * Evaluates to `true` if `list` contains `value`. */ case class In(value: Expression, list: Seq[Expression]) extends Predicate { - def children = value +: list + override def children: Seq[Expression] = value +: list - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: Row): Any = { val evaluatedValue = value.eval(input) @@ -110,10 +110,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case class InSet(value: Expression, hset: Set[Any]) extends Predicate { - def children = value :: Nil + override def children: Seq[Expression] = value :: Nil - def nullable = true // TODO: Figure out correct nullability semantics of IN. - override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. + override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" override def eval(input: Row): Any = { hset.contains(value.eval(input)) @@ -121,7 +121,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "&&" + override def symbol: String = "&&" override def eval(input: Row): Any = { val l = left.eval(input) @@ -143,7 +143,7 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { } case class Or(left: Expression, right: Expression) extends BinaryPredicate { - def symbol = "||" + override def symbol: String = "||" override def eval(input: Row): Any = { val l = left.eval(input) @@ -169,7 +169,8 @@ abstract class BinaryComparison extends BinaryPredicate { } case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "=" + override def symbol: String = "=" + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { @@ -185,8 +186,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=>" - override def nullable = false + override def symbol: String = "<=>" + + override def nullable: Boolean = false + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -201,9 +204,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<" + override def symbol: String = "<" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -216,7 +219,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -230,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = "<=" + override def symbol: String = "<=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -245,7 +248,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -259,9 +262,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">" + override def symbol: String = ">" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -288,9 +291,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - def symbol = ">=" + override def symbol: String = ">=" - lazy val ordering = { + lazy val ordering: Ordering[Any] = { if (left.dataType != right.dataType) { throw new TreeNodeException(this, s"Types do not match ${left.dataType} != ${right.dataType}") @@ -303,7 +306,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar override def eval(input: Row): Any = { val evalE1 = left.eval(input) - if(evalE1 == null) { + if (evalE1 == null) { null } else { val evalE2 = right.eval(input) @@ -317,13 +320,13 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends Expression { - def children = predicate :: trueValue :: falseValue :: Nil - override def nullable = trueValue.nullable || falseValue.nullable + override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil + override def nullable: Boolean = trueValue.nullable || falseValue.nullable override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - def dataType = { + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException( this, @@ -342,7 +345,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def toString = s"if ($predicate) $trueValue else $falseValue" + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } // scalastyle:off @@ -362,9 +365,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi // scalastyle:on case class CaseWhen(branches: Seq[Expression]) extends Expression { type EvaluatedType = Any - def children = branches - def dataType = { + override def children: Seq[Expression] = branches + + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") } @@ -379,12 +383,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { @transient private[this] lazy val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable = { + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } - override lazy val resolved = { + override lazy val resolved: Boolean = { if (!childrenResolved) { false } else { @@ -415,7 +419,7 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { res } - override def toString = { + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index f03d6f71a9fae..a8983df208318 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,8 +44,8 @@ trait MutableRow extends Row { */ object EmptyRow extends Row { override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq = Seq.empty - override def length = 0 + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException override def getInt(i: Int): Int = throw new UnsupportedOperationException override def getLong(i: Int): Long = throw new UnsupportedOperationException @@ -56,7 +56,7 @@ object EmptyRow extends Row { override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException - def copy() = this + override def copy(): Row = this } /** @@ -66,17 +66,17 @@ object EmptyRow extends Row { */ class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values.toSeq - override def length = values.length + override def length: Int = values.length - override def apply(i: Int) = values(i) + override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int) = values(i) == null + override def isNullAt(i: Int): Boolean = values(i) == null override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") @@ -167,16 +167,19 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { case _ => false } - def copy() = this + override def copy(): Row = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) extends GenericRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { /** No-arg constructor for serialization. */ - def this() = this(null) + protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) @@ -194,7 +197,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy() = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 3a5bdca1f07c3..35faa00782e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -26,17 +26,17 @@ import org.apache.spark.util.collection.OpenHashSet case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any - def nullable = false + override def nullable: Boolean = false // We are currently only using these Expressions internally for aggregation. However, if we ever // expose these to users we'll want to create a proper type instead of hijacking ArrayType. - def dataType = ArrayType(elementType) + override def dataType: DataType = ArrayType(elementType) - def eval(input: Row): Any = { + override def eval(input: Row): Any = { new OpenHashSet[Any]() } - override def toString = s"new Set($dataType)" + override def toString: String = s"new Set($dataType)" } /** @@ -46,12 +46,13 @@ case class NewSet(elementType: DataType) extends LeafExpression { case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any - def children = item :: set :: Nil + override def children: Seq[Expression] = item :: set :: Nil - def nullable = set.nullable + override def nullable: Boolean = set.nullable - def dataType = set.dataType - def eval(input: Row): Any = { + override def dataType: DataType = set.dataType + + override def eval(input: Row): Any = { val itemEval = item.eval(input) val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def toString = s"$set += $item" + override def toString: String = s"$set += $item" } /** @@ -77,13 +78,13 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any - def nullable = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable - def dataType = left.dataType + override def dataType: DataType = left.dataType - def symbol = "++=" + override def symbol: String = "++=" - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] @@ -109,16 +110,16 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def nullable = child.nullable + override def nullable: Boolean = child.nullable - def dataType = LongType + override def dataType: DataType = LongType - def eval(input: Row): Any = { + override def eval(input: Row): Any = { val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] if (childEval != null) { childEval.size.toLong } } - override def toString = s"$child.count()" + override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f85ee0a9bb6d8..3cdca4e9dd2d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -33,8 +33,8 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = left.nullable || right.nullable - def dataType: DataType = BooleanType + override def nullable: Boolean = left.nullable || right.nullable + override def dataType: DataType = BooleanType // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -98,11 +98,11 @@ trait CaseConversionExpression { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "LIKE" + override def symbol: String = "LIKE" // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = + override def escape(v: String): String = if (!v.isEmpty) { "(?s)" + (' ' +: v.init).zip(v).flatMap { case (prev, '\\') => "" @@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression) case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - def symbol = "RLIKE" + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } @@ -141,7 +141,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toUpperCase() - override def toString() = s"Upper($child)" + override def toString: String = s"Upper($child)" } /** @@ -151,7 +151,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: String): String = v.toLowerCase() - override def toString() = s"Lower($child)" + override def toString: String = s"Lower($child)" } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -160,7 +160,7 @@ trait StringComparison { type EvaluatedType = Any - def nullable: Boolean = left.nullable || right.nullable + override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -175,9 +175,9 @@ trait StringComparison { } } - def symbol: String = nodeName + override def symbol: String = nodeName - override def toString() = s"$nodeName($left, $right)" + override def toString: String = s"$nodeName($left, $right)" } /** @@ -185,7 +185,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - override def compare(l: String, r: String) = l.contains(r) + override def compare(l: String, r: String): Boolean = l.contains(r) } /** @@ -193,7 +193,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.startsWith(r) + override def compare(l: String, r: String): Boolean = l.startsWith(r) } /** @@ -201,7 +201,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringComparison { - def compare(l: String, r: String) = l.endsWith(r) + override def compare(l: String, r: String): Boolean = l.endsWith(r) } /** @@ -212,17 +212,17 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends type EvaluatedType = Any - override def foldable = str.foldable && pos.foldable && len.foldable + override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - def nullable: Boolean = str.nullable || pos.nullable || len.nullable - def dataType: DataType = { + override def nullable: Boolean = str.nullable || pos.nullable || len.nullable + override def dataType: DataType = { if (!resolved) { throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") } if (str.dataType == BinaryType) str.dataType else StringType } - override def children = str :: pos :: len :: Nil + override def children: Seq[Expression] = str :: pos :: len :: Nil @inline def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) @@ -267,7 +267,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends } } - override def toString = len match { + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" case _ => s"SUBSTR($str, $pos, $len)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1a75fcf3545bd..c23d3b61887c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -32,6 +33,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: Batch("Combine Limits", FixedPoint(100), CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), @@ -137,7 +141,7 @@ object ColumnPruning extends Rule[LogicalPlan] { condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) /** Applies a projection only when the child is producing unnecessary attributes */ - def pruneJoinChild(c: LogicalPlan) = prunedChild(c, allReferences) + def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b4c445b3badf1..9c8c643f7d17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -91,16 +91,18 @@ object PhysicalOperation extends PredicateHelper { (None, Nil, other, Map.empty) } - def collectAliases(fields: Seq[Expression]) = fields.collect { + def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child }.toMap - def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { + expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) - case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index bd9291e9ba5d7..02f7c26a8ab6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -71,7 +71,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionDown(e: Expression) = { + @inline def transformExpressionDown(e: Expression): Expression = { val newE = e.transformDown(rule) if (newE.fastEquals(e)) { e @@ -104,7 +104,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { var changed = false - @inline def transformExpressionUp(e: Expression) = { + @inline def transformExpressionUp(e: Expression): Expression = { val newE = e.transformUp(rule) if (newE.fastEquals(e)) { e @@ -165,5 +165,5 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" - override def simpleString = statePrefix + super.simpleString + override def simpleString: String = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0f8b144ccc113..b01a61d7bf8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -73,12 +73,16 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * can do better should override this function. */ def sameResult(plan: LogicalPlan): Boolean = { - plan.getClass == this.getClass && - plan.children.size == children.size && { - logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") - cleanArgs == plan.cleanArgs + val cleanLeft = EliminateSubQueries(this) + val cleanRight = EliminateSubQueries(plan) + + cleanLeft.getClass == cleanRight.getClass && + cleanLeft.children.size == cleanRight.children.size && { + logDebug( + s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]") + cleanRight.cleanArgs == cleanLeft.cleanArgs } && - (plan.children, children).zipped.forall(_ sameResult _) + (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _) } /** Args that have cleaned such that differences in expression id should not affect equality */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 384fe53a68362..4d9e41a2b5d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = projectList.map(_.toAttribute) override lazy val resolved: Boolean = { val containsAggregatesOrGenerators = projectList.exists ( _.collect { @@ -66,19 +66,19 @@ case class Generate( } } - override def output = + override def output: Seq[Attribute] = if (join) child.output ++ generatorOutput else generatorOutput } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output = left.output + override def output: Seq[Attribute] = left.output - override lazy val resolved = + override lazy val resolved: Boolean = childrenResolved && !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } @@ -94,7 +94,7 @@ case class Join( joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - override def output = { + override def output: Seq[Attribute] = { joinType match { case LeftSemi => left.output @@ -109,7 +109,7 @@ case class Join( } } - def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty + private def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguious expression ids. override lazy val resolved: Boolean = { @@ -118,7 +118,7 @@ case class Join( } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - def output = left.output + override def output: Seq[Attribute] = left.output } case class InsertIntoTable( @@ -128,10 +128,10 @@ case class InsertIntoTable( overwrite: Boolean) extends LogicalPlan { - override def children = child :: Nil - override def output = child.output + override def children: Seq[LogicalPlan] = child :: Nil + override def output: Seq[Attribute] = child.output - override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) } @@ -143,14 +143,14 @@ case class CreateTableAsSelect[T]( child: LogicalPlan, allowExisting: Boolean, desc: Option[T] = None) extends UnaryNode { - override def output = Seq.empty[Attribute] - override lazy val resolved = databaseName != None && childrenResolved + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = databaseName != None && childrenResolved } case class WriteToFile( path: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } /** @@ -163,7 +163,7 @@ case class Sort( order: Seq[SortOrder], global: Boolean, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Aggregate( @@ -172,7 +172,7 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - override def output = aggregateExpressions.map(_.toAttribute) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } /** @@ -199,7 +199,7 @@ trait GroupingAnalytics extends UnaryNode { def groupByExprs: Seq[Expression] def aggregations: Seq[NamedExpression] - override def output = aggregations.map(_.toAttribute) + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) } /** @@ -264,7 +264,7 @@ case class Rollup( gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { val limit = limitExpr.eval(null).asInstanceOf[Int] @@ -274,21 +274,21 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { - override def output = child.output.map(_.withQualifiers(alias :: Nil)) + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case class Distinct(child: LogicalPlan) extends UnaryNode { - override def output = child.output + override def output: Seq[Attribute] = child.output } case object NoRelation extends LeafNode { - override def output = Nil + override def output: Seq[Attribute] = Nil /** * Computes [[Statistics]] for this plan. The default implementation assumes the output @@ -301,5 +301,5 @@ case object NoRelation extends LeafNode { } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output = left.output + override def output: Seq[Attribute] = left.output } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index 72b0c5c8e7a26..e737418d9c3bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} /** * Performs a physical redistribution of the data. Used when the consumer of the query @@ -26,14 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} abstract class RedistributeData extends UnaryNode { self: Product => - def output = child.output + override def output: Seq[Attribute] = child.output } case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData { -} + extends RedistributeData case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) - extends RedistributeData { -} - + extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c3d7a3119064..288c11f69fe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -72,7 +72,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "a single partition.") // TODO: This is not really valid... - def clustering = ordering.map(_.child).toSet + def clustering: Set[Expression] = ordering.map(_.child).toSet } sealed trait Partitioning { @@ -113,7 +113,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -124,7 +124,7 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case SinglePartition => true case _ => false } @@ -139,9 +139,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - override def children = expressions - override def nullable = false - override def dataType = IntegerType + override def children: Seq[Expression] = expressions + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = expressions.toSet @@ -152,7 +152,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case h: HashPartitioning if h == this => true case _ => false @@ -178,9 +178,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - override def children = ordering - override def nullable = false - override def dataType = IntegerType + override def children: Seq[SortOrder] = ordering + override def nullable: Boolean = false + override def dataType: DataType = IntegerType private[this] lazy val clusteringSet = ordering.map(_.child).toSet @@ -194,7 +194,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning) = other match { + override def compatibleWith(other: Partitioning): Boolean = other match { case BroadcastPartitioning => true case r: RangePartitioning if r == this => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0ae9f6b2965d4..a2df51e598a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -36,12 +36,12 @@ object CurrentOrigin { override def initialValue: Origin = Origin() } - def get = value.get() - def set(o: Origin) = value.set(o) + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) - def reset() = value.set(Origin()) + def reset(): Unit = value.set(Origin()) - def setPosition(line: Int, start: Int) = { + def setPosition(line: Int, start: Int): Unit = { value.set( value.get.copy(line = Some(line), startPosition = Some(start))) } @@ -57,7 +57,7 @@ object CurrentOrigin { abstract class TreeNode[BaseType <: TreeNode[BaseType]] { self: BaseType with Product => - val origin = CurrentOrigin.get + val origin: Origin = CurrentOrigin.get /** Returns a Seq of the children of this node */ def children: Seq[BaseType] @@ -340,12 +340,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } /** Returns the name of this type of TreeNode. Defaults to the class name. */ - def nodeName = getClass.getSimpleName + def nodeName: String = getClass.getSimpleName /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. */ - protected def stringArgs = productIterator + protected def stringArgs: Iterator[Any] = productIterator /** Returns a string representing the arguments to this node, minus any children */ def argString: String = productIterator.flatMap { @@ -357,18 +357,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { }.mkString(", ") /** String representation of this node without any children */ - def simpleString = s"$nodeName $argString".trim + def simpleString: String = s"$nodeName $argString".trim override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString = generateTreeString(0, new StringBuilder).toString + def treeString: String = generateTreeString(0, new StringBuilder).toString /** * Returns a string representation of the nodes in this tree, where each operator is numbered. * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. */ - def numberedTreeString = + def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** @@ -420,14 +420,14 @@ trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType - def children = Seq(left, right) + def children: Seq[BaseType] = Seq(left, right) } /** * A [[TreeNode]] with no children. */ trait LeafNode[BaseType <: TreeNode[BaseType]] { - def children = Nil + def children: Seq[BaseType] = Nil } /** @@ -435,6 +435,5 @@ trait LeafNode[BaseType <: TreeNode[BaseType]] { */ trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType - def children = child :: Nil + def children: Seq[BaseType] = child :: Nil } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 79a8e06d4b4d4..ea6aa1850db4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -41,11 +41,11 @@ package object trees extends Logging { * A [[TreeNode]] companion for reference equality for Hash based Collection. */ class TreeNodeRef(val obj: TreeNode[_]) { - override def equals(o: Any) = o match { + override def equals(o: Any): Boolean = o match { case that: TreeNodeRef => that.obj.eq(obj) case _ => false } - override def hashCode = if (obj == null) 0 else obj.hashCode + override def hashCode: Int = if (obj == null) 0 else obj.hashCode } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index feed50f9a2a2d..c86214a2aa944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils package object util { - def fileToString(file: File, encoding: String = "UTF-8") = { + def fileToString(file: File, encoding: String = "UTF-8"): String = { val inStream = new FileInputStream(file) val outStream = new ByteArrayOutputStream try { @@ -45,7 +45,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = Utils.getSparkClassLoader) = { + classLoader: ClassLoader = Utils.getSparkClassLoader): String = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { @@ -93,7 +93,7 @@ package object util { new String(out.toByteArray) } - def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + def stringOrNull(a: AnyRef): String = if (a == null) null else a.toString def benchmark[A](f: => A): A = { val startTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index e50e9761431f5..6ee24ee0c1913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -41,6 +41,9 @@ import org.apache.spark.annotation.DeveloperApi sealed class Metadata private[types] (private[types] val map: Map[String, Any]) extends Serializable { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Tests whether this Metadata contains a binding for a key. */ def contains(key: String): Boolean = map.contains(key) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index d973144de3468..952cf5c75688d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -670,6 +670,10 @@ case class PrecisionInfo(precision: Int, scale: Int) */ @DeveloperApi case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = Decimal.DecimalIsFractional @@ -819,6 +823,10 @@ object ArrayType { */ @DeveloperApi case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") @@ -857,6 +865,9 @@ case class StructField( nullable: Boolean = true, metadata: Metadata = Metadata.empty) { + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) @@ -1003,6 +1014,9 @@ object StructType { @DeveloperApi case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + /** No-arg constructor for kryo. */ + protected def this() = this(null) + /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -1121,6 +1135,10 @@ case class MapType( keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- key: ${keyType.typeName}\n") builder.append(s"$prefix-- value: ${valueType.typeName} " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 359aec4a7b5ab..756cd36f05c8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -32,9 +32,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseInsensitiveCatalog = new SimpleCatalog(false) val caseSensitiveAnalyzer = - new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val caseInsensitiveAnalyzer = - new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } val checkAnalysis = new CheckAnalysis diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala new file mode 100644 index 0000000000000..f2f3a84d19380 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.types.IntegerType + +class AttributeSetSuite extends FunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + val aSet = AttributeSet(aLower :: Nil) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + val bSet = AttributeSet(bUpper :: Nil) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + test("sanity check") { + assert(aUpper != aLower) + assert(bUpper != bLower) + } + + test("checks by id not name") { + assert(aSet.contains(aUpper) === true) + assert(aSet.contains(aLower) === true) + assert(aSet.contains(fakeA) === false) + + assert(aSet.contains(bUpper) === false) + assert(aSet.contains(bLower) === false) + } + + test("++ preserves AttributeSet") { + assert((aSet ++ bSet).contains(aUpper) === true) + assert((aSet ++ bSet).contains(aLower) === true) + } + + test("extracts all references references") { + val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil) + assert(addSet.contains(aUpper)) + assert(addSet.contains(aLower)) + assert(addSet.contains(bUpper)) + assert(addSet.contains(bLower)) + } + + test("dedups attributes") { + assert(AttributeSet(aUpper :: aLower :: Nil).size === 1) + } + + test("subset") { + assert(aSet.subsetOf(aAndBSet) === true) + assert(aAndBSet.subsetOf(aSet) === false) + } + + test("equality") { + assert(aSet != aAndBSet) + assert(aAndBSet != aSet) + assert(aSet != bSet) + assert(bSet != aSet) + + assert(aSet == aSet) + assert(aSet == AttributeSet(aUpper :: Nil)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5aece166aad22..4c80359cf07af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils @@ -751,6 +751,67 @@ class DataFrame private[sql]( select(colNames :_*) } + /** + * Computes statistics for numeric columns, including count, mean, stddev, min, and max. + * If no columns are given, this function computes statistics for all numerical columns. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * programmatically compute summary statistics, use the `agg` function instead. + * + * {{{ + * df.describe("age", "height").show() + * + * // output: + * // summary age height + * // count 10.0 10.0 + * // mean 53.3 178.05 + * // stddev 11.6 15.7 + * // min 18.0 163.0 + * // max 92.0 192.0 + * }}} + * + * @group action + */ + @scala.annotation.varargs + def describe(cols: String*): DataFrame = { + + // TODO: Add stddev as an expression, and remove it from here. + def stddevExpr(expr: Expression): Expression = + Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) + + // The list of summary statistics to compute, in the form of expressions. + val statistics = List[(String, Expression => Expression)]( + "count" -> Count, + "mean" -> Average, + "stddev" -> stddevExpr, + "min" -> Min, + "max" -> Max) + + val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList + + val ret: Seq[Row] = if (outputCols.nonEmpty) { + val aggExprs = statistics.flatMap { case (_, colToAgg) => + outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + } + + val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + + // Pivot the data so each summary is one row + row.grouped(outputCols.size).toSeq.zip(statistics).map { + case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*) + } + } else { + // If there are no output columns, just output a single column that contains the stats. + statistics.map { case (name, _) => Row(name) } + } + + // The first column is string type, and the rest are double type. + val schema = StructType( + StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + LocalRelation(schema, ret) + } + /** * Returns the first `n` rows. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index c4534fd5f67e4..967bd76b302d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHa private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { - val kryo = new Kryo() + val kryo = super.newKryo() kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) @@ -57,8 +57,6 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[Decimal]) kryo.setReferences(false) - kryo.setClassLoader(Utils.getSparkClassLoader) - new AllScalaRegistrar().apply(kryo) kryo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 410600b0529d3..0d68810ec6043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -435,11 +435,18 @@ private[sql] case class ParquetRelation2( // Push down filters when possible. Notice that not all filters can be converted to Parquet // filter predicate. Here we try to convert each individual predicate and only collect those // convertible ones. - predicates - .flatMap(ParquetFilters.createFilter) - .reduceOption(FilterApi.and) - .filter(_ => sqlContext.conf.parquetFilterPushDown) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + if (sqlContext.conf.parquetFilterPushDown) { + predicates + // Don't push down predicates which reference partition columns + .filter { pred => + val partitionColNames = partitionColumns.map(_.name).toSet + val referencedColNames = pred.references.map(_.name).toSet + referencedColNames.intersect(partitionColNames).isEmpty + } + .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + } if (isPartitioned) { logInfo { @@ -758,12 +765,13 @@ private[sql] object ParquetRelation2 extends Logging { |${parquetSchema.prettyJson} """.stripMargin - assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + assert(metastoreSchema.size <= parquetSchema.size, schemaConflictMessage) val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + val reorderedParquetSchema = parquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { // Uses Parquet field names but retains Metastore data types. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index d2e807d3a69b6..eb46b46ca5bf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} +import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -204,19 +204,25 @@ private[sql] object ResolvedDataSource { provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) + def className = clazz.getCanonicalName val relation = userSpecifiedSchema match { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") + throw new AnalysisException(s"$className does not allow user-specified schemas.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") + throw new AnalysisException( + s"A schema needs to be specified when using $className.") + case _ => + throw new AnalysisException(s"$className is not a RelationProvider.") } } new ResolvedDataSource(clazz, relation) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff441ef26f9c0..fbc4065a9666c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -108,6 +108,13 @@ class DataFrameSuite extends QueryTest { ) } + test("self join with aliases") { + val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + checkAnswer( + df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = @@ -436,6 +443,50 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("describe") { + val describeTestData = Seq( + ("Bob", 16, 176), + ("Alice", 32, 164), + ("David", 60, 192), + ("Amy", 24, 180)).toDF("name", "age", "height") + + val describeResult = Seq( + Row("count", 4, 4), + Row("mean", 33.0, 178.0), + Row("stddev", 16.583123951777, 10.0), + Row("min", 16, 164), + Row("max", 60, 192)) + + val emptyDescribeResult = Seq( + Row("count", 0, 0), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) + + def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) + + val describeTwoCols = describeTestData.describe("age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + checkAnswer(describeTwoCols, describeResult) + + val describeAllCols = describeTestData.describe() + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + checkAnswer(describeAllCols, describeResult) + + val describeOneCol = describeTestData.describe("age") + assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) + checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + + val describeNoCol = describeTestData.select("name").describe() + assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + + val emptyDescription = describeTestData.limit(0).describe() + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + checkAnswer(emptyDescription, emptyDescribeResult) + } + test("apply on query results (SPARK-5462)") { val df = testData.sqlContext.sql("select key from testData") checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dd0948ad824be..e4dee87849fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -34,7 +34,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -109,7 +109,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") - val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan val planned = planner.HashJoin(join) assert(planned.size === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index f5b945f468dad..36465cc2fa11a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.SparkSqlSerializer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends FunSuite { @@ -50,4 +53,13 @@ class RowSuite extends FunSuite { row(0) = null assert(row.isNullAt(0)) } + + test("serialize w/ kryo") { + val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() + val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val instance = serializer.newInstance() + val ser = instance.serialize(row) + val de = instance.deserialize(ser).asInstanceOf[Row] + assert(de === row) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4d32e84fc1115..6a2c2a7c4080a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -321,6 +321,23 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } + + test("SPARK-6554: don't push down predicates which reference partition columns") { + import sqlContext.implicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.parquetFile(path).filter("part = 1"), + (1 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 321832cd43211..8462f9bb2d620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -212,8 +212,11 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructField("UPPERCase", IntegerType, nullable = true)))) } - // Conflicting field count - assert(intercept[Throwable] { + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + ParquetRelation2.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -221,6 +224,17 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { StructType(Seq( StructField("lowerCase", BinaryType), StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Conflicting field count + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) // Conflicting field names diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4c5eb48661f7d..d1a99555e90c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -459,7 +459,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Write path case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _) @@ -470,7 +470,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) @@ -479,33 +479,35 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) - (relation, parquetRelation, attributedRewrites) + (relation -> relation.output, parquetRelation, attributedRewrites) } + // Quick fix for SPARK-6450: Notice that we're using both the MetastoreRelation instances and + // their output attributes as the key of the map. This is because MetastoreRelation.equals + // doesn't take output attributes into account, thus multiple MetastoreRelation instances + // pointing to the same table get collapsed into a single entry in the map. A proper fix for + // this should be overriding equals & hashCode in MetastoreRelation. val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _)) // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes // attribute IDs referenced in other nodes. plan.transformUp { - case r: MetastoreRelation if relationMap.contains(r) => { - val parquetRelation = relationMap(r) - val withAlias = - r.alias.map(a => Subquery(a, parquetRelation)).getOrElse( - Subquery(r.tableName, parquetRelation)) + case r: MetastoreRelation if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) + val alias = r.alias.getOrElse(r.tableName) + Subquery(alias, parquetRelation) - withAlias - } case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite) - if relationMap.contains(r) => { - val parquetRelation = relationMap(r) + if relationMap.contains(r -> r.output) => + val parquetRelation = relationMap(r -> r.output) InsertIntoTable(parquetRelation, partition, child, overwrite) - } + case other => other.transformExpressions { case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 51775eb4cd6a0..c45c4ad70fae9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -55,37 +55,8 @@ private[hive] case object NativePlaceholder extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( - "TOK_DESCFUNCTION", - "TOK_DESCDATABASE", - "TOK_SHOW_CREATETABLE", - "TOK_SHOWCOLUMNS", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWINDEXES", - "TOK_SHOWPARTITIONS", - "TOK_SHOW_TBLPROPERTIES", - - "TOK_LOCKTABLE", - "TOK_SHOWLOCKS", - "TOK_UNLOCKTABLE", - - "TOK_SHOW_ROLES", - "TOK_CREATEROLE", - "TOK_DROPROLE", - "TOK_GRANT", - "TOK_GRANT_ROLE", - "TOK_REVOKE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - "TOK_SHOW_SET_ROLE", - - "TOK_CREATEFUNCTION", - "TOK_DROPFUNCTION", - - "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERDATABASE_OWNER", + "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", "TOK_ALTERTABLE_ADDCOLS", @@ -102,28 +73,61 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_DROPDATABASE", - "TOK_DROPINDEX", - "TOK_DROPTABLE_PROPERTIES", - "TOK_MSCK", - "TOK_ALTERVIEW_ADDPARTS", "TOK_ALTERVIEW_AS", "TOK_ALTERVIEW_DROPPARTS", "TOK_ALTERVIEW_PROPERTIES", "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", "TOK_CREATEVIEW", - "TOK_DROPVIEW_PROPERTIES", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", "TOK_DROPVIEW", - + "TOK_DROPVIEW_PROPERTIES", + "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_IMPORT", + "TOK_LOAD", - - "TOK_SWITCHDATABASE" + + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" ) // Commands that we do not need to explain. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index af309c0c6ce2c..3563472c7ae81 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} @@ -116,7 +116,7 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) } deserializedHadoopRDD @@ -189,9 +189,13 @@ class HadoopTableReader( val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) + // get the table deserializer + val tableSerDe = tableDesc.getDeserializerClass.newInstance() + tableSerDe.initialize(hconf, tableDesc.getProperties) // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, + mutableRow, tableSerDe) } }.toSeq @@ -261,25 +265,36 @@ private[hive] object HadoopTableReader extends HiveInspectors { * Transform all given raw `Writable`s into `Row`s. * * @param iterator Iterator of all `Writable`s to be transformed - * @param deserializer The `Deserializer` associated with the input `Writable` + * @param rawDeser The `Deserializer` associated with the input `Writable` * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled + * @param tableDeser Table Deserializer * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( iterator: Iterator[Writable], - deserializer: Deserializer, + rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow, + tableDeser: Deserializer): Iterator[Row] = { + + val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { + rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] + } else { + HiveShim.getConvertedOI( + rawDeser.getObjectInspector, + tableDeser.getObjectInspector).asInstanceOf[StructObjectInspector] + } - val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal }.unzip - // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern - // matching and branching costs per row. + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { _.getFieldObjectInspector match { case oi: BooleanObjectInspector => @@ -316,9 +331,11 @@ private[hive] object HadoopTableReader extends HiveInspectors { } } + val converter = ObjectInspectorConverters.getConverter(rawDeser.getObjectInspector, soi) + // Map each tuple to a row object iterator.map { value => - val raw = deserializer.deserialize(value) + val raw = converter.convert(rawDeser.deserialize(value)) var i = 0 while (i < fieldRefs.length) { val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index bfe43373d9534..47305571e579e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - // Cast required to avoid type inference selecting a deprecated Hive API. private val buffer = - function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index dc61e9d2e3522..a3497eadd67f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -23,6 +23,7 @@ import java.util.{Set => JavaSet} import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.RegexSerDe import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -153,8 +154,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r + val vs = new VariableSubstitution() + + // we should substitute variables in hql to pass the text to parseSql() as a parameter. + // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame, + // while we need a logicalPlan so we cannot reuse that. protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(hql)) { + extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) { def hiveExec(): Seq[String] = runSqlHive(hql) override def toString: String = hql + "\n" + super.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 44d24273e722a..221a0c263d36c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -92,12 +92,12 @@ class CachedTableSuite extends QueryTest { } test("Drop cached table") { - sql("CREATE TABLE test(a INT)") - cacheTable("test") - sql("SELECT * FROM test").collect() - sql("DROP TABLE test") + sql("CREATE TABLE cachedTableTest(a INT)") + cacheTable("cachedTableTest") + sql("SELECT * FROM cachedTableTest").collect() + sql("DROP TABLE cachedTableTest") intercept[AnalysisException] { - sql("SELECT * FROM test").collect() + sql("SELECT * FROM cachedTableTest").collect() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 381cd2a29123e..8011952e0d535 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -32,9 +32,12 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) +case class ThreeCloumntable(key: Int, value: String, key1: String) + class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ + val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -186,4 +189,43 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("DROP TABLE hiveTableWithStructValue") } + + test("SPARK-5498:partition schema does not match table schema") { + val testData = TestHive.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val testDatawithNull = TestHive.sparkContext.parallelize( + (1 to 10).map(i => ThreeCloumntable(i, i.toString,null))).toDF() + + val tmpDir = Utils.createTempDir() + sql(s"CREATE TABLE table_with_partition(key int,value string) PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') SELECT key,value FROM testData") + + // test schema the same between partition and table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // test difference type of field + sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") + checkAnswer(sql("select key,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + // add column to table + sql("ALTER TABLE table_with_partition ADD COLUMNS(key1 string)") + checkAnswer(sql("select key,value,key1 from table_with_partition where ds='1' "), + testDatawithNull.collect.toSeq + ) + + // change column name to table + sql("ALTER TABLE table_with_partition CHANGE COLUMN key keynew BIGINT") + checkAnswer(sql("select keynew,value from table_with_partition where ds='1' "), + testData.collect.toSeq + ) + + sql("DROP TABLE table_with_partition") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index ff2e6ea9ea51d..e5ad0bf552073 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -579,7 +579,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { Row(3) :: Row(4) :: Nil ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cb405f56bf53d..d7c5d1a25a82b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -22,7 +22,7 @@ import java.util import java.util.Properties import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} @@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index d891c4e8903d9..432d65a874518 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -292,7 +292,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { Seq(Row(1, "str1")) ) - table("test_parquet_ctas").queryExecution.analyzed match { + table("test_parquet_ctas").queryExecution.optimizedPlan match { case LogicalRelation(p: ParquetRelation2) => // OK case _ => fail( @@ -365,6 +365,31 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE IF EXISTS test_insert_parquet") } + + test("SPARK-6450 regression test") { + sql( + """CREATE TABLE IF NOT EXISTS ms_convert (key INT) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + + // This shouldn't throw AnalysisException + val analyzed = sql( + """SELECT key FROM ms_convert + |UNION ALL + |SELECT key FROM ms_convert + """.stripMargin).queryExecution.analyzed + + assertResult(2) { + analyzed.collect { + case r @ LogicalRelation(_: ParquetRelation2) => r + }.size + } + + sql("DROP TABLE ms_convert") + } } class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 30646ddbc29d8..0ed93c2c5b1fa 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} @@ -210,7 +210,7 @@ private[hive] object HiveShim { def getDataLocationPath(p: Partition) = p.getPartitionPath - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) + def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) def compatibilityBlackList = Seq( "decimal_.*", @@ -244,6 +244,12 @@ private[hive] object HiveShim { } } + def getConvertedOI( + inputOI: ObjectInspector, + outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) + } + def prepareWritable(w: Writable): Writable = { w } diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index f9fcbdae15745..7577309900209 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive +import java.util import java.util.{ArrayList => JArrayList} import java.util.Properties import java.rmi.server.UID @@ -38,7 +39,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable @@ -400,7 +401,11 @@ private[hive] object HiveShim { Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) } } - + + def getConvertedOI(inputOI: ObjectInspector, outputOI: ObjectInspector): ObjectInspector = { + ObjectInspectorConverters.getConvertedOI(inputOI, outputOI) + } + /* * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that * is needed to initialize before serialization. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index db64e11e16304..f73b463d07779 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -67,12 +67,12 @@ object Checkpoint extends Logging { val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r /** Get the checkpoint file for the given checkpoint time */ - def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) } /** Get the checkpoint backup file for the given checkpoint time */ - def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time): Path = { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } @@ -232,6 +232,8 @@ object CheckpointReader extends Logging { def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) + + // TODO(rxin): Why is this a def?! def fs = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 0e285d6088ec1..175140481e5ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -100,11 +100,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams(): Array[InputDStream[_]] = this.synchronized { inputStreams.toArray } - def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams(): Array[DStream[_]] = this.synchronized { outputStreams.toArray } - def getReceiverInputStreams() = this.synchronized { + def getReceiverInputStreams(): Array[ReceiverInputDStream[_]] = this.synchronized { inputStreams.filter(_.isInstanceOf[ReceiverInputDStream[_]]) .map(_.asInstanceOf[ReceiverInputDStream[_]]) .toArray diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index a0d8fb5ab93ec..3249bb348981f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -55,7 +55,6 @@ case class Duration (private val millis: Long) { def div(that: Duration): Double = this / that - def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -71,7 +70,7 @@ case class Duration (private val millis: Long) { def milliseconds: Long = millis - def prettyPrint = Utils.msDurationToString(millis) + def prettyPrint: String = Utils.msDurationToString(millis) } @@ -80,7 +79,7 @@ case class Duration (private val millis: Long) { * a given number of milliseconds. */ object Milliseconds { - def apply(milliseconds: Long) = new Duration(milliseconds) + def apply(milliseconds: Long): Duration = new Duration(milliseconds) } /** @@ -88,7 +87,7 @@ object Milliseconds { * a given number of seconds. */ object Seconds { - def apply(seconds: Long) = new Duration(seconds * 1000) + def apply(seconds: Long): Duration = new Duration(seconds * 1000) } /** @@ -96,7 +95,7 @@ object Seconds { * a given number of minutes. */ object Minutes { - def apply(minutes: Long) = new Duration(minutes * 60000) + def apply(minutes: Long): Duration = new Duration(minutes * 60000) } // Java-friendlier versions of the objects above. @@ -107,16 +106,16 @@ object Durations { /** * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. */ - def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + def milliseconds(milliseconds: Long): Duration = Milliseconds(milliseconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. */ - def seconds(seconds: Long) = Seconds(seconds) + def seconds(seconds: Long): Duration = Seconds(seconds) /** * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. */ - def minutes(minutes: Long) = Minutes(minutes) + def minutes(minutes: Long): Duration = Minutes(minutes) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index ad4f3fdd14ad6..3f5be785e1b1a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -39,18 +39,18 @@ class Interval(val beginTime: Time, val endTime: Time) { this.endTime < that.endTime } - def <= (that: Interval) = (this < that || this == that) + def <= (that: Interval): Boolean = (this < that || this == that) - def > (that: Interval) = !(this <= that) + def > (that: Interval): Boolean = !(this <= that) - def >= (that: Interval) = !(this < that) + def >= (that: Interval): Boolean = !(this < that) - override def toString = "[" + beginTime + ", " + endTime + "]" + override def toString: String = "[" + beginTime + ", " + endTime + "]" } private[streaming] object Interval { - def currentInterval(duration: Duration): Interval = { + def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) val intervalBegin = time.floor(duration) new Interval(intervalBegin, intervalBegin + duration) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 543224d4b07bc..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -188,7 +188,7 @@ class StreamingContext private[streaming] ( /** * Return the associated Spark context */ - def sparkContext = sc + def sparkContext: SparkContext = sc /** * Set each DStreams in this context to remember RDDs it generated in the last given duration. @@ -596,7 +596,8 @@ object StreamingContext extends Logging { @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) + : PairDStreamFunctions[K, V] = { DStream.toPairDStreamFunctions(stream)(kt, vt, ord) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 2eabdd9387913..73030e15c5661 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -415,8 +415,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T implicit val cmv2: ClassTag[V2] = fakeClassTag implicit val cmw: ClassTag[W] = fakeClassTag - def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = { transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd + } dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 7053f47ec69a2..4c28654ef6413 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -176,11 +176,11 @@ private[python] abstract class PythonDStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -212,7 +212,7 @@ private[python] class PythonTransformed2DStream( val func = new TransformFunction(pfunc) - override def dependencies = List(parent, parent2) + override def dependencies: List[DStream[_]] = List(parent, parent2) override def slideDuration: Duration = parent.slideDuration @@ -223,7 +223,7 @@ private[python] class PythonTransformed2DStream( func(Some(rdd1), Some(rdd2), validTime) } - val asJavaDStream = JavaDStream.fromDStream(this) + val asJavaDStream: JavaDStream[Array[Byte]] = JavaDStream.fromDStream(this) } /** @@ -260,12 +260,15 @@ private[python] class PythonReducedWindowedDStream( extends PythonDStream(parent, preduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) - override val mustCheckpoint = true - val invReduceFunc = new TransformFunction(pinvReduceFunc) + override val mustCheckpoint: Boolean = true + + val invReduceFunc: TransformFunction = new TransformFunction(pinvReduceFunc) def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index b874f561c12eb..795c5aa6d585b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -104,7 +104,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context = ssc + def context: StreamingContext = ssc /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() @@ -619,14 +619,16 @@ abstract class DStream[T: ClassTag] ( * operator, so this DStream will be registered as an output stream and there materialized. */ def print(num: Int) { - def foreachFunc = (rdd: RDD[T], time: Time) => { - val firstNum = rdd.take(num + 1) - println ("-------------------------------------------") - println ("Time: " + time) - println ("-------------------------------------------") - firstNum.take(num).foreach(println) - if (firstNum.size > num) println("...") - println() + def foreachFunc: (RDD[T], Time) => Unit = { + (rdd: RDD[T], time: Time) => { + val firstNum = rdd.take(num + 1) + println("-------------------------------------------") + println("Time: " + time) + println("-------------------------------------------") + firstNum.take(num).foreach(println) + if (firstNum.size > num) println("...") + println() + } } new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 0dc72790fbdbd..39fd21342813e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -114,7 +114,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } } - override def toString() = { + override def toString: String = { "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 22de8c02e63c8..66d519171fd76 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -298,7 +298,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] + private def hadoopFiles = data.asInstanceOf[mutable.HashMap[Time, Array[String]]] override def update(time: Time) { hadoopFiles.clear() @@ -320,7 +320,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( } } - override def toString() = { + override def toString: String = { "[\n" + hadoopFiles.size + " file sets\n" + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index c81534ae584ea..fcd5216f101af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -27,7 +27,7 @@ class FilteredDStream[T: ClassTag]( filterFunc: T => Boolean ) extends DStream[T](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 658623455498c..9d09a3baf37ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -28,7 +28,7 @@ class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( flatMapValueFunc: V => TraversableOnce[U] ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index c7bb2833eabb8..475ea2d2d4f38 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -27,7 +27,7 @@ class FlatMappedDStream[T: ClassTag, U: ClassTag]( flatMapFunc: T => Traversable[U] ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 1361c30395b57..685a32e1d280d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -28,7 +28,7 @@ class ForEachDStream[T: ClassTag] ( foreachFunc: (RDD[T], Time) => Unit ) extends DStream[Unit](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index a9bb51f054048..dbb295fe54f71 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -25,7 +25,7 @@ private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index aa1993f0580a8..e652702e213ef 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -61,7 +61,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) } } - override def dependencies = List() + override def dependencies: List[DStream[_]] = List() override def slideDuration: Duration = { if (ssc == null) throw new Exception("ssc is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 3d8ee29df1e82..5994bc1e23f2b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -28,7 +28,7 @@ class MapPartitionedDStream[T: ClassTag, U: ClassTag]( preservePartitioning: Boolean ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 7aea1f945d9db..954d2eb4a7b00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -28,7 +28,7 @@ class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( mapValueFunc: V => U ) extends DStream[(K, U)](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index 02704a8d1c2e0..fa14b2e897c3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -27,7 +27,7 @@ class MappedDStream[T: ClassTag, U: ClassTag] ( mapFunc: T => U ) extends DStream[U](parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index c0a5af0b65cc3..1385ccbf56ee5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -52,7 +52,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Reduce each batch of data using reduceByKey which will be further reduced by window // by ReducedWindowedDStream - val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + private val reducedStream = parent.reduceByKey(reduceFunc, partitioner) // Persist RDDs to memory by default as these RDDs are going to be reused. super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -60,7 +60,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(reducedStream) + override def dependencies: List[DStream[_]] = List(reducedStream) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 880a89bc36895..7757ccac09a58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -33,7 +33,7 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( mapSideCombine: Boolean = true ) extends DStream[(K,C)] (parent.ssc) { - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index ebb04dd35b9a2..de8718d0a80fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -36,7 +36,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 71b61856e23c0..5d46ca0715ffd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -32,7 +32,7 @@ class TransformedDStream[U: ClassTag] ( require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index abbc40befa95b..9405dbaa12329 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -33,17 +33,17 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") - override def dependencies = parents.toList + override def dependencies: List[DStream[_]] = parents.toList override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[T]] = { val rdds = new ArrayBuffer[RDD[T]]() - parents.map(_.getOrCompute(validTime)).foreach(_ match { + parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime) - }) + } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 775b6bfd065c0..899865a906c27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -46,7 +46,7 @@ class WindowedDStream[T: ClassTag]( def windowDuration: Duration = _windowDuration - override def dependencies = List(parent) + override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = _slideDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index dd1e96334952f..93caa4ba35c7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -117,8 +117,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - def segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.getOrElse(segmentLocations) + blockLocations.getOrElse( + HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index a7d63bd4f2dbf..cd309788a7717 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.receiver +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration._ @@ -25,10 +26,10 @@ import scala.reflect.ClassTag import akka.actor._ import akka.actor.SupervisorStrategy.{Escalate, Restart} + import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.StorageLevel -import java.nio.ByteBuffer import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -149,13 +150,13 @@ private[streaming] class ActorReceiver[T: ClassTag]( class Supervisor extends Actor { override val supervisorStrategy = receiverSupervisorStrategy - val worker = context.actorOf(props, name) + private val worker = context.actorOf(props, name) logInfo("Started receiver worker at:" + worker.path) - val n: AtomicInteger = new AtomicInteger(0) - val hiccups: AtomicInteger = new AtomicInteger(0) + private val n: AtomicInteger = new AtomicInteger(0) + private val hiccups: AtomicInteger = new AtomicInteger(0) - def receive = { + override def receive: PartialFunction[Any, Unit] = { case IteratorData(iterator) => logDebug("received iterator") @@ -189,13 +190,12 @@ private[streaming] class ActorReceiver[T: ClassTag]( } } - def onStart() = { + def onStart(): Unit = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) - } - def onStop() = { + def onStop(): Unit = { supervisor ! PoisonPill } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index ee5e639b26d91..42514d8b47dcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -120,7 +120,7 @@ private[streaming] class BlockGenerator( * `BlockGeneratorListener.onAddData` callback will be called. All received data items * will be periodically pushed into BlockManager. */ - def addDataWithCallback(data: Any, metadata: Any) = synchronized { + def addDataWithCallback(data: Any, metadata: Any): Unit = synchronized { waitToPush() currentBuffer += data listener.onAddData(data, metadata) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5acf8a9a811ee..5b5a3fe648602 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -245,7 +245,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * Get the unique identifier the receiver input stream that this * receiver is associated with. */ - def streamId = id + def streamId: Int = id /* * ================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 1f0244c251eba..4943f29395d12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -162,13 +162,13 @@ private[streaming] abstract class ReceiverSupervisor( } /** Check if receiver has been marked for stopping */ - def isReceiverStarted() = { + def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started } /** Check if receiver has been marked for stopping */ - def isReceiverStopped() = { + def isReceiverStopped(): Boolean = { logDebug("state = " + receiverState) receiverState == Stopped } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 7d29ed88cfcb4..8f2f1fef76874 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await -import akka.actor.{Actor, Props} +import akka.actor.{ActorRef, Actor, Props} import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration @@ -83,7 +83,7 @@ private[streaming] class ReceiverSupervisorImpl( private val actor = env.actorSystem.actorOf( Props(new Actor { - override def receive() = { + override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") stop("Stopped by driver", None) @@ -92,7 +92,7 @@ private[streaming] class ReceiverSupervisorImpl( cleanupOldBlocks(threshTime) } - def ref = self + def ref: ActorRef = self }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) /** Unique block ids if one wants to add blocks directly */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7e0f6b2cdfc08..30cf87f5b7dd1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -36,5 +36,5 @@ class Job(val time: Time, func: () => _) { id = "streaming job " + time + "." + number } - override def toString = id + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 59488dfb0f8c6..4946806d2ee95 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -82,7 +82,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") @@ -111,8 +111,8 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val pollTime = 100 // To prevent graceful stop to get stuck permanently - def hasTimedOut = { - val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + def hasTimedOut: Boolean = { + val timedOut = (System.currentTimeMillis() - timeWhenStopStarted) > stopTimeout if (timedOut) { logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") } @@ -133,7 +133,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Stopped generation timer") // Wait for the jobs to complete and checkpoints to be written - def haveAllBatchesBeenProcessed = { + def haveAllBatchesBeenProcessed: Boolean = { lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime } logInfo("Waiting for jobs to be processed and checkpoints to be written") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 60bc099b27a4c..d6a93acbe711b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -56,7 +56,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 8c15a75b1b0e0..5b134877d0b2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,8 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty - ) { + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -48,17 +47,17 @@ case class JobSet( if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted = processingStartTime > 0 + def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) - def processingDelay = processingEndTime - processingStartTime + def processingDelay: Long = processingEndTime - processingStartTime // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay = { + def totalDelay: Long = { processingEndTime - time.milliseconds } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b36aeb341d25e..98900473138fe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -72,7 +72,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private var actor: ActorRef = null /** Start the actor and receiver execution thread. */ - def start() = synchronized { + def start(): Unit = synchronized { if (actor != null) { throw new SparkException("ReceiverTracker already started") } @@ -86,7 +86,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Stop the receiver execution thread. */ - def stop(graceful: Boolean) = synchronized { + def stop(graceful: Boolean): Unit = synchronized { if (!receiverInputStreams.isEmpty && actor != null) { // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) @@ -201,7 +201,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Actor to receive messages from the receivers. */ private class ReceiverTrackerActor extends Actor { - def receive = { + override def receive: PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverActor) => registerReceiver(streamId, typ, host, receiverActor, sender) sender ! true @@ -244,16 +244,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (graceful) { val pollTime = 100 - def done = { receiverInfo.isEmpty && !running } logInfo("Waiting for receiver job to terminate gracefully") - while(!done) { + while (receiverInfo.nonEmpty || running) { Thread.sleep(pollTime) } logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (!receiverInfo.isEmpty) { + if (receiverInfo.nonEmpty) { logWarning("Not all of the receivers have deregistered, " + receiverInfo) } else { logInfo("All of the receivers have deregistered successfully") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 5ee53a5c5f561..e4bd067cacb77 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable.{Queue, HashMap} + import org.apache.spark.streaming.{Time, StreamingContext} import org.apache.spark.streaming.scheduler._ -import scala.collection.mutable.{Queue, HashMap} import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted import org.apache.spark.streaming.scheduler.BatchInfo @@ -59,11 +60,13 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { - runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + synchronized { + runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + } } - override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized { runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) @@ -72,19 +75,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } - override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { - waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) - runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) - completedaBatchInfos.enqueue(batchCompleted.batchInfo) - if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() - totalCompletedBatches += 1L - - batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => - totalProcessedRecords += infos.map(_.numRecords).sum + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + synchronized { + waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) + runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) + completedaBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + totalCompletedBatches += 1L + + batchCompleted.batchInfo.receivedBlockInfo.foreach { case (_, infos) => + totalProcessedRecords += infos.map(_.numRecords).sum + } } } - def numReceivers = synchronized { + def numReceivers: Int = synchronized { ssc.graph.getReceiverInputStreams().size } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index a73d6f3bf0661..4d968f8bfa7a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.OpenHashMap -import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { @@ -71,7 +69,7 @@ object RawTextHelper { var count = 0 while(data.hasNext) { - value = data.next + value = data.next() if (value != null) { count += 1 if (len == 0) { @@ -108,9 +106,13 @@ object RawTextHelper { } } - def add(v1: Long, v2: Long) = (v1 + v2) + def add(v1: Long, v2: Long): Long = { + v1 + v2 + } - def subtract(v1: Long, v2: Long) = (v1 - v2) + def subtract(v1: Long, v2: Long): Long = { + v1 - v2 + } - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long): Long = math.max(v1, v2) }