Skip to content

Commit

Permalink
Relax cudf version check for patch-level versions (NVIDIA#1930)
Browse files Browse the repository at this point in the history
* Relax cudf version check for patch-level versions

Signed-off-by: Jason Lowe <jlowe@nvidia.com>

* whitespace

* cleanup and use NumberFormatException

* Cleanup checking with Try and startsWith

* Add tests for many patch levels
  • Loading branch information
jlowe authored Mar 18, 2021
1 parent f048750 commit 89eb0b0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
23 changes: 22 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.Properties
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.util.Try

import com.nvidia.spark.rapids.python.PythonWorkerSemaphore

Expand Down Expand Up @@ -222,7 +223,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}
val expectedCudfVersion = pluginCudfVersion.toString
// compare cudf version in the classpath with the cudf version expected by plugin
if (!cudfVersion.equals(expectedCudfVersion)) {
if (!RapidsExecutorPlugin.cudfVersionSatisfied(expectedCudfVersion, cudfVersion)) {
throw CudfVersionMismatchException(s"Cudf version in the classpath is different. " +
s"Found $cudfVersion, RAPIDS Accelerator expects $expectedCudfVersion")
}
Expand All @@ -242,6 +243,26 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}
}

object RapidsExecutorPlugin {
/**
* Return true if the expected cudf version is satisfied by the actual version found.
* The version is satisfied if the major and minor versions match exactly. If there is a requested
* patch version then the actual patch version must be greater than or equal.
* For example, version 7.1 is not satisfied by version 7.2, but version 7.1 is satisfied by
* version 7.1.1.
*/
def cudfVersionSatisfied(expected: String, actual: String): Boolean = {
val (expMajorMinor, expPatch) = expected.split('.').splitAt(2)
val (actMajorMinor, actPatch) = actual.split('.').splitAt(2)
actMajorMinor.startsWith(expMajorMinor) && {
val expPatchInts = expPatch.map(_.toInt)
val actPatchInts = actPatch.map(v => Try(v.toInt).getOrElse(Int.MinValue))
val zipped = expPatchInts.zipAll(actPatchInts, 0, 0)
zipped.forall { case (e, a) => e <= a }
}
}
}

object ExecutionPlanCaptureCallback {
private[this] val shouldCapture: AtomicBoolean = new AtomicBoolean(false)
private[this] val execPlan: AtomicReference[SparkPlan] = new AtomicReference[SparkPlan]()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import org.scalatest.FunSuite

class RapidsExecutorPluginSuite extends FunSuite {
test("cudf version check") {
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7", "8"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7.2"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7", "8.7"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7.2.1"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0.1"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0.1.3"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.1"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1.3"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.2"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.2.3"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.0"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.1"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.1.1"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1-special"))
assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.2.2.2", "7.0.2.2"))
assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.2.2.2", "7.0.2.2.2"))
}
}

0 comments on commit 89eb0b0

Please sign in to comment.