Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

link FITS_rec instances to hdu extensions on save #178

Merged
merged 1 commit into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Other
TARGCAT and TARGDESC, which record the target category and description
as given by the user in the APT. [#179]

Bug Fixes
---------

- Link FITS_rec instances to created HDU on save to avoid data duplication. [#178]


1.7.0 (2023-06-29)
==================
Expand Down
33 changes: 23 additions & 10 deletions src/stdatamodels/fits_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def _fits_array_writer(fits_context, validator, _, instance, schema):
if instance is None:
return

instance_id = id(instance)

instance = np.asanyarray(instance)

if not len(instance.shape):
Expand All @@ -297,6 +299,10 @@ def _fits_array_writer(fits_context, validator, _, instance, schema):
index=index, hdu_type=hdu_type)

hdu.data = instance
if instance_id in fits_context.extension_array_links:
if fits_context.extension_array_links[instance_id]() is not hdu:
raise ValueError("Linking one array to multiple hdus is not supported")
fits_context.extension_array_links[instance_id] = weakref.ref(hdu)
hdu.ver = index + 1


Expand Down Expand Up @@ -331,6 +337,7 @@ def __init__(self, hdulist):
self.hdulist = weakref.ref(hdulist)
self.comment_stack = []
self.sequence_index = None
self.extension_array_links = {}


def _get_validators(hdulist):
Expand All @@ -350,7 +357,7 @@ def _get_validators(hdulist):
'type': partial(_fits_type, fits_context),
})

return validators
return validators, fits_context


def _save_from_schema(hdulist, tree, schema):
Expand All @@ -370,24 +377,30 @@ def datetime_callback(node, json_id):
else:
kwargs = {}

validator = asdf_schema.get_validator(
schema, None, _get_validators(hdulist), **kwargs)
validators, context = _get_validators(hdulist)
validator = asdf_schema.get_validator(schema, None, validators, **kwargs)

# This actually kicks off the saving
validator.validate(tree, _schema=schema)

# Replace arrays in the tree that are identical to HDU arrays
# with ndarray-1.0.0 tagged objects with special source values
# that represent links to the surrounding FITS file.
def ndarray_callback(node, json_id):
if (isinstance(node, (np.ndarray, NDArrayType))):
# Now link extensions to items in the tree

def callback(node, json_id):
if id(node) in context.extension_array_links:
hdu = context.extension_array_links[id(node)]()
return _create_tagged_dict_for_fits_array(hdu, hdulist.index(hdu))
elif isinstance(node, (np.ndarray, NDArrayType)):
# in addition to links generated during validation
# replace arrays in the tree that are identical to HDU arrays
# with ndarray-1.0.0 tagged objects with special source values
# that represent links to the surrounding FITS file.
# This is important for general ASDF-in-FITS support
for hdu_index, hdu in enumerate(hdulist):
if hdu.data is not None and node is hdu.data:
return _create_tagged_dict_for_fits_array(hdu, hdu_index)

return node

tree = treeutil.walk_and_modify(tree, ndarray_callback)
tree = treeutil.walk_and_modify(tree, callback)

return tree

Expand Down
48 changes: 48 additions & 0 deletions tests/test_fits.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from astropy.io import fits
import numpy as np
Expand Down Expand Up @@ -627,3 +629,49 @@ def test_resave_duplication_bug(tmp_path):

with fits.open(fn1) as ff1, fits.open(fn2) as ff2:
assert ff1['ASDF'].size == ff2['ASDF'].size


def test_table_linking(tmp_path):
file_path = tmp_path / "test.fits"

schema = {
'title': 'Test data model',
'$schema': 'http://stsci.edu/schemas/fits-schema/fits-schema',
'type': 'object',
'properties': {
'meta': {
'type': 'object',
'properties': {}
},
'test_table': {
'title': 'Test table',
'fits_hdu': 'TESTTABL',
'datatype': [
{'name': 'A_COL', 'datatype': 'int8'},
{'name': 'B_COL', 'datatype': 'int8'}
]
}
}
}

with DataModel(schema=schema) as dm:
test_array = np.array([(1, 2), (3, 4)], dtype=[('A_COL', 'i1'), ('B_COL', 'i1')])

# assigning to the model will convert the array to a FITS_rec
dm.test_table = test_array
assert isinstance(dm.test_table, fits.FITS_rec)

# save the model (with the table)
dm.save(file_path)

# open the model and confirm that the table was linked to an hdu
with fits.open(file_path) as ff:
# read the bytes for the embedded ASDF content
asdf_bytes = ff['ASDF'].data.tobytes()

# get only the bytes for the tree (not blocks) by splitting
# on the yaml end document marker '...'
# on the first block magic sequence
tree_string = asdf_bytes.split(b'...')[0].decode('ascii')
unlinked_arrays = re.findall(r'source:\s+[^f]', tree_string)
assert not len(unlinked_arrays), unlinked_arrays