Skip to content

Commit

Permalink
Support kNN filter on nested metadata
Browse files Browse the repository at this point in the history
Current knn search over nested vectors only supports filtering
on parent's meatadata. This adds support for filtering over
nested metadata.

Closes elastic#106994
  • Loading branch information
mayya-sharipova committed Oct 2, 2024
1 parent eb9b897 commit d29e647
Show file tree
Hide file tree
Showing 14 changed files with 512 additions and 45 deletions.
37 changes: 22 additions & 15 deletions docs/reference/query-dsl/knn-query.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,30 @@ to <<nested-knn-search, top level nested kNN search>>:

* kNN search over nested dense_vectors diversifies the top results over
the top-level document
* `filter` over the top-level document metadata is supported and acts as a
pre-filter
* `filter` over `nested` field metadata is not supported
* `filter` is supported both over the top-level document metadata and `nested`
field metadata. Filter acts as a pre-filter.

A sample query can look like below:
A sample query with filter over nested metadata can look like below:

[source,js]
----
{
"query" : {
"nested" : {
"path" : "paragraph",
"query" : {
"knn": {
"query_vector": [
0.45,
45
],
"field": "paragraph.vector",
"num_candidates": 2
"query": {
"nested": {
"path": "paragraph",
"query": {
"knn": {
"query_vector": [
0.45,
45
],
"field": "paragraph.vector",
"k": 10,
"filter": {
"match": {
"paragraph.language": "EN"
}
}
}
}
}
Expand All @@ -268,6 +272,9 @@ A sample query can look like below:
----
// NOTCONSOLE

The above query only considers vectors with `"paragraph.language": "EN"`
for scoring parents' documents.

[[knn-query-aggregations]]
==== Knn query with aggregations
`knn` query calculates aggregations on top `k` documents from each shard.
Expand Down
46 changes: 38 additions & 8 deletions docs/reference/search/search-your-data/knn-search.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ PUT passage_vectors
"type": "hnsw"
}
},
"language": {
"type" : "keyword"
},
"text": {
"type": "text",
"index": false
Expand All @@ -695,9 +698,9 @@ With the above mapping, we can index multiple passage vectors along with storing
----
POST passage_vectors/_bulk?refresh=true
{ "index": { "_id": "1" } }
{ "full_text": "first paragraph another paragraph", "creation_time": "2019-05-04", "paragraph": [ { "vector": [ 0.45, 45 ], "text": "first paragraph", "paragraph_id": "1" }, { "vector": [ 0.8, 0.6 ], "text": "another paragraph", "paragraph_id": "2" } ] }
{ "full_text": "first paragraph another paragraph", "creation_time": "2019-05-04", "paragraph": [ { "vector": [ 0.45, 45 ], "text": "first paragraph", "paragraph_id": "1", "language": "EN" }, { "vector": [ 0.8, 0.6 ], "text": "another paragraph", "paragraph_id": "2", "language": "FR" } ] }
{ "index": { "_id": "2" } }
{ "full_text": "number one paragraph number two paragraph", "creation_time": "2020-05-04", "paragraph": [ { "vector": [ 1.2, 4.5 ], "text": "number one paragraph", "paragraph_id": "1" }, { "vector": [ -1, 42 ], "text": "number two paragraph", "paragraph_id": "2" } ] }
{ "full_text": "number one paragraph number two paragraph", "creation_time": "2020-05-04", "paragraph": [ { "vector": [ 1.2, 4.5 ], "text": "number one paragraph", "paragraph_id": "1", "language": "FR" }, { "vector": [ -1, 42 ], "text": "number two paragraph", "paragraph_id": "2", "language": "EN" } ] }
----
//TEST[continued]
//TEST[s/\.\.\.//]
Expand Down Expand Up @@ -776,12 +779,8 @@ scored by their nearest passage vector (e.g. `"paragraph.vector"`).
----
// TESTRESPONSE[s/"took": 4/"took" : "$body.took"/]

What if you wanted to filter by some top-level document metadata? You can do this by adding `filter` to your
`knn` clause.


NOTE: `filter` will always be over the top-level document metadata. This means you cannot filter based on `nested`
field metadata.
What if you wanted to filter by some document metadata? You can do this by adding `filter` to your
`knn` clause. `filter` can be run based on both: the top-level document metadata and `nested` field metadata.

[source,console]
----
Expand Down Expand Up @@ -858,6 +857,37 @@ Now we have filtered based on the top level `"creation_time"` and only one docum
----
// TESTRESPONSE[s/"took": 4/"took" : "$body.took"/]


Filtering by nested field metadata: `paragraph.language` makes kNN search only consider vectors with this metadata
for scoring parents' documents:

[source,console]
----
POST passage_vectors/_search
{
"fields": [
"creation_time",
"full_text"
],
"_source": false,
"knn": {
"query_vector": [
0.45,
45
],
"field": "paragraph.vector",
"k": 2,
"num_candidates": 2,
"filter": {
"match" : {
"paragraph.language" : "EN"
}
}
}
}
----
//TEST[continued]

[discrete]
[[nested-knn-search-inner-hits]]
==== Nested kNN Search with Inner hits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ setup:
nested:
type: nested
properties:
language:
type: keyword
paragraph_id:
type: keyword
vector:
Expand All @@ -37,8 +39,10 @@ setup:
nested:
- paragraph_id: 0
vector: [230.0, 300.33, -34.8988, 15.555, -200.0]
language: EN
- paragraph_id: 1
vector: [240.0, 300, -3, 1, -20]
language: FR

- do:
index:
Expand All @@ -49,10 +53,13 @@ setup:
nested:
- paragraph_id: 0
vector: [-0.5, 100.0, -13, 14.8, -156.0]
language: EN
- paragraph_id: 2
vector: [0, 100.0, 0, 14.8, -156.0]
language: EN
- paragraph_id: 3
vector: [0, 1.0, 0, 1.8, -15.0]
language: FR

- do:
index:
Expand All @@ -63,6 +70,7 @@ setup:
nested:
- paragraph_id: 0
vector: [0.5, 111.3, -13.0, 14.8, -156.0]
language: FR

- do:
indices.refresh: {}
Expand Down Expand Up @@ -461,3 +469,72 @@ setup:
- match: {hits.hits.0._id: "2"}
- length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}


---
"Test filter on nested fields":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ knn_filter_on_nested_fields ]
test_runner_features: ["capabilities", "close_to"]
reason: "Capability for filtering on nested fields required"

- do:
search:
index: test
body:
_source: false
knn:
boost: 2
field: nested.vector
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
k: 3
num_candidates: 10
filter: { match: { nested.language: "EN" } }
inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }

- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "2" }
- match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" }
- close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0182, error: 0.0001 } }
- match: { hits.hits.1._id: "1" }
- match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }

- do:
search:
index: test
body:
_source: false
knn:
boost: 2
field: nested.vector
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
k: 3
num_candidates: 10
filter: { match: { nested.language: "FR" } }
inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }

- match: { hits.total.value: 3 }
- match: { hits.hits.0._id: "3" }
- match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
- close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ setup:
nested:
type: nested
properties:
language:
type: keyword
paragraph_id:
type: keyword
vector:
Expand All @@ -38,8 +40,10 @@ setup:
nested:
- paragraph_id: 0
vector: [230.0, 300.33, -34.8988, 15.555, -200.0]
language: EN
- paragraph_id: 1
vector: [240.0, 300, -3, 1, -20]
language: FR

- do:
index:
Expand All @@ -50,10 +54,13 @@ setup:
nested:
- paragraph_id: 0
vector: [-0.5, 100.0, -13, 14.8, -156.0]
language: EN
- paragraph_id: 2
vector: [0, 100.0, 0, 14.8, -156.0]
language: EN
- paragraph_id: 3
vector: [0, 1.0, 0, 1.8, -15.0]
language: FR

- do:
index:
Expand All @@ -64,6 +71,7 @@ setup:
nested:
- paragraph_id: 0
vector: [0.5, 111.3, -13.0, 14.8, -156.0]
language: FR

- do:
indices.refresh: {}
Expand Down Expand Up @@ -406,3 +414,82 @@ setup:

- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "2"}


---
"Test filter on nested fields":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ knn_filter_on_nested_fields ]
test_runner_features: ["capabilities", "close_to"]
reason: "Capability for filtering on nested fields required"

- do:
search:
index: test
body:
_source: false
query:
nested:
path: nested
query:
knn:
boost: 2
field: nested.vector
query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
k: 10
filter:
match:
nested.language: "EN"
inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }

- match: {hits.total.value: 2}
- match: {hits.hits.0._id: "2"}
- match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" }
- close_to: { hits.hits.0._score: { value: 0.02036, error: 0.0001 } }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.02036, error: 0.0001 } }
- match: {hits.hits.1._id: "1"}
- match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }

- do:
search:
index: test
body:
_source: false
query:
nested:
path: nested
query:
knn:
boost: 2
field: nested.vector
query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
k: 10
filter:
match:
nested.language: "FR"
inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false }

- match: { hits.total.value: 3 }
- match: { hits.hits.0._id: "3" }
- match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
- close_to: { hits.hits.0._score: { value: 0.0041, error: 0.0001 } }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0041, error: 0.0001 } }
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ static TransportVersion def(int id) {
public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0);
public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0);
public static final TransportVersion REGEX_AND_RANGE_INTERVAL_QUERIES = def(8_757_00_0);
public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(8_758_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS),
source.knnSearch().get(i).getField(),
source.knnSearch().get(i).getQueryVector(),
source.knnSearch().get(i).getSimilarity()
source.knnSearch().get(i).getSimilarity(),
source.knnSearch().get(i).getFilterQueries()
).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName());
if (nestedPath != null) {
query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());
Expand Down
Loading

0 comments on commit d29e647

Please sign in to comment.