Skip to content

Commit babbee0

Browse files
committed
Updated input types for MultiNorm.__call__()
1 parent c6cf321 commit babbee0

File tree

2 files changed

+182
-12
lines changed

2 files changed

+182
-12
lines changed

lib/matplotlib/colors.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,10 +3413,9 @@ def __call__(self, values, clip=None, structured_output=None):
34133413
Parameters
34143414
----------
34153415
values : array-like
3416-
Data to normalize, as tuple, scalar array or structured array.
3416+
Data to normalize, as tuple or list or structured array.
34173417
3418-
- If tuple, must be of length `n_components`
3419-
- If scalar array, the first axis must be of length `n_components`
3418+
- If tuple or list, must be of length `n_components`
34203419
- If structured array, must have `n_components` fields.
34213420
34223421
clip : list of bools or bool or None, optional
@@ -3530,22 +3529,72 @@ def _iterable_components_in_data(data, n_components):
35303529
Parameters
35313530
----------
35323531
data : np.ndarray, tuple or list
3533-
The input array. It must either be an array with n_components fields or have
3534-
a length (n_components)
3532+
The input data, as a tuple or list or structured array.
3533+
3534+
- If tuple or list, must be of length `n_components`
3535+
- If structured array, must have `n_components` fields.
35353536
35363537
Returns
35373538
-------
35383539
tuple of np.ndarray
35393540
35403541
"""
3541-
if isinstance(data, np.ndarray) and data.dtype.fields is not None:
3542-
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3543-
if len(data) != n_components:
3544-
raise ValueError("The input to this `MultiNorm` must be of shape "
3545-
f"({n_components}, ...), or be structured array or scalar "
3546-
f"with {n_components} fields.")
3542+
if isinstance(data, np.ndarray):
3543+
if data.dtype.fields is not None:
3544+
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3545+
if len(data) != n_components:
3546+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3547+
f". A structured array with "
3548+
f"{len(data)} fields is not compatible")
3549+
else:
3550+
# Input is a scalar array, which we do not support.
3551+
# try to give a hint as to how the data can be converted to
3552+
# an accepted format
3553+
if ((len(data.shape) == 1 and
3554+
data.shape[0] == n_components) or
3555+
(len(data.shape) > 1 and
3556+
data.shape[0] == n_components and
3557+
data.shape[-1] != n_components)
3558+
):
3559+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3560+
". You can use `list(data)` to convert"
3561+
f" the input data of shape {data.shape} to"
3562+
" a compatible list")
3563+
3564+
elif (len(data.shape) > 1 and
3565+
data.shape[-1] == n_components and
3566+
data.shape[0] != n_components):
3567+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3568+
". You can use "
3569+
"`rfn.unstructured_to_structured(data)` available "
3570+
"with `from numpy.lib import recfunctions as rfn` "
3571+
"to convert the input array of shape "
3572+
f"{data.shape} to a structured array")
3573+
else:
3574+
# Cannot give shape hint
3575+
# Either neither first nor last axis matches, or both do.
3576+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3577+
f". An np.ndarray of shape {data.shape} is"
3578+
" not compatible")
3579+
elif isinstance(data, (tuple, list)):
3580+
if len(data) != n_components:
3581+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3582+
f". A {type(data)} of length {len(data)} is"
3583+
" not compatible")
3584+
else:
3585+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3586+
f". Input of type {type(data)} is not supported")
3587+
35473588
return data
35483589

3590+
@staticmethod
3591+
def _get_input_err(n_components):
3592+
# returns the start of the error message given when a
3593+
# MultiNorm receives incompatible input
3594+
return ("The input to this `MultiNorm` must be a list or tuple "
3595+
f"of length {n_components}, or be structured array "
3596+
f"with {n_components} fields")
3597+
35493598
@staticmethod
35503599
def _ensure_multicomponent_data(data, n_components):
35513600
"""

lib/matplotlib/tests/test_colors.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import base64
1010
import platform
1111

12+
from numpy.lib import recfunctions as rfn
1213
from numpy.testing import assert_array_equal, assert_array_almost_equal
1314

1415
from matplotlib import cbook, cm
@@ -1891,7 +1892,7 @@ def test_close_error_name():
18911892
matplotlib.colormaps["grays"]
18921893

18931894

1894-
def test_multi_norm():
1895+
def test_multi_norm_creation():
18951896
# tests for mcolors.MultiNorm
18961897

18971898
# test wrong input
@@ -1911,13 +1912,24 @@ def test_multi_norm():
19111912
match="Invalid norm str name"):
19121913
mcolors.MultiNorm(["None"])
19131914

1915+
norm = mpl.colors.MultiNorm(['linear', 'linear'])
1916+
1917+
1918+
def test_multi_norm_call_vmin_vmax():
19141919
# test get vmin, vmax
19151920
norm = mpl.colors.MultiNorm(['linear', 'log'])
19161921
norm.vmin = 1
19171922
norm.vmax = 2
19181923
assert norm.vmin == (1, 1)
19191924
assert norm.vmax == (2, 2)
19201925

1926+
1927+
def test_multi_norm_call_clip_inverse():
1928+
# test get vmin, vmax
1929+
norm = mpl.colors.MultiNorm(['linear', 'log'])
1930+
norm.vmin = 1
1931+
norm.vmax = 2
1932+
19211933
# test call with clip
19221934
assert_array_equal(norm([3, 3], clip=False), [2.0, 1.584962500721156])
19231935
assert_array_equal(norm([3, 3], clip=True), [1.0, 1.0])
@@ -1933,6 +1945,9 @@ def test_multi_norm():
19331945
# test inverse
19341946
assert_array_almost_equal(norm.inverse([0.5, 0.5849625007211562]), [1.5, 1.5])
19351947

1948+
1949+
def test_multi_norm_autoscale():
1950+
norm = mpl.colors.MultiNorm(['linear', 'log'])
19361951
# test autoscale
19371952
norm.autoscale([[0, 1, 2, 3], [0.1, 1, 2, 3]])
19381953
assert_array_equal(norm.vmin, [0, 0.1])
@@ -1945,3 +1960,109 @@ def test_multi_norm():
19451960
assert_array_equal(norm([5, 0]), [1, 0.5])
19461961
assert_array_equal(norm.vmin, (0, -50))
19471962
assert_array_equal(norm.vmax, (5, 50))
1963+
1964+
1965+
def test_mult_norm_call_types():
1966+
mn = mpl.colors.MultiNorm(['linear', 'linear'])
1967+
mn.vmin = -2
1968+
mn.vmax = 2
1969+
1970+
vals = np.arange(6).reshape((3,2))
1971+
target = np.ma.array([(0.5, 0.75),
1972+
(1., 1.25),
1973+
(1.5, 1.75)])
1974+
1975+
# test structured array as input
1976+
structured_target = rfn.unstructured_to_structured(target)
1977+
from_mn= mn(rfn.unstructured_to_structured(vals))
1978+
assert from_mn.dtype == structured_target.dtype
1979+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1980+
rfn.structured_to_unstructured(structured_target))
1981+
1982+
# test list of arrays as input
1983+
assert_array_almost_equal(mn(list(vals.T)),
1984+
list(target.T))
1985+
# test list of floats as input
1986+
assert_array_almost_equal(mn(list(vals[0])),
1987+
list(target[0]))
1988+
# test tuple of arrays as input
1989+
assert_array_almost_equal(mn(tuple(vals.T)),
1990+
list(target.T))
1991+
1992+
1993+
# test setting structured_output true/false:
1994+
# structured input, structured output
1995+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=True)
1996+
assert from_mn.dtype == structured_target.dtype
1997+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1998+
rfn.structured_to_unstructured(structured_target))
1999+
# structured input, list as output
2000+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=False)
2001+
assert_array_almost_equal(from_mn,
2002+
list(target.T))
2003+
# list as input, structured output
2004+
from_mn= mn(list(vals.T), structured_output=True)
2005+
assert from_mn.dtype == structured_target.dtype
2006+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
2007+
rfn.structured_to_unstructured(structured_target))
2008+
# list as input, list as output
2009+
from_mn = mn(list(vals.T), structured_output=False)
2010+
assert_array_almost_equal(from_mn,
2011+
list(target.T))
2012+
2013+
# test with NoNorm, list as input
2014+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2015+
no_norm_out = mn_no_norm(list(vals.T))
2016+
assert_array_almost_equal(no_norm_out,
2017+
[[0., 0.5, 1.],
2018+
[1, 3, 5]])
2019+
assert no_norm_out[0].dtype == np.dtype('float64')
2020+
assert no_norm_out[1].dtype == np.dtype('int64')
2021+
2022+
# test with NoNorm, structured array as input
2023+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2024+
no_norm_out = mn_no_norm(rfn.unstructured_to_structured(vals))
2025+
assert_array_almost_equal(rfn.structured_to_unstructured(no_norm_out),
2026+
np.array(\
2027+
[[0., 0.5, 1.],
2028+
[1, 3, 5]]).T)
2029+
assert no_norm_out.dtype['f0'] == np.dtype('float64')
2030+
assert no_norm_out.dtype['f1'] == np.dtype('int64')
2031+
2032+
# test single int as input
2033+
with pytest.raises(ValueError,
2034+
match="Input of type <class 'int'> is not supported"):
2035+
mn(1)
2036+
2037+
# test list of incompatible size
2038+
with pytest.raises(ValueError,
2039+
match="A <class 'list'> of length 3 is not compatible"):
2040+
mn([3, 2, 1])
2041+
2042+
# np.arrays of shapes that can be converted:
2043+
for data in [np.zeros(2), np.zeros((2,3)), np.zeros((2,3,3))]:
2044+
with pytest.raises(ValueError,
2045+
match=r"You can use `list\(data\)` to convert"):
2046+
mn(data)
2047+
2048+
for data in [np.zeros((3, 2)), np.zeros((3, 3, 2))]:
2049+
with pytest.raises(ValueError,
2050+
match=r"You can use `rfn.unstructured_to_structured"):
2051+
mn(data)
2052+
2053+
# np.ndarray that can be converted, but unclear if first or last axis
2054+
for data in [np.zeros((2, 2)), np.zeros((2, 3, 2))]:
2055+
with pytest.raises(ValueError,
2056+
match="An np.ndarray of shape"):
2057+
mn(data)
2058+
2059+
# incompatible arrays where no relevant axis matches
2060+
for data in [np.zeros(3), np.zeros((3, 2, 3))]:
2061+
with pytest.raises(ValueError,
2062+
match=r"An np.ndarray of shape"):
2063+
mn(data)
2064+
2065+
# test incompatible class
2066+
with pytest.raises(ValueError,
2067+
match="Input of type <class 'str'> is not supported"):
2068+
mn("An object of incompatible class")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy