Skip to content

Commit d7f73ab

Browse files
committed
Allow unit-ful image data
1 parent a142369 commit d7f73ab

File tree

7 files changed

+84
-14
lines changed

7 files changed

+84
-14
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Unit converters can now support units on images
2+
-----------------------------------------------
3+
4+
Matplotlib now supports using `~.axes.Axes.imshow` to plot data with units.
5+
For this to be supported by third-party `~.units.ConversionInterface`s,
6+
the `~.units.ConversionInterface.default_units` and
7+
`~.units.ConversionInterface.convert` must allow for the *axis* argument to be
8+
a ``matplotlib.images._ImageBase`` object.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
`~.axes.Axes.imshow` can now be used with unit-ful data
2+
-------------------------------------------------------
3+
4+
`~.axes.Axes.imshow` can now be used with data that has units attached to it.

examples/units/basic_units.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,16 @@ def convert(val, unit, axis):
346346
if np.iterable(val):
347347
if isinstance(val, np.ma.MaskedArray):
348348
val = val.astype(float).filled(np.nan)
349-
out = np.empty(len(val))
350-
for i, thisval in enumerate(val):
351-
if np.ma.is_masked(thisval):
352-
out[i] = np.nan
353-
else:
354-
try:
355-
out[i] = thisval.convert_to(unit).get_value()
356-
except AttributeError:
357-
out[i] = thisval
349+
out = np.empty(np.shape(val))
350+
masked_mask = np.ma.getmaskarray(val)
351+
out[masked_mask] = np.nan
352+
out[~masked_mask] = val[~masked_mask].convert_to(unit).get_value()
358353
return out
354+
359355
if np.ma.is_masked(val):
360356
return np.nan
361357
else:
358+
# Scalar
362359
return val.convert_to(unit).get_value()
363360

364361
@staticmethod

examples/units/units_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
=================
3+
Images with units
4+
=================
5+
Plotting images with units.
6+
7+
.. only:: builder_html
8+
9+
This example requires :download:`basic_units.py <basic_units.py>`
10+
"""
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from basic_units import secs
14+
15+
data = np.array([[1, 2],
16+
[3, 4]]) * secs
17+
18+
fig, ax = plt.subplots()
19+
image = ax.imshow(data)
20+
fig.colorbar(image)
21+
plt.show()

lib/matplotlib/colorbar.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,6 @@ def __init__(self, ax, mappable=None, *, cmap=None,
470470
linewidths=[0.5 * mpl.rcParams['axes.linewidth']])
471471
self.ax.add_collection(self.dividers)
472472

473-
self.locator = None
474-
self.formatter = None
475473
self.__scale = None # linear, log10 for now. Hopefully more?
476474

477475
if ticklocation == 'auto':
@@ -481,15 +479,23 @@ def __init__(self, ax, mappable=None, *, cmap=None,
481479
self.set_label(label)
482480
self._reset_locator_formatter_scale()
483481

482+
self.locator = None
484483
if np.iterable(ticks):
485484
self.locator = ticker.FixedLocator(ticks, nbins=len(ticks))
486485
else:
487486
self.locator = ticks # Handle default in _ticker()
488487

488+
self.formatter = None
489489
if isinstance(format, str):
490490
self.formatter = ticker.FormatStrFormatter(format)
491-
else:
492-
self.formatter = format # Assume it is a Formatter or None
491+
elif format is not None:
492+
self.formatter = format # Assume it is a Formatter
493+
elif hasattr(mappable, 'converter') and hasattr(mappable, 'units'):
494+
# Set from mappable if it has a converter and units
495+
info = mappable.converter.axisinfo(mappable.units, self._long_axis)
496+
if info is not None and info.majfmt is not None:
497+
self.formatter = info.majfmt
498+
493499
self.draw_all()
494500

495501
if isinstance(mappable, contour.ContourSet) and not mappable.filled:

lib/matplotlib/image.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import matplotlib.colors as mcolors
1919
import matplotlib.cm as cm
2020
import matplotlib.cbook as cbook
21+
import matplotlib.units as munits
2122
# For clarity, names from _image are given explicitly in this module:
2223
import matplotlib._image as _image
2324
# For user convenience, the names from _image are also imported into
@@ -696,6 +697,7 @@ def set_data(self, A):
696697
"""
697698
if isinstance(A, PIL.Image.Image):
698699
A = pil_to_array(A) # Needed e.g. to apply png palette.
700+
A = self._convert_units(A)
699701
self._A = cbook.safe_masked_invalid(A, copy=True)
700702

701703
if (self._A.dtype != np.uint8 and
@@ -733,6 +735,30 @@ def set_data(self, A):
733735
self._rgbacache = None
734736
self.stale = True
735737

738+
def _convert_units(self, A):
739+
# Take the first element since units expects a 1D sequence, not 2D
740+
converter = munits.registry.get_converter(A[0])
741+
if converter is None:
742+
return A
743+
744+
try:
745+
units = converter.default_units(A, self)
746+
except Exception as e:
747+
raise RuntimeError(
748+
f'{converter} failed when trying to return the default units '
749+
f'for this image. This may be because {converter} has not '
750+
'implemented support for images in the default_units() method.'
751+
) from e
752+
753+
try:
754+
return converter.convert(A, units, self)
755+
except Exception as e:
756+
raise RuntimeError(
757+
f'{converter} failed when trying to convert the units '
758+
f'for this image. This may be because {converter} has not '
759+
'implemented support for images in the convert() method.'
760+
) from e
761+
736762
def set_array(self, A):
737763
"""
738764
Retained for backwards compatibility - use set_data instead.

lib/matplotlib/units.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ def axisinfo(unit, axis):
118118

119119
@staticmethod
120120
def default_units(x, axis):
121-
"""Return the default unit for *x* or ``None`` for the given axis."""
121+
"""
122+
Return the default unit for *x* or ``None``.
123+
124+
*axis* can be either an `Axis` or an ``_ImageBase`` (if units of a 2D
125+
image are being converted).
126+
"""
122127
return None
123128

124129
@staticmethod
@@ -128,6 +133,9 @@ def convert(obj, unit, axis):
128133
129134
If *obj* is a sequence, return the converted sequence. The output must
130135
be a sequence of scalars that can be used by the numpy array layer.
136+
137+
*axis* can be either an `Axis` or an ``_ImageBase`` (if units of a 2D
138+
image are being converted).
131139
"""
132140
return obj
133141

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy