Skip to content

Commit

Permalink
Java bindings for Fixed-point type support for Parquet (#7153)
Browse files Browse the repository at this point in the history
Adds in java support to be able to write fixed-point type to parquet

Authors:
  - Raza Jafri (@razajafri)

Approvers:
  - Karthikeyan (@karthikeyann)
  - Jason Lowe (@jlowe)

URL: #7153
  • Loading branch information
razajafri authored Jan 21, 2021
1 parent 6390498 commit 4111cb7
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 14 deletions.
14 changes: 13 additions & 1 deletion cpp/include/cudf/io/parquet.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -846,6 +846,18 @@ class chunked_parquet_writer_options_builder {
return *this;
}

/**
* @brief Sets decimal precision data.
*
* @param v Vector of precision data flattened with exactly one entry per
* decimal column.
*/
chunked_parquet_writer_options_builder& decimal_precision(std::vector<uint8_t> const& v)
{
options._decimal_precision = v;
return *this;
}

/**
* @brief Sets compression type to chunked_parquet_writer_options.
*
Expand Down
38 changes: 37 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ParquetWriterOptions.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,6 +42,7 @@ public enum StatisticsFrequency {
public static class Builder extends CMWriterBuilder<Builder> {
private StatisticsFrequency statsGranularity = StatisticsFrequency.ROWGROUP;
private boolean isTimestampTypeInt96 = false;
private int[] precisionValues = null;

public Builder withStatisticsFrequency(StatisticsFrequency statsGranularity) {
this.statsGranularity = statsGranularity;
Expand All @@ -56,6 +57,30 @@ public Builder withTimestampInt96(boolean int96) {
return this;
}

/**
* Overwrite flattened precision values for all decimal columns that are expected to be in
* this Table. The list of precisions should be an in-order traversal of all Decimal columns,
* including nested columns. Please look at the example below.
*
* NOTE: The number of `precisionValues` should be equal to the numbers of Decimal columns
* otherwise a CudfException will be thrown. Also note that the values will be overwritten
* every time this method is called
*
* Example:
* Table0 : c0[type: INT32]
* c1[type: Decimal32(3, 1)]
* c2[type: Struct[col0[type: Decimal(2, 1)],
* col1[type: INT64],
* col2[type: Decimal(8, 6)]]
* c3[type: Decimal64(12, 5)]
*
* Flattened list of precision from the above example will be {3, 2, 8, 12}
*/
public Builder withPrecisionValues(int... precisionValues) {
this.precisionValues = precisionValues;
return this;
}

public ParquetWriterOptions build() {
return new ParquetWriterOptions(this);
}
Expand All @@ -73,12 +98,21 @@ private ParquetWriterOptions(Builder builder) {
super(builder);
this.statsGranularity = builder.statsGranularity;
this.isTimestampTypeInt96 = builder.isTimestampTypeInt96;
this.precisions = builder.precisionValues;
}

public StatisticsFrequency getStatisticsFrequency() {
return statsGranularity;
}

/**
* Return the flattened list of precisions if set otherwise empty array will be returned.
* For a definition of what `flattened` means please look at {@link Builder#withPrecisionValues}
*/
public int[] getPrecisions() {
return precisions;
}

/**
* Returns true if the writer is expected to write timestamps in INT96
*/
Expand All @@ -87,4 +121,6 @@ public boolean isTimestampTypeInt96() {
}

private boolean isTimestampTypeInt96;

private int[] precisions;
}
13 changes: 12 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -239,6 +239,9 @@ private static native long[] readParquet(String[] filterColumnNames, String file
* @param metadataValues Metadata values corresponding to metadataKeys
* @param compression native compression codec ID
* @param statsFreq native statistics frequency ID
* @param isInt96 true if timestamp type is int96
* @param precisions precision list containing all the precisions of the decimal types in
* the columns
* @param filename local output path
* @return a handle that is used in later calls to writeParquetChunk and writeParquetEnd.
*/
Expand All @@ -248,6 +251,8 @@ private static native long writeParquetFileBegin(String[] columnNames,
String[] metadataValues,
int compression,
int statsFreq,
boolean isInt96,
int[] precisions,
String filename) throws CudfException;

/**
Expand All @@ -259,6 +264,8 @@ private static native long writeParquetFileBegin(String[] columnNames,
* @param compression native compression codec ID
* @param statsFreq native statistics frequency ID
* @param isInt96 true if timestamp type is int96
* @param precisions precision list containing all the precisions of the decimal types in
* the columns
* @param consumer consumer of host buffers produced.
* @return a handle that is used in later calls to writeParquetChunk and writeParquetEnd.
*/
Expand All @@ -269,6 +276,7 @@ private static native long writeParquetBufferBegin(String[] columnNames,
int compression,
int statsFreq,
boolean isInt96,
int[] precisions,
HostBufferConsumer consumer) throws CudfException;

/**
Expand Down Expand Up @@ -778,6 +786,8 @@ private ParquetTableWriter(ParquetWriterOptions options, File outputFile) {
options.getMetadataValues(),
options.getCompressionType().nativeId,
options.getStatisticsFrequency().nativeId,
options.isTimestampTypeInt96(),
options.getPrecisions(),
outputFile.getAbsolutePath());
}

Expand All @@ -789,6 +799,7 @@ private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer cons
options.getCompressionType().nativeId,
options.getStatisticsFrequency().nativeId,
options.isTimestampTypeInt96(),
options.getPrecisions(),
consumer);
this.consumer = consumer;
}
Expand Down
17 changes: 14 additions & 3 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet(
JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin(
JNIEnv *env, jclass, jobjectArray j_col_names, jbooleanArray j_col_nullability,
jobjectArray j_metadata_keys, jobjectArray j_metadata_values, jint j_compression,
jint j_stats_freq, jboolean j_isInt96, jobject consumer) {
jint j_stats_freq, jboolean j_isInt96, jintArray j_precisions, jobject consumer) {
JNI_NULL_CHECK(env, j_col_names, "null columns", 0);
JNI_NULL_CHECK(env, j_col_nullability, "null nullability", 0);
JNI_NULL_CHECK(env, j_metadata_keys, "null metadata keys", 0);
Expand All @@ -859,13 +859,18 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin(
std::unique_ptr<cudf::jni::jni_writer_data_sink> data_sink(
new cudf::jni::jni_writer_data_sink(env, consumer));
sink_info sink{data_sink.get()};
cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> const v_precisions(
precisions.data(), precisions.data() + precisions.size());
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

std::shared_ptr<pq_chunked_state> state = write_parquet_chunked_begin(opts);
cudf::jni::native_parquet_writer_handle *ret =
new cudf::jni::native_parquet_writer_handle(state, data_sink);
Expand All @@ -877,7 +882,7 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin(
JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetFileBegin(
JNIEnv *env, jclass, jobjectArray j_col_names, jbooleanArray j_col_nullability,
jobjectArray j_metadata_keys, jobjectArray j_metadata_values, jint j_compression,
jint j_stats_freq, jstring j_output_path) {
jint j_stats_freq, jboolean j_isInt96, jintArray j_precisions, jstring j_output_path) {
JNI_NULL_CHECK(env, j_col_names, "null columns", 0);
JNI_NULL_CHECK(env, j_col_nullability, "null nullability", 0);
JNI_NULL_CHECK(env, j_metadata_keys, "null metadata keys", 0);
Expand All @@ -900,14 +905,20 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetFileBegin(
for (size_t i = 0; i < meta_keys.size(); ++i) {
metadata.user_data[meta_keys[i].get()] = meta_values[i].get();
}

cudf::jni::native_jintArray precisions(env, j_precisions);
std::vector<uint8_t> v_precisions(
precisions.data(), precisions.data() + precisions.size());

sink_info sink{output_path.get()};
chunked_parquet_writer_options opts =
chunked_parquet_writer_options::builder(sink)
.nullable_metadata(&metadata)
.compression(static_cast<compression_type>(j_compression))
.stats_level(static_cast<statistics_freq>(j_stats_freq))
.int96_timestamps(static_cast<bool>(j_isInt96))
.decimal_precision(v_precisions)
.build();

std::shared_ptr<pq_chunked_state> state = write_parquet_chunked_begin(opts);
cudf::jni::native_parquet_writer_handle *ret =
new cudf::jni::native_parquet_writer_handle(state);
Expand Down
38 changes: 30 additions & 8 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -3949,6 +3949,20 @@ private Table getExpectedFileTable() {
.build();
}

private Table getExpectedFileTableWithDecimals() {
return new TestBuilder()
.column(true, false, false, true, false)
.column(5, 1, 0, 2, 7)
.column(new Byte[]{2, 3, 4, 5, 9})
.column(3l, 9l, 4l, 2l, 20l)
.column("this", "is", "a", "test", "string")
.column(1.0f, 3.5f, 5.9f, 7.1f, 9.8f)
.column(5.0d, 9.5d, 0.9d, 7.23d, 2.8d)
.decimal32Column(3, 298, 2473, 2119, 1273, 9879)
.decimal64Column(4, 398l, 1322l, 983237l, 99872l, 21337l)
.build();
}

@Test
void testParquetWriteToFileNoNames() throws IOException {
File tempFile = File.createTempFile("test-nonames", ".parquet");
Expand Down Expand Up @@ -4008,9 +4022,12 @@ public long readInto(HostMemoryBuffer buffer, long len) {

@Test
void testParquetWriteToBufferChunkedInt96() {
try (Table table0 = getExpectedFileTable();
try (Table table0 = getExpectedFileTableWithDecimals();
MyBufferConsumer consumer = new MyBufferConsumer()) {
ParquetWriterOptions options = ParquetWriterOptions.builder().withTimestampInt96(true).build();
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withTimestampInt96(true)
.withPrecisionValues(5, 5)
.build();

try (TableWriter writer = Table.writeParquetChunked(options, consumer)) {
writer.write(table0);
Expand Down Expand Up @@ -4043,11 +4060,13 @@ void testParquetWriteToBufferChunked() {
@Test
void testParquetWriteToFileWithNames() throws IOException {
File tempFile = File.createTempFile("test-names", ".parquet");
try (Table table0 = getExpectedFileTable()) {
try (Table table0 = getExpectedFileTableWithDecimals()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh",
"eighth", "nineth")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(5, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4063,12 +4082,14 @@ void testParquetWriteToFileWithNames() throws IOException {
@Test
void testParquetWriteToFileWithNamesAndMetadata() throws IOException {
File tempFile = File.createTempFile("test-names-metadata", ".parquet");
try (Table table0 = getExpectedFileTable()) {
try (Table table0 = getExpectedFileTableWithDecimals()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh")
.withColumnNames("first", "second", "third", "fourth", "fifth", "sixth", "seventh",
"eighth", "nineth")
.withMetadata("somekey", "somevalue")
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(6, 8)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand All @@ -4084,10 +4105,11 @@ void testParquetWriteToFileWithNamesAndMetadata() throws IOException {
@Test
void testParquetWriteToFileUncompressedNoStats() throws IOException {
File tempFile = File.createTempFile("test-uncompressed", ".parquet");
try (Table table0 = getExpectedFileTable()) {
try (Table table0 = getExpectedFileTableWithDecimals()) {
ParquetWriterOptions options = ParquetWriterOptions.builder()
.withCompressionType(CompressionType.NONE)
.withStatisticsFrequency(ParquetWriterOptions.StatisticsFrequency.NONE)
.withPrecisionValues(4, 6)
.build();
try (TableWriter writer = Table.writeParquetChunked(options, tempFile.getAbsoluteFile())) {
writer.write(table0);
Expand Down

0 comments on commit 4111cb7

Please sign in to comment.