Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SQLDocumentStore tests #3517

Merged
merged 6 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ jobs:
ES_JAVA_OPTS: "-Xms128m -Xmx256m"
ports:
- 9200:9200
# env:
# ELASTICSEARCH_HOST: "elasticsearch"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from previous PR

steps:
- uses: actions/checkout@v3

Expand All @@ -154,7 +152,35 @@ jobs:

- name: Run tests
run: |
pytest -x -m "document_store and integration" test/document_stores/test_elasticsearch.py
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_elasticsearch.py

- uses: act10ns/slack@v1
with:
status: ${{ job.status }}
channel: '#haystack'
if: failure() && github.repository_owner == 'deepset-ai' && github.ref == 'refs/heads/main'

integration-tests-sql:
name: Integration / SQL / ${{ matrix.os }}
needs:
- unit-tests
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest,macos-latest,windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3

- name: Setup Python
uses: ./.github/actions/python_cache/

- name: Install Haystack
run: pip install -U .[sql]

- name: Run tests
run: |
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_sql.py

- uses: act10ns/slack@v1
with:
Expand All @@ -179,8 +205,6 @@ jobs:
ES_JAVA_OPTS: "-Xms128m -Xmx256m"
ports:
- 9200:9200
# env:
# OPENSEARCH_HOST: "opensearch"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover from previous PR

steps:
- uses: actions/checkout@v3

Expand All @@ -192,7 +216,7 @@ jobs:

- name: Run tests
run: |
pytest -x -m "document_store and integration" test/document_stores/test_opensearch.py
pytest --maxfail=5 -m "document_store and integration" test/document_stores/test_opensearch.py

- uses: act10ns/slack@v1
with:
Expand Down
26 changes: 21 additions & 5 deletions haystack/document_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,16 +460,30 @@ def write_labels(self, labels, index=None, headers: Optional[Dict[str, str]] = N
# self.write_documents(documents=[label.document], index=index, duplicate_documents="skip")

# TODO: Handle label meta data

# Sanitize fields to adhere to SQL constraints
answer = label.answer
if answer is not None:
answer = answer.to_json()

no_answer = label.no_answer
if label.no_answer is None:
no_answer = False

document = label.document
if document is not None:
document = document.to_json()

label_orm = LabelORM(
id=label.id,
no_answer=label.no_answer,
no_answer=no_answer,
# document_id=label.document.id,
document=label.document.to_json(),
document=document,
origin=label.origin,
query=label.query,
is_correct_answer=label.is_correct_answer,
is_correct_document=label.is_correct_document,
answer=label.answer.to_json(),
answer=answer,
pipeline_id=label.pipeline_id,
index=index,
)
Expand Down Expand Up @@ -576,11 +590,13 @@ def _convert_sql_row_to_document(self, row) -> Document:
return document

def _convert_sql_row_to_label(self, row) -> Label:
# doc = self._convert_sql_row_to_document(row.document)
answer = row.answer
if answer is not None:
answer = Answer.from_json(answer)

label = Label(
query=row.query,
answer=Answer.from_json(row.answer), # type: ignore
answer=answer,
document=Document.from_json(row.document),
is_correct_answer=row.is_correct_answer,
is_correct_document=row.is_correct_document,
Expand Down
5 changes: 1 addition & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,7 @@ def get_document_store(
recreate_index: bool = True,
): # cosine is default similarity as dot product is not supported by Weaviate
document_store: BaseDocumentStore
if document_store_type == "sql":
document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")

elif document_store_type == "memory":
if document_store_type == "memory":
document_store = InMemoryDocumentStore(
return_embedding=True,
embedding_dim=embedding_dim,
Expand Down
52 changes: 40 additions & 12 deletions test/document_stores/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_write_documents(self, ds, documents):
ds.write_documents(documents)
docs = ds.get_all_documents()
assert len(docs) == len(documents)
for i, doc in enumerate(docs):
expected = documents[i]
assert doc.id == expected.id
expected_ids = set(doc.id for doc in documents)
ids = set(doc.id for doc in docs)
assert ids == expected_ids

@pytest.mark.integration
def test_write_labels(self, ds, labels):
Expand Down Expand Up @@ -142,27 +142,41 @@ def test_get_all_documents_with_incorrect_filter_value(self, ds, documents):
assert len(result) == 0

@pytest.mark.integration
def test_extended_filter(self, ds, documents):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was too big in any case, it was ported as it was in the first step of the refactoring but breaking it down was functional to this PR as the docstore doesn't fully support metadata and some tests are expected not to pass

def test_eq_filters(self, ds, documents):
ds.write_documents(documents)

# Test comparison operators individually

result = ds.get_all_documents(filters={"year": {"$eq": "2020"}})
assert len(result) == 3
result = ds.get_all_documents(filters={"year": "2020"})
assert len(result) == 3

@pytest.mark.integration
def test_in_filters(self, ds, documents):
ds.write_documents(documents)

result = ds.get_all_documents(filters={"year": {"$in": ["2020", "2021", "n.a."]}})
assert len(result) == 6
result = ds.get_all_documents(filters={"year": ["2020", "2021", "n.a."]})
assert len(result) == 6

@pytest.mark.integration
def test_ne_filters(self, ds, documents):
ds.write_documents(documents)

result = ds.get_all_documents(filters={"year": {"$ne": "2020"}})
assert len(result) == 6

@pytest.mark.integration
def test_nin_filters(self, ds, documents):
ds.write_documents(documents)

result = ds.get_all_documents(filters={"year": {"$nin": ["2020", "2021", "n.a."]}})
assert len(result) == 3

@pytest.mark.integration
def test_comparison_filters(self, ds, documents):
ds.write_documents(documents)

result = ds.get_all_documents(filters={"numbers": {"$gt": 0}})
assert len(result) == 3

Expand All @@ -175,11 +189,17 @@ def test_extended_filter(self, ds, documents):
result = ds.get_all_documents(filters={"numbers": {"$lte": 2.0}})
assert len(result) == 6

# Test compound filters
@pytest.mark.integration
def test_compound_filters(self, ds, documents):
ds.write_documents(documents)

result = ds.get_all_documents(filters={"year": {"$lte": "2021", "$gte": "2020"}})
assert len(result) == 6

@pytest.mark.integration
def test_simplified_filters(self, ds, documents):
ds.write_documents(documents)

filters = {"$and": {"year": {"$lte": "2021", "$gte": "2020"}, "name": {"$in": ["name_0", "name_1"]}}}
result = ds.get_all_documents(filters=filters)
assert len(result) == 4
Expand All @@ -188,6 +208,9 @@ def test_extended_filter(self, ds, documents):
result = ds.get_all_documents(filters=filters_simplified)
assert len(result) == 4

@pytest.mark.integration
def test_nested_condition_filters(self, ds, documents):
ds.write_documents(documents)
filters = {
"$and": {
"year": {"$lte": "2021", "$gte": "2020"},
Expand Down Expand Up @@ -223,8 +246,12 @@ def test_extended_filter(self, ds, documents):
result = ds.get_all_documents(filters=filters_simplified)
assert len(result) == 5

# Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore

@pytest.mark.integration
def test_nested_condition_not_filters(self, ds, documents):
"""
Test nested logical operations within "$not", important as we apply De Morgan's laws in WeaviateDocumentstore
"""
ds.write_documents(documents)
filters = {
"$not": {
"$or": {
Expand All @@ -234,8 +261,9 @@ def test_extended_filter(self, ds, documents):
}
}
result = ds.get_all_documents(filters=filters)
docs_meta = result[0].meta["numbers"]
assert len(result) == 3

docs_meta = result[0].meta["numbers"]
assert [2, 4] == docs_meta

# Test same logical operator twice on same level
Expand Down Expand Up @@ -289,8 +317,8 @@ def test_duplicate_documents_skip(self, ds, documents):
updated_docs.append(updated_d)

ds.write_documents(updated_docs, duplicate_documents="skip")
result = ds.get_all_documents()
assert result[0].meta["name"] == "name_0"
for d in ds.get_all_documents():
assert d.meta["name"] != "Updated"

@pytest.mark.integration
def test_duplicate_documents_overwrite(self, ds, documents):
Expand Down
58 changes: 1 addition & 57 deletions test/document_stores/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,6 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}


def test_get_all_documents_with_correct_filters_legacy_sqlite(docs, tmp_path):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was redundant at this point

document_store_with_docs = get_document_store("sql", tmp_path)
document_store_with_docs.write_documents(docs)

document_store_with_docs.use_windowed_query = False
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
assert len(documents) == 1
assert documents[0].meta["name"] == "filename2"

documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]})
assert len(documents) == 2
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}


def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
documents = document_store_with_docs.get_all_documents(filters={"incorrect_meta_field": ["test2"]})
assert len(documents) == 0
Expand All @@ -198,7 +183,7 @@ def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs)


# See test_pinecone.py
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "sql", "weaviate", "memory"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch", "weaviate", "memory"], indirect=True)
def test_extended_filter(document_store_with_docs):
# Test comparison operators individually
documents = document_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "test1"}})
Expand Down Expand Up @@ -410,47 +395,6 @@ def test_write_document_meta(document_store: BaseDocumentStore):
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"


@pytest.mark.parametrize("document_store", ["sql"], indirect=True)
def test_sql_write_document_invalid_meta(document_store: BaseDocumentStore):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two tests were ported to the new class

documents = [
{
"content": "dict_with_invalid_meta",
"valid_meta_field": "test1",
"invalid_meta_field": [1, 2, 3],
"name": "filename1",
"id": "1",
},
Document(
content="document_object_with_invalid_meta",
meta={"valid_meta_field": "test2", "invalid_meta_field": [1, 2, 3], "name": "filename2"},
id="2",
),
]
document_store.write_documents(documents)
documents_in_store = document_store.get_all_documents()
assert len(documents_in_store) == 2

assert document_store.get_document_by_id("1").meta == {"name": "filename1", "valid_meta_field": "test1"}
assert document_store.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}


@pytest.mark.parametrize("document_store", ["sql"], indirect=True)
def test_sql_write_different_documents_same_vector_id(document_store: BaseDocumentStore):
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
doc2 = {"content": "content 2", "name": "doc2", "id": "2", "vector_id": "vector_id"}

document_store.write_documents([doc1], index="index1")
documents_in_index1 = document_store.get_all_documents(index="index1")
assert len(documents_in_index1) == 1
document_store.write_documents([doc2], index="index2")
documents_in_index2 = document_store.get_all_documents(index="index2")
assert len(documents_in_index2) == 1

document_store.write_documents([doc1], index="index3")
with pytest.raises(Exception, match=r"(?i)unique"):
document_store.write_documents([doc2], index="index3")


def test_write_document_index(document_store: BaseDocumentStore):
document_store.delete_index("haystack_test_one")
document_store.delete_index("haystack_test_two")
Expand Down
Loading