Skip to content

Commit 19e86ca

Browse files
Oscar Gustafssonoscargus
authored andcommitted
Improve pandas and xarray conversion
1 parent 0359832 commit 19e86ca

File tree

9 files changed

+63
-22
lines changed

9 files changed

+63
-22
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ dependencies:
5656
- pytest-xdist
5757
- tornado
5858
- pytz
59+
- xarray

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7907,8 +7907,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79077907
"""
79087908

79097909
def _kde_method(X, coords):
7910-
if hasattr(X, 'values'): # support pandas.Series
7911-
X = X.values
7910+
# Unpack in case of e.g. Pandas or xarray object
7911+
X = cbook._unpack_to_numpy(X)
79127912
# fallback gracefully if the vector contains only one value
79137913
if np.all(X[0] == X):
79147914
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,8 @@ def _to_unmasked_float_array(x):
13111311

13121312
def _check_1d(x):
13131313
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1314-
if hasattr(x, 'to_numpy'):
1315-
# if we are given an object that creates a numpy, we should use it...
1316-
x = x.to_numpy()
1314+
# Unpack in case of e.g. Pandas or xarray object
1315+
x = _unpack_to_numpy(x)
13171316
if not hasattr(x, 'shape') or len(x.shape) < 1:
13181317
return np.atleast_1d(x)
13191318
else:
@@ -1332,15 +1331,8 @@ def _reshape_2D(X, name):
13321331
*name* is used to generate the error message for invalid inputs.
13331332
"""
13341333

1335-
# unpack if we have a values or to_numpy method.
1336-
try:
1337-
X = X.to_numpy()
1338-
except AttributeError:
1339-
try:
1340-
if isinstance(X.values, np.ndarray):
1341-
X = X.values
1342-
except AttributeError:
1343-
pass
1334+
# Unpack in case of e.g. Pandas or xarray object
1335+
X = _unpack_to_numpy(X)
13441336

13451337
# Iterate over columns for ndarrays.
13461338
if isinstance(X, np.ndarray):
@@ -1626,7 +1618,7 @@ def index_of(y):
16261618
The x and y values to plot.
16271619
"""
16281620
try:
1629-
return y.index.to_numpy(), y.to_numpy()
1621+
return _unpack_to_numpy(y.index), _unpack_to_numpy(y)
16301622
except AttributeError:
16311623
pass
16321624
try:
@@ -2231,3 +2223,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22312223
factory = _make_class_factory(mixin_class, fmt, attr_name)
22322224
cls = factory(base_class)
22332225
return cls.__new__(cls)
2226+
2227+
2228+
def _unpack_to_numpy(x):
2229+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2230+
if isinstance(x, np.ndarray):
2231+
# If numpy, return directly
2232+
return x
2233+
if hasattr(x, 'to_numpy'):
2234+
# Assume that any function to_numpy() do actually return a numpy array
2235+
return x.to_numpy()
2236+
if hasattr(x, 'values'):
2237+
xtmp = x.values
2238+
# For example a dict has a 'values' attribute, but it is not a property
2239+
# so in this case we do not want to return a function
2240+
if isinstance(xtmp, np.ndarray):
2241+
return xtmp
2242+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,8 @@ def date2num(d):
437437
The Gregorian calendar is assumed; this is not universal practice.
438438
For details see the module docstring.
439439
"""
440-
if hasattr(d, "values"):
441-
# this unpacks pandas series or dataframes...
442-
d = d.values
440+
# Unpack in case of e.g. Pandas or xarray object
441+
d = cbook._unpack_to_numpy(d)
443442

444443
# make an iterable, but save state to unpack later:
445444
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,37 @@ def test_reshape2d_pandas(pd):
680680
for x, xnew in zip(X.T, Xnew):
681681
np.testing.assert_array_equal(x, xnew)
682682

683+
684+
def test_reshape2d_xarray(xr):
685+
# separate to allow the rest of the tests to run if no xarray...
683686
X = np.arange(30).reshape(10, 3)
684-
x = pd.DataFrame(X, columns=["a", "b", "c"])
687+
x = xr.DataArray(X, dims=["x", "y"])
685688
Xnew = cbook._reshape_2D(x, 'x')
686689
# Need to check each row because _reshape_2D returns a list of arrays:
687690
for x, xnew in zip(X.T, Xnew):
688691
np.testing.assert_array_equal(x, xnew)
689692

690693

694+
def test_index_of_pandas(pd):
695+
# separate to allow the rest of the tests to run if no pandas...
696+
X = np.arange(30).reshape(10, 3)
697+
x = pd.DataFrame(X, columns=["a", "b", "c"])
698+
Idx, Xnew = cbook.index_of(x)
699+
np.testing.assert_array_equal(X, Xnew)
700+
IdxRef = np.arange(10)
701+
np.testing.assert_array_equal(Idx, IdxRef)
702+
703+
704+
def test_index_of_xarray(xr):
705+
# separate to allow the rest of the tests to run if no xarray...
706+
X = np.arange(30).reshape(10, 3)
707+
x = xr.DataArray(X, dims=["x", "y"])
708+
Idx, Xnew = cbook.index_of(x)
709+
np.testing.assert_array_equal(X, Xnew)
710+
IdxRef = np.arange(10)
711+
np.testing.assert_array_equal(Idx, IdxRef)
712+
713+
691714
def test_contiguous_regions():
692715
a, b, c = 3, 4, 5
693716
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

requirements/testing/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pandas!=0.25.0
77
pikepdf
88
pytz
99
pywin32; sys.platform == 'win32'
10+
xarray

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy