From d5ea041587aee13e14d0d9337a31c1d368eee9c3 Mon Sep 17 00:00:00 2001 From: geeksheikh Date: Wed, 11 Jan 2023 14:27:19 -0500 Subject: [PATCH] global Session - initialize configs --- .../overwatch/MultiWorkspaceDeployment.scala | 8 +++++++- .../databricks/labs/overwatch/env/Database.scala | 16 +++++++++++++--- .../overwatch/pipeline/BronzeTransforms.scala | 1 - .../labs/overwatch/pipeline/Initializer.scala | 3 ++- .../databricks/labs/overwatch/utils/Config.scala | 10 ++-------- .../overwatch/utils/SparkSessionWrapper.scala | 15 +++++++++++---- 6 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala index 06722e592..dc16c9f62 100644 --- a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala +++ b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala @@ -363,6 +363,12 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { def deploy(parallelism: Int = 4, zones: String = "Bronze,Silver,Gold"): Unit = { val processingStartTime = System.currentTimeMillis(); try { + if (parallelism > 1) SparkSessionWrapper.parSessionsOn = true + SparkSessionWrapper.sessionsMap.clear() + + // initialize spark overrides for global spark conf + PipelineFunctions.setSparkOverrides(spark(globalSession = true), SparkSessionWrapper.globalSparkConfOverrides) + println("ParallelismLevel :" + parallelism) val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath) snapshotConfig(multiWorkspaceConfig) @@ -370,7 +376,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { .performMandatoryValidation(multiWorkspaceConfig, parallelism) .map(buildParams) println("Workspace to be Deployed :" + params.size) - SparkSessionWrapper.parSessionsOn = true + val zoneArray = zones.split(",") zoneArray.foreach(zone => { val responseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala index f7b0c260d..6a1ecec58 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala @@ -269,7 +269,16 @@ class Database(config: Config) extends SparkSessionWrapper { } catch { case e: Throwable => val exceptionMsg = e.getMessage.toLowerCase() - if (exceptionMsg != null && (exceptionMsg.contains("concurrent") || exceptionMsg.contains("conflicting")) && retryCount < 5) { + logger.log(Level.WARN, + s""" + |DELTA Table Write Failure: + |$exceptionMsg + |Attempting Retry + |""".stripMargin) + val concurrentWriteFailure = exceptionMsg.contains("concurrent") || + exceptionMsg.contains("conflicting") || + exceptionMsg.contains("all nested columns must match") + if (exceptionMsg != null && concurrentWriteFailure && retryCount < 5) { coolDown(target.tableFullName) true } else { @@ -309,8 +318,9 @@ class Database(config: Config) extends SparkSessionWrapper { */ private def coolDown(tableName: String): Unit = { val rnd = new scala.util.Random - val number:Long = (rnd.nextFloat() * 30 + 30).toLong*1000 - logger.log(Level.INFO,"Slowing multithreaded writing for " + tableName + "sleeping..." + number+" thread name "+Thread.currentThread().getName) + val number:Long = ((rnd.nextFloat() * 30) + 30 + (rnd.nextFloat() * 30)).toLong*1000 + logger.log(Level.INFO,"DELTA WRITE COOLDOWN: Slowing multithreaded writing for " + + tableName + "sleeping..." + number + " thread name " + Thread.currentThread().getName) Thread.sleep(number) } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala index b7b618f44..02121b9b3 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala @@ -15,7 +15,6 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.eventhubs.{ConnectionStringBuilder, EventHubsConf, EventPosition} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.LongType import org.apache.spark.sql.{AnalysisException, Column, DataFrame} import org.apache.spark.util.SerializableConfiguration diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala index 8be96d630..24cd24fb4 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala @@ -568,7 +568,7 @@ object Initializer extends SparkSessionWrapper { val config = new Config() if(organizationID.isEmpty) { config.setOrganizationId(getOrgId) - }else{ + }else{ // is multiWorkspace deployment since orgID is passed logger.log(Level.INFO, "Setting multiworkspace deployment") config.setOrganizationId(organizationID.get) if (apiUrl.nonEmpty) { @@ -576,6 +576,7 @@ object Initializer extends SparkSessionWrapper { } config.setIsMultiworkspaceDeployment(true) } + // set spark overrides in scoped spark session config.registerInitialSparkConf(spark.conf.getAll) config.setInitialWorkerCount(getNumberOfWorkerNodes) config.setInitialShuffleParts(spark.conf.get("spark.sql.shuffle.partitions").toInt) diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala index bc0f3676b..8bc002f4c 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala @@ -153,15 +153,9 @@ class Config() { "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> value.getOrElse("spark.databricks.delta.optimizeWrite.numShuffleBlocks", "50000"), "spark.databricks.delta.optimizeWrite.binSize" -> - value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512"), - "spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink - "spark.sql.caseSensitive" -> "false", - "spark.sql.autoBroadcastJoinThreshold" -> "10485760", - "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", - "spark.databricks.delta.schema.autoMerge.enabled" -> "true", - "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 + value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512") ) - _initialSparkConf = value ++ manualOverrides + _initialSparkConf = value ++ manualOverrides ++ SparkSessionWrapper.globalSparkConfOverrides this } diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala index 3e9b04b6d..ff5add8cb 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala @@ -13,6 +13,14 @@ object SparkSessionWrapper { var parSessionsOn = false private[overwatch] val sessionsMap = new ConcurrentHashMap[Long, SparkSession]().asScala + private[overwatch] val globalSparkConfOverrides = Map( + "spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink + "spark.sql.caseSensitive" -> "false", + "spark.sql.autoBroadcastJoinThreshold" -> "10485760", + "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", + "spark.databricks.delta.schema.autoMerge.enabled" -> "true", + "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 + ) } @@ -34,7 +42,6 @@ trait SparkSessionWrapper extends Serializable { private def buildSpark(): SparkSession = { - sessionsMap.hashCode() SparkSession .builder() .appName("GlobalSession") @@ -53,7 +60,7 @@ trait SparkSessionWrapper extends Serializable { buildSpark() } else{ - val currentThreadID = Thread.currentThread().getId() + val currentThreadID = Thread.currentThread().getId val sparkSession = sessionsMap.getOrElse(currentThreadID, buildSpark().newSession()) sessionsMap.put(currentThreadID, sparkSession) sparkSession @@ -62,8 +69,8 @@ trait SparkSessionWrapper extends Serializable { buildSpark() } } - @transient - lazy val spark:SparkSession = spark(false) + + @transient lazy val spark:SparkSession = spark(false) lazy val sc: SparkContext = spark.sparkContext // sc.setLogLevel("WARN")