Skip to content

Commit

Permalink
[SPARK-48441][SQL][FOLLOWUP] Fix StringTrim behaviour for UTF8_LCASE …
Browse files Browse the repository at this point in the history
…collation

### What changes were proposed in this pull request?
Fix how `StringTrim*` expressions handle edge cases with one-to-many case mapping, by updating the lowercase trimming logic that matches multiple characters from the `srcString` to a single character in `trimStr`. These methods now correctly follow the iterative "longest possible match" behaviour (proposed for StringTrim in the original PR).

### Why are the changes needed?
Fix a subtle bug in `trim`-like functions.

Example:
```
trim("ii\u0307", "İi"); -- returns: "\u0307" (wrong), instead of "" (correct)
```

### Does this PR introduce _any_ user-facing change?
Yes, trim* function behaviour is slightly altered for UTF8_LCASE.

### How was this patch tested?
New tests in `CollationSupportSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47836 from uros-db/alter-trim-followup.

Authored-by: Uros Bojanic <uros.bojanic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
uros-db authored and MaxGekk committed Aug 26, 2024
1 parent e10a789 commit f394cd3
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -990,20 +990,29 @@ public static UTF8String lowercaseTrimLeft(
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the left to find the first character that is not in the set.
int searchIndex = 0, codePoint;
int searchIndex = 0, codePoint, codePointBuffer = -1;
Iterator<Integer> srcIter = srcString.codePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
// Get the next code point from either the buffer or the iterator.
if (codePointBuffer != -1) {
codePoint = codePointBuffer;
codePointBuffer = -1;
}
else {
codePoint = getLowercaseCodePoint(srcIter.next());
}
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_COMBINING_DOT) {
codePointBuffer = codePoint;
codePoint = getLowercaseCodePoint(srcIter.next());
if (codePoint == CODE_POINT_COMBINING_DOT) {
searchIndex += 2;
}
else {
if (trimChars.contains(codePoint)) ++searchIndex;
codePointBuffer = -1;
} else if (trimChars.contains(codePointBuffer)) {
++searchIndex;
codePointBuffer = codePoint;
} else {
break;
}
} else if (trimChars.contains(codePoint)) {
Expand Down Expand Up @@ -1100,20 +1109,28 @@ public static UTF8String lowercaseTrimRight(
while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next()));

// Iterate over `srcString` from the right to find the first character that is not in the set.
int searchIndex = srcString.numChars(), codePoint;
int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1;
Iterator<Integer> srcIter = srcString.reverseCodePointIterator();
while (srcIter.hasNext()) {
codePoint = getLowercaseCodePoint(srcIter.next());
if (codePointBuffer != -1) {
codePoint = codePointBuffer;
codePointBuffer = -1;
}
else {
codePoint = getLowercaseCodePoint(srcIter.next());
}
// Special handling for Turkish dotted uppercase letter I.
if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() &&
trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) {
int nextCodePoint = getLowercaseCodePoint(srcIter.next());
if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint))
|| nextCodePoint == CODE_POINT_LOWERCASE_I) {
codePointBuffer = codePoint;
codePoint = getLowercaseCodePoint(srcIter.next());
if (codePoint == CODE_POINT_LOWERCASE_I) {
searchIndex -= 2;
}
else {
if (trimChars.contains(codePoint)) --searchIndex;
codePointBuffer = -1;
} else if (trimChars.contains(codePointBuffer)) {
--searchIndex;
codePointBuffer = codePoint;
} else {
break;
}
} else if (trimChars.contains(codePoint)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2741,6 +2741,10 @@ public void testStringTrim() throws SparkException {
assertStringTrim("UTF8_BINARY", "ixi", "i", "x");
assertStringTrim("UTF8_BINARY", "i", "İ", "i");
assertStringTrim("UTF8_BINARY", "i\u0307", "İ", "i\u0307");
assertStringTrim("UTF8_BINARY", "ii\u0307", "İi", "\u0307");
assertStringTrim("UTF8_BINARY", "iii\u0307", "İi", "\u0307");
assertStringTrim("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307");
assertStringTrim("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307");
assertStringTrim("UTF8_BINARY", "i\u0307", "i", "\u0307");
assertStringTrim("UTF8_BINARY", "i\u0307", "\u0307", "i");
assertStringTrim("UTF8_BINARY", "i\u0307", "i\u0307", "");
Expand All @@ -2766,6 +2770,10 @@ public void testStringTrim() throws SparkException {
assertStringTrim("UTF8_LCASE", "ixi", "i", "x");
assertStringTrim("UTF8_LCASE", "i", "İ", "i");
assertStringTrim("UTF8_LCASE", "i\u0307", "İ", "");
assertStringTrim("UTF8_LCASE", "ii\u0307", "İi", "");
assertStringTrim("UTF8_LCASE", "iii\u0307", "İi", "");
assertStringTrim("UTF8_LCASE", "iiii\u0307", "iİ", "");
assertStringTrim("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", "");
assertStringTrim("UTF8_LCASE", "i\u0307", "i", "\u0307");
assertStringTrim("UTF8_LCASE", "i\u0307", "\u0307", "i");
assertStringTrim("UTF8_LCASE", "i\u0307", "i\u0307", "");
Expand All @@ -2791,6 +2799,10 @@ public void testStringTrim() throws SparkException {
assertStringTrim("UNICODE", "ixi", "i", "x");
assertStringTrim("UNICODE", "i", "İ", "i");
assertStringTrim("UNICODE", "i\u0307", "İ", "i\u0307");
assertStringTrim("UNICODE", "ii\u0307", "İi", "i\u0307");
assertStringTrim("UNICODE", "iii\u0307", "İi", "i\u0307");
assertStringTrim("UNICODE", "iiii\u0307", "iİ", "i\u0307");
assertStringTrim("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307");
assertStringTrim("UNICODE", "i\u0307", "i", "i\u0307");
assertStringTrim("UNICODE", "i\u0307", "\u0307", "i\u0307");
assertStringTrim("UNICODE", "i\u0307", "i\u0307", "i\u0307");
Expand All @@ -2817,6 +2829,10 @@ public void testStringTrim() throws SparkException {
assertStringTrim("UNICODE_CI", "ixi", "i", "x");
assertStringTrim("UNICODE_CI", "i", "İ", "i");
assertStringTrim("UNICODE_CI", "i\u0307", "İ", "");
assertStringTrim("UNICODE_CI", "ii\u0307", "İi", "");
assertStringTrim("UNICODE_CI", "iii\u0307", "İi", "");
assertStringTrim("UNICODE_CI", "iiii\u0307", "iİ", "");
assertStringTrim("UNICODE_CI", "ii\u0307ii\u0307", "iİ", "");
assertStringTrim("UNICODE_CI", "i\u0307", "i", "i\u0307");
assertStringTrim("UNICODE_CI", "i\u0307", "\u0307", "i\u0307");
assertStringTrim("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307");
Expand Down Expand Up @@ -3021,6 +3037,10 @@ public void testStringTrimLeft() throws SparkException {
assertStringTrimLeft("UTF8_BINARY", "ixi", "i", "xi");
assertStringTrimLeft("UTF8_BINARY", "i", "İ", "i");
assertStringTrimLeft("UTF8_BINARY", "i\u0307", "İ", "i\u0307");
assertStringTrimLeft("UTF8_BINARY", "ii\u0307", "İi", "\u0307");
assertStringTrimLeft("UTF8_BINARY", "iii\u0307", "İi", "\u0307");
assertStringTrimLeft("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307");
assertStringTrimLeft("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307");
assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i", "\u0307");
assertStringTrimLeft("UTF8_BINARY", "i\u0307", "\u0307", "i\u0307");
assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i\u0307", "");
Expand All @@ -3046,6 +3066,10 @@ public void testStringTrimLeft() throws SparkException {
assertStringTrimLeft("UTF8_LCASE", "ixi", "i", "xi");
assertStringTrimLeft("UTF8_LCASE", "i", "İ", "i");
assertStringTrimLeft("UTF8_LCASE", "i\u0307", "İ", "");
assertStringTrimLeft("UTF8_LCASE", "ii\u0307", "İi", "");
assertStringTrimLeft("UTF8_LCASE", "iii\u0307", "İi", "");
assertStringTrimLeft("UTF8_LCASE", "iiii\u0307", "iİ", "");
assertStringTrimLeft("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", "");
assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i", "\u0307");
assertStringTrimLeft("UTF8_LCASE", "i\u0307", "\u0307", "i\u0307");
assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i\u0307", "");
Expand All @@ -3071,6 +3095,10 @@ public void testStringTrimLeft() throws SparkException {
assertStringTrimLeft("UNICODE", "ixi", "i", "xi");
assertStringTrimLeft("UNICODE", "i", "İ", "i");
assertStringTrimLeft("UNICODE", "i\u0307", "İ", "i\u0307");
assertStringTrimLeft("UNICODE", "ii\u0307", "İi", "i\u0307");
assertStringTrimLeft("UNICODE", "iii\u0307", "İi", "i\u0307");
assertStringTrimLeft("UNICODE", "iiii\u0307", "iİ", "i\u0307");
assertStringTrimLeft("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307");
assertStringTrimLeft("UNICODE", "i\u0307", "i", "i\u0307");
assertStringTrimLeft("UNICODE", "i\u0307", "\u0307", "i\u0307");
assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307");
Expand All @@ -3097,6 +3125,10 @@ public void testStringTrimLeft() throws SparkException {
assertStringTrimLeft("UNICODE_CI", "ixi", "i", "xi");
assertStringTrimLeft("UNICODE_CI", "i", "İ", "i");
assertStringTrimLeft("UNICODE_CI", "i\u0307", "İ", "");
assertStringTrimLeft("UNICODE_CI", "ii\u0307", "İi", "");
assertStringTrimLeft("UNICODE_CI", "iii\u0307", "İi", "");
assertStringTrimLeft("UNICODE_CI", "iiii\u0307", "iİ", "");
assertStringTrimLeft("UNICODE_CI", "ii\u0307ii\u0307", "iİ", "");
assertStringTrimLeft("UNICODE_CI", "i\u0307", "i", "i\u0307");
assertStringTrimLeft("UNICODE_CI", "i\u0307", "\u0307", "i\u0307");
assertStringTrimLeft("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307");
Expand Down Expand Up @@ -3302,6 +3334,10 @@ public void testStringTrimRight() throws SparkException {
assertStringTrimRight("UTF8_BINARY", "ixi", "i", "ix");
assertStringTrimRight("UTF8_BINARY", "i", "İ", "i");
assertStringTrimRight("UTF8_BINARY", "i\u0307", "İ", "i\u0307");
assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307");
assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307");
assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307");
assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307");
assertStringTrimRight("UTF8_BINARY", "i\u0307", "i", "i\u0307");
assertStringTrimRight("UTF8_BINARY", "i\u0307", "\u0307", "i");
assertStringTrimRight("UTF8_BINARY", "i\u0307", "i\u0307", "");
Expand All @@ -3327,6 +3363,10 @@ public void testStringTrimRight() throws SparkException {
assertStringTrimRight("UTF8_LCASE", "ixi", "i", "ix");
assertStringTrimRight("UTF8_LCASE", "i", "İ", "i");
assertStringTrimRight("UTF8_LCASE", "i\u0307", "İ", "");
assertStringTrimRight("UTF8_LCASE", "ii\u0307", "İi", "");
assertStringTrimRight("UTF8_LCASE", "iii\u0307", "İi", "");
assertStringTrimRight("UTF8_LCASE", "iiii\u0307", "iİ", "");
assertStringTrimRight("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", "");
assertStringTrimRight("UTF8_LCASE", "i\u0307", "i", "i\u0307");
assertStringTrimRight("UTF8_LCASE", "i\u0307", "\u0307", "i");
assertStringTrimRight("UTF8_LCASE", "i\u0307", "i\u0307", "");
Expand All @@ -3352,6 +3392,10 @@ public void testStringTrimRight() throws SparkException {
assertStringTrimRight("UNICODE", "ixi", "i", "ix");
assertStringTrimRight("UNICODE", "i", "İ", "i");
assertStringTrimRight("UNICODE", "i\u0307", "İ", "i\u0307");
assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307");
assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307");
assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307");
assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307");
assertStringTrimRight("UNICODE", "i\u0307", "i", "i\u0307");
assertStringTrimRight("UNICODE", "i\u0307", "\u0307", "i\u0307");
assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307");
Expand All @@ -3378,6 +3422,10 @@ public void testStringTrimRight() throws SparkException {
assertStringTrimRight("UNICODE_CI", "ixi", "i", "ix");
assertStringTrimRight("UNICODE_CI", "i", "İ", "i");
assertStringTrimRight("UNICODE_CI", "i\u0307", "İ", "");
assertStringTrimRight("UNICODE_CI", "ii\u0307", "İi", "");
assertStringTrimRight("UNICODE_CI", "iii\u0307", "İi", "");
assertStringTrimRight("UNICODE_CI", "iiii\u0307", "iİ", "");
assertStringTrimRight("UNICODE_CI", "ii\u0307ii\u0307", "iİ", "");
assertStringTrimRight("UNICODE_CI", "i\u0307", "i", "i\u0307");
assertStringTrimRight("UNICODE_CI", "i\u0307", "\u0307", "i\u0307");
assertStringTrimRight("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307");
Expand Down

0 comments on commit f394cd3

Please sign in to comment.