diff --git a/build.sbt b/build.sbt index 56fa8ce08..be73e8bb5 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ name := "overwatch" organization := "com.databricks.labs" -version := "0.7.1.0" +version := "0.7.1.1" scalaVersion := "2.12.12" scalacOptions ++= Seq("-Xmax-classfile-name", "78") diff --git a/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala b/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala index 1aa1e3aff..3086046df 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ApiCallV2.scala @@ -1,5 +1,6 @@ package com.databricks.labs.overwatch +import com.databricks.labs.overwatch.ApiCallV2.sc import com.databricks.labs.overwatch.pipeline.PipelineFunctions import com.databricks.labs.overwatch.utils._ import com.fasterxml.jackson.databind.ObjectMapper @@ -12,7 +13,12 @@ import org.json.JSONObject import scalaj.http.{HttpOptions, HttpResponse} import java.util +import java.util.Collections +import java.util.concurrent.Executors import scala.annotation.tailrec +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +import scala.math.Ordered.orderingToOrdered +import scala.util.{Failure, Success} /** * Companion object for APICallV2. @@ -56,15 +62,16 @@ object ApiCallV2 extends SparkSessionWrapper { * @param accumulator To make track of number of api request. * @return */ - def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], tempSuccessPath: String, accumulator: LongAccumulator): ApiCallV2 = { + def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], tempSuccessPath: String + ): ApiCallV2 = { new ApiCallV2(apiEnv) .setEndPoint(apiName) .buildMeta(apiName) .setQueryMap(queryMap) .setSuccessTempPath(tempSuccessPath) - .setAccumulator(accumulator) } + /** * Companion Object which takes three parameter and initialise the ApiCallV2. * @@ -97,6 +104,25 @@ object ApiCallV2 extends SparkSessionWrapper { .setApiV(apiVersion) } + /** + * + * @param apiEnv + * @param apiName + * @param queryMap + * @param tempSuccessPath + * @param apiVersion + * @return + */ + def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], + tempSuccessPath: String, apiVersion: Double): ApiCallV2 = { + new ApiCallV2(apiEnv) + .setEndPoint(apiName) + .buildMeta(apiName) + .setQueryMap(queryMap) + .setSuccessTempPath(tempSuccessPath) + .setApiV(apiVersion) + } + } /** @@ -127,7 +153,6 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { private var _queryMap: Map[String, String] = Map[String, String]() private var _accumulator: LongAccumulator = sc.longAccumulator("ApiAccumulator") //Multithreaded call accumulator will make track of the request. - protected def accumulator: LongAccumulator = _accumulator protected def apiSuccessCount: Int = _apiSuccessCount protected def apiFailureCount: Int = _apiFailureCount @@ -232,6 +257,8 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { this } + def asStrings: Array[String] = _apiResponseArray.toArray(new Array[String](_apiResponseArray.size)) + /** * Setting up the api name and api metadata for that api. * @@ -549,7 +576,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { * Performs api calls in parallel. * @return */ - def executeMultiThread(): util.ArrayList[String] = { + def executeMultiThread(accumulator: LongAccumulator): util.ArrayList[String] = { @tailrec def executeThreadedHelper(): util.ArrayList[String] = { val response = getResponse responseCodeHandler(response) @@ -607,7 +634,6 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { responseCodeHandler(response) _apiResponseArray.add(response.body) if (apiMeta.storeInTempLocation && successTempPath.nonEmpty) { - accumulator.add(1) if (apiEnv.successBatchSize <= _apiResponseArray.size()) { //Checking if its right time to write the batches into persistent storage val responseFlag = PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, _apiResponseArray.toString) if (responseFlag) { //Clearing the resultArray in-case of successful write @@ -646,6 +672,125 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { } } + /** + * Function to make parallel API calls. Currently this functions supports only SqlQueryHistory and ClusterEvents + * @param endpoint + * @param jsonInput + * @param config + * @return + */ + def makeParallelApiCalls(endpoint: String, jsonInput: Map[String, String], config: Config): String = { + val tempEndpointLocation = endpoint.replaceAll("/","") + val acc = sc.longAccumulator(tempEndpointLocation) + + val tmpSuccessPath = if(jsonInput.contains("tmp_success_path")) jsonInput.get("tmp_success_path").get + else s"${config.tempWorkingDir}/${tempEndpointLocation}/${System.currentTimeMillis()}" + + val tmpErrorPath = if(jsonInput.contains("tmp_error_path")) jsonInput.get("tmp_error_path").get + else s"${config.tempWorkingDir}/errors/${tempEndpointLocation}/${System.currentTimeMillis()}" + + var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) + var apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) + val apiResponseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) + implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(config.apiEnv.threadPoolSize)) + val apiMetaFactoryObj = new ApiMetaFactory().getApiClass(endpoint) + val dataFrame_column = apiMetaFactoryObj.dataframeColumn + val parallelApiCallsParams = apiMetaFactoryObj.getParallelAPIParams(jsonInput) + var startValue = parallelApiCallsParams.get("start_value").get.toLong + val endValue = parallelApiCallsParams.get("end_value").get.toLong + val incrementCounter = parallelApiCallsParams.get("increment_counter").get.toLong + val finalResponseCount = parallelApiCallsParams.get("final_response_count").get.toLong + + while (startValue < endValue){ + val jsonQuery = apiMetaFactoryObj.getAPIJsonQuery(startValue, endValue, jsonInput) + + //call future + val future = Future { + val apiObj = ApiCallV2( + config.apiEnv, + endpoint, + jsonQuery, + tempSuccessPath = tmpSuccessPath + ).executeMultiThread(acc) + + synchronized { + apiObj.forEach( + obj=>if(obj.contains(dataFrame_column)){ + apiResponseArray.add(obj) + } + ) + if (apiResponseArray.size() >= config.apiEnv.successBatchSize) { + PipelineFunctions.writeMicroBatchToTempLocation(tmpSuccessPath, apiResponseArray.toString) + apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) + } + } + } + future.onComplete { + case Success(_) => + apiResponseCounter.add(1) + + case Failure(e) => + if (e.isInstanceOf[ApiCallFailureV2]) { + synchronized { + apiErrorArray.add(e.getMessage) + if (apiErrorArray.size() >= config.apiEnv.errorBatchSize) { + PipelineFunctions.writeMicroBatchToTempLocation(tmpErrorPath, apiErrorArray.toString) + apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) + } + } + logger.log(Level.ERROR, "Future failure message: " + e.getMessage, e) + } + apiResponseCounter.add(1) + } + startValue = startValue + incrementCounter + } + + val timeoutThreshold = config.apiEnv.apiWaitingTime // 5 minutes + var currentSleepTime = 0 + var accumulatorCountWhileSleeping = acc.value + while (apiResponseCounter.size() < finalResponseCount && currentSleepTime < timeoutThreshold) { + //As we are using Futures and running 4 threads in parallel, We are checking if all the treads has completed + // the execution or not. If we have not received the response from all the threads then we are waiting for 5 + // seconds and again revalidating the count. + if (currentSleepTime > 120000) //printing the waiting message only if the waiting time is more than 2 minutes. + { + println( + s"""Waiting for other queued API Calls to complete; cumulative wait time ${currentSleepTime / 1000} + |seconds; Api response yet to receive ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) + } + Thread.sleep(5000) + currentSleepTime += 5000 + if (accumulatorCountWhileSleeping < acc.value) { //new API response received while waiting. + currentSleepTime = 0 //resetting the sleep time. + accumulatorCountWhileSleeping = acc.value + } + } + if (apiResponseCounter.size() != finalResponseCount) { // Checking whether all the api responses has been received or not. + logger.log(Level.ERROR, + s"""Unable to receive all the ${endpoint} api responses; Api response + |received ${apiResponseCounter.size()};Api response not + |received ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) + throw new Exception( + s"""Unable to receive all the ${endpoint} api responses; Api response received + |${apiResponseCounter.size()}; + |Api response not received ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) + } + if (apiResponseArray.size() > 0) { //In case of response array didn't hit the batch-size as a + // final step we will write it to the persistent storage. + PipelineFunctions.writeMicroBatchToTempLocation(tmpSuccessPath, apiResponseArray.toString) + apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) + + + } + if (apiErrorArray.size() > 0) { //In case of error array didn't hit the batch-size + // as a final step we will write it to the persistent storage. + PipelineFunctions.writeMicroBatchToTempLocation(tmpErrorPath, apiErrorArray.toString) + apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) + } + tmpSuccessPath + } + } diff --git a/src/main/scala/com/databricks/labs/overwatch/ApiMeta.scala b/src/main/scala/com/databricks/labs/overwatch/ApiMeta.scala index b8787acc1..9f83d751c 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ApiMeta.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ApiMeta.scala @@ -1,7 +1,7 @@ package com.databricks.labs.overwatch import com.databricks.dbutils_v1.DBUtilsHolder.dbutils -import com.databricks.labs.overwatch.utils.ApiEnv +import com.databricks.labs.overwatch.utils.{ApiEnv, TimeTypes} import com.fasterxml.jackson.databind.JsonNode import org.apache.log4j.{Level, Logger} import scalaj.http.{Http, HttpRequest} @@ -107,7 +107,17 @@ trait ApiMeta { logger.log(Level.INFO, s"""Proxy has been set to IP: ${apiEnv.proxyHost.get} PORT:${apiEnv.proxyPort.get}""") } if (apiEnv.proxyUserName.nonEmpty && apiEnv.proxyPasswordScope.nonEmpty && apiEnv.proxyPasswordKey.nonEmpty) { - val password = dbutils.secrets.get(scope = apiEnv.proxyPasswordScope.get, apiEnv.proxyPasswordKey.get) + val password = try { + dbutils.secrets.get(scope = apiEnv.proxyPasswordScope.get, apiEnv.proxyPasswordKey.get) + } catch { + case e: IllegalArgumentException if e.getMessage.contains("Secret does not exist") => + val failMsg = + s"""Error getting proxy secret details using: + |ProxyPasswordScope: ${apiEnv.proxyPasswordScope} + |ProxyPasswordKey: ${apiEnv.proxyPasswordKey} + |""".stripMargin + throw new Exception(failMsg, e) + } request = request.proxyAuth(apiEnv.proxyUserName.get, password) logger.log(Level.INFO, s"""Proxy UserName set to IP: ${apiEnv.proxyUserName.get} scope:${apiEnv.proxyPasswordScope.get} key:${apiEnv.proxyPasswordKey.get}""") } @@ -135,6 +145,16 @@ trait ApiMeta { |""".stripMargin } + private[overwatch] def getAPIJsonQuery(startValue: Long, endValue: Long, jsonInput: Map[String, String]): Map[String, String] = { + logger.log(Level.INFO, s"""Needs to be override for specific API for manipulating the input JSON Query""") + Map[String, String]() + } + + private[overwatch] def getParallelAPIParams(jsonInput: Map[String, String]): Map[String, String] = { + logger.log(Level.INFO, s"""Needs to be override for specific API for intializing Parallel API call function""") + Map[String, String]() + } + } /** @@ -157,6 +177,11 @@ class ApiMetaFactory { case "clusters/resize" => new ClusterResizeApi case "jobs/runs/get" => new JobRunGetApi case "dbfs/search-mounts" => new DbfsSearchMountsApi + case "jobs/runs/list" => new JobRunsApi + case "libraries/all-cluster-statuses" => new ClusterLibraryApi + case "policies/clusters/list" => new ClusterPolicesApi + case "token/list" => new TokensApi + case "global-init-scripts" => new GlobalInitsScriptsApi case _ => new UnregisteredApi } logger.log(Level.INFO, meta.toString) @@ -200,6 +225,32 @@ class SqlQueryHistoryApi extends ApiMeta { // logger.info(s"DEBUG - NEXT_PAGE_TOKEN = ${_jsonValue}") requestMap.filterNot { case (k, _) => k.toLowerCase.startsWith("filter_by")} ++ Map(s"page_token" -> s"${_jsonValue}") } + + private[overwatch] override def getAPIJsonQuery(startValue: Long, endValue: Long, jsonInput: Map[String, String]): Map[String, String] = { + val (startTime, endTime) = if ((endValue - startValue)/(1000*60*60) > 1) { + (startValue, + startValue+(1000*60*60)) + } + else{ + (startValue, + endValue) + } + Map( + "max_results" -> "50", + "include_metrics" -> "true", + "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", + "filter_by.query_start_time_range.end_time_ms" -> s"$endTime" + ) + } + + private[overwatch] override def getParallelAPIParams(jsonInput: Map[String, String]): Map[String, String] = { + Map( + "start_value" -> s"""${jsonInput.get("start_value").get.toLong}""", + "end_value" -> s"""${jsonInput.get("end_value").get.toLong}""", + "increment_counter" -> s"""${jsonInput.get("increment_counter").get.toLong}""", + "final_response_count" -> s"""${jsonInput.get("final_response_count").get.toLong}""" + ) + } } class WorkspaceListApi extends ApiMeta { @@ -259,4 +310,70 @@ class ClusterEventsApi extends ApiMeta { setDataframeColumn("events") setApiCallType("POST") setStoreInTempLocation(true) + + private[overwatch] override def getAPIJsonQuery(startValue: Long, endValue: Long,jsonInput: Map[String, String]): Map[String, String] = { + val clusterIDs = jsonInput.get("cluster_ids").get.split(",").map(_.trim).toArray + val startTime = jsonInput.get("start_time").get.toLong + val endTime = jsonInput.get("end_time").get.toLong + + Map("cluster_id" -> s"""${clusterIDs(startValue.toInt)}""", + "start_time" -> s"""${startTime}""", + "end_time" -> s"""${endTime}""", + "limit" -> "500" + ) + } + + private[overwatch] override def getParallelAPIParams(jsonInput: Map[String, String]): Map[String, String] = { + Map( + "start_value" -> s"""${jsonInput.get("start_value").get.toLong}""", + "end_value" -> s"""${jsonInput.get("end_value").get.toLong}""", + "increment_counter" -> s"""${jsonInput.get("increment_counter").get.toLong}""", + "final_response_count" -> s"""${jsonInput.get("final_response_count").get.toLong}""" + ) + } + +} + +class JobRunsApi extends ApiMeta { + setDataframeColumn("runs") + setApiCallType("GET") + setPaginationKey("has_more") + setIsDerivePaginationLogic(true) + setStoreInTempLocation(true) + + private[overwatch] override def hasNextPage(jsonObject: JsonNode): Boolean = { + jsonObject.get(paginationKey).asBoolean() + } + + private[overwatch] override def getPaginationLogic(jsonObject: JsonNode, requestMap: Map[String, String]): Map[String, String] = { + val limit = Integer.parseInt(requestMap.get("limit").get) + var offset = Integer.parseInt(requestMap.get("offset").get) + val expand_tasks = requestMap.get("expand_tasks").get + offset = offset + limit + Map( + "limit" -> s"${limit}", + "expand_tasks" -> s"${expand_tasks}", + "offset" -> s"${offset}" + ) + } +} + +class ClusterLibraryApi extends ApiMeta { + setDataframeColumn("statuses") + setApiCallType("GET") +} + +class ClusterPolicesApi extends ApiMeta { + setDataframeColumn("policies") + setApiCallType("GET") +} + +class TokensApi extends ApiMeta { + setDataframeColumn("token_infos") + setApiCallType("GET") +} + +class GlobalInitsScriptsApi extends ApiMeta { + setDataframeColumn("scripts") + setApiCallType("GET") } diff --git a/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala b/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala index 7f7892ba4..3bf71aeb5 100644 --- a/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala +++ b/src/main/scala/com/databricks/labs/overwatch/BatchRunner.scala @@ -7,11 +7,8 @@ import org.apache.log4j.{Level, Logger} object BatchRunner extends SparkSessionWrapper { private val logger: Logger = Logger.getLogger(this.getClass) - - private def setGlobalDeltaOverrides(): Unit = { - spark.conf.set("spark.databricks.delta.optimize.maxFileSize", 1024 * 1024 * 128) - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024 * 1024 * 128) - } + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() /** * if args length == 2, 0 = pipeline of bronze, silver, or gold and 1 = overwatch args @@ -22,7 +19,6 @@ object BatchRunner extends SparkSessionWrapper { */ def main(args: Array[String]): Unit = { envInit() - setGlobalDeltaOverrides() // JARS for databricks remote // sc.addJar("C:\\Dev\\git\\Databricks--Overwatch\\target\\scala-2.11\\overwatch_2.11-0.2.jar") @@ -68,7 +64,6 @@ object BatchRunner extends SparkSessionWrapper { Gold(workspace).run() } - } diff --git a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala index 29993561f..dab54816a 100644 --- a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala +++ b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala @@ -1,11 +1,13 @@ package com.databricks.labs.overwatch +import com.databricks.labs.overwatch.env.Workspace import com.databricks.labs.overwatch.pipeline.TransformFunctions._ import com.databricks.labs.overwatch.pipeline._ import com.databricks.labs.overwatch.utils._ import com.databricks.labs.overwatch.validation.DeploymentValidation import org.apache.log4j.{Level, Logger} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import java.text.SimpleDateFormat import java.time.LocalDateTime @@ -13,25 +15,33 @@ import java.util import java.util.concurrent.Executors import java.util.{Collections, Date} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} +import scala.language.postfixOps /** * * MultiWorkspaceDeployment class is the main class which runs the deployment for multiple workspaces. * - * @params configCsvPath: path of the csv file which will contain the different configs for the workspaces. + * @params configLocation: can be either a delta table or path of the delta table or fully qualified path of the csv file which contains the configuration, * @params tempOutputPath: location which will be used as a temp storage.It will be automatically cleaned after each run. * @params apiEnvConfig: configs related to api call. */ object MultiWorkspaceDeployment extends SparkSessionWrapper { - def apply(configCsvPath: String): MultiWorkspaceDeployment = { - apply(configCsvPath, "/mnt/tmp/overwatch") + def apply(configLocation: String): MultiWorkspaceDeployment = { + apply(configLocation, "/mnt/tmp/overwatch") } - def apply(configCsvPath: String, tempOutputPath: String) = { + /** + * + * @param configLocation can be either a delta table or path of the delta table or fully qualified path of the csv file which contains the configuration, + * @param tempOutputPath location which will be used as a temp storage.It will be automatically cleaned after each run. + * @return + */ + def apply(configLocation: String, tempOutputPath: String) = { new MultiWorkspaceDeployment() - .setConfigCsvPath(configCsvPath) + .setConfigLocation(configLocation) .setOutputPath(tempOutputPath) .setPipelineSnapTime() } @@ -44,7 +54,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { private val logger: Logger = Logger.getLogger(this.getClass) - private var _configCsvPath: String = _ + private var _configLocation: String = _ private var _outputPath: String = _ @@ -58,13 +68,13 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { private def outputPath: String = _outputPath - protected def configCsvPath: String = _configCsvPath + protected def configLocation: String = _configLocation private var _pipelineSnapTime: Long = _ - private def setConfigCsvPath(value: String): this.type = { - _configCsvPath = value + private def setConfigLocation(value: String): this.type = { + _configLocation = value this } @@ -100,14 +110,15 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { Some(s"${config.etl_storage_prefix}/${config.consumer_database_name}.db") ) val tokenSecret = TokenSecret(config.secret_scope, config.secret_key_dbpat) - val ehConnString = s"{{secrets/${config.secret_scope}/${config.eh_scope_key}}}" - val ehStatePath = s"${config.etl_storage_prefix}/${config.workspace_id}/ehState" val badRecordsPath = s"${config.etl_storage_prefix}/${config.workspace_id}/sparkEventsBadrecords" + // TODO -- ISSUE 781 - quick fix to support non-json audit logs but needs to be added back to external parameters + val auditLogFormat = spark.conf.getOption("overwatch.aws.auditlogformat").getOrElse("json") val auditLogConfig = if (s"${config.cloud}" == "AWS") { - val awsAuditSourcePath = s"${config.auditlogprefix_source_aws}" - AuditLogConfig(rawAuditPath = Some(awsAuditSourcePath)) + AuditLogConfig(rawAuditPath = config.auditlogprefix_source_aws, auditLogFormat = auditLogFormat) } else { - val azureLogConfig = AzureAuditLogEventhubConfig(connectionString = ehConnString, eventHubName = config.eh_name, auditRawEventsPrefix = ehStatePath) + val ehConnString = s"{{secrets/${config.secret_scope}/${config.eh_scope_key.get}}}" + val ehStatePath = s"${config.etl_storage_prefix}/${config.workspace_id}/ehState" + val azureLogConfig = AzureAuditLogEventhubConfig(connectionString = ehConnString, eventHubName = config.eh_name.get, auditRawEventsPrefix = ehStatePath) AuditLogConfig(azureAuditLogEventhubConfig = Some(azureLogConfig)) } val interactiveDBUPrice: Double = config.interactive_dbu_price @@ -115,10 +126,9 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { val sqlComputerDBUPrice: Double = config.sql_compute_dbu_price val jobsLightDBUPrice: Double = config.jobs_light_dbu_price val customWorkspaceName: String = config.workspace_name - val standardScopes = "audit,sparkEvents,jobs,clusters,clusterEvents,notebooks,pools,accounts".split(",").toBuffer - if (config.excluded_scopes != null) { - config.excluded_scopes.split(":").foreach(scope => standardScopes -= scope) - } + val standardScopes = "audit,sparkEvents,jobs,clusters,clusterEvents,notebooks,pools,accounts,dbsql".split(",") + val scopesToExecute = (standardScopes.map(_.toLowerCase).toSet -- + config.excluded_scopes.getOrElse("").split(":").map(_.toLowerCase).toSet).toArray val maxDaysToLoad: Int = config.max_days val primordialDateString: Date = config.primordial_date @@ -131,14 +141,13 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { dataTarget = Some(dataTarget), tokenSecret = Some(tokenSecret), badRecordsPath = Some(badRecordsPath), - overwatchScope = Some(standardScopes), + overwatchScope = Some(scopesToExecute), maxDaysToLoad = maxDaysToLoad, databricksContractPrices = DatabricksContractPrices(interactiveDBUPrice, automatedDBUPrice, sqlComputerDBUPrice, jobsLightDBUPrice), primordialDateString = Some(stringDate), workspace_name = Some(customWorkspaceName), externalizeOptimize = true, - apiEnvConfig = Some(apiEnvConfig), - tempWorkingDir = "" + apiEnvConfig = Some(apiEnvConfig) ) MultiWorkspaceParams(JsonUtils.objToJson(params).compactString, s"""${config.api_url}""", @@ -166,79 +175,82 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { config.enable_unsafe_SSL.getOrElse(false), config.thread_pool_size.getOrElse(4), config.api_waiting_time.getOrElse(300000), - Some(apiProxyConfig)) + Some(apiProxyConfig), + config.mount_mapping_path) apiEnvConfig } - private def startBronzeDeployment(multiWorkspaceParams: MultiWorkspaceParams) = { + private def startBronzeDeployment(workspace: Workspace, deploymentId: String): MultiWSDeploymentReport = { + val workspaceId = workspace.getConfig.organizationId + val args = JsonUtils.objToJson(workspace.getConfig.inputConfig) + println(s"""************Bronze Deployment Started workspaceID:$workspaceId\nargs:${args.prettyString}********** """) try { - println(s"""************Bronze Deployment Started workspaceID:${multiWorkspaceParams.workspaceId} args:${multiWorkspaceParams.args}********** """) - val workspace = Initializer(multiWorkspaceParams.args, debugFlag = false, - apiURL = Some(multiWorkspaceParams.apiUrl), - organizationID = Some(multiWorkspaceParams.workspaceId)) - Bronze(workspace).run() - println(s"""************Bronze Deployment Completed workspaceID:${multiWorkspaceParams.workspaceId}************ """) - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Bronze", Some(multiWorkspaceParams.args), + println(s"""************Bronze Deployment Completed workspaceID:$workspaceId************ """) + MultiWSDeploymentReport(workspaceId, "Bronze", Some(args.compactString), "SUCCESS", - Some(multiWorkspaceParams.deploymentId) - )) + Some(deploymentId) + ) } catch { case exception: Exception => val fullMsg = PipelineFunctions.appendStackStrace(exception, "Got Exception while Deploying,") logger.log(Level.ERROR, fullMsg) - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Bronze", Some(multiWorkspaceParams.args), + MultiWSDeploymentReport(workspaceId, "Bronze", Some(args.compactString), fullMsg, - Some(multiWorkspaceParams.deploymentId) - )) + Some(deploymentId) + ) + } finally { + clearThreadFromSessionsMap() } } - private def startSilverDeployment(multiWorkspaceParams: MultiWorkspaceParams) = { + private def startSilverDeployment(workspace: Workspace, deploymentId: String): MultiWSDeploymentReport = { + val workspaceId = workspace.getConfig.organizationId + val args = JsonUtils.objToJson(workspace.getConfig.inputConfig) try { - println(s"""************Silver Deployment Started workspaceID:${multiWorkspaceParams.workspaceId} args:${multiWorkspaceParams.args} ************""") - val workspace = Initializer(multiWorkspaceParams.args, debugFlag = false, - apiURL = Some(multiWorkspaceParams.apiUrl), - organizationID = Some(multiWorkspaceParams.workspaceId)) + println(s"""************Silver Deployment Started workspaceID:$workspaceId\nargs:${args.prettyString} ************""") Silver(workspace).run() - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Silver", Some(multiWorkspaceParams.args), + println(s"""************Silver Deployment Completed workspaceID:$workspaceId************""") + MultiWSDeploymentReport(workspaceId, "Silver", Some(args.compactString), "SUCCESS", - Some(multiWorkspaceParams.deploymentId) - )) - println(s"""************Silver Deployment Completed workspaceID:${multiWorkspaceParams.workspaceId}************""") + Some(deploymentId) + ) } catch { case exception: Exception => val fullMsg = PipelineFunctions.appendStackStrace(exception, "Got Exception while Deploying,") logger.log(Level.ERROR, fullMsg) - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Silver", Some(multiWorkspaceParams.args), + MultiWSDeploymentReport(workspaceId, "Silver", Some(args.compactString), fullMsg, - Some(multiWorkspaceParams.deploymentId) - )) + Some(deploymentId) + ) + } finally { + clearThreadFromSessionsMap() } } - private def startGoldDeployment(multiWorkspaceParams: MultiWorkspaceParams) = { + private def startGoldDeployment(workspace: Workspace, deploymentId: String): MultiWSDeploymentReport = { + val workspaceId = workspace.getConfig.organizationId + val args = JsonUtils.objToJson(workspace.getConfig.inputConfig) try { - println(s"""************Gold Deployment Started workspaceID:${multiWorkspaceParams.workspaceId} args:${multiWorkspaceParams.args} ************"""") - val workspace = Initializer(multiWorkspaceParams.args, debugFlag = false, - apiURL = Some(multiWorkspaceParams.apiUrl), - organizationID = Some(multiWorkspaceParams.workspaceId)) + println(s"""************Gold Deployment Started workspaceID:$workspaceId args:${args.prettyString} ************"""") Gold(workspace).run() - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Gold", Some(multiWorkspaceParams.args), + println(s"""************Gold Deployment Completed workspaceID:$workspaceId************""") + MultiWSDeploymentReport(workspaceId, "Gold", Some(args.compactString), "SUCCESS", - Some(multiWorkspaceParams.deploymentId) - )) - println(s"""************Gold Deployment Completed workspaceID:${multiWorkspaceParams.workspaceId}************""") + Some(deploymentId) + ) } catch { case exception: Exception => val fullMsg = PipelineFunctions.appendStackStrace(exception, "Got Exception while Deploying,") logger.log(Level.ERROR, fullMsg) - deploymentReport.append(MultiWSDeploymentReport(multiWorkspaceParams.workspaceId, "Gold", Some(multiWorkspaceParams.args), + MultiWSDeploymentReport(workspaceId, "Gold", Some(args.compactString), fullMsg, - Some(multiWorkspaceParams.deploymentId) - )) + Some(deploymentId) + ) + }finally { + clearThreadFromSessionsMap() } } @@ -255,7 +267,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { multiworkspaceConfigs.toSeq.toDS().toDF() .withColumn("snapTS", lit(pipelineSnapTime.asTSString)) .withColumn("timestamp", lit(pipelineSnapTime.asUnixTimeMilli)) - .withColumn("configFile", lit(configCsvPath)) + .withColumn("configFile", lit(configLocation)) .write.format("delta") .mode("append") .option("mergeSchema", "true") @@ -292,12 +304,12 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { * @param path * @param reportName */ - private def saveDeploymentReport(validationArray: ArrayBuffer[MultiWSDeploymentReport], path: String, reportName: String): Unit = { + private def saveDeploymentReport(validationArray: Array[MultiWSDeploymentReport], path: String, reportName: String): Unit = { var reportPath = path if (!path.startsWith("dbfs:") && !path.startsWith("s3") && !path.startsWith("abfss")) { reportPath = s"""dbfs:${path}""" } - validationArray.toDS().toDF() + validationArray.toSeq.toDF() .withColumn("snapTS", lit(pipelineSnapTime.asTSString)) .withColumn("timestamp", lit(pipelineSnapTime.asUnixTimeMilli)) .write.format("delta").mode("append").save(s"""${reportPath}/report/${reportName}""") @@ -306,33 +318,66 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { /** * Validates the config csv file existence. */ - private[overwatch] def validateFileExistence(configCsvPath: String): Boolean = { - if (!Helpers.pathExists(configCsvPath)) { - throw new BadConfigException("Unable to find config file in the given location:" + configCsvPath) + private[overwatch] def validateFileExistence(configLocation: String): Boolean = { + if (!Helpers.pathExists(configLocation)) { + throw new BadConfigException("Unable to find config file in the given location:" + configLocation) } true } + + private def generateBaseConfig(configLocation: String): DataFrame = { + val rawBaseConfigDF = try { + if (configLocation.toLowerCase().endsWith(".csv")) { // CSV file + println(s"Config source: csv path ${configLocation}") + validateFileExistence(configLocation) + spark.read.option("header", "true") + .option("ignoreLeadingWhiteSpace", true) + .option("ignoreTrailingWhiteSpace", true) + .csv(configLocation) + } else if (configLocation.contains("/")) { // delta path + println(s"Config source: delta path ${configLocation}") + validateFileExistence(configLocation) + spark.read.format("delta").load(configLocation) + } else { // delta table + println(s"Config source: delta table ${configLocation}") + if (!spark.catalog.tableExists(configLocation)) { + throw new BadConfigException("Unable to find Delta table" + configLocation) + } + spark.read.table(configLocation) + } + } catch { + case e: Exception => + println("Exception while reading config , please provide config csv path/config delta path/config delta table") + throw e + } + + val deploymentSelectsNoNullStrings = Schema.deployementMinimumSchema.fields.map(f => { + when(trim(lower(col(f.name))) === "null", lit(null).cast(f.dataType)).otherwise(col(f.name)).alias(f.name) + }) + + rawBaseConfigDF + .verifyMinimumSchema(Schema.deployementMinimumSchema) + .select(deploymentSelectsNoNullStrings: _*) + + } + private def generateMultiWorkspaceConfig( - configCsvPath: String, + configLocation: String, deploymentId: String, outputPath: String = "" ): Array[MultiWorkspaceConfig] = { // Array[MultiWorkspaceConfig] = { try { - validateFileExistence(configCsvPath) - val multiWorkspaceConfig = spark.read.option("header", "true") - .option("ignoreLeadingWhiteSpace", true) - .option("ignoreTrailingWhiteSpace", true) - .csv(configCsvPath) - .scrubSchema - .verifyMinimumSchema(Schema.deployementMinimumSchema) - .filter(MultiWorkspaceConfigColumns.active.toString) + val baseConfig = generateBaseConfig(configLocation) + val multiWorkspaceConfig = baseConfig + .withColumn("api_url", when('api_url.endsWith("/"), 'api_url.substr(lit(0), length('api_url) - 1)).otherwise('api_url)) .withColumn("deployment_id", lit(deploymentId)) .withColumn("output_path", lit(outputPath)) .as[MultiWorkspaceConfig] + .filter(_.active) .collect() - if(multiWorkspaceConfig.size<1){ - throw new BadConfigException("Config file has 0 record, config file:" + configCsvPath) + if(multiWorkspaceConfig.length < 1){ + throw new BadConfigException("Config file has 0 record, config file:" + configLocation) } multiWorkspaceConfig } catch { @@ -345,6 +390,34 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { } + /** + * crate pipeline executions as futures and return the deployment reports + * @param deploymentParams deployment params for a specific workspace + * @param medallions medallions to execute (bronze, silver, gold) + * @param ec futures executionContext + * @return future deployment report + */ + private def executePipelines( + deploymentParams: MultiWorkspaceParams, + medallions: String, + ec: ExecutionContextExecutor + ): Future[Array[MultiWSDeploymentReport]] = { + + Future { + val threadDeploymentReport = ArrayBuffer[MultiWSDeploymentReport]() + val deploymentId = deploymentParams.deploymentId + val workspace = Initializer(deploymentParams.args, + apiURL = Some(deploymentParams.apiUrl), + organizationID = Some(deploymentParams.workspaceId)) + + val zonesLower = medallions.toLowerCase + if (zonesLower.contains("bronze")) threadDeploymentReport.append(startBronzeDeployment(workspace, deploymentId)) + if (zonesLower.contains("silver")) threadDeploymentReport.append(startSilverDeployment(workspace, deploymentId)) + if (zonesLower.contains("gold")) threadDeploymentReport.append(startGoldDeployment(workspace, deploymentId)) + threadDeploymentReport.toArray + }(ec) + + } /** @@ -355,44 +428,42 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { */ def deploy(parallelism: Int = 4, zones: String = "Bronze,Silver,Gold"): Unit = { val processingStartTime = System.currentTimeMillis(); - println("ParallelismLevel :" + parallelism) - - val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath) - snapshotConfig(multiWorkspaceConfig) - val params = DeploymentValidation - .performMandatoryValidation(multiWorkspaceConfig, parallelism) - .map(buildParams) - - println("Workspace to be Deployed :" + params.size) - val zoneArray = zones.split(",") - zoneArray.foreach(zone => { - val responseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(parallelism)) - params.foreach(deploymentParams => { - val future = Future { - zone.toLowerCase match { - case "bronze" => - startBronzeDeployment(deploymentParams) - case "silver" => - startSilverDeployment(deploymentParams) - case "gold" => - startGoldDeployment(deploymentParams) - } - } - future.onComplete { - case _ => - responseCounter.add(1) - } - }) - while (responseCounter.size() < params.length) { - Thread.sleep(5000) - } - }) - saveDeploymentReport(deploymentReport, multiWorkspaceConfig.head.etl_storage_prefix, "deploymentReport") + try { + // initialize spark overrides for global spark conf + // global overrides should be set BEFORE parSessionsOn is set to true + PipelineFunctions.setSparkOverrides(spark(globalSession = true), SparkSessionWrapper.globalSparkConfOverrides) + + if (parallelism > 1) SparkSessionWrapper.parSessionsOn = true + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() + + println("ParallelismLevel :" + parallelism) + val multiWorkspaceConfig = generateMultiWorkspaceConfig(configLocation, deploymentId, outputPath) + snapshotConfig(multiWorkspaceConfig) + val params = DeploymentValidation + .performMandatoryValidation(multiWorkspaceConfig, parallelism) + .map(buildParams) + println("Workspaces to be Deployed :" + params.length) + val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(parallelism)) + val deploymentReports = params.map(executePipelines(_, zones, ec)) + .flatMap(f => Await.result(f, Duration.Inf)) + + deploymentReport.appendAll(deploymentReports) + saveDeploymentReport(deploymentReport.toArray, multiWorkspaceConfig.head.etl_storage_prefix, "deploymentReport") + } catch { + case e: Exception => + val failMsg = s"FAILED DEPLOYMENT WITH EXCEPTION" + println(failMsg) + logger.log(Level.ERROR, failMsg, e) + throw e + } finally { + SparkSessionWrapper.sessionsMap.clear() + SparkSessionWrapper.globalTableLock.clear() + } println(s"""Deployment completed in sec ${(System.currentTimeMillis() - processingStartTime) / 1000}""") - } + /** * Validates all the parameters provided in config csv file and generates a report which is stored at /etl_storrage_prefix/report/validationReport * @@ -401,7 +472,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { */ def validate(parallelism: Int = 4): Unit = { val processingStartTime = System.currentTimeMillis() - val multiWorkspaceConfig = generateMultiWorkspaceConfig(configCsvPath, deploymentId, outputPath) + val multiWorkspaceConfig = generateMultiWorkspaceConfig(configLocation, deploymentId, outputPath) val validations = DeploymentValidation.performValidation(multiWorkspaceConfig, parallelism) val notValidatedCount = validations.filterNot(_.validated).length diff --git a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceRunner.scala b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceRunner.scala index 5bb55e6b8..b228b62f9 100644 --- a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceRunner.scala +++ b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceRunner.scala @@ -10,11 +10,6 @@ object MultiWorkspaceRunner extends SparkSessionWrapper{ private val logger: Logger = Logger.getLogger(this.getClass) - private def setGlobalDeltaOverrides(): Unit = { - spark.conf.set("spark.databricks.delta.optimize.maxFileSize", 1024 * 1024 * 128) - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024 * 1024 * 128) - } - private def validateInputZone(zones: String): Unit = { val zoneArray = zones.split(",").distinct zoneArray.foreach(zone => { @@ -36,7 +31,6 @@ object MultiWorkspaceRunner extends SparkSessionWrapper{ */ def main(args: Array[String]): Unit = { envInit() - setGlobalDeltaOverrides() if (args.length == 1) { //Deploy Bronze,Silver and Gold with default parallelism. logger.log(Level.INFO, "Deploying Bronze,Silver and Gold") MultiWorkspaceDeployment(args(0)).deploy(4,"Bronze,Silver,Gold") diff --git a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala index 09adbff43..7c4af604d 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala @@ -188,8 +188,9 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw getOptionBoolean(masterNode, "apiEnvConfig.enableUnsafeSSL").getOrElse(false), getOptionInt(masterNode, "apiEnvConfig.threadPoolSize").getOrElse(4), getOptionLong(masterNode, "apiEnvConfig.apiWaitingTime").getOrElse(300000), - apiProxyNodeConfig - )) + apiProxyNodeConfig, + getOptionString(masterNode, "apiEnvConfig.mountMappingPath")) + ) } else { None } diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala index f7b0c260d..ec95330ae 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Database.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Database.scala @@ -2,8 +2,8 @@ package com.databricks.labs.overwatch.env import com.databricks.labs.overwatch.pipeline.TransformFunctions._ import com.databricks.labs.overwatch.pipeline.{PipelineFunctions, PipelineTable} -import com.databricks.labs.overwatch.utils.{Config, SparkSessionWrapper, WriteMode} -import io.delta.tables.DeltaTable +import com.databricks.labs.overwatch.utils.{Config, SparkSessionWrapper, WriteMode, MergeScope} +import io.delta.tables.{DeltaMergeBuilder, DeltaTable} import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -13,6 +13,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, Row} import java.util import java.util.UUID import scala.annotation.tailrec +import scala.util.Random class Database(config: Config) extends SparkSessionWrapper { @@ -154,9 +155,9 @@ class Database(config: Config) extends SparkSessionWrapper { |""".stripMargin) df.write.format("delta").save(dfTempPath) - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024 * 1024 * 16) // maximize parallelism on re-read and let + // maximize parallelism on re-read and let AQE bring it back down + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024 * 1024 * 16) spark.conf.set("spark.databricks.delta.formatCheck.enabled", "true") - // AQE bring it back down spark.read.format("delta").load(dfTempPath) } @@ -166,12 +167,16 @@ class Database(config: Config) extends SparkSessionWrapper { .map(_.name).headOption } - // TODO - refactor this write function and the writer from the target - // write function has gotten overly complex - def write(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column, maxMergeScanDates: Array[String] = Array()): Boolean = { + /** + * Add metadata, cleanse duplicates, and persistBeforeWrite for perf as needed based on target + * + * @param df original dataframe that is to be written + * @param target target to which the df is to be written + * @param pipelineSnapTime pipelineSnapTime + * @return + */ + private def preWriteActions(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column): DataFrame = { var finalSourceDF: DataFrame = df - - // apend metadata to source DF finalSourceDF = if (target.withCreateDate) finalSourceDF.withColumn("Pipeline_SnapTS", pipelineSnapTime) else finalSourceDF finalSourceDF = if (target.withOverwatchRunID) finalSourceDF.withColumn("Overwatch_RunID", lit(config.runID)) else finalSourceDF finalSourceDF = if (target.workspaceName) finalSourceDF.withColumn("workspace_name", lit(config.workspaceName)) else finalSourceDF @@ -179,7 +184,79 @@ class Database(config: Config) extends SparkSessionWrapper { // if target is to be deduped, dedup it by keys finalSourceDF = if (!target.permitDuplicateKeys) finalSourceDF.dedupByKey(target.keys, target.incrementalColumns) else finalSourceDF - val finalDF = if (target.persistBeforeWrite) persistAndLoad(finalSourceDF, target) else finalSourceDF + // always persistAndLoad when parallelism > 1 to reduce table lock times + // don't persist and load metadata tables with fixed schemas (isEvolvingSchema false) + val isParallelNotStreamingNotMeta = SparkSessionWrapper.parSessionsOn && !target.isStreaming && !target.isEvolvingSchema + if (target.persistBeforeWrite || isParallelNotStreamingNotMeta) persistAndLoad(finalSourceDF, target) else finalSourceDF + } + + /** + * + * @param deltaTarget Delta table to which to write + * @param updatesDF DF to be merged into the delta table + * @param mergeCondition merge logic as a string + * @param target Pipeline Table target to write to + * @return + */ + private def deriveDeltaMergeBuilder( + deltaTarget: DeltaTable, + updatesDF: DataFrame, + mergeCondition: String, + target: PipelineTable + ): DeltaMergeBuilder = { + + val mergeScope = target.mergeScope + logger.log(Level.INFO, s"BEGINNING MERGE for target ${target.tableFullName}. \nMERGE SCOPE: " + + s"$mergeScope") + + if (mergeScope == MergeScope.insertOnly) { + deltaTarget + .merge(updatesDF, mergeCondition) + .whenNotMatched + .insertAll() + } else if (mergeScope == MergeScope.updateOnly) { + deltaTarget + .merge(updatesDF, mergeCondition) + .whenMatched + .updateAll() + } + else if (mergeScope == MergeScope.full) { + deltaTarget + .merge(updatesDF, mergeCondition) + .whenMatched + .updateAll() + .whenNotMatched + .insertAll() + } else { + throw new Exception("Merge Scope Not Supported") + } + } + + // TODO - refactor this write function and the writer from the target + // write function has gotten overly complex + /** + * Write the dataframe to the target + * + * @param df df with data to be written + * @param target target to where data should be saved + * @param pipelineSnapTime pipelineSnapTime + * @param maxMergeScanDates perf -- max dates to scan on a merge write. If populated this will filter the target source + * df to minimize the right-side scan. + * @param preWritesPerformed whether pre-write actions have already been completed prior to calling write. This is + * defaulted to false and should only be set to true in specific circumstances when + * preWriteActions was called prior to calling the write function. Typically used for + * concurrency locking. + * @return + */ + def write(df: DataFrame, target: PipelineTable, pipelineSnapTime: Column, maxMergeScanDates: Array[String] = Array(), preWritesPerformed: Boolean = false): Unit = { + val finalSourceDF: DataFrame = df + + // append metadata to source DF and cleanse as necessary + val finalDF = if (!preWritesPerformed) { + preWriteActions(finalSourceDF, target, pipelineSnapTime) + } else { + finalSourceDF + } // ON FIRST RUN - WriteMode is automatically overwritten to APPEND if (target.writeMode == WriteMode.merge) { // DELTA MERGE / UPSERT @@ -204,12 +281,7 @@ class Database(config: Config) extends SparkSessionWrapper { logger.log(Level.INFO, mergeDetailMsg) spark.conf.set("spark.databricks.delta.commitInfo.userMetadata", config.runID) // TODO -- when DBR 9.1 LTS GA, use LSM (low-shuffle-merge) to improve pipeline - deltaTarget - .merge(updatesDF, mergeCondition) - .whenMatched - .updateAll() - .whenNotMatched - .insertAll() + deriveDeltaMergeBuilder(deltaTarget, updatesDF, mergeCondition, target) .execute() spark.conf.unset("spark.databricks.delta.commitInfo.userMetadata") @@ -254,30 +326,59 @@ class Database(config: Config) extends SparkSessionWrapper { logger.log(Level.INFO, s"Completed write to ${target.tableFullName}") } registerTarget(target) - true } + /** + * Function forces the thread to sleep for some specified time. + * @param tableName Name of table for logging only + * @param minimumSeconds minimum number of seconds for which to sleep + * @param maxRandomSeconds max random seconds to add + */ + private def coolDown(tableName: String, minimumSeconds: Long, maxRandomSeconds: Long): Unit = { + val rnd = new scala.util.Random + val number: Long = ((rnd.nextFloat() * maxRandomSeconds) + minimumSeconds).toLong * 1000L + logger.log(Level.INFO,"Slowing parallel writes to " + tableName + "sleeping..." + number + + " thread name " + Thread.currentThread().getName) + Thread.sleep(number) + } - def performRetry(inputDf: DataFrame, + /** + * Perform write retry after cooldown for legacy deployments + * race conditions can occur when multiple workspaces attempt to modify the schema at the same time, when this + * occurs, simply retry the write after some cooldown + * @param inputDf DF to write + * @param target target to where to write + * @param pipelineSnapTime pipelineSnapTime + * @param maxMergeScanDates same as write function + */ + private def performRetry(inputDf: DataFrame, target: PipelineTable, pipelineSnapTime: Column, - maxMergeScanDates: Array[String] = Array()): this.type = { - @tailrec def executeRetry(retryCount: Int): this.type = { + maxMergeScanDates: Array[String] = Array()): Unit = { + @tailrec def executeRetry(retryCount: Int): Unit = { val rerunFlag = try { write(inputDf, target, pipelineSnapTime, maxMergeScanDates) false } catch { case e: Throwable => val exceptionMsg = e.getMessage.toLowerCase() + logger.log(Level.WARN, + s""" + |DELTA Table Write Failure: + |$exceptionMsg + |Will Retry After a small delay. + |This is usually caused by multiple writes attempting to evolve the schema simultaneously + |""".stripMargin) if (exceptionMsg != null && (exceptionMsg.contains("concurrent") || exceptionMsg.contains("conflicting")) && retryCount < 5) { - coolDown(target.tableFullName) + coolDown(target.tableFullName, 30, 30) true } else { throw e } } - if (retryCount < 5 && rerunFlag) executeRetry(retryCount + 1) else this + if (retryCount < 5 && rerunFlag) executeRetry(retryCount + 1) } + try { executeRetry(1) } catch { @@ -287,33 +388,93 @@ class Database(config: Config) extends SparkSessionWrapper { } - def writeWithRetry(df: DataFrame, - target: PipelineTable, - pipelineSnapTime: Column, - maxMergeScanDates: Array[String] = Array(), - daysToProcess: Option[Int] = None): Boolean = { + /** + * Used for multithreaded multiworkspace deployments + * Check if a table is locked and if it is wait until max timeout for it to be unlocked and fail the write if + * the timeout is reached + * Table Lock Timeout can be overridden by setting cluster spark config overwatch.tableLockTimeout -- in milliseconds + * default table lock timeout is 20 minutes or 1200000 millis + * + * @param tableName name of the table to be written + * @return + */ + private def targetNotLocked(tableName: String): Boolean = { + val defaultTimeout: String = "1200000" + val timeout = spark(globalSession = true).conf.getOption("overwatch.tableLockTimeout").getOrElse(defaultTimeout).toLong + val timerStart = System.currentTimeMillis() + + @tailrec def testLock(retryCount: Int): Boolean = { + val currWaitTime = System.currentTimeMillis() - timerStart + val withinTimeout = currWaitTime < timeout + if (SparkSessionWrapper.globalTableLock.contains(tableName)) { + if (withinTimeout) { + logger.log(Level.WARN, s"TABLE LOCKED: $tableName for $currWaitTime -- waiting for parallel writes to complete") + coolDown(tableName, retryCount * 5, 2) // add 5 to 7 seconds to sleep for each lock test + testLock(retryCount + 1) + } else { + throw new Exception(s"TABLE LOCK TIMEOUT - The table $tableName remained locked for more than the configured " + + s"max timeout of $timeout millis. This may be increased by setting the following spark config in the cluster" + + s"to something higher than the default (20 minutes). Usually only necessary for historical loads. \n" + + s"overwatch.tableLockTimeout") + } + } else true // table not locked + } - val needsCache = daysToProcess.getOrElse(1000) < 5 && !target.autoOptimize + testLock(retryCount = 1) + } + + /** + * Wrapper for the write function + * If legacy deployment retry delta write in event of race condition + * If multiworkspace -- implement table locking to alleviate race conditions on parallelize workspace loads + * + * @param df + * @param target + * @param pipelineSnapTime + * @param maxMergeScanDates + * @param daysToProcess + * @return + */ + private[overwatch] def writeWithRetry(df: DataFrame, + target: PipelineTable, + pipelineSnapTime: Column, + maxMergeScanDates: Array[String] = Array(), + daysToProcess: Option[Int] = None): Unit = { + + // needsCache when it's a small number of days and not in parallel and not autoOptimize + // when in parallel disable cache because it will always use persistAndLoad to reduce table lock times. + // persist and load will all be able to happen in parallel to temp location and use a simple read/write to + // merge into target rather than locking the target for the entire time all the transforms are being executed. + val needsCache = daysToProcess.getOrElse(1000) < 5 && + !target.autoOptimize && + !SparkSessionWrapper.parSessionsOn && + target.isEvolvingSchema // don't cache small meta tables + + logger.log(Level.INFO, s"PRE-CACHING TARGET ${target.tableFullName} ENABLED: $needsCache") val inputDf = if (needsCache) { logger.log(Level.INFO, "Persisting data :" + target.tableFullName) df.persist() } else df if (needsCache) inputDf.count() - performRetry(inputDf,target, pipelineSnapTime, maxMergeScanDates) - true - } - /** - * Function forces the thread to sleep for a random 30-60 seconds. - * @param tableName - */ - private def coolDown(tableName: String): Unit = { - val rnd = new scala.util.Random - val number:Long = (rnd.nextFloat() * 30 + 30).toLong*1000 - logger.log(Level.INFO,"Slowing multithreaded writing for " + tableName + "sleeping..." + number+" thread name "+Thread.currentThread().getName) - Thread.sleep(number) + if (target.config.isMultiworkspaceDeployment && target.requiresLocking) { // multi-workspace ++ locking + val withMetaDf = preWriteActions(df, target, pipelineSnapTime) + if (targetNotLocked(target.tableFullName)) { + try { + SparkSessionWrapper.globalTableLock.add(target.tableFullName) + write(withMetaDf, target, pipelineSnapTime, maxMergeScanDates, preWritesPerformed = true) + } catch { + case e: Throwable => throw e + } finally { + SparkSessionWrapper.globalTableLock.remove(target.tableFullName) + } + } + } else { // not multiworkspace -- or if multiworkspace and does not require locking + performRetry(inputDf, target, pipelineSnapTime, maxMergeScanDates) + } } + } object Database { diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala index 4550db71c..5cb8bcbbc 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala @@ -150,8 +150,7 @@ class Workspace(config: Config) extends SparkSessionWrapper { config.apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = s"${config.tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}", - accumulator = acc + tempSuccessPath = s"${config.tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}" ) .execute() .asDF() @@ -160,115 +159,24 @@ class Workspace(config: Config) extends SparkSessionWrapper { def getSqlQueryHistoryParallelDF(fromTime: TimeTypes, untilTime: TimeTypes): DataFrame = { val sqlQueryHistoryEndpoint = "sql/history/queries" - val acc = sc.longAccumulator("sqlQueryHistoryAccumulator") - var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - var apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - val apiResponseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(config.apiEnv.threadPoolSize)) - val tmpSqlQueryHistorySuccessPath = s"${config.tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}" - val tmpSqlQueryHistoryErrorPath = s"${config.tempWorkingDir}/errors/sqlqueryhistory_silver/${System.currentTimeMillis()}" val untilTimeMs = untilTime.asUnixTimeMilli - var fromTimeMs = fromTime.asUnixTimeMilli - (1000*60*60*24*2) //subtracting 2 days for running query merge - val finalResponseCount = scala.math.ceil((untilTimeMs - fromTimeMs).toDouble/(1000*60*60)) // Total no. of API Calls - while (fromTimeMs < untilTimeMs){ - val (startTime, endTime) = if ((untilTimeMs- fromTimeMs)/(1000*60*60) > 1) { - (fromTimeMs, - fromTimeMs+(1000*60*60)) - } - else{ - (fromTimeMs, - untilTimeMs) - } - //create payload for the API calls - val jsonQuery = Map( - "max_results" -> "50", - "include_metrics" -> "true", - "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", - "filter_by.query_start_time_range.end_time_ms" -> s"$endTime" - ) - /**TODO: - * Refactor the below code to make it more generic - */ - //call future - val future = Future { - val apiObj = ApiCallV2( - config.apiEnv, - sqlQueryHistoryEndpoint, - jsonQuery, - tempSuccessPath = tmpSqlQueryHistorySuccessPath, - accumulator = acc - ).executeMultiThread() - - synchronized { - apiObj.forEach( - obj=>if(obj.contains("res")){ - apiResponseArray.add(obj) - } - ) - if (apiResponseArray.size() >= config.apiEnv.successBatchSize) { - PipelineFunctions.writeMicroBatchToTempLocation(tmpSqlQueryHistorySuccessPath, apiResponseArray.toString) - apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - } - } - future.onComplete { - case Success(_) => - apiResponseCounter.add(1) - - case Failure(e) => - if (e.isInstanceOf[ApiCallFailureV2]) { - synchronized { - apiErrorArray.add(e.getMessage) - if (apiErrorArray.size() >= config.apiEnv.errorBatchSize) { - PipelineFunctions.writeMicroBatchToTempLocation(tmpSqlQueryHistoryErrorPath, apiErrorArray.toString) - apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - } - logger.log(Level.ERROR, "Future failure message: " + e.getMessage, e) - } - apiResponseCounter.add(1) - } - fromTimeMs = fromTimeMs+(1000*60*60) - } + val fromTimeMs = fromTime.asUnixTimeMilli - (1000*60*60*24*2) //subtracting 2 days for running query merge + val finalResponseCount = scala.math.ceil((untilTimeMs - fromTimeMs).toDouble/(1000*60*60)).toLong// Total no. of API Calls + + // creating Json input for parallel API calls + val jsonInput = Map( + "start_value" -> s"${fromTimeMs}", + "end_value" -> s"${untilTimeMs}", + "increment_counter" -> "3600000", + "final_response_count" -> s"${finalResponseCount}", + "result_key" -> "res" + ) - val timeoutThreshold = config.apiEnv.apiWaitingTime // 5 minutes - var currentSleepTime = 0 - var accumulatorCountWhileSleeping = acc.value - while (apiResponseCounter.size() < finalResponseCount && currentSleepTime < timeoutThreshold) { - //As we are using Futures and running 4 threads in parallel, We are checking if all the treads has completed - // the execution or not. If we have not received the response from all the threads then we are waiting for 5 - // seconds and again revalidating the count. - if (currentSleepTime > 120000) //printing the waiting message only if the waiting time is more than 2 minutes. - { - println( - s"""Waiting for other queued API Calls to complete; cumulative wait time ${currentSleepTime / 1000} - |seconds; Api response yet to receive ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) - } - Thread.sleep(5000) - currentSleepTime += 5000 - if (accumulatorCountWhileSleeping < acc.value) { //new API response received while waiting. - currentSleepTime = 0 //resetting the sleep time. - accumulatorCountWhileSleeping = acc.value - } - } - if (apiResponseCounter.size() != finalResponseCount) { // Checking whether all the api responses has been received or not. - logger.log(Level.ERROR, - s"""Unable to receive all the sql/history/queries api responses; Api response - |received ${apiResponseCounter.size()};Api response not - |received ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) - throw new Exception( - s"""Unable to receive all the sql/history/queries api responses; Api response received - |${apiResponseCounter.size()};Api response not received ${finalResponseCount - apiResponseCounter.size()}""".stripMargin) - } - if (apiResponseArray.size() > 0) { //In case of response array didn't hit the batch-size as a final step we will write it to the persistent storage. - PipelineFunctions.writeMicroBatchToTempLocation(tmpSqlQueryHistorySuccessPath, apiResponseArray.toString) - apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - if (apiErrorArray.size() > 0) { //In case of error array didn't hit the batch-size as a final step we will write it to the persistent storage. - PipelineFunctions.writeMicroBatchToTempLocation(tmpSqlQueryHistorySuccessPath, apiErrorArray.toString) - apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - } + // calling function to make parallel API calls + val apiCallV2Obj = new ApiCallV2(config.apiEnv) + val tmpSqlQueryHistorySuccessPath= apiCallV2Obj.makeParallelApiCalls(sqlQueryHistoryEndpoint, jsonInput, config) logger.log(Level.INFO, " sql query history landing completed") + if(Helpers.pathExists(tmpSqlQueryHistorySuccessPath)) { try { spark.read.json(tmpSqlQueryHistorySuccessPath) @@ -317,6 +225,91 @@ class Workspace(config: Config) extends SparkSessionWrapper { }) } + def getClusterLibraries: DataFrame = { + val libsEndpoint = "libraries/all-cluster-statuses" + ApiCallV2(config.apiEnv, libsEndpoint) + .execute() + .asDF() + .withColumn("organization_id", lit(config.organizationId)) + } + + def getClusterPolicies: DataFrame = { + val policiesEndpoint = "policies/clusters/list" + ApiCallV2(config.apiEnv, policiesEndpoint) + .execute() + .asDF() + .withColumn("organization_id", lit(config.organizationId)) + } + + def getTokens: DataFrame = { + val tokenEndpoint = "token/list" + ApiCallV2(config.apiEnv, tokenEndpoint) + .execute() + .asDF() + .withColumn("organization_id", lit(config.organizationId)) + } + + def getGlobalInitScripts: DataFrame = { + val globalInitScEndpoint = "global-init-scripts" + ApiCallV2(config.apiEnv, globalInitScEndpoint) + .execute() + .asDF() + .withColumn("organization_id", lit(config.organizationId)) + } + + /** + * Function to get the the list of Job Runs + * @return + */ + def getJobRunsDF(fromTime: TimeTypes, untilTime: TimeTypes): DataFrame = { + val jobsRunsEndpoint = "jobs/runs/list" + val jsonQuery = Map( + "limit" -> "25", + "expand_tasks" -> "true", + "offset" -> "0", + "start_time_from" -> s"${fromTime.asUnixTimeMilli}", + "start_time_to" -> s"${untilTime.asUnixTimeMilli}" + ) + val acc = sc.longAccumulator("sqlQueryHistoryAccumulator") + var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) + val tempWorkingDir = s"${config.tempWorkingDir}/jobrunslist_bronze/${System.currentTimeMillis()}" + + val apiObj = ApiCallV2(config.apiEnv, + jobsRunsEndpoint, + jsonQuery, + tempSuccessPath = tempWorkingDir, + 2.1).executeMultiThread(acc) + + apiObj.forEach( + obj => if (obj.contains("runs")) { + apiResponseArray.add(obj) + } + ) + + if (apiResponseArray.size() > 0) { //In case of response array didn't hit the batch-size as a final step we will write it to the persistent storage. + PipelineFunctions.writeMicroBatchToTempLocation(tempWorkingDir, apiResponseArray.toString) + } + + if(Helpers.pathExists(tempWorkingDir)) { + try { + spark.conf.set("spark.sql.caseSensitive", "true") + val df = spark.read.json(tempWorkingDir) + .select(explode(col("runs")).alias("runs")).select(col("runs" + ".*")) + .withColumn("organization_id", lit(config.organizationId)) + spark.conf.set("spark.sql.caseSensitive", "false") + df + } catch { + case e: Throwable => + throw new Exception(e) + } + } else { + println(s"""No Data is present for jobs/runs/list from - ${fromTime.asUnixTimeMilli} to - ${untilTime.asUnixTimeMilli}""") + logger.log(Level.INFO,s"""No Data is present for jobs/runs/list from - ${fromTime.asUnixTimeMilli} to - ${untilTime.asUnixTimeMilli}""") + spark.emptyDataFrame + } + + } + /** * Create a backup of the Overwatch datasets * diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala index 66d4caeae..9c0cbcdfa 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala @@ -26,7 +26,13 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) BronzeTargets.processedEventLogs, BronzeTargets.cloudMachineDetail, BronzeTargets.dbuCostDetail, - BronzeTargets.clusterEventsErrorsTarget + BronzeTargets.clusterEventsErrorsTarget, + BronzeTargets.libsSnapshotTarget, + BronzeTargets.policiesSnapshotTarget, + BronzeTargets.instanceProfilesSnapshotTarget, + BronzeTargets.tokensSnapshotTarget, + BronzeTargets.globalInitScSnapshotTarget, + BronzeTargets.jobRunsSnapshotTarget ) } @@ -131,7 +137,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) BronzeTargets.clustersSnapshotTarget.asDF, Seq( prepClusterEventLogs( - BronzeTargets.auditLogsTarget.asIncrementalDF(clusterEventLogsModule, BronzeTargets.auditLogsTarget.incrementalColumns), + BronzeTargets.auditLogsTarget.asIncrementalDF(clusterEventLogsModule, BronzeTargets.auditLogsTarget.incrementalColumns, additionalLagDays = 1), // 1 lag day to get laggard records clusterEventLogsModule.fromTime, clusterEventLogsModule.untilTime, pipelineSnapTime, @@ -139,7 +145,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) config.organizationId, database, BronzeTargets.clusterEventsErrorsTarget, - config.tempWorkingDir + config ) ), append(BronzeTargets.clusterEventsTarget) @@ -167,7 +173,8 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) BronzeTargets.clustersSnapshotTarget, sparkLogClusterScaleCoefficient, config.apiEnv, - config.isMultiworkspaceDeployment + config.isMultiworkspaceDeployment, + config.organizationId ), generateEventLogsDF( database, @@ -182,9 +189,47 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) append(BronzeTargets.sparkEventLogsTarget) // Not new data only -- date filters handled in function logic ) + lazy private[overwatch] val libsSnapshotModule = Module(1007, "Bronze_Libraries_Snapshot", this) + lazy private val appendLibsProcess = ETLDefinition( + workspace.getClusterLibraries, + append(BronzeTargets.libsSnapshotTarget) + ) + + lazy private[overwatch] val policiesSnapshotModule = Module(1008, "Bronze_Policies_Snapshot", this) + lazy private val appendPoliciesProcess = ETLDefinition( + workspace.getClusterPolicies, + append(BronzeTargets.policiesSnapshotTarget) + ) + + lazy private[overwatch] val instanceProfileSnapshotModule = Module(1009, "Bronze_Instance_Profile_Snapshot", this) + lazy private val appendInstanceProfileProcess = ETLDefinition( + workspace.getProfilesDF, + append(BronzeTargets.instanceProfilesSnapshotTarget) + ) + + lazy private[overwatch] val tokenSnapshotModule = Module(1010, "Bronze_Token_Snapshot", this) + lazy private val appendTokenProcess = ETLDefinition( + workspace.getTokens, + append(BronzeTargets.tokensSnapshotTarget) + ) + + lazy private[overwatch] val globalInitScSnapshotModule = Module(1011, "Bronze_Global_Init_Scripts_Snapshot", this) + lazy private val appendGlobalInitScProcess = ETLDefinition( + workspace.getGlobalInitScripts, + append(BronzeTargets.globalInitScSnapshotTarget) + ) + + lazy private[overwatch] val jobRunsSnapshotModule = Module(1012, "Bronze_Job_Runs_Snapshot", this) // check module number + lazy private val appendJobRunsProcess = ETLDefinition( + workspace.getJobRunsDF(jobRunsSnapshotModule.fromTime, jobRunsSnapshotModule.untilTime), + Seq(cleanseRawJobRunsSnapDF(BronzeTargets.jobRunsSnapshotTarget.keys, config.runID)), + append(BronzeTargets.jobRunsSnapshotTarget) + ) + // TODO -- convert and merge this into audit's ETLDefinition private def landAzureAuditEvents(): Unit = { + println(s"Audit Logs Bronze: Land Stream Beginning for WorkspaceID: ${config.organizationId}") val rawAzureAuditEvents = landAzureAuditLogDF( BronzeTargets.auditLogAzureLandRaw, config.auditLogConfig.azureAuditLogEventhubConfig.get, @@ -208,11 +253,25 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) landAzureAuditEvents() } auditLogsModule.execute(appendAuditLogsProcess) - case OverwatchScope.clusters => clustersSnapshotModule.execute(appendClustersAPIProcess) + case OverwatchScope.clusters => + clustersSnapshotModule.execute(appendClustersAPIProcess) + libsSnapshotModule.execute(appendLibsProcess) + policiesSnapshotModule.execute(appendPoliciesProcess) + if (config.cloudProvider== "aws") { + instanceProfileSnapshotModule.execute(appendInstanceProfileProcess) + } case OverwatchScope.clusterEvents => clusterEventLogsModule.execute(appendClusterEventLogsProcess) - case OverwatchScope.jobs => jobsSnapshotModule.execute(appendJobsProcess) + case OverwatchScope.jobs => + jobsSnapshotModule.execute(appendJobsProcess) + // setting this to experimental -- runtimes can be EXTREMELY long for customers with MANY job runs + if (spark(globalSession = true).conf.getOption("overwatch.experimental.enablejobrunsnapshot").getOrElse("false").toBoolean) { + jobRunsSnapshotModule.execute(appendJobRunsProcess) + } case OverwatchScope.pools => poolsSnapshotModule.execute(appendPoolsProcess) case OverwatchScope.sparkEvents => sparkEventLogsModule.execute(appendSparkEventLogsProcess) + case OverwatchScope.accounts => + tokenSnapshotModule.execute(appendTokenProcess) + globalInitScSnapshotModule.execute(appendGlobalInitScProcess) case _ => } } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala index d4f357f58..dbc134d33 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala @@ -1,5 +1,6 @@ package com.databricks.labs.overwatch.pipeline +import com.databricks.dbutils_v1.DBUtilsHolder.dbutils import com.databricks.labs.overwatch.env.Database import com.databricks.labs.overwatch.eventhubs.AadAuthInstance import com.databricks.labs.overwatch.pipeline.WorkflowsTransforms.{workflowsCleanseJobClusters, workflowsCleanseTasks} @@ -273,7 +274,8 @@ trait BronzeTransforms extends SparkSessionWrapper { "settings.job_clusters" -> col("cleansedJobsClusters"), "settings.tags" -> SchemaTools.structToMap(outputDF, "settings.tags"), "settings.notebook_task.base_parameters" -> SchemaTools.structToMap(outputDF, "settings.notebook_task.base_parameters") - ) ++ PipelineFunctions.newClusterCleaner(outputDF, "settings.tasks.new_cluster") + ) ++ PipelineFunctions.newClusterCleaner(outputDF, "settings.tasks.new_cluster") ++ + PipelineFunctions.newClusterCleaner(outputDF, "settings.new_cluster") outputDF .join(cleansedTasksDF, keys.toSeq, "left") @@ -349,7 +351,9 @@ trait BronzeTransforms extends SparkSessionWrapper { PipelineFunctions.cleanseCorruptAuditLogs(spark, baselineAuditLogs) .withColumn("response", structFromJson(spark, schemaBuilders, "response")) - .drop("logId") + .withColumn("requestParamsJson", to_json('requestParams)) + .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) + .drop("logId", "requestParamsJson") } else { @@ -357,7 +361,7 @@ trait BronzeTransforms extends SparkSessionWrapper { val datesGlob = if (fromDT == untilDT) { Array(s"${auditLogConfig.rawAuditPath.get}/date=${fromDT.toString}") } else { - getDatesGlob(fromDT, untilDT) + getDatesGlob(fromDT, untilDT.plusDays(1)) // add one day to until to ensure intra-day audit logs prior to untilTS are captured. .map(dt => s"${auditLogConfig.rawAuditPath.get}/date=${dt}") .filter(Helpers.pathExists) } @@ -387,6 +391,9 @@ trait BronzeTransforms extends SparkSessionWrapper { baseDF // When globbing the paths, the date must be reconstructed and re-added manually .withColumn("organization_id", lit(organizationId)) + .withColumn("requestParamsJson", to_json('requestParams)) + .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) + .drop("requestParamsJson") .withColumn("filename", input_file_name) .withColumn("filenameAR", split(input_file_name, "/")) .withColumn("date", @@ -463,85 +470,28 @@ trait BronzeTransforms extends SparkSessionWrapper { endTime: TimeTypes, apiEnv: ApiEnv, tmpClusterEventsSuccessPath: String, - tmpClusterEventsErrorPath: String) = { + tmpClusterEventsErrorPath: String, + config: Config) = { val finalResponseCount = clusterIDs.length - var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - var apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - val apiResponseCounter = Collections.synchronizedList(new util.ArrayList[Int]()) - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(apiEnv.threadPoolSize)) - //TODO identify the best practice to implement the future. - val accumulator = sc.longAccumulator("ClusterEventsAccumulator") - for (i <- clusterIDs.indices) { - val jsonQuery = Map("cluster_id" -> s"""${clusterIDs(i)}""", - "start_time" -> s"""${startTime.asUnixTimeMilli}""", - "end_time" -> s"""${endTime.asUnixTimeMilli}""", - "limit" -> "500" - ) - val future = Future { - val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath,accumulator).executeMultiThread() - synchronized { - apiObj.forEach( - obj => if (obj.contains("events")) { - apiResponseArray.add(obj) - }else{ - logger.log(Level.INFO,"NO real events found:"+obj) - } - - ) - if (apiResponseArray.size() >= apiEnv.successBatchSize) { - PipelineFunctions.writeMicroBatchToTempLocation(tmpClusterEventsSuccessPath, apiResponseArray.toString) - apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - } + implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(config.apiEnv.threadPoolSize)) + val clusterEventsEndpoint = "clusters/events" + + // creating Json input for parallel API calls + val jsonInput = Map( + "start_value" -> "0", + "end_value" -> s"${finalResponseCount}", + "increment_counter" -> "1", + "final_response_count" -> s"${finalResponseCount}", + "cluster_ids" -> s"${clusterIDs.mkString(",")}", + "start_time" -> s"${startTime.asUnixTimeMilli}", + "end_time" -> s"${endTime.asUnixTimeMilli}", + "tmp_success_path" -> tmpClusterEventsSuccessPath, + "tmp_error_path" -> tmpClusterEventsErrorPath + ) - } - future.onComplete { - case Success(_) => - apiResponseCounter.add(1) - - case Failure(e) => - if (e.isInstanceOf[ApiCallFailureV2]) { - synchronized { - apiErrorArray.add(e.getMessage) - if (apiErrorArray.size() >= apiEnv.errorBatchSize) { - PipelineFunctions.writeMicroBatchToTempLocation(tmpClusterEventsErrorPath, apiErrorArray.toString) - apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - } - logger.log(Level.ERROR, "Future failure message: " + e.getMessage, e) - } - apiResponseCounter.add(1) - } - } - val timeoutThreshold = apiEnv.apiWaitingTime // 5 minutes - var currentSleepTime = 0 - var accumulatorCountWhileSleeping = accumulator.value - while (apiResponseCounter.size() < finalResponseCount && currentSleepTime < timeoutThreshold) { - //As we are using Futures and running 4 threads in parallel, We are checking if all the treads has completed the execution or not. - // If we have not received the response from all the threads then we are waiting for 5 seconds and again revalidating the count. - if (currentSleepTime > 120000) //printing the waiting message only if the waiting time is more than 2 minutes. - { - println(s"""Waiting for other queued API Calls to complete; cumulative wait time ${currentSleepTime / 1000} seconds; Api response yet to receive ${finalResponseCount - apiResponseCounter.size()}""") - } - Thread.sleep(5000) - currentSleepTime += 5000 - if (accumulatorCountWhileSleeping < accumulator.value) { //new API response received while waiting. - currentSleepTime = 0 //resetting the sleep time. - accumulatorCountWhileSleeping = accumulator.value - } - } - if (apiResponseCounter.size() != finalResponseCount) { // Checking whether all the api responses has been received or not. - logger.log(Level.ERROR, s"""Unable to receive all the clusters/events api responses; Api response received ${apiResponseCounter.size()};Api response not received ${finalResponseCount - apiResponseCounter.size()}""") - throw new Exception(s"""Unable to receive all the clusters/events api responses; Api response received ${apiResponseCounter.size()};Api response not received ${finalResponseCount - apiResponseCounter.size()}""") - } - if (apiResponseArray.size() > 0) { //In case of response array didn't hit the batch-size as a final step we will write it to the persistent storage. - PipelineFunctions.writeMicroBatchToTempLocation(tmpClusterEventsSuccessPath, apiResponseArray.toString) - apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - } - if (apiErrorArray.size() > 0) { //In case of error array didn't hit the batch-size as a final step we will write it to the persistent storage. - PipelineFunctions.writeMicroBatchToTempLocation(tmpClusterEventsErrorPath, apiErrorArray.toString) - apiErrorArray = Collections.synchronizedList(new util.ArrayList[String]()) - } + // calling function to make parallel API calls + val apiCallV2Obj = new ApiCallV2(config.apiEnv) + apiCallV2Obj.makeParallelApiCalls(clusterEventsEndpoint, jsonInput, config) logger.log(Level.INFO, " Cluster event landing completed") } @@ -604,7 +554,7 @@ trait BronzeTransforms extends SparkSessionWrapper { organizationId: String, database: Database, erroredBronzeEventsTarget: PipelineTable, - tempWorkingDir: String + config: Config )(clusterSnapshotDF: DataFrame): DataFrame = { val clusterIDs = getClusterIdsWithNewEvents(filteredAuditLogDF, clusterSnapshotDF) @@ -617,10 +567,11 @@ trait BronzeTransforms extends SparkSessionWrapper { val processingStartTime = System.currentTimeMillis(); logger.log(Level.INFO, "Calling APIv2, Number of cluster id:" + clusterIDs.length + " run id :" + apiEnv.runID) - val tmpClusterEventsSuccessPath = s"$tempWorkingDir/clusterEventsBronze/success" + apiEnv.runID - val tmpClusterEventsErrorPath = s"$tempWorkingDir/clusterEventsBronze/error" + apiEnv.runID + val tmpClusterEventsSuccessPath = s"${config.tempWorkingDir}/clusterEventsBronze/success" + apiEnv.runID + val tmpClusterEventsErrorPath = s"${config.tempWorkingDir}/clusterEventsBronze/error" + apiEnv.runID - landClusterEvents(clusterIDs, startTime, endTime, apiEnv, tmpClusterEventsSuccessPath, tmpClusterEventsErrorPath) + landClusterEvents(clusterIDs, startTime, endTime, apiEnv, tmpClusterEventsSuccessPath, + tmpClusterEventsErrorPath, config) if (Helpers.pathExists(tmpClusterEventsErrorPath)) { persistErrors( spark.read.json(tmpClusterEventsErrorPath) @@ -741,7 +692,7 @@ trait BronzeTransforms extends SparkSessionWrapper { val pathsGlob = validNewFilesWMetaDF .filter(!'failed && 'withinSpecifiedTimeRange) .orderBy('fileSize.desc) - .select('fileName) + .select('filename) .as[String].collect if (pathsGlob.nonEmpty) { // new files less bad files and already-processed files logger.log(Level.INFO, s"VALID NEW EVENT LOGS FOUND: COUNT --> ${pathsGlob.length}") @@ -834,10 +785,20 @@ trait BronzeTransforms extends SparkSessionWrapper { .otherwise(col("Stage Attempt ID")) } else col("Stage Attempt ID") + // raw data contains both "Executor ID" and "executorId" at root for different events + val executorIdOverride: Column = if(baseEventsDF.columns.contains("Executor ID")) { + if (baseEventsDF.columns.contains("executorId")) { // blacklisted executor ids cannot exist if executor ids do not + concat(col("Executor ID"), 'executorId) + } else col("Executor ID") + } else { // handle missing Executor ID field + lit(null).cast("long") + } + val bronzeSparkEventsScrubber = getSparkEventsSchemaScrubber(baseEventsDF) val rawScrubbed = if (baseEventsDF.columns.count(_.toLowerCase().replace(" ", "") == "stageid") > 1) { baseEventsDF + .withColumn("Executor ID", executorIdOverride) .withColumn("progress", progressCol) .withColumn("filename", input_file_name) .withColumn("pathSize", size(split('filename, "/"))) @@ -845,17 +806,18 @@ trait BronzeTransforms extends SparkSessionWrapper { .withColumn("clusterId", split('filename, "/")('pathSize - lit(5))) .withColumn("StageID", stageIDColumnOverride) .withColumn("StageAttemptID", stageAttemptIDColumnOverride) - .drop("pathSize", "Stage ID", "stageId", "Stage Attempt ID", "stageAttemptId") + .drop("pathSize", "executorId", "Stage ID", "stageId", "Stage Attempt ID", "stageAttemptId") .withColumn("filenameGroup", groupFilename('filename)) .scrubSchema(bronzeSparkEventsScrubber) } else { baseEventsDF + .withColumn("Executor ID", executorIdOverride) .withColumn("progress", progressCol) .withColumn("filename", input_file_name) .withColumn("pathSize", size(split('filename, "/"))) .withColumn("SparkContextId", split('filename, "/")('pathSize - lit(2))) .withColumn("clusterId", split('filename, "/")('pathSize - lit(5))) - .drop("pathSize") + .drop("pathSize", "executorId") .withColumn("filenameGroup", groupFilename('filename)) .scrubSchema(bronzeSparkEventsScrubber) } @@ -866,7 +828,6 @@ trait BronzeTransforms extends SparkSessionWrapper { rawScrubbed.withColumn("Properties", SchemaTools.structToMap(rawScrubbed, "Properties")) .withColumn("modifiedConfigs", SchemaTools.structToMap(rawScrubbed, "modifiedConfigs")) .withColumn("extraTags", SchemaTools.structToMap(rawScrubbed, "extraTags")) - .withColumnRenamed("executorId", "blackListedExecutorIds") .join(eventLogsDF, Seq("filename")) .withColumn("organization_id", lit(organizationId)) .withColumn("Properties", expr("map_filter(Properties, (k,v) -> k not in ('sparkexecutorextraClassPath'))")) @@ -897,17 +858,17 @@ trait BronzeTransforms extends SparkSessionWrapper { private[overwatch] def getAllEventLogPrefix(inputDataframe: DataFrame, apiEnv: ApiEnv): DataFrame = { + try{ val mountMap = getMountPointMapping(apiEnv) //Getting the mount info from api and cleaning the data + .withColumn("mount_point", when('mount_point.endsWith("/"), 'mount_point.substr(lit(0), length('mount_point) - 1)).otherwise('mount_point)) .withColumn("source", when('source.endsWith("/"), 'source.substr(lit(0), length('source) - 1)).otherwise('source)) .filter(col("mount_point") =!= "/") - //Cleaning the data for cluster log path val formattedInputDf = inputDataframe.withColumn("cluster_log_conf", when('cluster_log_conf.endsWith("/"), 'cluster_log_conf.substr(lit(0), length('cluster_log_conf) - 1)).otherwise('cluster_log_conf)) .withColumn("cluster_mount_point_temp", regexp_replace('cluster_log_conf, "dbfs:", "")) .withColumn("cluster_mount_point", 'cluster_mount_point_temp) // .withColumn("cluster_mount_point", regexp_replace('cluster_mount_point_temp, "//", "/")) - //Joining the cluster log data with mount point data val joinDF = formattedInputDf .join(mountMap, formattedInputDf.col("cluster_mount_point").startsWith(mountMap.col("mount_point")), "left") //starts with then when @@ -925,11 +886,35 @@ trait BronzeTransforms extends SparkSessionWrapper { val result = pathsDF.select('wildPrefix, 'cluster_id) result + }catch { + case e:Exception=> + logger.log(Level.ERROR,"Unable to get all the event log prefix",e) + throw e + } + } private def getMountPointMapping(apiEnv: ApiEnv): DataFrame = { - val endPoint = "dbfs/search-mounts" - ApiCallV2(apiEnv, endPoint).execute().asDF() + try{ + if (apiEnv.mountMappingPath.nonEmpty) { + logger.log(Level.INFO, "Reading cluster logs from " + apiEnv.mountMappingPath) + spark.read.option("header", "true") + .option("ignoreLeadingWhiteSpace", true) + .option("ignoreTrailingWhiteSpace", true) + .csv(apiEnv.mountMappingPath.get) + .withColumnRenamed("mountPoint","mount_point") + .select("mount_point", "source") + } else { + logger.log(Level.INFO,"Calling dbfs/search-mounts for cluster logs") + val endPoint = "dbfs/search-mounts" + ApiCallV2(apiEnv, endPoint).execute().asDF() + } + }catch { + case e:Exception=> + logger.log(Level.ERROR,"ERROR while reading mount point",e) + throw e + } + } @@ -941,7 +926,8 @@ trait BronzeTransforms extends SparkSessionWrapper { clusterSnapshotTable: PipelineTable, sparkLogClusterScaleCoefficient: Double, apiEnv: ApiEnv, - isMultiWorkSpaceDeployment: Boolean + isMultiWorkSpaceDeployment: Boolean, + organisationId: String )(incrementalAuditDF: DataFrame): DataFrame = { logger.log(Level.INFO, "Collecting Event Log Paths Glob. This can take a while depending on the " + @@ -965,7 +951,8 @@ trait BronzeTransforms extends SparkSessionWrapper { val incrementalClusterWLogging = historicalAuditLookupDF .withColumn("global_cluster_id", cluster_idFromAudit) .select('global_cluster_id.alias("cluster_id"), $"requestParams.cluster_log_conf") - .join(incrementalClusterIDs, Seq("cluster_id")) + // Change for #357 + .join(incrementalClusterIDs.hint("SHUFFLE_HASH"), Seq("cluster_id")) .withColumn("cluster_log_conf", coalesce(get_json_object('cluster_log_conf, "$.dbfs"), get_json_object('cluster_log_conf, "$.s3"))) .withColumn("cluster_log_conf", get_json_object('cluster_log_conf, "$.destination")) .filter('cluster_log_conf.isNotNull) @@ -985,7 +972,7 @@ trait BronzeTransforms extends SparkSessionWrapper { // Build root level eventLog path prefix from clusterID and log conf // /some/log/prefix/cluster_id/eventlog val allEventLogPrefixes = - if(isMultiWorkSpaceDeployment) { + if(isMultiWorkSpaceDeployment && organisationId != Initializer.getOrgId) { getAllEventLogPrefix(newLogDirsNotIdentifiedInAudit .unionByName(incrementalClusterWLogging), apiEnv).select('wildPrefix).distinct() } else { @@ -1022,4 +1009,29 @@ trait BronzeTransforms extends SparkSessionWrapper { } + protected def cleanseRawJobRunsSnapDF(keys: Array[String], runId: String)(df: DataFrame): DataFrame = { + val outputDF = df.scrubSchema + val rawDf = outputDF + .withColumn("Overwatch_RunID", lit(runId)) + .modifyStruct(PipelineFunctions.newClusterCleaner(outputDF, "cluster_spec.new_cluster")) + +// val keys = Array("organization_id", "job_id", "run_id", "Overwatch_RunID") + val emptyKeysDF = Seq.empty[(String, Long, Long, String)].toDF("organization_id", "job_id", "run_id", "Overwatch_RunID") + val cleansedTasksDF = workflowsCleanseTasks(rawDf, keys, emptyKeysDF, "tasks") + val cleansedJobClustersDF = workflowsCleanseJobClusters(rawDf, keys, emptyKeysDF, "job_clusters") + + val changeInventory = Map[String, Column]( + "tasks" -> col("cleansedTasks"), + "job_clusters" -> col("cleansedJobsClusters"), + ) + + val cleanDF = rawDf + .join(cleansedTasksDF, keys.toSeq, "left") + .join(cleansedJobClustersDF, keys.toSeq, "left") + .modifyStruct(changeInventory) + .drop("cleansedTasks", "cleansedJobsClusters") + .scrubSchema(SchemaScrubber(cullNullTypes = true)) + cleanDF + } + } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala index 013a42f28..ea348bb8b 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala @@ -83,7 +83,11 @@ class Gold(_workspace: Workspace, _database: Database, _config: Config) append(GoldTargets.clusterTarget) ) + private val clsfSparkOverrides = Map( + "spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes" -> "67108864" // lower to 64MB due to high skew potential + ) lazy private[overwatch] val clusterStateFactModule = Module(3005, "Gold_ClusterStateFact", this, Array(2019, 2014), 3.0) + .withSparkOverrides(clsfSparkOverrides) lazy private val appendClusterStateFactProccess = ETLDefinition( SilverTargets.clusterStateDetailTarget.asIncrementalDF( clusterStateFactModule, @@ -122,7 +126,8 @@ class Gold(_workspace: Workspace, _database: Database, _config: Config) ) val jrcpSparkOverrides = Map( - "spark.sql.autoBroadcastJoinThreshold" -> "-1" + "spark.sql.autoBroadcastJoinThreshold" -> "-1", + "spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes" -> "67108864" // lower to 64MB due to high skew potential ) lazy private[overwatch] val jobRunCostPotentialFactModule = Module(3015, "Gold_jobRunCostPotentialFact", this, Array(3001, 3003, 3005), 3.0) .withSparkOverrides(jrcpSparkOverrides) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala index a52a9c9ef..5cbeece46 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala @@ -292,7 +292,7 @@ trait GoldTransforms extends SparkSessionWrapper { .lookupWhen(dbuCostDetailsTSDF) .df val clusterPotMetaToFill = Array( - "cluster_name", "custom_tags", "driver_node_type_id", + "cluster_name", "custom_tags", "driver_node_type_id", "runtime_engine", "node_type_id", "spark_version", "sku", "dbu_rate", "driverSpecs", "workerSpecs" ) val clusterPotKeys = Seq("organization_id", "cluster_id") @@ -335,12 +335,15 @@ trait GoldTransforms extends SparkSessionWrapper { val workerPotentialCoreS = when('databricks_billable, $"workerSpecs.vCPUs" * 'current_num_workers * 'uptime_in_state_S).otherwise(lit(0)) - val driverDBUs = when('databricks_billable, $"driverSpecs.Hourly_DBUs" * 'uptime_in_state_H).otherwise(lit(0)).alias("driver_dbus") - val workerDBUs = when('databricks_billable, $"workerSpecs.Hourly_DBUs" * 'current_num_workers * 'uptime_in_state_H).otherwise(lit(0)).alias("worker_dbus") + val isPhotonEnabled = upper('runtime_engine).equalTo("PHOTON") + val isNotAnSQlWarehouse = !upper('sku).equalTo("SQLCOMPUTE") + val photonDBUMultiplier = when(isPhotonEnabled && isNotAnSQlWarehouse, lit(2)).otherwise(lit(1)) + val driverDBUs = when('databricks_billable, $"driverSpecs.Hourly_DBUs" * 'uptime_in_state_H * photonDBUMultiplier).otherwise(lit(0)).alias("driver_dbus") + val workerDBUs = when('databricks_billable, $"workerSpecs.Hourly_DBUs" * 'current_num_workers * 'uptime_in_state_H * photonDBUMultiplier).otherwise(lit(0)).alias("worker_dbus") val driverComputeCost = Costs.compute('cloud_billable, $"driverSpecs.Compute_Contract_Price", lit(1), 'uptime_in_state_H).alias("driver_compute_cost") val workerComputeCost = Costs.compute('cloud_billable, $"workerSpecs.Compute_Contract_Price", 'target_num_workers, 'uptime_in_state_H).alias("worker_compute_cost") - val driverDBUCost = Costs.dbu('databricks_billable, $"driverSpecs.Hourly_DBUs", 'dbu_rate, lit(1), 'uptime_in_state_H, runtimeEngine = 'runtime_engine, sku = 'sku).alias("driver_dbu_cost") - val workerDBUCost = Costs.dbu('databricks_billable, $"workerSpecs.Hourly_DBUs", 'dbu_rate, 'current_num_workers, 'uptime_in_state_H, runtimeEngine = 'runtime_engine, sku = 'sku).alias("worker_dbu_cost") + val driverDBUCost = Costs.dbu(driverDBUs, 'dbu_rate).alias("driver_dbu_cost") + val workerDBUCost = Costs.dbu(workerDBUs, 'dbu_rate).alias("worker_dbu_cost") val clusterStateFactCols: Array[Column] = Array( 'organization_id, @@ -790,7 +793,7 @@ trait GoldTransforms extends SparkSessionWrapper { protected val jobRunCostPotentialFactViewColumnMapping: String = """ |organization_id, workspace_name, job_id, job_name, run_id, job_run_id, task_run_id, task_key, repair_id, run_name, - |startEpochMS, cluster_id, cluster_name, cluster_tags, driver_node_type_id, node_type_id, dbu_rate, + |startEpochMS, cluster_id, cluster_name, cluster_type, cluster_tags, driver_node_type_id, node_type_id, dbu_rate, |multitask_parent_run_id, parent_run_id, task_runtime, task_execution_runtime, terminal_state, |job_trigger_type, task_type, created_by, last_edited_by, running_days, avg_cluster_share, |avg_overlapping_runs, max_overlapping_runs, run_cluster_states, worker_potential_core_H, driver_compute_cost, diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala index 8be96d630..84a09c0f2 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala @@ -388,6 +388,10 @@ class Initializer(config: Config) extends SparkSessionWrapper { if (!disableValidations) validateIntelligentScaling(rawParams.intelligentScaling) config.setIntelligentScaling(rawParams.intelligentScaling) + // as of 0711 + val disabledModulesString = spark(globalSession = true).conf.getOption("overwatch.modules.disabled").getOrElse("0") + config.registerDisabledModules(disabledModulesString) + this } @@ -558,9 +562,10 @@ object Initializer extends SparkSessionWrapper { envInit() def getOrgId: String = { - if (dbutils.notebook.getContext.tags("orgId") == "0") { + val clusterOwnerOrgID = spark.conf.get("spark.databricks.clusterUsageTags.clusterOwnerOrgId") + if (clusterOwnerOrgID == " " || clusterOwnerOrgID == "0") { dbutils.notebook.getContext.apiUrl.get.split("\\.")(0).split("/").last - } else dbutils.notebook.getContext.tags("orgId") + } else clusterOwnerOrgID } private def initConfigState(debugFlag: Boolean,organizationID: Option[String],apiUrl: Option[String]): Config = { @@ -568,7 +573,7 @@ object Initializer extends SparkSessionWrapper { val config = new Config() if(organizationID.isEmpty) { config.setOrganizationId(getOrgId) - }else{ + }else{ // is multiWorkspace deployment since orgID is passed logger.log(Level.INFO, "Setting multiworkspace deployment") config.setOrganizationId(organizationID.get) if (apiUrl.nonEmpty) { @@ -576,9 +581,9 @@ object Initializer extends SparkSessionWrapper { } config.setIsMultiworkspaceDeployment(true) } - config.registerInitialSparkConf(spark.conf.getAll) + // set spark overrides in scoped spark session and override the necessary values for Pipeline Run + config.registerInitialSparkConf(spark(globalSession = true).conf.getAll) config.setInitialWorkerCount(getNumberOfWorkerNodes) - config.setInitialShuffleParts(spark.conf.get("spark.sql.shuffle.partitions").toInt) if (debugFlag) { envInit("DEBUG") config.setDebugFlag(debugFlag) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala index 7d67b7a42..dfda985c1 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala @@ -6,6 +6,7 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.DataFrame import java.time.Duration +import scala.util.parsing.json.JSON.number class Module( val moduleId: Int, @@ -25,6 +26,8 @@ class Module( private var _isFirstRun: Boolean = false + private var _moduleStartMessage: String = "" + private[overwatch] val moduleState: SimplifiedModuleStatusReport = { if (pipeline.getModuleState(moduleId).isEmpty) { initModuleState @@ -36,6 +39,13 @@ class Module( def isFirstRun: Boolean = _isFirstRun + private[overwatch] def moduleStartMessage: String = _moduleStartMessage + + private def setModuleStartMessage(value: String): this.type = { + _moduleStartMessage = value + this + } + def daysToProcess: Int = { Duration.between( fromTime.asLocalDateTime.toLocalDate.atStartOfDay(), @@ -331,15 +341,23 @@ class Module( @throws(classOf[IllegalArgumentException]) def execute(_etlDefinition: ETLDefinition): ModuleStatusReport = { - optimizeShufflePartitions() + + if (config.disabledModules.contains(moduleId)) throw new ModuleDisabled(moduleId, s"MODULE DISABLED: $moduleId-$moduleName") + val shufflePartitions = spark.conf.get("spark.sql.shuffle.partitions") + val notAQEAutoOptimizeShuffle = spark.conf.getOption("spark.databricks.adaptive.autoOptimizeShuffle.enabled").getOrElse("false").toBoolean + if (Helpers.isNumeric(shufflePartitions) && notAQEAutoOptimizeShuffle){ + optimizeShufflePartitions() + } + logger.log(Level.INFO, s"Spark Overrides Initialized for target: $moduleName to\n${sparkOverrides.mkString(", ")}") PipelineFunctions.setSparkOverrides(spark, sparkOverrides, config.debugFlag) - val startMsg = s"\nBeginning: $moduleId-$moduleName\nTIME RANGE: ${fromTime.asTSString} -> ${untilTime.asTSString} --> Workspace ID: ${config.organizationId}" - println(startMsg) - - if (config.debugFlag) println(startMsg) + val startMsg = s"$moduleId-$moduleName --> Workspace ID: ${config.organizationId}\nTIME RANGE: " + + s"${fromTime.asTSString} -> ${untilTime.asTSString}\n" + setModuleStartMessage(startMsg) + println(s"\nBeginning: $startMsg") logger.log(Level.INFO, startMsg) + try { if (fromTime.asUnixTimeMilli == untilTime.asUnixTimeMilli) throw new NoNewDataException("FROM and UNTIL times are identical. Likely due to upstream dependencies " + @@ -357,6 +375,8 @@ class Module( noNewDataHandler(PipelineFunctions.appendStackStrace(e, e.apiCallDetail), Level.ERROR, allowModuleProgression = e.allowModuleProgression) case e: ApiCallFailure if e.failPipeline => fail(PipelineFunctions.appendStackStrace(e, e.msg)) + case e: ModuleDisabled => + fail(e.getMessage) case e: FailedModuleException => val errMessage = s"FAILED: $moduleId-$moduleName Module" logger.log(Level.ERROR, errMessage, e) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala index f0452f35b..3e1706037 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Pipeline.scala @@ -384,14 +384,22 @@ class Pipeline( dbutils.fs.rm(config.tempWorkingDir) postProcessor.refreshPipReportView(pipelineStateViewTarget) - - spark.catalog.clearCache() + //TODO clearcache will clear global cache multithread performance issue + // spark.catalog.clearCache() + clearThreadFromSessionsMap() } + /** + * restore the spark config to the way it was when config / workspace were first instantiated + */ private[overwatch] def restoreSparkConf(): Unit = { restoreSparkConf(config.initialSparkConf) } + /** + * restore spark configs that are passed into this function + * @param value map of configs to be overridden + */ protected def restoreSparkConf(value: Map[String, String]): Unit = { PipelineFunctions.setSparkOverrides(spark, value, config.debugFlag) } @@ -448,9 +456,10 @@ class Pipeline( val rowsWritten = writeOpsMetrics.getOrElse("numOutputRows", "0") val execMins: Double = (endTime - startTime) / 1000.0 / 60.0 val simplifiedExecMins: Double = execMins - (execMins % 0.01) - val msg = s"SUCCESS! ${module.moduleName}\nOUTPUT ROWS: $rowsWritten\nRUNTIME MINS: $simplifiedExecMins --> Workspace ID: ${config.organizationId}" - println(msg) - logger.log(Level.INFO, msg) + val successMessage = s"SUCCESS! ${module.moduleName}\nOUTPUT ROWS: $rowsWritten\nRUNTIME MINS: " + + s"$simplifiedExecMins --> Workspace ID: ${config.organizationId}" + println(s"COMPLETED: ${module.moduleStartMessage} $successMessage") + logger.log(Level.INFO, module.moduleStartMessage ++ successMessage) // Generate Success Report ModuleStatusReport( diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineFunctions.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineFunctions.scala index 1bda55e21..5260578e8 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineFunctions.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineFunctions.scala @@ -378,7 +378,7 @@ object PipelineFunctions extends SparkSessionWrapper { debugFlag: Boolean = false): Unit = { logger.info( s""" - |SPARK OVERRIDES BEING SET: + |SPARK OVERRIDES BEING SET FOR THREAD ${Thread.currentThread().getId}: |${sparkOverrides.mkString("\n")} |""".stripMargin) sparkOverrides foreach { case (k, v) => @@ -515,7 +515,7 @@ object PipelineFunctions extends SparkSessionWrapper { when(isAutomated && isJobsLight, "jobsLight") .when(isAutomated && !isJobsLight, "automated") .when(clusterType === "SQL Analytics", lit("sqlCompute")) - .when(clusterType === "Serverless", lit("serverless")) + .when(clusterType === "High-Concurrency", lit("interactive")) .when(!isAutomated, "interactive") .otherwise("unknown") } @@ -619,6 +619,12 @@ object PipelineFunctions extends SparkSessionWrapper { case 1004 => "audit_log_bronze" case 1005 => "cluster_events_bronze" case 1006 => "spark_events_bronze" + case 1007 => "libs_snapshot_bronze" + case 1008 => "policies_snapshot_bronze" + case 1009 => "instance_profiles_snapshot_bronze" + case 1010 => "tokens_snapshot_bronze" + case 1011 => "global_inits_snapshot_bronze" + case 1012 => "job_runs_snapshot_bronze" case 2003 => "spark_executors_silver" case 2005 => "spark_Executions_silver" case 2006 => "spark_jobs_silver" diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala index 05af1ce21..b76de5708 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala @@ -1,6 +1,7 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.labs.overwatch.pipeline.TransformFunctions._ +import com.databricks.labs.overwatch.utils.MergeScope.MergeScope import com.databricks.labs.overwatch.utils.WriteMode.WriteMode import com.databricks.labs.overwatch.utils._ import io.delta.tables.DeltaTable @@ -24,6 +25,7 @@ case class PipelineTable( format: String = "delta", // TODO -- Convert to Enum persistBeforeWrite: Boolean = false, _mode: WriteMode = WriteMode.append, + mergeScope: MergeScope = MergeScope.full, maxMergeScanDates: Int = 33, // used to create explicit date merge condition -- should be removed after merge dynamic partition pruning is enabled DBR 11.x LTS private val _permitDuplicateKeys: Boolean = true, private val _databaseName: String = "default", @@ -53,6 +55,12 @@ case class PipelineTable( // Minimum Schema Enforcement Management private var withMasterMinimumSchema: Boolean = if (masterSchema.nonEmpty) true else false private var enforceNonNullable: Boolean = if (masterSchema.nonEmpty) true else false + private var _existsConfirmed: Boolean = false + + private def setExists(value: Boolean): Unit = _existsConfirmed = value + private def existsConfirmed: Boolean = _existsConfirmed + + def isStreaming: Boolean = if(checkpointPath.nonEmpty) true else false private def emitMissingMasterSchemaMessage(): Unit = { val msg = s"No Master Schema defined for Table $tableFullName" @@ -127,6 +135,32 @@ case class PipelineTable( // DeltaTable.forName(tableFullName). val tableLocation: String = s"${config.etlDataPathPrefix}/$name".toLowerCase + /** + * is the schema evolving + * @return + */ + def isEvolvingSchema: Boolean = { + name match { // when exists -- locking not necessary if schema does not evolve + case "pipeline_report" => false + case "instanceDetails" => false + case "dbuCostDetails" => false + case "spark_events_processedFiles" => false + case _ => true + } + } + + private[overwatch] def requiresLocking: Boolean = { + // If target's schema is evolving or does not exist true + val enableLocking = if (!exists || isEvolvingSchema) true else false + val logMsg = if(enableLocking) { + s"LOCKING ENABLED for table $name" + } else { + s"LOCKING DISABLED for table $name" + } + logger.log(Level.INFO, logMsg) + enableLocking + } + /** * default catalog only validation * @@ -146,16 +180,33 @@ case class PipelineTable( * @return */ def exists(pathValidation: Boolean = true, dataValidation: Boolean = false, catalogValidation: Boolean = false): Boolean = { - var entityExists = true - if (pathValidation || dataValidation) entityExists = Helpers.pathExists(tableLocation) - if (catalogValidation) entityExists = spark.catalog.tableExists(tableFullName) - if (dataValidation) { // if other validation is enabled it must first pass those for this test to be attempted - // opposite -- when result is empty source data does not exist - entityExists = entityExists && !spark.read.format("delta").load(tableLocation) - .filter(col("organization_id") === config.organizationId) - .isEmpty + // If target already confirmed to exist just return that it exists + // only determine once per state + if (!existsConfirmed) { // if not already confirmed to exist -- check existence + if (pathValidation || dataValidation) { // when path or data validation is enabled + if (format == "delta") { // if delta verify the _delta_log is present not just the path + setExists(Helpers.pathExists(s"$tableLocation/_delta_log")) + } else { // not delta verify the parent dir exists + setExists(Helpers.pathExists(tableLocation)) + } + } + if (catalogValidation) setExists(spark.catalog.tableExists(tableFullName)) + + // if other validation is enabled it must first pass those for this test to be attempted + if (dataValidation && existsConfirmed) { // ++ entity exists to ensure path validation complete + // opposite -- when result is empty source data does not exist + try { + val workspaceDataPresent = !spark.read.format("delta") + .load(tableLocation) + .filter(col("organization_id") === config.organizationId) + .isEmpty + setExists(workspaceDataPresent) + } catch { + case _: Throwable => setExists(false) + } + } } - entityExists + existsConfirmed } /** diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala index 679f196d5..44635c8b0 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala @@ -1,6 +1,6 @@ package com.databricks.labs.overwatch.pipeline -import com.databricks.labs.overwatch.utils.{Config, WriteMode} +import com.databricks.labs.overwatch.utils.{Config, MergeScope, WriteMode} abstract class PipelineTargets(config: Config) { // TODO -- Refactor -- this class should extend workspace so these are "WorkspaceTargets" @@ -65,12 +65,15 @@ abstract class PipelineTargets(config: Config) { lazy private[overwatch] val auditLogsTarget: PipelineTable = PipelineTable( name = "audit_log_bronze", - _keys = Array("serviceName", "actionName", "requestId"), + _keys = Array("timestamp", "serviceName", "actionName", "requestId", "hashKey"), config, incrementalColumns = Array("date", "timestamp"), partitionBy = Seq("organization_id", "date"), statsColumns = ("actionName, requestId, serviceName, sessionId, " + "timestamp, date, Pipeline_SnapTS, Overwatch_RunID").split(", "), + _permitDuplicateKeys = false, + _mode = WriteMode.merge, + mergeScope = MergeScope.insertOnly, masterSchema = Some(Schema.auditMasterSchema) ) @@ -92,6 +95,8 @@ abstract class PipelineTargets(config: Config) { name = "cluster_events_bronze", _keys = Array("cluster_id", "type", "timestamp"), config, + _mode = WriteMode.merge, + mergeScope = MergeScope.insertOnly, partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), incrementalColumns = Array("timestamp"), statsColumns = "cluster_id, timestamp, type, Pipeline_SnapTS, Overwatch_RunID".split(", "), @@ -154,6 +159,53 @@ abstract class PipelineTargets(config: Config) { config ) + lazy private[overwatch] val libsSnapshotTarget: PipelineTable = PipelineTable( + name = "libs_snapshot_bronze", + _keys = Array("cluster_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) + + lazy private[overwatch] val policiesSnapshotTarget: PipelineTable = PipelineTable( + name = "policies_snapshot_bronze", + _keys = Array("policy_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) + + lazy private[overwatch] val instanceProfilesSnapshotTarget: PipelineTable = PipelineTable( + name = "instance_profiles_snapshot_bronze", + _keys = Array("cluster_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) + + lazy private[overwatch] val tokensSnapshotTarget: PipelineTable = PipelineTable( + name = "tokens_snapshot_bronze", + _keys = Array("token_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) + + lazy private[overwatch] val globalInitScSnapshotTarget: PipelineTable = PipelineTable( + name = "global_inits_snapshot_bronze", + _keys = Array("script_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) + + lazy private[overwatch] val jobRunsSnapshotTarget: PipelineTable = PipelineTable( + name = "job_runs_snapshot_bronze", + _keys = Array("job_id", "run_id", "Overwatch_RunID"), + config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Seq("organization_id") + ) } /** diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineView.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineView.scala index cbdc13887..70084bb11 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineView.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineView.scala @@ -24,7 +24,6 @@ case class PipelineView(name: String, } else { partitionMapOverrides } - print(s"partmap is ${partMap}") pubStatementSB.append(s" where ${partMap.head._1} = ${s"delta.`${config.etlDataPathPrefix}/${dataSourceName}`"}.${partMap.head._2} ") if (partMap.keys.toArray.length > 1) { partMap.tail.foreach(pCol => { diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala index ce4cc5574..6541ba096 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala @@ -88,6 +88,7 @@ object Schema extends SparkSessionWrapper { StructField("requestId", StringType, nullable = true), StructField("sessionId", StringType, nullable = true), StructField("version", StringType, nullable = true), + StructField("hashKey", LongType, nullable = true), StructField("requestParams", StructType(Seq( StructField("clusterId", StringType, nullable = true), @@ -1031,26 +1032,26 @@ object Schema extends SparkSessionWrapper { val deployementMinimumSchema:StructType = StructType(Seq( StructField("workspace_name", StringType, nullable = false), - StructField("workspace_id", StringType, nullable = true), - StructField("workspace_url", StringType, nullable = true), - StructField("api_url", StringType, nullable = true), - StructField("cloud", StringType, nullable = true), - StructField("primordial_date", DateType, nullable = true), - StructField("etl_storage_prefix", StringType, nullable = true), - StructField("etl_database_name", StringType, nullable = true), - StructField("consumer_database_name", StringType, nullable = true), - StructField("secret_scope", StringType, nullable = true), - StructField("secret_key_dbpat", StringType, nullable = true), + StructField("workspace_id", StringType, nullable = false), + StructField("workspace_url", StringType, nullable = false), + StructField("api_url", StringType, nullable = false), + StructField("cloud", StringType, nullable = false), + StructField("primordial_date", DateType, nullable = false), + StructField("etl_storage_prefix", StringType, nullable = false), + StructField("etl_database_name", StringType, nullable = false), + StructField("consumer_database_name", StringType, nullable = false), + StructField("secret_scope", StringType, nullable = false), + StructField("secret_key_dbpat", StringType, nullable = false), StructField("auditlogprefix_source_aws", StringType, nullable = true), StructField("eh_name", StringType, nullable = true), StructField("eh_scope_key", StringType, nullable = true), - StructField("interactive_dbu_price", DoubleType, nullable = true), - StructField("automated_dbu_price", DoubleType, nullable = true), - StructField("sql_compute_dbu_price", DoubleType, nullable = true), - StructField("jobs_light_dbu_price", DoubleType, nullable = true), - StructField("max_days", IntegerType, nullable = true), + StructField("interactive_dbu_price", DoubleType, nullable = false), + StructField("automated_dbu_price", DoubleType, nullable = false), + StructField("sql_compute_dbu_price", DoubleType, nullable = false), + StructField("jobs_light_dbu_price", DoubleType, nullable = false), + StructField("max_days", IntegerType, nullable = false), StructField("excluded_scopes", StringType, nullable = true), - StructField("active", BooleanType, nullable = true), + StructField("active", BooleanType, nullable = false), StructField("proxy_host", StringType, nullable = true), StructField("proxy_port", IntegerType, nullable = true), StructField("proxy_user_name", StringType, nullable = true), @@ -1060,6 +1061,13 @@ object Schema extends SparkSessionWrapper { StructField("error_batch_size", IntegerType, nullable = true), StructField("enable_unsafe_SSL", BooleanType, nullable = true), StructField("thread_pool_size", IntegerType, nullable = true), - StructField("api_waiting_time", LongType, nullable = true) + StructField("api_waiting_time", LongType, nullable = true), + StructField("mount_mapping_path", StringType, nullable = true) + )) + + val mountMinimumSchema: StructType = StructType(Seq( + StructField("mountPoint", StringType, nullable = false), + StructField("source", StringType, nullable = false), + StructField("workspace_id", StringType, nullable = false) )) } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala index 47e662db9..1a8d3f03a 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala @@ -178,6 +178,7 @@ class Silver(_workspace: Workspace, _database: Database, _config: Config) "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> "500000", "spark.databricks.delta.optimizeWrite.binSize" -> "2048", "spark.sql.files.maxPartitionBytes" -> (1024 * 1024 * 64).toString, + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> (1024 * 1024 * 32).toString, "spark.sql.autoBroadcastJoinThreshold" -> ((1024 * 1024 * 2).toString), "spark.sql.adaptive.autoBroadcastJoinThreshold" -> ((1024 * 1024 * 2).toString) ) @@ -213,7 +214,7 @@ class Silver(_workspace: Workspace, _database: Database, _config: Config) "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> "500000", "spark.databricks.delta.optimizeWrite.binSize" -> "2048", // output is very dense, shrink output file size "spark.sql.files.maxPartitionBytes" -> (1024 * 1024 * 64).toString, - "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> (1024 * 1024 * 4).toString, + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> (1024 * 1024 * 8).toString, "spark.sql.autoBroadcastJoinThreshold" -> ((1024 * 1024 * 2).toString), "spark.sql.adaptive.autoBroadcastJoinThreshold" -> ((1024 * 1024 * 2).toString) ) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala index 1b9407b75..194b76797 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala @@ -688,7 +688,7 @@ trait SilverTransforms extends SparkSessionWrapper { .orderBy('timestamp).rowsBetween(-1000, Window.currentRow) val isSingleNode = get_json_object(regexp_replace('spark_conf, "\\.", "_"), "$.spark_databricks_cluster_profile") === lit("singleNode") - val isServerless = get_json_object(regexp_replace('spark_conf, "\\.", "_"), + val isHC = get_json_object(regexp_replace('spark_conf, "\\.", "_"), "$.spark_databricks_cluster_profile") === lit("serverless") val isSQLAnalytics = get_json_object('custom_tags, "$.SqlEndpointId").isNotNull val tableAcls = coalesce(get_json_object(regexp_replace('spark_conf, "\\.", "_"), "$.spark_databricks_acl_dfAclsEnabled").cast("boolean"), lit(false)).alias("table_acls_enabled") @@ -702,7 +702,7 @@ trait SilverTransforms extends SparkSessionWrapper { val enableElasticDisk = when('enable_elastic_disk === "false", lit(false)) .otherwise(lit(true)) val deriveClusterType = when(isSingleNode, lit("Single Node")) - .when(isServerless, lit("Serverless")) + .when(isHC, lit("High-Concurrency")) .when(isSQLAnalytics, lit("SQL Analytics")) .otherwise("Standard").alias("cluster_type") @@ -710,8 +710,13 @@ trait SilverTransforms extends SparkSessionWrapper { val clusterBaseWMetaDF = clusterBaseDF // remove start, startResults, and permanentDelete as they do not contain sufficient metadata .filter('actionName.isin("create", "edit")) - val bronzeClusterSnapUntilCurrent = bronze_cluster_snap.asDF - .filter('Pipeline_SnapTS <= untilTime.asColumnTS) + + val lastClusterSnapW = Window.partitionBy('organization_id, 'cluster_id) + .orderBy('Pipeline_SnapTS.desc) + val bronzeClusterSnapLatest = bronze_cluster_snap.asDF + .withColumn("rnk", rank().over(lastClusterSnapW)) + .withColumn("rn", row_number().over(lastClusterSnapW)) + .filter('rnk === 1 && 'rn === 1).drop("rnk", "rn") /** * clusterBaseFilled - if first run, baseline cluster spec for existing clusters that haven't been edited since @@ -724,14 +729,14 @@ trait SilverTransforms extends SparkSessionWrapper { "current initial state for all existing clusters." logger.log(Level.INFO, firstRunMsg) println(firstRunMsg) - val missingClusterIds = bronzeClusterSnapUntilCurrent.select('organization_id, 'cluster_id).distinct + val missingClusterIds = bronzeClusterSnapLatest.select('organization_id, 'cluster_id).distinct .join( clusterBaseWMetaDF .select('organization_id, 'cluster_id).distinct, Seq("organization_id", "cluster_id"), "anti" ) val latestClusterSnapW = Window.partitionBy('organization_id, 'cluster_id).orderBy('Pipeline_SnapTS.desc) - val missingClusterBaseFromSnap = bronzeClusterSnapUntilCurrent + val missingClusterBaseFromSnap = bronzeClusterSnapLatest .join(missingClusterIds, Seq("organization_id", "cluster_id")) .withColumn("rnk", rank().over(latestClusterSnapW)) .filter('rnk === 1).drop("rnk") @@ -941,7 +946,7 @@ trait SilverTransforms extends SparkSessionWrapper { .withColumn("pool_snap_node_type", lit(null).cast("string")) } - val onlyOnceSemanticsW = Window.partitionBy('organization_id, 'cluster_id, 'actionName).orderBy('timestamp) + val onlyOnceSemanticsW = Window.partitionBy('organization_id, 'cluster_id, 'actionName,'timestamp).orderBy('timestamp) clusterBaseWithPoolsAndSnapPools .select(clusterSpecBaseCols: _*) .join(creatorLookup, Seq("organization_id", "cluster_id"), "left") diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/TransformFunctions.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/TransformFunctions.scala index bd6c7a512..12f096768 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/TransformFunctions.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/TransformFunctions.scala @@ -25,6 +25,7 @@ object TransformFunctions { /** * drops columns that contain only nulls + * * @param df dataframe to more data * @return * @@ -79,12 +80,13 @@ object TransformFunctions { * EX: startEvent.joinWithLag(endEvent, keyColumn[s], "laggingDateCol", lagDays = 30, joinType = "left") * The above example with match the join condition on all keyColumn[s] AND where left.laggingDateCol >= * date_sub(right.laggingDateCol, 30) - * @param df2 df to be joined - * @param usingColumns key columns ** less the lagDateColumn + * + * @param df2 df to be joined + * @param usingColumns key columns ** less the lagDateColumn * @param lagDateColumnName name of the lag control column - * @param laggingSide which side of the join is the lagging side - * @param lagDays how many days to allow for lag - * @param joinType join type, one of left | inner | right + * @param laggingSide which side of the join is the lagging side + * @param lagDays how many days to allow for lag + * @param joinType join type, one of left | inner | right * @return */ def joinWithLag( @@ -100,7 +102,7 @@ object TransformFunctions { s"supported, you selected $joinType, switch to supported join type") require( // both sides contain the lagDateColumnName df.schema.fields.exists(_.name == lagDateColumnName) && - df2.schema.fields.exists(_.name == lagDateColumnName), + df2.schema.fields.exists(_.name == lagDateColumnName), s"$lagDateColumnName must exist on both sides of the join" ) @@ -108,9 +110,9 @@ object TransformFunctions { df.schema.fields .filter(f => f.name == lagDateColumnName) .exists(f => f.dataType == TimestampType || f.dataType == DateType) && - df2.schema.fields - .filter(f => f.name == lagDateColumnName) - .exists(f => f.dataType == TimestampType || f.dataType == DateType), + df2.schema.fields + .filter(f => f.name == lagDateColumnName) + .exists(f => f.dataType == TimestampType || f.dataType == DateType), s"$lagDateColumnName must be either a Date or Timestamp type on both sides of the join" ) @@ -121,16 +123,17 @@ object TransformFunctions { (df, df2.suffixDFCols(rightSuffix, allJoinCols, caseSensitive = true)) } else (df.suffixDFCols(leftSuffix, allJoinCols, caseSensitive = true), df2) - val baseJoinCondition = if (joinType == "left" || joinType == "inner"){ + val baseJoinCondition = if (joinType == "left" || joinType == "inner") { usingColumns.map(k => s"$k = ${k}${rightSuffix}").mkString(" AND ") } else usingColumns.map(k => s"$k = ${k}${leftSuffix}").mkString(" AND ") val joinConditionWLag = if (joinType == "left" || joinType == "inner") { if (laggingSide == "left") { - expr(s"$baseJoinCondition AND ${lagDateColumnName} >= date_sub(${lagDateColumnName}${rightSuffix}, $lagDays)") + expr(s"$baseJoinCondition AND ${lagDateColumnName} >= date_sub(${lagDateColumnName}${rightSuffix}, $lagDays)") + } else { + expr(s"$baseJoinCondition AND ${lagDateColumnName}${rightSuffix} >= date_sub(${lagDateColumnName}, $lagDays)") + } } else { - expr(s"$baseJoinCondition AND ${lagDateColumnName}${rightSuffix} >= date_sub(${lagDateColumnName}, $lagDays)") - }} else { if (laggingSide == "left") { expr(s"$baseJoinCondition AND ${lagDateColumnName}${leftSuffix} >= date_sub(${lagDateColumnName}, $lagDays)") } else { @@ -149,6 +152,7 @@ object TransformFunctions { } def requireFields(fieldName: Seq[String]): DataFrame = requireFields(false, fieldName: _*) + def requireFields(caseSensitive: Boolean, fieldName: String*): DataFrame = { fieldName.map(f => { val fWithCase = if (caseSensitive) f else f.toLowerCase @@ -237,13 +241,14 @@ object TransformFunctions { * fills metadata columns in a dataframe using windows * the windows will use the keys and the incrementals to go back as far as needed to get a value * if a value cannot be filled from previous data, first future value will be used to fill - * @param fieldsToFill Array of fields to fill - * @param keys keys by which to partition the window + * + * @param fieldsToFill Array of fields to fill + * @param keys keys by which to partition the window * @param incrementalFields fields by which to order the window - * @param orderedLookups Seq of columns that provide a secondary lookup for the value within the row - * @param noiseBuckets Optional number of buckets to split lookup window - * creates an intermediate step that can help shrink skew on large datasets with - * heavily skewed lookups + * @param orderedLookups Seq of columns that provide a secondary lookup for the value within the row + * @param noiseBuckets Optional number of buckets to split lookup window + * creates an intermediate step that can help shrink skew on large datasets with + * heavily skewed lookups * @return */ def fillMeta( @@ -251,68 +256,72 @@ object TransformFunctions { keys: Seq[String], incrementalFields: Seq[String], orderedLookups: Seq[Column] = Seq[Column](), - noiseBuckets: Int = 0 + noiseBuckets: Int = 1 ) : DataFrame = { + val dfFields = df.columns // generate noise as per the number of noise buckets created - val stepDF = if (noiseBuckets > 0) { - val keysWithNoise = keys :+ "__overwatch_ctrl_noiseBucket" - val wNoise = Window.partitionBy(keysWithNoise map col: _*).orderBy(incrementalFields map col: _*) - val wNoisePrev = wNoise.rowsBetween(Window.unboundedPreceding, Window.currentRow) - val wNoiseNext = wNoise.rowsBetween(Window.currentRow, Window.unboundedFollowing) - - val selectsWithFills = dfFields.map(f => { - if(fieldsToFill.map(_.toLowerCase).contains(f.toLowerCase)) { // field to fill - bidirectionalFill(f, wNoisePrev, wNoiseNext, orderedLookups) - } else { // not a fill field just return original value - col(f) - } - }) - df - .withColumn("__overwatch_ctrl_noiseBucket", round(rand() * noiseBuckets, 0)) - .select(selectsWithFills: _*) + val noiseBucketCount = round(rand() * noiseBuckets, 0) + val keysWithNoise = keys :+ "__overwatch_ctrl_noiseBucket" - } else df + val dfc = df + .withColumn("__overwatch_ctrl_noiseBucket", lit(noiseBucketCount)) + .cache() - val wRaw = Window.partitionBy(keys map col: _*).orderBy(incrementalFields map col: _*) - val wPrev = wRaw.rowsBetween(Window.unboundedPreceding, Window.currentRow) - val wNext = wRaw.rowsBetween(Window.currentRow, Window.unboundedFollowing) + dfc.count() + + val wNoise = Window.partitionBy(keysWithNoise map col: _*).orderBy(incrementalFields map col: _*) + val wNoisePrev = wNoise.rowsBetween(Window.unboundedPreceding, Window.currentRow) + val wNoiseNext = wNoise.rowsBetween(Window.currentRow, Window.unboundedFollowing) val selectsWithFills = dfFields.map(f => { if(fieldsToFill.map(_.toLowerCase).contains(f.toLowerCase)) { // field to fill - bidirectionalFill(f, wPrev, wNext, orderedLookups) + bidirectionalFill(f, wNoisePrev, wNoiseNext, orderedLookups) } else { // not a fill field just return original value col(f) } }) - stepDF - .drop("__overwatch_ctrl_noiseBucket") // drop noise col if exists - .select(selectsWithFills: _*) + + val stepDF = dfc + .select(selectsWithFills :+ col("__overwatch_ctrl_noiseBucket"): _*) + + val lookupSelects = (keys ++ fieldsToFill) ++ Array("unixTimeMS_state_start") + val lookupTSDF = stepDF + .select(lookupSelects map col: _*) + .distinct + .toTSDF("unixTimeMS_state_start", keys: _*) + + dfc.toTSDF("unixTimeMS_state_start", keys: _*) + .lookupWhen(lookupTSDF, maxLookAhead = 1000L) + .df } /** * remove dups via a window - * @param keys seq of keys for the df + * + * @param keys seq of keys for the df * @param incrementalFields seq of incremental fields for the df * @return */ def dedupByKey( - keys: Seq[String], - incrementalFields: Seq[String] - ): DataFrame = { -// val keysLessIncrementals = (keys.toSet -- incrementalFields.toSet).toArray - val w = Window.partitionBy(keys map col: _*).orderBy(incrementalFields map col: _*) + keys: Seq[String], + incrementalFields: Seq[String] + ): DataFrame = { + // val keysLessIncrementals = (keys.toSet -- incrementalFields.toSet).toArray + val distinctKeys = (keys ++ incrementalFields).toSet.toArray + val w = Window.partitionBy(distinctKeys map col: _*).orderBy(incrementalFields map col: _*) df .withColumn("rnk", rank().over(w)) .withColumn("rn", row_number().over(w)) - .filter(col("rnk") === 1 && col("rn") === 1) + .filter(col("rnk") === 1 && col("rn") === 1) .drop("rnk", "rn") } /** * Supports strings, numericals, booleans. Defined keys don't contain any other types thus this function should * ensure no nulls present for keys + * * @return */ def fillAllNAs: DataFrame = { @@ -363,13 +372,14 @@ object TransformFunctions { /** * appends fields to an existing struct - * @param structFieldName name of struct to which namedColumns should be applied - * @param namedColumns Array of NamedColumn + * + * @param structFieldName name of struct to which namedColumns should be applied + * @param namedColumns Array of NamedColumn * @param overrideExistingStructCols Whether or not to override the value of existing struct field if it exists - * @param newStructFieldName If not provided, the original struct will be morphed, if a secondary struct is desired - * provide a name here and the original struct will not be altered. - * New, named struct will be added to the top level - * @param caseSensitive whether or not the field names are case sensitive + * @param newStructFieldName If not provided, the original struct will be morphed, if a secondary struct is desired + * provide a name here and the original struct will not be altered. + * New, named struct will be added to the top level + * @param caseSensitive whether or not the field names are case sensitive * @return */ def appendToStruct( @@ -381,8 +391,8 @@ object TransformFunctions { ): DataFrame = { require(df.hasFieldNamed(structFieldName, caseSensitive), s"ERROR: Dataframe must contain the struct field to be altered. " + - s"$structFieldName was not found. Struct fields include " + - s"${df.schema.fields.filter(_.dataType.typeName == "struct").map(_.name).mkString(", ")}" + s"$structFieldName was not found. Struct fields include " + + s"${df.schema.fields.filter(_.dataType.typeName == "struct").map(_.name).mkString(", ")}" ) val fieldToAlterTypeName = df.select(structFieldName).schema.fields.head.dataType.typeName @@ -400,9 +410,9 @@ object TransformFunctions { } - private def bidirectionalFill(colToFillName: String, wPrev: WindowSpec, wNext: WindowSpec, orderedLookups: Seq[Column] = Seq[Column]()) : Column = { + private def bidirectionalFill(colToFillName: String, wPrev: WindowSpec, wNext: WindowSpec, orderedLookups: Seq[Column] = Seq[Column]()): Column = { val colToFill = col(colToFillName) - if (orderedLookups.nonEmpty){ // TODO -- omit nulls from lookup + if (orderedLookups.nonEmpty) { // TODO -- omit nulls from lookup val coalescedLookup = Array(colToFill) ++ orderedLookups.map(lookupCol => { last(lookupCol, true).over(wPrev) }) ++ orderedLookups.map(lookupCol => { @@ -429,31 +439,26 @@ object TransformFunctions { ) } + /** + * Calculates DBU Costs + * @param isDatabricksBillable is state in DBU Billable State + * @param dbus dbus from the state + * @param dbuRate_H cost of a dbu per hour for the given sku + * @param nodeCount number of nodes considered in the calculation + * @param stateTime uptime (wall time) in state + * @param smoothingCol coefficient derived from the smooth function + * @return + */ def dbu( - isDatabricksBillable: Column, - dbu_H: Column, + dbus: Column, dbuRate_H: Column, - nodeCount: Column, - computeTime_H: Column, - runtimeEngine: Column, - sku: Column, smoothingCol: Option[Column] = None ): Column = { - //Check if the cluster is enabled with Photon or not - val isPhotonEnabled = upper(runtimeEngine).equalTo("PHOTON") - //Check if the cluster is not a SQL warehouse/endpoint - val isNotAnSQlWarehouse = !upper(sku).equalTo("SQLCOMPUTE") - //This is the default logic for DBU calculation - val defaultCalculation = dbu_H * computeTime_H * nodeCount * dbuRate_H * smoothingCol.getOrElse(lit(1)) - val dbuMultiplier = 2 - //assign the variables and return column with calculation - coalesce( - when(isDatabricksBillable && isPhotonEnabled && isNotAnSQlWarehouse , defaultCalculation * dbuMultiplier) - .when(isDatabricksBillable, defaultCalculation) - .otherwise(lit(0)), - lit(0) // don't allow costs to be null (i.e. missing worker node type and/or single node workers - ) + //This is the default logic for DBU calculation + dbus * dbuRate_H +// val defaultCalculation = dbus * stateTime * nodeCount * dbuRate_H * smoothingCol.getOrElse(lit(1)) +// when(isDatabricksBillable, defaultCalculation).otherwise(lit(0.0)) } /** diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/WorkflowsTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/WorkflowsTransforms.scala index 8b034f1fe..e96388b00 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/WorkflowsTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/WorkflowsTransforms.scala @@ -18,6 +18,15 @@ object WorkflowsTransforms extends SparkSessionWrapper { * BEGIN Workflow generic functions */ + /** + * When colIfExists expression returns a null type due to conversion to a different type set it to null and + * type case it to the definedNullType + * @param df df to fix + * @param fieldName name of the field to test for null types + * @param colIfExists expression that creates the field + * @param definedNullType type to cast to when expression results in nullType + * @return + */ private def handleRootNull(df: DataFrame, fieldName: String, colIfExists: Column, definedNullType: DataType): Column = { val nullField = lit(null).cast(definedNullType).alias(fieldName) if (SchemaTools.nestedColExists(df.schema, fieldName)) { @@ -38,6 +47,16 @@ object WorkflowsTransforms extends SparkSessionWrapper { .filter('rnk === 1 && 'rn === 1).drop("rnk", "rn") } + /** + * Clean "tasks" field. This is done by exploding the the tasks, cleaning them and then rewrapping them to array + * using collect_list + * @param df df that contains the "tasks" field within the jobs / jobruns context + * @param keys dataframe keys as defined in pipeline_target + * @param emptyKeysDF empty DF with the typed keys and no data to protect against empty arrays + * @param pathToTasksField dot-delimited location of tasks such as settings.tasks + * @param cleansedTaskAlias alias of the clean tasks to return + * @return + */ def workflowsCleanseTasks( df: DataFrame, keys: Array[String], @@ -72,6 +91,16 @@ object WorkflowsTransforms extends SparkSessionWrapper { } else emptyDFWKeysAndCleansedTasks // build empty DF with keys to allow the subsequent joins } + /** + * Clean "job_clusters" field. This is done by exploding the the job_clusters, cleaning them and then + * rewrapping them to array using collect_list + * @param df df that contains the "job_clusters" field within the jobs / jobruns context + * @param keys dataframe keys as defined in pipeline_target + * @param emptyKeysDF empty DF with the typed keys and no data to protect against empty arrays + * @param pathToJobClustersField dot-delimited location of tasks such as settings.job_clusters + * @param cleansedJobClustersAlias alias of the clean job_clusters to return + * @return + */ def workflowsCleanseJobClusters( df: DataFrame, keys: Array[String], diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala index bc0f3676b..a2e0cacfb 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala @@ -33,13 +33,13 @@ class Config() { private var _badRecordsPath: String = _ private var _primordialDateString: Option[String] = None private var _maxDays: Int = 60 + private var _disabledModules: Array[Int] = Array[Int]() private var _initialWorkerCount: Int = _ private var _intelligentScaling: IntelligentScaling = IntelligentScaling() private var _passthroughLogPath: Option[String] = None private var _inputConfig: OverwatchParams = _ private var _overwatchScope: Seq[OverwatchScope.Value] = OverwatchScope.values.toSeq private var _initialSparkConf: Map[String, String] = Map() - private var _intialShuffleParts: Int = 200 private var _contractInteractiveDBUPrice: Double = _ private var _contractAutomatedDBUPrice: Double = _ private var _contractSQLComputeDBUPrice: Double = _ @@ -55,6 +55,9 @@ class Config() { * The next section is getters that provide access to local configuration variables. Only adding details where * the getter may be obscure or more complicated. */ + + // as of 0711 + def disabledModules: Array[Int] = _disabledModules def isMultiworkspaceDeployment: Boolean = _isMultiworkspaceDeployment def apiUrl: Option[String] = _apiUrl @@ -77,8 +80,6 @@ class Config() { def cloudProvider: String = _cloudProvider - def initialShuffleParts: Int = _intialShuffleParts - def maxDays: Int = _maxDays def initialWorkerCount: Int = _initialWorkerCount @@ -140,28 +141,17 @@ class Config() { def overwatchScope: Seq[OverwatchScope.Value] = _overwatchScope + /** + * override spark confs with Overwatch global overrides + * meant to be used only from Initializer to ensure at the beginning of a PipelineRun all the spark Confs are set + * correctly + * @param value + * @return + */ private[overwatch] def registerInitialSparkConf(value: Map[String, String]): this.type = { - val manualOverrides = Map( - "spark.databricks.delta.properties.defaults.autoOptimize.autoCompact" -> - value.getOrElse("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact", "false"), - "spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite" -> - value.getOrElse("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite", "false"), - "spark.databricks.delta.optimize.maxFileSize" -> - value.getOrElse("spark.databricks.delta.optimize.maxFileSize", (1024 * 1024 * 128).toString), - "spark.databricks.delta.retentionDurationCheck.enabled" -> - value.getOrElse("spark.databricks.delta.retentionDurationCheck.enabled", "true"), - "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> - value.getOrElse("spark.databricks.delta.optimizeWrite.numShuffleBlocks", "50000"), - "spark.databricks.delta.optimizeWrite.binSize" -> - value.getOrElse("spark.databricks.delta.optimizeWrite.binSize", "512"), - "spark.sql.shuffle.partitions" -> "400", // allow aqe to shrink - "spark.sql.caseSensitive" -> "false", - "spark.sql.autoBroadcastJoinThreshold" -> "10485760", - "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", - "spark.databricks.delta.schema.autoMerge.enabled" -> "true", - "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 - ) - _initialSparkConf = value ++ manualOverrides + logger.log(Level.INFO, s"Config Initialized with Spark Overrides of:\n" + + s"${SparkSessionWrapper.globalSparkConfOverrides}") + _initialSparkConf = value ++ SparkSessionWrapper.globalSparkConfOverrides this } @@ -185,19 +175,13 @@ class Config() { * BEGIN SETTERS */ - /** - * Identify the initial value before overwatch for shuffle partitions. This value gets modified a lot through - * this process but should be set back to the same as the value before Overwatch process when Overwatch finishes - * its work - * - * @param value number of shuffle partitions to be set - * @return - */ - private[overwatch] def setInitialShuffleParts(value: Int): this.type = { - _intialShuffleParts = value + // as of 0711 + private[overwatch] def registerDisabledModules(value: String): this.type = { + val disabledModulesArray = value.replaceAll("\\s", "").split(",").map(_.toInt) + _disabledModules = disabledModulesArray + logger.log(Level.INFO, s"DISABLING MODULES: ${disabledModulesArray.mkString(", ")}") this } - private[overwatch] def setMaxDays(value: Int): this.type = { _maxDays = value this @@ -389,7 +373,8 @@ class Config() { setApiEnv(ApiEnv(isLocalTesting, workspaceURL, rawToken, packageVersion, derivedApiEnvConfig.successBatchSize, derivedApiEnvConfig.errorBatchSize, runID, derivedApiEnvConfig.enableUnsafeSSL, derivedApiEnvConfig.threadPoolSize, derivedApiEnvConfig.apiWaitingTime, derivedApiProxy.proxyHost, derivedApiProxy.proxyPort, - derivedApiProxy.proxyUserName, derivedApiProxy.proxyPasswordScope, derivedApiProxy.proxyPasswordKey + derivedApiProxy.proxyUserName, derivedApiProxy.proxyPasswordScope, derivedApiProxy.proxyPasswordKey , + derivedApiEnvConfig.mountMappingPath )) this diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala index 18c5bae76..b7134e933 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/SparkSessionWrapper.scala @@ -1,8 +1,39 @@ package com.databricks.labs.overwatch.utils +import com.databricks.labs.overwatch.utils.SparkSessionWrapper.parSessionsOn import org.apache.log4j.{Level, Logger} import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession +import org.eclipse.jetty.util.ConcurrentHashSet + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + + +object SparkSessionWrapper { + + var parSessionsOn = false + private[overwatch] val sessionsMap = new ConcurrentHashMap[Long, SparkSession]().asScala + private[overwatch] val globalTableLock = new ConcurrentHashSet[String] + private[overwatch] val globalSparkConfOverrides = Map( + "spark.sql.shuffle.partitions" -> "800", // allow aqe to shrink + "spark.databricks.adaptive.autoOptimizeShuffle.enabled" -> "true", // enable AQE + "spark.databricks.delta.optimize.maxFileSize" -> "134217728", // 128 MB default + "spark.sql.files.maxPartitionBytes" -> "134217728", // 128 MB default + "spark.sql.caseSensitive" -> "false", + "spark.sql.autoBroadcastJoinThreshold" -> "10485760", + "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "10485760", + "spark.databricks.delta.schema.autoMerge.enabled" -> "true", + "spark.databricks.delta.properties.defaults.autoOptimize.autoCompact" -> "false", + "spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite" -> "false", + "spark.databricks.delta.retentionDurationCheck.enabled" -> "true", + "spark.databricks.delta.optimizeWrite.numShuffleBlocks" -> "50000", + "spark.databricks.delta.optimizeWrite.binSize" -> "512", + "spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes" -> "268435456", // reset to default 256MB + "spark.sql.optimizer.collapseProjectAlwaysInline" -> "true" // temporary workaround ES-318365 + ) + +} /** * Enables access to the Spark variable. @@ -12,42 +43,54 @@ import org.apache.spark.sql.SparkSession trait SparkSessionWrapper extends Serializable { private val logger: Logger = Logger.getLogger(this.getClass) + private val sessionsMap = SparkSessionWrapper.sessionsMap /** * Init environment. This structure alows for multiple calls to "reinit" the environment. Important in the case of * autoscaling. When the cluster scales up/down envInit and then check for current cluster cores. */ - @transient lazy protected val _envInit: Boolean = envInit() + + private def buildSpark(): SparkSession = { + SparkSession + .builder() + .appName("Overwatch - GlobalSession") + .getOrCreate() + } /** * Access to spark * If testing locally or using DBConnect, the System variable "OVERWATCH" is set to "LOCAL" to make the code base * behavior differently to work in remote execution AND/OR local only mode but local only mode * requires some additional setup. */ - lazy val spark: SparkSession = if (System.getenv("OVERWATCH") != "LOCAL") { - logger.log(Level.INFO, "Using Databricks SparkSession") - SparkSession - .builder().master("local").appName("OverwatchBatch") - .getOrCreate() - } else { - logger.log(Level.INFO, "Using Custom, local SparkSession") - SparkSession.builder() - .master("local[*]") - .config("spark.driver.maxResultSize", "8g") - .appName("OverwatchBatch") -// Useful configs for local spark configs and/or using labs/spark-local-execution -// https://github.com/databricks-academy/spark-local-execution -// .config("spark.driver.bindAddress", "0.0.0.0") -// .enableHiveSupport() -// .config("spark.warehouse.dir", "metastore") - .getOrCreate() + private[overwatch] def spark(globalSession : Boolean = false): SparkSession = { + + if(SparkSessionWrapper.parSessionsOn){ + if(globalSession){ + buildSpark() + } + else{ + val currentThreadID = Thread.currentThread().getId + val sparkSession = sessionsMap.getOrElse(currentThreadID, buildSpark().newSession()) + sessionsMap.put(currentThreadID, sparkSession) + sparkSession + } + }else{ + buildSpark() + } } + @transient lazy val spark:SparkSession = spark(false) + lazy val sc: SparkContext = spark.sparkContext // sc.setLogLevel("WARN") + protected def clearThreadFromSessionsMap(): Unit ={ + sessionsMap.remove(Thread.currentThread().getId) + logger.log(Level.INFO, s"""Removed ${Thread.currentThread().getId} from sessionMap""") + } + def getCoresPerWorker: Int = sc.parallelize("1", 1) .map(_ => java.lang.Runtime.getRuntime.availableProcessors).collect()(0) diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala index 29973654f..3b3b15529 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala @@ -49,7 +49,8 @@ case class ApiEnv( proxyPort: Option[Int] = None, proxyUserName: Option[String] = None, proxyPasswordScope: Option[String] = None, - proxyPasswordKey: Option[String] = None + proxyPasswordKey: Option[String] = None, + mountMappingPath: Option[String] = None ) @@ -59,7 +60,8 @@ case class ApiEnvConfig( enableUnsafeSSL: Boolean = false, threadPoolSize: Int = 4, apiWaitingTime: Long = 300000, - apiProxyConfig: Option[ApiProxyConfig] = None + apiProxyConfig: Option[ApiProxyConfig] = None, + mountMappingPath: Option[String] = None ) case class ApiProxyConfig( @@ -81,15 +83,15 @@ case class MultiWorkspaceConfig(workspace_name: String, consumer_database_name: String, secret_scope: String, secret_key_dbpat: String, - auditlogprefix_source_aws: String, - eh_name: String, - eh_scope_key: String, + auditlogprefix_source_aws: Option[String], + eh_name: Option[String], + eh_scope_key: Option[String], interactive_dbu_price: Double, automated_dbu_price: Double, sql_compute_dbu_price: Double, jobs_light_dbu_price: Double, max_days: Int, - excluded_scopes: String, + excluded_scopes: Option[String], active: Boolean, proxy_host: Option[String] = None, proxy_port: Option[Int] = None, @@ -101,6 +103,7 @@ case class MultiWorkspaceConfig(workspace_name: String, enable_unsafe_SSL: Option[Boolean]= None, thread_pool_size: Option[Int] = None, api_waiting_time: Option[Long] = None, + mount_mapping_path: Option[String], deployment_id: String, output_path: String ) @@ -329,6 +332,18 @@ object WriteMode extends Enumeration { val merge: Value = Value("merge") } +/** + * insertOnly = whenNotMatched --> Insert + * updateOnly = whenMatched --> update + * full = both insert and update + */ +object MergeScope extends Enumeration { + type MergeScope = Value + val full: Value = Value("full") + val insertOnly: Value = Value("insertOnly") + val updateOnly: Value = Value("updateOnly") +} + // Todo Issue_56 private[overwatch] class NoNewDataException(s: String, val level: Level, val allowModuleProgression: Boolean = false) extends Exception(s) {} @@ -340,6 +355,11 @@ private[overwatch] class ApiCallEmptyResponse(val apiCallDetail: String, val all private[overwatch] class ApiCallFailureV2(s: String) extends Exception(s) {} +private[overwatch] class ModuleDisabled(moduleId: Int, s: String) extends Exception(s) { + private val logger = Logger.getLogger("ModuleDisabled") + logger.log(Level.INFO, s"MODULE DISABLED: MODULE_ID: $moduleId -- SKIPPING") +} + private[overwatch] class ApiCallFailure( val httpResponse: HttpResponse[String], apiCallDetail: String, diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala index a2f9b6320..aee5bf1bb 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala @@ -145,6 +145,18 @@ object Helpers extends SparkSessionWrapper { import spark.implicits._ + /** + * Checks whether the provided String value is numeric.//We also need to check for double/float, + * ticket TODO#770 has been created for the same. + * TODO Unit test case for the function. + * + * @param value + * @return + */ + def isNumeric(value:String):Boolean={ + value.forall(Character.isDigit) + } + /** * Getter for parallelism between 8 and driver cores * @@ -539,19 +551,30 @@ object Helpers extends SparkSessionWrapper { * Cannot derive schemas < 0.6.0.3 * * @param etlDB Overwatch ETL database + * @param organization_id Optional - Use only when trying to instantiate remote deployment - org id of remote workspace + * @param apiUrl Optiona - Use only when trying to instantiate remote deployment apiURL of remote workspace + * @param successfullOnly Only consider successful runs when looking for latest config + * @param disableValidations Whether or not to have initializer disable validations * @return */ - def getWorkspaceByDatabase(etlDB: String, successfulOnly: Boolean = true, disableValidations: Boolean = false): Workspace = { + def getWorkspaceByDatabase( + etlDB: String, + organization_id: Option[String] = None, + apiUrl: Option[String] = None, + successfullOnly: Boolean = true, + disableValidations: Boolean = false + ): Workspace = { // verify database exists assert(spark.catalog.databaseExists(etlDB), s"The database provided, $etlDB, does not exist.") val dbMeta = spark.sessionState.catalog.getDatabaseMetadata(etlDB) val dbProperties = dbMeta.properties + val isRemoteWorkspace = organization_id.nonEmpty // verify database is owned and managed by Overwatch assert(dbProperties.getOrElse("OVERWATCHDB", "FALSE") == "TRUE", s"The database provided, $etlDB, is not an Overwatch managed Database. Please provide an Overwatch managed database") - val workspaceID = Initializer.getOrgId + val workspaceID = if (isRemoteWorkspace) organization_id.get else Initializer.getOrgId - val statusFilter = if (successfulOnly) 'status === "SUCCESS" else lit(true) + val statusFilter = if (successfullOnly) 'status === "SUCCESS" else lit(true) val latestConfigByOrg = Window.partitionBy('organization_id).orderBy('Pipeline_SnapTS.desc) val testConfig = spark.table(s"${etlDB}.pipeline_report") @@ -563,7 +586,25 @@ object Helpers extends SparkSessionWrapper { .select(to_json('inputConfig).alias("compactString")) .as[String].first() - Initializer(testConfig, disableValidations = disableValidations) + val workspace = if (isRemoteWorkspace) { // single workspace deployment + Initializer(testConfig, disableValidations = true) + } else { // multi workspace deployment + Initializer( + testConfig, + disableValidations = disableValidations, + apiURL = apiUrl, + organizationID = organization_id + ) + } + + // set cloud provider for remote workspaces + if (isRemoteWorkspace && workspace.getConfig.auditLogConfig.rawAuditPath.nonEmpty) { + workspace.getConfig.setCloudProvider("aws") + } + if (isRemoteWorkspace && workspace.getConfig.auditLogConfig.rawAuditPath.isEmpty) { + workspace.getConfig.setCloudProvider("azure") + } + workspace } /** @@ -650,9 +691,7 @@ object Helpers extends SparkSessionWrapper { val remoteConfig = remoteWorkspace.getConfig val etlDatabaseNameToCreate = if (localETLDatabaseName == "" & !usingExternalMetastore) {remoteConfig.databaseName} else {localETLDatabaseName} val consumerDatabaseNameToCreate = if (localConsumerDatabaseName == "" & !usingExternalMetastore) {remoteConfig.consumerDatabaseName} else {localConsumerDatabaseName} - val LocalWorkSpaceID = if (dbutils.notebook.getContext.tags("orgId") == "0") { - dbutils.notebook.getContext.apiUrl.get.split("\\.")(0).split("/").last - } else dbutils.notebook.getContext.tags("orgId") + val LocalWorkSpaceID = Initializer.getOrgId val localETLDBPath = if (!usingExternalMetastore ){ Some(s"${remoteStoragePrefix}/${LocalWorkSpaceID}/${etlDatabaseNameToCreate}.db") @@ -727,6 +766,7 @@ object Helpers extends SparkSessionWrapper { val incrementalFilters = incrementalFields.map(f => { f.dataType.typeName match { case "long" => s"${f.name} >= ${rollbackToTime.asUnixTimeMilli}" + case "double" => s"""cast(${f.name} as long) >= ${rollbackToTime.asUnixTimeMilli}""" case "date" => s"${f.name} >= '${rollbackToTime.asDTString}'" case "timestamp" => s"${f.name} >= '${rollbackToTime.asTSString}'" } @@ -793,7 +833,7 @@ object Helpers extends SparkSessionWrapper { if (dryRun) println("DRY RUN: Nothing will be changed") val config = workspace.getConfig val orgFilter = if (workspaceIds.isEmpty) { - 'organization_id === config.organizationId + throw new Exception("""workspaceIDs cannot be empty. To rollback all workspaces use "global" in the array """) } else if (workspaceIds.headOption.getOrElse(config.organizationId) == "global") { lit(true) } else { @@ -817,9 +857,10 @@ object Helpers extends SparkSessionWrapper { s"${rollbackTSByModule.map(_.organization_id).distinct.mkString(", ")}") rollbackPipelineStateToTimestamp(rollbackTSByModule, customRollbackStatus, config, dryRun) - val allTargets = Bronze(workspace, suppressReport = true, suppressStaticDatasets = true).getAllTargets ++ + val allTargets = (Bronze(workspace, suppressReport = true, suppressStaticDatasets = true).getAllTargets ++ Silver(workspace, suppressReport = true, suppressStaticDatasets = true).getAllTargets ++ - Gold(workspace, suppressReport = true, suppressStaticDatasets = true).getAllTargets + Gold(workspace, suppressReport = true, suppressStaticDatasets = true).getAllTargets) + .filter(_.exists(pathValidation = false, catalogValidation = true)) val targetsToRollback = rollbackTSByModule.map(rollback => { val targetTableName = PipelineFunctions.getTargetTableNameByModule(rollback.moduleId) diff --git a/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala b/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala index d2171473b..79a6b2f07 100644 --- a/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala +++ b/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala @@ -3,7 +3,7 @@ package com.databricks.labs.overwatch.validation import com.databricks.dbutils_v1.DBUtilsHolder.dbutils import com.databricks.labs.overwatch.ApiCallV2 import com.databricks.labs.overwatch.pipeline.TransformFunctions._ -import com.databricks.labs.overwatch.pipeline.{Pipeline, PipelineFunctions, Schema} +import com.databricks.labs.overwatch.pipeline.{Initializer, Pipeline, PipelineFunctions, Schema} import com.databricks.labs.overwatch.utils.SchemaTools.structFromJson import com.databricks.labs.overwatch.utils._ import com.databricks.labs.validation.{Rule, RuleSet} @@ -98,44 +98,66 @@ object DeploymentValidation extends SparkSessionWrapper { } } - private def validateMountCount(conf: MultiWorkspaceConfig): DeploymentValidationReport = { + /** + * Validates the content of provided mount_mapping_path csv file.Below are the validation rules. + * 1)validate for file existence. + * 2)validate the provided csv file belongs to the provided workspace_id. + * 3)validate the provided csv file contains columns "mountPoint", "source","workspace_id" and has some values in it. + * + * @param conf + * @return + */ + private def validateMountMappingPath(conf: MultiWorkspaceConfig): DeploymentValidationReport = { + // get fine here -- already verified non-empty in calling function + val path = conf.mount_mapping_path.get.trim val testDetails = s"""WorkSpaceMountTest - |APIURL:${conf.api_url} - |DBPATWorkspaceScope:${conf.secret_scope} - |SecretKey_DBPAT:${conf.secret_key_dbpat}""".stripMargin + |mount_mapping_path:${path} + """.stripMargin + try { - val patToken = dbutils.secrets.get(scope = conf.secret_scope, key = conf.secret_key_dbpat) - val apiEnv = ApiEnv(false, conf.api_url, patToken, getClass.getPackage.getImplementationVersion) - val endPoint = "dbfs/search-mounts" - val mountCount = ApiCallV2(apiEnv, endPoint).execute().asDF().count() - if(mountCount<50) - { - DeploymentValidationReport(true, - getSimpleMsg("Validate_Mount"), - testDetails, - Some("SUCCESS"), - Some(conf.workspace_id) - ) - }else{ + if (!Helpers.pathExists(path)) { DeploymentValidationReport(false, getSimpleMsg("Validate_Mount"), testDetails, - Some("Number of mounts found in workspace is more than 50"), + Some("Unable to find the provided csv: " + path), Some(conf.workspace_id) ) + } else { + val inputDf = spark.read.option("header", "true") + .option("ignoreLeadingWhiteSpace", true) + .option("ignoreTrailingWhiteSpace", true) + .csv(path) + .filter('source.isNotNull) + .verifyMinimumSchema(Schema.mountMinimumSchema) + .select("mountPoint", "source","workspace_id") + .filter('workspace_id === conf.workspace_id) + + + val dataCount = inputDf.count() + if (dataCount > 0) { + DeploymentValidationReport(true, + getSimpleMsg("Validate_Mount"), + s"""WorkSpaceMountTest + |mount_mapping_path:${path} + |mount points found:${dataCount} + """.stripMargin, + Some("SUCCESS"), + Some(conf.workspace_id) + ) + } else { + DeploymentValidationReport(false, + getSimpleMsg("Validate_Mount"), + testDetails, + Some(s"""No data found for workspace_id: ${conf.workspace_id} in provided csv: ${path}"""), + Some(conf.workspace_id) + ) + } } - } catch { - case exception: Exception => - val msg = - s"""No Data retrieved - |WorkspaceId:${conf.workspace_id} - |APIURL:${conf.api_url} - | DBPATWorkspaceScope:${conf.secret_scope} - | SecretKey_DBPAT:${conf.secret_key_dbpat}""".stripMargin - val fullMsg = PipelineFunctions.appendStackStrace(exception, msg) + case e: Exception => + val fullMsg = PipelineFunctions.appendStackStrace(e, s"""Exception while reading the mount_mapping_path :${path}""") logger.log(Level.ERROR, fullMsg) DeploymentValidationReport(false, getSimpleMsg("Validate_Mount"), @@ -143,9 +165,74 @@ object DeploymentValidation extends SparkSessionWrapper { Some(fullMsg), Some(conf.workspace_id) ) + } + } + private def validateMountCount(conf: MultiWorkspaceConfig): DeploymentValidationReport = { + + val isAzure = conf.cloud.toLowerCase == "azure" //Mount-point validation is only done for Azure + val isRemoteWorkspace = conf.workspace_id.trim != Initializer.getOrgId // No need to perform mount-point validation for driver workspace. + val isMountMappingPathProvided = conf.mount_mapping_path.nonEmpty + + if (isAzure && isRemoteWorkspace) { //Performing mount test + if (isMountMappingPathProvided) { + validateMountMappingPath(conf) + } else { + val testDetails = + s"""WorkSpaceMountTest + |APIURL:${conf.api_url} + |DBPATWorkspaceScope:${conf.secret_scope} + |SecretKey_DBPAT:${conf.secret_key_dbpat}""".stripMargin + try { + val patToken = dbutils.secrets.get(scope = conf.secret_scope, key = conf.secret_key_dbpat) + val apiEnv = ApiEnv(false, conf.api_url, patToken, getClass.getPackage.getImplementationVersion) + val endPoint = "dbfs/search-mounts" + val mountCount = ApiCallV2(apiEnv, endPoint).execute().asDF().count() + if (mountCount < 50) { + DeploymentValidationReport(true, + getSimpleMsg("Validate_Mount"), + testDetails, + Some("SUCCESS"), + Some(conf.workspace_id) + ) + } else { + DeploymentValidationReport(false, + getSimpleMsg("Validate_Mount"), + testDetails, + Some("Number of mounts found in workspace is more than 50"), + Some(conf.workspace_id) + ) + } + + } catch { + case exception: Exception => + val msg = + s"""No Data retrieved + |WorkspaceId:${conf.workspace_id} + |APIURL:${conf.api_url} + | DBPATWorkspaceScope:${conf.secret_scope} + | SecretKey_DBPAT:${conf.secret_key_dbpat}""".stripMargin + val fullMsg = PipelineFunctions.appendStackStrace(exception, msg) + logger.log(Level.ERROR, fullMsg) + DeploymentValidationReport(false, + getSimpleMsg("Validate_Mount"), + testDetails, + Some(fullMsg), + Some(conf.workspace_id) + ) + + } + } + } else { + DeploymentValidationReport(true, + getSimpleMsg("Validate_Mount"), + "Skipping mount point check", + Some("SUCCESS"), + Some(conf.workspace_id) + ) } + } /** @@ -303,9 +390,12 @@ object DeploymentValidation extends SparkSessionWrapper { * @param primordial_date * @param maxDate */ - private def validateAuditLog(workspace_id: String, auditlogprefix_source_aws: String, primordial_date: Date, maxDate: Int): DeploymentValidationReport = { + private def validateAuditLog(workspace_id: String, auditlogprefix_source_aws: Option[String], primordial_date: Date, maxDate: Int): DeploymentValidationReport = { try { - val fromDT = new java.sql.Date(primordial_date.getTime()).toLocalDate() + if (auditlogprefix_source_aws.isEmpty) throw new BadConfigException( + "auditlogprefix_source_aws cannot be null when cloud is AWS") + val auditLogPrefix = auditlogprefix_source_aws.get + val fromDT = new java.sql.Date(primordial_date.getTime).toLocalDate var untilDT = fromDT.plusDays(maxDate.toLong) val dateCompare = untilDT.compareTo(LocalDate.now()) val msgBuffer = new StringBuffer() @@ -315,12 +405,12 @@ object DeploymentValidation extends SparkSessionWrapper { val daysBetween = ChronoUnit.DAYS.between(fromDT, untilDT) var validationFlag = false if (daysBetween == 0) { - validationFlag = Helpers.pathExists(s"${auditlogprefix_source_aws}/date=${fromDT.toString}") + validationFlag = Helpers.pathExists(s"${auditLogPrefix}/date=${fromDT.toString}") } else { val pathsToCheck = datesStream(fromDT).takeWhile(_.isBefore(untilDT)).toArray - .map(dt => s"${auditlogprefix_source_aws}/date=${dt}") + .map(dt => s"${auditLogPrefix}/date=${dt}") val presentPaths = datesStream(fromDT).takeWhile(_.isBefore(untilDT)).toArray - .map(dt => s"${auditlogprefix_source_aws}/date=${dt}") + .map(dt => s"${auditLogPrefix}/date=${dt}") .filter(Helpers.pathExists) if (presentPaths.length == daysBetween) { validationFlag = true @@ -344,7 +434,7 @@ object DeploymentValidation extends SparkSessionWrapper { } else { val msg = s"""ReValidate the folder existence - | Make sure audit log with required date folder exist inside ${auditlogprefix_source_aws} + | Make sure audit log with required date folder exist inside ${auditlogprefix_source_aws.getOrElse("EMPTY")} |, primordial_date:${primordial_date} |, maxDate:${maxDate} """.stripMargin @@ -362,7 +452,7 @@ object DeploymentValidation extends SparkSessionWrapper { case exception: Exception => val msg = s"""AuditLogPrefixTest workspace_id:${workspace_id} - | Make sure audit log with required date folder exist inside ${auditlogprefix_source_aws} + | Make sure audit log with required date folder exist inside ${auditlogprefix_source_aws.getOrElse("EMPTY")} |, primordial_date:${primordial_date} |, maxDate:${maxDate} """.stripMargin logger.log(Level.ERROR, msg) @@ -384,7 +474,18 @@ object DeploymentValidation extends SparkSessionWrapper { * @param key * @param ehName */ - private def validateEventHub(workspace_id: String, scope: String, key: String, ehName: String, outputPath: String): DeploymentValidationReport = { + private def validateEventHub( + workspace_id: String, + scope: String, + optKey: Option[String], + optEHName: Option[String], + outputPath: String + ): DeploymentValidationReport = { + if (optKey.isEmpty || optEHName.isEmpty) throw new BadConfigException("When cloud is Azure, the eh_name and " + + "eh_scope_key are required fields but they were empty in the config") + // using gets here because if they were empty from above check, exception would already be thrown + val key = optKey.get + val ehName = optEHName.get val testDetails = s"""Connectivity test with ehName:${ehName} scope:${scope} SecretKey_DBPAT:${key}""" try { import org.apache.spark.eventhubs.{ConnectionStringBuilder, EventHubsConf, EventPosition} diff --git a/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala b/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala index 15f4473bb..42b6b8229 100644 --- a/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala +++ b/src/test/scala/com/databricks/labs/overwatch/ApiCallV2Test.scala @@ -2,7 +2,7 @@ package com.databricks.labs.overwatch import com.databricks.labs.overwatch.ApiCallV2.sc import com.databricks.labs.overwatch.pipeline.PipelineFunctions -import com.databricks.labs.overwatch.utils.{ApiCallFailureV2, ApiEnv} +import com.databricks.labs.overwatch.utils.{ApiCallFailureV2, ApiEnv, SparkSessionWrapper} import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.lit @@ -96,8 +96,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { apiEnv, endPoint, query, - tempSuccessPath = "src/test/scala/tempDir/sqlqueryhistory", - accumulator = acc + tempSuccessPath = "src/test/scala/tempDir/sqlqueryhistory" ).execute().asDF() != null) } @@ -160,15 +159,15 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { it("comparison test for jobs/list API") { val endPoint = "jobs/list" val oldAPI = ApiCall(endPoint, apiEnv).executeGet().asDF - val query = Map( - "limit"->"2", - "expand_tasks"->"true", - "offset"->"0" - ) - val newAPI = ApiCallV2(apiEnv, endPoint,query,2.1).execute().asDF() - println(oldAPI.count()+"old api count") - println(newAPI.count()+"new api count") - assert(oldAPI.count() == newAPI.count() && oldAPI.except(newAPI).count() == 0 && newAPI.except(oldAPI).count() == 0) + val query = Map( + "limit" -> "2", + "expand_tasks" -> "true", + "offset" -> "0" + ) + val newAPI = ApiCallV2(apiEnv, endPoint, query, 2.1).execute().asDF() + println(oldAPI.count() + "old api count") + println(newAPI.count() + "new api count") + assert(oldAPI.count() == newAPI.count() && oldAPI.except(newAPI).count() == 0 && newAPI.except(oldAPI).count() == 0) } it("comparison test for clusters/list API") { @@ -206,6 +205,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { assert(oldAPI.count() == newAPI.count() && oldAPI.except(newAPI).count() == 0 && newAPI.except(oldAPI).count() == 0) } it("test multithreading") { + SparkSessionWrapper.parSessionsOn = true val endPoint = "clusters/list" val clusterIDsDf = ApiCallV2(apiEnv, endPoint).execute().asDF().select("cluster_id") clusterIDsDf.show(false) @@ -219,14 +219,14 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { val tmpClusterEventsErrorPath = "" val accumulator = sc.longAccumulator("ClusterEventsAccumulator") for (i <- clusterIDs.indices) { - val jsonQuery = Map("cluster_id" -> s"""${clusterIDs(i).get(0)}""", - "start_time"->"1052775426000", - "end_time"->"1655453826000", + val jsonQuery = Map("cluster_id" -> s"""${clusterIDs(i).get(0)}""", + "start_time" -> "1052775426000", + "end_time" -> "1655453826000", "limit" -> "500" ) println(jsonQuery) val future = Future { - val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath,accumulator).executeMultiThread() + val apiObj = ApiCallV2(apiEnv, "clusters/events", jsonQuery, tmpClusterEventsSuccessPath).executeMultiThread(accumulator) synchronized { apiResponseArray.addAll(apiObj) if (apiResponseArray.size() >= apiEnv.successBatchSize) { @@ -268,9 +268,9 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { } Thread.sleep(5000) currentSleepTime += 5000 - if (accumulatorCountWhileSleeping < accumulator.value) {//new API response received while waiting + if (accumulatorCountWhileSleeping < accumulator.value) { //new API response received while waiting currentSleepTime = 0 //resetting the sleep time - accumulatorCountWhileSleeping = accumulator.value + accumulatorCountWhileSleeping = accumulator.value } } if (responseCounter != finalResponseCount) { @@ -278,7 +278,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { } } - it("test sqlHistoryDF"){ + it("test sqlHistoryDF") { lazy val spark: SparkSession = { SparkSession .builder() @@ -292,22 +292,21 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { val acc = sc.longAccumulator("sqlQueryHistoryAccumulator") // val startTime = "1665878400000".toLong // subtract 2 days for running query merge // val startTime = "1669026035073".toLong -(1000*60*60*24)// subtract 2 days for running query merge - val startTime = "1669026035073".toLong + val startTime = "1669026035073".toLong val untilTime = "1669112526676".toLong val tempWorkingDir = "" println("test_started") val jsonQuery = Map( "max_results" -> "50", "include_metrics" -> "true", - "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", - "filter_by.query_start_time_range.end_time_ms" -> s"${untilTime}" + "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", + "filter_by.query_start_time_range.end_time_ms" -> s"${untilTime}" ) ApiCallV2( apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = s"${tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}", - accumulator = acc + tempSuccessPath = s"${tempWorkingDir}/sqlqueryhistory_silver/${System.currentTimeMillis()}" ) .execute() .asDF() @@ -335,7 +334,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(apiEnv.threadPoolSize)) val tmpSqlQueryHistorySuccessPath = "/tmp/test/" val tmpSqlQueryHistoryErrorPath = "" - val startTime = "1669026035073".toLong + val startTime = "1669026035073".toLong val untilTimeMs = "1669112526676".toLong // val untilTimeMs = "1666182197381".toLong @@ -343,14 +342,14 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { // var fromTimeMs = "1662249600000".toLong - (1000*60*60*24*2) //subtracting 2 days for running query merge // var fromTimeMs = "1666182197381".toLong - (1000*60*60*24*2) var fromTimeMs = "1669026035073".toLong - val finalResponseCount = scala.math.ceil((untilTimeMs - fromTimeMs).toDouble/(1000*60*60)) // Total no. of API Calls + val finalResponseCount = scala.math.ceil((untilTimeMs - fromTimeMs).toDouble / (1000 * 60 * 60)) // Total no. of API Calls - while (fromTimeMs < untilTimeMs){ - val (startTime, endTime) = if ((untilTimeMs- fromTimeMs)/(1000*60*60) > 1) { + while (fromTimeMs < untilTimeMs) { + val (startTime, endTime) = if ((untilTimeMs - fromTimeMs) / (1000 * 60 * 60) > 1) { (fromTimeMs, - fromTimeMs+(1000*60*60)) + fromTimeMs + (1000 * 60 * 60)) } - else{ + else { (fromTimeMs, untilTimeMs) } @@ -359,10 +358,10 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { val jsonQuery = Map( "max_results" -> "50", "include_metrics" -> "true", - "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", // Do we need to subtract 2 days for every API call? - "filter_by.query_start_time_range.end_time_ms" -> s"$endTime" + "filter_by.query_start_time_range.start_time_ms" -> s"$startTime", // Do we need to subtract 2 days for every API call? + "filter_by.query_start_time_range.end_time_ms" -> s"$endTime" ) - /**TODO: + /** TODO: * Refactor the below code to make it more generic */ //call future @@ -371,13 +370,12 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { apiEnv, sqlQueryHistoryEndpoint, jsonQuery, - tempSuccessPath = tmpSqlQueryHistorySuccessPath, - accumulator = acc - ).executeMultiThread() + tempSuccessPath = tmpSqlQueryHistorySuccessPath + ).executeMultiThread(acc) synchronized { apiObj.forEach( - obj=>if(obj.contains("res")){ + obj => if (obj.contains("res")) { apiResponseArray.add(obj) } ) @@ -405,7 +403,7 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { } responseCounter = responseCounter + 1 } - fromTimeMs = fromTimeMs+(1000*60*60) + fromTimeMs = fromTimeMs + (1000 * 60 * 60) } val timeoutThreshold = apiEnv.apiWaitingTime // 5 minutes var currentSleepTime = 0 @@ -445,8 +443,58 @@ class ApiCallV2Test extends AnyFunSpec with BeforeAndAfterAll { } - } + it("test jobRunsList") { + lazy val spark: SparkSession = { + SparkSession + .builder() + .master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + } + var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) + + lazy val sc: SparkContext = spark.sparkContext + val jobsRunsEndpoint = "jobs/runs/list" + val fromTime = "1676160000000".toLong + val untilTime = "1676574760631".toLong + val tempWorkingDir = "/tmp/test/" + println("test_started") + + val jsonQuery = Map( + "limit" -> "25", + "expand_tasks" -> "true", + "offset" -> "0", + "start_time_from" -> s"${fromTime}", + "start_time_to" -> s"${untilTime}" + ) + val acc = sc.longAccumulator("jobRunsListAccumulator") + val apiObj = ApiCallV2( + apiEnv, + jobsRunsEndpoint, + jsonQuery, + tempSuccessPath = tempWorkingDir, + 2.1 + ).executeMultiThread(acc) + // .execute() + // .asDF() + // .withColumn("organization_id", lit("1234")) + // .show(false) + + apiObj.forEach( + obj => if (obj.contains("res")) { + apiResponseArray.add(obj) + } + ) + println(apiObj.size()) + + if (apiResponseArray.size() > 0) { //In case of response array didn't hit the batch-size as a final step we will write it to the persistent storage. + PipelineFunctions.writeMicroBatchToTempLocation(tempWorkingDir, apiResponseArray.toString) + } + + } + } } diff --git a/src/test/scala/com/databricks/labs/overwatch/utils/HelpersTest.scala b/src/test/scala/com/databricks/labs/overwatch/utils/HelpersTest.scala new file mode 100644 index 000000000..7479f190f --- /dev/null +++ b/src/test/scala/com/databricks/labs/overwatch/utils/HelpersTest.scala @@ -0,0 +1,15 @@ +package com.databricks.labs.overwatch.utils + +import org.scalatest.funspec.AnyFunSpec + +class HelpersTest extends AnyFunSpec { + + describe("Helpers Test") { + it("Numeric test"){ + assert(Helpers.isNumeric("1234") == true) + assert(Helpers.isNumeric("abcd") == false) + } + + } + +}