From f394cd3878bd38cfe6fbe2fc7ac01ada853f6375 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Mon, 26 Aug 2024 15:12:23 +0200 Subject: [PATCH] [SPARK-48441][SQL][FOLLOWUP] Fix StringTrim behaviour for UTF8_LCASE collation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Fix how `StringTrim*` expressions handle edge cases with one-to-many case mapping, by updating the lowercase trimming logic that matches multiple characters from the `srcString` to a single character in `trimStr`. These methods now correctly follow the iterative "longest possible match" behaviour (proposed for StringTrim in the original PR). ### Why are the changes needed? Fix a subtle bug in `trim`-like functions. Example: ``` trim("ii\u0307", "İi"); -- returns: "\u0307" (wrong), instead of "" (correct) ``` ### Does this PR introduce _any_ user-facing change? Yes, trim* function behaviour is slightly altered for UTF8_LCASE. ### How was this patch tested? New tests in `CollationSupportSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47836 from uros-db/alter-trim-followup. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../util/CollationAwareUTF8String.java | 49 +++++++++++++------ .../unsafe/types/CollationSupportSuite.java | 48 ++++++++++++++++++ 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 9f26cc0bac21c..6e7a41af593db 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -990,20 +990,29 @@ public static UTF8String lowercaseTrimLeft( while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next())); // Iterate over `srcString` from the left to find the first character that is not in the set. - int searchIndex = 0, codePoint; + int searchIndex = 0, codePoint, codePointBuffer = -1; Iterator srcIter = srcString.codePointIterator(); while (srcIter.hasNext()) { - codePoint = getLowercaseCodePoint(srcIter.next()); + // Get the next code point from either the buffer or the iterator. + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } + else { + codePoint = getLowercaseCodePoint(srcIter.next()); + } // Special handling for Turkish dotted uppercase letter I. if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() && trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { - int nextCodePoint = getLowercaseCodePoint(srcIter.next()); - if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint)) - || nextCodePoint == CODE_POINT_COMBINING_DOT) { + codePointBuffer = codePoint; + codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePoint == CODE_POINT_COMBINING_DOT) { searchIndex += 2; - } - else { - if (trimChars.contains(codePoint)) ++searchIndex; + codePointBuffer = -1; + } else if (trimChars.contains(codePointBuffer)) { + ++searchIndex; + codePointBuffer = codePoint; + } else { break; } } else if (trimChars.contains(codePoint)) { @@ -1100,20 +1109,28 @@ public static UTF8String lowercaseTrimRight( while (trimIter.hasNext()) trimChars.add(getLowercaseCodePoint(trimIter.next())); // Iterate over `srcString` from the right to find the first character that is not in the set. - int searchIndex = srcString.numChars(), codePoint; + int searchIndex = srcString.numChars(), codePoint, codePointBuffer = -1; Iterator srcIter = srcString.reverseCodePointIterator(); while (srcIter.hasNext()) { - codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePointBuffer != -1) { + codePoint = codePointBuffer; + codePointBuffer = -1; + } + else { + codePoint = getLowercaseCodePoint(srcIter.next()); + } // Special handling for Turkish dotted uppercase letter I. if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() && trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { - int nextCodePoint = getLowercaseCodePoint(srcIter.next()); - if ((trimChars.contains(codePoint) && trimChars.contains(nextCodePoint)) - || nextCodePoint == CODE_POINT_LOWERCASE_I) { + codePointBuffer = codePoint; + codePoint = getLowercaseCodePoint(srcIter.next()); + if (codePoint == CODE_POINT_LOWERCASE_I) { searchIndex -= 2; - } - else { - if (trimChars.contains(codePoint)) --searchIndex; + codePointBuffer = -1; + } else if (trimChars.contains(codePointBuffer)) { + --searchIndex; + codePointBuffer = codePoint; + } else { break; } } else if (trimChars.contains(codePoint)) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 202d730974533..72fb1f65bf9b7 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -2741,6 +2741,10 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_BINARY", "ixi", "i", "x"); assertStringTrim("UTF8_BINARY", "i", "İ", "i"); assertStringTrim("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrim("UTF8_BINARY", "ii\u0307", "İi", "\u0307"); + assertStringTrim("UTF8_BINARY", "iii\u0307", "İi", "\u0307"); + assertStringTrim("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307"); + assertStringTrim("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307"); assertStringTrim("UTF8_BINARY", "i\u0307", "i", "\u0307"); assertStringTrim("UTF8_BINARY", "i\u0307", "\u0307", "i"); assertStringTrim("UTF8_BINARY", "i\u0307", "i\u0307", ""); @@ -2766,6 +2770,10 @@ public void testStringTrim() throws SparkException { assertStringTrim("UTF8_LCASE", "ixi", "i", "x"); assertStringTrim("UTF8_LCASE", "i", "İ", "i"); assertStringTrim("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrim("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrim("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrim("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrim("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); assertStringTrim("UTF8_LCASE", "i\u0307", "i", "\u0307"); assertStringTrim("UTF8_LCASE", "i\u0307", "\u0307", "i"); assertStringTrim("UTF8_LCASE", "i\u0307", "i\u0307", ""); @@ -2791,6 +2799,10 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE", "ixi", "i", "x"); assertStringTrim("UNICODE", "i", "İ", "i"); assertStringTrim("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrim("UNICODE", "ii\u0307", "İi", "i\u0307"); + assertStringTrim("UNICODE", "iii\u0307", "İi", "i\u0307"); + assertStringTrim("UNICODE", "iiii\u0307", "iİ", "i\u0307"); + assertStringTrim("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307"); assertStringTrim("UNICODE", "i\u0307", "i", "i\u0307"); assertStringTrim("UNICODE", "i\u0307", "\u0307", "i\u0307"); assertStringTrim("UNICODE", "i\u0307", "i\u0307", "i\u0307"); @@ -2817,6 +2829,10 @@ public void testStringTrim() throws SparkException { assertStringTrim("UNICODE_CI", "ixi", "i", "x"); assertStringTrim("UNICODE_CI", "i", "İ", "i"); assertStringTrim("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrim("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrim("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrim("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrim("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrim("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrim("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrim("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); @@ -3021,6 +3037,10 @@ public void testStringTrimLeft() throws SparkException { assertStringTrimLeft("UTF8_BINARY", "ixi", "i", "xi"); assertStringTrimLeft("UTF8_BINARY", "i", "İ", "i"); assertStringTrimLeft("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrimLeft("UTF8_BINARY", "ii\u0307", "İi", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "iii\u0307", "İi", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "iiii\u0307", "iİ", "\u0307"); + assertStringTrimLeft("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "\u0307ii\u0307"); assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i", "\u0307"); assertStringTrimLeft("UTF8_BINARY", "i\u0307", "\u0307", "i\u0307"); assertStringTrimLeft("UTF8_BINARY", "i\u0307", "i\u0307", ""); @@ -3046,6 +3066,10 @@ public void testStringTrimLeft() throws SparkException { assertStringTrimLeft("UTF8_LCASE", "ixi", "i", "xi"); assertStringTrimLeft("UTF8_LCASE", "i", "İ", "i"); assertStringTrimLeft("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrimLeft("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrimLeft("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrimLeft("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrimLeft("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i", "\u0307"); assertStringTrimLeft("UTF8_LCASE", "i\u0307", "\u0307", "i\u0307"); assertStringTrimLeft("UTF8_LCASE", "i\u0307", "i\u0307", ""); @@ -3071,6 +3095,10 @@ public void testStringTrimLeft() throws SparkException { assertStringTrimLeft("UNICODE", "ixi", "i", "xi"); assertStringTrimLeft("UNICODE", "i", "İ", "i"); assertStringTrimLeft("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrimLeft("UNICODE", "ii\u0307", "İi", "i\u0307"); + assertStringTrimLeft("UNICODE", "iii\u0307", "İi", "i\u0307"); + assertStringTrimLeft("UNICODE", "iiii\u0307", "iİ", "i\u0307"); + assertStringTrimLeft("UNICODE", "ii\u0307ii\u0307", "iİ", "i\u0307ii\u0307"); assertStringTrimLeft("UNICODE", "i\u0307", "i", "i\u0307"); assertStringTrimLeft("UNICODE", "i\u0307", "\u0307", "i\u0307"); assertStringTrimLeft("UNICODE", "i\u0307", "i\u0307", "i\u0307"); @@ -3097,6 +3125,10 @@ public void testStringTrimLeft() throws SparkException { assertStringTrimLeft("UNICODE_CI", "ixi", "i", "xi"); assertStringTrimLeft("UNICODE_CI", "i", "İ", "i"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrimLeft("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrimLeft("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrimLeft("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrimLeft("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimLeft("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrimLeft("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307"); @@ -3302,6 +3334,10 @@ public void testStringTrimRight() throws SparkException { assertStringTrimRight("UTF8_BINARY", "ixi", "i", "ix"); assertStringTrimRight("UTF8_BINARY", "i", "İ", "i"); assertStringTrimRight("UTF8_BINARY", "i\u0307", "İ", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307"); assertStringTrimRight("UTF8_BINARY", "i\u0307", "i", "i\u0307"); assertStringTrimRight("UTF8_BINARY", "i\u0307", "\u0307", "i"); assertStringTrimRight("UTF8_BINARY", "i\u0307", "i\u0307", ""); @@ -3327,6 +3363,10 @@ public void testStringTrimRight() throws SparkException { assertStringTrimRight("UTF8_LCASE", "ixi", "i", "ix"); assertStringTrimRight("UTF8_LCASE", "i", "İ", "i"); assertStringTrimRight("UTF8_LCASE", "i\u0307", "İ", ""); + assertStringTrimRight("UTF8_LCASE", "ii\u0307", "İi", ""); + assertStringTrimRight("UTF8_LCASE", "iii\u0307", "İi", ""); + assertStringTrimRight("UTF8_LCASE", "iiii\u0307", "iİ", ""); + assertStringTrimRight("UTF8_LCASE", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimRight("UTF8_LCASE", "i\u0307", "i", "i\u0307"); assertStringTrimRight("UTF8_LCASE", "i\u0307", "\u0307", "i"); assertStringTrimRight("UTF8_LCASE", "i\u0307", "i\u0307", ""); @@ -3352,6 +3392,10 @@ public void testStringTrimRight() throws SparkException { assertStringTrimRight("UNICODE", "ixi", "i", "ix"); assertStringTrimRight("UNICODE", "i", "İ", "i"); assertStringTrimRight("UNICODE", "i\u0307", "İ", "i\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307", "İi", "ii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iii\u0307", "İi", "iii\u0307"); + assertStringTrimRight("UTF8_BINARY", "iiii\u0307", "iİ", "iiii\u0307"); + assertStringTrimRight("UTF8_BINARY", "ii\u0307ii\u0307", "iİ", "ii\u0307ii\u0307"); assertStringTrimRight("UNICODE", "i\u0307", "i", "i\u0307"); assertStringTrimRight("UNICODE", "i\u0307", "\u0307", "i\u0307"); assertStringTrimRight("UNICODE", "i\u0307", "i\u0307", "i\u0307"); @@ -3378,6 +3422,10 @@ public void testStringTrimRight() throws SparkException { assertStringTrimRight("UNICODE_CI", "ixi", "i", "ix"); assertStringTrimRight("UNICODE_CI", "i", "İ", "i"); assertStringTrimRight("UNICODE_CI", "i\u0307", "İ", ""); + assertStringTrimRight("UNICODE_CI", "ii\u0307", "İi", ""); + assertStringTrimRight("UNICODE_CI", "iii\u0307", "İi", ""); + assertStringTrimRight("UNICODE_CI", "iiii\u0307", "iİ", ""); + assertStringTrimRight("UNICODE_CI", "ii\u0307ii\u0307", "iİ", ""); assertStringTrimRight("UNICODE_CI", "i\u0307", "i", "i\u0307"); assertStringTrimRight("UNICODE_CI", "i\u0307", "\u0307", "i\u0307"); assertStringTrimRight("UNICODE_CI", "i\u0307", "i\u0307", "i\u0307");