Skip to content

Commit

Permalink
global Session - initialize configs
Browse files Browse the repository at this point in the history
  • Loading branch information
GeekSheikh committed Jan 12, 2023
1 parent 01e2192 commit d5ea041
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,20 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper {
def deploy(parallelism: Int = 4, zones: String = "Bronze,Silver,Gold"): Unit = {
val processingStartTime = System.currentTimeMillis();
try {
if (parallelism > 1) SparkSessionWrapper.parSessionsOn = true
SparkSessionWrapper.sessionsMap.clear()

// initialize spark overrides for global spark conf
PipelineFunctions.setSparkOverrides(spark(globalSession = true), SparkSessionWrapper.globalSparkConfOverrides)

println("ParallelismLevel :" + parallelism)
val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath)
snapshotConfig(multiWorkspaceConfig)
val params = DeploymentValidation
.performMandatoryValidation(multiWorkspaceConfig, parallelism)
.map(buildParams)
println("Workspace to be Deployed :" + params.size)
SparkSessionWrapper.parSessionsOn = true

val zoneArray = zones.split(",")
zoneArray.foreach(zone => {
val responseCounter = Collections.synchronizedList(new util.ArrayList[Int]())
Expand Down
16 changes: 13 additions & 3 deletions src/main/scala/com/databricks/labs/overwatch/env/Database.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,16 @@ class Database(config: Config) extends SparkSessionWrapper {
} catch {
case e: Throwable =>
val exceptionMsg = e.getMessage.toLowerCase()
if (exceptionMsg != null && (exceptionMsg.contains("concurrent") || exceptionMsg.contains("conflicting")) && retryCount < 5) {
logger.log(Level.WARN,
s"""
|DELTA Table Write Failure:
|$exceptionMsg
|Attempting Retry
|""".stripMargin)
val concurrentWriteFailure = exceptionMsg.contains("concurrent") ||
exceptionMsg.contains("conflicting") ||
exceptionMsg.contains("all nested columns must match")
if (exceptionMsg != null && concurrentWriteFailure && retryCount < 5) {
coolDown(target.tableFullName)
true
} else {
Expand Down Expand Up @@ -309,8 +318,9 @@ class Database(config: Config) extends SparkSessionWrapper {
*/
private def coolDown(tableName: String): Unit = {
val rnd = new scala.util.Random
val number:Long = (rnd.nextFloat() * 30 + 30).toLong*1000
logger.log(Level.INFO,"Slowing multithreaded writing for " + tableName + "sleeping..." + number+" thread name "+Thread.currentThread().getName)
val number:Long = ((rnd.nextFloat() * 30) + 30 + (rnd.nextFloat() * 30)).toLong*1000
logger.log(Level.INFO,"DELTA WRITE COOLDOWN: Slowing multithreaded writing for " +
tableName + "sleeping..." + number + " thread name " + Thread.currentThread().getName)
Thread.sleep(number)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import org.apache.log4j.{Level, Logger}
import org.apache.spark.eventhubs.{ConnectionStringBuilder, EventHubsConf, EventPosition}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.{AnalysisException, Column, DataFrame}
import org.apache.spark.util.SerializableConfiguration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,14 +568,15 @@ object Initializer extends SparkSessionWrapper {
val config = new Config()
if(organizationID.isEmpty) {
config.setOrganizationId(getOrgId)
}else{
}else{ // is multiWorkspace deployment since orgID is passed
logger.log(Level.INFO, "Setting multiworkspace deployment")
config.setOrganizationId(organizationID.get)
if (apiUrl.nonEmpty) {
config.setApiUrl(apiUrl)
}
config.setIsMultiworkspaceDeployment(true)
}
// set spark overrides in scoped spark session
config.registerInitialSparkConf(spark.conf.getAll)
config.setInitialWorkerCount(getNumberOfWorkerNodes)
config.setInitialShuffleParts(spark.conf.get("spark.sql.shuffle.partitions").toInt)
Expand Down
10 changes: 2 additions & 8 deletions src/main/scala/com/databricks/labs/overwatch/utils/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,9 @@ class Config() {
"spark.databricks.delta.optimizeWrite.numShuffleBlocks" ->
value.getOrElse("spark.databricks.delta.optimizeWrite.numShuffleBlocks", "50000"),
"spark.databricks.delta.optimizeWrite.binSize" ->
value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512"),
"spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink
"spark.sql.caseSensitive" -> "false",
"spark.sql.autoBroadcastJoinThreshold" -> "10485760",
"spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760",
"spark.databricks.delta.schema.autoMerge.enabled" -> "true",
"spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365
value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512")
)
_initialSparkConf = value ++ manualOverrides
_initialSparkConf = value ++ manualOverrides ++ SparkSessionWrapper.globalSparkConfOverrides
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ object SparkSessionWrapper {

var parSessionsOn = false
private[overwatch] val sessionsMap = new ConcurrentHashMap[Long, SparkSession]().asScala
private[overwatch] val globalSparkConfOverrides = Map(
"spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink
"spark.sql.caseSensitive" -> "false",
"spark.sql.autoBroadcastJoinThreshold" -> "10485760",
"spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760",
"spark.databricks.delta.schema.autoMerge.enabled" -> "true",
"spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365
)

}

Expand All @@ -34,7 +42,6 @@ trait SparkSessionWrapper extends Serializable {


private def buildSpark(): SparkSession = {
sessionsMap.hashCode()
SparkSession
.builder()
.appName("GlobalSession")
Expand All @@ -53,7 +60,7 @@ trait SparkSessionWrapper extends Serializable {
buildSpark()
}
else{
val currentThreadID = Thread.currentThread().getId()
val currentThreadID = Thread.currentThread().getId
val sparkSession = sessionsMap.getOrElse(currentThreadID, buildSpark().newSession())
sessionsMap.put(currentThreadID, sparkSession)
sparkSession
Expand All @@ -62,8 +69,8 @@ trait SparkSessionWrapper extends Serializable {
buildSpark()
}
}
@transient
lazy val spark:SparkSession = spark(false)

@transient lazy val spark:SparkSession = spark(false)

lazy val sc: SparkContext = spark.sparkContext
// sc.setLogLevel("WARN")
Expand Down

0 comments on commit d5ea041

Please sign in to comment.