From 0cd0c03bc3ef26280eb5bedfecc2c6fcd35d7b91 Mon Sep 17 00:00:00 2001 From: QQ Date: Fri, 23 Aug 2024 09:22:15 +0800 Subject: [PATCH] Fix exception when calling `label_index` on a Query object (#2044) * bug: No module named 'tiledb.multirange_indexer' * Add test --------- Co-authored-by: Agisilaos Kounelis --- tiledb/libtiledb.pyx | 2 +- tiledb/tests/test_dimension_label.py | 67 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tiledb/libtiledb.pyx b/tiledb/libtiledb.pyx index dd871839b4..0a20b112ff 100644 --- a/tiledb/libtiledb.pyx +++ b/tiledb/libtiledb.pyx @@ -2158,7 +2158,7 @@ cdef class Query(object): def label_index(self, labels): """Apply Array.label_index with query parameters.""" - from .multirange_indexer import LabelIndexer + from .multirange_indexing import LabelIndexer return LabelIndexer(self.array, tuple(labels), query=self) @property diff --git a/tiledb/tests/test_dimension_label.py b/tiledb/tests/test_dimension_label.py index 0daeab2e36..bb7c98e5bd 100644 --- a/tiledb/tests/test_dimension_label.py +++ b/tiledb/tests/test_dimension_label.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import numpy as np import pytest @@ -416,3 +418,68 @@ def test_dimension_label_round_trip_dense_var(self): result["value"], attr_data[:, index:] ) np.testing.assert_array_equal(result[label_name], label_index) + + @pytest.mark.skipif( + tiledb.libtiledb.version()[0] == 2 and tiledb.libtiledb.version()[1] < 15, + reason="dimension labels requires libtiledb version 2.15 or greater", + ) + def test_dimension_label_on_query(self): + uri = self.path("query_label_index") + + dim1 = tiledb.Dim("d1", domain=(1, 4)) + dim2 = tiledb.Dim("d2", domain=(1, 3)) + dom = tiledb.Domain(dim1, dim2) + att = tiledb.Attr("a1", dtype=np.int64) + dim_labels = { + 0: {"l1": dim1.create_label_schema("decreasing", np.int64)}, + 1: { + "l2": dim2.create_label_schema("increasing", np.int64), + "l3": dim2.create_label_schema("increasing", np.float64), + }, + } + schema = tiledb.ArraySchema(domain=dom, attrs=(att,), dim_labels=dim_labels) + tiledb.Array.create(uri, schema) + + a1_data = np.reshape(np.arange(1, 13), (4, 3)) + l1_data = np.arange(4, 0, -1) + l2_data = np.arange(-1, 2) + l3_data = np.linspace(0, 1.0, 3) + + with tiledb.open(uri, "w") as A: + A[:] = {"a1": a1_data, "l1": l1_data, "l2": l2_data, "l3": l3_data} + + with tiledb.open(uri, "r") as A: + np.testing.assert_equal( + A.query().label_index(["l1"])[3:4], + OrderedDict( + {"l1": np.array([4, 3]), "a1": np.array([[1, 2, 3], [4, 5, 6]])} + ), + ) + np.testing.assert_equal( + A.query().label_index(["l1", "l3"])[2, 0.5:1.0], + OrderedDict( + { + "l3": np.array([0.5, 1.0]), + "l1": np.array([2]), + "a1": np.array([[8, 9]]), + } + ), + ) + np.testing.assert_equal( + A.query().label_index(["l2"])[:, -1:0], + OrderedDict( + { + "l2": np.array([-1, 0]), + "a1": np.array([[1, 2], [4, 5], [7, 8], [10, 11]]), + }, + ), + ) + np.testing.assert_equal( + A.query().label_index(["l3"])[:, 0.5:1.0], + OrderedDict( + { + "l3": np.array([0.5, 1.0]), + "a1": np.array([[2, 3], [5, 6], [8, 9], [11, 12]]), + }, + ), + )