Skip to content

Commit

Permalink
Sync-up the support of parse_url in qualification tool (#1195)
Browse files Browse the repository at this point in the history
* Sync-up the support of parse_url in qualification tool

Fixes #1190

- parse_url(*,query,*) should be treated as supported
- fix the unit-test
- identify unsupported parts ref-file-authority-userinfo

---------

Signed-off-by: Ahmed Hussein (amahussein) <a@ahussein.me>
  • Loading branch information
amahussein authored Jul 17, 2024
1 parent c580851 commit b31e00a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ object SQLPlanParser extends Logging {
// We do not want them to appear as independent expressions.
"structfield", "structtype")

// As RAPIDS plugin rev 2b09372, it only supports parse_url(*,HOST|PROTOCOL|QUERY|PATH[,*]).
// the following partToExtract parse_url(*,REF|FILE|AUTHORITY|USERINFO[,*]) are not supported
val unsupportedParseURLParts = Set("FILE", "REF", "AUTHORITY", "USERINFO")
// define a pattern to identify whether a certain string contains the unsupported extractParts of
// the parse_url
val regExParseURLPart =
s"(?i)parse_url\\(.*,\\s*(${unsupportedParseURLParts.mkString("|")})(?:\\s*,.*)*\\)".r

/**
* This function is used to create a set of nodes that should be skipped while parsing the Execs
* of a specific node.
Expand Down Expand Up @@ -652,9 +660,10 @@ object SQLPlanParser extends Logging {

// This method aims at doing some common processing to an expression before
// we start parsing it. For example, some special handling is required for some functions.
private def processSpecialFunctions(expr: String): String = {
// For parse_url, we only support parse_url(*,Host,*); parse_url(*,Protocol,*)
// So we want to be able to define that parse_url(*,QUERY,*) is not supported.
def processSpecialFunctions(expr: String): String = {
// For parse_url, we only support parse_url(*,HOST|PROTOCOL|QUERY|PATH[,*]).
// So we want to be able to define that parse_url(*,REF|FILE|AUTHORITY|USERINFO[,*])
// is not supported.

// The following regex uses forward references to find matches for parse_url(*)
// we need to use forward references because otherwise multiple occurrences will be matched
Expand All @@ -666,16 +675,25 @@ object SQLPlanParser extends Logging {
// parse_url(url_col#7, QUERY, false) AS QUERY#10]
val parseURLPattern = ("parse_url(?=\\()(?:(?=.*?\\((?!.*?\\1)(.*\\)(?!.*\\2).*))(?=.*?\\)" +
"(?!.*?\\2)(.*)).)+?.*?(?=\\1)[^(]*(?=\\2$)").r
var newExpr = expr
parseURLPattern.findAllMatchIn(expr).foreach { parse_call =>
// iterate on all matches replacing parse_url by parse_url_query
// note that we do replaceFirst because we want to map 1-to-1 and the order does
// not matter here.
if (parse_call.matched.matches("parse_url\\(.*,\\s*(?i)query\\s*,.*\\)")) {
newExpr = newExpr.replaceFirst("parse_url\\(", "parse_url_query(")
val allMatches = parseURLPattern.findAllMatchIn(expr)
if (allMatches.nonEmpty) {
var newExpr = expr
allMatches.foreach { parse_call =>
// iterate on all matches replacing parse_url by parse_url_{parttoextract} if any
// note that we do replaceFirst because we want to map 1-to-1 and the order does
// not matter here.
val matched = parse_call.matched
val extractPart = regExParseURLPart.findFirstMatchIn(matched).map(_.group(1))
if (extractPart.isDefined) {
val replacedParseClass =
matched.replaceFirst("parse_url\\(", s"parse_url_${extractPart.get.toLowerCase}(")
newExpr = newExpr.replace(matched, replacedParseClass)
}
}
newExpr
} else {
expr
}
newExpr
}

private def getAllFunctionNames(regPattern: Regex, expr: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.util.control.NonFatal
import com.nvidia.spark.rapids.BaseTestSuite
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, ToolTestUtils}
import com.nvidia.spark.rapids.tool.qualification._
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper}
import org.scalatest.Matchers.{be, contain, convertToAnyShouldWrapper}
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.sql.{DataFrame, TrampolineUtil}
Expand Down Expand Up @@ -1192,9 +1192,35 @@ class SQLPlanParserSuite extends BaseTestSuite {
}
}

test("ParseUrl is supported except that for parse_url_query") {
// parse_url(*,QUERY,*) should cause the project to be unsupported
// the expression will appear in the unsupportedExpression summary
test("ParseUrl is supported") {
// Verify that each partToExtract expression causes the parse_url to be renamed.
SQLPlanParser.unsupportedParseURLParts.foreach { part =>
val partLC = part.toLowerCase
val a1 = SQLPlanParser.processSpecialFunctions(s"parse_url(test1, $part, cast(test))")
val a2 = SQLPlanParser.processSpecialFunctions(s"parse_url(test1, $partLC, cast(test))")
val a3 = SQLPlanParser.processSpecialFunctions(s"parse_url(test1, $part)")
val a4 = SQLPlanParser.processSpecialFunctions(s"parse_url(test1, $partLC)")
a1 shouldEqual s"parse_url_$partLC(test1, $part, cast(test))"
a2 shouldEqual s"parse_url_$partLC(test1, $partLC, cast(test))"
a3 shouldEqual s"parse_url_$partLC(test1, $part)"
a4 shouldEqual s"parse_url_$partLC(test1, $partLC)"
}
// verify that having the keywords in different argument does not affect the correctness
val a5 = SQLPlanParser.processSpecialFunctions("parse_url(AUTHORITY, ANY_PART, query)")
a5 shouldEqual "parse_url(AUTHORITY, ANY_PART, query)"
// verify multiple calls
val a6 = SQLPlanParser.processSpecialFunctions(
"parse_url(AUTHORITY, ANY_PART, query), parse_url(AUTHORITY, REF, query), " +
"parse_url(AUTHORITY, FILE, query)")
a6 shouldEqual "parse_url(AUTHORITY, ANY_PART, query), parse_url_ref(AUTHORITY, REF, query), " +
"parse_url_file(AUTHORITY, FILE, query)"
// verify nested. Note it does not matter the order as long as we get it right
val a7 = SQLPlanParser.processSpecialFunctions(
"parse_url(parse_url(AUTHORITY, ANY_PART, query), REF, query)")
val a8 = SQLPlanParser.processSpecialFunctions(
"parse_url(parse_url(AUTHORITY, REF, query), ANY_PART, query)")
a7 shouldEqual "parse_url_ref(parse_url(AUTHORITY, ANY_PART, query), REF, query)"
a8 shouldEqual "parse_url_ref(parse_url(AUTHORITY, REF, query), ANY_PART, query)"
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
val (eventLog, _) = ToolTestUtils.generateEventLog(eventLogDir,
Expand All @@ -1207,7 +1233,8 @@ class SQLPlanParserSuite extends BaseTestSuite {
df1.write.parquet(s"$parquetoutputLoc/testparse")
val df2 = spark.read.parquet(s"$parquetoutputLoc/testparse")
df2.selectExpr("*", "parse_url(`url_col`, 'HOST') as HOST",
"parse_url(`url_col`,'QUERY') as QUERY")
"parse_url(`url_col`,'QUERY') as QUERY", "parse_url(`url_col`, 'REF') as REF",
"parse_url(`url_col`,'USERINFO') as NOT_QUERY")
}
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
Expand All @@ -1219,6 +1246,8 @@ class SQLPlanParserSuite extends BaseTestSuite {
val allExecInfo = getAllExecsFromPlan(parsedPlans.toSeq)
val projects = allExecInfo.filter(_.exec.contains("Project"))
assertSizeAndNotSupported(1, projects)
val expectedExprss = Seq("parse_url_ref", "parse_url_userinfo")
projects(0).unsupportedExprs.map(_.exprName) should contain theSameElementsAs expectedExprss
}
}
}
Expand Down

0 comments on commit b31e00a

Please sign in to comment.