Skip to content

Commit

Permalink
Add Support For SchemaEvolution on Enumerations (#1834)
Browse files Browse the repository at this point in the history
This PR adds `SchemaEvolution.add_enumeration` and `SchemaEvolution.drop_enumeration`. The bindings were done in a previous PR but never added to `SchemaEvolution`.
  • Loading branch information
nguyenv committed Sep 20, 2023
1 parent 3e13a83 commit 799a434
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tiledb/schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tiledb

from .enumeration import Enumeration
from .main import ArraySchemaEvolution as ASE


Expand Down Expand Up @@ -29,6 +30,20 @@ def drop_attribute(self, attr_name: str):

self.ase.drop_attribute(attr_name)

def add_enumeration(self, enmr: Enumeration):
"""Add the given enumeration to the schema evolution plan.
Note: this function does not apply any changes; the changes are
only applied when `ArraySchemaEvolution.array_evolve` is called."""

self.ase.add_enumeration(enmr)

def drop_enumeration(self, enmr_name: str):
"""Drop the given enumeration (by name) in the schema evolution.
Note: this function does not apply any changes; the changes are
only applied when `ArraySchemaEvolution.array_evolve` is called."""

self.ase.drop_enumeration(enmr_name)

def array_evolve(self, uri: str):
"""Apply ArraySchemaEvolution actions to Array at given URI."""

Expand Down
53 changes: 53 additions & 0 deletions tiledb/tests/test_schema_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,56 @@ def get_schema_timestamps(schema_uri):
se.array_evolve(uri)

assert 123456789 in get_schema_timestamps(schema_uri)


def test_schema_evolution_with_enmr(tmp_path):
ctx = tiledb.default_ctx()
se = tiledb.ArraySchemaEvolution(ctx)

uri = str(tmp_path)

attrs = [
tiledb.Attr(name="a1", dtype=np.float64),
tiledb.Attr(name="a2", dtype=np.int32),
]
dims = [tiledb.Dim(domain=(0, 3), dtype=np.uint64)]
domain = tiledb.Domain(*dims)
schema = tiledb.ArraySchema(domain=domain, attrs=attrs, sparse=False)
tiledb.Array.create(uri, schema)

data1 = {
"a1": np.arange(5, 9),
"a2": np.random.randint(0, 1e7, size=4).astype(np.int32),
}

with tiledb.open(uri, "w") as A:
A[:] = data1

with tiledb.open(uri) as A:
assert not A.schema.has_attr("a3")

newattr = tiledb.Attr("a3", dtype=np.int8, enum_label="e3")
se.add_attribute(newattr)

with pytest.raises(tiledb.TileDBError) as excinfo:
se.array_evolve(uri)
assert " Attribute refers to an unknown enumeration" in str(excinfo.value)

se.add_enumeration(tiledb.Enumeration("e3", True, np.arange(0, 8)))
se.array_evolve(uri)

with tiledb.open(uri) as A:
assert A.schema.has_attr("a3")
assert A.attr("a3").enum_label == "e3"

se.drop_enumeration("e3")

with pytest.raises(tiledb.TileDBError) as excinfo:
se.array_evolve(uri)
assert "the enumeration has not been loaded" in str(excinfo.value)

se.drop_attribute("a3")
se.array_evolve(uri)

with tiledb.open(uri) as A:
assert not A.schema.has_attr("a3")

0 comments on commit 799a434

Please sign in to comment.