Skip to content

Commit

Permalink
Make metadata decoding eager.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 22, 2022
1 parent c12c160 commit 8c352a4
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 171 deletions.
16 changes: 8 additions & 8 deletions python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def __init__(self, tree_sequence, breakpoints=None):
# to describe it in terms of the tables now if we want to have an
# independent implementation.
ll_ts = self._tree_sequence._ll_tree_sequence
site_metadata_decoder = tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().site
).decode_row
mutation_metadata_decoder = tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().mutation
).decode_row

def make_mutation(id_):
(
Expand All @@ -158,11 +164,8 @@ def make_mutation(id_):
time=time,
derived_state=derived_state,
parent=parent,
metadata=metadata,
metadata=mutation_metadata_decoder(metadata),
edge=edge,
metadata_decoder=tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().mutation
).decode_row,
)

for j in range(tree_sequence.num_sites):
Expand All @@ -173,10 +176,7 @@ def make_mutation(id_):
position=pos,
ancestral_state=ancestral_state,
mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations],
metadata=metadata,
metadata_decoder=tskit.metadata.parse_metadata_schema(
ll_ts.get_table_metadata_schemas().site
).decode_row,
metadata=site_metadata_decoder(metadata),
)
)

Expand Down
2 changes: 2 additions & 0 deletions python/tests/test_file_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,11 @@ def test_msprime_v_0_3_0(self):
ts = tskit.load_legacy(path)
self.verify_tree_sequence(ts)

@pytest.mark.skip("file has undecodable metadata")
def test_tskit_v_0_3_3(self):
path = os.path.join(test_data_dir, "old-formats", "tskit-0.3.3.trees")
ts = tskit.load(path)
self.verify_0_3_3(ts)
self.verify_tree_sequence(ts)


Expand Down
38 changes: 1 addition & 37 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4412,33 +4412,7 @@ def test_metadata(self):
# Test decoding
instances = self.get_instances(5)
for j, inst in enumerate(instances):
assert inst.metadata == ("x" * j) + "decoded"

# Decoder doesn't effect equality
(inst,) = self.get_instances(1)
(inst2,) = self.get_instances(1)
assert inst == inst2
inst._metadata = "different"
assert inst != inst2

def test_decoder_run_once(self):
# For a given instance, the decoded metadata should be cached, with the decoder
# called once
(inst,) = self.get_instances(1)
times_run = 0

# Hack in a tracing decoder
def decoder(m):
nonlocal times_run
times_run += 1
return m.decode() + "decoded"

inst._metadata_decoder = decoder
assert times_run == 0
_ = inst.metadata
assert times_run == 1
_ = inst.metadata
assert times_run == 1
assert inst.metadata == (b"x" * j)


class TestIndividualContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin):
Expand All @@ -4451,7 +4425,6 @@ def get_instances(self, n):
parents=[j],
nodes=[j],
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand All @@ -4467,7 +4440,6 @@ def get_instances(self, n):
population=j,
individual=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand All @@ -4482,7 +4454,6 @@ def get_instances(self, n):
parent=j,
child=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
id=j,
)
for j in range(n)
Expand All @@ -4498,7 +4469,6 @@ def get_instances(self, n):
ancestral_state="A" * j,
mutations=TestMutationContainer().get_instances(j),
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand All @@ -4515,7 +4485,6 @@ def get_instances(self, n):
derived_state="A" * j,
parent=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand All @@ -4529,7 +4498,6 @@ def test_nan_equality(self):
derived_state="A" * 42,
parent=42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
b = tskit.Mutation(
id=42,
Expand All @@ -4538,7 +4506,6 @@ def test_nan_equality(self):
derived_state="A" * 42,
parent=42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
c = tskit.Mutation(
id=42,
Expand All @@ -4548,7 +4515,6 @@ def test_nan_equality(self):
derived_state="A" * 42,
parent=42,
metadata=b"x" * 42,
metadata_decoder=lambda m: m.decode() + "decoded",
)
assert a == a
assert a == b
Expand All @@ -4573,7 +4539,6 @@ def get_instances(self, n):
dest=j,
time=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand All @@ -4585,7 +4550,6 @@ def get_instances(self, n):
tskit.Population(
id=j,
metadata=b"x" * j,
metadata_decoder=lambda m: m.decode() + "decoded",
)
for j in range(n)
]
Expand Down
60 changes: 0 additions & 60 deletions python/tskit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import json
import pprint
import struct
import types
from itertools import islice
from typing import Any
from typing import Mapping
Expand Down Expand Up @@ -732,65 +731,6 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema:
return MetadataSchema(decoded)


class _CachedMetadata:
"""
Descriptor for lazy decoding of metadata on attribute access.
"""

def __get__(self, row, owner):
if row._metadata_decoder is not None:
# Some classes that use this are frozen so we need to directly setattr.
__builtins__object__setattr__(
row, "_metadata", row._metadata_decoder(row._metadata)
)
# Decoder being None indicates that metadata is decoded
__builtins__object__setattr__(row, "_metadata_decoder", None)
return row._metadata

def __set__(self, row, value):
__builtins__object__setattr__(row, "_metadata", value)


def lazy_decode(own_init=False):
def _lazy_decode(cls):
"""
Modifies a dataclass such that it lazily decodes metadata, if it is encoded.
If the metadata passed to the constructor is encoded a `metadata_decoder`
parameter must be also be passed.
"""
if not own_init:
wrapped_init = cls.__init__

# Intercept the init to record the decoder
def new_init(self, *args, metadata_decoder=None, **kwargs):
__builtins__object__setattr__(
self, "_metadata_decoder", metadata_decoder
)
wrapped_init(self, *args, **kwargs)

cls.__init__ = new_init

# Add a descriptor to the class to decode and cache metadata
cls.metadata = _CachedMetadata()

# Add slots needed to the class
slots = cls.__slots__
slots.extend(["_metadata", "_metadata_decoder"])
dict_ = dict()
sloted_members = dict()
for k, v in cls.__dict__.items():
if k not in slots:
dict_[k] = v
elif not isinstance(v, types.MemberDescriptorType):
sloted_members[k] = v
new_cls = type(cls.__name__, cls.__bases__, dict_)
for k, v in sloted_members.items():
setattr(new_cls, k, v)
return new_cls

return _lazy_decode


class MetadataProvider:
"""
Abstract superclass of container objects that provide metadata.
Expand Down
Loading

0 comments on commit 8c352a4

Please sign in to comment.