Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add DenseRetriever abstraction #3252

Merged
merged 53 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9b74ab9
support cosine similiarity with faiss
tstadel Sep 14, 2022
b81cc00
update docs
tstadel Sep 14, 2022
6138fdf
update api docs
tstadel Sep 14, 2022
0b65cd0
fix tests
tstadel Sep 14, 2022
b107cfe
Revert "update api docs"
tstadel Sep 14, 2022
7178027
fix api docs
tstadel Sep 14, 2022
c9c050f
collapse test
tstadel Sep 14, 2022
4786d8b
rename similairity to space_type mappings
tstadel Sep 14, 2022
c441a1a
only normalize for faiss
tstadel Sep 15, 2022
fbd5983
Merge branch 'main' into os_faiss_cosine
tstadel Sep 15, 2022
77869f9
fix merge
tstadel Sep 15, 2022
25592f5
fix docs normalization
tstadel Sep 19, 2022
277d3ea
get rid of List[np.array]
tstadel Sep 19, 2022
2f08839
update docs
tstadel Sep 19, 2022
9cb849b
fix tests and tutorials
tstadel Sep 19, 2022
cf1f575
fix mypy
tstadel Sep 19, 2022
6714e1a
fix mypy
tstadel Sep 19, 2022
088a00d
fix mypy again
tstadel Sep 19, 2022
fae55bf
again mypy
tstadel Sep 19, 2022
52ebe5c
blacken
tstadel Sep 19, 2022
1b284ad
update tutorial 4 docs
tstadel Sep 19, 2022
6df02ab
fix embeddingretriever
tstadel Sep 19, 2022
408d81b
fix faiss
tstadel Sep 19, 2022
b16ab45
move dense specific logic to DenseRetriever
tstadel Sep 19, 2022
a17fc4c
fix mypy
tstadel Sep 20, 2022
51a8154
cosine tests for all documents stores
tstadel Sep 20, 2022
6cee6f2
fix pinecone
tstadel Sep 20, 2022
a56d424
add docstring
tstadel Sep 20, 2022
6ad31f4
docstring corrections
tstadel Sep 20, 2022
603c2ab
update docs
tstadel Sep 20, 2022
cefd688
add integration test marker
tstadel Sep 20, 2022
02571b9
docstrings update
tstadel Sep 20, 2022
7e68c6c
update docs
tstadel Sep 20, 2022
e39492c
fix typo
tstadel Sep 20, 2022
3d5995d
update docs
tstadel Sep 20, 2022
6e4c0a2
fix MockDenseRetriever
tstadel Sep 20, 2022
0bcadac
run integration tests for all documentstores
tstadel Sep 20, 2022
e93c9c0
fix test_update_embeddings_cosine_similarity
tstadel Sep 20, 2022
f12eece
fix faiss tests not running
tstadel Sep 20, 2022
2c74536
blacken
tstadel Sep 20, 2022
a5c3dc5
make test_cosine_sanity_check integration test
tstadel Sep 20, 2022
8df4ce1
Merge branch 'os_faiss_cosine' into dense_retriever
tstadel Sep 20, 2022
18b7497
update docs
tstadel Sep 20, 2022
95337c4
Merge branch 'main' of github.com:deepset-ai/haystack into dense_retr…
tstadel Sep 20, 2022
b23574a
fix imports
tstadel Sep 20, 2022
59de91c
import DenseRetriever normally
tstadel Sep 21, 2022
ce1796e
update docs
tstadel Sep 21, 2022
2faa7e2
fix deepcopy of documents
tstadel Sep 21, 2022
e4c688f
Merge branch 'main' into dense_retriever
tstadel Sep 21, 2022
83cf8f3
update schema
tstadel Sep 21, 2022
08770ee
Merge branch 'dense_retriever' of github.com:deepset-ai/haystack into…
tstadel Sep 21, 2022
dea5c05
Revert "update schema"
tstadel Sep 21, 2022
44fdbf6
fix schema for ci manually
tstadel Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/_src/api/api/document_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### BaseElasticsearchDocumentStore.update\_embeddings

```python
def update_embeddings(retriever,
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
Expand Down Expand Up @@ -2097,7 +2097,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### InMemoryDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -2913,7 +2913,7 @@ None
#### FAISSDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -3277,7 +3277,7 @@ None
#### Milvus1DocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -3681,7 +3681,7 @@ exists.
#### Milvus2DocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -4398,7 +4398,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### WeaviateDocumentStore.update\_embeddings

```python
def update_embeddings(retriever,
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
Expand Down Expand Up @@ -5452,7 +5452,7 @@ Parameter options:
#### PineconeDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Union[Dict, List, str, int,
Expand Down
97 changes: 72 additions & 25 deletions docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,60 @@ Performing training on this class according to the TF-IDF algorithm.

# Module dense

<a id="dense.DenseRetriever"></a>

## DenseRetriever

```python
class DenseRetriever(BaseRetriever)
```

Base class for all dense retrievers.

<a id="dense.DenseRetriever.embed_queries"></a>

#### DenseRetriever.embed\_queries

```python
@abstractmethod
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries.

**Arguments**:

- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.DenseRetriever.embed_documents"></a>

#### DenseRetriever.embed\_documents

```python
@abstractmethod
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents.

**Arguments**:

- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.DensePassageRetriever"></a>

## DensePassageRetriever

```python
class DensePassageRetriever(BaseRetriever)
class DensePassageRetriever(DenseRetriever)
```

Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
Expand Down Expand Up @@ -842,36 +890,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### DensePassageRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries using the query encoder
Create embeddings for a list of queries using the query encoder.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.DensePassageRetriever.embed_documents"></a>

#### DensePassageRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents using the passage encoder
Create embeddings for a list of documents using the passage encoder.

**Arguments**:

- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack.
- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents / passages shape (batch_size, embedding_dim)
Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.DensePassageRetriever.train"></a>

Expand Down Expand Up @@ -1005,7 +1053,7 @@ Load DensePassageRetriever from the specified directory.
## TableTextRetriever

```python
class TableTextRetriever(BaseRetriever)
class TableTextRetriever(DenseRetriever)
```

Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
Expand Down Expand Up @@ -1198,25 +1246,25 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### TableTextRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries using the query encoder
Create embeddings for a list of queries using the query encoder.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.TableTextRetriever.embed_documents"></a>

#### TableTextRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of text documents and / or tables using the text passage encoder and
Expand All @@ -1225,12 +1273,11 @@ the table encoder.

**Arguments**:

- `docs`: List of Document objects used to represent documents / passages in
a standardized way within Haystack.
- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.TableTextRetriever.train"></a>

Expand Down Expand Up @@ -1370,7 +1417,7 @@ Load TableTextRetriever from the specified directory.
## EmbeddingRetriever

```python
class EmbeddingRetriever(BaseRetriever)
class EmbeddingRetriever(DenseRetriever)
```

<a id="dense.EmbeddingRetriever.__init__"></a>
Expand Down Expand Up @@ -1638,36 +1685,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### EmbeddingRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.EmbeddingRetriever.embed_documents"></a>

#### EmbeddingRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents.

**Arguments**:

- `docs`: List of documents to embed
- `documents`: List of documents to embed.

**Returns**:

Embeddings, one per input document
Embeddings, one per input document, shape: (docs, embedding_dim)

<a id="dense.EmbeddingRetriever.train"></a>

Expand Down
23 changes: 22 additions & 1 deletion haystack/document_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from haystack.schema import Document, Label, MultiLabel
from haystack.nodes.base import BaseComponent
from haystack.errors import DuplicateDocumentError
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl

Expand Down Expand Up @@ -698,6 +698,27 @@ def _get_duplicate_labels(

return [label for label in labels if label.id in duplicate_ids]

@classmethod
def _validate_embeddings_shape(cls, embeddings: np.ndarray, num_documents: int, embedding_dim: int):
"""
Validates the shape of model-generated embeddings against expected values for indexing.

:param embeddings: Embeddings to validate
:param num_documents: Number of documents the embeddings were generated for
:param embedding_dim: Number of embedding dimensions to expect
"""
num_embeddings, embedding_size = embeddings.shape
if num_embeddings != num_documents:
raise DocumentStoreError(
"The number of embeddings does not match the number of documents: "
f"({num_embeddings} != {num_documents})"
)
if embedding_size != embedding_dim:
raise RuntimeError(
f"Embedding dimensions of the model ({embedding_size}) don't match the embedding dimensions of the document store ({embedding_dim}). "
f"Initiate {cls.__name__} again with arg embedding_dim={embedding_size}."
)


class KeywordDocumentStore(BaseDocumentStore):
"""
Expand Down
17 changes: 6 additions & 11 deletions haystack/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError, HaystackError
from haystack.nodes.retriever import DenseRetriever

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1400,7 +1401,7 @@ def _get_raw_similarity_score(self, score):

def update_embeddings(
self,
retriever,
retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -1482,16 +1483,10 @@ def update_embeddings(
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_documents(document_batch) # type: ignore
if len(document_batch) != len(embeddings):
raise DocumentStoreError(
"The number of embeddings does not match the number of documents in the batch "
f"({len(embeddings)} != {len(document_batch)})"
)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate ElasticsearchDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
)
embeddings = retriever.embed_documents(document_batch)
self._validate_embeddings_shape(
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
)
ZanSara marked this conversation as resolved.
Show resolved Hide resolved

doc_updates = []
for doc, emb in zip(document_batch, embeddings):
Expand Down
Loading