Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for BM25Retriever in InMemoryDocumentStore #3561

Merged
merged 31 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cf16c25
Merge remote-tracking branch 'origin/main' into imds_support_for_bm25
anakin87 Nov 10, 2022
ee89a34
very first draft
anakin87 Nov 10, 2022
3d7e8aa
implement query and query_batch
anakin87 Nov 11, 2022
5890742
add more bm25 parameters
anakin87 Nov 12, 2022
d94433f
add rank_bm25 dependency
anakin87 Nov 12, 2022
ce5efae
fix mypy
anakin87 Nov 12, 2022
432eff7
remove tokenizer callable parameter
anakin87 Nov 12, 2022
91d40ff
remove unused import
anakin87 Nov 12, 2022
2e39f05
only json serializable attributes
anakin87 Nov 12, 2022
343f1d4
try to fix: pylint too-many-public-methods / R0904
anakin87 Nov 12, 2022
514c248
bm25 attribute always present
anakin87 Nov 12, 2022
03a35a2
convert errors into warnings to make the tutorial 1 work
anakin87 Nov 12, 2022
25d6d42
add docstrings; tests
anakin87 Nov 13, 2022
707a81b
try to make tests run
anakin87 Nov 13, 2022
ac67603
better docstrings; revert not running tests
anakin87 Nov 13, 2022
2381901
some suggestions from review
anakin87 Nov 14, 2022
2d830c9
Merge remote-tracking branch 'upstream/main' into imds_support_for_bm25
anakin87 Nov 14, 2022
34fedfc
rename elasticsearch retriever as bm25 in tests; try to test memory_bm25
anakin87 Nov 14, 2022
bbd9faa
exclude tests with filters
anakin87 Nov 15, 2022
b3b2668
change elasticsearch to bm25 retriever in test_summarizer
anakin87 Nov 15, 2022
be5969e
merge; bm25_algorithm as a property
anakin87 Nov 15, 2022
47df196
Merge branch 'imds_support_for_bm25' of https://github.com/anakin87/h…
anakin87 Nov 15, 2022
d408128
add tests
anakin87 Nov 15, 2022
de970e8
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 15, 2022
2e06683
try to improve tests
anakin87 Nov 17, 2022
1ee1544
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 17, 2022
ba89540
better type hint
anakin87 Nov 17, 2022
832ef82
Merge branch 'main' into imds_support_for_bm25
anakin87 Nov 18, 2022
f64016e
adapt test_table_text_retriever_embedding
anakin87 Nov 18, 2022
99429f6
handle non-textual docs
anakin87 Nov 21, 2022
aad1970
query only textual documents
anakin87 Nov 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,22 @@ def __init__(

self.main_device = self.devices[0]

@property
def bm25_tokenization_regex(self):
return self._tokenizer

@bm25_tokenization_regex.setter
def bm25_tokenization_regex(self, regex_string: str):
self._tokenizer = re.compile(regex_string).findall

@property
def bm25_algorithm(self):
return self._bm25_class

@bm25_algorithm.setter
def bm25_algorithm(self, algorithm: str):
self._bm25_class = getattr(rank_bm25, algorithm)

def write_documents(
self,
documents: Union[List[dict], List[Document]],
Expand Down Expand Up @@ -184,13 +200,11 @@ def update_bm25(self, index: Optional[str] = None):
:param index: Index name for which the BM25 representation is to be updated. If set to None, the default self.index is used.
"""
index = index or self.index
tokenizer = re.compile(self.bm25_tokenization_regex).findall

logger.info("Updating BM25 representation...")
all_documents = self.get_all_documents(index=index)
tokenized_corpus = [tokenizer(doc.content.lower()) for doc in all_documents]
bm25_class = getattr(rank_bm25, self.bm25_algorithm)
self.bm25[index] = bm25_class(tokenized_corpus, **self.bm25_parameters)
tokenized_corpus = [
self.bm25_tokenization_regex(doc.content.lower())
for doc in tqdm(self.get_all_documents(index=index), unit=" docs", desc="Updating BM25 representation...")
]
self.bm25[index] = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)

def _create_document_field_map(self):
return {self.embedding_field: "embedding"}
Expand Down Expand Up @@ -922,10 +936,9 @@ def query(
)

if query is None:
query = ""
return []

tokenizer = re.compile(self.bm25_tokenization_regex).findall
tokenized_query = tokenizer(query.lower())
tokenized_query = self.bm25_tokenization_regex(query.lower())
docs_scores = self.bm25[index].get_scores(tokenized_query)
top_docs_positions = np.argsort(docs_scores)[::-1][:top_k]

Expand Down
7 changes: 5 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def indexing_document_classifier():
)


@pytest.fixture(params=["es_filter_only", "elasticsearch", "dpr", "embedding", "tfidf", "table_text_retriever"])
@pytest.fixture(params=["es_filter_only", "bm25", "dpr", "embedding", "tfidf", "table_text_retriever"])
def retriever(request, document_store):
return get_retriever(request.param, document_store)

Expand Down Expand Up @@ -771,7 +771,7 @@ def get_retriever(retriever_type, document_store):
use_gpu=False,
embed_title=True,
)
elif retriever_type == "elasticsearch":
elif retriever_type == "bm25":
retriever = BM25Retriever(document_store=document_store)
elif retriever_type == "es_filter_only":
retriever = FilterRetriever(document_store=document_store)
Expand Down Expand Up @@ -961,6 +961,9 @@ def get_document_store(
similarity=similarity,
)

elif document_store_type == "memory_bm25":
document_store = InMemoryDocumentStore(index=index, use_bm25=True)

anakin87 marked this conversation as resolved.
Show resolved Hide resolved
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
document_store = ElasticsearchDocumentStore(
Expand Down
29 changes: 28 additions & 1 deletion test/document_stores/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
from rank_bm25 import BM25
import pytest
from unittest.mock import Mock

Expand Down Expand Up @@ -99,7 +100,6 @@ def test_write_with_duplicate_doc_ids_custom_index(document_store: BaseDocumentS


def test_get_all_documents_without_filters(document_store_with_docs):
print("hey!")
documents = document_store_with_docs.get_all_documents()
assert all(isinstance(d, Document) for d in documents)
assert len(documents) == 5
Expand Down Expand Up @@ -1447,3 +1447,30 @@ def test_normalize_embeddings_diff_shapes():
VEC_1 = np.array([0.1, 0.2, 0.3], dtype="float32").reshape(1, -1)
BaseDocumentStore.normalize_embedding(VEC_1)
assert np.linalg.norm(VEC_1) - 1 < 0.01


@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_update_bm25(document_store_with_docs):
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
document_store_with_docs.update_bm25()
bm25_representation = document_store_with_docs.bm25[document_store_with_docs.index]
assert isinstance(bm25_representation, BM25)
assert bm25_representation.corpus_size == document_store_with_docs.get_document_count()


@pytest.mark.parametrize("document_store_with_docs", ["memory_bm25"], indirect=True)
def test_memory_query(document_store_with_docs):
query_text = "Rome"
docs = document_store_with_docs.query(query=query_text, top_k=1)
assert len(docs) == 1
assert docs[0].content == "My name is Matteo and I live in Rome"


@pytest.mark.parametrize("document_store_with_docs", ["memory_bm25"], indirect=True)
def test_memory_query_batch(document_store_with_docs):
query_texts = ["Paris", "Madrid"]
docs = document_store_with_docs.query_batch(queries=query_texts, top_k=5)
assert len(docs) == 2
assert len(docs[0]) == 5
assert docs[0][0].content == "My name is Christelle and I live in Paris"
assert len(docs[1]) == 5
assert docs[1][0].content == "My name is Camila and I live in Madrid"
11 changes: 8 additions & 3 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
("embedding", "faiss"),
("embedding", "memory"),
("embedding", "milvus"),
("elasticsearch", "elasticsearch"),
("bm25", "elasticsearch"),
("bm25", "memory_bm25"),
("es_filter_only", "elasticsearch"),
("tfidf", "memory"),
],
Expand All @@ -65,8 +66,12 @@ def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs:
assert res[0].meta["name"] == "filename1"

# test with filters
if not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore)) and not isinstance(
retriever_with_docs, TfidfRetriever
if (
not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore))
and not isinstance(retriever_with_docs, TfidfRetriever)
and not (
isinstance(document_store_with_docs, InMemoryDocumentStore) and document_store_with_docs.use_bm25 == True
)
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
):
# single filter
result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5)
Expand Down
4 changes: 2 additions & 2 deletions test/nodes/test_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_summarization_batch_multiple_doc_lists(summarizer):
@pytest.mark.integration
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True
"retriever,document_store", [("embedding", "memory"), ("bm25", "elasticsearch")], indirect=True
)
def test_summarization_pipeline(document_store, retriever, summarizer):
document_store.write_documents(DOCS)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_summarization_one_summary(summarizer):
@pytest.mark.integration
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store", [("embedding", "memory"), ("elasticsearch", "elasticsearch")], indirect=True
"retriever,document_store", [("embedding", "memory"), ("bm25", "elasticsearch")], indirect=True
)
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):
document_store.write_documents(SPLIT_DOCS)
Expand Down
6 changes: 3 additions & 3 deletions test/pipelines/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_eval_reader(reader, document_store, use_confidence_scores):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("open_domain", [True, False])
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("retriever", ["bm25"], indirect=True)
def test_eval_elastic_retriever(document_store, open_domain, retriever):
# add eval data (SQUAD format)
document_store.add_eval_data(
Expand All @@ -188,7 +188,7 @@ def test_eval_elastic_retriever(document_store, open_domain, retriever):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("retriever", ["bm25"], indirect=True)
def test_eval_pipeline(document_store, reader, retriever):
# add eval data (SQUAD format)
document_store.add_eval_data(
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_pa
assert metrics["Retriever"]["ndcg"] == 0.5


@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs", ["bm25"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_extractive_qa_labels_with_filters(reader, retriever_with_docs, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion test/pipelines/test_standard_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_document_search_pipeline_batch(retriever, document_store):


@pytest.mark.integration
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch", "dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs", ["bm25", "dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_documentsearch_es_authentication(retriever_with_docs, document_store_with_docs: ElasticsearchDocumentStore):
if isinstance(retriever_with_docs, (DensePassageRetriever, EmbeddingRetriever)):
Expand Down