Skip to content

Commit

Permalink
Fix exception when calling label_index on a Query object (#2044)
Browse files Browse the repository at this point in the history
* bug: No module named 'tiledb.multirange_indexer'

* Add test

---------

Co-authored-by: Agisilaos Kounelis <kounelisagis@gmail.com>
  • Loading branch information
sric0880 and kounelisagis committed Aug 23, 2024
1 parent aeb95f1 commit 0cd0c03
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tiledb/libtiledb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2158,7 +2158,7 @@ cdef class Query(object):

def label_index(self, labels):
"""Apply Array.label_index with query parameters."""
from .multirange_indexer import LabelIndexer
from .multirange_indexing import LabelIndexer
return LabelIndexer(self.array, tuple(labels), query=self)

@property
Expand Down
67 changes: 67 additions & 0 deletions tiledb/tests/test_dimension_label.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import numpy as np
import pytest

Expand Down Expand Up @@ -416,3 +418,68 @@ def test_dimension_label_round_trip_dense_var(self):
result["value"], attr_data[:, index:]
)
np.testing.assert_array_equal(result[label_name], label_index)

@pytest.mark.skipif(
tiledb.libtiledb.version()[0] == 2 and tiledb.libtiledb.version()[1] < 15,
reason="dimension labels requires libtiledb version 2.15 or greater",
)
def test_dimension_label_on_query(self):
uri = self.path("query_label_index")

dim1 = tiledb.Dim("d1", domain=(1, 4))
dim2 = tiledb.Dim("d2", domain=(1, 3))
dom = tiledb.Domain(dim1, dim2)
att = tiledb.Attr("a1", dtype=np.int64)
dim_labels = {
0: {"l1": dim1.create_label_schema("decreasing", np.int64)},
1: {
"l2": dim2.create_label_schema("increasing", np.int64),
"l3": dim2.create_label_schema("increasing", np.float64),
},
}
schema = tiledb.ArraySchema(domain=dom, attrs=(att,), dim_labels=dim_labels)
tiledb.Array.create(uri, schema)

a1_data = np.reshape(np.arange(1, 13), (4, 3))
l1_data = np.arange(4, 0, -1)
l2_data = np.arange(-1, 2)
l3_data = np.linspace(0, 1.0, 3)

with tiledb.open(uri, "w") as A:
A[:] = {"a1": a1_data, "l1": l1_data, "l2": l2_data, "l3": l3_data}

with tiledb.open(uri, "r") as A:
np.testing.assert_equal(
A.query().label_index(["l1"])[3:4],
OrderedDict(
{"l1": np.array([4, 3]), "a1": np.array([[1, 2, 3], [4, 5, 6]])}
),
)
np.testing.assert_equal(
A.query().label_index(["l1", "l3"])[2, 0.5:1.0],
OrderedDict(
{
"l3": np.array([0.5, 1.0]),
"l1": np.array([2]),
"a1": np.array([[8, 9]]),
}
),
)
np.testing.assert_equal(
A.query().label_index(["l2"])[:, -1:0],
OrderedDict(
{
"l2": np.array([-1, 0]),
"a1": np.array([[1, 2], [4, 5], [7, 8], [10, 11]]),
},
),
)
np.testing.assert_equal(
A.query().label_index(["l3"])[:, 0.5:1.0],
OrderedDict(
{
"l3": np.array([0.5, 1.0]),
"a1": np.array([[2, 3], [5, 6], [8, 9], [11, 12]]),
},
),
)

0 comments on commit 0cd0c03

Please sign in to comment.