From 008f1d5e0af5400d44b06b89b2837a06689707c6 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Thu, 7 Jul 2022 09:45:55 -0400 Subject: [PATCH] Support octal and hex within regex character class pattern (#11112) Closes #11109 Adds support for `\0` octal and `\x` hex patterns within a regex character class `[ ]` pattern. Refactors the existing octal and hex parsing in non-class expression so it can be reused when building a character class instruction. The refactored code was also simplified as well. This change fixes the linked issue by supporting `\0` and `\x00` in the expression which identify embedded null characters. Additional gtests were added to check for octal and hex within an `[ ]` expression. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/11112 --- cpp/src/strings/regex/regcomp.cpp | 201 +++++++++++++++------------ cpp/tests/strings/contains_tests.cpp | 24 +++- 2 files changed, 131 insertions(+), 94 deletions(-) diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 992d66a5ff4..50d641c9a74 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -188,6 +188,42 @@ class regex_parser { std::vector _items; bool _has_counted{false}; + /** + * @brief Parses octal characters at the current expression position + * to return the represented character + * + * Reads up to 3 octal digits. The first digit should be passed + * in `in_chr`. + * + * @param in_chr The first character of the octal pattern + * @return The resulting character + */ + char32_t handle_octal(char32_t in_chr) + { + auto out_chr = in_chr - '0'; + auto c = *_expr_ptr; + auto digits = 1; + while ((c >= '0') && (c <= '7') && (digits < 3)) { + out_chr = (out_chr * 8) | (c - '0'); + c = *(++_expr_ptr); + ++digits; + } + return out_chr; + } + + /** + * @brief Parses 2 hex characters at the current expression position + * to return the represented character + * + * @return The resulting character + */ + char32_t handle_hex() + { + std::string hex(1, static_cast(*_expr_ptr++)); + hex.append(1, static_cast(*_expr_ptr++)); + return static_cast(std::stol(hex, nullptr, 16)); // 16 = hex + } + /** * @brief Returns the next character in the expression * @@ -239,6 +275,14 @@ class regex_parser { case 'a': chr = 0x07; break; case 'b': chr = 0x08; break; case 'f': chr = 0x0C; break; + case '0' ... '7': { + chr = handle_octal(chr); + break; + } + case 'x': { + chr = handle_hex(); + break; + } case 'w': builtins |= cclass_w.builtins; std::tie(is_quoted, chr) = next_char(); @@ -313,101 +357,76 @@ class regex_parser { auto [is_quoted, chr] = next_char(); if (is_quoted) { - // treating all quoted numbers as Octal, since we are not supporting backreferences - if (chr >= '0' && chr <= '7') { - chr = chr - '0'; - auto c = *_expr_ptr; - auto digits = 1; - while (c >= '0' && c <= '7' && digits < 3) { - chr = (chr << 3) | (c - '0'); - c = *(++_expr_ptr); - ++digits; + switch (chr) { + case 't': chr = '\t'; break; + case 'n': chr = '\n'; break; + case 'r': chr = '\r'; break; + case 'a': chr = 0x07; break; + case 'f': chr = 0x0C; break; + case '0' ... '7': { + chr = handle_octal(chr); + break; } - _chr = chr; - return CHAR; - } else { - switch (chr) { - case 't': chr = '\t'; break; - case 'n': chr = '\n'; break; - case 'r': chr = '\r'; break; - case 'a': chr = 0x07; break; - case 'f': chr = 0x0C; break; - case '0': chr = 0; break; - case 'x': { - char32_t a = *_expr_ptr++; - char32_t b = *_expr_ptr++; - chr = 0; - if (a >= '0' && a <= '9') - chr += (a - '0') << 4; - else if (a >= 'a' && a <= 'f') - chr += (a - 'a' + 10) << 4; - else if (a >= 'A' && a <= 'F') - chr += (a - 'A' + 10) << 4; - if (b >= '0' && b <= '9') - chr += b - '0'; - else if (b >= 'a' && b <= 'f') - chr += b - 'a' + 10; - else if (b >= 'A' && b <= 'F') - chr += b - 'A' + 10; - break; - } - case 'w': { - if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); } - _cclass_id = _id_cclass_w; - return CCLASS; - } - case 'W': { - if (_id_cclass_W < 0) { - reclass cls = cclass_w; - cls.literals.push_back({'\n', '\n'}); - _id_cclass_W = _prog.add_class(cls); - } - _cclass_id = _id_cclass_W; - return NCCLASS; - } - case 's': { - if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } - _cclass_id = _id_cclass_s; - return CCLASS; - } - case 'S': { - if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } - _cclass_id = _id_cclass_s; - return NCCLASS; - } - case 'd': { - if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); } - _cclass_id = _id_cclass_d; - return CCLASS; + case 'x': { + chr = handle_hex(); + break; + } + case 'w': { + if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); } + _cclass_id = _id_cclass_w; + return CCLASS; + } + case 'W': { + if (_id_cclass_W < 0) { + reclass cls = cclass_w; + cls.literals.push_back({'\n', '\n'}); + _id_cclass_W = _prog.add_class(cls); } - case 'D': { - if (_id_cclass_D < 0) { - reclass cls = cclass_d; - cls.literals.push_back({'\n', '\n'}); - _id_cclass_D = _prog.add_class(cls); - } - _cclass_id = _id_cclass_D; - return NCCLASS; + _cclass_id = _id_cclass_W; + return NCCLASS; + } + case 's': { + if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } + _cclass_id = _id_cclass_s; + return CCLASS; + } + case 'S': { + if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); } + _cclass_id = _id_cclass_s; + return NCCLASS; + } + case 'd': { + if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); } + _cclass_id = _id_cclass_d; + return CCLASS; + } + case 'D': { + if (_id_cclass_D < 0) { + reclass cls = cclass_d; + cls.literals.push_back({'\n', '\n'}); + _id_cclass_D = _prog.add_class(cls); } - case 'b': return BOW; - case 'B': return NBOW; - case 'A': return BOL; - case 'Z': return EOL; - default: { - // let valid escapable chars fall through as literal CHAR - if (chr && (std::find(escapable_chars.begin(), - escapable_chars.end(), - static_cast(chr)) != escapable_chars.end())) { - break; - } - // anything else is a bad escape so throw an error - CUDF_FAIL("invalid regex pattern: bad escape character at position " + - std::to_string(_expr_ptr - _pattern_begin - 1)); + _cclass_id = _id_cclass_D; + return NCCLASS; + } + case 'b': return BOW; + case 'B': return NBOW; + case 'A': return BOL; + case 'Z': return EOL; + default: { + // let valid escapable chars fall through as literal CHAR + if (chr && + (std::find(escapable_chars.begin(), escapable_chars.end(), static_cast(chr)) != + escapable_chars.end())) { + break; } - } // end-switch - _chr = chr; - return CHAR; - } + // anything else is a bad escape so throw an error + CUDF_FAIL("invalid regex pattern: bad escape character at position " + + std::to_string(_expr_ptr - _pattern_begin - 1)); + } + } // end-switch + _chr = chr; + return CHAR; } // handle regex characters diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index 21c18977746..70f28aa139d 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -244,15 +244,21 @@ TEST_F(StringsContainsTests, MatchesIPV4Test) TEST_F(StringsContainsTests, OctalTest) { - cudf::test::strings_column_wrapper strings({"A3", "B", "CDA3EY", ""}); + cudf::test::strings_column_wrapper strings({"A3", "B", "CDA3EY", "", "99", "\a\t\r"}); auto strings_view = cudf::strings_column_view(strings); - cudf::test::fixed_width_column_wrapper expected({1, 0, 1, 0}); - auto results = cudf::strings::contains_re(strings_view, "\\101"); + auto expected = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0, 0, 0}); + auto results = cudf::strings::contains_re(strings_view, "\\101"); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); results = cudf::strings::contains_re(strings_view, "\\1013"); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); results = cudf::strings::contains_re(strings_view, "D*\\101\\063"); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + results = cudf::strings::contains_re(strings_view, "\\719"); + expected = cudf::test::fixed_width_column_wrapper({0, 0, 0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + results = cudf::strings::contains_re(strings_view, "[\\7][\\11][\\15]"); + expected = cudf::test::fixed_width_column_wrapper({0, 0, 0, 0, 0, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } TEST_F(StringsContainsTests, HexTest) @@ -279,6 +285,10 @@ TEST_F(StringsContainsTests, HexTest) 0, [ch](auto idx) { return ch == static_cast(idx); }); cudf::test::fixed_width_column_wrapper expected(true_dat, true_dat + count); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + // also test hex character appearing in character class brackets + pattern = "[" + pattern + "]"; + results = cudf::strings::contains_re(strings_view, pattern); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } } @@ -304,6 +314,14 @@ TEST_F(StringsContainsTests, EmbeddedNullCharacter) results = cudf::strings::contains_re(strings_view, "J\\0B"); expected = cudf::test::fixed_width_column_wrapper({0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); + + results = cudf::strings::contains_re(strings_view, "[G-J][\\0]B"); + expected = cudf::test::fixed_width_column_wrapper({0, 0, 0, 0, 0, 0, 1, 1, 1, 1}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); + + results = cudf::strings::contains_re(strings_view, "[A-D][\\x00]B"); + expected = cudf::test::fixed_width_column_wrapper({1, 1, 1, 1, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected); } TEST_F(StringsContainsTests, Errors)