Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix ORC reader issue with decimal type #6466

Merged
merged 9 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
- PR #6285 Removed unsafe `reinterpret_cast` and implicit pointer-to-bool casts
- PR #6281 Fix unreachable code warning in datetime.cuh
- PR #6286 Fix `read_csv` `int32` overflow
- PR #6466 Fix ORC reader issue with decimal type
- PR #6310 Replace a misspelled reference to `master` branch with `main` branch in a comment in changelog.sh
- PR #6289 Revert #6206
- PR #6291 Fix issue related to row-wise operations in `cudf.DataFrame`
Expand Down
121 changes: 72 additions & 49 deletions cpp/src/io/orc/stripe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1033,61 +1033,84 @@ static __device__ int Decode_Decimals(orc_bytestream_s *bs,
int col_scale,
int t)
{
if (t == 0) {
uint32_t maxpos = min(bs->len, bs->pos + (BYTESTREAM_BFRSZ - 8u));
uint32_t lastpos = bs->pos;
uint32_t n;
for (n = 0; n < numvals; n++) {
uint32_t pos = lastpos;
*(volatile int32_t *)&vals[n] = lastpos;
pos += varint_length<uint4>(bs, pos);
if (pos > maxpos) break;
lastpos = pos;
uint32_t num_vals_read = 0;
// Iterates till `numvals` are read or there is nothing to read once the
// stream has reached its end, and can't read anything more.
while (num_vals_read != numvals) {
if (t == 0) {
uint32_t maxpos = min(bs->len, bs->pos + (BYTESTREAM_BFRSZ - 8u));
uint32_t lastpos = bs->pos;
uint32_t n;
for (n = num_vals_read; n < numvals; n++) {
uint32_t pos = lastpos;
pos += varint_length<uint4>(bs, pos);
if (pos > maxpos) break;
*(volatile int32_t *)&vals[n] = lastpos;
lastpos = pos;
}
scratch->num_vals = n;
bytestream_flush_bytes(bs, lastpos - bs->pos);
}
scratch->num_vals = n;
bytestream_flush_bytes(bs, lastpos - bs->pos);
}
__syncthreads();
numvals = scratch->num_vals;
if (t < numvals) {
int pos = *(volatile int32_t *)&vals[t];
int128_s v = decode_varint128(bs, pos);
__syncthreads();
uint32_t num_vals_to_read = scratch->num_vals;
if (t >= num_vals_read and t < num_vals_to_read) {
int pos = *(volatile int32_t *)&vals[t];
int128_s v = decode_varint128(bs, pos);

if (col_scale & ORC_DECIMAL2FLOAT64_SCALE) {
double f = Int128ToDouble_rn(v.lo, v.hi);
int32_t scale = (t < numvals) ? val_scale : 0;
if (scale >= 0)
reinterpret_cast<volatile double *>(vals)[t] = f / kPow10[min(scale, 39)];
else
reinterpret_cast<volatile double *>(vals)[t] = f * kPow10[min(-scale, 39)];
} else {
int32_t scale = (t < numvals) ? (col_scale & ~ORC_DECIMAL2FLOAT64_SCALE) - val_scale : 0;
if (scale >= 0) {
scale = min(scale, 27);
vals[t] = ((int64_t)v.lo * kPow5i[scale]) << scale;
} else // if (scale < 0)
{
bool is_negative = (v.hi < 0);
uint64_t hi = v.hi, lo = v.lo;
scale = min(-scale, 27);
if (is_negative) {
hi = (~hi) + (lo == 0);
lo = (~lo) + 1;
}
lo = (lo >> (uint32_t)scale) | ((uint64_t)hi << (64 - scale));
hi >>= (int32_t)scale;
if (hi != 0) {
// Use intermediate float
lo = __double2ull_rn(Int128ToDouble_rn(lo, hi) / __ll2double_rn(kPow5i[scale]));
hi = 0;
} else {
lo /= kPow5i[scale];
if (col_scale & ORC_DECIMAL2FLOAT64_SCALE) {
double f = Int128ToDouble_rn(v.lo, v.hi);
int32_t scale = (t < numvals) ? val_scale : 0;
if (scale >= 0)
reinterpret_cast<volatile double *>(vals)[t] = f / kPow10[min(scale, 39)];
else
reinterpret_cast<volatile double *>(vals)[t] = f * kPow10[min(-scale, 39)];
} else {
int32_t scale = (t < numvals) ? (col_scale & ~ORC_DECIMAL2FLOAT64_SCALE) - val_scale : 0;
if (scale >= 0) {
scale = min(scale, 27);
vals[t] = ((int64_t)v.lo * kPow5i[scale]) << scale;
} else // if (scale < 0)
{
bool is_negative = (v.hi < 0);
uint64_t hi = v.hi, lo = v.lo;
scale = min(-scale, 27);
if (is_negative) {
hi = (~hi) + (lo == 0);
lo = (~lo) + 1;
}
lo = (lo >> (uint32_t)scale) | ((uint64_t)hi << (64 - scale));
hi >>= (int32_t)scale;
if (hi != 0) {
// Use intermediate float
lo = __double2ull_rn(Int128ToDouble_rn(lo, hi) / __ll2double_rn(kPow5i[scale]));
hi = 0;
} else {
lo /= kPow5i[scale];
}
vals[t] = (is_negative) ? -(int64_t)lo : (int64_t)lo;
}
vals[t] = (is_negative) ? -(int64_t)lo : (int64_t)lo;
}
}
// There is nothing to read, so break
if (num_vals_read == num_vals_to_read) break;

// Update number of values read (This contains values of previous iteration)
num_vals_read = num_vals_to_read;

// Have to wait till all threads have copied data
__syncthreads();
if (num_vals_read != numvals) {
bytestream_fill(bs, t);
__syncthreads();
if (t == 0) {
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
// Needs to be reset since bytestream has been filled
bs->fill_count = 0;
}
}
// Adding to get all threads in sync before next read
__syncthreads();
}
return numvals;
return num_vals_read;
}

/**
Expand Down
Binary file not shown.
Binary file not shown.
18 changes: 18 additions & 0 deletions python/cudf/cudf/tests/test_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,21 @@ def test_orc_writer_sliced(tmpdir):

df_select.to_orc(cudf_path)
assert_eq(cudf.read_orc(cudf_path), df_select.reset_index(drop=True))


@pytest.mark.parametrize(
"orc_file",
[
"TestOrcFile.decimal.same.values.orc",
"TestOrcFile.decimal.multiple.values.orc",
],
)
def test_orc_reader_decimal_type(datadir, orc_file):
file_path = datadir / orc_file
pdf = pd.read_orc(file_path)
df = cudf.read_orc(file_path).to_pandas()
# Converting to strings since pandas keeps it in decimal
pdf["col8"] = pdf["col8"].astype("str")
df["col8"] = df["col8"].astype("str")

assert_eq(pdf, df)