diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 2b367ff6b3f..4fffc82ab2d 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3278,6 +3278,19 @@ private static void assertIsSupportedMapKeyType(DType keyType) { assert isSupportedKeyType : "Map lookup by STRUCT and LIST keys is not supported."; } + /** + * Given a column of type List> and a key column of type X, return a column of type Y, + * where each row in the output column is the Y value corresponding to the X key. + * If the key is not found, the corresponding output value is null. + * @param keys the column view with keys to lookup in the column + * @return a column of values or nulls based on the lookup result + */ + public final ColumnVector getMapValue(ColumnView keys) { + assert type.equals(DType.LIST) : "column type must be a LIST"; + assert keys != null : "Lookup key may not be null"; + return new ColumnVector(mapLookupForKeys(getNativeView(), keys.getNativeView())); + } + /** * Given a column of type List> and a key of type X, return a column of type Y, * where each row in the output column is the Y value corresponding to the X key. @@ -3913,6 +3926,21 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ private static native long mapLookup(long columnView, long key) throws CudfException; + /** + * Native method for map lookup over a column of List> + * The lookup column must have as many rows as the map column, + * and must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * If a key is repeated in a map row, the value corresponding to the last matching + * key is returned. + * If a lookup key is null or not found, the corresponding value is null. + * @param columnView the column view handle of the map + * @param keys the column view holding the keys + * @return a column of values corresponding the value of the lookup key. + * @throws CudfException + */ + private static native long mapLookupForKeys(long columnView, long keys) throws CudfException; + /** * Native method for check the existence of a key over a column of List> * @param columnView the column view handle of the map diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 664f4a5561d..a3ccffdf8b3 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1391,6 +1391,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplace(JNIEnv *env CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookupForKeys(JNIEnv *env, jclass, + jlong map_column_view, + jlong lookup_keys) { + JNI_NULL_CHECK(env, map_column_view, "column is null", 0); + JNI_NULL_CHECK(env, lookup_keys, "lookup key is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const *cv = reinterpret_cast(map_column_view); + auto const *column_keys = reinterpret_cast(lookup_keys); + auto const maps_view = cudf::jni::maps_column_view{*cv}; + return release_as_jlong(maps_view.get_values_for(*column_keys)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jclass, jlong map_column_view, jlong lookup_key) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0c279a1e788..c38da81d38b 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -5864,6 +5864,21 @@ void testStructChildValidity() { } } + @Test + void testGetMapValueForKeys() { + List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2))); + List list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(2, 3))); + List list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(5, 4))); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3); + ColumnVector lookupKey = ColumnVector.fromInts(1, 6, 5); + ColumnVector res = cv.getMapValue(lookupKey); + ColumnVector expected = ColumnVector.fromBoxedInts(2, null, 4)) { + assertColumnsAreEqual(expected, res); + } + } + @Test void testGetMapValueForInteger() { List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2)));