Skip to content

Commit d130006

Browse files
committed
simplify to_rgba() by extracting the part relating to RGBA data
1 parent e7d53e0 commit d130006

File tree

3 files changed

+52
-49
lines changed

3 files changed

+52
-49
lines changed

lib/matplotlib/artist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ def set_array(self, A):
14461446
A : array-like or None
14471447
The values that are mapped to colors.
14481448
1449-
The base class `.VectorMappable` does not make any assumptions on
1449+
The base class `.ColorizingArtist` does not make any assumptions on
14501450
the dimensionality and shape of the value array *A*.
14511451
"""
14521452
if A is None:
@@ -1466,7 +1466,7 @@ def get_array(self):
14661466
"""
14671467
Return the array of values, that are mapped to colors.
14681468
1469-
The base class `.VectorMappable` does not make any assumptions on
1469+
The base class `.ColorizingArtist` does not make any assumptions on
14701470
the dimensionality and shape of the array.
14711471
"""
14721472
return self._A

lib/matplotlib/cm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class ScalarMappable(colorizer.ColorizerShim):
279279
"""
280280
A mixin class to map one or multiple sets of scalar data to RGBA.
281281
282-
The VectorMappable applies data normalization before returning RGBA colors
282+
The ScalarMappable applies data normalization before returning RGBA colors
283283
from the given `~matplotlib.colors.Colormap`, `~matplotlib.colors.BivarColormap`,
284284
or `~matplotlib.colors.MultivarColormap`.
285285
"""

lib/matplotlib/colorizer.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
class Colorizer():
2525
"""
2626
Class that holds the data to color pipeline
27-
accessible via `.to_rgba(A)` and executed via
28-
the `.norm` and `.cmap` attributes.
27+
accessible via `Colorizer.to_rgba(A)` and executed via
28+
the `Colorizer.norm` and `Colorizer.cmap` attributes.
2929
"""
3030
def __init__(self, cmap=None, norm=None):
3131

@@ -125,56 +125,59 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
125125
126126
"""
127127
# First check for special case, image input:
128-
# First check for special case, image input:
129-
try:
130-
if x.ndim == 3:
131-
if x.shape[2] == 3:
132-
if alpha is None:
133-
alpha = 1
134-
if x.dtype == np.uint8:
135-
alpha = np.uint8(alpha * 255)
136-
m, n = x.shape[:2]
137-
xx = np.empty(shape=(m, n, 4), dtype=x.dtype)
138-
xx[:, :, :3] = x
139-
xx[:, :, 3] = alpha
140-
elif x.shape[2] == 4:
141-
xx = x
142-
else:
143-
raise ValueError("Third dimension must be 3 or 4")
144-
if xx.dtype.kind == 'f':
145-
# If any of R, G, B, or A is nan, set to 0
146-
if np.any(nans := np.isnan(x)):
147-
if x.shape[2] == 4:
148-
xx = xx.copy()
149-
xx[np.any(nans, axis=2), :] = 0
150-
151-
if norm and (xx.max() > 1 or xx.min() < 0):
152-
raise ValueError("Floating point image RGB values "
153-
"must be in the 0..1 range.")
154-
if bytes:
155-
xx = (xx * 255).astype(np.uint8)
156-
elif xx.dtype == np.uint8:
157-
if not bytes:
158-
xx = xx.astype(np.float32) / 255
159-
else:
160-
raise ValueError("Image RGB array must be uint8 or "
161-
"floating point; found %s" % xx.dtype)
162-
# Account for any masked entries in the original array
163-
# If any of R, G, B, or A are masked for an entry, we set alpha to 0
164-
if np.ma.is_masked(x):
165-
xx[np.any(np.ma.getmaskarray(x), axis=2), 3] = 0
166-
return xx
167-
except AttributeError:
168-
# e.g., x is not an ndarray; so try mapping it
169-
pass
170-
171-
# This is the normal case, mapping a scalar array:
128+
if isinstance(x, np.ndarray) and x.ndim == 3:
129+
return self._pass_image_data(x, alpha, bytes, norm)
130+
131+
# Otherwise run norm -> colormap pipeline
172132
x = ma.asarray(x)
173133
if norm:
174134
x = self.norm(x)
175135
rgba = self.cmap(x, alpha=alpha, bytes=bytes)
176136
return rgba
177137

138+
@staticmethod
139+
def _pass_image_data(x, alpha=None, bytes=False, norm=True):
140+
"""
141+
Helper function to pass ndarray of shape (...,3) or (..., 4)
142+
through `to_rgba()`, see `to_rgba()` for docstring.
143+
"""
144+
if x.shape[2] == 3:
145+
if alpha is None:
146+
alpha = 1
147+
if x.dtype == np.uint8:
148+
alpha = np.uint8(alpha * 255)
149+
m, n = x.shape[:2]
150+
xx = np.empty(shape=(m, n, 4), dtype=x.dtype)
151+
xx[:, :, :3] = x
152+
xx[:, :, 3] = alpha
153+
elif x.shape[2] == 4:
154+
xx = x
155+
else:
156+
raise ValueError("Third dimension must be 3 or 4")
157+
if xx.dtype.kind == 'f':
158+
# If any of R, G, B, or A is nan, set to 0
159+
if np.any(nans := np.isnan(x)):
160+
if x.shape[2] == 4:
161+
xx = xx.copy()
162+
xx[np.any(nans, axis=2), :] = 0
163+
164+
if norm and (xx.max() > 1 or xx.min() < 0):
165+
raise ValueError("Floating point image RGB values "
166+
"must be in the 0..1 range.")
167+
if bytes:
168+
xx = (xx * 255).astype(np.uint8)
169+
elif xx.dtype == np.uint8:
170+
if not bytes:
171+
xx = xx.astype(np.float32) / 255
172+
else:
173+
raise ValueError("Image RGB array must be uint8 or "
174+
"floating point; found %s" % xx.dtype)
175+
# Account for any masked entries in the original array
176+
# If any of R, G, B, or A are masked for an entry, we set alpha to 0
177+
if np.ma.is_masked(x):
178+
xx[np.any(np.ma.getmaskarray(x), axis=2), 3] = 0
179+
return xx
180+
178181
def normalize(self, x):
179182
"""
180183
Normalize the data in x.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy