Skip to content

Commit 81ab0d5

Browse files
committed
Updated input types for MultiNorm.__call__()
1 parent 67b8264 commit 81ab0d5

File tree

2 files changed

+182
-12
lines changed

2 files changed

+182
-12
lines changed

lib/matplotlib/colors.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,10 +3413,9 @@ def __call__(self, values, clip=None, structured_output=None):
34133413
Parameters
34143414
----------
34153415
values : array-like
3416-
Data to normalize, as tuple, scalar array or structured array.
3416+
Data to normalize, as tuple or list or structured array.
34173417
3418-
- If tuple, must be of length `n_components`
3419-
- If scalar array, the first axis must be of length `n_components`
3418+
- If tuple or list, must be of length `n_components`
34203419
- If structured array, must have `n_components` fields.
34213420
34223421
clip : list of bools or bool or None, optional
@@ -3530,22 +3529,72 @@ def _iterable_components_in_data(data, n_components):
35303529
Parameters
35313530
----------
35323531
data : np.ndarray, tuple or list
3533-
The input array. It must either be an array with n_components fields or have
3534-
a length (n_components)
3532+
The input data, as a tuple or list or structured array.
3533+
3534+
- If tuple or list, must be of length `n_components`
3535+
- If structured array, must have `n_components` fields.
35353536
35363537
Returns
35373538
-------
35383539
tuple of np.ndarray
35393540
35403541
"""
3541-
if isinstance(data, np.ndarray) and data.dtype.fields is not None:
3542-
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3543-
if len(data) != n_components:
3544-
raise ValueError("The input to this `MultiNorm` must be of shape "
3545-
f"({n_components}, ...), or be structured array or scalar "
3546-
f"with {n_components} fields.")
3542+
if isinstance(data, np.ndarray):
3543+
if data.dtype.fields is not None:
3544+
data = tuple(data[descriptor[0]] for descriptor in data.dtype.descr)
3545+
if len(data) != n_components:
3546+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3547+
f". A structured array with "
3548+
f"{len(data)} fields is not compatible")
3549+
else:
3550+
# Input is a scalar array, which we do not support.
3551+
# try to give a hint as to how the data can be converted to
3552+
# an accepted format
3553+
if ((len(data.shape) == 1 and
3554+
data.shape[0] == n_components) or
3555+
(len(data.shape) > 1 and
3556+
data.shape[0] == n_components and
3557+
data.shape[-1] != n_components)
3558+
):
3559+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3560+
". You can use `list(data)` to convert"
3561+
f" the input data of shape {data.shape} to"
3562+
" a compatible list")
3563+
3564+
elif (len(data.shape) > 1 and
3565+
data.shape[-1] == n_components and
3566+
data.shape[0] != n_components):
3567+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3568+
". You can use "
3569+
"`rfn.unstructured_to_structured(data)` available "
3570+
"with `from numpy.lib import recfunctions as rfn` "
3571+
"to convert the input array of shape "
3572+
f"{data.shape} to a structured array")
3573+
else:
3574+
# Cannot give shape hint
3575+
# Either neither first nor last axis matches, or both do.
3576+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3577+
f". An np.ndarray of shape {data.shape} is"
3578+
" not compatible")
3579+
elif isinstance(data, (tuple, list)):
3580+
if len(data) != n_components:
3581+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3582+
f". A {type(data)} of length {len(data)} is"
3583+
" not compatible")
3584+
else:
3585+
raise ValueError(f"{MultiNorm._get_input_err(n_components)}"
3586+
f". Input of type {type(data)} is not supported")
3587+
35473588
return data
35483589

3590+
@staticmethod
3591+
def _get_input_err(n_components):
3592+
# returns the start of the error message given when a
3593+
# MultiNorm receives incompatible input
3594+
return ("The input to this `MultiNorm` must be a list or tuple "
3595+
f"of length {n_components}, or be structured array "
3596+
f"with {n_components} fields")
3597+
35493598
@staticmethod
35503599
def _ensure_multicomponent_data(data, n_components):
35513600
"""

lib/matplotlib/tests/test_colors.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import base64
1010
import platform
1111

12+
from numpy.lib import recfunctions as rfn
1213
from numpy.testing import assert_array_equal, assert_array_almost_equal
1314

1415
from matplotlib import cbook, cm
@@ -1881,7 +1882,7 @@ def n_components(self):
18811882
axes[1,1].contourf(r, colorizer=colorizer)
18821883

18831884

1884-
def test_multi_norm():
1885+
def test_multi_norm_creation():
18851886
# tests for mcolors.MultiNorm
18861887

18871888
# test wrong input
@@ -1901,13 +1902,24 @@ def test_multi_norm():
19011902
match="Invalid norm str name"):
19021903
mcolors.MultiNorm(["None"])
19031904

1905+
norm = mpl.colors.MultiNorm(['linear', 'linear'])
1906+
1907+
1908+
def test_multi_norm_call_vmin_vmax():
19041909
# test get vmin, vmax
19051910
norm = mpl.colors.MultiNorm(['linear', 'log'])
19061911
norm.vmin = 1
19071912
norm.vmax = 2
19081913
assert norm.vmin == (1, 1)
19091914
assert norm.vmax == (2, 2)
19101915

1916+
1917+
def test_multi_norm_call_clip_inverse():
1918+
# test get vmin, vmax
1919+
norm = mpl.colors.MultiNorm(['linear', 'log'])
1920+
norm.vmin = 1
1921+
norm.vmax = 2
1922+
19111923
# test call with clip
19121924
assert_array_equal(norm([3, 3], clip=False), [2.0, 1.584962500721156])
19131925
assert_array_equal(norm([3, 3], clip=True), [1.0, 1.0])
@@ -1923,6 +1935,9 @@ def test_multi_norm():
19231935
# test inverse
19241936
assert_array_almost_equal(norm.inverse([0.5, 0.5849625007211562]), [1.5, 1.5])
19251937

1938+
1939+
def test_multi_norm_autoscale():
1940+
norm = mpl.colors.MultiNorm(['linear', 'log'])
19261941
# test autoscale
19271942
norm.autoscale([[0, 1, 2, 3], [0.1, 1, 2, 3]])
19281943
assert_array_equal(norm.vmin, [0, 0.1])
@@ -1935,3 +1950,109 @@ def test_multi_norm():
19351950
assert_array_equal(norm([5, 0]), [1, 0.5])
19361951
assert_array_equal(norm.vmin, (0, -50))
19371952
assert_array_equal(norm.vmax, (5, 50))
1953+
1954+
1955+
def test_mult_norm_call_types():
1956+
mn = mpl.colors.MultiNorm(['linear', 'linear'])
1957+
mn.vmin = -2
1958+
mn.vmax = 2
1959+
1960+
vals = np.arange(6).reshape((3,2))
1961+
target = np.ma.array([(0.5, 0.75),
1962+
(1., 1.25),
1963+
(1.5, 1.75)])
1964+
1965+
# test structured array as input
1966+
structured_target = rfn.unstructured_to_structured(target)
1967+
from_mn= mn(rfn.unstructured_to_structured(vals))
1968+
assert from_mn.dtype == structured_target.dtype
1969+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1970+
rfn.structured_to_unstructured(structured_target))
1971+
1972+
# test list of arrays as input
1973+
assert_array_almost_equal(mn(list(vals.T)),
1974+
list(target.T))
1975+
# test list of floats as input
1976+
assert_array_almost_equal(mn(list(vals[0])),
1977+
list(target[0]))
1978+
# test tuple of arrays as input
1979+
assert_array_almost_equal(mn(tuple(vals.T)),
1980+
list(target.T))
1981+
1982+
1983+
# test setting structured_output true/false:
1984+
# structured input, structured output
1985+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=True)
1986+
assert from_mn.dtype == structured_target.dtype
1987+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1988+
rfn.structured_to_unstructured(structured_target))
1989+
# structured input, list as output
1990+
from_mn = mn(rfn.unstructured_to_structured(vals), structured_output=False)
1991+
assert_array_almost_equal(from_mn,
1992+
list(target.T))
1993+
# list as input, structured output
1994+
from_mn= mn(list(vals.T), structured_output=True)
1995+
assert from_mn.dtype == structured_target.dtype
1996+
assert_array_almost_equal(rfn.structured_to_unstructured(from_mn),
1997+
rfn.structured_to_unstructured(structured_target))
1998+
# list as input, list as output
1999+
from_mn = mn(list(vals.T), structured_output=False)
2000+
assert_array_almost_equal(from_mn,
2001+
list(target.T))
2002+
2003+
# test with NoNorm, list as input
2004+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2005+
no_norm_out = mn_no_norm(list(vals.T))
2006+
assert_array_almost_equal(no_norm_out,
2007+
[[0., 0.5, 1.],
2008+
[1, 3, 5]])
2009+
assert no_norm_out[0].dtype == np.dtype('float64')
2010+
assert no_norm_out[1].dtype == np.dtype('int64')
2011+
2012+
# test with NoNorm, structured array as input
2013+
mn_no_norm = mpl.colors.MultiNorm(['linear', mcolors.NoNorm()])
2014+
no_norm_out = mn_no_norm(rfn.unstructured_to_structured(vals))
2015+
assert_array_almost_equal(rfn.structured_to_unstructured(no_norm_out),
2016+
np.array(\
2017+
[[0., 0.5, 1.],
2018+
[1, 3, 5]]).T)
2019+
assert no_norm_out.dtype['f0'] == np.dtype('float64')
2020+
assert no_norm_out.dtype['f1'] == np.dtype('int64')
2021+
2022+
# test single int as input
2023+
with pytest.raises(ValueError,
2024+
match="Input of type <class 'int'> is not supported"):
2025+
mn(1)
2026+
2027+
# test list of incompatible size
2028+
with pytest.raises(ValueError,
2029+
match="A <class 'list'> of length 3 is not compatible"):
2030+
mn([3, 2, 1])
2031+
2032+
# np.arrays of shapes that can be converted:
2033+
for data in [np.zeros(2), np.zeros((2,3)), np.zeros((2,3,3))]:
2034+
with pytest.raises(ValueError,
2035+
match=r"You can use `list\(data\)` to convert"):
2036+
mn(data)
2037+
2038+
for data in [np.zeros((3, 2)), np.zeros((3, 3, 2))]:
2039+
with pytest.raises(ValueError,
2040+
match=r"You can use `rfn.unstructured_to_structured"):
2041+
mn(data)
2042+
2043+
# np.ndarray that can be converted, but unclear if first or last axis
2044+
for data in [np.zeros((2, 2)), np.zeros((2, 3, 2))]:
2045+
with pytest.raises(ValueError,
2046+
match="An np.ndarray of shape"):
2047+
mn(data)
2048+
2049+
# incompatible arrays where no relevant axis matches
2050+
for data in [np.zeros(3), np.zeros((3, 2, 3))]:
2051+
with pytest.raises(ValueError,
2052+
match=r"An np.ndarray of shape"):
2053+
mn(data)
2054+
2055+
# test incompatible class
2056+
with pytest.raises(ValueError,
2057+
match="Input of type <class 'str'> is not supported"):
2058+
mn("An object of incompatible class")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy