Skip to content

Commit

Permalink
ENH: support RGB in pcolormesh (wrapped case)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylehofmann authored and rcomer committed Jun 5, 2023
1 parent bd42c77 commit 24ed4a8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
29 changes: 18 additions & 11 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,24 +1922,31 @@ def _wrap_quadmesh(self, collection, **kwargs):
"map it must be fully transparent.",
stacklevel=3)

# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
if C.ndim == 3:
# RGB(A) array.
if _MPL_VERSION.release < (3, 8):
raise ValueError(
"GeoQuadMesh wrapping requires Matplotlib v3.8 or later")
pcolormesh_data = C.copy()
pcolormesh_data[mask, -1] = 0 # TODO: add trailing dim for RGB
if C.shape[-1] == 3:
pcolormesh_data = ma.dstack((C, np.ones(C.shape[:2],
dtype=C.dtype)))
elif C.shape[-1] == 4:
pcolormesh_data = C.copy()
else:
raise ValueError("Last dimension of 3-dimensional input must"
f" have length 3 or 4, not {C.shape[-1]}")
if C_mask is not None:
full_mask = np.any(C_mask, axis=-1)
full_mask |= mask
else:
full_mask = mask
pcolormesh_data[full_mask, -1] = 0
collection.set_array(pcolormesh_data)

# mask needs an extra trailing dimension.
mask = np.broadcast_to(mask[:, :, np.newaxis], C.shape)
C_mask = None

# mask will need an extra trailing dimension later
mask = np.broadcast_to(mask[..., np.newaxis], C.shape)
else:

# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)

# create the masked array to be used with this pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = np.ma.array(C, mask=full_mask)
Expand Down
34 changes: 25 additions & 9 deletions lib/cartopy/tests/mpl/test_mpl_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,21 @@ def test_cursor_values():
SKIP_PRE_MPL38 = pytest.mark.skipif(
MPL_VERSION.release[:2] < (3, 8), reason='mpl < 3.8')
PARAMETRIZE_PCOLORMESH_WRAP = pytest.mark.parametrize(
'as_rgba',
[False, pytest.param(True, marks=SKIP_PRE_MPL38)],
ids=['standard', 'rgba'])
'mesh_data_kind',
[
'standard',
pytest.param('rgb', marks=SKIP_PRE_MPL38),
pytest.param('rgba', marks=SKIP_PRE_MPL38),
],
ids=['standard', 'rgb', 'rgba'],
)


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap1.png',
tolerance=1.27)
def test_pcolormesh_global_with_wrap1(as_rgba):
def test_pcolormesh_global_with_wrap1(mesh_data_kind):
# make up some realistic data with bounds (such as data from the UM)
nx, ny = 36, 18
xbnds = np.linspace(0, 360, nx, endpoint=True)
Expand All @@ -264,10 +269,12 @@ def test_pcolormesh_global_with_wrap1(as_rgba):
data = data[:-1, :-1]
fig = plt.figure()

if as_rgba:
if mesh_data_kind in ('rgb', 'rgba'):
cmap = plt.get_cmap()
norm = mcolors.Normalize()
data = cmap(norm(data))
if mesh_data_kind == 'rgb':
data = data[..., 0:3]

ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False)
Expand Down Expand Up @@ -338,7 +345,7 @@ def test_pcolormesh_get_array_with_mask():
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap2.png',
tolerance=1.87)
def test_pcolormesh_global_with_wrap2(as_rgba):
def test_pcolormesh_global_with_wrap2(mesh_data_kind):
# make up some realistic data with bounds (such as data from the UM)
nx, ny = 36, 18
xbnds, xstep = np.linspace(0, 360, nx - 1, retstep=True, endpoint=True)
Expand All @@ -353,10 +360,12 @@ def test_pcolormesh_global_with_wrap2(as_rgba):
data = data[:-1, :-1]
fig = plt.figure()

if as_rgba:
if mesh_data_kind in ('rgb', 'rgba'):
cmap = plt.get_cmap()
norm = mcolors.Normalize()
data = cmap(norm(data))
if mesh_data_kind == 'rgb':
data = data[..., 0:3]

ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False)
Expand All @@ -375,7 +384,7 @@ def test_pcolormesh_global_with_wrap2(as_rgba):
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png',
tolerance=1.42)
def test_pcolormesh_global_with_wrap3(as_rgba):
def test_pcolormesh_global_with_wrap3(mesh_data_kind):
nx, ny = 33, 17
xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True)
ybnds = np.linspace(91.25, -91.25, ny, endpoint=True)
Expand All @@ -393,10 +402,17 @@ def test_pcolormesh_global_with_wrap3(as_rgba):
data = np.ma.masked_greater(data, 2.6)
fig = plt.figure()

if as_rgba:
if mesh_data_kind in ('rgb', 'rgba'):
if mesh_data_kind == 'rgb':
mask = np.ma.getmaskarray(data)
cmap = plt.get_cmap()
norm = mcolors.Normalize()
data = cmap(norm(data))
if mesh_data_kind == 'rgb':
data = data[..., 0:3]
# Use data's mask as an alpha channel
mask = np.broadcast_to(mask[..., np.newaxis], data.shape).copy()
data = np.ma.array(data, mask=mask)

ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45))
c = ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(),
Expand Down

0 comments on commit 24ed4a8

Please sign in to comment.