Skip to content

Commit 3323ae8

Browse files
authored
Merge pull request #25887 from patel-zeel/speed_up_jax_torch_plots
Update `_unpack_to_numpy` function to convert JAX and PyTorch arrays to NumPy
2 parents 7fde77e + 9acfb5b commit 3323ae8

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

lib/matplotlib/cbook.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,30 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23492349
return cls.__new__(cls)
23502350

23512351

2352+
def _is_torch_array(x):
2353+
"""Check if 'x' is a PyTorch Tensor."""
2354+
try:
2355+
# we're intentionally not attempting to import torch. If somebody
2356+
# has created a torch array, torch should already be in sys.modules
2357+
return isinstance(x, sys.modules['torch'].Tensor)
2358+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2359+
# we're attempting to access attributes on imported modules which
2360+
# may have arbitrary user code, so we deliberately catch all exceptions
2361+
return False
2362+
2363+
2364+
def _is_jax_array(x):
2365+
"""Check if 'x' is a JAX Array."""
2366+
try:
2367+
# we're intentionally not attempting to import jax. If somebody
2368+
# has created a jax array, jax should already be in sys.modules
2369+
return isinstance(x, sys.modules['jax'].Array)
2370+
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2371+
# we're attempting to access attributes on imported modules which
2372+
# may have arbitrary user code, so we deliberately catch all exceptions
2373+
return False
2374+
2375+
23522376
def _unpack_to_numpy(x):
23532377
"""Internal helper to extract data from e.g. pandas and xarray objects."""
23542378
if isinstance(x, np.ndarray):
@@ -2363,6 +2387,12 @@ def _unpack_to_numpy(x):
23632387
# so in this case we do not want to return a function
23642388
if isinstance(xtmp, np.ndarray):
23652389
return xtmp
2390+
if _is_torch_array(x) or _is_jax_array(x):
2391+
xtmp = x.__array__()
2392+
2393+
# In case __array__() method does not return a numpy array in future
2394+
if isinstance(xtmp, np.ndarray):
2395+
return xtmp
23662396
return x
23672397

23682398

lib/matplotlib/tests/test_cbook.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import itertools
45
import pickle
56

@@ -16,6 +17,7 @@
1617
from matplotlib import _api, cbook
1718
import matplotlib.colors as mcolors
1819
from matplotlib.cbook import delete_masked_points, strip_math
20+
from types import ModuleType
1921

2022

2123
class Test_delete_masked_points:
@@ -938,3 +940,45 @@ def test_auto_format_str(fmt, value, result):
938940
"""Apply *value* to the format string *fmt*."""
939941
assert cbook._auto_format_str(fmt, value) == result
940942
assert cbook._auto_format_str(fmt, np.float64(value)) == result
943+
944+
945+
def test_unpack_to_numpy_from_torch():
946+
"""Test that torch tensors are converted to numpy arrays.
947+
We don't want to create a dependency on torch in the test suite, so we mock it.
948+
"""
949+
class Tensor:
950+
def __init__(self, data):
951+
self.data = data
952+
def __array__(self):
953+
return self.data
954+
torch = ModuleType('torch')
955+
torch.Tensor = Tensor
956+
sys.modules['torch'] = torch
957+
958+
data = np.arange(10)
959+
torch_tensor = torch.Tensor(data)
960+
961+
result = cbook._unpack_to_numpy(torch_tensor)
962+
assert result is torch_tensor.__array__()
963+
964+
965+
def test_unpack_to_numpy_from_jax():
966+
"""Test that jax arrays are converted to numpy arrays.
967+
We don't want to create a dependency on jax in the test suite, so we mock it.
968+
"""
969+
class Array:
970+
def __init__(self, data):
971+
self.data = data
972+
def __array__(self):
973+
return self.data
974+
975+
jax = ModuleType('jax')
976+
jax.Array = Array
977+
978+
sys.modules['jax'] = jax
979+
980+
data = np.arange(10)
981+
jax_array = jax.Array(data)
982+
983+
result = cbook._unpack_to_numpy(jax_array)
984+
assert result is jax_array.__array__()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy