Skip to content

Commit

Permalink
fix(time): allow zero length arrays as arguments
Browse files Browse the repository at this point in the history
This fixes many of the routines in caput.time to allow zero length
arrays as arguments. To do this it introduces a decorator `scalarize`
which allows code to be written which assumes it is receiving a non-zero
length array as an argument, while the decorator takes care of scalar
and zero length array arguments.
  • Loading branch information
jrs65 committed Oct 14, 2020
1 parent 956d3a1 commit 1a4b324
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 12 deletions.
82 changes: 82 additions & 0 deletions caput/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,88 @@ def __get__(self, obj, type=None):
return _vectorize_desc


def scalarize(dtype=np.float64):
"""Handle scalars and other iterables being passed to numpy requiring code.
Parameters
----------
dtype : np.dtype, optional
The output datatype. Used only to set the return type of zero-length arrays.
Returns
-------
vectorized_function : func
"""

class _scalarize_desc(object):
# See
# http://www.ianbicking.org/blog/2008/10/decorators-and-descriptors.html
# for a description of this pattern

def __init__(self, func):
# Save a reference to the function and set various properties so the
# docstrings etc. get passed through
self.func = func
self.__doc__ = func.__doc__
self.__name__ = func.__name__
self.__module__ = func.__module__

def __call__(self, *args, **kwargs):
# This gets called whenever the wrapped function is invoked

args, scalar, empty = zip(*[self._make_array(a) for a in args])

if all(empty):
return np.array([], dtype=dtype)

ret = self.func(*args, **kwargs)

if all(scalar):
ret = ret[0]

return ret

def _make_array(self, x):
# Change iterables to arrays and scalars into length-1 arrays

from skyfield import timelib

# Special handling for the slightly awkward skyfield types
if isinstance(x, timelib.Time):

if isinstance(x.tt, np.ndarray):
scalar = False
else:
scalar = True
x = x.ts.tt_jd(np.array([x.tt]))

elif isinstance(x, np.ndarray):
scalar = False

elif isinstance(x, (list, tuple)):
x = np.array(x)
scalar = False

else:
x = np.array([x])
scalar = True

return (x, scalar, len(x) == 0)

def __get__(self, obj, type=None):

# As a descriptor, this gets called whenever this is used to wrap a
# function, and simply binds it to the instance

if obj is None:
return self

new_func = self.func.__get__(obj, type)
return self.__class__(new_func)

return _scalarize_desc


def open_h5py_mpi(f, mode, use_mpi=True, comm=None):
"""Ensure that we have an h5py File object.
Expand Down
9 changes: 7 additions & 2 deletions caput/tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def test_lsd_array():
itimes = obs.lsd_to_unix(lsds)
assert times == approx(itimes, rel=1e-5, abs=1e-5)

# Check that it works with zero length arrays
assert obs.lsd_to_unix(np.array([])).size == 0


def test_datetime_to_string():
dt = datetime(2014, 4, 21, 16, 33, 12, 12356)
Expand All @@ -230,8 +233,7 @@ def test_string_to_datetime():
def test_from_unix_time():
"""Make sure we are properly parsing the unix time.
This is as much a test of ephem as our code. See issue #29 on the
PyEphem github page.
This is as much a test of Skyfield as our code.
"""

unix_time = random.random() * 2e6
Expand Down Expand Up @@ -357,3 +359,6 @@ def test_ensure_unix():
assert (ctime.ensure_unix(dt_list) == ut_array).all()
assert (ctime.ensure_unix(ut_array) == ut_array).all()
assert ctime.ensure_unix(sf_array) == approx(ut_array, rel=1e-10, abs=1e-4)

# Check that it works for zero length arrays
assert ctime.ensure_unix(np.array([])).size == 0
28 changes: 18 additions & 10 deletions caput/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@
import warnings

import numpy as np
from skyfield import timelib

from . import config
from .misc import vectorize
from .misc import vectorize, scalarize


# Approximate number of seconds in a sidereal second.
Expand Down Expand Up @@ -405,6 +406,7 @@ def transit_RA(self, time):
return ra


@scalarize()
def unix_to_skyfield_time(unix_time):
"""Formats the Unix time into a time that can be interpreted by ephem.
Expand Down Expand Up @@ -440,10 +442,11 @@ def unix_to_skyfield_time(unix_time):
return t


@scalarize()
def unix_to_era(unix_time):
"""Calculate the Earth Rotation Angle for a given time.
The Earth Rotation Angle is the angle between the Celetial and Terrestrial
The Earth Rotation Angle is the angle between the Celestial and Terrestrial
Intermediate origins, and is a modern replacement for the Greenwich Sidereal
Time.
Expand All @@ -467,6 +470,7 @@ def unix_to_era(unix_time):
return 360.0 * era


@scalarize()
def era_to_unix(era, time0):
"""Calculate the UNIX time for a given Earth Rotation Angle.
Expand Down Expand Up @@ -505,7 +509,7 @@ def era_to_unix(era, time0):
return time0 + diff_time - leap_seconds


@vectorize()
@vectorize(otypes=[object])
def unix_to_datetime(unix_time):
"""Converts unix time to a :class:`~datetime.datetime` object.
Expand All @@ -531,7 +535,7 @@ def unix_to_datetime(unix_time):
return naive_datetime_to_utc(dt)


@vectorize()
@vectorize(otypes=[np.float64])
def datetime_to_unix(dt):
"""Converts a :class:`~datetime.datetime` object to the unix time.
Expand Down Expand Up @@ -559,6 +563,7 @@ def datetime_to_unix(dt):
return since_epoch.total_seconds()


@vectorize(otypes=[np.unicode])
def datetime_to_timestr(dt):
"""Converts a :class:`~datetime.datetime` to "YYYYMMDDTHHMMSSZ" format.
Expand All @@ -582,6 +587,7 @@ def datetime_to_timestr(dt):
return dt.strftime("%Y%m%dT%H%M%SZ")


@vectorize(otypes=[object])
def timestr_to_datetime(time_str):
"""Converts date "YYYYMMDDTHHMMSS*" to a :class:`~datetime.datetime`.
Expand All @@ -603,6 +609,7 @@ def timestr_to_datetime(time_str):
return datetime.strptime(time_str[:15], "%Y%m%dT%H%M%S")


@scalarize(dtype=np.int64)
def leap_seconds_between(time_a, time_b):
"""Determine how many leap seconds occurred between two Unix times.
Expand Down Expand Up @@ -645,6 +652,7 @@ def leap_seconds_between(time_a, time_b):
return time_shift_int


@scalarize()
def ensure_unix(time):
"""Convert the input time to Unix time format.
Expand All @@ -659,11 +667,9 @@ def ensure_unix(time):
Output time.
"""

time0 = np.array(time).flatten()[0] if hasattr(time, "__len__") else time

if isinstance(time0, datetime):
if isinstance(time[0], datetime):
return datetime_to_unix(time)
elif isinstance(time0, basestring):
elif isinstance(time[0], basestring):
return datetime_to_unix(timestr_to_datetime(time))
else:

Expand All @@ -672,21 +678,23 @@ def ensure_unix(time):
try:
from skyfield import timelib

if isinstance(time0, timelib.Time):
if isinstance(time[0], timelib.Time):
return datetime_to_unix(time.utc_datetime())
except ImportError:
pass

# Finally try and convert into a float.
try:
return np.float64(time)
if np.issubdtype(time.dtype, np.number):
return time.astype(np.float64)
except TypeError:
raise TypeError("Could not convert %s into a UNIX time" % repr(type(time)))


_warned_utc_datetime = False


@vectorize(otypes=[object])
def naive_datetime_to_utc(dt):
"""Add UTC timezone info to a naive datetime.
Expand Down

0 comments on commit 1a4b324

Please sign in to comment.