Skip to content

Commit

Permalink
add complicated logic to handle spectral-axis boolean array masking
Browse files Browse the repository at this point in the history
  • Loading branch information
keflavich committed May 13, 2016
1 parent 60ddb26 commit 4264006
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
82 changes: 72 additions & 10 deletions spectral_cube/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def dims_to_skip(shp1, shp2):
# b==1 is broadcastable but not desired
if a == 1:
dims.append(len(shp2) - ii - 1)
elif a == b:
elif a == b:
pass
else:
raise ValueError("This should not be possible")
Expand All @@ -75,6 +75,10 @@ def view_of_subset(shp1, shp2, view):
Given two shapes and a view, assuming that shape 1 can be broadcast
to shape 2, return the sub-view that applies to shape 1
"""
# if the view is 1-dimensional, we can't subset it
if not hasattr(view, '__len__'):
return view

dts = dims_to_skip(shp1, shp2)
if view:
cv_view = [x for ii,x in enumerate(view) if ii not in dts]
Expand All @@ -85,6 +89,25 @@ def view_of_subset(shp1, shp2, view):

return cv_view

def view_skipping_nulldims(shp1, shp2, view):
"""
Assuming shape 1 is broadcastable to shape 2, replace all elements of
'view' with 'None' (e.g., x[:]) if shp1 is broadcasting along that
dimension
"""
# if the view is 1-dimensional, we can't subset it
if not hasattr(view, '__len__'):
return view

dts = dims_to_skip(shp1, shp2)
if view:
new_view = [(slice(None) if ii in dts else x)
for ii,x in enumerate(view)]
else:
# if no view is specified, slice None in all dimensions
new_view = [x for ii,x in enumerate([slice(None)]*3)]

return new_view

class MaskBase(object):

Expand Down Expand Up @@ -343,13 +366,47 @@ def __init__(self, mask, wcs, shape=None, include=True):
if shape is not None and not is_broadcastable_and_smaller(mask.shape, shape):
raise ValueError("Mask cannot be broadcast to the specified shape.")
self._shape = shape or mask.shape
n_extra_dims = (len(self._shape)-mask.ndim)
if n_extra_dims > 0:
strides = (0,)*n_extra_dims + mask.strides
self._mask = as_strided(mask, shape=self.shape,
strides=strides)
else:
self._mask = mask
self._mask = mask

"""
Developer note (AG):
The logic in this following section seems overly complicated. All
of it is added to make sure that a 1D boolean array along the
spectral axis can be created. I thought this was possible
previously, but experience many errors in my latest attempt to use
one.
"""
# If a shape is given, we may need to broadcast to that shape
if shape is not None:
# these are dimensions that simply don't exist
n_empty_dims = (len(self._shape)-mask.ndim)

# these are dimensions of shape 1 that would be squeezed away but may
# be needed to make the arrays broadcastable (e.g., mask[:,None,None])
# Need to add n_empty_dims because (1,2) will broadcast to (3,1,2)
# and there will be no extra dims.
extra_dims = [ii
for ii,(sh1,sh2) in
enumerate(zip((0,)*n_empty_dims + mask.shape, shape))
if sh1 == 1 and sh1 != sh2]


# Add the [None,]'s and the nonexistant
n_extra_dims = n_empty_dims + len(extra_dims)

# if there are no extra dims, we're done, the original shape is fine
if n_extra_dims > 0:
strides = (0,)*n_empty_dims + mask.strides

for ed in extra_dims:
# all of the [None,] dims should have 0 stride
assert strides[ed] == 0,"Stride shape failure"

self._mask = as_strided(mask, shape=self.shape,
strides=strides)

# Make sure the mask shape matches the Mask object shape
assert self._mask.shape == self.shape,"Shape initialization failure"

def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
"""
Expand All @@ -373,11 +430,16 @@ def _validate_wcs(self, new_data=None, new_wcs=None, **kwargs):
self._wcs_whitelist.add(new_wcs)

def _include(self, data=None, wcs=None, view=()):
result_mask = self._mask[view]
#sub_view = view_skipping_nulldims(self._mask.shape, data.shape, view)
#print("view={0} sub_view={1}".format(view,sub_view))
sub_view = view
result_mask = self._mask[sub_view]
return result_mask if self._mask_type == 'include' else ~result_mask

def _exclude(self, data=None, wcs=None, view=()):
result_mask = self._mask[view]
#sub_view = view_skipping_nulldims(self._mask.shape, data.shape, view)
sub_view = view
result_mask = self._mask[sub_view]
return result_mask if self._mask_type == 'exclude' else ~result_mask

@property
Expand Down
5 changes: 3 additions & 2 deletions spectral_cube/spectral_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,13 +2449,14 @@ def __init__(self, *args, **kwargs):

beam_mask = BooleanArrayMask(goodbeams[:,None,None],
wcs=self._wcs,
shape=(len(beams),1,1),
shape=self.shape,
)
if not is_broadcastable_and_smaller(beam_mask.shape,
self._data.shape):
# this should never be allowed to happen
raise ValueError("Beam mask shape is not broadcastable to data shape: "
"%s vs %s" % (beam_mask.shape, self._data.shape))
assert beam_mask.shape == self.shape

new_mask = self._mask & beam_mask

Expand All @@ -2474,7 +2475,7 @@ def __init__(self, *args, **kwargs):
def __getitem__(self, view):

# Need to allow self[:], self[:,:]
if isinstance(view, (slice,int)):
if isinstance(view, (slice,int,np.int64)):
view = (view, slice(None), slice(None))
elif len(view) == 2:
view = view + (slice(None),)
Expand Down

0 comments on commit 4264006

Please sign in to comment.