Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contextualise filter edge case fixes #537

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/marqo/tensor_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@
'/', '*', '^', '\\', '!', '[', '||', '?',
'&&', '"', ']', '-', '{', '~', '+', '}', ':', ')', '('
}

# these are chars that are not officially listed as Lucene special chars, but
# aren't treated as normal chars either
NON_OFFICIAL_LUCENE_SPECIAL_CHARS = {
' '
}
70 changes: 53 additions & 17 deletions src/marqo/tensor_search/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,43 @@ def build_searchable_attributes_filter(searchable_attribs: Sequence) -> str:
f"{enums.TensorField.chunks}.{enums.TensorField.field_name}:{sanitised_attr_name}"
f" OR {build_searchable_attributes_filter(searchable_attribs=searchable_attribs)}")

def sanitise_lucene_special_chars(user_str: str) -> str:
"""Santitises Lucene's special chars.

def sanitise_lucene_special_chars(to_be_sanitised: str) -> str:
"""Santitises Lucene's special chars in a string.

We shouldn't apply this to the user's filter string, as they can choose to escape
Lucene's special chars themselves.

This should be used to sanitise a filter string constructed for users behind the
scenes (such as for searchable attributes).

See here for more info:
https://lucene.apache.org/core/6_0_0/queryparser/org/apache/lucene/queryparser/classic/package-summary.html#Escaping_Special_Characters

"""
# this prevents us from double-escaping backslashes. This may be unnecessary.
non_backslash_chars = constants.LUCENE_SPECIAL_CHARS.union(constants.NON_OFFICIAL_LUCENE_SPECIAL_CHARS) - {'\\'}

# Add backslash to backslash:
# We do this first, so that we don't double-escape the backslashes
user_str = user_str.replace("\\", "\\\\")
to_be_sanitised.replace("\\", "\\\\")

# Add backslash to all other special chars:
for char in constants.LUCENE_SPECIAL_CHARS:
if not char is "\\":
user_str = user_str.replace(char, f'\\{char}')

return user_str
for char in non_backslash_chars:
to_be_sanitised = to_be_sanitised.replace(char, f'\\{char}')
return to_be_sanitised


def contextualise_user_filter(filter_string: Optional[str], simple_properties: typing.Iterable) -> str:
"""adds the chunk prefix to the start of properties found in simple string (filter_string)
This allows for filtering within chunks.

Because this is a user-defined filter, if they want to filter on a field names that contain
special characters, we expect them to escape the special characters themselves.

In order to search chunks we need to append the chunk prefix to the start of the field name.
This will only work if they escape the special characters in the field names themselves in
the exact same way that we do.

Args:
filter_string:
filter_string: the user defined filter string
simple_properties: simple properties of an index (such as text or floats
and bools)

Expand All @@ -85,10 +96,35 @@ def contextualise_user_filter(filter_string: Optional[str], simple_properties: t
if filter_string is None:
return ''
contextualised_filter = filter_string

for field in simple_properties:
if ' ' in field:
field_with_escaped_space = field.replace(' ', r'\ ') # monitor this: fixed the invalid escape sequence (Deprecation warning).
contextualised_filter = contextualised_filter.replace(f'{field_with_escaped_space}:', f'{enums.TensorField.chunks}.{field_with_escaped_space}:')
else:
contextualised_filter = contextualised_filter.replace(f'{field}:', f'{enums.TensorField.chunks}.{field}:')
escaped_field_name = sanitise_lucene_special_chars(field)
if escaped_field_name in filter_string:
# we want to replace only the field name that directly corresponds to the simple property,
# not any other field names that contain the simple property as a substring.
if (
# this is for the case where the field name is at the start of the filter string
filter_string.startswith(escaped_field_name) and

# for cases like filter_string:"z_z_z:foo", escaped_field_name=z
# where the field name is a substring at the start of the field name
# in the filter string.
# This prevents us from accidentally generating the filter_string:
# "__chunks_.z___chunks_.z___chunks_.z:foo":
len(filter_string.split(':')[0]) == len(escaped_field_name)
):
contextualised_filter = contextualised_filter.replace(
f'{escaped_field_name}:', f'{enums.TensorField.chunks}.{escaped_field_name}:')
else:
# the case where the field name is not at the start of the filter string

# the case where the field name is after as space
# e.g.: "field_a:foo AND field_b:bar, escaped_field_name=field_b"
contextualised_filter = contextualised_filter.replace(
f' {escaped_field_name}:', f' {enums.TensorField.chunks}.{escaped_field_name}:')

# the case where the field name is directly after an opening bracket
contextualised_filter = contextualised_filter.replace(
f'({escaped_field_name}:', f'({enums.TensorField.chunks}.{escaped_field_name}:')

return contextualised_filter
40 changes: 21 additions & 19 deletions tests/tensor_search/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def test_contextualise_user_filter(self):
f"({enums.TensorField.chunks}.an_int:[0 TO 30] AND {enums.TensorField.chunks}.an_int:2) AND {enums.TensorField.chunks}.abc:(some text)"
),
( # fields with spaces
r'spaced\ int:[0 TO 30]',
"spaced\\ int:[0 TO 30]",
["spaced int"],
rf'{enums.TensorField.chunks}.spaced\ int:[0 TO 30]'
f"{enums.TensorField.chunks}.spaced\\ int:[0 TO 30]"
),
( # field in string not in properties
"field_not_in_properties:random AND normal_field:3",
Expand All @@ -45,11 +45,12 @@ def test_contextualise_user_filter(self):
)
]
for given_filter_string, given_simple_properties, expected in expected_mappings:
assert expected == filtering.contextualise_user_filter(
filter_string=given_filter_string,
contextualised_user_filter = filtering.contextualise_user_filter(
filter_string=given_filter_string,
simple_properties=given_simple_properties,
)

assert expected == contextualised_user_filter

def test_build_searchable_attributes_filter(self):
expected_mappings = [
(["an_int", "abc"],
Expand All @@ -67,31 +68,32 @@ def test_build_tensor_search_filter(self):
test_cases = (
{
"filter_string": "abc:(some text)",
"simple_properties": ["abc"],
"searchable_attributes": ["bleh"],
"expected": f"({enums.TensorField.chunks}.{enums.TensorField.field_name}:(bleh)) AND ({enums.TensorField.chunks}.abc:(some text))"
"simple_properties": {"abc": "xyz"},
"searchable_attributes": ["abc"],
"expected": f"({enums.TensorField.chunks}.{enums.TensorField.field_name}:(abc)) AND ({enums.TensorField.chunks}.abc:(some text))"
},
# special character in searchable attribute
# escaped space in filter string
)
for case in test_cases:
assert case["expected"] == filtering.build_tensor_search_filter(
tensor_search_filter = filtering.build_tensor_search_filter(
filter_string=case["filter_string"],
simple_properties=case["simple_properties"],
searchable_attribs=case["searchable_attributes"]
)

assert case["expected"] == tensor_search_filter

def test_sanitise_lucene_special_chars(self):
expected_mappings = [
("plain text", "plain text"),
("exclamation!", "exclamation\\!"),
("some text?", "some text\\?"),
("some text&&", "some text\\&&"),
("some text\\", "some text\\\\"),
("everything ||&?\\ combined", "everything \\||&\\?\\\\ combined"),
("some text&", "some text&") # no change, & is not a special char
("some text", "some\\ text"),
("text!", "text\\!"),
("some text!", "some\\ text\\!"),
("text?", "text\\?"),
("text&&", "text\\&&"),
("text&", "text&") # no change, & is not a special char
]
for given, expected in expected_mappings:
assert expected == filtering.sanitise_lucene_special_chars(
escaped_output = filtering.sanitise_lucene_special_chars(
given
)
)
assert expected == escaped_output
Loading