diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 6b6e5c6a750..92da5f807d1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -21,6 +21,7 @@ import java.util.Properties import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ +import scala.util.Try import com.nvidia.spark.rapids.python.PythonWorkerSemaphore @@ -222,7 +223,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { } val expectedCudfVersion = pluginCudfVersion.toString // compare cudf version in the classpath with the cudf version expected by plugin - if (!cudfVersion.equals(expectedCudfVersion)) { + if (!RapidsExecutorPlugin.cudfVersionSatisfied(expectedCudfVersion, cudfVersion)) { throw CudfVersionMismatchException(s"Cudf version in the classpath is different. " + s"Found $cudfVersion, RAPIDS Accelerator expects $expectedCudfVersion") } @@ -242,6 +243,26 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { } } +object RapidsExecutorPlugin { + /** + * Return true if the expected cudf version is satisfied by the actual version found. + * The version is satisfied if the major and minor versions match exactly. If there is a requested + * patch version then the actual patch version must be greater than or equal. + * For example, version 7.1 is not satisfied by version 7.2, but version 7.1 is satisfied by + * version 7.1.1. + */ + def cudfVersionSatisfied(expected: String, actual: String): Boolean = { + val (expMajorMinor, expPatch) = expected.split('.').splitAt(2) + val (actMajorMinor, actPatch) = actual.split('.').splitAt(2) + actMajorMinor.startsWith(expMajorMinor) && { + val expPatchInts = expPatch.map(_.toInt) + val actPatchInts = actPatch.map(v => Try(v.toInt).getOrElse(Int.MinValue)) + val zipped = expPatchInts.zipAll(actPatchInts, 0, 0) + zipped.forall { case (e, a) => e <= a } + } + } +} + object ExecutionPlanCaptureCallback { private[this] val shouldCapture: AtomicBoolean = new AtomicBoolean(false) private[this] val execPlan: AtomicReference[SparkPlan] = new AtomicReference[SparkPlan]() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsExecutorPluginSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsExecutorPluginSuite.scala new file mode 100644 index 00000000000..1ca5aad2bf4 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsExecutorPluginSuite.scala @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.scalatest.FunSuite + +class RapidsExecutorPluginSuite extends FunSuite { + test("cudf version check") { + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7", "8")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7.2")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7", "8.7")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7", "7.2.1")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0.1")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.0.1.3")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0", "7.1")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1.3")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.2")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.2.3")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.0")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.1")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.1.1")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.1", "7.0.1-special")) + assert(!RapidsExecutorPlugin.cudfVersionSatisfied("7.0.2.2.2", "7.0.2.2")) + assert(RapidsExecutorPlugin.cudfVersionSatisfied("7.0.2.2.2", "7.0.2.2.2")) + } +}