Skip to content

Commit

Permalink
Support octal and hex within regex character class pattern (#11112)
Browse files Browse the repository at this point in the history
Closes #11109 

Adds support for `\0` octal and `\x` hex patterns within a regex character class `[ ]` pattern. Refactors the existing octal and hex parsing in non-class expression so it can be reused when building a character class instruction. The refactored code was also simplified as well. 
This change fixes the linked issue by supporting `\0` and `\x00` in the expression which identify embedded null characters.
Additional gtests were added to check for octal and hex within an `[ ]` expression.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Bradley Dice (https://github.com/bdice)

URL: #11112
  • Loading branch information
davidwendt authored Jul 7, 2022
1 parent 8426a99 commit 008f1d5
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 94 deletions.
201 changes: 110 additions & 91 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,42 @@ class regex_parser {
std::vector<Item> _items;
bool _has_counted{false};

/**
* @brief Parses octal characters at the current expression position
* to return the represented character
*
* Reads up to 3 octal digits. The first digit should be passed
* in `in_chr`.
*
* @param in_chr The first character of the octal pattern
* @return The resulting character
*/
char32_t handle_octal(char32_t in_chr)
{
auto out_chr = in_chr - '0';
auto c = *_expr_ptr;
auto digits = 1;
while ((c >= '0') && (c <= '7') && (digits < 3)) {
out_chr = (out_chr * 8) | (c - '0');
c = *(++_expr_ptr);
++digits;
}
return out_chr;
}

/**
* @brief Parses 2 hex characters at the current expression position
* to return the represented character
*
* @return The resulting character
*/
char32_t handle_hex()
{
std::string hex(1, static_cast<char>(*_expr_ptr++));
hex.append(1, static_cast<char>(*_expr_ptr++));
return static_cast<char32_t>(std::stol(hex, nullptr, 16)); // 16 = hex
}

/**
* @brief Returns the next character in the expression
*
Expand Down Expand Up @@ -239,6 +275,14 @@ class regex_parser {
case 'a': chr = 0x07; break;
case 'b': chr = 0x08; break;
case 'f': chr = 0x0C; break;
case '0' ... '7': {
chr = handle_octal(chr);
break;
}
case 'x': {
chr = handle_hex();
break;
}
case 'w':
builtins |= cclass_w.builtins;
std::tie(is_quoted, chr) = next_char();
Expand Down Expand Up @@ -313,101 +357,76 @@ class regex_parser {

auto [is_quoted, chr] = next_char();
if (is_quoted) {
// treating all quoted numbers as Octal, since we are not supporting backreferences
if (chr >= '0' && chr <= '7') {
chr = chr - '0';
auto c = *_expr_ptr;
auto digits = 1;
while (c >= '0' && c <= '7' && digits < 3) {
chr = (chr << 3) | (c - '0');
c = *(++_expr_ptr);
++digits;
switch (chr) {
case 't': chr = '\t'; break;
case 'n': chr = '\n'; break;
case 'r': chr = '\r'; break;
case 'a': chr = 0x07; break;
case 'f': chr = 0x0C; break;
case '0' ... '7': {
chr = handle_octal(chr);
break;
}
_chr = chr;
return CHAR;
} else {
switch (chr) {
case 't': chr = '\t'; break;
case 'n': chr = '\n'; break;
case 'r': chr = '\r'; break;
case 'a': chr = 0x07; break;
case 'f': chr = 0x0C; break;
case '0': chr = 0; break;
case 'x': {
char32_t a = *_expr_ptr++;
char32_t b = *_expr_ptr++;
chr = 0;
if (a >= '0' && a <= '9')
chr += (a - '0') << 4;
else if (a >= 'a' && a <= 'f')
chr += (a - 'a' + 10) << 4;
else if (a >= 'A' && a <= 'F')
chr += (a - 'A' + 10) << 4;
if (b >= '0' && b <= '9')
chr += b - '0';
else if (b >= 'a' && b <= 'f')
chr += b - 'a' + 10;
else if (b >= 'A' && b <= 'F')
chr += b - 'A' + 10;
break;
}
case 'w': {
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); }
_cclass_id = _id_cclass_w;
return CCLASS;
}
case 'W': {
if (_id_cclass_W < 0) {
reclass cls = cclass_w;
cls.literals.push_back({'\n', '\n'});
_id_cclass_W = _prog.add_class(cls);
}
_cclass_id = _id_cclass_W;
return NCCLASS;
}
case 's': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return CCLASS;
}
case 'S': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return NCCLASS;
}
case 'd': {
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); }
_cclass_id = _id_cclass_d;
return CCLASS;
case 'x': {
chr = handle_hex();
break;
}
case 'w': {
if (_id_cclass_w < 0) { _id_cclass_w = _prog.add_class(cclass_w); }
_cclass_id = _id_cclass_w;
return CCLASS;
}
case 'W': {
if (_id_cclass_W < 0) {
reclass cls = cclass_w;
cls.literals.push_back({'\n', '\n'});
_id_cclass_W = _prog.add_class(cls);
}
case 'D': {
if (_id_cclass_D < 0) {
reclass cls = cclass_d;
cls.literals.push_back({'\n', '\n'});
_id_cclass_D = _prog.add_class(cls);
}
_cclass_id = _id_cclass_D;
return NCCLASS;
_cclass_id = _id_cclass_W;
return NCCLASS;
}
case 's': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return CCLASS;
}
case 'S': {
if (_id_cclass_s < 0) { _id_cclass_s = _prog.add_class(cclass_s); }
_cclass_id = _id_cclass_s;
return NCCLASS;
}
case 'd': {
if (_id_cclass_d < 0) { _id_cclass_d = _prog.add_class(cclass_d); }
_cclass_id = _id_cclass_d;
return CCLASS;
}
case 'D': {
if (_id_cclass_D < 0) {
reclass cls = cclass_d;
cls.literals.push_back({'\n', '\n'});
_id_cclass_D = _prog.add_class(cls);
}
case 'b': return BOW;
case 'B': return NBOW;
case 'A': return BOL;
case 'Z': return EOL;
default: {
// let valid escapable chars fall through as literal CHAR
if (chr && (std::find(escapable_chars.begin(),
escapable_chars.end(),
static_cast<char>(chr)) != escapable_chars.end())) {
break;
}
// anything else is a bad escape so throw an error
CUDF_FAIL("invalid regex pattern: bad escape character at position " +
std::to_string(_expr_ptr - _pattern_begin - 1));
_cclass_id = _id_cclass_D;
return NCCLASS;
}
case 'b': return BOW;
case 'B': return NBOW;
case 'A': return BOL;
case 'Z': return EOL;
default: {
// let valid escapable chars fall through as literal CHAR
if (chr &&
(std::find(escapable_chars.begin(), escapable_chars.end(), static_cast<char>(chr)) !=
escapable_chars.end())) {
break;
}
} // end-switch
_chr = chr;
return CHAR;
}
// anything else is a bad escape so throw an error
CUDF_FAIL("invalid regex pattern: bad escape character at position " +
std::to_string(_expr_ptr - _pattern_begin - 1));
}
} // end-switch
_chr = chr;
return CHAR;
}

// handle regex characters
Expand Down
24 changes: 21 additions & 3 deletions cpp/tests/strings/contains_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,21 @@ TEST_F(StringsContainsTests, MatchesIPV4Test)

TEST_F(StringsContainsTests, OctalTest)
{
cudf::test::strings_column_wrapper strings({"A3", "B", "CDA3EY", ""});
cudf::test::strings_column_wrapper strings({"A3", "B", "CDA3EY", "", "99", "\a\t\r"});
auto strings_view = cudf::strings_column_view(strings);
cudf::test::fixed_width_column_wrapper<bool> expected({1, 0, 1, 0});
auto results = cudf::strings::contains_re(strings_view, "\\101");
auto expected = cudf::test::fixed_width_column_wrapper<bool>({1, 0, 1, 0, 0, 0});
auto results = cudf::strings::contains_re(strings_view, "\\101");
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
results = cudf::strings::contains_re(strings_view, "\\1013");
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
results = cudf::strings::contains_re(strings_view, "D*\\101\\063");
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
results = cudf::strings::contains_re(strings_view, "\\719");
expected = cudf::test::fixed_width_column_wrapper<bool>({0, 0, 0, 0, 1, 0});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
results = cudf::strings::contains_re(strings_view, "[\\7][\\11][\\15]");
expected = cudf::test::fixed_width_column_wrapper<bool>({0, 0, 0, 0, 0, 1});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

TEST_F(StringsContainsTests, HexTest)
Expand All @@ -279,6 +285,10 @@ TEST_F(StringsContainsTests, HexTest)
0, [ch](auto idx) { return ch == static_cast<char>(idx); });
cudf::test::fixed_width_column_wrapper<bool> expected(true_dat, true_dat + count);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
// also test hex character appearing in character class brackets
pattern = "[" + pattern + "]";
results = cudf::strings::contains_re(strings_view, pattern);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}
}

Expand All @@ -304,6 +314,14 @@ TEST_F(StringsContainsTests, EmbeddedNullCharacter)
results = cudf::strings::contains_re(strings_view, "J\\0B");
expected = cudf::test::fixed_width_column_wrapper<bool>({0, 0, 0, 0, 0, 0, 0, 0, 0, 1});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);

results = cudf::strings::contains_re(strings_view, "[G-J][\\0]B");
expected = cudf::test::fixed_width_column_wrapper<bool>({0, 0, 0, 0, 0, 0, 1, 1, 1, 1});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);

results = cudf::strings::contains_re(strings_view, "[A-D][\\x00]B");
expected = cudf::test::fixed_width_column_wrapper<bool>({1, 1, 1, 1, 0, 0, 0, 0, 0, 0});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(results->view(), expected);
}

TEST_F(StringsContainsTests, Errors)
Expand Down

0 comments on commit 008f1d5

Please sign in to comment.