Skip to content

Commit 709fba8

Browse files
authored
Merge pull request #22560 from oscargus/pandasconversion
Improve pandas/xarray/... conversion
2 parents a7b7260 + 56af810 commit 709fba8

File tree

8 files changed

+61
-21
lines changed

8 files changed

+61
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7944,8 +7944,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79447944
"""
79457945

79467946
def _kde_method(X, coords):
7947-
if hasattr(X, 'values'): # support pandas.Series
7948-
X = X.values
7947+
# Unpack in case of e.g. Pandas or xarray object
7948+
X = cbook._unpack_to_numpy(X)
79497949
# fallback gracefully if the vector contains only one value
79507950
if np.all(X[0] == X):
79517951
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,9 +1331,8 @@ def _to_unmasked_float_array(x):
13311331

13321332
def _check_1d(x):
13331333
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1334-
if hasattr(x, 'to_numpy'):
1335-
# if we are given an object that creates a numpy, we should use it...
1336-
x = x.to_numpy()
1334+
# Unpack in case of e.g. Pandas or xarray object
1335+
x = _unpack_to_numpy(x)
13371336
if not hasattr(x, 'shape') or len(x.shape) < 1:
13381337
return np.atleast_1d(x)
13391338
else:
@@ -1352,15 +1351,8 @@ def _reshape_2D(X, name):
13521351
*name* is used to generate the error message for invalid inputs.
13531352
"""
13541353

1355-
# unpack if we have a values or to_numpy method.
1356-
try:
1357-
X = X.to_numpy()
1358-
except AttributeError:
1359-
try:
1360-
if isinstance(X.values, np.ndarray):
1361-
X = X.values
1362-
except AttributeError:
1363-
pass
1354+
# Unpack in case of e.g. Pandas or xarray object
1355+
X = _unpack_to_numpy(X)
13641356

13651357
# Iterate over columns for ndarrays.
13661358
if isinstance(X, np.ndarray):
@@ -2251,3 +2243,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22512243
factory = _make_class_factory(mixin_class, fmt, attr_name)
22522244
cls = factory(base_class)
22532245
return cls.__new__(cls)
2246+
2247+
2248+
def _unpack_to_numpy(x):
2249+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2250+
if isinstance(x, np.ndarray):
2251+
# If numpy, return directly
2252+
return x
2253+
if hasattr(x, 'to_numpy'):
2254+
# Assume that any function to_numpy() do actually return a numpy array
2255+
return x.to_numpy()
2256+
if hasattr(x, 'values'):
2257+
xtmp = x.values
2258+
# For example a dict has a 'values' attribute, but it is not a property
2259+
# so in this case we do not want to return a function
2260+
if isinstance(xtmp, np.ndarray):
2261+
return xtmp
2262+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,8 @@ def date2num(d):
434434
The Gregorian calendar is assumed; this is not universal practice.
435435
For details see the module docstring.
436436
"""
437-
if hasattr(d, "values"):
438-
# this unpacks pandas series or dataframes...
439-
d = d.values
437+
# Unpack in case of e.g. Pandas or xarray object
438+
d = cbook._unpack_to_numpy(d)
440439

441440
# make an iterable, but save state to unpack later:
442441
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,14 +701,37 @@ def test_reshape2d_pandas(pd):
701701
for x, xnew in zip(X.T, Xnew):
702702
np.testing.assert_array_equal(x, xnew)
703703

704+
705+
def test_reshape2d_xarray(xr):
706+
# separate to allow the rest of the tests to run if no xarray...
704707
X = np.arange(30).reshape(10, 3)
705-
x = pd.DataFrame(X, columns=["a", "b", "c"])
708+
x = xr.DataArray(X, dims=["x", "y"])
706709
Xnew = cbook._reshape_2D(x, 'x')
707710
# Need to check each row because _reshape_2D returns a list of arrays:
708711
for x, xnew in zip(X.T, Xnew):
709712
np.testing.assert_array_equal(x, xnew)
710713

711714

715+
def test_index_of_pandas(pd):
716+
# separate to allow the rest of the tests to run if no pandas...
717+
X = np.arange(30).reshape(10, 3)
718+
x = pd.DataFrame(X, columns=["a", "b", "c"])
719+
Idx, Xnew = cbook.index_of(x)
720+
np.testing.assert_array_equal(X, Xnew)
721+
IdxRef = np.arange(10)
722+
np.testing.assert_array_equal(Idx, IdxRef)
723+
724+
725+
def test_index_of_xarray(xr):
726+
# separate to allow the rest of the tests to run if no xarray...
727+
X = np.arange(30).reshape(10, 3)
728+
x = xr.DataArray(X, dims=["x", "y"])
729+
Idx, Xnew = cbook.index_of(x)
730+
np.testing.assert_array_equal(X, Xnew)
731+
IdxRef = np.arange(10)
732+
np.testing.assert_array_equal(Idx, IdxRef)
733+
734+
712735
def test_contiguous_regions():
713736
a, b, c = 3, 4, 5
714737
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

requirements/testing/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pandas!=0.25.0
77
pikepdf
88
pytz
99
pywin32; sys.platform == 'win32'
10+
xarray

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy