Skip to content

Commit

Permalink
Updates to address most comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
mairanteodoro committed Jul 12, 2023
1 parent f4f379d commit afa9e7a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/stcal/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .util import *
140 changes: 111 additions & 29 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Utility function for assign_wcs.
Common utility functions for datamodel alignment.
"""
import logging
Expand All @@ -22,12 +22,10 @@
log.setLevel(logging.DEBUG)


_MAX_SIP_DEGREE = 6


__all__ = [
"wcs_from_footprints",
"compute_scale",
"compute_fiducial",
"calc_rotation_matrix",
]

Expand All @@ -53,10 +51,12 @@ def compute_scale(
Reference WCS object from which to compute a scaling factor.
fiducial : tuple
Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating reference points.
Input fiducial of (RA, DEC) or (RA, DEC, Wavelength) used in calculating
reference points.
disp_axis : int
Dispersion axis integer. Assumes the same convention as `wcsinfo.dispersion_direction`
Dispersion axis integer. Assumes the same convention as
`wcsinfo.dispersion_direction`
pscale_ratio : int
Ratio of input to output pixel scale
Expand Down Expand Up @@ -104,7 +104,7 @@ def compute_scale(


def calc_rotation_matrix(
roll_ref: float, v3i_yang: float, vparity: int = 1
roll_ref: float, v3i_yangle: float, vparity: int = 1
) -> List[float]:
"""Calculate the rotation matrix.
Expand All @@ -113,7 +113,7 @@ def calc_rotation_matrix(
roll_ref : float
Telescope roll angle of V3 North over East at the ref. point in radians
v3i_yang : float
v3i_yangle : float
The angle between ideal Y-axis and V3 in radians.
vparity : int
Expand All @@ -128,17 +128,14 @@ def calc_rotation_matrix(
Notes
-----
The rotation is
pc1_1 | pc2_1
----------------
| pc1_1 pc2_1 |
| pc1_2 pc2_2 |
----------------
pc1_2 | pc2_2
"""
if vparity not in (1, -1):
raise ValueError(f"vparity should be 1 or -1. Input was: {vparity}")

rel_angle = roll_ref - (vparity * v3i_yang)
rel_angle = roll_ref - (vparity * v3i_yangle)

pc1_1 = vparity * np.cos(rel_angle)
pc1_2 = np.sin(rel_angle)
Expand All @@ -149,8 +146,22 @@ def calc_rotation_matrix(


def _calculate_fiducial_from_spatial_footprint(
spatial_footprint, fiducial, spatial_axes
):
spatial_footprint: np.ndarray,
) -> np.ndarray:
"""
Calculates the fiducial coordinates from a given spatial footprint.
Parameters
----------
spatial_footprint : `~numpy.ndarray`
A 2xN array containing the world coordinates of the WCS footprint's
bounding box, where N is the number of bounding box positions.
Returns
-------
lon_fiducial, lat_fiducial : `numpy.ndarray`, `numpy.ndarray`
The world coordinates of the fiducial point in the output coordinate frame.
"""
lon, lat = spatial_footprint
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x = np.cos(lat) * np.cos(lon)
Expand All @@ -164,14 +175,34 @@ def _calculate_fiducial_from_spatial_footprint(
lat_fiducial = np.rad2deg(
np.arctan2(z_mid, np.sqrt(x_mid**2 + y_mid**2))
)
fiducial[spatial_axes] = lon_fiducial, lat_fiducial
return lon_fiducial, lat_fiducial


def compute_fiducial(wcslist, bounding_box=None):
def compute_fiducial(wcslist: list, bounding_box=None) -> np.ndarray:
"""
For a celestial footprint this is the center.
For a spectral footprint, it is the beginning of the range.
Calculates the world coordinates of the fiducial point of a list of WCS objects.
For a celestial footprint this is the center. For a spectral footprint, it is the
beginning of its range.
Parameters
----------
wcslist : list
A list containing all the WCS objects for which the fiducial is to be
calculated.
bounding_box : `~astropy.modeling.bounding_box` or list, optional
The bounding box to be used when calculating the fiducial.
If a list is provided, it should be in the following format:
[[x0_lower, x0_upper], [x1_lower, x1_upper]].
Returns
-------
fiducial : `numpy.ndarray`
A two-elements array containing the world coordinates of the fiducial point
in the combined output coordinate frame.
Notes
-----
This function assumes all WCSs have the same output coordinate frame.
"""

Expand All @@ -186,8 +217,8 @@ def compute_fiducial(wcslist, bounding_box=None):

fiducial = np.empty(len(axes_types))
if spatial_footprint.any():
_calculate_fiducial_from_spatial_footprint(
spatial_footprint, fiducial, spatial_axes
fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint(
spatial_footprint
)
if spectral_footprint.any():
fiducial[spectral_axes] = spectral_footprint.min()
Expand All @@ -196,12 +227,17 @@ def compute_fiducial(wcslist, bounding_box=None):

def wcsinfo_from_model(input_model: SupportsDataWithWcs):
"""
Create a dict {wcs_keyword: array_of_values} pairs from a data model.
Creates a dict {wcs_keyword: array_of_values} pairs from a data model.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.JwstDataModel`
The input data model
The input data model.
Returns
-------
wcsinfo : dict
A dict containing the WCS FITS keywords and corresponding values.
"""
defaults = {
Expand All @@ -221,7 +257,7 @@ def wcsinfo_from_model(input_model: SupportsDataWithWcs):
val.append(v)
wcsinfo[key] = np.array(val)

pc = np.zeros((wcsaxes, wcsaxes))
pc = np.zeros((wcsaxes, wcsaxes), dtype=np.float32)
for i in range(1, wcsaxes + 1):
for j in range(1, wcsaxes + 1):
pc[i - 1, j - 1] = getattr(
Expand All @@ -234,8 +270,45 @@ def wcsinfo_from_model(input_model: SupportsDataWithWcs):


def _generate_tranform_from_datamodel(
refmodel, pscale_ratio, pscale, rotation, ref_fiducial
refmodel: SupportsDataWithWcs,
ref_fiducial: np.array,
pscale_ratio: int = None,
pscale: float = None,
rotation: float = None,
):
"""
Creates a transform from pixel to world coordinates based on a
reference datamodel's WCS.
Parameters
----------
refmodel : a valid datamodel
The datamodel that should be used as reference for calculating the
transform parameters.
pscale_ratio : int, None, optional
Ratio of input to output pixel scale. This parameter is only used when
pscale=`None` and, in that case, it is passed on to `compute_scale`.
pscale : float, None, optional
The plate scale. If `None`, the plate scale is calculated from the reference
datamodel.
rotation : float, None, optional
Position angle of output image's Y-axis relative to North.
A value of 0.0 would orient the final output image to be North up.
The default of `None` specifies that the images will not be rotated,
but will instead be resampled in the default orientation for the camera
with the x and y axes of the resampled image corresponding
approximately to the detector axes. Ignored when ``transform`` is
provided. If `None`, the rotation angle is extracted from the
reference model's `meta.wcsinfo.roll_ref`.
ref_fiducial : np.array
A two-elements array containing the world coordinates of the fiducial point.
Returns
-------
transform : `~astropy.modeling.core.CompoundModel`
An `~astropy.modeling` compound model containing the transform from pixel to
world coordinates.
"""
sky_axes = refmodel.meta.wcs._get_axes_indices().tolist()

# Need to put the rotation matrix (List[float, float, float, float])
Expand Down Expand Up @@ -298,9 +371,9 @@ def wcs_from_footprints(
Parameters
----------
dmodels : list of `~jwst.datamodels.JwstDataModel`
dmodels : list of valid datamodels
A list of data models.
refmodel : `~jwst.datamodels.JwstDataModel`, optional
refmodel : a valid datamodel, optional
This model's WCS is used as a reference.
WCS. The output coordinate frame, the projection and a
scaling and rotation transform is created from it. If not supplied
Expand Down Expand Up @@ -338,6 +411,11 @@ def wcs_from_footprints(
Right ascension and declination of the reference pixel. Automatically
computed if not provided.
Returns
-------
wnew : `~gwcs.WCS`
The WCS object associated with the combined input footprints.
"""
bb = bounding_box
wcslist = [im.meta.wcs for im in dmodels]
Expand Down Expand Up @@ -368,7 +446,11 @@ def wcs_from_footprints(

if transform is None:
transform = _generate_tranform_from_datamodel(
refmodel, pscale_ratio, pscale, rotation, ref_fiducial
refmodel=refmodel,
pscale_ratio=pscale_ratio,
pscale=pscale,
rotation=rotation,
ref_fiducial=ref_fiducial,
)

out_frame = refmodel.meta.wcs.output_frame
Expand Down
51 changes: 25 additions & 26 deletions tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,16 @@ def _create_wcs_object_without_distortion(
pscale,
shape,
):
fiducial_detector = tuple(shape.value)

# subtract 1 to account for pixel indexing starting at 0
shift = models.Shift(-(fiducial_detector[0] - 1)) & models.Shift(
-(fiducial_detector[1] - 1)
)
shift = models.Shift(-(shape[0] - 1)) & models.Shift(-(shape[1] - 1))

scale = models.Scale(pscale[0].to("deg")) & models.Scale(
pscale[1].to("deg")
)
scale = models.Scale(pscale[0]) & models.Scale(pscale[1])

tan = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(
fiducial_world[0],
fiducial_world[1],
180 * u.deg,
180,
)

det2sky = shift | scale | tan | celestial_rotation
Expand All @@ -54,8 +48,8 @@ def _create_wcs_object_without_distortion(
wcs_obj = WCS(pipeline)

wcs_obj.bounding_box = (
(-0.5, fiducial_detector[0] - 0.5),
(-0.5, fiducial_detector[0] - 0.5),
(-0.5, shape[0] - 0.5),
(-0.5, shape[0] - 0.5),
)

return wcs_obj
Expand All @@ -65,7 +59,7 @@ def _create_wcs_and_datamodel(fiducial_world, shape, pscale):
wcs = _create_wcs_object_without_distortion(
fiducial_world=fiducial_world, shape=shape, pscale=pscale
)
ra_ref, dec_ref = fiducial_world[0].value, fiducial_world[1].value
ra_ref, dec_ref = fiducial_world[0], fiducial_world[1]
return DataModel(
ra_ref=ra_ref,
dec_ref=dec_ref,
Expand Down Expand Up @@ -121,9 +115,9 @@ def test_compute_fiducial():
WCS's footprint.
"""

shape = (3, 3) * u.pix
fiducial_world = (0, 0) * u.deg
pscale = (0.05, 0.05) * u.arcsec
shape = (3, 3) # in pixels
fiducial_world = (0, 0) # in deg
pscale = (0.000014, 0.000014) # in deg/pixel

wcs = _create_wcs_object_without_distortion(
fiducial_world=fiducial_world, shape=shape, pscale=pscale
Expand All @@ -134,35 +128,40 @@ def test_compute_fiducial():
assert all(np.isclose(wcs(1, 1), computed_fiducial))


@pytest.mark.parametrize("pscales", [(0.05, 0.05), (0.1, 0.05)])
@pytest.mark.parametrize(
"pscales", [(0.000014, 0.000014), (0.000028, 0.000014)]
)
def test_compute_scale(pscales):
"""Test that util.compute_scale can properly determine the pixel scale of a
WCS object.
"""
shape = (3, 3) * u.pix
fiducial_world = (0, 0) * u.deg
pscale = (pscales[0], pscales[1]) * u.arcsec
shape = (3, 3) # in pixels
fiducial_world = (0, 0) # in deg
pscale = (pscales[0], pscales[1]) # in deg/pixel

wcs = _create_wcs_object_without_distortion(
fiducial_world=fiducial_world, shape=shape, pscale=pscale
)
expected_scale = np.sqrt(pscale[0].to("deg") * pscale[1].to("deg")).value
expected_scale = np.sqrt(pscale[0] * pscale[1])

computed_scale = compute_scale(wcs=wcs, fiducial=fiducial_world.value)
computed_scale = compute_scale(wcs=wcs, fiducial=fiducial_world)

assert np.isclose(expected_scale, computed_scale)


def test_wcs_from_footprints():
shape = (3, 3) * u.pix
fiducial_world = (10, 0) * u.deg
pscale = (0.1, 0.1) * u.arcsec
shape = (3, 3) # in pixels
fiducial_world = (10, 0) # in deg
pscale = (0.000028, 0.000028) # in deg/pixel

dm_1 = _create_wcs_and_datamodel(fiducial_world, shape, pscale)
wcs_1 = dm_1.meta.wcs

# new fiducial will be shifted by one pixel in both directions
fiducial_world -= pscale
# shift fiducial by one pixel in both directions and create a new WCS
fiducial_world = (
fiducial_world[0] - 0.000028,
fiducial_world[1] - 0.000028,
)
dm_2 = _create_wcs_and_datamodel(fiducial_world, shape, pscale)
wcs_2 = dm_2.meta.wcs

Expand Down

0 comments on commit afa9e7a

Please sign in to comment.