Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data_returned and data_available #93

Merged
merged 3 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions optimade/server/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_attribute_fields(self) -> set:
@abstractmethod
def find(
self, params: EntryListingQueryParams
) -> Tuple[List[EntryResource], bool, NonnegativeInt, set]:
) -> Tuple[List[EntryResource], NonnegativeInt, bool, set]:
"""
Fetches results and indicates if more data is available.

Expand All @@ -56,7 +56,7 @@ def find(
params (EntryListingQueryParams): entry listing URL query params

Returns:
Tuple[List[Entry], bool, NonnegativeInt, set]: (results, more_data_available, data_available, fields)
Tuple[List[Entry], NonnegativeInt, bool, set]: (results, data_returned, more_data_available, fields)

"""

Expand Down Expand Up @@ -93,40 +93,43 @@ def count(self, **kwargs):
for k in list(kwargs.keys()):
if k not in ("filter", "skip", "limit", "hint", "maxTimeMS"):
del kwargs[k]
if "filter" not in kwargs: # "filter" is needed for count_documents()
kwargs["filter"] = {}
return self.collection.count_documents(**kwargs)

def find(
self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
) -> Tuple[List[EntryResource], bool, NonnegativeInt, set]:
) -> Tuple[List[EntryResource], NonnegativeInt, bool, set]:
criteria = self._parse_params(params)
if isinstance(params, EntryListingQueryParams):
criteria_nolimit = criteria.copy()
del criteria_nolimit["limit"]
nresults_now = self.count(**criteria)
nresults_total = self.count(**criteria_nolimit)
more_data_available = nresults_now < nresults_total
data_available = nresults_total
else:
more_data_available = False
data_available = self.count(**criteria)
if data_available > 1:
raise HTTPException(
status_code=404,
detail=f"Instead of a single entry, {data_available} entries were found",
)

all_fields = criteria.pop("fields")
if getattr(params, "response_fields", False):
fields = set(params.response_fields.split(","))
else:
fields = all_fields.copy()

results = []
for doc in self.collection.find(**criteria):
results.append(self.resource_cls(**self.resource_mapper.map_back(doc)))

if isinstance(params, SingleEntryQueryParams):
if isinstance(params, EntryListingQueryParams):
nresults_now = len(results)
criteria_nolimit = criteria.copy()
criteria_nolimit.pop("limit", None)
data_returned = self.count(**criteria_nolimit)
more_data_available = nresults_now < data_returned
else:
# SingleEntryQueryParams, e.g., /structures/{entry_id}
data_returned = 1
more_data_available = False
if len(results) > 1:
raise HTTPException(
status_code=404,
detail=f"Instead of a single entry, {len(results)} entries were found",
)
results = results[0] if results else None

return results, more_data_available, data_available, all_fields - fields
return results, data_returned, more_data_available, all_fields - fields

def _alias_filter(self, filter_: dict) -> dict:
res = {}
Expand Down
76 changes: 39 additions & 37 deletions optimade/server/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,180 +200,182 @@ class FilterTests(unittest.TestCase):
def test_custom_field(self):
request = '/structures?filter=_exmpl__mp_chemsys="Ac"'
expected_ids = ["mpf_1"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_id(self):
request = "/structures?filter=id=mpf_2"
expected_ids = ["mpf_2"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_geq(self):
request = "/structures?filter=nelements>=9"
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_gt(self):
request = "/structures?filter=nelements>8"
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_gt_none(self):
request = "/structures?filter=nelements>9"
expected_ids = []
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_list_has(self):
request = '/structures?filter=elements HAS "Ti"'
expected_ids = ["mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_page_limit(self):
request = '/structures?filter=elements HAS "Ac"&page_limit=2'
expected_ids = ["mpf_1", "mpf_2"]
self._check_response(request, expected_ids)
expected_return = 6
self._check_response(request, expected_ids, expected_return)

request = '/structures?page_limit=2&filter=elements HAS "Ac"'
expected_ids = ["mpf_1", "mpf_2"]
self._check_response(request, expected_ids)
expected_return = 6
self._check_response(request, expected_ids, expected_return)

@unittest.skip("Skipping HAS ALL until implemented in server code.")
def test_list_has_all(self):
request = '/structures?filter=elements HAS ALL "Ba","F","H","Mn","O","Re","Si"'
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=elements HAS ALL "Re","Ti"'
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

@unittest.skip("Skipping HAS ANY until implemented in server code.")
def test_list_has_any(self):
request = '/structures?filter=elements HAS ANY "Re","Ti"'
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_list_length_basic(self):
request = "/structures?filter=LENGTH elements = 9"
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

@unittest.skip("Skipping LENGTH until implemented in server code.")
def test_list_length(self):
request = "/structures?filter=LENGTH elements = 9"
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = "/structures?filter=LENGTH elements >= 9"
expected_ids = ["mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = "/structures?filter=LENGTH structure_features > 0"
expected_ids = []
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

@unittest.skip("Skipping HAS ONLY until implemented in server code.")
def test_list_has_only(self):
request = '/structures?filter=elements HAS ONLY "Ac"'
expected_ids = ["mpf_1"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

@unittest.skip("Skipping correlated list query until implemented in server code.")
def test_list_correlated(self):
request = '/structures?filter=elements:elements_ratios HAS "Ag":"0.2"'
expected_ids = ["mpf_259"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_is_known(self):
request = "/structures?filter=nsites IS KNOWN AND nsites>=44"
expected_ids = ["mpf_551", "mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = "/structures?filter=lattice_vectors IS KNOWN AND nsites>=44"
expected_ids = ["mpf_551", "mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_aliased_is_known(self):
request = "/structures?filter=id IS KNOWN AND nsites>=44"
expected_ids = ["mpf_551", "mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = "/structures?filter=chemical_formula_reduced IS KNOWN AND nsites>=44"
expected_ids = ["mpf_551", "mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = (
"/structures?filter=chemical_formula_descriptive IS KNOWN AND nsites>=44"
)
expected_ids = ["mpf_551", "mpf_3803", "mpf_3819"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_aliased_fields(self):
request = '/structures?filter=chemical_formula_anonymous="A"'
expected_ids = ["mpf_1", "mpf_200"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=chemical_formula_anonymous CONTAINS "A2BC"'
expected_ids = ["mpf_2", "mpf_3", "mpf_110"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_string_contains(self):
request = '/structures?filter=chemical_formula_descriptive CONTAINS "c2Ag"'
expected_ids = ["mpf_3", "mpf_2"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_string_start(self):
request = (
'/structures?filter=chemical_formula_descriptive STARTS WITH "Ag2CSNCl"'
)
expected_ids = ["mpf_259"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_string_end(self):
request = '/structures?filter=chemical_formula_descriptive ENDS WITH "NClO4"'
expected_ids = ["mpf_259"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_list_has_and(self):
request = '/structures?filter=elements HAS "Ac" AND nelements=1'
expected_ids = ["mpf_1"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_not_or_and_precedence(self):
request = '/structures?filter=NOT elements HAS "Ac" AND nelements=1'
expected_ids = ["mpf_200"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=nelements=1 AND NOT elements HAS "Ac"'
expected_ids = ["mpf_200"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=NOT elements HAS "Ac" AND nelements=1 OR nsites=1'
expected_ids = ["mpf_1", "mpf_200"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=elements HAS "Ac" AND nelements>1 AND nsites=1'
expected_ids = []
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def test_brackets(self):
request = '/structures?filter=elements HAS "Ac" AND nelements=1 OR nsites=1'
expected_ids = ["mpf_200", "mpf_1"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

request = '/structures?filter=(elements HAS "Ac" AND nelements=1) OR (elements HAS "Ac" AND nsites=1)'
expected_ids = ["mpf_1"]
self._check_response(request, expected_ids)
self._check_response(request, expected_ids, len(expected_ids))

def _check_response(self, request, expected_id):
def _check_response(self, request, expected_ids, expected_return):
try:
response = self.client.get(request)
self.assertEqual(
response.status_code, 200, msg=f"Request failed: {response.json()}"
)
response = response.json()
response_ids = [struct["id"] for struct in response["data"]]
self.assertEqual(sorted(expected_id), sorted(response_ids))
self.assertEqual(response["meta"]["data_returned"], len(expected_id))
self.assertEqual(sorted(expected_ids), sorted(response_ids))
self.assertEqual(response["meta"]["data_returned"], expected_return)
except Exception as exc:
print("Request attempted:")
print(f"http://localhost:5000{request}")
Expand Down
16 changes: 10 additions & 6 deletions optimade/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def get_entries(
params: EntryListingQueryParams,
) -> EntryResponseMany:
"""Generalized /{entry} endpoint getter"""
results, more_data_available, data_available, fields = collection.find(params)
results, data_returned, more_data_available, fields = collection.find(params)

if more_data_available:
parse_result = urllib.parse.urlparse(str(request.url))
Expand All @@ -140,7 +140,10 @@ def get_entries(
links=links,
data=results,
meta=meta_values(
str(request.url), len(results), data_available, more_data_available
url=str(request.url),
data_returned=data_returned,
data_available=len(collection),
more_data_available=more_data_available,
),
)

Expand All @@ -153,7 +156,7 @@ def get_single_entry(
params: SingleEntryQueryParams,
) -> EntryResponseOne:
params.filter = f'id="{entry_id}"'
results, more_data_available, data_available, fields = collection.find(params)
results, data_returned, more_data_available, fields = collection.find(params)

if more_data_available:
raise StarletteHTTPException(
Expand All @@ -166,13 +169,14 @@ def get_single_entry(
if fields and results is not None:
results = handle_response_fields(results, fields)[0]

data_returned = 1 if results else 0

return response(
links=links,
data=results,
meta=meta_values(
str(request.url), data_returned, data_available, more_data_available
url=str(request.url),
data_returned=data_returned,
data_available=len(collection),
more_data_available=more_data_available,
),
)

Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
filterwarnings =
ignore:.*PY_SSIZE_T_CLEAN will be required for '#' formats.*:DeprecationWarning
ignore:.*"@coroutine" decorator is deprecated since Python 3.8, use "async def" instead.*:DeprecationWarning
ignore:.*Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated.*:DeprecationWarning