From e82f5ff48562f6add3a3b8f310120c4e7aac4eb9 Mon Sep 17 00:00:00 2001 From: Guenia Izquierdo Delgado Date: Thu, 25 Jan 2024 13:52:29 -0500 Subject: [PATCH] 0800 release (#1144) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial commit * traceability implemented (#1102) * traceability implemented * code review implemented * missed code implemented (#1105) * Initial commit * traceability implemented (#1102) * traceability implemented * code review implemented * missed code implemented * missed code implemented --------- Co-authored-by: Guenia Izquierdo * Added proper exception for Spark Stream Gold if progress c… (#1085) * Initial commit * 09-Nov-23: Added proper exception for Spark Stream Gold if progress column contains only null in SparkEvents_Bronze --------- Co-authored-by: Guenia Izquierdo Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * Gracefully Handle Exception for NotebookCommands_Gold (#1095) * Initial commit * Gracefully Handle Exception for NotebookCommands_Gold * Convert the check in buildNotebookCommandsFact to single or clause --------- Co-authored-by: Guenia Izquierdo Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * code missed in merge (#1120) * Fix Helper Method to Instantiate Remote Workspaces (#1110) * Initial commit * Change getRemoteWorkspaceByPath and getWorkspaceByDatabase to take it RemoteWorkspace * Remove Unnecessary println Statements --------- Co-authored-by: Guenia Izquierdo * Ensure we test the write into a partitioned storage_prefix (#1088) * Initial commit * Ensure we test the write into a partitioned storage_prefix * silver warehouse spec fix (#1121) * added missed copy-pasta (#1129) * Exclude cluster logs in S3 root bucket (#1118) * Exclude cluster logs in S3 root bucket * Omit cluster log paths pointing to s3a as well * implemented recon (#1116) * implemented recon * docs added * file path change * review comments implemented * Added ShuffleFactor to NotebookCommands (#1124) Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * disabled traceability (#1130) * Added JobRun_Silver in buildClusterStateFact for Cluster E… (#1083) * Initial commit * 08-Nov-23: Added JobRun_Silver in buildClusterStateFact for Cluster End Time Imputation * Impute Terminating Events in CLSF from JR_Silver * Impute Terminating Events in CLSD * Impute Terminating Events in CLSD * Change CLSF to original 0730 version * Change CLSF to original 0730 version * Added cluster_spec in CLSD to get job Cluster only * Make the variables name in buildClusterStateDetail into more descriptive way * Make the variables name in buildClusterStateDetail into more descriptive way --------- Co-authored-by: Guenia Izquierdo Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * Sys table audit log integration (#1122) * system table integration with audit log * adding code to resolve issues with response col * fixed timestamp issue * adding print statement for from and until time * adding fix for azure * removed comments * removed comments and print statements * removed comments * implemented code review comments * implemented code review comments * adding review comment * Sys table integration multi acount (#1131) * added code changes for multi account deployment * code for multi account system table integration * Sys table integration multi acount (#1132) * added code changes for multi account deployment * code for multi account system table integration * adding code for system table migration check * changing exception for empty audit log from system table * adding code to handle sql_endpoint in configs and fix in migration validation (#1133) * corner case commit (#1134) * Handle CLSD Cluster Impute when jrcp and clusterSpec is Empty (#1135) * Handle CLSD Cluster Impute when jrcp and clusterSpec is Empty * Exclude last_state from clsd as it is not needed in the logic. --------- Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * Exclude 2011 and 2014 as dependency module for 2019 (#1136) * Exclude 2011 and 2014 as dependency module for 2019 * Added comment in CLSD for understandability --------- Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> * corner case commit (#1137) * Update version * adding fix for empty EH config for system tables (#1140) * corner case commit (#1142) * adding fix for empty audit log for warehouse_spec_silver (#1141) * recon columns removed (#1143) * recon columns removed * recon columns removed --------- Co-authored-by: Sriram Mohanty <69749553+sriram251-code@users.noreply.github.com> Co-authored-by: Sourav Banerjee <109206082+souravbaner-da@users.noreply.github.com> Co-authored-by: Sourav Banerjee <30810740+Sourav692@users.noreply.github.com> Co-authored-by: Aman <91308367+aman-db@users.noreply.github.com> --- build.sbt | 2 +- .../overwatch/MultiWorkspaceDeployment.scala | 96 +++-- .../labs/overwatch/ParamDeserializer.scala | 10 +- .../labs/overwatch/api/ApiCallV2.scala | 104 ++++-- .../labs/overwatch/api/ApiMeta.scala | 48 ++- .../labs/overwatch/env/Workspace.scala | 67 ++-- .../labs/overwatch/pipeline/Bronze.scala | 30 +- .../overwatch/pipeline/BronzeTransforms.scala | 352 ++++++++++++------ .../overwatch/pipeline/DbsqlTransforms.scala | 13 +- .../labs/overwatch/pipeline/Gold.scala | 2 +- .../overwatch/pipeline/GoldTransforms.scala | 9 +- .../labs/overwatch/pipeline/Initializer.scala | 12 + .../pipeline/InitializerFunctions.scala | 215 ++++++++--- .../labs/overwatch/pipeline/Module.scala | 53 +++ .../overwatch/pipeline/PipelineTable.scala | 3 +- .../overwatch/pipeline/PipelineTargets.scala | 45 ++- .../labs/overwatch/pipeline/Schema.scala | 3 +- .../labs/overwatch/pipeline/Silver.scala | 12 +- .../overwatch/pipeline/SilverTransforms.scala | 57 ++- .../labs/overwatch/utils/Config.scala | 5 + .../labs/overwatch/utils/SchemaTools.scala | 21 ++ .../labs/overwatch/utils/Structures.scala | 28 +- .../labs/overwatch/utils/Tools.scala | 144 ++++++- .../validation/DataReconciliation.scala | 316 ++++++++++++++++ .../validation/DeploymentValidation.scala | 85 ++++- .../overwatch/ParamDeserializerTest.scala | 3 +- .../overwatch/pipeline/InitializeTest.scala | 26 +- 27 files changed, 1420 insertions(+), 341 deletions(-) create mode 100644 src/main/scala/com/databricks/labs/overwatch/validation/DataReconciliation.scala diff --git a/build.sbt b/build.sbt index 38eeaa40f..3128812a2 100644 --- a/build.sbt +++ b/build.sbt @@ -2,7 +2,7 @@ name := "overwatch" organization := "com.databricks.labs" -version := "0.7.2.2.1" +version := "0.8.0.0" scalaVersion := "2.12.12" scalacOptions ++= Seq("-Xmax-classfile-name", "78") diff --git a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala index 1551cf73d..24c5de6a1 100644 --- a/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala +++ b/src/main/scala/com/databricks/labs/overwatch/MultiWorkspaceDeployment.scala @@ -73,6 +73,9 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { private var _pipelineSnapTime: Long = _ + private var _systemTableAudit: String = "system.access.audit" + + private def systemTableAudit: String = _systemTableAudit private def setConfigLocation(value: String): this.type = { _configLocation = value @@ -113,30 +116,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { val tokenSecret = TokenSecret(config.secret_scope, config.secret_key_dbpat) val badRecordsPath = s"${config.storage_prefix}/${config.workspace_id}/sparkEventsBadrecords" // TODO -- ISSUE 781 - quick fix to support non-json audit logs but needs to be added back to external parameters - val auditLogFormat = spark.conf.getOption("overwatch.auditlogformat").getOrElse("json") - val auditLogConfig = if (s"${config.cloud.toLowerCase()}" != "azure") { - AuditLogConfig(rawAuditPath = config.auditlogprefix_source_path, auditLogFormat = auditLogFormat) - } else { - - val ehStatePath = s"${config.storage_prefix}/${config.workspace_id}/ehState" - val isAAD = config.aad_client_id.nonEmpty && - config.aad_tenant_id.nonEmpty && - config.aad_client_secret_key.nonEmpty && - config.eh_conn_string.nonEmpty - val azureLogConfig = if(isAAD){ - AzureAuditLogEventhubConfig(connectionString = config.eh_conn_string.get, eventHubName = config.eh_name.get - , auditRawEventsPrefix = ehStatePath, - azureClientId = Some(config.aad_client_id.get), - azureClientSecret = Some(dbutils.secrets.get(config.secret_scope, key = config.aad_client_secret_key.get)), - azureTenantId = Some(config.aad_tenant_id.get), - azureAuthEndpoint = config.aad_authority_endpoint.getOrElse("https://login.microsoftonline.com/") - ) - }else{ - val ehConnString = s"{{secrets/${config.secret_scope}/${config.eh_scope_key.get}}}" - AzureAuditLogEventhubConfig(connectionString = ehConnString, eventHubName = config.eh_name.get, auditRawEventsPrefix = ehStatePath) - } - AuditLogConfig(azureAuditLogEventhubConfig = Some(azureLogConfig)) - } + val auditLogConfig = getAuditlogConfigs(config) val interactiveDBUPrice: Double = config.interactive_dbu_price val automatedDBUPrice: Double = config.automated_dbu_price val sqlComputerDBUPrice: Double = config.sql_compute_dbu_price @@ -152,6 +132,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { val stringDate = dateFormat.format(primordialDateString) val apiEnvConfig = getProxyConfig(config) val temp_dir_path = config.temp_dir_path.getOrElse("") + val sql_endpoint = config.sql_endpoint val params = OverwatchParams( auditLogConfig = auditLogConfig, @@ -165,7 +146,8 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { workspace_name = Some(customWorkspaceName), externalizeOptimize = true, apiEnvConfig = Some(apiEnvConfig), - tempWorkingDir = temp_dir_path + tempWorkingDir = temp_dir_path, + sqlEndpoint = sql_endpoint ) MultiWorkspaceParams(JsonUtils.objToJson(params).compactString, s"""${config.api_url}""", @@ -186,7 +168,61 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { } } - private def getProxyConfig(config: MultiWorkspaceConfig): ApiEnvConfig = { + private def getAuditLogConfigForSystemTable(config: MultiWorkspaceConfig): AuditLogConfig = { + if(config.sql_endpoint.getOrElse("").isEmpty) { + val auditLogFormat = "delta" + AuditLogConfig(rawAuditPath = config.auditlogprefix_source_path, + auditLogFormat = auditLogFormat, systemTableName = Some(systemTableAudit)) + } + else { + val auditLogFormat = "delta" + AuditLogConfig(rawAuditPath = config.auditlogprefix_source_path, + auditLogFormat = auditLogFormat, systemTableName = Some(systemTableAudit), + sqlEndpoint = config.sql_endpoint) + } + } + + private def getAuditLogConfigForAwsGcp(config: MultiWorkspaceConfig): AuditLogConfig = { + val auditLogFormat = spark.conf.getOption("overwatch.auditlogformat").getOrElse("json") + AuditLogConfig(rawAuditPath = config.auditlogprefix_source_path, auditLogFormat = auditLogFormat) + } + + private def getAuditLogConfigForzure(config: MultiWorkspaceConfig): AuditLogConfig = { + val ehStatePath = s"${config.storage_prefix}/${config.workspace_id}/ehState" + val isAAD = config.aad_client_id.nonEmpty && + config.aad_tenant_id.nonEmpty && + config.aad_client_secret_key.nonEmpty && + config.eh_conn_string.nonEmpty + val azureLogConfig = if (isAAD) { + AzureAuditLogEventhubConfig(connectionString = config.eh_conn_string.get, eventHubName = config.eh_name.get + , auditRawEventsPrefix = ehStatePath, + azureClientId = Some(config.aad_client_id.get), + azureClientSecret = Some(dbutils.secrets.get(config.secret_scope, key = config.aad_client_secret_key.get)), + azureTenantId = Some(config.aad_tenant_id.get), + azureAuthEndpoint = config.aad_authority_endpoint.getOrElse("https://login.microsoftonline.com/") + ) + } else { + val ehConnString = s"{{secrets/${config.secret_scope}/${config.eh_scope_key.get}}}" + AzureAuditLogEventhubConfig(connectionString = ehConnString, eventHubName = config.eh_name.get, auditRawEventsPrefix = ehStatePath) + } + AuditLogConfig(azureAuditLogEventhubConfig = Some(azureLogConfig)) + } + + private def getAuditlogConfigs(config: MultiWorkspaceConfig): AuditLogConfig = { + if (config.auditlogprefix_source_path.getOrElse("").toLowerCase.equals("system")) { + getAuditLogConfigForSystemTable(config) + } else { + if(s"${config.cloud.toLowerCase()}" != "azure") { + getAuditLogConfigForAwsGcp(config) + } + else { + getAuditLogConfigForzure(config) + } + } + } + + + private def getProxyConfig(config: MultiWorkspaceConfig): ApiEnvConfig = { val apiProxyConfig = ApiProxyConfig(config.proxy_host, config.proxy_port, config.proxy_user_name, config.proxy_password_scope, config.proxy_password_key) val apiEnvConfig = ApiEnvConfig(config.success_batch_size.getOrElse(200), config.error_batch_size.getOrElse(500), @@ -416,7 +452,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { .as[MultiWorkspaceConfig] .filter(_.active) .collect() - if(multiWorkspaceConfig.length < 1){ + if (multiWorkspaceConfig.length < 1) { throw new BadConfigException("Config file has 0 record, config file:" + configLocation) } multiWorkspaceConfig @@ -426,7 +462,6 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { logger.log(Level.ERROR, fullMsg) throw e } - } @@ -479,6 +514,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { println("ParallelismLevel :" + parallelism) val multiWorkspaceConfig = generateMultiWorkspaceConfig(configLocation, deploymentId, outputPath) + snapshotConfig(multiWorkspaceConfig) val params = DeploymentValidation .performMandatoryValidation(multiWorkspaceConfig, parallelism) @@ -492,7 +528,7 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { deploymentReport.appendAll(deploymentReports) saveDeploymentReport(deploymentReport.toArray, multiWorkspaceConfig.head.storage_prefix, "deploymentReport") } catch { - case e: Exception => + case e: Throwable => val failMsg = s"FAILED DEPLOYMENT WITH EXCEPTION" println(failMsg) logger.log(Level.ERROR, failMsg, e) @@ -544,4 +580,6 @@ class MultiWorkspaceDeployment extends SparkSessionWrapper { returnParam } + + } \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala index 7c4af604d..356b783b0 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala @@ -101,6 +101,8 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw val rawAuditPath = getOptionString(masterNode, "auditLogConfig.rawAuditPath") val auditLogFormat = getOptionString(masterNode, "auditLogConfig.auditLogFormat").getOrElse("json") val azureEventHubNode = getNodeFromPath(masterNode, "auditLogConfig.azureAuditLogEventhubConfig") + val systemTableName = getOptionString(masterNode, "auditLogConfig.systemTableName") + val sqlEndpoint= getOptionString(masterNode, "auditLogConfig.sqlEndpoint") val azureAuditEventHubConfig = if (azureEventHubNode.nonEmpty) { val node = azureEventHubNode.get @@ -121,7 +123,8 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw None } - val auditLogConfig = AuditLogConfig(rawAuditPath, auditLogFormat, azureAuditEventHubConfig) + val auditLogConfig = AuditLogConfig(rawAuditPath, auditLogFormat, azureAuditEventHubConfig, systemTableName, + sqlEndpoint) val dataTarget = if (masterNode.has("dataTarget")) { Some(DataTarget( @@ -194,7 +197,7 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw } else { None } - + val sql_endpoint = getOptionString(masterNode, "sqlEndpoint") OverwatchParams( auditLogConfig, token, @@ -208,7 +211,8 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw workspace_name, externalizeOptimize, apiEnvConfig, - tempWorkingDir + tempWorkingDir, + sql_endpoint ) } } \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/overwatch/api/ApiCallV2.scala b/src/main/scala/com/databricks/labs/overwatch/api/ApiCallV2.scala index eea9af1f6..e4b8c35b5 100644 --- a/src/main/scala/com/databricks/labs/overwatch/api/ApiCallV2.scala +++ b/src/main/scala/com/databricks/labs/overwatch/api/ApiCallV2.scala @@ -1,6 +1,7 @@ package com.databricks.labs.overwatch.api import com.databricks.labs.overwatch.pipeline.PipelineFunctions +import com.databricks.labs.overwatch.utils.Helpers.deriveRawApiResponseDF import com.databricks.labs.overwatch.utils._ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule @@ -61,7 +62,7 @@ object ApiCallV2 extends SparkSessionWrapper { * @return */ def apply(apiEnv: ApiEnv, apiName: String, queryMap: Map[String, String], tempSuccessPath: String - ): ApiCallV2 = { + ): ApiCallV2 = { new ApiCallV2(apiEnv) .setEndPoint(apiName) .buildMeta(apiName) @@ -341,6 +342,9 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { hibernate(response) execute() } else { + if(writeTraceApiFlag()){ + PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, apiMeta.enrichAPIResponse(response,jsonQuery,queryMap)) + } throw new ApiCallFailure(response, buildGenericErrorMessage, debugFlag = false) } @@ -442,7 +446,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { case "GET" => response = try { - apiMeta.getBaseRequest() + apiMeta.getBaseRequest() .params(queryMap) .options(reqOptions) .asString @@ -509,17 +513,20 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { /** * Converting the API response to Dataframe. * - * @return Dataframe which is created from the API response. + * + * @withMetaData if this flag is set to true the returned dataframe will include both raw api response and the metadata of the api. + * if this flag is set to false the returned dataframe will only include raw api response. + * @return Dataframe which is created from the API response with the api call metadata. */ - def asDF(): DataFrame = { + def asDF(withMetaData: Boolean = false): DataFrame = { var apiResultDF: DataFrame = null; - if (_apiResponseArray.size == 0 && !apiMeta.storeInTempLocation) { //If response contains no Data. + if (_apiResponseArray.size == 0 && !apiMeta.batchPersist) { //If response contains no Data. val errMsg = s"API CALL Resulting DF is empty BUT no errors detected, progressing module. " + s"Details Below:\n$buildGenericErrorMessage" throw new ApiCallEmptyResponse(errMsg, true) - } else if (_apiResponseArray.size != 0 && successTempPath.isEmpty) { //If API response don't have pagination/volume of response is not huge then we directly convert the response which is in-memory to spark DF. + } else if (_apiResponseArray.size != 0 ) { //If API response don't have pagination/volume of response is not huge then we directly convert the response which is in-memory to spark DF. apiResultDF = spark.read.json(Seq(_apiResponseArray.toString).toDS()) - } else if (apiMeta.storeInTempLocation && successTempPath.nonEmpty) { //Read the response from the Temp location/Disk and convert it to Dataframe. + } else if (apiMeta.batchPersist && successTempPath.nonEmpty) { //Read the response from the Temp location/Disk and convert it to Dataframe. apiResultDF = try { spark.read.json(successTempPath.get) } catch { @@ -535,10 +542,17 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { logger.error(errMsg) spark.emptyDataFrame }else { - extrapolateSupportedStructure(apiResultDF) + val rawDf = if(withMetaData){ + apiResultDF + }else{ + deriveRawApiResponseDF(apiResultDF) + } + extrapolateSupportedStructure(rawDf) } } + + private def jsonQueryToApiErrorDetail(e: ApiCallFailure): String = { val mapper = new ObjectMapper() val jsonObject = mapper.readTree(jsonQuery); @@ -560,19 +574,27 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { * @return */ private def emptyDFCheck(apiResultDF: DataFrame): Boolean = { - if (apiResultDF.columns.length == 0) { //Check number of columns in result Dataframe - true - } else if (apiResultDF.columns.size == 1 && apiResultDF.columns.contains(apiMeta.paginationKey)) { //Check if only pagination key in present in the response - true - } else if (apiResultDF.columns.size == 1 && apiResultDF.columns.contains(apiMeta.emptyResponseColumn)) { //Check if only pagination key in present in the response + + val filteredDf = apiResultDF.select('rawResponse) + .filter('rawResponse =!= "{}") + if (filteredDf.isEmpty) { true - } - else { - false + } else { + val rawDF= filteredDf + .withColumn("rawResponse", SchemaTools.structFromJson(spark, apiResultDF, "rawResponse")) + .select("rawResponse.*") + if (rawDF.columns.length == 0) { //Check number of columns in result Dataframe + true + } else if (rawDF.columns.size == 1 && rawDF.columns.contains(apiMeta.paginationKey)) { //Check if only pagination key in present in the response + true + } else { + false + } } } + /** * Performs api calls in parallel. * @return @@ -581,8 +603,8 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { @tailrec def executeThreadedHelper(): util.ArrayList[String] = { val response = getResponse responseCodeHandler(response) - _apiResponseArray.add(response.body) - if (apiMeta.storeInTempLocation && successTempPath.nonEmpty) { + _apiResponseArray.add(apiMeta.enrichAPIResponse(response,jsonQuery,queryMap))//for GET request we have to convert queryMap to Json + if (apiMeta.batchPersist && successTempPath.nonEmpty) { accumulator.add(1) if (apiEnv.successBatchSize <= _apiResponseArray.size()) { //Checking if its right time to write the batches into persistent storage val responseFlag = PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, _apiResponseArray.toString) @@ -595,11 +617,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { logger.log(Level.INFO, buildDetailMessage()) setPrintFinalStatsFlag(false) } - if (paginate(response.body)) { - executeThreadedHelper() - } else { - _apiResponseArray - } + if (paginate(response.body)) executeThreadedHelper() else _apiResponseArray } try { executeThreadedHelper() @@ -637,8 +655,8 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { @tailrec def executeHelper(): this.type = { val response = getResponse responseCodeHandler(response) - _apiResponseArray.add(response.body) - if (apiMeta.storeInTempLocation && successTempPath.nonEmpty) { + _apiResponseArray.add(apiMeta.enrichAPIResponse(response,jsonQuery,queryMap)) + if (apiMeta.batchPersist && successTempPath.nonEmpty) { if (apiEnv.successBatchSize <= _apiResponseArray.size()) { //Checking if its right time to write the batches into persistent storage val responseFlag = PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, _apiResponseArray.toString) if (responseFlag) { //Clearing the resultArray in-case of successful write @@ -654,6 +672,7 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { } try { executeHelper() + this } catch { case e: java.lang.NoClassDefFoundError => { val excMsg = "DEPENDENCY MISSING: scalaj. Ensure that the proper scalaj library is attached to your cluster" @@ -674,9 +693,21 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { logger.log(Level.WARN, excMsg, e) throw e } + }finally { + if (writeTraceApiFlag()) { + PipelineFunctions.writeMicroBatchToTempLocation(successTempPath.get, _apiResponseArray.toString) + } } } + /** + * Function to check if traceability is enabled or not for API calls.. + * @return + */ + private def writeTraceApiFlag(): Boolean ={ + spark.conf.getOption("overwatch.traceapi").getOrElse("false").toBoolean && !apiMeta.batchPersist && successTempPath.nonEmpty + } + /** * Function to make parallel API calls. Currently this functions supports only SqlQueryHistory and ClusterEvents * @param endpoint @@ -684,12 +715,17 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { * @param config * @return */ - def makeParallelApiCalls(endpoint: String, jsonInput: Map[String, String], config: Config): String = { - val tempEndpointLocation = endpoint.replaceAll("/","") + def makeParallelApiCalls( + endpoint: String, + jsonInput: Map[String, String], + pipelineSnapTime: Long, + config: Config + ): String = { + val tempEndpointLocation = endpoint.replaceAll("/","_") val acc = sc.longAccumulator(tempEndpointLocation) val tmpSuccessPath = if(jsonInput.contains("tmp_success_path")) jsonInput.get("tmp_success_path").get - else s"${config.tempWorkingDir}/${tempEndpointLocation}/${System.currentTimeMillis()}" + else s"${config.tempWorkingDir}/${tempEndpointLocation}/${pipelineSnapTime}" val tmpErrorPath = if(jsonInput.contains("tmp_error_path")) jsonInput.get("tmp_error_path").get else s"${config.tempWorkingDir}/errors/${tempEndpointLocation}/${System.currentTimeMillis()}" @@ -712,12 +748,12 @@ class ApiCallV2(apiEnv: ApiEnv) extends SparkSessionWrapper { //call future val future = Future { - val apiObj = ApiCallV2( - config.apiEnv, - endpoint, - jsonQuery, - tempSuccessPath = tmpSuccessPath - ).executeMultiThread(acc) + val apiObj = ApiCallV2( + config.apiEnv, + endpoint, + jsonQuery, + tempSuccessPath = tmpSuccessPath + ).executeMultiThread(acc) synchronized { apiObj.forEach( diff --git a/src/main/scala/com/databricks/labs/overwatch/api/ApiMeta.scala b/src/main/scala/com/databricks/labs/overwatch/api/ApiMeta.scala index e41986a33..831b18a17 100644 --- a/src/main/scala/com/databricks/labs/overwatch/api/ApiMeta.scala +++ b/src/main/scala/com/databricks/labs/overwatch/api/ApiMeta.scala @@ -2,9 +2,11 @@ package com.databricks.labs.overwatch.api import com.databricks.dbutils_v1.DBUtilsHolder.dbutils import com.databricks.labs.overwatch.utils.ApiEnv -import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.log4j.{Level, Logger} -import scalaj.http.{Http, HttpRequest} +import org.json.JSONObject +import scalaj.http.{Http, HttpRequest, HttpResponse} /** * Configuration for each API. @@ -16,7 +18,7 @@ trait ApiMeta { protected var _paginationToken: String = _ protected var _dataframeColumn: String = "*" protected var _apiCallType: String = _ - protected var _storeInTempLocation = false + protected var _batchPersist = false protected var _apiV = "api/2.0" protected var _isDerivePaginationLogic = false protected var _apiEnv: ApiEnv = _ @@ -33,7 +35,7 @@ trait ApiMeta { protected[overwatch] def apiCallType: String = _apiCallType - protected[overwatch] def storeInTempLocation: Boolean = _storeInTempLocation + protected[overwatch] def batchPersist : Boolean = _batchPersist protected[overwatch] def apiV: String = _apiV @@ -55,8 +57,8 @@ trait ApiMeta { this } - private[overwatch] def setStoreInTempLocation(value: Boolean): this.type = { - _storeInTempLocation = value + private[overwatch] def setBatchPersist(value: Boolean): this.type = { + _batchPersist = value this } @@ -147,7 +149,7 @@ trait ApiMeta { |paginationToken: ${paginationToken} |dataframeColumns: ${dataframeColumn} |apiCallType: ${apiCallType} - |storeInTempLocation: ${storeInTempLocation} + |storeInTempLocation: ${batchPersist} |apiV: ${apiV} |isDerivePaginationLogic: ${isDerivePaginationLogic} |""".stripMargin @@ -163,6 +165,32 @@ trait ApiMeta { Map[String, String]() } + /** + * Function will add the meta info to the api response. + * + * @param response + * @param jsonQuery + * @param queryMap + * @return a string containing the api response and the meta for the api call. + */ + private[overwatch] def enrichAPIResponse(response: HttpResponse[String], jsonQuery: String, queryMap: Map[String, String]): String = { + val filter: String = if (apiCallType.equals("POST")) jsonQuery else { + val mapper = new ObjectMapper() + mapper.registerModule(DefaultScalaModule) + mapper.writeValueAsString(queryMap) + } + val jsonObject = new JSONObject(); + val apiTraceabilityMeta = new JSONObject(); + apiTraceabilityMeta.put("endPoint", apiName) + apiTraceabilityMeta.put("type", apiCallType) + apiTraceabilityMeta.put("apiVersion", apiV) + apiTraceabilityMeta.put("responseCode", response.code) + apiTraceabilityMeta.put("batchKeyFilter", filter) + jsonObject.put("rawResponse", response.body.trim) + jsonObject.put("apiTraceabilityMeta", apiTraceabilityMeta) + jsonObject.toString + } + } /** @@ -221,7 +249,7 @@ class SqlQueryHistoryApi extends ApiMeta { setPaginationToken("next_page_token") setDataframeColumn("res") setApiCallType("GET") - setStoreInTempLocation(true) + setBatchPersist(true) setIsDerivePaginationLogic(true) private[overwatch] override def hasNextPage(jsonObject: JsonNode): Boolean = { @@ -311,7 +339,7 @@ class ClusterEventsApi extends ApiMeta { setPaginationToken("next_page") setDataframeColumn("events") setApiCallType("POST") - setStoreInTempLocation(true) + setBatchPersist(true) private[overwatch] override def getAPIJsonQuery(startValue: Long, endValue: Long,jsonInput: Map[String, String]): Map[String, String] = { val clusterIDs = jsonInput.get("cluster_ids").get.split(",").map(_.trim).toArray @@ -342,7 +370,7 @@ class JobRunsApi extends ApiMeta { setPaginationKey("has_more") setPaginationToken("next_page_token") setIsDerivePaginationLogic(true) - setStoreInTempLocation(true) + setBatchPersist(true) private[overwatch] override def hasNextPage(jsonObject: JsonNode): Boolean = { jsonObject.get(paginationKey).asBoolean() diff --git a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala index 092d95fab..c31057404 100644 --- a/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala +++ b/src/main/scala/com/databricks/labs/overwatch/env/Workspace.scala @@ -3,6 +3,7 @@ package com.databricks.labs.overwatch.env import com.databricks.dbutils_v1.DBUtilsHolder.dbutils import com.databricks.labs.overwatch.api.ApiCallV2 import com.databricks.labs.overwatch.pipeline.PipelineFunctions +import com.databricks.labs.overwatch.utils.Helpers.deriveRawApiResponseDF import com.databricks.labs.overwatch.utils._ import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.DataFrame @@ -10,11 +11,8 @@ import org.apache.spark.sql.functions._ import java.util import java.util.Collections -import java.util.concurrent.Executors import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.forkjoin.ForkJoinPool -import scala.util.{Failure, Success} /** @@ -57,14 +55,14 @@ class Workspace(config: Config) extends SparkSessionWrapper { * * @return */ - def getJobsDF: DataFrame = { - + def getJobsDF(apiTempPath: String): DataFrame = { val jobsEndpoint = "jobs/list" val query = Map( "limit" -> "25", "expand_tasks" -> "true" ) ApiCallV2(config.apiEnv, jobsEndpoint,query,2.1) + .setSuccessTempPath(apiTempPath) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) @@ -78,9 +76,10 @@ class Workspace(config: Config) extends SparkSessionWrapper { */ def getConfig: Config = config - def getClustersDF: DataFrame = { + def getClustersDF(tempApiDir: String): DataFrame = { val clustersEndpoint = "clusters/list" ApiCallV2(config.apiEnv, clustersEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) @@ -106,9 +105,10 @@ class Workspace(config: Config) extends SparkSessionWrapper { * * @return */ - def getPoolsDF: DataFrame = { + def getPoolsDF(tempApiDir: String): DataFrame = { val poolsEndpoint = "instance-pools/list" ApiCallV2(config.apiEnv, poolsEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) @@ -119,9 +119,12 @@ class Workspace(config: Config) extends SparkSessionWrapper { * * @return */ - def getProfilesDF: DataFrame = { + def getProfilesDF(tempApiDir: String): DataFrame = { val profilesEndpoint = "instance-profiles/list" - ApiCallV2(config.apiEnv, profilesEndpoint).execute().asDF().withColumn("organization_id", lit(config.organizationId)) + ApiCallV2(config.apiEnv, profilesEndpoint) + .setSuccessTempPath(tempApiDir) + .execute() + .asDF().withColumn("organization_id", lit(config.organizationId)) } @@ -156,7 +159,11 @@ class Workspace(config: Config) extends SparkSessionWrapper { .withColumn("organization_id", lit(config.organizationId)) } - def getSqlQueryHistoryParallelDF(fromTime: TimeTypes, untilTime: TimeTypes): DataFrame = { + def getSqlQueryHistoryParallelDF(fromTime: TimeTypes, + untilTime: TimeTypes, + pipelineSnapTime: TimeTypes, + tmpSqlHistorySuccessPath: String, + tmpSqlHistoryErrorPath: String): DataFrame = { val sqlQueryHistoryEndpoint = "sql/history/queries" val untilTimeMs = untilTime.asUnixTimeMilli val fromTimeMs = fromTime.asUnixTimeMilli - (1000*60*60*24*2) //subtracting 2 days for running query merge @@ -168,19 +175,26 @@ class Workspace(config: Config) extends SparkSessionWrapper { "end_value" -> s"${untilTimeMs}", "increment_counter" -> "3600000", "final_response_count" -> s"${finalResponseCount}", - "result_key" -> "res" + "result_key" -> "res", + "tmp_success_path" -> tmpSqlHistorySuccessPath, + "tmp_error_path" -> tmpSqlHistoryErrorPath ) // calling function to make parallel API calls val apiCallV2Obj = new ApiCallV2(config.apiEnv) - val tmpSqlQueryHistorySuccessPath= apiCallV2Obj.makeParallelApiCalls(sqlQueryHistoryEndpoint, jsonInput, config) + val tmpSqlQueryHistorySuccessPath= apiCallV2Obj.makeParallelApiCalls(sqlQueryHistoryEndpoint, jsonInput, pipelineSnapTime.asUnixTimeMilli,config) logger.log(Level.INFO, " sql query history landing completed") if(Helpers.pathExists(tmpSqlQueryHistorySuccessPath)) { try { - spark.read.json(tmpSqlQueryHistorySuccessPath) - .select(explode(col("res")).alias("res")).select(col("res" + ".*")) - .withColumn("organization_id", lit(config.organizationId)) + val rawDF = deriveRawApiResponseDF(spark.read.json(tmpSqlQueryHistorySuccessPath)) + if (rawDF.columns.contains("res")) { + rawDF.select(explode(col("res")).alias("res")).select(col("res" + ".*")) + .withColumn("organization_id", lit(config.organizationId)) + } else { + logger.log(Level.INFO, s"""No Data is present for sql/query/history from - ${fromTimeMs} to - ${untilTimeMs}, res column not found in dataset""") + spark.emptyDataFrame + } } catch { case e: Throwable => throw new Exception(e) @@ -224,33 +238,37 @@ class Workspace(config: Config) extends SparkSessionWrapper { }) } - def getClusterLibraries: DataFrame = { + def getClusterLibraries(tempApiDir: String): DataFrame = { val libsEndpoint = "libraries/all-cluster-statuses" ApiCallV2(config.apiEnv, libsEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) } - def getClusterPolicies: DataFrame = { + def getClusterPolicies(tempApiDir: String): DataFrame = { val policiesEndpoint = "policies/clusters/list" ApiCallV2(config.apiEnv, policiesEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) } - def getTokens: DataFrame = { + def getTokens(tempApiDir: String): DataFrame = { val tokenEndpoint = "token/list" ApiCallV2(config.apiEnv, tokenEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) } - def getGlobalInitScripts: DataFrame = { + def getGlobalInitScripts(tempApiDir: String): DataFrame = { val globalInitScEndpoint = "global-init-scripts" ApiCallV2(config.apiEnv, globalInitScEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) @@ -260,7 +278,7 @@ class Workspace(config: Config) extends SparkSessionWrapper { * Function to get the the list of Job Runs * @return */ - def getJobRunsDF(fromTime: TimeTypes, untilTime: TimeTypes): DataFrame = { + def getJobRunsDF(fromTime: TimeTypes, untilTime: TimeTypes,tempWorkingDir: String): DataFrame = { val jobsRunsEndpoint = "jobs/runs/list" val jsonQuery = Map( "limit" -> "25", @@ -270,7 +288,6 @@ class Workspace(config: Config) extends SparkSessionWrapper { ) val acc = sc.longAccumulator("sqlQueryHistoryAccumulator") var apiResponseArray = Collections.synchronizedList(new util.ArrayList[String]()) - val tempWorkingDir = s"${config.tempWorkingDir}/jobrunslist_bronze/${System.currentTimeMillis()}" val apiObj = ApiCallV2(config.apiEnv, jobsRunsEndpoint, @@ -279,7 +296,7 @@ class Workspace(config: Config) extends SparkSessionWrapper { 2.1).executeMultiThread(acc) apiObj.forEach( - obj => if (obj.contains("runs")) { + obj => if (obj.contains("job_id")) { apiResponseArray.add(obj) } ) @@ -291,7 +308,8 @@ class Workspace(config: Config) extends SparkSessionWrapper { if(Helpers.pathExists(tempWorkingDir)) { try { spark.conf.set("spark.sql.caseSensitive", "true") - val df = spark.read.json(tempWorkingDir) + val baseDF = spark.read.json(tempWorkingDir) + val df = deriveRawApiResponseDF(baseDF) .select(explode(col("runs")).alias("runs")).select(col("runs" + ".*")) .withColumn("organization_id", lit(config.organizationId)) spark.conf.set("spark.sql.caseSensitive", "false") @@ -314,9 +332,10 @@ class Workspace(config: Config) extends SparkSessionWrapper { * a snapshot of actively defined warehouses is captured and used to fill in the blanks in the silver+ layers. * @return */ - def getWarehousesDF: DataFrame = { + def getWarehousesDF(tempApiDir: String): DataFrame = { val warehousesEndpoint = "sql/warehouses" ApiCallV2(config.apiEnv, warehousesEndpoint) + .setSuccessTempPath(tempApiDir) .execute() .asDF() .withColumn("organization_id", lit(config.organizationId)) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala index 6d417eec5..b7e55e871 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Bronze.scala @@ -1,6 +1,7 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.labs.overwatch.env.{Database, Workspace} +import com.databricks.labs.overwatch.utils.Helpers.deriveApiTempDir import com.databricks.labs.overwatch.utils.{CloneDetail, Config, Helpers, OverwatchScope} import org.apache.log4j.{Level, Logger} @@ -109,7 +110,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendJobsProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getJobsDF, + workspace.getJobsDF(deriveApiTempDir(config.tempWorkingDir,jobsSnapshotModule.moduleName,pipelineSnapTime)), Seq(cleanseRawJobsSnapDF(BronzeTargets.jobsSnapshotTarget.keys, config.runID)), append(BronzeTargets.jobsSnapshotTarget) ) @@ -119,7 +120,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendClustersAPIProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getClustersDF, + workspace.getClustersDF(deriveApiTempDir(config.tempWorkingDir,clustersSnapshotModule.moduleName,pipelineSnapTime)), Seq(cleanseRawClusterSnapDF), append(BronzeTargets.clustersSnapshotTarget) ) @@ -129,7 +130,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendPoolsProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getPoolsDF, + workspace.getPoolsDF(deriveApiTempDir(config.tempWorkingDir,poolsSnapshotModule.moduleName,pipelineSnapTime)), Seq(cleanseRawPoolsDF()), append(BronzeTargets.poolsSnapshotTarget) ) @@ -146,7 +147,9 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) auditLogsModule.untilTime.asLocalDateTime, BronzeTargets.auditLogAzureLandRaw, config.runID, - config.organizationId + config.organizationId, + config.sqlEndpoint, + config.apiEnv ), append(BronzeTargets.auditLogsTarget) ) @@ -167,7 +170,8 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) config.organizationId, database, BronzeTargets.clusterEventsErrorsTarget, - config + config, + clusterEventLogsModule.moduleName ) ), append(BronzeTargets.clusterEventsTarget) @@ -220,7 +224,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendLibsProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getClusterLibraries, + workspace.getClusterLibraries(deriveApiTempDir(config.tempWorkingDir,libsSnapshotModule.moduleName,pipelineSnapTime)), append(BronzeTargets.libsSnapshotTarget) ) } @@ -229,7 +233,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendPoliciesProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getClusterPolicies, + workspace.getClusterPolicies(deriveApiTempDir(config.tempWorkingDir,policiesSnapshotModule.moduleName,pipelineSnapTime)), append(BronzeTargets.policiesSnapshotTarget) ) } @@ -238,7 +242,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendInstanceProfileProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getProfilesDF, + workspace.getProfilesDF(deriveApiTempDir(config.tempWorkingDir,instanceProfileSnapshotModule.moduleName,pipelineSnapTime)), append(BronzeTargets.instanceProfilesSnapshotTarget) ) } @@ -247,7 +251,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendTokenProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getTokens, + workspace.getTokens(deriveApiTempDir(config.tempWorkingDir,tokenSnapshotModule.moduleName,pipelineSnapTime)), append(BronzeTargets.tokensSnapshotTarget) ) } @@ -256,7 +260,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendGlobalInitScProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getGlobalInitScripts, + workspace.getGlobalInitScripts(deriveApiTempDir(config.tempWorkingDir,globalInitScSnapshotModule.moduleName,pipelineSnapTime)), append(BronzeTargets.globalInitScSnapshotTarget) ) } @@ -265,7 +269,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendJobRunsProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getJobRunsDF(jobRunsSnapshotModule.fromTime, jobRunsSnapshotModule.untilTime), + workspace.getJobRunsDF(jobRunsSnapshotModule.fromTime, jobRunsSnapshotModule.untilTime,deriveApiTempDir(config.tempWorkingDir,jobRunsSnapshotModule.moduleName,pipelineSnapTime)), Seq(cleanseRawJobRunsSnapDF(BronzeTargets.jobRunsSnapshotTarget.keys, config.runID)), append(BronzeTargets.jobRunsSnapshotTarget) ) @@ -276,7 +280,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendWarehousesAPIProcess: () => ETLDefinition = { () => ETLDefinition ( - workspace.getWarehousesDF, + workspace.getWarehousesDF(deriveApiTempDir(config.tempWorkingDir,warehousesSnapshotModule.moduleName,pipelineSnapTime)), Seq(cleanseRawWarehouseSnapDF), append(BronzeTargets.warehousesSnapshotTarget) ) @@ -305,7 +309,7 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config) private def executeModules(): Unit = { config.overwatchScope.foreach { case OverwatchScope.audit => - if (config.cloudProvider == "azure") { + if (config.cloudProvider == "azure" && !config.auditLogConfig.systemTableName.isDefined) { landAzureAuditEvents() } auditLogsModule.execute(appendAuditLogsProcess) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala index 74aea67d8..bfd732d53 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/BronzeTransforms.scala @@ -5,7 +5,7 @@ import com.databricks.labs.overwatch.api.{ApiCall, ApiCallV2} import com.databricks.labs.overwatch.env.Database import com.databricks.labs.overwatch.eventhubs.AadAuthInstance import com.databricks.labs.overwatch.pipeline.WorkflowsTransforms.{workflowsCleanseJobClusters, workflowsCleanseTasks} -import com.databricks.labs.overwatch.utils.Helpers.{getDatesGlob, removeTrailingSlashes} +import com.databricks.labs.overwatch.utils.Helpers.{deriveRawApiResponseDF, getDatesGlob, removeTrailingSlashes} import com.databricks.labs.overwatch.utils.SchemaTools.structFromJson import com.databricks.labs.overwatch.utils._ import com.fasterxml.jackson.databind.ObjectMapper @@ -21,11 +21,11 @@ import org.apache.spark.sql.{AnalysisException, Column, DataFrame} import org.apache.spark.util.SerializableConfiguration import java.time.LocalDateTime +import java.time.format.DateTimeFormatter +import java.sql.Timestamp import java.util.concurrent.Executors import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} - trait BronzeTransforms extends SparkSessionWrapper { @@ -331,114 +331,14 @@ trait BronzeTransforms extends SparkSessionWrapper { untilTime: LocalDateTime, auditRawLand: PipelineTable, overwatchRunID: String, - organizationId: String + organizationId: String, + sqlEndpoint: String, + apiEnv: ApiEnv ): DataFrame = { - val fromDT = fromTime.toLocalDate - val untilDT = untilTime.toLocalDate - if (cloudProvider == "azure") { - val azureAuditSourceFilters = 'Overwatch_RunID === lit(overwatchRunID) && 'organization_id === organizationId - val rawBodyLookup = auditRawLand.asDF - .filter(azureAuditSourceFilters) - - val requiredColumns : Array[Column] = Array( - col("category"), - col("version"), - col("timestamp"), - col("date"), - col("identity").alias("userIdentity"), - col("organization_id"), - col("properties.actionName"), - col("properties.logId"), - col("properties.requestId"), - col("properties.requestParams"), - col("properties.response"), - col("properties.serviceName"), - col("properties.sessionId"), - col("properties.sourceIPAddress"), - col("properties.userAgent"), - ) - - val schemaBuilders = auditRawLand.asDF - .filter(azureAuditSourceFilters) - .withColumn("parsedBody", structFromJson(spark, rawBodyLookup, "deserializedBody")) - .select(explode($"parsedBody.records").alias("streamRecord"), 'organization_id) - .selectExpr("streamRecord.*", "organization_id") - .withColumn("version", 'operationVersion) - .withColumn("time", 'time.cast("timestamp")) - .withColumn("timestamp", unix_timestamp('time) * 1000) - .withColumn("date", 'time.cast("date")) - .select(requiredColumns: _*) - - val baselineAuditLogs = auditRawLand.asDF - .filter(azureAuditSourceFilters) - .withColumn("parsedBody", structFromJson(spark, rawBodyLookup, "deserializedBody")) - .select(explode($"parsedBody.records").alias("streamRecord"), 'organization_id) - .selectExpr("streamRecord.*", "organization_id") - .withColumn("version", 'operationVersion) - .withColumn("time", 'time.cast("timestamp")) - .withColumn("timestamp", unix_timestamp('time) * 1000) - .withColumn("date", 'time.cast("date")) - .select(requiredColumns: _*) - .withColumn("userIdentity", structFromJson(spark, schemaBuilders, "userIdentity")) - .withColumn("requestParams", structFromJson(spark, schemaBuilders, "requestParams")) - - val auditDF = PipelineFunctions.cleanseCorruptAuditLogs(spark, baselineAuditLogs) - .withColumn("response", structFromJson(spark, schemaBuilders, "response")) - .withColumn("requestParamsJson", to_json('requestParams)) - .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) - .drop("logId", "requestParamsJson") - - auditDF - - } else { - - // inclusive from exclusive to - val datesGlob = if (fromDT == untilDT) { - Array(s"${auditLogConfig.rawAuditPath.get}/date=${fromDT.toString}") - } else { - getDatesGlob(fromDT, untilDT.plusDays(1)) // add one day to until to ensure intra-day audit logs prior to untilTS are captured. - .map(dt => s"${auditLogConfig.rawAuditPath.get}/date=${dt}") - .filter(Helpers.pathExists) - } - - val auditLogsFailureMsg = s"Audit Logs Module Failure: Audit logs are required to use Overwatch and no data " + - s"was found in the following locations: ${datesGlob.mkString(", ")}" - - if (datesGlob.nonEmpty) { - val rawDF = try { - spark.read.format(auditLogConfig.auditLogFormat).load(datesGlob: _*) - } catch { // corrupted audit logs with duplicate columns in the source - case e: AnalysisException if e.message.contains("Found duplicate column(s) in the data schema") => - spark.conf.set("spark.sql.caseSensitive", "true") - spark.read.format(auditLogConfig.auditLogFormat).load(datesGlob: _*) - } - // clean corrupted source audit logs even when there is only one of the duplicate columns in the source - // but still will conflict with the existing columns in the target - val cleanRawDF = PipelineFunctions.cleanseCorruptAuditLogs(spark, rawDF) - - val baseDF = if (auditLogConfig.auditLogFormat == "json") cleanRawDF else { - val rawDFWRPJsonified = cleanRawDF - .withColumn("requestParams", to_json('requestParams)) - rawDFWRPJsonified - .withColumn("requestParams", structFromJson(spark, rawDFWRPJsonified, "requestParams")) - } - - baseDF - // When globbing the paths, the date must be reconstructed and re-added manually - .withColumn("organization_id", lit(organizationId)) - .withColumn("requestParamsJson", to_json('requestParams)) - .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) - .drop("requestParamsJson") - .withColumn("filename", input_file_name) - .withColumn("filenameAR", split(input_file_name, "/")) - .withColumn("date", - split(expr("filter(filenameAR, x -> x like ('date=%'))")(0), "=")(1).cast("date")) - .drop("filenameAR") - .verifyMinimumSchema(Schema.auditMasterSchema) - } else { - throw new Exception(auditLogsFailureMsg) - } - } + if(auditLogConfig.systemTableName.isDefined) + getAuditLogsDfFromSystemTables(fromTime, untilTime, organizationId, auditLogConfig, apiEnv) + else + getAuditLogsDfFromCloud(auditLogConfig, cloudProvider, fromTime, untilTime, auditRawLand, overwatchRunID, organizationId) } private def buildClusterEventBatches(apiEnv: ApiEnv, @@ -504,12 +404,11 @@ trait BronzeTransforms extends SparkSessionWrapper { private def landClusterEvents(clusterIDs: Array[String], startTime: TimeTypes, endTime: TimeTypes, - apiEnv: ApiEnv, + pipelineSnapTime: Long, tmpClusterEventsSuccessPath: String, tmpClusterEventsErrorPath: String, config: Config) = { val finalResponseCount = clusterIDs.length - implicit val ec: ExecutionContextExecutor = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(config.apiEnv.threadPoolSize)) val clusterEventsEndpoint = "clusters/events" val lagTime = 600000 //10 minutes @@ -529,17 +428,19 @@ trait BronzeTransforms extends SparkSessionWrapper { // calling function to make parallel API calls val apiCallV2Obj = new ApiCallV2(config.apiEnv) - apiCallV2Obj.makeParallelApiCalls(clusterEventsEndpoint, jsonInput, config) + apiCallV2Obj.makeParallelApiCalls(clusterEventsEndpoint, jsonInput, pipelineSnapTime, config) logger.log(Level.INFO, " Cluster event landing completed") } private def processClusterEvents(tmpClusterEventsSuccessPath: String, organizationId: String, erroredBronzeEventsTarget: PipelineTable): DataFrame = { logger.log(Level.INFO, "COMPLETE: Cluster Events acquisition, building data") if (Helpers.pathExists(tmpClusterEventsSuccessPath)) { - if (spark.read.json(tmpClusterEventsSuccessPath).columns.contains("events")) { + val baseDF = spark.read.json(tmpClusterEventsSuccessPath) + val rawDf = deriveRawApiResponseDF(baseDF) + if (rawDf.columns.contains("events")) { try { val tdf = SchemaScrubber.scrubSchema( - spark.read.json(tmpClusterEventsSuccessPath) + rawDf .select(explode('events).alias("events")) .select(col("events.*")) ).scrubSchema @@ -594,7 +495,8 @@ trait BronzeTransforms extends SparkSessionWrapper { organizationId: String, database: Database, erroredBronzeEventsTarget: PipelineTable, - config: Config + config: Config, + apiEndpointTempDir: String )(clusterSnapshotDF: DataFrame): DataFrame = { val clusterIDs = getClusterIdsWithNewEvents(filteredAuditLogDF, clusterSnapshotDF) @@ -607,10 +509,11 @@ trait BronzeTransforms extends SparkSessionWrapper { val processingStartTime = System.currentTimeMillis(); logger.log(Level.INFO, "Calling APIv2, Number of cluster id:" + clusterIDs.length + " run id :" + apiEnv.runID) - val tmpClusterEventsSuccessPath = s"${config.tempWorkingDir}/clusterEventsBronze/success" + apiEnv.runID - val tmpClusterEventsErrorPath = s"${config.tempWorkingDir}/clusterEventsBronze/error" + apiEnv.runID - landClusterEvents(clusterIDs, startTime, endTime, apiEnv, tmpClusterEventsSuccessPath, + val tmpClusterEventsSuccessPath = s"${config.tempWorkingDir}/${apiEndpointTempDir}/success_" + pipelineSnapTS.asUnixTimeMilli + val tmpClusterEventsErrorPath = s"${config.tempWorkingDir}/${apiEndpointTempDir}/error_" + pipelineSnapTS.asUnixTimeMilli + + landClusterEvents(clusterIDs, startTime, endTime, pipelineSnapTS.asUnixTimeMilli, tmpClusterEventsSuccessPath, tmpClusterEventsErrorPath, config) if (Helpers.pathExists(tmpClusterEventsErrorPath)) { persistErrors( @@ -1141,4 +1044,215 @@ trait BronzeTransforms extends SparkSessionWrapper { cleanDF } + private def fetchDatafromSystemTableAuditLog( + fromTimeSysTableCompatible: String, + untilTimeSysTableCompatible: String, + organizationId: String, + auditLogConfig: AuditLogConfig, + apiEnv: ApiEnv + ): DataFrame = { + try { + if(auditLogConfig.sqlEndpoint.getOrElse("").nonEmpty) { + val host = apiEnv.workspaceURL.stripPrefix("https://").stripSuffix("/") + val sqlEndpoint = auditLogConfig.sqlEndpoint.get + val query = + s"""select * from ${auditLogConfig.systemTableName.get.toString} + |where workspace_id='${organizationId}' + |and event_time >= '${fromTimeSysTableCompatible}' + |and event_time <= '${untilTimeSysTableCompatible}' + |""".stripMargin + logger.log(Level.INFO, query) + val systemTableNameDf = spark.read + .format("databricks") + .option("host", host) + .option("httpPath", sqlEndpoint) + .option("personalAccessToken", apiEnv.rawToken) + .option("query", query) + .load() + systemTableNameDf + } + else{ + spark.table(auditLogConfig.systemTableName.get.toString) + .filter('workspace_id === organizationId) + .filter('event_time >= fromTimeSysTableCompatible + && 'event_time <= untilTimeSysTableCompatible) + } + } catch { + case e: org.apache.spark.sql.AnalysisException => + throw new Exception(s"Issues while fetching data from system.access.audit: ${e.getMessage}") + } + } + + + def getAuditLogsDfFromSystemTables( + fromTime: LocalDateTime, + untilTime: LocalDateTime, + organizationId: String, + auditLogConfig: AuditLogConfig, + apiEnv: ApiEnv + ): DataFrame = { + try { + println(s"Fetching data from system.access.audit for workspace_id - ${organizationId}") + val sysTableFormat = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS") + // Adding the below code to add time so that the whole day data can be fetched, despite of the fromTime and untilTime + val fromTimeSysTableCompatible = fromTime.withHour(0).withMinute(0).withSecond(0).format(sysTableFormat) + val untilTimeSysTableCompatible = untilTime.withHour(23).withMinute(59).withSecond(59).format(sysTableFormat) + println(s"system.access.audit fromTime - ${fromTimeSysTableCompatible}") + println(s"system.access.audit untilTime - ${untilTimeSysTableCompatible}") + + val rawSystemTableFiltered = fetchDatafromSystemTableAuditLog(fromTimeSysTableCompatible, + untilTimeSysTableCompatible, + organizationId, + auditLogConfig, + apiEnv + ) + + if (rawSystemTableFiltered.isEmpty) { + val message = s"No Data present in system.access.audit for organizationId: $organizationId " + + s"and fromTime: $fromTimeSysTableCompatible and untilTime: $untilTimeSysTableCompatible" + logger.log(Level.WARN, message) + throw new NoNewDataException(message, Level.WARN, allowModuleProgression = false) + } + + val isSqlEndpointEmpty = auditLogConfig.sqlEndpoint.getOrElse("").isEmpty + // if sql endpoint is not empty, the audit log data will be fetched from + // internal system tables then we need to convert the requestParams to json + val deriveRequestParams = if(isSqlEndpointEmpty) to_json(col("requestParams")) else col("requestParams") + + + val auditLogFromSysTable = SchemaTools.snakeToCamel(rawSystemTableFiltered) + .withColumn("organization_id", col("workspaceID")) + .withColumnRenamed("eventDate", "date") + .withColumn("timestamp", (col("eventTime").cast("double") * 1000).cast("long")) + .withColumn("requestParamsString", deriveRequestParams) + .drop("requestParams", "eventTime") + + // if sql endpoint is not empty, the audit log data will be derived from extenal system tables + // then userIdentity and response cols needs to be converted to struct from string + val deriveUserIdentity = if(isSqlEndpointEmpty) col("userIdentity") + else structFromJson(spark, auditLogFromSysTable, "userIdentity") + val deriveResponse = if(isSqlEndpointEmpty) col("response") + else structFromJson(spark, auditLogFromSysTable, "response") + + val auditLogFromSysTableToStruct = auditLogFromSysTable + .withColumn("requestParams", structFromJson(spark, auditLogFromSysTable, "requestParamsString")) + .withColumn("userIdentity", deriveUserIdentity) + .withColumn("response", deriveResponse) + .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsString)) + .verifyMinimumSchema(Schema.auditMasterSchema) + .drop("requestParamsString") + .withColumn("response", $"response".withField("statusCode", + coalesce($"response.statusCode", $"response.status_code".cast(LongType)))) + .withColumn("response", $"response".withField("errorMessage", + coalesce($"response.errorMessage", $"response.error_message"))) + .withColumn("response", struct($"response.statusCode", $"response.errorMessage", $"response.result")) + + auditLogFromSysTableToStruct + } catch { + case e: org.apache.spark.sql.AnalysisException => + throw new Exception(s"Issues while fetching data from table system.access.audit: ${e.getMessage}") + case e: Exception => throw e + } + } + + + def getAuditLogsDfFromCloud(auditLogConfig: AuditLogConfig, + cloudProvider: String, + fromTime: LocalDateTime, + untilTime: LocalDateTime, + auditRawLand: PipelineTable, + overwatchRunID: String, + organizationId: String + ): DataFrame = { + val fromDT = fromTime.toLocalDate + val untilDT = untilTime.toLocalDate + if (cloudProvider == "azure") { + val azureAuditSourceFilters = 'Overwatch_RunID === lit(overwatchRunID) && 'organization_id === organizationId + val rawBodyLookup = auditRawLand.asDF + .filter(azureAuditSourceFilters) + val schemaBuilders = auditRawLand.asDF + .filter(azureAuditSourceFilters) + .withColumn("parsedBody", structFromJson(spark, rawBodyLookup, "deserializedBody")) + .select(explode($"parsedBody.records").alias("streamRecord"), 'organization_id) + .selectExpr("streamRecord.*", "organization_id") + .withColumn("version", 'operationVersion) + .withColumn("time", 'time.cast("timestamp")) + .withColumn("timestamp", unix_timestamp('time) * 1000) + .withColumn("date", 'time.cast("date")) + .select('category, 'version, 'timestamp, 'date, 'properties, 'identity.alias("userIdentity"), 'organization_id) + .selectExpr("*", "properties.*").drop("properties") + + + val baselineAuditLogs = auditRawLand.asDF + .filter(azureAuditSourceFilters) + .withColumn("parsedBody", structFromJson(spark, rawBodyLookup, "deserializedBody")) + .select(explode($"parsedBody.records").alias("streamRecord"), 'organization_id) + .selectExpr("streamRecord.*", "organization_id") + .withColumn("version", 'operationVersion) + .withColumn("time", 'time.cast("timestamp")) + .withColumn("timestamp", unix_timestamp('time) * 1000) + .withColumn("date", 'time.cast("date")) + .select('category, 'version, 'timestamp, 'date, 'properties, 'identity.alias("userIdentity"), 'organization_id) + .withColumn("userIdentity", structFromJson(spark, schemaBuilders, "userIdentity")) + .selectExpr("*", "properties.*").drop("properties") + .withColumn("requestParams", structFromJson(spark, schemaBuilders, "requestParams")) + + PipelineFunctions.cleanseCorruptAuditLogs(spark, baselineAuditLogs) + .withColumn("response", structFromJson(spark, schemaBuilders, "response")) + .withColumn("requestParamsJson", to_json('requestParams)) + .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) + .drop("logId", "requestParamsJson") + + } else { + + // inclusive from exclusive to + val datesGlob = if (fromDT == untilDT) { + Array(s"${auditLogConfig.rawAuditPath.get}/date=${fromDT.toString}") + } else { + getDatesGlob(fromDT, untilDT.plusDays(1)) // add one day to until to ensure intra-day audit logs prior to untilTS are captured. + .map(dt => s"${auditLogConfig.rawAuditPath.get}/date=${dt}") + .filter(Helpers.pathExists) + } + + val auditLogsFailureMsg = s"Audit Logs Module Failure: Audit logs are required to use Overwatch and no data " + + s"was found in the following locations: ${datesGlob.mkString(", ")}" + + if (datesGlob.nonEmpty) { + val rawDF = try { + spark.read.format(auditLogConfig.auditLogFormat).load(datesGlob: _*) + } catch { // corrupted audit logs with duplicate columns in the source + case e: AnalysisException if e.message.contains("Found duplicate column(s) in the data schema") => + spark.conf.set("spark.sql.caseSensitive", "true") + spark.read.format(auditLogConfig.auditLogFormat).load(datesGlob: _*) + } + // clean corrupted source audit logs even when there is only one of the duplicate columns in the source + // but still will conflict with the existing columns in the target + val cleanRawDF = PipelineFunctions.cleanseCorruptAuditLogs(spark, rawDF) + + val baseDF = if (auditLogConfig.auditLogFormat == "json") cleanRawDF else { + val rawDFWRPJsonified = cleanRawDF + .withColumn("requestParams", to_json('requestParams)) + rawDFWRPJsonified + .withColumn("requestParams", structFromJson(spark, rawDFWRPJsonified, "requestParams")) + } + + baseDF + // When globbing the paths, the date must be reconstructed and re-added manually + .withColumn("organization_id", lit(organizationId)) + .withColumn("requestParamsJson", to_json('requestParams)) + .withColumn("hashKey", xxhash64('organization_id, 'timestamp, 'serviceName, 'actionName, 'requestId, 'requestParamsJson)) + .drop("requestParamsJson") + .withColumn("filename", input_file_name) + .withColumn("filenameAR", split(input_file_name, "/")) + .withColumn("date", + split(expr("filter(filenameAR, x -> x like ('date=%'))")(0), "=")(1).cast("date")) + .drop("filenameAR") + .verifyMinimumSchema(Schema.auditMasterSchema) + } else { + throw new Exception(auditLogsFailureMsg) + } + } + } + + } \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/DbsqlTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/DbsqlTransforms.scala index 27e15db20..419aeed18 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/DbsqlTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/DbsqlTransforms.scala @@ -58,9 +58,11 @@ object DbsqlTransforms extends SparkSessionWrapper { * @param warehouseBaseWMetaDF * @return */ - def deriveWarehouseBaseFilled(isFirstRun: Boolean, bronzeWarehouseSnapUntilCurrent: DataFrame) + def deriveWarehouseBaseFilled(isFirstRun: Boolean, + bronzeWarehouseSnapUntilCurrent: DataFrame, + warehouseSpecSilver: PipelineTable) (warehouseBaseWMetaDF: DataFrame): DataFrame = { - val result = if (isFirstRun) { + val result = if (isFirstRun || warehouseSpecSilver.exists(dataValidation = true)) { val firstRunMsg = "Silver_WarehouseSpec -- First run detected, will impute warehouse state from bronze to derive " + "current initial state for all existing warehouses." logger.log(Level.INFO, firstRunMsg) @@ -161,11 +163,16 @@ object DbsqlTransforms extends SparkSessionWrapper { 'warehouse_type ) - val auditLogDfWithStructs = auditLogDf + val rawAuditLogDf = auditLogDf .filter('actionName.isin("createEndpoint", "editEndpoint", "createWarehouse", "editWarehouse", "deleteEndpoint", "deleteWarehouse") && responseSuccessFilter && 'serviceName === "databrickssql") + + if(rawAuditLogDf.isEmpty) + throw new NoNewDataException("No New Data", Level.INFO, allowModuleProgression = true) + + val auditLogDfWithStructs = rawAuditLogDf .selectExpr("*", "requestParams.*").drop("requestParams", "Overwatch_RunID") .select(warehouseSummaryCols: _*) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala index 999af2916..08d5c6aca 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Gold.scala @@ -306,7 +306,7 @@ class Gold(_workspace: Workspace, _database: Database, _config: Config) ) } - lazy private[overwatch] val notebookCommandsFactModule = Module(3019, "Gold_NotebookCommands", this, Array(1004,3004,3005)) + lazy private[overwatch] val notebookCommandsFactModule = Module(3019, "Gold_NotebookCommands", this, Array(1004,3004,3005),6.0, shuffleFactor = 4.0) lazy private val appendNotebookCommandsFactProcess: () => ETLDefinition = { () => ETLDefinition( diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala index ac2420b5c..3f41ab076 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/GoldTransforms.scala @@ -499,11 +499,16 @@ trait GoldTransforms extends SparkSessionWrapper { clsfIncrementalDF : DataFrame, )(auditIncrementalDF: DataFrame): DataFrame = { + if (auditIncrementalDF.isEmpty || notebook.asDF.isEmpty || clsfIncrementalDF.isEmpty) { + throw new NoNewDataException("No New Data", Level.WARN, true) + } + val auditDF_base = auditIncrementalDF .filter(col("serviceName") === "notebook" && col("actionName") === "runCommand") .selectExpr("*", "requestParams.*").drop("requestParams") if (auditDF_base.columns.contains("executionTime")){ + val notebookLookupTSDF = notebook.asDF .select("organization_id", "notebook_id", "notebook_path", "notebook_name", "unixTimeMS", "date") .withColumnRenamed("notebook_id", "notebookId") @@ -710,7 +715,7 @@ trait GoldTransforms extends SparkSessionWrapper { .withColumn("db_id_in_job", when(isDatabricksJob && 'db_run_id.isNull, extractDBIdInJob('jobGroupAr)) .otherwise( - when(isAutomatedCluster && 'db_run_id.isNull, extractDBJobId('cluster_name)) + when(isAutomatedCluster && 'db_run_id.isNull, extractDBIdInJob('cluster_name)) .otherwise('db_run_id) ) ) @@ -816,7 +821,7 @@ trait GoldTransforms extends SparkSessionWrapper { // when there is no input data break out of module, progress timeline and continue with pipeline val emptyMsg = s"No new streaming data found." - if (streamRawDF.isEmpty) throw new NoNewDataException(emptyMsg, Level.WARN, allowModuleProgression = true) + if (streamRawDF.filter('progress.isNotNull).isEmpty) throw new NoNewDataException(emptyMsg, Level.WARN, allowModuleProgression = true) val lastStreamValue = Window.partitionBy('organization_id, 'SparkContextId, 'clusterId, 'stream_id, 'stream_run_id).orderBy('stream_timestamp) val onlyOnceEventGuaranteeW = Window.partitionBy(streamTargetKeys map col: _*).orderBy('fileCreateEpochMS.desc) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala index 40133a9c6..bb8ba41c6 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala @@ -191,6 +191,18 @@ object Initializer extends SparkSessionWrapper { ) } + def apply(overwatchArgs: String, disableValidations: Boolean,initializeDatabase: Boolean,organizationID : Option[String]): Workspace = { + apply( + overwatchArgs, + debugFlag = false, + isSnap = false, + disableValidations, + initializeDatabase, + apiURL = None, + organizationID + ) + } + /** * * @param overwatchArgs Json string of args -- When passing into args in Databricks job UI, the json string must diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/InitializerFunctions.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/InitializerFunctions.scala index cecb4c704..4c6ff8fb5 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/InitializerFunctions.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/InitializerFunctions.scala @@ -2,11 +2,19 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.dbutils_v1.DBUtilsHolder.dbutils import com.databricks.labs.overwatch.env.Database +import com.databricks.labs.overwatch.utils.Helpers.{getCurrentCatalogName, setCurrentCatalog, spark} import com.databricks.labs.overwatch.utils.OverwatchScope.{OverwatchScope, _} import com.databricks.labs.overwatch.utils._ import org.apache.log4j.{Level, Logger} +import org.apache.spark.sql.catalyst.dsl.expressions.{DslExpression, DslSymbol} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.{lit, rank, row_number} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ + trait InitializerFunctions extends SparkSessionWrapper { @@ -95,9 +103,14 @@ trait InitializerFunctions // Audit logs are required and paramount to Overwatch delivery -- they must be present and valid /** Validate and set Audit Log Configs */ val rawAuditLogConfig = rawParams.auditLogConfig - val validatedAuditLogConfig = validateAuditLogConfigs(rawAuditLogConfig) + + val validatedAuditLogConfig = validateAuditLogConfigs(rawAuditLogConfig, config.organizationId, + config.workspaceURL, config.sqlEndpoint, validatedTokenSecret) config.setAuditLogConfig(validatedAuditLogConfig) + // check if it is a valid migration from cloud to system table + checkSystemTableMigrationValidity(config) + // must happen AFTER data target validation // persistent location for corrupted spark event log files /** Validate and set Bad Records Path */ @@ -400,69 +413,26 @@ trait InitializerFunctions } @throws(classOf[BadConfigException]) - def validateAuditLogConfigs(auditLogConfig: AuditLogConfig): AuditLogConfig = { - - if (disableValidations) { + def validateAuditLogConfigs(auditLogConfig: AuditLogConfig, + organization_id: String, + workspace_url: String, + sql_endpoint: String, + token: Option[TokenSecret]): AuditLogConfig = { + if (disableValidations) { //need to double check this quickBuildAuditLogConfig(auditLogConfig) } else { - if (config.cloudProvider != "azure") { - - val auditLogPath = auditLogConfig.rawAuditPath - val auditLogFormat = auditLogConfig.auditLogFormat.toLowerCase.trim - if (config.overwatchScope.contains(audit) && auditLogPath.isEmpty) { - throw new BadConfigException("Audit cannot be in scope without the 'auditLogPath' being set. ") - } - - if (auditLogPath.nonEmpty) - dbutils.fs.ls(auditLogPath.get).foreach(auditFolder => { - if (auditFolder.isDir) require(auditFolder.name.startsWith("date="), s"Audit directory must contain " + - s"partitioned date folders in the format of ${auditLogPath.get}/date=. Received ${auditFolder} instead.") - }) - - val supportedAuditLogFormats = Array("json", "parquet", "delta") - if (!supportedAuditLogFormats.contains(auditLogFormat)) { - throw new BadConfigException(s"Audit Log Format: Supported formats are ${supportedAuditLogFormats.mkString(",")} " + - s"but $auditLogFormat was placed in teh configuration. Please select a supported audit log format.") - } - - val finalAuditLogPath = if (auditLogPath.get.endsWith("/")) auditLogPath.get.dropRight(1) else auditLogPath.get - - // return validated audit log config for aws - auditLogConfig.copy(rawAuditPath = Some(finalAuditLogPath), auditLogFormat = auditLogFormat) - + if (ifFetchFromSystemTable(auditLogConfig)){ + validateAuditLogConfigsFromSystemTable(auditLogConfig,organization_id, workspace_url, token) } else { - val ehConfigOp = auditLogConfig.azureAuditLogEventhubConfig - require(ehConfigOp.nonEmpty, "When using Azure, an Eventhub must be configured for audit log retrieval") - val ehConfig = ehConfigOp.get - val ehPrefix = ehConfig.auditRawEventsPrefix - - val cleanPrefix = if (ehPrefix.endsWith("/")) ehPrefix.dropRight(1) else ehPrefix - val rawEventsCheckpoint = ehConfig.auditRawEventsChk.getOrElse(s"${ehPrefix}/rawEventsCheckpoint") - // TODO -- Audit log bronze is no longer streaming target -- remove this path - val auditLogBronzeChk = ehConfig.auditLogChk.getOrElse(s"${ehPrefix}/auditLogBronzeCheckpoint") - - if (config.debugFlag) { - println("DEBUG FROM Init") - println(s"cleanPrefix = ${cleanPrefix}") - println(s"rawEventsCheck = ${rawEventsCheckpoint}") - println(s"auditLogsBronzeChk = ${auditLogBronzeChk}") - println(s"ehPrefix = ${ehPrefix}") - } - - val ehFinalConfig = ehConfig.copy( - auditRawEventsPrefix = cleanPrefix, - auditRawEventsChk = Some(rawEventsCheckpoint), - auditLogChk = Some(auditLogBronzeChk) - ) - - // parse the connection string to validate format - PipelineFunctions.parseAndValidateEHConnectionString(ehFinalConfig.connectionString, ehFinalConfig.azureClientId.isEmpty) - // return validated auditLogConfig for Azure - auditLogConfig.copy(azureAuditLogEventhubConfig = Some(ehFinalConfig)) + validateAuditLogConfigsFromCloud(auditLogConfig) } } } + def ifFetchFromSystemTable(auditLogConfig: AuditLogConfig): Boolean = { + auditLogConfig.rawAuditPath.getOrElse("").toLowerCase.equals("system") + } + /** * defaults temp working dir to etlTargetPath/organizationId * this is important to minimize bandwidth issues @@ -536,4 +506,135 @@ trait InitializerFunctions */ def validateAndSetDataTarget(dataTarget: DataTarget): Unit + def validateSqlEndpoint(sqlEndpoint: String, organizationId: String): String = { + val pattern = """/sql/1\.0/warehouses/.*""".r + sqlEndpoint.trim match { + case pattern(_*) => sqlEndpoint + case "" => sqlEndpoint + case _ => throw new Exception(s"Invalid sqlEndpoint for organizationId: $organizationId ") + } + } + def validateAuditLogConfigsFromSystemTable(auditLogConfig: AuditLogConfig, + organizationId: String, + workspace_url: String, + token: Option[TokenSecret]): AuditLogConfig = { + val auditLogFormat = "delta" + val systemTableName = auditLogConfig.systemTableName.get + + val sqlEndpoint = validateSqlEndpoint(auditLogConfig.sqlEndpoint.getOrElse(""),organizationId) + if(sqlEndpoint.isEmpty) { + val systemTableNameDf = spark.table(systemTableName).filter(s"workspace_id = '$organizationId'").limit(1) + if (systemTableNameDf.isEmpty) + throw new Exception(s"No data found in ${systemTableName} for organizationId: $organizationId ") + auditLogConfig.copy(auditLogFormat=auditLogFormat,systemTableName = Some(systemTableName)) + } + else { + val host = workspace_url.stripPrefix("https://").stripSuffix("/") + val scope = token.get.scope + val key = token.get.key + val rawToken = dbutils.secrets.get(scope, key) + val systemTableNameDf = spark.read + .format("databricks") + .option("host", host) + .option("httpPath", sqlEndpoint) + .option("personalAccessToken", rawToken) + .option("query", s"select * from ${systemTableName} where workspace_id='${organizationId}' limit 1") + .load() + if (systemTableNameDf.isEmpty) + throw new Exception(s"No data found in ${systemTableName} for organizationId: $organizationId ") + auditLogConfig.copy(auditLogFormat=auditLogFormat,systemTableName = Some(systemTableName),sqlEndpoint = Some(sqlEndpoint)) + } + } + + def validateAuditLogConfigsFromCloud(auditLogConfig: AuditLogConfig): AuditLogConfig = { + if (config.cloudProvider != "azure") { + + val auditLogPath = auditLogConfig.rawAuditPath + val auditLogFormat = auditLogConfig.auditLogFormat.toLowerCase.trim + if (config.overwatchScope.contains(audit) && auditLogPath.isEmpty) { + throw new BadConfigException("Audit cannot be in scope without the 'auditLogPath' being set. ") + } + + if (auditLogPath.nonEmpty) + dbutils.fs.ls(auditLogPath.get).foreach(auditFolder => { + if (auditFolder.isDir) require(auditFolder.name.startsWith("date="), s"Audit directory must contain " + + s"partitioned date folders in the format of ${auditLogPath.get}/date=. Received ${auditFolder} instead.") + }) + + val supportedAuditLogFormats = Array("json", "parquet", "delta") + if (!supportedAuditLogFormats.contains(auditLogFormat)) { + throw new BadConfigException(s"Audit Log Format: Supported formats are ${supportedAuditLogFormats.mkString(",")} " + + s"but $auditLogFormat was placed in teh configuration. Please select a supported audit log format.") + } + + val finalAuditLogPath = if (auditLogPath.get.endsWith("/")) auditLogPath.get.dropRight(1) else auditLogPath.get + + // return validated audit log config for aws + auditLogConfig.copy(rawAuditPath = Some(finalAuditLogPath), auditLogFormat = auditLogFormat) + + } else { + val ehConfigOp = auditLogConfig.azureAuditLogEventhubConfig + require(ehConfigOp.nonEmpty, "When using Azure, an Eventhub must be configured for audit log retrieval") + val ehConfig = ehConfigOp.get + val ehPrefix = ehConfig.auditRawEventsPrefix + + val cleanPrefix = if (ehPrefix.endsWith("/")) ehPrefix.dropRight(1) else ehPrefix + val rawEventsCheckpoint = ehConfig.auditRawEventsChk.getOrElse(s"${ehPrefix}/rawEventsCheckpoint") + // TODO -- Audit log bronze is no longer streaming target -- remove this path + val auditLogBronzeChk = ehConfig.auditLogChk.getOrElse(s"${ehPrefix}/auditLogBronzeCheckpoint") + + if (config.debugFlag) { + println("DEBUG FROM Init") + println(s"cleanPrefix = ${cleanPrefix}") + println(s"rawEventsCheck = ${rawEventsCheckpoint}") + println(s"auditLogsBronzeChk = ${auditLogBronzeChk}") + println(s"ehPrefix = ${ehPrefix}") + } + + val ehFinalConfig = ehConfig.copy( + auditRawEventsPrefix = cleanPrefix, + auditRawEventsChk = Some(rawEventsCheckpoint), + auditLogChk = Some(auditLogBronzeChk) + ) + + // parse the connection string to validate format + PipelineFunctions.parseAndValidateEHConnectionString(ehFinalConfig.connectionString, ehFinalConfig.azureClientId.isEmpty) + // return validated auditLogConfig for Azure + auditLogConfig.copy(azureAuditLogEventhubConfig = Some(ehFinalConfig)) + } + } + + def checkSystemTableMigrationValidity(config: Config): Boolean = { + val workspaceID = config.organizationId + val latestConfigByOrg = Window.partitionBy(col("organization_id")).orderBy(col("Pipeline_SnapTS").desc) + val etlDB = config.databaseName + val initialCatalog = getCurrentCatalogName(spark) + val etlDBWithOutCatalog = if(etlDB.contains(".")){ + setCurrentCatalog(spark, etlDB.split("\\.").head) + etlDB.split("\\.").last + } else etlDB + + if(!spark.catalog.tableExists(s"${etlDBWithOutCatalog}.pipeline_report")) { + logger.log(Level.INFO, s"Since it is a first run no need to check for migration validity") + setCurrentCatalog(spark, initialCatalog) + return true + } + + val lastValue = spark.table(s"${etlDBWithOutCatalog}.pipeline_report") + .filter(col("status") === "SUCCESS") + .withColumn("rnk", rank().over(latestConfigByOrg)) + .withColumn("rn", row_number().over(latestConfigByOrg)) + .filter(col("rnk") === 1 && col("rn") === 1) + .filter(col("organization_id") === workspaceID) + .select("inputConfig.auditLogConfig.rawAuditPath").collect.map(x=>x(0)).mkString + + setCurrentCatalog(spark, initialCatalog) + val currentValue = config.auditLogConfig.rawAuditPath.getOrElse("") + if( lastValue == "system" && currentValue !="system" ) + throw new Exception(s"Cannot migrate from system table to cloud for organization_id - ${workspaceID}" + + s" Please use the same configuration as the last run") + else + true + } + } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala index 4c52dabe9..229a3982e 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala @@ -1,9 +1,11 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.labs.overwatch.pipeline.TransformFunctions._ +import com.databricks.labs.overwatch.utils.Helpers.{deriveApiTempDir, deriveApiTempErrDir, pathExists} import com.databricks.labs.overwatch.utils._ import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, lit} import java.time.Duration import scala.util.parsing.json.JSON.number @@ -423,11 +425,62 @@ class Module( logger.log(Level.ERROR, msg, e) fail(msg) } finally { + if (spark.conf.getOption("overwatch.traceapi").getOrElse("false").toBoolean) { + persistApiEvents() + } spark.catalog.clearCache() } } + /** + * Function add necessary fields to the apiEventDetails and persist it. + * + * @param rawTraceDF + */ + private def transformAndPersistApiEvents(rawTraceDF: DataFrame) = { + val rawStructDF = rawTraceDF + .select("apiTraceabilityMeta.*", "rawResponse") + val batchKeyFilterOverride = if (rawStructDF.columns.contains("batchKeyFilter")) { + col("batchKeyFilter") + } else { + lit("") + } + + val finalDF = rawStructDF + .withColumn("batchKeyFilter", batchKeyFilterOverride) + .withColumn("data", col("rawResponse").cast("binary")) + .select("endPoint", "type", "apiVersion", "batchKeyFilter", "responseCode", "data") + .withColumn("moduleId", lit(moduleId)) + .withColumn("moduleName", lit(moduleName)) + .withColumn("organization_id", lit(config.organizationId)) + .withColumn("snapTS", lit(pipeline.pipelineSnapTime.asTSString)) + .withColumn("timestamp", lit(pipeline.pipelineSnapTime.asUnixTimeMilli)) + .withColumn("Overwatch_RunID", lit(config.runID)) + + val optimizeDF = PipelineFunctions.optimizeDFForWrite(finalDF, pipeline.apiEventsTarget) + pipeline.database.writeWithRetry(optimizeDF, pipeline.apiEventsTarget, pipeline.pipelineSnapTime.asColumnTS) + + } + + /** + * Function persists the apiEvent details. + */ + private def persistApiEvents(): Unit = { + try { + val successPath = deriveApiTempDir(config.tempWorkingDir, moduleName, pipeline.pipelineSnapTime) + if (pathExists(successPath)) { + val rawTraceDF = spark.read.json(successPath) + transformAndPersistApiEvents(rawTraceDF) + } + + } catch { + case e: Throwable => + println("got exception while writing" + e.getMessage) + e.getMessage + } + } + } diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala index 381626c91..205bbebd2 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTable.scala @@ -42,7 +42,8 @@ case class PipelineTable( workspaceName: Boolean = true, isTemp: Boolean = false, checkpointPath: Option[String] = None, - masterSchema: Option[StructType] = None + masterSchema: Option[StructType] = None, + excludedReconColumn: Array[String] = Array() ) extends SparkSessionWrapper { private val logger: Logger = Logger.getLogger(this.getClass) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala index 066568fcd..979152fec 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/PipelineTargets.scala @@ -19,6 +19,16 @@ abstract class PipelineTargets(config: Config) { "writeOpsMetrics, lastOptimizedTS, Pipeline_SnapTS, primordialDateString").split(", ") ) + val apiEventsTarget: PipelineTable = PipelineTable( + name = "apiEventDetails", + _keys = Array("organization_id", "Overwatch_RunID", "endPoint"), + config = config, + incrementalColumns = Array("Pipeline_SnapTS"), + partitionBy = Array("organization_id", "endPoint"), + statsColumns = ("organization_id,endPoint").split(", ") + ) + + lazy private[overwatch] val pipelineStateViewTarget: PipelineView = PipelineView( name = "pipReport", pipelineStateTarget, @@ -74,7 +84,8 @@ abstract class PipelineTargets(config: Config) { _permitDuplicateKeys = false, _mode = WriteMode.merge, mergeScope = MergeScope.insertOnly, - masterSchema = Some(Schema.auditMasterSchema) + masterSchema = Some(Schema.auditMasterSchema), + excludedReconColumn = Array("hashKey","response","requestParams","userIdentity")// It will be different for system table ) lazy private[overwatch] val auditLogAzureLandRaw: PipelineTable = PipelineTable( @@ -86,7 +97,7 @@ abstract class PipelineTargets(config: Config) { incrementalColumns = Array("Pipeline_SnapTS"), zOrderBy = Array("Overwatch_RunID"), withOverwatchRunID = if (config.cloudProvider == "azure") false else true, - checkpointPath = if (config.cloudProvider == "azure") + checkpointPath = if (config.cloudProvider == "azure" && !config.auditLogConfig.systemTableName.isDefined) config.auditLogConfig.azureAuditLogEventhubConfig.get.auditRawEventsChk else None ) @@ -294,7 +305,8 @@ abstract class PipelineTargets(config: Config) { incrementalColumns = Array("startEpochMS"), // don't load into gold until run is terminated zOrderBy = Array("runId", "jobId"), partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), - persistBeforeWrite = true + persistBeforeWrite = true, + excludedReconColumn = Array("requestDetails") //for system tables extra data are coming ) lazy private[overwatch] val accountLoginTarget: PipelineTable = PipelineTable( @@ -319,7 +331,8 @@ abstract class PipelineTargets(config: Config) { _keys = Array("timestamp", "cluster_id"), config, incrementalColumns = Array("timestamp"), - partitionBy = Seq("organization_id", "__overwatch_ctrl_noise") + partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), + excludedReconColumn = Array("timestamp")// It will be SnapTS in epoc ) lazy private[overwatch] val clusterStateDetailTarget: PipelineTable = PipelineTable( @@ -339,7 +352,8 @@ abstract class PipelineTargets(config: Config) { _mode = WriteMode.merge, incrementalColumns = Array("timestamp"), statsColumns = Array("instance_pool_id", "instance_pool_name", "node_type_id"), - partitionBy = Seq("organization_id") + partitionBy = Seq("organization_id"), + excludedReconColumn = Array("request_details") ) lazy private[overwatch] val dbJobsStatusTarget: PipelineTable = PipelineTable( @@ -347,7 +361,8 @@ abstract class PipelineTargets(config: Config) { _keys = Array("timestamp", "jobId", "actionName", "requestId"), config, incrementalColumns = Array("timestamp"), - partitionBy = Seq("organization_id", "__overwatch_ctrl_noise") + partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), + excludedReconColumn = Array("response") ) lazy private[overwatch] val notebookStatusTarget: PipelineTable = PipelineTable( @@ -365,7 +380,8 @@ abstract class PipelineTargets(config: Config) { _mode = WriteMode.merge, _permitDuplicateKeys = false, incrementalColumns = Array("query_start_time_ms"), - partitionBy = Seq("organization_id") + partitionBy = Seq("organization_id"), + excludedReconColumn = Array("Timestamp") //Timestamp is the pipelineSnapTs in epoc ) lazy private[overwatch] val warehousesSpecTarget: PipelineTable = PipelineTable( @@ -373,7 +389,8 @@ abstract class PipelineTargets(config: Config) { _keys = Array("timestamp", "warehouse_id"), config, incrementalColumns = Array("timestamp"), - partitionBy = Seq("organization_id") + partitionBy = Seq("organization_id"), + excludedReconColumn = Array("Timestamp") //Timestamp is the pipelineSnapTs in epoc ) } @@ -401,7 +418,8 @@ abstract class PipelineTargets(config: Config) { _mode = WriteMode.merge, incrementalColumns = Array("timestamp"), statsColumns = Array("instance_pool_id", "instance_pool_name", "node_type_id"), - partitionBy = Seq("organization_id") + partitionBy = Seq("organization_id"), + excludedReconColumn = Array("request_details") ) lazy private[overwatch] val poolsViewTarget: PipelineView = PipelineView( @@ -415,7 +433,8 @@ abstract class PipelineTargets(config: Config) { _keys = Array("job_id", "unixTimeMS", "action", "request_id"), config, incrementalColumns = Array("unixTimeMS"), - partitionBy = Seq("organization_id", "__overwatch_ctrl_noise") + partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), + excludedReconColumn = Array("response") ) lazy private[overwatch] val jobViewTarget: PipelineView = PipelineView( @@ -431,7 +450,8 @@ abstract class PipelineTargets(config: Config) { _mode = WriteMode.merge, zOrderBy = Array("job_id", "run_id"), incrementalColumns = Array("startEpochMS"), - partitionBy = Seq("organization_id", "__overwatch_ctrl_noise") + partitionBy = Seq("organization_id", "__overwatch_ctrl_noise"), + excludedReconColumn = Array("request_detail") ) lazy private[overwatch] val jobRunsViewTarget: PipelineView = PipelineView( @@ -509,7 +529,8 @@ abstract class PipelineTargets(config: Config) { partitionBy = Seq("organization_id", "state_start_date", "__overwatch_ctrl_noise"), maxMergeScanDates = 31, // 1 greater than clusterStateDetail incrementalColumns = Array("state_start_date", "unixTimeMS_state_start"), - zOrderBy = Array("cluster_id", "unixTimeMS_state_start") + zOrderBy = Array("cluster_id", "unixTimeMS_state_start"), + excludedReconColumn = Array("driverSpecs","workerSpecs") //driverSpecs and workerSpecs contains PipelineSnapTs and runID ) lazy private[overwatch] val clusterStateFactViewTarget: PipelineView = PipelineView( diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala index 9b74162ae..416170c39 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Schema.scala @@ -1135,7 +1135,8 @@ object Schema extends SparkSessionWrapper { StructField("aad_tenant_id", StringType, nullable = true), StructField("aad_client_id", StringType, nullable = true), StructField("aad_client_secret_key", StringType, nullable = true), - StructField("aad_authority_endpoint", StringType, nullable = true) + StructField("aad_authority_endpoint", StringType, nullable = true), + StructField("sql_endpoint", StringType, nullable = true) )) val mountMinimumSchema: StructType = StructType(Seq( diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala index 7204bcc5f..6a99e0182 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Silver.scala @@ -1,6 +1,7 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.labs.overwatch.env.{Database, Workspace} +import com.databricks.labs.overwatch.utils.Helpers.{deriveApiTempDir, deriveApiTempErrDir} import com.databricks.labs.overwatch.utils.{Config, OverwatchScope} import org.apache.log4j.Logger @@ -329,7 +330,9 @@ class Silver(_workspace: Workspace, _database: Database, _config: Config) ), Seq(buildClusterStateDetail( clusterStateDetailModule.untilTime, - BronzeTargets.auditLogsTarget.asIncrementalDF(clusterSpecModule, BronzeTargets.auditLogsTarget.incrementalColumns,1) //Added to get the Removed Cluster + BronzeTargets.auditLogsTarget.asIncrementalDF(clusterSpecModule, BronzeTargets.auditLogsTarget.incrementalColumns,1), //Added to get the Removed Cluster, + SilverTargets.dbJobRunsTarget.asIncrementalDF(clusterStateDetailModule, SilverTargets.dbJobRunsTarget.incrementalColumns, 30), + SilverTargets.clustersSpecTarget )), append(SilverTargets.clusterStateDetailTarget) ) @@ -369,7 +372,12 @@ class Silver(_workspace: Workspace, _database: Database, _config: Config) lazy private val appendSqlQueryHistoryProcess: () => ETLDefinition = { () => ETLDefinition( - workspace.getSqlQueryHistoryParallelDF(sqlQueryHistoryModule.fromTime, sqlQueryHistoryModule.untilTime), + workspace.getSqlQueryHistoryParallelDF( + sqlQueryHistoryModule.fromTime, + sqlQueryHistoryModule.untilTime, + sqlQueryHistoryModule.pipeline.pipelineSnapTime, + deriveApiTempDir(config.tempWorkingDir, sqlQueryHistoryModule.moduleName, pipelineSnapTime), + deriveApiTempErrDir(config.tempWorkingDir, sqlQueryHistoryModule.moduleName, pipelineSnapTime)), Seq(enhanceSqlQueryHistory), append(SilverTargets.sqlQueryHistoryTarget) ) diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala index f348aaa04..fc458032f 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/SilverTransforms.scala @@ -998,20 +998,23 @@ trait SilverTransforms extends SparkSessionWrapper { def buildClusterStateDetail( untilTime: TimeTypes, - auditLogDF: DataFrame + auditLogDF: DataFrame, + jrsilverDF: DataFrame, + clusterSpec: PipelineTable, )(clusterEventsDF: DataFrame): DataFrame = { val stateUnboundW = Window.partitionBy('organization_id, 'cluster_id).orderBy('timestamp) val stateFromCurrentW = Window.partitionBy('organization_id, 'cluster_id).rowsBetween(1L, 1000L).orderBy('timestamp) val stateUntilCurrentW = Window.partitionBy('organization_id, 'cluster_id).rowsBetween(-1000L, -1L).orderBy('timestamp) val stateUntilPreviousRowW = Window.partitionBy('organization_id, 'cluster_id).rowsBetween(Window.unboundedPreceding, -1L).orderBy('timestamp) val uptimeW = Window.partitionBy('organization_id, 'cluster_id, 'reset_partition).orderBy('unixTimeMS_state_start) + val orderingWindow = Window.partitionBy('organization_id, 'cluster_id).orderBy(desc("timestamp")) val nonBillableTypes = Array( - "STARTING", "TERMINATING", "CREATING", "RESTARTING" + "STARTING", "TERMINATING", "CREATING", "RESTARTING" , "TERMINATING_IMPUTED" ) - // some states like EXPANDED_DISK and NODES_LOST, etc are excluded because they - // occasionally do come after the cluster has been terminated; thus they are not a guaranteed event + // some states like EXPANDED_DISK and NODES_LOST, etc are excluded because + // they occasionally do come after the cluster has been terminated; thus they are not a guaranteed event // goal is to be certain about the 99th percentile val runningStates = Array( "STARTING", "INIT_SCRIPTS_STARTED", "RUNNING", "CREATING", @@ -1020,13 +1023,47 @@ trait SilverTransforms extends SparkSessionWrapper { val invalidEventChain = lead('runningSwitch, 1).over(stateUnboundW).isNotNull && lead('runningSwitch, 1) .over(stateUnboundW) === lead('previousSwitch, 1).over(stateUnboundW) - val clusterEventsBaseline = clusterEventsDF + + + val refinedClusterEventsDF = clusterEventsDF .selectExpr("*", "details.*") .drop("details") .withColumnRenamed("type", "state") + + val clusterEventsFinal = if (jrsilverDF.isEmpty || clusterSpec.asDF.isEmpty) { + refinedClusterEventsDF + }else{ + val refinedClusterEventsDFFiltered = refinedClusterEventsDF + .withColumn("row", row_number().over(orderingWindow)) + .filter('state =!= "TERMINATING" && 'row === 1) + + val exceptClusterEventsDF1 = refinedClusterEventsDF.join(refinedClusterEventsDFFiltered.select("cluster_id","timestamp","state"),Seq("cluster_id","timestamp","state"),"leftAnti") + + val jrSilverAgg= jrsilverDF + .groupBy("clusterID") + .agg(max("TaskExecutionRunTime.endTS").alias("end_run_time")) + .filter('end_run_time.isNotNull) + + val joined = refinedClusterEventsDFFiltered.join(jrSilverAgg, refinedClusterEventsDFFiltered("cluster_id") === jrSilverAgg("clusterID"), "inner") + .withColumn("state", lit("TERMINATING_IMPUTED")) + + + // Join with Cluster Spec to get filter on automated cluster + val clusterSpecDF = clusterSpec.asDF.withColumnRenamed("cluster_id","clusterID") + .withColumn("isAutomated",isAutomated('cluster_name)) + .select("clusterID","cluster_name","isAutomated") + .filter('isAutomated).dropDuplicates() + + val jobClusterImputed = joined.join(clusterSpecDF,Seq("clusterID"),"inner") + .drop("row","clusterID","end_run_time","cluster_name","isAutomated") + + refinedClusterEventsDF.union(jobClusterImputed) + } + + val clusterEventsBaseline = clusterEventsFinal .withColumn( "runningSwitch", - when('state === "TERMINATING", lit(false)) + when('state.isin("TERMINATING","TERMINATING_IMPUTED"), lit(false)) .when('state.isin("CREATING", "STARTING"), lit(true)) .otherwise(lit(null).cast("boolean"))) .withColumn( @@ -1117,11 +1154,11 @@ trait SilverTransforms extends SparkSessionWrapper { val stateBeforeRemoval = clusterEventsBaselineForRemovedCluster .withColumn("rnk",rank().over(window)) .withColumn("rn", row_number().over(window)) - .withColumn("unixTimeMS_state_end",when('state === "TERMINATING",'unixTimeMS_state_end).otherwise('deletion_timestamp)) + .withColumn("unixTimeMS_state_end",when('state.isin("TERMINATING","TERMINATING_IMPUTED"),'unixTimeMS_state_end).otherwise('deletion_timestamp)) .filter('rnk === 1 && 'rn === 1).drop("rnk", "rn") val stateDuringRemoval = stateBeforeRemoval - .withColumn("timestamp",when('state === "TERMINATING",'unixTimeMS_state_end+1).otherwise(col("deletion_timestamp")+1)) + .withColumn("timestamp",when('state.isin("TERMINATING","TERMINATING_IMPUTED"),'unixTimeMS_state_end+1).otherwise(col("deletion_timestamp")+1)) .withColumn("isRunning",lit(false)) .withColumn("unixTimeMS_state_start",('timestamp)) .withColumn("unixTimeMS_state_end",('timestamp)) @@ -1143,7 +1180,7 @@ trait SilverTransforms extends SparkSessionWrapper { clusterEventsBaselineFinal .withColumn("counter_reset", when( - lag('state, 1).over(stateUnboundW).isin("TERMINATING", "RESTARTING", "EDITED") || + lag('state, 1).over(stateUnboundW).isin("TERMINATING", "RESTARTING", "EDITED","TERMINATING_IMPUTED") || !'isRunning, lit(1) ).otherwise(lit(0)) ) @@ -1369,6 +1406,6 @@ trait SilverTransforms extends SparkSessionWrapper { deriveInputForWarehouseBase(df,silver_warehouse_spec,auditBaseCols) .transform(deriveWarehouseBase()) - .transform(deriveWarehouseBaseFilled(isFirstRun, bronzeWarehouseSnapLatest)) + .transform(deriveWarehouseBaseFilled(isFirstRun, bronzeWarehouseSnapLatest, silver_warehouse_spec)) } } diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala index aa70c935f..054695eac 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala @@ -50,6 +50,8 @@ class Config() { private var _deploymentType: String = _ private var _etlCatalogName: String = _ private var _consumerCatalogName: String = _ + private val _systemTableAudit: String = "system.access.audit" + private var _sqlEndpoint: String = _ private val logger: Logger = Logger.getLogger(this.getClass) @@ -144,6 +146,9 @@ class Config() { def consumerCatalogName: String = _consumerCatalogName + def systemTableAudit: String = _systemTableAudit + + def sqlEndpoint: String = _sqlEndpoint /** * OverwatchScope defines the modules active for the current run diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/SchemaTools.scala b/src/main/scala/com/databricks/labs/overwatch/utils/SchemaTools.scala index b19b6dc09..97d26b4f8 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/SchemaTools.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/SchemaTools.scala @@ -7,6 +7,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, SparkSession} import scala.util.Random +import java.util +import java.util.Collections +import scala.collection.JavaConverters._ /** * SchemaTools is one of the more complex objects in Overwatch as it handles the schema (or lack there of rather) @@ -536,5 +539,23 @@ object SchemaTools extends SparkSessionWrapper { validator } } + + /** + * Function to change to column's naming convention from snake case to camel case + * For example: "cluster_id" to "clusterId" and "cluster_name" to "clusterName" + * @param df + * @return + */ + def snakeToCamel(df: DataFrame) : DataFrame = { + val columnNames = df.columns.toSeq + var newColumnNames = Collections.synchronizedList(new util.ArrayList[String]()) + columnNames.foreach( + x=>{ + newColumnNames.add(x.split("_").head.concat(x.split("_").tail.map(_.capitalize).mkString(""))) + } + ) + val renamedCols = newColumnNames.asScala.toSeq + df.toDF(renamedCols:_*) + } } diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala index 55b878b7e..c05bb6c3d 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala @@ -138,12 +138,31 @@ case class MultiWorkspaceConfig(workspace_name: String, aad_authority_endpoint: Option[String], deployment_id: String, output_path: String, - temp_dir_path: Option[String] + temp_dir_path: Option[String], + sql_endpoint: Option[String] = None ) case class RulesValidationResult(ruleName: String, passed: String, permitted: String, actual: String) case class RulesValidationReport(deployment_id: String, workspace_id: String, result: RulesValidationResult) +case class ReconReport( + validated: Boolean = false, + workspaceId: String, + reconType: String, + sourceDB: String, + targetDB: String, + tableName: String, + sourceCount: Option[Long] = None, + targetCount: Option[Long] = None, + missingInSource: Option[Long] = None, + missingInTarget: Option[Long] = None, + commonDataCount: Option[Long] = None, + deviationPercentage: Option[Double] = None, + sourceQuery: Option[String] = None, + targetQuery: Option[String] = None, + errorMsg: Option[String] = None + ) + object MultiWorkspaceConfigColumns extends Enumeration { val workspace_name, workspace_id, workspace_url, api_url, cloud, primordial_date, storage_prefix, etl_database_name, consumer_database_name, secret_scope, @@ -196,7 +215,9 @@ case class AzureAuditLogEventhubConfig( case class AuditLogConfig( rawAuditPath: Option[String] = None, auditLogFormat: String = "json", - azureAuditLogEventhubConfig: Option[AzureAuditLogEventhubConfig] = None + azureAuditLogEventhubConfig: Option[AzureAuditLogEventhubConfig] = None, + systemTableName: Option[String] = None, + sqlEndpoint: Option[String] = None ) case class IntelligentScaling(enabled: Boolean = false, minimumCores: Int = 4, maximumCores: Int = 512, coeff: Double = 1.0) @@ -213,7 +234,8 @@ case class OverwatchParams(auditLogConfig: AuditLogConfig, workspace_name: Option[String] = None, externalizeOptimize: Boolean = false, apiEnvConfig: Option[ApiEnvConfig] = None, - tempWorkingDir: String = "" // will be set after data target validated if not overridden + tempWorkingDir: String = "", // will be set after data target validated if not overridden + sqlEndpoint: Option[String] = None ) case class ParsedConfig( diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala index 5500c72e1..7353867f8 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala @@ -3,10 +3,12 @@ package com.databricks.labs.overwatch.utils import com.amazonaws.services.s3.model.AmazonS3Exception import com.databricks.labs.overwatch.env.Workspace import com.databricks.dbutils_v1.DBUtilsHolder.dbutils +import com.databricks.labs.overwatch.api.ApiMetaFactory import java.io.FileNotFoundException import com.databricks.labs.overwatch.pipeline.TransformFunctions._ import com.databricks.labs.overwatch.pipeline._ +import com.databricks.labs.overwatch.validation.DataReconciliation import com.fasterxml.jackson.annotation.JsonInclude.{Include, Value} import com.fasterxml.jackson.core.io.JsonStringEncoder import com.fasterxml.jackson.databind.ObjectMapper @@ -26,6 +28,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryProgressEvent import java.net.URI import java.time.LocalDate import scala.collection.parallel.ForkJoinTaskSupport +import scala.collection.parallel.mutable.ParArray import scala.concurrent.forkjoin.ForkJoinPool // TODO -- Add loggers to objects with throwables @@ -654,11 +657,11 @@ object Helpers extends SparkSessionWrapper { .as[String].first() val workspace = if (isRemoteWorkspace) { // single workspace deployment - Initializer(testConfig, disableValidations = true) + Initializer(testConfig, disableValidations = true, initializeDatabase = false,Some(workspaceID)) } else { // multi workspace deployment Initializer( testConfig, - disableValidations = disableValidations, + disableValidations = disableValidations , apiURL = apiUrl, organizationID = organization_id ) @@ -704,7 +707,7 @@ object Helpers extends SparkSessionWrapper { .filter('organization_id === workspaceID) .select(to_json('inputConfig).alias("compactString")) .as[String].first() - Initializer(testConfig, disableValidations = true, initializeDatabase = false) + Initializer(testConfig, disableValidations = true, initializeDatabase = false,Some(workspaceID)) } /** @@ -1185,4 +1188,139 @@ object Helpers extends SparkSessionWrapper { } streamManager } + + + /** + * Function separates the raw api response from the enriched api response which contains the meta info of the api call and + * returns a dataframe which contains only raw api response. + * + * @param dataFrame + * @return + */ + def deriveRawApiResponseDF(dataFrame: DataFrame): DataFrame = { + val filteredDf = dataFrame.select('rawResponse) + .filter('rawResponse =!= "{}") + if (filteredDf.isEmpty) { + filteredDf + } else { + filteredDf + .withColumn("rawResponse", SchemaTools.structFromJson(spark, dataFrame, "rawResponse")) + .select("rawResponse.*") + } + } + + /** + * Function returns a temp path which can be used to store successful api responses. + * @param tempWOrkingDir + * @param endpointDir + * @param pipelineSnapTs + * @return + */ + private[overwatch] def deriveApiTempDir(tempWorkingDir: String, endpointDir: String, pipelineSnapTs: TimeTypes): String = { + s"${tempWorkingDir}/${endpointDir}/success_" + pipelineSnapTs.asUnixTimeMilli + } + + /** + * Function return a temp path which can be used to store the api responses which was not successful. + * @param tempWOrkingDir + * @param endpointDir + * @param pipelineSnapTs + * @return + */ + private[overwatch] def deriveApiTempErrDir(tempWOrkingDir: String, endpointDir: String, pipelineSnapTs: TimeTypes): String = { + s"${tempWOrkingDir}/${endpointDir}/error_" + pipelineSnapTs.asUnixTimeMilli + } + + + /** + * Function converts the data column which is binary in traceApi dataframe to string. + * @param df + * @return + */ + def transformBinaryDf(df: DataFrame): DataFrame = { + df.withColumnRenamed("data", "rawResponse").withColumn("rawResponse", col("rawResponse").cast("String")) + } + + /** + * Function returns the api response for provided apiName from api trace data. + * @param df dataframe which contains the traceApi data. + * @param apiName name of the api. + * @return + */ + def deriveTraceDFByApiName(df: DataFrame, apiName: String): DataFrame = { + val rawDF = deriveRawApiResponseDF(transformBinaryDf(df)) + val apiMetaFactory = new ApiMetaFactory().getApiClass(apiName) + rawDF.select(explode(col(apiMetaFactory.dataframeColumn)).alias(apiMetaFactory.dataframeColumn)).select(col(apiMetaFactory.dataframeColumn + ".*")) + } + + /** + * Function returns the api response for provided moduleID from api trace data. + * @param apiEventTable apiEventTable name. + * @param moduleId + * @return + */ + def getTraceDFByModule(apiEventTable: String, moduleId: Long): DataFrame = { + val rawDF = spark.read.table(apiEventTable).filter('moduleId === moduleId) + val endPoint = rawDF.head().getAs[String]("endPoint") + deriveTraceDFByApiName(rawDF, endPoint) + } + + /** + * Function returns the api response for provided apiName from api trace data. + * @param apiEventTable apiEventTable name. + * @param endPoint + * @return + */ + def getTraceDFByApi(apiEventTable: String, endPoint: String): DataFrame = { + deriveTraceDFByApiName(spark.read.table(apiEventTable).filter('endPoint === endPoint), endPoint) + } + + /** + * Function returns the api response for provided apiName from api trace data. + * @param apiEventPath path of the apiEvent data. + * @param endPoint + * @return + */ + def getTraceDFByPath(apiEventPath: String, endPoint: String): DataFrame = { + deriveTraceDFByApiName(spark.read.load(apiEventPath).filter('endPoint === endPoint), endPoint) + } + + + /** + * This function will perform the data reconciliation between two deployments, we need two overwatch deployments with current and previous versions. + * After running the reconciliation it will generate a report which will contain all comparison results for each table. + * + * @param sourceEtl : ETL database name of previous version of OW + * @param targetEtl : ETL database name of current version of OW + */ + def performRecon(sourceEtl: String, targetEtl: String) = { + DataReconciliation.performRecon(sourceEtl, targetEtl) + } + + /** + * This function will perform the data reconciliation between two tables. + * Function will return a dataframe which has that data which is present in source but not present in target with hashcode for each columns. + * + * @param sourceTable : ETL database name of previous version of OW + * @param targetTable : ETL database name of current version of OW + * @param includeNonHashCol : If true the result dataframe will contain all the columns with the real value. + */ + def reconSingleTable(sourceTable: String, targetTable: String,includeNonHashCol:Boolean= true ) = { + DataReconciliation.reconTable(sourceTable, targetTable) + } + + /** + * This method fetches all targets for a workspace. + * + * @param workspace : Workspace object + * @return ParArray of PipelineTable + */ + def getAllPipelineTargets(workspace: Workspace): ParArray[PipelineTable] = { + val b = Bronze(workspace) + val s = Silver(workspace) + val g = Gold(workspace) + (b.getAllTargets ++ s.getAllTargets ++ g.getAllTargets).filter(_.exists(dataValidation = true, catalogValidation = false)).par + } + + } diff --git a/src/main/scala/com/databricks/labs/overwatch/validation/DataReconciliation.scala b/src/main/scala/com/databricks/labs/overwatch/validation/DataReconciliation.scala new file mode 100644 index 000000000..7407949fe --- /dev/null +++ b/src/main/scala/com/databricks/labs/overwatch/validation/DataReconciliation.scala @@ -0,0 +1,316 @@ +package com.databricks.labs.overwatch.validation + +import com.databricks.labs.overwatch.env.Workspace +import com.databricks.labs.overwatch.pipeline._ +import com.databricks.labs.overwatch.utils.Helpers.getAllPipelineTargets +import com.databricks.labs.overwatch.utils.{Helpers, ReconReport, SparkSessionWrapper} +import org.apache.log4j.Logger +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, hash, lit} + +import java.time.LocalDateTime +import scala.collection.mutable.ArrayBuffer +import scala.collection.parallel.mutable.ParArray + +/** + * Data Reconciliation is a new feature of OW which will ensure whether the data is consistent across the current release and previous release. + * In order to perform the data reconciliation, we need two overwatch deployments with current and previous versions. + * After running the reconciliation it will generate a report which will contain all comparison results for each table. + * This reconciliation module is independent of pipeline run and will be used as an helper function. + */ +object DataReconciliation extends SparkSessionWrapper { + + import spark.implicits._ + + + private val logger: Logger = Logger.getLogger(this.getClass) + + /** + * Function is the starting point of the reconciliation. + * @param sourceEtl : ETL name of Previous version of OW + * @param targetEtl : ETL name of the current version of OW + */ + private[overwatch] def performRecon(sourceEtl:String,targetEtl:String): Unit ={ + val sourceOrgIDArr = getAllOrgID(sourceEtl) + val targetOrgIDArr = getAllOrgID(targetEtl) + performBasicRecon(sourceOrgIDArr,targetOrgIDArr) + val sourceWorkspace = getConfig(sourceEtl,sourceOrgIDArr(0)) + val targetWorkspace = getConfig(targetEtl,targetOrgIDArr(0)) + val targets = getAllPipelineTargets(sourceWorkspace) + println("Number of tables for recon: "+targets.length) + println(targets.foreach(t => println(t.name))) + val report = runRecon(targets, sourceEtl, sourceOrgIDArr, targetEtl) + val reconRunId: String = java.util.UUID.randomUUID.toString + val etlStoragePrefix = targetWorkspace.getConfig.etlDataPathPrefix.substring(0, targetWorkspace.getConfig.etlDataPathPrefix.length - 13) + saveReconReport(report, etlStoragePrefix, "ReconReport", reconRunId) + } + + /** + * Performs the below comparison between two tables called source table and target table. + * Count validation in Source + * Count validation in Target + * Common data between source and target + * Missing data in source + * Missing data in target + * Deviation percentage: it is calculated with the formula ((missingSourceCount + missingTargetCount)/SourceCount)*100 + * @param target + * @param orgId + * @param sourceEtl + * @param targetEtl + * @return + */ + private def hashValidation(target: PipelineTable,orgId: String, sourceEtl:String,targetEtl:String ):ReconReport ={ + val reconType = "Validation by hashing" + try { + val sourceQuery = getQuery(s"""${target.tableFullName}""", orgId) + val sourceTable = hashAllColumns(getTableDF(sourceQuery,target)) + val targetQuery = getQuery(s"""${target.tableFullName.replaceAll(sourceEtl,targetEtl)}""", orgId) + val targetTable = hashAllColumns(getTableDF(targetQuery,target)) + val sourceCount = sourceTable.count() + val targetCount = targetTable.count() + val missingSourceCount = targetTable.exceptAll(sourceTable).count() + val missingTargetCount = sourceTable.exceptAll(targetTable).count() + val commonDataCount = sourceTable.intersectAll(targetTable).count() + val deviationFactor = { + if ((missingSourceCount + missingTargetCount) == 0) { + 1 + } else { + missingSourceCount + missingTargetCount + } + + } + val deviation:Double = { + if(deviationFactor == 1){ + 0 + }else{ + (deviationFactor.toDouble/sourceCount)*100 + } + } + + val validated: Boolean = { + if ((sourceCount == targetCount) && (missingSourceCount == 0 && missingTargetCount == 0)) { + true + } else { + false + } + } + + ReconReport(validated = validated, + workspaceId = orgId, + reconType = reconType, + sourceDB = sourceEtl, + targetDB = targetEtl, + tableName = target.name, + sourceCount = Some(sourceCount), + targetCount = Some(targetCount), + missingInSource = Some(missingSourceCount), + missingInTarget = Some(missingTargetCount), + commonDataCount = Some(commonDataCount), + deviationPercentage = Some(deviation), + sourceQuery = Some(sourceQuery), + targetQuery = Some(targetQuery), + errorMsg = Some("")) + + } catch { + case e: Exception => + e.printStackTrace() + val fullMsg = PipelineFunctions.appendStackStrace(e, "Got Exception while running recon,") + ReconReport( + workspaceId = orgId, + reconType = reconType, + sourceDB = sourceEtl, + targetDB = targetEtl, + tableName = target.tableFullName, + errorMsg = Some(fullMsg) + ) + } + + + + } + + /** + * This method runs the reconciliation for all targets in parallel. + * + * @param targets : Array of PipelineTable + * @param sourceEtl : ETL name of Previous version of OW + * @param sourceOrgIDArr : Array of organization IDs + * @param targetEtl : ETL name of the current version of OW + * @return Array of ReconReport + */ + private[overwatch] def runRecon(targets: ParArray[PipelineTable] , + sourceEtl:String, + sourceOrgIDArr: Array[String], + targetEtl:String, + ):Array[ReconReport]={ + spark.conf.set("spark.sql.legacy.allowHashOnMapType","true") + val reconStatus: ArrayBuffer[ReconReport] = new ArrayBuffer[ReconReport]() + sourceOrgIDArr.foreach(orgId=> { + targets.foreach(target => { + reconStatus.append(hashValidation(target, orgId, sourceEtl, targetEtl)) + }) + } + ) + spark.conf.set("spark.sql.legacy.allowHashOnMapType","false") + reconStatus.toArray + } + + /** + * Function saves the recon report. + * @param reconStatusArray + * @param path + * @param reportName + * @param reconRunId + */ + private def saveReconReport(reconStatusArray: Array[ReconReport], path: String, reportName: String, reconRunId: String): Unit = { + val validationPath = { + if (!path.startsWith("dbfs:") && !path.startsWith("s3") && !path.startsWith("abfss") && !path.startsWith("gs")) { + s"""dbfs:${path}""" + }else{ + path + } + } + + val pipelineSnapTime = Pipeline.createTimeDetail(LocalDateTime.now(Pipeline.systemZoneId).toInstant(Pipeline.systemZoneOffset).toEpochMilli) + reconStatusArray.toSeq.toDS().toDF() + .withColumn("reconRunId", lit(reconRunId)) + .withColumn("snapTS", lit(pipelineSnapTime.asTSString)) + .withColumn("timestamp", lit(pipelineSnapTime.asUnixTimeMilli)) + .write.format("delta") + .option("mergeSchema", "true") + .mode("append") + .save(s"""${validationPath}/report/${reportName}""") + println("ReconRunID:"+reconRunId) + println("Validation report has been saved to " + s"""${validationPath}/report/${reportName}""") + } + + + /** + * This method generates a query to fetch data from a table for a specific organization. + * + * @param tableName : Name of the table + * @param orgId : Organization ID + * @return String : Query + */ + private def getQuery(tableName: String, orgId: String): String = { + s"""select * from $tableName where organization_id = ${orgId} """ + } + + + /** + * This method fetches a DataFrame from a table using a query. + * + * @param query : Query to fetch data + * @param target : PipelineTable object + * @return DataFrame + */ + private def getTableDF(query: String,target: PipelineTable):DataFrame = { + try{ + val excludedCol = target.excludedReconColumn + val dropCol = excludedCol ++ Array("Overwatch_RunID", "Pipeline_SnapTS", "__overwatch_ctrl_noise") + val filterDF = spark.sql(query).drop(dropCol: _ *) + filterDF + }catch { + case exception: Exception => + println(s"""Exception: Unable to run the query ${query}"""+exception.getMessage) + spark.emptyDataFrame + } + + } + + + /** + * This method performs basic reconciliation between two arrays of organization IDs. + * + * @param sourceOrgIDArr : Array of organization IDs from source ETL + * @param targetOrgIDArr : Array of organization IDs from target ETL + */ + private[overwatch] def performBasicRecon(sourceOrgIDArr: Array[String], targetOrgIDArr: Array[String]): Unit = { + println("Number of workspace in Source:" + sourceOrgIDArr.size) + println("Number of workspace in Target:" + targetOrgIDArr.size) + if (sourceOrgIDArr.size < 1 || targetOrgIDArr.size < 1) { + val msg = "Number of workspace in source/target etl is 0 , Exiting" + println(msg) + throw new Exception(msg) + } + + } + + /** + * This method retrieves the Workspace configuration for a given ETL database and organization ID. + * + * @param sourceEtl The name of the ETL database. + * @param orgID The organization ID. + * @return A Workspace object containing the configuration for the specified ETL database and organization ID. + */ + private[overwatch] def getConfig(sourceEtl: String, orgID: String): Workspace = { + Helpers.getWorkspaceByDatabase(sourceEtl, Some(orgID)) + } + + + private[overwatch] def getAllOrgID(etlDB: String): Array[String] = { + try{ + spark.table(s"${etlDB}.pipeline_report").select("organization_id").distinct().collect().map(row => row.getString(0)) + }catch { + case e:Throwable=> + val msg = "Got exception while reading from pipeline_report ," + println(msg+e.getMessage) + throw e + } + + } + + /** + * This method hashes all columns of a DataFrame. + * + * @param df The DataFrame whose columns are to be hashed. + * @param includeNonHashCol A boolean flag indicating whether to include non-hash columns in the output DataFrame. Default is false. + * @return A DataFrame with the hashed columns. If includeNonHashCol is true, the original columns are also included. + * + * The method works as follows: + * 1. Retrieves the column names of the input DataFrame. + * 2. Creates a new set of columns where each column is the hash of the original column. + * 3. Selects the original columns and the hashed columns from the input DataFrame. + * 4. If includeNonHashCol is true, returns the DataFrame with both original and hashed columns. Otherwise, returns the DataFrame with only the hashed columns. + */ + private[overwatch] def hashAllColumns(df: DataFrame, includeNonHashCol:Boolean= false): DataFrame = { + val columns = df.columns + val hashCols = columns.map(column => hash(col(column)).alias(s"${column}_hash")) + val selectDf = df.select((columns.map(name => df(name)) ++ hashCols): _*) + if(includeNonHashCol){ + selectDf + } + else{ + selectDf.select(hashCols: _*) + } + + } + + + /** + * This method performs a reconciliation between two tables. + * + * @param sourceTable The name of the source table to be reconciled. + * @param targetTable The name of the target table to be reconciled. + * @param includeNonHashCol A boolean flag indicating whether to include non-hash columns in the reconciliation. Default is true. + * @return A DataFrame containing the data present in the source table but missing from the target table. + * + * The method works as follows: + * 1. Reads the source and target tables from Spark. + * 2. Drops the columns "Overwatch_RunID", "Pipeline_SnapTS", "__overwatch_ctrl_noise" from both tables. + * 3. Hashes all columns of both tables using the `hashAllColumns` method. + * 4. Finds the data present in the source table but missing from the target table using the `exceptAll` method. + * 5. Returns the DataFrame containing the missing data. + */ + private[overwatch] def reconTable(sourceTable: String, targetTable: String, includeNonHashCol: Boolean = true): DataFrame = { + val sourceHashTable = hashAllColumns(spark.read.table(sourceTable).drop("Overwatch_RunID", "Pipeline_SnapTS", "__overwatch_ctrl_noise"), includeNonHashCol) + val targetHashTable = hashAllColumns(spark.read.table(targetTable).drop("Overwatch_RunID", "Pipeline_SnapTS", "__overwatch_ctrl_noise"), includeNonHashCol) + val missingSource = sourceHashTable.exceptAll(targetHashTable) + missingSource + } + + + + + +} diff --git a/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala b/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala index 725a9334f..ff7cc2a09 100644 --- a/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala +++ b/src/main/scala/com/databricks/labs/overwatch/validation/DeploymentValidation.scala @@ -70,12 +70,12 @@ object DeploymentValidation extends SparkSessionWrapper { * if false it will register the exception in the validation report. */ private def storagePrefixAccessValidation(config: MultiWorkspaceConfig, fastFail: Boolean = false): DeploymentValidationReport = { - val testDetails = s"""StorageAccessTest storage : ${config.storage_prefix}""" + val testDetails = s"""StorageAccessTest storage : ${config.storage_prefix}/${config.workspace_id}/""" try { - dbutils.fs.mkdirs(s"""${config.storage_prefix}/test_access""") - dbutils.fs.put(s"""${config.storage_prefix}/test_access/testwrite""", "This is a file in cloud storage.") - dbutils.fs.head(s"""${config.storage_prefix}/test_access/testwrite""") - dbutils.fs.rm(s"""${config.storage_prefix}/test_access""", true) + dbutils.fs.mkdirs(s"""${config.storage_prefix}/${config.workspace_id}/test_access""") + dbutils.fs.put(s"""${config.storage_prefix}/${config.workspace_id}/test_access/testwrite""", "This is a file in cloud storage.") + dbutils.fs.head(s"""${config.storage_prefix}/${config.workspace_id}/test_access/testwrite""") + dbutils.fs.rm(s"""${config.storage_prefix}/${config.workspace_id}/test_access""", true) DeploymentValidationReport(true, getSimpleMsg("Storage_Access"), testDetails, @@ -364,7 +364,12 @@ object DeploymentValidation extends SparkSessionWrapper { * @return */ private def cloudSpecificValidation(config: MultiWorkspaceConfig): DeploymentValidationReport = { - + if(config.auditlogprefix_source_path.getOrElse("").toLowerCase.equals("system")) { + val tableName = fetchTableName(config.auditlogprefix_source_path) + validateSystemTableAudit(tableName, config.workspace_id, config.sql_endpoint.getOrElse(""), + config.workspace_url, config.secret_scope, config.secret_key_dbpat) + } + else config.cloud.toLowerCase match { case cloudType if cloudType == "aws" || cloudType == "gcp" => validateAuditLog( @@ -377,7 +382,6 @@ object DeploymentValidation extends SparkSessionWrapper { validateEventHub( config) } - } /** @@ -529,6 +533,8 @@ object DeploymentValidation extends SparkSessionWrapper { * @return */ private def checkAAD(config: MultiWorkspaceConfig): Boolean = { + if(config.auditlogprefix_source_path.getOrElse("").toLowerCase.equals("system")) + return true if (config.eh_name.isEmpty) throw new BadConfigException("eh_name should be nonempty, please check the configuration.") @@ -703,6 +709,7 @@ object DeploymentValidation extends SparkSessionWrapper { case "Valid_Excluded_Scopes" => "Excluded scope can be audit:sparkEvents:jobs:clusters:clusterEvents:notebooks:pools:accounts." case "Storage_Access" => "ETL_STORAGE_PREFIX should have read,write and create access" case "Validate_Mount" => "Number of mount points in the workspace should not exceed 50" + case "Validate_SystemTablesAudit" => "System table name should be system.access.audit" } } @@ -780,4 +787,68 @@ object DeploymentValidation extends SparkSessionWrapper { validationStatus.toArray } + private def fetchTableName(auditlogprefix_source_path: Option[String]): String = { + if(auditlogprefix_source_path.get.toLowerCase().equals("system")) + "system.access.audit" + else + auditlogprefix_source_path.get + } + + private def validateSystemTableAudit(table_name: String, + workspace_id: String, + sql_endpoint: String, + workspace_host: String, + workspace_token_key: String, + workspace_scope:String + ): DeploymentValidationReport = { + val testDetails = s"Testing for System table - ${table_name} and workspace_id - ${workspace_id}" + val ifDataExists = if(sql_endpoint.trim.isEmpty) + verifySystemTableAudit(table_name, workspace_id) + else + verifyAuditLogFromExternalAccount(table_name,workspace_id, workspace_host, + sql_endpoint, workspace_scope, workspace_token_key) + + if(spark.catalog.tableExists(table_name) && !ifDataExists) { + DeploymentValidationReport(true, + getSimpleMsg("Validate_SystemTablesAudit"), + testDetails, + Some("SUCCESS"), + Some(workspace_id) + ) + } else { + DeploymentValidationReport(false, + getSimpleMsg("Validate_SystemTablesAudit"), + testDetails, + Some("FAILED"), + Some(workspace_id) + ) + throw new BadConfigException( + s"${table_name} does not exists for workspace_id - ${workspace_id}") + } + } + + private def verifySystemTableAudit(table_name: String, workspace_id: String): Boolean = { + val auditLogData = spark.read.table(table_name) + .filter('workspace_id === workspace_id) + auditLogData.isEmpty + } + + private def verifyAuditLogFromExternalAccount(table_name: String, + workspace_id: String, + workspace_host: String, + sql_endpoint: String, + workspace_token_key: String, + workspace_scope: String): Boolean = { + val patToken = dbutils.secrets.get(scope = workspace_scope, key = workspace_token_key) + val host = workspace_host.stripPrefix("https://").stripSuffix("/") + val auditLogData = spark.read + .format("databricks") + .option("host",host) + .option("httpPath",sql_endpoint) + .option("personalAccessToken",patToken) + .option("query",s"select * from ${table_name} where workspace_id='${workspace_id}'") + .load() + auditLogData.isEmpty + } + } \ No newline at end of file diff --git a/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala b/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala index ce29352f3..b4fbfe8eb 100644 --- a/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala +++ b/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala @@ -51,7 +51,8 @@ class ParamDeserializerTest extends AnyFunSpec { workspace_name = Some("myTestWorkspace"), externalizeOptimize = false, apiEnvConfig = None, - tempWorkingDir = "" + tempWorkingDir = "", + sqlEndpoint = None ) assertResult(expected)(mapper.readValue[OverwatchParams](incomplete)) diff --git a/src/test/scala/com/databricks/labs/overwatch/pipeline/InitializeTest.scala b/src/test/scala/com/databricks/labs/overwatch/pipeline/InitializeTest.scala index f40d39181..ec683d900 100644 --- a/src/test/scala/com/databricks/labs/overwatch/pipeline/InitializeTest.scala +++ b/src/test/scala/com/databricks/labs/overwatch/pipeline/InitializeTest.scala @@ -3,7 +3,7 @@ package com.databricks.labs.overwatch.pipeline import com.databricks.labs.overwatch.SparkSessionTestWrapper import com.databricks.labs.overwatch.env.Database import com.databricks.labs.overwatch.utils.OverwatchScope._ -import com.databricks.labs.overwatch.utils.{BadConfigException, Config, OverwatchScope} +import com.databricks.labs.overwatch.utils.{BadConfigException, Config, OverwatchScope, TokenSecret} import com.fasterxml.jackson.core.JsonParseException import com.fasterxml.jackson.core.io.JsonEOFException import com.github.mrpowers.spark.fast.tests.DataFrameComparer @@ -111,7 +111,11 @@ class InitializeTest extends AnyFunSpec with DataFrameComparer with SparkSession conf.setDeploymentType("default") val init = Initializer.buildInitializer(conf) val expectedAuditConf = AuditLogConfig(Some("path/to/auditLog"), "json", None) - val actualAuditConf = init.validateAuditLogConfigs(configInput) + val organizationId = "123" + val workspace_url = "https://test.databricks.com" + val sql_endpoint = "/sql/1.0/warehouses/123abc" + val token = Some(TokenSecret("scope", "key")) + val actualAuditConf = init.validateAuditLogConfigs(configInput,organizationId,workspace_url,sql_endpoint,token) assert(expectedAuditConf == actualAuditConf) } @@ -124,9 +128,13 @@ class InitializeTest extends AnyFunSpec with DataFrameComparer with SparkSession val conf = new Config conf.setDeploymentType("default") val init = Initializer.buildInitializer(conf) + val organizationId = "123" + val workspace_url = "https://test.databricks.com" + val sql_endpoint = "/sql/1.0/warehouses/123abc" + val token = Some(TokenSecret("scope", "key")) val expectedAuditConf = AuditLogConfig(None, "json", Some(AzureAuditLogEventhubConfig("sample.connection.string", "auditLog", "path/to/auditLog/prefix",10000,10, Some("path/to/auditLog/prefix/rawEventsCheckpoint"), Some("path/to/auditLog/prefix/auditLogBronzeCheckpoint")))) - val actualAuditConf = init.validateAuditLogConfigs(configInput) + val actualAuditConf = init.validateAuditLogConfigs(configInput,organizationId,workspace_url, sql_endpoint, token) assert(expectedAuditConf == actualAuditConf) } @@ -143,7 +151,11 @@ class InitializeTest extends AnyFunSpec with DataFrameComparer with SparkSession conf.setCloudProvider("aws") conf.setDeploymentType("default") val init = Initializer.buildInitializer(conf) - assertThrows[BadConfigException](init.validateAuditLogConfigs(configInput)) + val organizationId = "123" + val workspace_url = "https://test.databricks.com" + val sql_endpoint = "/sql/1.0/warehouses/123abc" + val token = Some(TokenSecret("scope", "key")) + assertThrows[BadConfigException](init.validateAuditLogConfigs(configInput,organizationId,workspace_url,sql_endpoint,token)) } ignore ("validateAuditLogConfigs function validate audit log format in the config ") { @@ -154,9 +166,13 @@ class InitializeTest extends AnyFunSpec with DataFrameComparer with SparkSession conf.setCloudProvider("aws") conf.setDeploymentType("default") val init = Initializer.buildInitializer(conf) + val organizationId = "123" + val workspace_url = "https://test.databricks.com" + val sql_endpoint = "/sql/1.0/warehouses/123abc" + val token = Some(TokenSecret("scope", "key")) val quickBuildAuditLogConfig = PrivateMethod[AuditLogConfig]('validateAuditLogConfigs) - assertThrows[BadConfigException](init.validateAuditLogConfigs(configInput)) + assertThrows[BadConfigException](init.validateAuditLogConfigs(configInput,organizationId,workspace_url,sql_endpoint,token)) } }