From 3903d27ea5df315ea75317b4e84166e04a10b629 Mon Sep 17 00:00:00 2001 From: Sriram Mohanty <69749553+sriram251-code@users.noreply.github.com> Date: Fri, 3 Feb 2023 20:07:51 +0530 Subject: [PATCH] Spark events bronze par session (#678) * initial commit * Test Spark per session * notes from meeting * persession implemented * persession implemented * persession implemented * persession implemented * persession implemented * persession implemented * pr review implemented * 686 - SparkEvents Executor ID Schema Handler * global Session - initialize configs * Concurrent Writes - Table Lock for Parallelized Loads (#691) * initi commit -- working * code modified preWriteActions added * re-instanted perform retry for legacy deployments and added comments * added logging details * clear sessionsMap on batch runner --------- Co-authored-by: sriram251-code * minor fixes --------- Co-authored-by: Daniel Tomes <10840635+GeekSheikh@users.noreply.github.com> Co-authored-by: geeksheikh --- .../databricks/labs/overwatch/ApiCallV2.scala | 12 +- .../labs/overwatch/BatchRunner.scala | 3 +- .../overwatch/MultiWorkspaceDeployment.scala | 83 +++++---- .../labs/overwatch/env/Database.scala | 171 +++++++++++++++--- .../labs/overwatch/env/Workspace.scala | 8 +- .../overwatch/pipeline/BronzeTransforms.scala | 20 +- .../labs/overwatch/pipeline/Initializer.scala | 3 +- .../labs/overwatch/pipeline/Pipeline.scala | 5 +- .../labs/overwatch/utils/Config.scala | 10 +- .../overwatch/utils/SparkSessionWrapper.scala | 70 +++++-- .../labs/overwatch/ApiCallV2Test.scala | 16 +- 11 files changed, 282 insertions(+), 119 deletions(-) diff --git a/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala b/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala index 1aa1e3aff..7f3e6b149 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala @@ -56,13 +56,12 @@ object ApiCallV2 extends SparkSessionWrapper { * @param accumulator To make track of number of api request. * @return */ - def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], tempSuccessPath: String, accumulator: LongAccumulator): ApiCallV2 = { + def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], tempSuccessPath: String): ApiCallV2 = { new ApiCallV2(apiEnv) .setEndPoint(apiName) .buildMeta(apiName) .setQueryMap(queryMap) .setSuccessTempPath(tempSuccessPath) - .setAccumulator(accumulator) } /** @@ -125,9 +124,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { private var _apiFailureCount: Int = 0 private var _printFinalStatusFlag: Boolean = true private var _queryMap: Map[String, String] = Map[String, String]() - private var _accumulator: LongAccumulator = sc.longAccumulator("ApiAccumulator") //Multithreaded call accumulator will make track of the request. - protected def accumulator: LongAccumulator = _accumulator protected def apiSuccessCount: Int = _apiSuccessCount protected def apiFailureCount: Int = _apiFailureCount @@ -148,10 +145,6 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { protected def queryMap: Map[String, String] = _queryMap - private[overwatch] def setAccumulator(value: LongAccumulator): this.type = { - _accumulator = value - this - } private[overwatch] def setApiV(value: Double): this.type = { apiMeta.setApiV("api/"+value) @@ -549,7 +542,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { * Performs api calls in parallel. * @return */ - def executeMultiThread(): util.ArrayList[String] = { + def executeMultiThread(accumulator: LongAccumulator): util.ArrayList[String] = { @tailrec def executeThreadedHelper(): util.ArrayList[String] = { val response = getResponse responseCodeHandler(response) @@ -607,7 +600,6 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { responseCodeHandler(response) _apiResponseArray.add(response.body) if (apiMeta.storeInTempLocation && successTempPath.nonEmpty) { - accumulator.add(1) if (apiEnv.successBatchSize <= _apiResponseArray.size()) { //Checking if its right time to write the batches into persistent storage val responseFlag = PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, _apiResponseArray.toString) if (responseFlag) { //Clearing the resultArray in-case of successful write diff --git a/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala b/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala index 7f7892ba4..809ca9f84 100644 --- a/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala +++ b/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala @@ -7,6 +7,8 @@ import org.apache.log4j.{Level, Logger} object BatchRunner extends SparkSessionWrapper { private val logger: Logger = Logger.getLogger(this.getClass) + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() private def setGlobalDeltaOverrides(): Unit = { spark.conf.set("spark.databricks.delta.optimize.maxFileSize", 1024 * 1024 * 128) @@ -68,7 +70,6 @@ object BatchRunner extends SparkSessionWrapper { Gold(workspace).run() } - } diff --git a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala index 29993561f..fc0657a84 100644 --- a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala +++ b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala @@ -2,6 +2,7 @@ package com.databricks.labs.overwatch import com.databricks.labs.overwatch.pipeline.TransformFunctions._ import com.databricks.labs.overwatch.pipeline._ +import com.databricks.labs.overwatch.utils.SparkSessionWrapper.parSessionsOn import com.databricks.labs.overwatch.utils._ import com.databricks.labs.overwatch.validation.DeploymentValidation import org.apache.log4j.{Level, Logger} @@ -191,6 +192,8 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { fullMsg, Some(multiWorkspaceParams.deploymentId) )) + } finally { + clearThreadFromSessionsMap() } } @@ -215,6 +218,8 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { fullMsg, Some(multiWorkspaceParams.deploymentId) )) + } finally { + clearThreadFromSessionsMap() } } @@ -239,6 +244,8 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { fullMsg, Some(multiWorkspaceParams.deploymentId) )) + }finally { + clearThreadFromSessionsMap() } } @@ -355,42 +362,54 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { */ def deploy(parallelism: Int = 4, zones: String = "Bronze,Silver,Gold"): Unit = { val processingStartTime = System.currentTimeMillis(); - println("ParallelismLevel :" + parallelism) - - val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath) - snapshotConfig(multiWorkspaceConfig) - val params = DeploymentValidation - .performMandatoryValidation(multiWorkspaceConfig, parallelism) - .map(buildParams) - - println("Workspace to be Deployed :" + params.size) - val zoneArray = zones.split(",") - zoneArray.foreach(zone => { - val responseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(parallelism)) - params.foreach(deploymentParams => { - val future = Future { - zone.toLowerCase match { - case "bronze" => - startBronzeDeployment(deploymentParams) - case "silver" => - startSilverDeployment(deploymentParams) - case "gold" => - startGoldDeployment(deploymentParams) + try { + if (parallelism > 1) SparkSessionWrapper.parSessionsOn = true + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() + + // initialize spark overrides for global spark conf + PipelineFunctions.setSparkOverrides(spark(globalSession = true), SparkSessionWrapper.globalSparkConfOverrides) + + println("ParallelismLevel :" + parallelism) + val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath) + snapshotConfig(multiWorkspaceConfig) + val params = DeploymentValidation + .performMandatoryValidation(multiWorkspaceConfig, parallelism) + .map(buildParams) + println("Workspace to be Deployed :" + params.size) + + val zoneArray = zones.split(",") + zoneArray.foreach(zone => { + val responseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) + implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(parallelism)) + params.foreach(deploymentParams => { + val future = Future { + zone.toLowerCase match { + case "bronze" => + startBronzeDeployment(deploymentParams) + case "silver" => + startSilverDeployment(deploymentParams) + case "gold" => + startGoldDeployment(deploymentParams) + } } - } - future.onComplete { - case _ => - responseCounter.add(1) + future.onComplete { + case _ => + responseCounter.add(1) + } + }) + while (responseCounter.size() < params.length) { + Thread.sleep(5000) } }) - while (responseCounter.size() < params.length) { - Thread.sleep(5000) - } - }) - saveDeploymentReport(deploymentReport, multiWorkspaceConfig.head.etl_storage_prefix, "deploymentReport") + saveDeploymentReport(deploymentReport, multiWorkspaceConfig.head.etl_storage_prefix, "deploymentReport") + } catch { + case e: Exception => throw e + } finally { + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() + } println(s"""Deployment completed in sec ${(System.currentTimeMillis() - processingStartTime) / 1000}""") - } /** diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala index f7b0c260d..b3e22e3de 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala @@ -13,6 +13,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, Row} import java.util import java.util.UUID import scala.annotation.tailrec +import scala.util.Random class Database(config: Config) extends SparkSessionWrapper { @@ -166,12 +167,16 @@ class Database(config: Config) extends SparkSessionWrapper { .map(_.name).headOption } - // TODO - refactor this write function and the writer from the target - // write function has gotten overly complex - def write(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column, maxMergeScanDates: Array[String] = Array()): Boolean = { + /** + * Add metadata, cleanse duplicates, and persistBeforeWrite for perf as needed based on target + * + * @param df original dataframe that is to be written + * @param target target to which the df is to be written + * @param pipelineSnapTime pipelineSnapTime + * @return + */ + private def preWriteActions(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column): DataFrame = { var finalSourceDF: DataFrame = df - - // apend metadata to source DF finalSourceDF = if (target.withCreateDate) finalSourceDF.withColumn("Pipeline_SnapTS", pipelineSnapTime) else finalSourceDF finalSourceDF = if (target.withOverwatchRunID) finalSourceDF.withColumn("Overwatch_RunID", lit(config.runID)) else finalSourceDF finalSourceDF = if (target.workspaceName) finalSourceDF.withColumn("workspace_name", lit(config.workspaceName)) else finalSourceDF @@ -180,6 +185,35 @@ class Database(config: Config) extends SparkSessionWrapper { finalSourceDF = if (!target.permitDuplicateKeys) finalSourceDF.dedupByKey(target.keys, target.incrementalColumns) else finalSourceDF val finalDF = if (target.persistBeforeWrite) persistAndLoad(finalSourceDF, target) else finalSourceDF + finalDF + } + + // TODO - refactor this write function and the writer from the target + // write function has gotten overly complex + + /** + * Write the dataframe to the target + * + * @param df df with data to be written + * @param target target to where data should be saved + * @param pipelineSnapTime pipelineSnapTime + * @param maxMergeScanDates perf -- max dates to scan on a merge write. If populated this will filter the target source + * df to minimize the right-side scan. + * @param preWritesPerformed whether pre-write actions have already been completed prior to calling write. This is + * defaulted to false and should only be set to true in specific circumstances when + * preWriteActions was called prior to calling the write function. Typically used for + * concurrency locking. + * @return + */ + def write(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column, maxMergeScanDates: Array[String] = Array(), preWritesPerformed: Boolean = false): Unit = { + val finalSourceDF: DataFrame = df + + // append metadata to source DF and cleanse as necessary + val finalDF = if (!preWritesPerformed) { + preWriteActions(finalSourceDF, target, pipelineSnapTime) + } else { + finalSourceDF + } // ON FIRST RUN - WriteMode is automatically overwritten to APPEND if (target.writeMode == WriteMode.merge) { // DELTA MERGE / UPSERT @@ -254,30 +288,59 @@ class Database(config: Config) extends SparkSessionWrapper { logger.log(Level.INFO, s"Completed write to ${target.tableFullName}") } registerTarget(target) - true } + /** + * Function forces the thread to sleep for some specified time. + * @param tableName Name of table for logging only + * @param minimumSeconds minimum number of seconds for which to sleep + * @param maxRandomSeconds max random seconds to add + */ + private def coolDown(tableName: String, minimumSeconds: Long, maxRandomSeconds: Long): Unit = { + val rnd = new scala.util.Random + val number: Long = ((rnd.nextFloat() * maxRandomSeconds) + minimumSeconds).toLong * 1000L + logger.log(Level.INFO,"Slowing parallel writes to " + tableName + "sleeping..." + number + + " thread name " + Thread.currentThread().getName) + Thread.sleep(number) + } - def performRetry(inputDf: DataFrame, + /** + * Perform write retry after cooldown for legacy deployments + * race conditions can occur when multiple workspaces attempt to modify the schema at the same time, when this + * occurs, simply retry the write after some cooldown + * @param inputDf DF to write + * @param target target to where to write + * @param pipelineSnapTime pipelineSnapTime + * @param maxMergeScanDates same as write function + */ + private def performRetry(inputDf: DataFrame, target: PipelineTable, pipelineSnapTime: Column, - maxMergeScanDates: Array[String] = Array()): this.type = { - @tailrec def executeRetry(retryCount: Int): this.type = { + maxMergeScanDates: Array[String] = Array()): Unit = { + @tailrec def executeRetry(retryCount: Int): Unit = { val rerunFlag = try { write(inputDf, target, pipelineSnapTime, maxMergeScanDates) false } catch { case e: Throwable => val exceptionMsg = e.getMessage.toLowerCase() + logger.log(Level.WARN, + s""" + |DELTA Table Write Failure: + |$exceptionMsg + |Will Retry After a small delay. + |This is usually caused by multiple writes attempting to evolve the schema simultaneously + |""".stripMargin) if (exceptionMsg != null && (exceptionMsg.contains("concurrent") || exceptionMsg.contains("conflicting")) && retryCount < 5) { - coolDown(target.tableFullName) + coolDown(target.tableFullName, 30, 30) true } else { throw e } } - if (retryCount < 5 && rerunFlag) executeRetry(retryCount + 1) else this + if (retryCount < 5 && rerunFlag) executeRetry(retryCount + 1) } + try { executeRetry(1) } catch { @@ -287,11 +350,59 @@ class Database(config: Config) extends SparkSessionWrapper { } - def writeWithRetry(df: DataFrame, - target: PipelineTable, - pipelineSnapTime: Column, - maxMergeScanDates: Array[String] = Array(), - daysToProcess: Option[Int] = None): Boolean = { + /** + * Used for multithreaded multiworkspace deployments + * Check if a table is locked and if it is wait until max timeout for it to be unlocked and fail the write if + * the timeout is reached + * Table Lock Timeout can be overridden by setting cluster spark config overwatch.tableLockTimeout -- in milliseconds + * default table lock timeout is 20 minutes or 1200000 millis + * + * @param tableName name of the table to be written + * @return + */ + private def targetNotLocked(tableName: String): Boolean = { + val defaultTimeout: String = "1200000" + val timeout = spark(globalSession = true).conf.getOption("overwatch.tableLockTimeout").getOrElse(defaultTimeout).toLong + val timerStart = System.currentTimeMillis() + + @tailrec def testLock(retryCount: Int): Boolean = { + val currWaitTime = System.currentTimeMillis() - timerStart + val withinTimeout = currWaitTime < timeout + if (SparkSessionWrapper.globalTableLock.contains(tableName)) { + if (withinTimeout) { + logger.log(Level.WARN, s"TABLE LOCKED: $tableName for $currWaitTime -- waiting for parallel writes to complete") + coolDown(tableName, retryCount * 10, 2) // add 10 to 12 seconds to sleep for each lock test + testLock(retryCount + 1) + } else { + throw new Exception(s"TABLE LOCK TIMEOUT - The table $tableName remained locked for more than the configured " + + s"max timeout of $timeout millis. This may be increased by setting the following spark config in the cluster" + + s"to something higher than the default (20 minutes). Usually only necessary for historical loads. \n" + + s"overwatch.tableLockTimeout") + } + } else true // table not locked + } + + testLock(retryCount = 1) + } + + + /** + * Wrapper for the write function + * If legacy deployment retry delta write in event of race condition + * If multiworkspace -- implement table locking to alleviate race conditions on parallelize workspace loads + * + * @param df + * @param target + * @param pipelineSnapTime + * @param maxMergeScanDates + * @param daysToProcess + * @return + */ + private[overwatch] def writeWithRetry(df: DataFrame, + target: PipelineTable, + pipelineSnapTime: Column, + maxMergeScanDates: Array[String] = Array(), + daysToProcess: Option[Int] = None): Unit = { val needsCache = daysToProcess.getOrElse(1000) < 5 && !target.autoOptimize val inputDf = if (needsCache) { @@ -299,21 +410,25 @@ class Database(config: Config) extends SparkSessionWrapper { df.persist() } else df if (needsCache) inputDf.count() - performRetry(inputDf,target, pipelineSnapTime, maxMergeScanDates) - true - } - /** - * Function forces the thread to sleep for a random 30-60 seconds. - * @param tableName - */ - private def coolDown(tableName: String): Unit = { - val rnd = new scala.util.Random - val number:Long = (rnd.nextFloat() * 30 + 30).toLong*1000 - logger.log(Level.INFO,"Slowing multithreaded writing for " + tableName + "sleeping..." + number+" thread name "+Thread.currentThread().getName) - Thread.sleep(number) + if (!target.config.isMultiworkspaceDeployment) { // legacy deployment + performRetry(inputDf, target, pipelineSnapTime, maxMergeScanDates) + } else { // multi-workspace deployment + val withMetaDf = preWriteActions(df, target, pipelineSnapTime) + if (targetNotLocked(target.tableFullName)) { + try { + SparkSessionWrapper.globalTableLock.add(target.tableFullName) + write(withMetaDf, target, pipelineSnapTime, maxMergeScanDates, preWritesPerformed = true) + } catch { + case e: Throwable => throw e + } finally { + SparkSessionWrapper.globalTableLock.remove(target.tableFullName) + } + } + } } + } object Database { diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala index 4550db71c..c138dcd5f 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala @@ -150,8 +150,7 @@ class Workspace(config: Config) extends SparkSessionWrapper { config.apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = s"${config.tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}", - accumulator = acc + tempSuccessPath = s"${config.tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}" ) .execute() .asDF() @@ -195,9 +194,8 @@ class Workspace(config: Config) extends SparkSessionWrapper { config.apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = tmpSqlQueryHistorySuccessPath, - accumulator = acc - ).executeMultiThread() + tempSuccessPath = tmpSqlQueryHistorySuccessPath + ).executeMultiThread(acc) synchronized { apiObj.forEach( diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala index d4f357f58..526aa51ec 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala @@ -478,7 +478,7 @@ trait BronzeTransforms extends SparkSessionWrapper { "limit" -> "500" ) val future = Future { - val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath,accumulator).executeMultiThread() + val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath).executeMultiThread(accumulator) synchronized { apiObj.forEach( obj => if (obj.contains("events")) { @@ -741,7 +741,7 @@ trait BronzeTransforms extends SparkSessionWrapper { val pathsGlob = validNewFilesWMetaDF .filter(!'failed && 'withinSpecifiedTimeRange) .orderBy('fileSize.desc) - .select('fileName) + .select('filename) .as[String].collect if (pathsGlob.nonEmpty) { // new files less bad files and already-processed files logger.log(Level.INFO, s"VALID NEW EVENT LOGS FOUND: COUNT --> ${pathsGlob.length}") @@ -834,10 +834,20 @@ trait BronzeTransforms extends SparkSessionWrapper { .otherwise(col("Stage Attempt ID")) } else col("Stage Attempt ID") + // raw data contains both "Executor ID" and "executorId" at root for different events + val executorIdOverride: Column = if(baseEventsDF.columns.contains("Executor ID")) { + if (baseEventsDF.columns.contains("executorId")) { // blacklisted executor ids cannot exist if executor ids do not + concat(col("Executor ID"), 'executorId) + } else col("Executor ID") + } else { // handle missing Executor ID field + lit(null).cast("long") + } + val bronzeSparkEventsScrubber = getSparkEventsSchemaScrubber(baseEventsDF) val rawScrubbed = if (baseEventsDF.columns.count(_.toLowerCase().replace(" ", "") == "stageid") > 1) { baseEventsDF + .withColumn("Executor ID", executorIdOverride) .withColumn("progress", progressCol) .withColumn("filename", input_file_name) .withColumn("pathSize", size(split('filename, "/"))) @@ -845,17 +855,18 @@ trait BronzeTransforms extends SparkSessionWrapper { .withColumn("clusterId", split('filename, "/")('pathSize - lit(5))) .withColumn("StageID", stageIDColumnOverride) .withColumn("StageAttemptID", stageAttemptIDColumnOverride) - .drop("pathSize", "Stage ID", "stageId", "Stage Attempt ID", "stageAttemptId") + .drop("pathSize", "executorId", "Stage ID", "stageId", "Stage Attempt ID", "stageAttemptId") .withColumn("filenameGroup", groupFilename('filename)) .scrubSchema(bronzeSparkEventsScrubber) } else { baseEventsDF + .withColumn("Executor ID", executorIdOverride) .withColumn("progress", progressCol) .withColumn("filename", input_file_name) .withColumn("pathSize", size(split('filename, "/"))) .withColumn("SparkContextId", split('filename, "/")('pathSize - lit(2))) .withColumn("clusterId", split('filename, "/")('pathSize - lit(5))) - .drop("pathSize") + .drop("pathSize", "executorId") .withColumn("filenameGroup", groupFilename('filename)) .scrubSchema(bronzeSparkEventsScrubber) } @@ -866,7 +877,6 @@ trait BronzeTransforms extends SparkSessionWrapper { rawScrubbed.withColumn("Properties", SchemaTools.structToMap(rawScrubbed, "Properties")) .withColumn("modifiedConfigs", SchemaTools.structToMap(rawScrubbed, "modifiedConfigs")) .withColumn("extraTags", SchemaTools.structToMap(rawScrubbed, "extraTags")) - .withColumnRenamed("executorId", "blackListedExecutorIds") .join(eventLogsDF, Seq("filename")) .withColumn("organization_id", lit(organizationId)) .withColumn("Properties", expr("map_filter(Properties, (k,v) -> k not in ('sparkexecutorextraClassPath'))")) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala index 8be96d630..24cd24fb4 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala @@ -568,7 +568,7 @@ object Initializer extends SparkSessionWrapper { val config = new Config() if(organizationID.isEmpty) { config.setOrganizationId(getOrgId) - }else{ + }else{ // is multiWorkspace deployment since orgID is passed logger.log(Level.INFO, "Setting multiworkspace deployment") config.setOrganizationId(organizationID.get) if (apiUrl.nonEmpty) { @@ -576,6 +576,7 @@ object Initializer extends SparkSessionWrapper { } config.setIsMultiworkspaceDeployment(true) } + // set spark overrides in scoped spark session config.registerInitialSparkConf(spark.conf.getAll) config.setInitialWorkerCount(getNumberOfWorkerNodes) config.setInitialShuffleParts(spark.conf.get("spark.sql.shuffle.partitions").toInt) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala index f0452f35b..0b60c5ca2 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala @@ -384,8 +384,9 @@ class Pipeline( dbutils.fs.rm(config.tempWorkingDir) postProcessor.refreshPipReportView(pipelineStateViewTarget) - - spark.catalog.clearCache() + //TODO clearcache will clear global cache multithread performance issue + // spark.catalog.clearCache() + clearThreadFromSessionsMap() } private[overwatch] def restoreSparkConf(): Unit = { diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala index bc0f3676b..8bc002f4c 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala @@ -153,15 +153,9 @@ class Config() { "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> value.getOrElse("spark.databricks.delta.optimizeWrite.numShuffleBlocks", "50000"), "spark.databricks.delta.optimizeWrite.binSize" -> - value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512"), - "spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink - "spark.sql.caseSensitive" -> "false", - "spark.sql.autoBroadcastJoinThreshold" -> "10485760", - "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", - "spark.databricks.delta.schema.autoMerge.enabled" -> "true", - "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 + value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512") ) - _initialSparkConf = value ++ manualOverrides + _initialSparkConf = value ++ manualOverrides ++ SparkSessionWrapper.globalSparkConfOverrides this } diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala index 18c5bae76..177e087bf 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala @@ -1,8 +1,30 @@ package com.databricks.labs.overwatch.utils +import com.databricks.labs.overwatch.utils.SparkSessionWrapper.parSessionsOn import org.apache.log4j.{Level, Logger} import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession +import org.eclipse.jetty.util.ConcurrentHashSet + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + + +object SparkSessionWrapper { + + var parSessionsOn = false + private[overwatch] val sessionsMap = new ConcurrentHashMap[Long, SparkSession]().asScala + private[overwatch] val globalTableLock = new ConcurrentHashSet[String] + private[overwatch] val globalSparkConfOverrides = Map( + "spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink + "spark.sql.caseSensitive" -> "false", + "spark.sql.autoBroadcastJoinThreshold" -> "10485760", + "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", + "spark.databricks.delta.schema.autoMerge.enabled" -> "true", + "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 + ) + +} /** * Enables access to the Spark variable. @@ -12,42 +34,54 @@ import org.apache.spark.sql.SparkSession trait SparkSessionWrapper extends Serializable { private val logger: Logger = Logger.getLogger(this.getClass) + private val sessionsMap = SparkSessionWrapper.sessionsMap /** * Init environment. This structure alows for multiple calls to "reinit" the environment. Important in the case of * autoscaling. When the cluster scales up/down envInit and then check for current cluster cores. */ - @transient lazy protected val _envInit: Boolean = envInit() + + private def buildSpark(): SparkSession = { + SparkSession + .builder() + .appName("GlobalSession") + .getOrCreate() + } /** * Access to spark * If testing locally or using DBConnect, the System variable "OVERWATCH" is set to "LOCAL" to make the code base * behavior differently to work in remote execution AND/OR local only mode but local only mode * requires some additional setup. */ - lazy val spark: SparkSession = if (System.getenv("OVERWATCH") != "LOCAL") { - logger.log(Level.INFO, "Using Databricks SparkSession") - SparkSession - .builder().master("local").appName("OverwatchBatch") - .getOrCreate() - } else { - logger.log(Level.INFO, "Using Custom, local SparkSession") - SparkSession.builder() - .master("local[*]") - .config("spark.driver.maxResultSize", "8g") - .appName("OverwatchBatch") -// Useful configs for local spark configs and/or using labs/spark-local-execution -// https://github.com/databricks-academy/spark-local-execution -// .config("spark.driver.bindAddress", "0.0.0.0") -// .enableHiveSupport() -// .config("spark.warehouse.dir", "metastore") - .getOrCreate() + private[overwatch] def spark(globalSession : Boolean = false): SparkSession = { + + if(SparkSessionWrapper.parSessionsOn){ + if(globalSession){ + buildSpark() + } + else{ + val currentThreadID = Thread.currentThread().getId + val sparkSession = sessionsMap.getOrElse(currentThreadID, buildSpark().newSession()) + sessionsMap.put(currentThreadID, sparkSession) + sparkSession + } + }else{ + buildSpark() + } } + @transient lazy val spark:SparkSession = spark(false) + lazy val sc: SparkContext = spark.sparkContext // sc.setLogLevel("WARN") + protected def clearThreadFromSessionsMap(): Unit ={ + sessionsMap.remove(Thread.currentThread().getId) + logger.log(Level.INFO, s"""Removed ${Thread.currentThread().getId} from sessionMap""") + } + def getCoresPerWorker: Int = sc.parallelize("1", 1) .map(_ => java.lang.Runtime.getRuntime.availableProcessors).collect()(0) diff --git a/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala b/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala index 15f4473bb..647fece6d 100644 --- a/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala +++ b/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala @@ -2,7 +2,7 @@ package com.databricks.labs.overwatch import com.databricks.labs.overwatch.ApiCallV2.sc import com.databricks.labs.overwatch.pipeline.PipelineFunctions -import com.databricks.labs.overwatch.utils.{ApiCallFailureV2, ApiEnv} +import com.databricks.labs.overwatch.utils.{ApiCallFailureV2, ApiEnv, SparkSessionWrapper} import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.lit @@ -96,8 +96,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { apiEnv, endPoint, query, - tempSuccessPath = "src/test/scala/tempDir/sqlqueryhistory", - accumulator = acc + tempSuccessPath = "src/test/scala/tempDir/sqlqueryhistory" ).execute().asDF() != null) } @@ -206,6 +205,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { assert(oldAPI.count() == newAPI.count() && oldAPI.except(newAPI).count() == 0 && newAPI.except(oldAPI).count() == 0) } it("test multithreading") { + SparkSessionWrapper.parSessionsOn =true val endPoint = "clusters/list" val clusterIDsDf = ApiCallV2(apiEnv, endPoint).execute().asDF().select("cluster_id") clusterIDsDf.show(false) @@ -226,7 +226,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { ) println(jsonQuery) val future = Future { - val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath,accumulator).executeMultiThread() + val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath).executeMultiThread(accumulator) synchronized { apiResponseArray.addAll(apiObj) if (apiResponseArray.size() >= apiEnv.successBatchSize) { @@ -306,8 +306,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = s"${tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}", - accumulator = acc + tempSuccessPath = s"${tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}" ) .execute() .asDF() @@ -371,9 +370,8 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = tmpSqlQueryHistorySuccessPath, - accumulator = acc - ).executeMultiThread() + tempSuccessPath = tmpSqlQueryHistorySuccessPath + ).executeMultiThread(acc) synchronized { apiObj.forEach(