Skip to content

Commit

Permalink
chore: add DenseRetriever abstraction (#3252)
Browse files Browse the repository at this point in the history
* support cosine similiarity with faiss

* update docs

* update api docs

* fix tests

* Revert "update api docs"

This reverts commit 6138fdf.

* fix api docs

* collapse test

* rename similairity to space_type mappings

* only normalize for faiss

* fix merge

* fix docs normalization

* get rid of List[np.array]

* update docs

* fix tests and tutorials

* fix mypy

* fix mypy

* fix mypy again

* again mypy

* blacken

* update tutorial  4 docs

* fix embeddingretriever

* fix faiss

* move dense specific logic to DenseRetriever

* fix mypy

* cosine tests for all documents stores

* fix pinecone

* add docstring

* docstring corrections

* update docs

* add integration test marker

* docstrings update

* update docs

* fix typo

* update docs

* fix MockDenseRetriever

* run integration tests for all documentstores

* fix test_update_embeddings_cosine_similarity

* fix faiss tests not running

* blacken

* make test_cosine_sanity_check integration test

* update docs

* fix imports

* import  DenseRetriever normally

* update docs

* fix deepcopy of documents

* update schema

* Revert "update schema"

This reverts commit 83cf8f3.

* fix schema for ci manually
  • Loading branch information
tstadel authored Sep 21, 2022
1 parent 492a804 commit b10e2c3
Show file tree
Hide file tree
Showing 19 changed files with 333 additions and 259 deletions.
14 changes: 7 additions & 7 deletions docs/_src/api/api/document_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### BaseElasticsearchDocumentStore.update\_embeddings

```python
def update_embeddings(retriever,
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
Expand Down Expand Up @@ -2097,7 +2097,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### InMemoryDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -2913,7 +2913,7 @@ None
#### FAISSDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -3277,7 +3277,7 @@ None
#### Milvus1DocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -3681,7 +3681,7 @@ exists.
#### Milvus2DocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -4398,7 +4398,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### WeaviateDocumentStore.update\_embeddings

```python
def update_embeddings(retriever,
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
Expand Down Expand Up @@ -5452,7 +5452,7 @@ Parameter options:
#### PineconeDocumentStore.update\_embeddings

```python
def update_embeddings(retriever: "BaseRetriever",
def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Union[Dict, List, str, int,
Expand Down
97 changes: 72 additions & 25 deletions docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,60 @@ Performing training on this class according to the TF-IDF algorithm.

# Module dense

<a id="dense.DenseRetriever"></a>

## DenseRetriever

```python
class DenseRetriever(BaseRetriever)
```

Base class for all dense retrievers.

<a id="dense.DenseRetriever.embed_queries"></a>

#### DenseRetriever.embed\_queries

```python
@abstractmethod
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries.

**Arguments**:

- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.DenseRetriever.embed_documents"></a>

#### DenseRetriever.embed\_documents

```python
@abstractmethod
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents.

**Arguments**:

- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.DensePassageRetriever"></a>

## DensePassageRetriever

```python
class DensePassageRetriever(BaseRetriever)
class DensePassageRetriever(DenseRetriever)
```

Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
Expand Down Expand Up @@ -842,36 +890,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### DensePassageRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries using the query encoder
Create embeddings for a list of queries using the query encoder.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.DensePassageRetriever.embed_documents"></a>

#### DensePassageRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents using the passage encoder
Create embeddings for a list of documents using the passage encoder.

**Arguments**:

- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack.
- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents / passages shape (batch_size, embedding_dim)
Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.DensePassageRetriever.train"></a>

Expand Down Expand Up @@ -1005,7 +1053,7 @@ Load DensePassageRetriever from the specified directory.
## TableTextRetriever

```python
class TableTextRetriever(BaseRetriever)
class TableTextRetriever(DenseRetriever)
```

Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
Expand Down Expand Up @@ -1198,25 +1246,25 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### TableTextRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries using the query encoder
Create embeddings for a list of queries using the query encoder.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.TableTextRetriever.embed_documents"></a>

#### TableTextRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of text documents and / or tables using the text passage encoder and
Expand All @@ -1225,12 +1273,11 @@ the table encoder.

**Arguments**:

- `docs`: List of Document objects used to represent documents / passages in
a standardized way within Haystack.
- `documents`: List of documents to embed.

**Returns**:

Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
Embeddings of documents, one per input document, shape: (documents, embedding_dim)

<a id="dense.TableTextRetriever.train"></a>

Expand Down Expand Up @@ -1370,7 +1417,7 @@ Load TableTextRetriever from the specified directory.
## EmbeddingRetriever

```python
class EmbeddingRetriever(BaseRetriever)
class EmbeddingRetriever(DenseRetriever)
```

<a id="dense.EmbeddingRetriever.__init__"></a>
Expand Down Expand Up @@ -1638,36 +1685,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### EmbeddingRetriever.embed\_queries

```python
def embed_queries(texts: List[str]) -> List[np.ndarray]
def embed_queries(queries: List[str]) -> np.ndarray
```

Create embeddings for a list of queries.

**Arguments**:

- `texts`: Queries to embed
- `queries`: List of queries to embed.

**Returns**:

Embeddings, one per input queries
Embeddings, one per input query, shape: (queries, embedding_dim)

<a id="dense.EmbeddingRetriever.embed_documents"></a>

#### EmbeddingRetriever.embed\_documents

```python
def embed_documents(docs: List[Document]) -> List[np.ndarray]
def embed_documents(documents: List[Document]) -> np.ndarray
```

Create embeddings for a list of documents.

**Arguments**:

- `docs`: List of documents to embed
- `documents`: List of documents to embed.

**Returns**:

Embeddings, one per input document
Embeddings, one per input document, shape: (docs, embedding_dim)

<a id="dense.EmbeddingRetriever.train"></a>

Expand Down
23 changes: 22 additions & 1 deletion haystack/document_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from haystack.schema import Document, Label, MultiLabel
from haystack.nodes.base import BaseComponent
from haystack.errors import DuplicateDocumentError
from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl

Expand Down Expand Up @@ -698,6 +698,27 @@ def _get_duplicate_labels(

return [label for label in labels if label.id in duplicate_ids]

@classmethod
def _validate_embeddings_shape(cls, embeddings: np.ndarray, num_documents: int, embedding_dim: int):
"""
Validates the shape of model-generated embeddings against expected values for indexing.
:param embeddings: Embeddings to validate
:param num_documents: Number of documents the embeddings were generated for
:param embedding_dim: Number of embedding dimensions to expect
"""
num_embeddings, embedding_size = embeddings.shape
if num_embeddings != num_documents:
raise DocumentStoreError(
"The number of embeddings does not match the number of documents: "
f"({num_embeddings} != {num_documents})"
)
if embedding_size != embedding_dim:
raise RuntimeError(
f"Embedding dimensions of the model ({embedding_size}) don't match the embedding dimensions of the document store ({embedding_dim}). "
f"Initiate {cls.__name__} again with arg embedding_dim={embedding_size}."
)


class KeywordDocumentStore(BaseDocumentStore):
"""
Expand Down
17 changes: 6 additions & 11 deletions haystack/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError, HaystackError
from haystack.nodes.retriever import DenseRetriever

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1400,7 +1401,7 @@ def _get_raw_similarity_score(self, score):

def update_embeddings(
self,
retriever,
retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
update_existing_embeddings: bool = True,
Expand Down Expand Up @@ -1482,16 +1483,10 @@ def update_embeddings(
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_documents(document_batch) # type: ignore
if len(document_batch) != len(embeddings):
raise DocumentStoreError(
"The number of embeddings does not match the number of documents in the batch "
f"({len(embeddings)} != {len(document_batch)})"
)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate ElasticsearchDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
)
embeddings = retriever.embed_documents(document_batch)
self._validate_embeddings_shape(
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
)

doc_updates = []
for doc, emb in zip(document_batch, embeddings):
Expand Down
Loading

0 comments on commit b10e2c3

Please sign in to comment.