Skip to content

Commit

Permalink
Forbid usage of *args and **kwargs in any node's __init__ (#2362)
Browse files Browse the repository at this point in the history
* Add failing test

* Remove `**kwargs` from docstores' `__init__` functions (#2407)

* Remove kwargs from ESDocStore subclasses

* Remove kwargs from subclasses of SQLDocumentStore

* Remove kwargs from Weaviate

* Revert change in pinecone

* Fix tests

* Fix retriever test wirh weaviate

* Change Exception into DocumentStoreError

* Update Documentation & Code Style

* Remove `**kwargs` from `FARMReader` (#2413)

* Remove FARMReader kwargs without trying to replace them functionally

* Update Documentation & Code Style

* enforce same index values before and after saving/loading eval dataframes (#2398)

* Add tests for missing `__init__` and `super().__init__()` in custom nodes (#2350)

* Add tests for missing init and super

* Update Documentation & Code Style

* change in with endswith

* Move test in pipeline.py and change test in pipeline_yaml.py

* Update Documentation & Code Style

* Use caplog to test the warning

* Update Documentation & Code Style

* move tests into test_pipeline and use get_config

* Update Documentation & Code Style

* Unmock version name

* Improve variadic args test

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
ZanSara and github-actions[bot] authored Apr 14, 2022
1 parent 46a50fb commit 929c685
Show file tree
Hide file tree
Showing 17 changed files with 4,848 additions and 74 deletions.
13 changes: 8 additions & 5 deletions docs/_src/api/api/document_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class ElasticsearchDocumentStore(KeywordDocumentStore)
#### \_\_init\_\_

```python
def __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, recreate_index: bool = False, create_index: bool = True, refresh_type: str = "wait_for", similarity="dot_product", timeout=30, return_embedding: bool = False, duplicate_documents: str = "overwrite", index_type: str = "flat", scroll: str = "1d", skip_missing_embeddings: bool = True, synonyms: Optional[List] = None, synonym_type: str = "synonym", use_system_proxy: bool = False)
def __init__(host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, username: str = "", password: str = "", api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", scheme: str = "http", ca_certs: Optional[str] = None, verify_certs: bool = True, recreate_index: bool = False, create_index: bool = True, refresh_type: str = "wait_for", similarity: str = "dot_product", timeout: int = 30, return_embedding: bool = False, duplicate_documents: str = "overwrite", index_type: str = "flat", scroll: str = "1d", skip_missing_embeddings: bool = True, synonyms: Optional[List] = None, synonym_type: str = "synonym", use_system_proxy: bool = False)
```

A DocumentStore using Elasticsearch to store and query the documents for our search.
Expand Down Expand Up @@ -1231,7 +1231,7 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore)
#### \_\_init\_\_

```python
def __init__(verify_certs=False, scheme="https", username="admin", password="admin", port=9200, **kwargs)
def __init__(scheme: str = "https", username: str = "admin", password: str = "admin", host: Union[str, List[str]] = "localhost", port: Union[int, List[int]] = 9200, api_key_id: Optional[str] = None, api_key: Optional[str] = None, aws4auth=None, index: str = "document", label_index: str = "label", search_fields: Union[str, list] = "content", content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", embedding_dim: int = 768, custom_mapping: Optional[dict] = None, excluded_meta_data: Optional[list] = None, analyzer: str = "standard", ca_certs: Optional[str] = None, verify_certs: bool = False, recreate_index: bool = False, create_index: bool = True, refresh_type: str = "wait_for", similarity: str = "dot_product", timeout: int = 30, return_embedding: bool = False, duplicate_documents: str = "overwrite", index_type: str = "flat", scroll: str = "1d", skip_missing_embeddings: bool = True, synonyms: Optional[List] = None, synonym_type: str = "synonym", use_system_proxy: bool = False)
```

Document Store using OpenSearch (https://opensearch.org/). It is compatible with the AWS Elasticsearch Service.
Expand Down Expand Up @@ -2235,7 +2235,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_

```python
def __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
def __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, n_links: int = 64, ef_search: int = 20, ef_construction: int = 80)
```

**Arguments**:
Expand Down Expand Up @@ -2282,6 +2282,9 @@ If specified no other params besides faiss_config_path must be specified.
- `faiss_config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
- `n_links`: used only if index_factory == "HNSW"
- `ef_search`: used only if index_factory == "HNSW"
- `ef_construction`: used only if index_factory == "HNSW"

<a id="faiss.FAISSDocumentStore.write_documents"></a>

Expand Down Expand Up @@ -2545,7 +2548,7 @@ Usage:
#### \_\_init\_\_

```python
def __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", isolation_level: str = None, **kwargs, ,)
def __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", isolation_level: str = None)
```

**Arguments**:
Expand Down Expand Up @@ -3168,7 +3171,7 @@ The current implementation is not supporting the storage of labels, so you canno
#### \_\_init\_\_

```python
def __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "cosine", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite", **kwargs, ,)
def __init__(host: Union[str, List[str]] = "http://localhost", port: Union[int, List[int]] = 8080, timeout_config: tuple = (5, 15), username: str = None, password: str = None, index: str = "Document", embedding_dim: int = 768, content_field: str = "content", name_field: str = "name", similarity: str = "cosine", index_type: str = "hnsw", custom_schema: Optional[dict] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = "overwrite")
```

**Arguments**:
Expand Down
2 changes: 1 addition & 1 deletion docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interf
#### \_\_init\_\_

```python
def __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, devices: List[torch.device] = [], no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, confidence_threshold: Optional[float] = None, proxies: Optional[Dict[str, str]] = None, local_files_only=False, force_download=False, use_auth_token: Optional[Union[str, bool]] = None, **kwargs, ,)
def __init__(model_name_or_path: str, model_version: Optional[str] = None, context_window_size: int = 150, batch_size: int = 50, use_gpu: bool = True, devices: List[torch.device] = [], no_ans_boost: float = 0.0, return_no_answer: bool = False, top_k: int = 10, top_k_per_candidate: int = 3, top_k_per_sample: int = 1, num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, progress_bar: bool = True, duplicate_filtering: int = 0, use_confidence_scores: bool = True, confidence_threshold: Optional[float] = None, proxies: Optional[Dict[str, str]] = None, local_files_only=False, force_download=False, use_auth_token: Optional[Union[str, bool]] = None)
```

**Arguments**:
Expand Down
161 changes: 148 additions & 13 deletions haystack/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from haystack.schema import Document, Label
from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,8 +55,8 @@ def __init__(
recreate_index: bool = False,
create_index: bool = True,
refresh_type: str = "wait_for",
similarity="dot_product",
timeout=30,
similarity: str = "dot_product",
timeout: int = 30,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
Expand Down Expand Up @@ -179,9 +180,9 @@ def __init__(
self.scroll = scroll
self.skip_missing_embeddings: bool = skip_missing_embeddings
if similarity in ["cosine", "dot_product", "l2"]:
self.similarity = similarity
self.similarity: str = similarity
else:
raise Exception(
raise DocumentStoreError(
f"Invalid value {similarity} for similarity in ElasticSearchDocumentStore constructor. Choose between 'cosine', 'l2' and 'dot_product'"
)
if index_type in ["flat", "hnsw"]:
Expand Down Expand Up @@ -1592,7 +1593,42 @@ def delete_index(self, index: str):


class OpenSearchDocumentStore(ElasticsearchDocumentStore):
def __init__(self, verify_certs=False, scheme="https", username="admin", password="admin", port=9200, **kwargs):
def __init__(
self,
scheme: str = "https", # Mind this different default param
username: str = "admin", # Mind this different default param
password: str = "admin", # Mind this different default param
host: Union[str, List[str]] = "localhost",
port: Union[int, List[int]] = 9200,
api_key_id: Optional[str] = None,
api_key: Optional[str] = None,
aws4auth=None,
index: str = "document",
label_index: str = "label",
search_fields: Union[str, list] = "content",
content_field: str = "content",
name_field: str = "name",
embedding_field: str = "embedding",
embedding_dim: int = 768,
custom_mapping: Optional[dict] = None,
excluded_meta_data: Optional[list] = None,
analyzer: str = "standard",
ca_certs: Optional[str] = None,
verify_certs: bool = False, # Mind this different default param
recreate_index: bool = False,
create_index: bool = True,
refresh_type: str = "wait_for",
similarity: str = "dot_product",
timeout: int = 30,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
synonym_type: str = "synonym",
use_system_proxy: bool = False,
):
"""
Document Store using OpenSearch (https://opensearch.org/). It is compatible with the AWS Elasticsearch Service.
Expand Down Expand Up @@ -1662,14 +1698,44 @@ def __init__(self, verify_certs=False, scheme="https", username="admin", passwor
Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process.
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
"""
super().__init__(
scheme=scheme,
username=username,
password=password,
host=host,
port=port,
api_key_id=api_key_id,
api_key=api_key,
aws4auth=aws4auth,
index=index,
label_index=label_index,
search_fields=search_fields,
content_field=content_field,
name_field=name_field,
embedding_field=embedding_field,
embedding_dim=embedding_dim,
custom_mapping=custom_mapping,
excluded_meta_data=excluded_meta_data,
analyzer=analyzer,
ca_certs=ca_certs,
verify_certs=verify_certs,
recreate_index=recreate_index,
create_index=create_index,
refresh_type=refresh_type,
similarity=similarity,
timeout=timeout,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index_type=index_type,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
synonym_type=synonym_type,
use_system_proxy=use_system_proxy,
)
self.embeddings_field_supports_similarity = False
self.similarity_to_space_type = {"cosine": "cosinesimil", "dot_product": "innerproduct", "l2": "l2"}
self.space_type_to_similarity = {v: k for k, v in self.similarity_to_space_type.items()}
# Overwrite default kwarg values of parent class so that in default cases we can initialize
# an OpenSearchDocumentStore without provding any arguments
super(OpenSearchDocumentStore, self).__init__(
verify_certs=verify_certs, scheme=scheme, username=username, password=password, port=port, **kwargs
)

def query_by_embedding(
self,
Expand Down Expand Up @@ -1914,7 +1980,7 @@ def _create_document_index(self, index_name: str, headers: Optional[Dict[str, st
if not self.client.indices.exists(index=index_name, headers=headers):
raise e

def _get_embedding_field_mapping(self, similarity: Optional[str]):
def _get_embedding_field_mapping(self, similarity: str):
space_type = self.similarity_to_space_type[similarity]
method: dict = {"space_type": space_type, "name": "hnsw", "engine": "nmslib"}

Expand Down Expand Up @@ -2049,10 +2115,79 @@ class OpenDistroElasticsearchDocumentStore(OpenSearchDocumentStore):
A DocumentStore which has an Open Distro for Elasticsearch service behind it.
"""

def __init__(self, similarity="cosine", **kwargs):
def __init__(
self,
scheme: str = "https",
username: str = "admin",
password: str = "admin",
host: Union[str, List[str]] = "localhost",
port: Union[int, List[int]] = 9200,
api_key_id: Optional[str] = None,
api_key: Optional[str] = None,
aws4auth=None,
index: str = "document",
label_index: str = "label",
search_fields: Union[str, list] = "content",
content_field: str = "content",
name_field: str = "name",
embedding_field: str = "embedding",
embedding_dim: int = 768,
custom_mapping: Optional[dict] = None,
excluded_meta_data: Optional[list] = None,
analyzer: str = "standard",
ca_certs: Optional[str] = None,
verify_certs: bool = False,
recreate_index: bool = False,
create_index: bool = True,
refresh_type: str = "wait_for",
similarity: str = "cosine", # Mind this different default param
timeout: int = 30,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
synonym_type: str = "synonym",
use_system_proxy: bool = False,
):
logger.warning(
"Open Distro for Elasticsearch has been replaced by OpenSearch! "
"See https://opensearch.org/faq/ for details. "
"We recommend using the OpenSearchDocumentStore instead."
)
super(OpenDistroElasticsearchDocumentStore, self).__init__(similarity=similarity, **kwargs)
super().__init__(
scheme=scheme,
username=username,
password=password,
host=host,
port=port,
api_key_id=api_key_id,
api_key=api_key,
aws4auth=aws4auth,
index=index,
label_index=label_index,
search_fields=search_fields,
content_field=content_field,
name_field=name_field,
embedding_field=embedding_field,
embedding_dim=embedding_dim,
custom_mapping=custom_mapping,
excluded_meta_data=excluded_meta_data,
analyzer=analyzer,
ca_certs=ca_certs,
verify_certs=verify_certs,
recreate_index=recreate_index,
create_index=create_index,
refresh_type=refresh_type,
similarity=similarity,
timeout=timeout,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index_type=index_type,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
synonym_type=synonym_type,
use_system_proxy=use_system_proxy,
)
Loading

0 comments on commit 929c685

Please sign in to comment.