Skip to content

Commit

Permalink
Java bindings for left outer distinct join (#15154)
Browse files Browse the repository at this point in the history
Adds Java bindings to the distinct left join functionality added in #15149.

Authors:
  - Jason Lowe (https://github.com/jlowe)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Jim Brennan (https://github.com/jbrennan333)

URL: #15154
  • Loading branch information
jlowe authored Mar 6, 2024
1 parent eb8de18 commit d824fa5
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 11 deletions.
52 changes: 41 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ private static native long[] merge(long[] tableHandles, int[] sortKeyIndexes,
private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] leftDistinctJoinGatherMap(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException;
Expand Down Expand Up @@ -2949,6 +2952,33 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes a gather map that can be used to manifest the result of a left equi-join between
* two tables where the right table is guaranteed to not contain any duplicated join keys.
* The left table can be used as-is to produce the left table columns resulting from the join,
* i.e.: left table ordering is preserved in the join result, so no gather map is required for
* the left table. The resulting gather map can be applied to the right table to produce the
* right table columns resulting from the join. It is assumed this table instance holds the
* key columns from the left table, and the table argument represents the key columns from the
* right table. A {@link GatherMap} instance will be returned that can be used to gather the
* right table and that result combined with the left table to produce a left outer join result.
*
* It is the responsibility of the caller to close the resulting gather map instance.
*
* @param rightKeys join key columns from the right table
* @param compareNullsEqual true if null key values should match otherwise false
* @return right table gather map
*/
public GatherMap leftDistinctJoinGatherMap(Table rightKeys, boolean compareNullsEqual) {
if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) {
throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightKeys.getNumberOfColumns());
}
long[] gatherMapData =
leftDistinctJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
* Computes the number of rows resulting from a left equi-join between two tables.
* It is assumed this table instance holds the key columns from the left table, and the
Expand Down Expand Up @@ -3576,7 +3606,7 @@ public static GatherMap[] mixedFullJoinGatherMaps(Table leftKeys, Table rightKey
return buildJoinGatherMaps(gatherMapData);
}

private static GatherMap buildSemiJoinGatherMap(long[] gatherMapData) {
private static GatherMap buildSingleJoinGatherMap(long[] gatherMapData) {
long bufferSize = gatherMapData[0];
long leftAddr = gatherMapData[1];
long leftHandle = gatherMapData[2];
Expand All @@ -3601,7 +3631,7 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
}
long[] gatherMapData =
leftSemiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3634,7 +3664,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3659,7 +3689,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftSemiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), outputRowCount);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3716,7 +3746,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe
leftConditional.getNativeView(), rightConditional.getNativeView(),
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3752,7 +3782,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL,
joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3773,7 +3803,7 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
}
long[] gatherMapData =
leftAntiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3806,7 +3836,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3831,7 +3861,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftAntiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), outputRowCount);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3888,7 +3918,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
leftConditional.getNativeView(), rightConditional.getNativeView(),
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3924,7 +3954,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL,
joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down
18 changes: 18 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps(
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMap(
JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) {
return cudf::jni::join_gather_single_map(
env, j_left_keys, j_right_keys, compare_nulls_equal,
[](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) {
auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right) ?
cudf::nullable_join::YES :
cudf::nullable_join::NO;
if (cudf::detail::has_nested_columns(right)) {
cudf::distinct_hash_join<cudf::has_nested::YES> hash(right, left, has_nulls, nulleq);
return hash.left_join();
} else {
cudf::distinct_hash_join<cudf::has_nested::NO> hash(right, left, has_nulls, nulleq);
return hash.left_join();
}
});
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join) {
Expand Down
101 changes: 101 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,107 @@ void testLeftJoinGatherMapsNulls() {
}
}

private void checkLeftDistinctJoin(Table leftKeys, Table rightKeys, ColumnView expected,
boolean compareNullsEqual) {
try (GatherMap map = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual)) {
int numRows = (int) expected.getRowCount();
assertEquals(numRows, map.getRowCount());
try (ColumnView view = map.toColumnView(0, numRows)) {
assertColumnsAreEqual(expected, view);
}
}
}

@Test
void testLeftDistinctJoinGatherMaps() {
final int inv = Integer.MIN_VALUE;
try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build();
Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build();
ColumnVector expected = ColumnVector.fromInts(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3, 0)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testLeftDistinctJoinGatherMapsWithNested() {
final int inv = Integer.MIN_VALUE;
StructType structType = new StructType(false,
new BasicType(false, DType.STRING),
new BasicType(false, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", 2),
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3)
};
StructData[] rightData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData("abc", -1),
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
ColumnVector expected = ColumnVector.fromInts(0, inv, inv, 2, 0, inv, inv)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testLeftDistinctJoinGatherMapsNullsEqual() {
final int inv = Integer.MIN_VALUE;
try (Table leftKeys = new Table.TestBuilder()
.column(2, 3, 9, 0, 1, 7, 4, null, null, 8)
.build();
Table rightKeys = new Table.TestBuilder()
.column(null, 9, 8, 10, 32)
.build();
ColumnVector expected = ColumnVector.fromInts(inv, inv, 1, inv, inv, inv, inv, 0, 0, 2)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testLeftDistinctJoinGatherMapsWithNestedNullsEqual() {
final int inv = Integer.MIN_VALUE;
StructType structType = new StructType(true,
new BasicType(true, DType.STRING),
new BasicType(true, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
null,
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", null),
null,
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3),
new StructData(null, null),
new StructData(null, 1)
};
StructData[] rightData = new StructData[]{
null,
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData(null, null),
new StructData(null, 2),
new StructData(null, 1),
new StructData("xyz", null),
new StructData("abc", null),
new StructData("abc", -1)
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
ColumnVector expected = ColumnVector.fromInts(1, 0, inv, inv, 7, 0, 1, inv, inv, 4, 6)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testLeftHashJoinGatherMaps() {
final int inv = Integer.MIN_VALUE;
Expand Down

0 comments on commit d824fa5

Please sign in to comment.