Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incremental Snapshot #909

Closed
26 changes: 2 additions & 24 deletions src/main/scala/com/databricks/labs/overwatch/env/Database.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.databricks.labs.overwatch.env

import com.databricks.labs.overwatch.pipeline.TransformFunctions._
import com.databricks.labs.overwatch.pipeline.{PipelineFunctions, PipelineTable}
import com.databricks.labs.overwatch.utils.{Config, SparkSessionWrapper, WriteMode, MergeScope}
import com.databricks.labs.overwatch.utils.{Config, Helpers, MergeScope, SparkSessionWrapper, WriteMode}
import io.delta.tables.{DeltaMergeBuilder, DeltaTable}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -109,28 +109,6 @@ class Database(config: Config) extends SparkSessionWrapper {
registerTarget(target)
}

private def getQueryListener(query: StreamingQuery, minEventsPerTrigger: Long): StreamingQueryListener = {
val streamManager = new StreamingQueryListener() {
override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = {
println("Query started: " + queryStarted.id)
}

override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = {
println("Query terminated: " + queryTerminated.id)
}

override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = {
println("Query made progress: " + queryProgress.progress)
if (config.debugFlag) {
println(query.status.prettyJson)
}
if (queryProgress.progress.numInputRows <= minEventsPerTrigger) {
query.stop()
}
}
}
streamManager
}

/**
* It's often more efficient to write a temporary version of the data to be merged than to compare complex
Expand Down Expand Up @@ -304,7 +282,7 @@ class Database(config: Config) extends SparkSessionWrapper {
.asInstanceOf[DataStreamWriter[Row]]
.option("path", target.tableLocation)
.start()
val streamManager = getQueryListener(streamWriter, config.auditLogConfig.azureAuditLogEventhubConfig.get.minEventsPerTrigger)
val streamManager = Helpers.getQueryListener(streamWriter,config, config.auditLogConfig.azureAuditLogEventhubConfig.get.minEventsPerTrigger)
spark.streams.addListener(streamManager)
val listenerAddedMsg = s"Event Listener Added.\nStream: ${streamWriter.name}\nID: ${streamWriter.id}"
if (config.debugFlag) println(listenerAddedMsg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class Workspace(config: Config) extends SparkSessionWrapper {
val sourceName = dataset.name
val sourcePath = dataset.path
val targetPath = if (targetPrefix.takeRight(1) == "/") s"$targetPrefix$sourceName" else s"$targetPrefix/$sourceName"
CloneDetail(sourcePath, targetPath, asOfTS, cloneLevel)
CloneDetail(sourcePath, targetPath, asOfTS, cloneLevel,Array(),WriteMode.append)
}).toArray.toSeq
Helpers.parClone(cloneSpecs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class Bronze(_workspace: Workspace, _database: Database, _config: Config)
* if choosing to overwrite, only one backup will be maintained
* @param excludes which bronze targets to exclude from the snapshot
*/
//TODO: Add link for new Snapshot Class functionality
@deprecated("This Method is deprecated, Use new Snapshot Class instead. Please check the link for more details")
def snapshot(
targetPrefix: String,
overwrite: Boolean,
Expand Down
277 changes: 277 additions & 0 deletions src/main/scala/com/databricks/labs/overwatch/pipeline/Snapshot.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
package com.databricks.labs.overwatch.pipeline

import com.databricks.labs.overwatch.env.{Database, Workspace}
import com.databricks.labs.overwatch.utils._
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.streaming.Trigger

import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool
import com.databricks.labs.overwatch.utils.Helpers.removeTrailingSlashes
import io.delta.tables.DeltaTable
import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, Row}

class Snapshot (_sourceETLDB: String, _targetPrefix: String, _workspace: Workspace, _database: Database, _config: Config)
extends Pipeline(_workspace, _database, _config){


import spark.implicits._
private val snapshotRootPath = removeTrailingSlashes(_targetPrefix)
private val workSpace = _workspace
private val bronze = Bronze(workSpace)
private val silver = Silver(workSpace)
private val gold = Gold(workSpace)

private val logger: Logger = Logger.getLogger(this.getClass)
private val driverCores = java.lang.Runtime.getRuntime.availableProcessors()
val Config = _config


private def parallelism: Int = {
driverCores
}

private[overwatch] def snapStream(cloneDetails: Seq[CloneDetail]): Unit = {

val cloneDetailsPar = cloneDetails.par
val taskSupport = new ForkJoinTaskSupport(new ForkJoinPool(parallelism))
cloneDetailsPar.tasksupport = taskSupport
import spark.implicits._
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled",true)

logger.log(Level.INFO, "Streaming START:")
val cloneReport = cloneDetailsPar.map(cloneSpec => {
try {
val rawStreamingDF = spark.readStream.format("delta").option("ignoreChanges", "true").load(s"${cloneSpec.source}")
val sourceName = s"${cloneSpec.source}".split("/").takeRight(1).head
val checkPointLocation = s"${snapshotRootPath}/checkpoint/${sourceName}"
val targetLocation = s"${cloneSpec.target}"
val streamWriter = if (cloneSpec.mode == WriteMode.merge){
if(Helpers.pathExists(targetLocation)){
val deltaTable = DeltaTable.forPath(spark,targetLocation)
val immutableColumns = cloneSpec.immutableColumns

def upsertToDelta(microBatchOutputDF: DataFrame, batchId: Long) {

val mergeCondition: String = immutableColumns.map(k => s"updates.$k = target.$k").mkString(" AND ")
deltaTable.as("target")
.merge(
microBatchOutputDF.as("updates"),
mergeCondition)
.whenMatched().updateAll()
.whenNotMatched().insertAll()
.execute()
}

rawStreamingDF
.writeStream
.format("delta")
.outputMode("append")
.foreachBatch(upsertToDelta _)
.trigger(Trigger.Once())
.option("checkpointLocation", checkPointLocation)
.option("mergeSchema", "true")
.option("path", targetLocation)
.start()

}else{
rawStreamingDF
.writeStream
.format("delta")
.trigger(Trigger.Once())
.option("checkpointLocation", checkPointLocation)
.option("mergeSchema", "true")
.option("path", targetLocation)
.start()
}
}else{
rawStreamingDF
.writeStream
.format("delta")
.trigger(Trigger.Once())
.option("checkpointLocation", checkPointLocation)
.option("mergeSchema", "true")
.option("path", targetLocation)
.start()
}
val streamManager = Helpers.getQueryListener(streamWriter,workspace.getConfig, workspace.getConfig.auditLogConfig.azureAuditLogEventhubConfig.get.minEventsPerTrigger)
spark.streams.addListener(streamManager)
val listenerAddedMsg = s"Event Listener Added.\nStream: ${streamWriter.name}\nID: ${streamWriter.id}"
if (config.debugFlag) println(listenerAddedMsg)
logger.log(Level.INFO, listenerAddedMsg)

streamWriter.awaitTermination()
spark.streams.removeListener(streamManager)

logger.log(Level.INFO, s"Streaming COMPLETE: ${cloneSpec.source} --> ${cloneSpec.target}")
CloneReport(cloneSpec, s"Streaming For: ${cloneSpec.source} --> ${cloneSpec.target}", "SUCCESS")
} catch {
case e: Throwable if (e.getMessage.contains("is after the latest commit timestamp of")) => {
val failMsg = PipelineFunctions.appendStackStrace(e)
val msg = s"SUCCESS WITH WARNINGS: The timestamp provided, ${cloneSpec.asOfTS.get} " +
s"resulted in a temporally unsafe exception. Cloned the source without the as of timestamp arg. " +
s"\nDELTA ERROR MESSAGE: ${failMsg}"
logger.log(Level.WARN, msg)
CloneReport(cloneSpec, s"Streaming For: ${cloneSpec.source} --> ${cloneSpec.target}", msg)
}
case e: Throwable => {
val failMsg = PipelineFunctions.appendStackStrace(e)
CloneReport(cloneSpec, s"Streaming For: ${cloneSpec.source} --> ${cloneSpec.target}", failMsg)
}
}
}).toArray.toSeq
val cloneReportPath = s"${snapshotRootPath}/clone_report/"
cloneReport.toDS.write.mode("append").option("mergeSchema", "true").format("delta").save(cloneReportPath)
}


private[overwatch] def buildCloneSpecs(
sourceToSnap: Array[PipelineTable]
): Seq[CloneDetail] = {

val finalSnapshotRootPath = s"${snapshotRootPath}/data"
val cloneSpecs = sourceToSnap.map(dataset => {
val sourceName = dataset.name.toLowerCase
val sourcePath = dataset.tableLocation
val mode = dataset._mode
val immutableColumns = (dataset.keys ++ dataset.incrementalColumns).distinct
val targetPath = s"$finalSnapshotRootPath/$sourceName"
CloneDetail(sourcePath, targetPath, None, "Deep",immutableColumns,mode)
}).toArray.toSeq
cloneSpecs
}

private[overwatch] def incrementalSnap(
pipeline : String,
excludes: Option[String] = Some("")
): this.type = {


val sourceToSnap = {
if (pipeline.toLowerCase() == "bronze") bronze.getAllTargets
else if (pipeline.toLowerCase() == "silver") silver.getAllTargets
else if (pipeline.toLowerCase() == "gold") gold.getAllTargets
else Array(pipelineStateTarget)
}

val exclude = excludes match {
case Some(s) if s.nonEmpty => s
case _ => ""
}
val excludeList = exclude.split(":")

val cleanExcludes = excludeList.map(_.toLowerCase).map(exclude => {
if (exclude.contains(".")) exclude.split("\\.").takeRight(1).head else exclude
})
cleanExcludes.foreach(x => println(x))

val sourceToSnapFiltered = sourceToSnap
.filter(_.exists()) // source path must exist
.filterNot(t => cleanExcludes.contains(t.name.toLowerCase))

val cloneSpecs = buildCloneSpecs(sourceToSnapFiltered)
snapStream(cloneSpecs)
this
}

private[overwatch] def snap(
pipeline : String,
cloneLevel: String = "DEEP",
excludes: Option[String] = Some("")
): this.type= {
val acceptableCloneLevels = Array("DEEP", "SHALLOW")
require(acceptableCloneLevels.contains(cloneLevel.toUpperCase), s"SNAP CLONE ERROR: cloneLevel provided is " +
s"$cloneLevel. CloneLevels supported are ${acceptableCloneLevels.mkString(",")}.")

val sourceToSnap = {
if (pipeline.toLowerCase() == "bronze") bronze.getAllTargets
else if (pipeline.toLowerCase() == "silver") silver.getAllTargets
else if (pipeline.toLowerCase() == "gold") gold.getAllTargets
else Array(pipelineStateTarget)
}

val exclude = excludes match {
case Some(s) if s.nonEmpty => s
case _ => ""
}
val excludeList = exclude.split(":")

val cleanExcludes = excludeList.map(_.toLowerCase).map(exclude => {
if (exclude.contains(".")) exclude.split("\\.").takeRight(1).head else exclude
})
cleanExcludes.foreach(x => println(x))


val sourceToSnapFiltered = sourceToSnap
.filter(_.exists()) // source path must exist
.filterNot(t => cleanExcludes.contains(t.name.toLowerCase))

val cloneSpecs = buildCloneSpecs(sourceToSnapFiltered)
val cloneReport = Helpers.parClone(cloneSpecs)
val cloneReportPath = s"${snapshotRootPath}/clone_report/"
cloneReport.toDS.write.format("delta").mode("append").save(cloneReportPath)
this
}

}

object Snapshot extends SparkSessionWrapper {


def apply(workspace: Workspace,
sourceETLDB : String,
targetPrefix : String,
pipeline : String,
snapshotType: String,
excludes: Option[String],
CloneLevel: String
): Any = {
if (snapshotType.toLowerCase()== "incremental")
new Snapshot(sourceETLDB, targetPrefix, workspace, workspace.database, workspace.getConfig).incrementalSnap(pipeline, excludes)
if (snapshotType.toLowerCase()== "full")
new Snapshot(sourceETLDB, targetPrefix, workspace, workspace.database, workspace.getConfig).snap(pipeline,CloneLevel,excludes)

}


/**
* Create a backup of the Overwatch datasets
*
* @param arg(0) Source Database Name.
* @param arg(1) Target snapshotRootPath
* @param arg(2) Define the Medallion Layers. Argumnent should be in form of "Bronze, Silver, Gold"(All 3 or any combination of them)
* @param arg(3) Type of Snapshot to be performed. Full for Full Snapshot , Incremental for Incremental Snapshot
* @param arg(4) Array of table names to exclude from the snapshot
* this is the table name only - without the database prefix
* @return
*/

def main(args: Array[String]): Unit = {

val sourceETLDB = args(0)
val snapshotRootPath = args(1)
val pipeline = args(2)
val snapshotType = args(3)
val tablesToExclude = args.lift(4).getOrElse("")
val cloneLevel = args.lift(5).getOrElse("Deep")

val snapWorkSpace = Helpers.getWorkspaceByDatabase(sourceETLDB)

val pipelineLower = pipeline.toLowerCase
if (pipelineLower.contains("bronze")) Snapshot(snapWorkSpace,sourceETLDB,snapshotRootPath,"Bronze",snapshotType,Some(tablesToExclude),cloneLevel)
if (pipelineLower.contains("silver")) Snapshot(snapWorkSpace,sourceETLDB,snapshotRootPath,"Silver",snapshotType,Some(tablesToExclude),cloneLevel)
if (pipelineLower.contains("gold")) Snapshot(snapWorkSpace,sourceETLDB,snapshotRootPath,"Gold",snapshotType,Some(tablesToExclude),cloneLevel)
Snapshot(snapWorkSpace,sourceETLDB,snapshotRootPath,"pipeline_report",snapshotType,Some(tablesToExclude),cloneLevel)

println("SnapShot Completed")
}







}


Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ case class MultiWSDeploymentReport(

case class WorkspaceMetastoreRegistrationReport(workspaceDataset: WorkspaceDataset, registerStatement: String, status: String)

case class CloneDetail(source: String, target: String, asOfTS: Option[String] = None, cloneLevel: String = "DEEP")

case class CloneDetail(source: String, target: String, asOfTS: Option[String] = None, cloneLevel: String = "DEEP",immutableColumns:Array[String] = Array(),mode: WriteMode.WriteMode = WriteMode.append)

case class CloneReport(cloneSpec: CloneDetail, cloneStatement: String, status: String)

Expand Down
25 changes: 25 additions & 0 deletions src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.util.SerializableConfiguration

import java.net.URI
Expand Down Expand Up @@ -1029,4 +1031,27 @@ object Helpers extends SparkSessionWrapper {
when(url.endsWith("/"), url.substr(lit(0), length(url) - 1)).otherwise(url)
}

def getQueryListener(query: StreamingQuery,config:Config,minEventsPerTrigger: Long): StreamingQueryListener = {
val streamManager = new StreamingQueryListener() {
override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = {
println("Query started: " + queryStarted.id)
}

override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = {
println("Query terminated: " + queryTerminated.id)
}

override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = {
println("Query made progress: " + queryProgress.progress)
if (config.debugFlag) {
println(query.status.prettyJson)
}
if (queryProgress.progress.numInputRows <= minEventsPerTrigger) {
query.stop()
}
}
}
streamManager
}

}