From 44a27d2b3202c403b34225ce289ae606104632d2 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Sun, 21 May 2023 02:56:38 +0200 Subject: [PATCH] subplots() returns AxesArray subplots() used to return a numpy array of Axes, which has some drawbacks. The numpy array is mainly used as a 2D container structure that allows 2D indexing. Apart from that, it's not particularly well suited: - Many of the numpy functions do not work on Axes. - Some functions work, but have awkward semantics; e.g. len() gives the number of rows. - We can't add our own functionality. AxesArray introduces a facade to the underlying array to allow us to customize the API. For the beginning, the API is 100% compatible with the previous numpy array behavior, but we deprecate everything except for a few reasonable methods. --- lib/matplotlib/gridspec.py | 71 +++++++++++++++++++++++- lib/matplotlib/tests/test_subplots.py | 78 +++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/gridspec.py b/lib/matplotlib/gridspec.py index b2f1b5d8ff2e..9cd9b81f9aca 100644 --- a/lib/matplotlib/gridspec.py +++ b/lib/matplotlib/gridspec.py @@ -309,10 +309,10 @@ def subplots(self, *, sharex=False, sharey=False, squeeze=True, if squeeze: # Discarding unneeded dimensions that equal 1. If we only have one # subplot, just return it instead of a 1-element array. - return axarr.item() if axarr.size == 1 else axarr.squeeze() + return axarr.item() if axarr.size == 1 else AxesArray(axarr.squeeze()) else: # Returned axis array will be always 2-d, even if nrows=ncols=1. - return axarr + return AxesArray(axarr) class GridSpec(GridSpecBase): @@ -736,3 +736,70 @@ def subgridspec(self, nrows, ncols, **kwargs): fig.add_subplot(gssub[0, i]) """ return GridSpecFromSubplotSpec(nrows, ncols, self, **kwargs) + + +class AxesArray: + """ + A container for a 1D or 2D grid arrangement of Axes. + + This is used as the return type of ``subplots()``. + + Formerly, ``subplots()`` returned a numpy array of Axes. For a transition period, + AxesArray will act like a numpy array, but all functions and properties that + are not listed explicitly below are deprecated. + """ + def __init__(self, array): + self._array = array + + @staticmethod + def _ensure_wrapped(ax_or_axs): + if isinstance(ax_or_axs, np.ndarray): + return AxesArray(ax_or_axs) + else: + return ax_or_axs + + def __getitem__(self, index): + return self._ensure_wrapped(self._array[index]) + + @property + def __array_struct__(self): + return self._array.__array_struct__ + + @property + def ndim(self): + return self._array.ndim + + @property + def shape(self): + return self._array.shape + + @property + def size(self): + return self._array.size + + @property + def flat(self): + return self._array.flat + + @property + def flatten(self): + """[Disouraged] Use ``axs.flat`` instead.""" + return self._array.flatten + + @property + def ravel(self): + """[Disouraged] Use ``axs.flat`` instead.""" + return self._array.ravel + + @property + def __iter__(self): + return iter([self._ensure_wrapped(row) for row in self._array]) + + def __getattr__(self, item): + # forward all other attributes to the underlying array + # (this is a temporary measure to allow a smooth transition) + attr = getattr(self._array, item) + _api.warn_deprecated("3.9", + message=f"Using {item!r} on AxesArray is deprecated.", + pending=True) + return attr diff --git a/lib/matplotlib/tests/test_subplots.py b/lib/matplotlib/tests/test_subplots.py index cf5f4b902e24..d067edd2543e 100644 --- a/lib/matplotlib/tests/test_subplots.py +++ b/lib/matplotlib/tests/test_subplots.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import matplotlib as mpl from matplotlib.axes import Axes, SubplotBase import matplotlib.pyplot as plt from matplotlib.testing.decorators import check_figures_equal, image_comparison @@ -283,3 +284,80 @@ def test_old_subplot_compat(): assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase) with pytest.raises(TypeError): Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1]) + + +class TestAxesArray: + @staticmethod + def contain_same_axes(axs1, axs2): + return all(ax1 is ax2 for ax1, ax2 in zip(axs1.flat, axs2.flat)) + + def test_1d(self): + axs = plt.figure().subplots(1, 3) + # shape and size + assert axs.shape == (3,) + assert axs.size == 3 + assert axs.ndim == 1 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 3 + # flatten + assert all(isinstance(ax, Axes) for ax in axs.flatten()) + assert len(set(id(ax) for ax in axs.flatten())) == 3 + # ravel + assert all(isinstance(ax, Axes) for ax in axs.ravel()) + assert len(set(id(ax) for ax in axs.ravel())) == 3 + # single index + assert all(isinstance(axs[i], Axes) for i in range(axs.size)) + assert len(set(axs[i] for i in range(axs.size))) == 3 +# iteration + assert all(ax1 is ax2 for ax1, ax2 in zip(axs, axs.flat)) + + def test_1d_no_squeeze(self): + axs = plt.figure().subplots(1, 3, squeeze=False) + # shape and size + assert axs.shape == (1, 3) + assert axs.size == 3 + assert axs.ndim == 2 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 3 + # 2d indexing + assert axs[0, 0] is axs.flat[0] + assert axs[0, 2] is axs.flat[-1] + # single index + axs_type = type(axs) + assert type(axs[0]) is axs_type + assert axs[0].shape == (3,) + # iteration + assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs)) + + def test_2d(self): + axs = plt.figure().subplots(2, 3) + # shape and size + assert axs.shape == (2, 3) + assert axs.size == 6 + assert axs.ndim == 2 + # flat + assert all(isinstance(ax, Axes) for ax in axs.flat) + assert len(set(id(ax) for ax in axs.flat)) == 6 + # flatten + assert all(isinstance(ax, Axes) for ax in axs.flatten()) + assert len(set(id(ax) for ax in axs.flatten())) == 6 + # ravel + assert all(isinstance(ax, Axes) for ax in axs.ravel()) + assert len(set(id(ax) for ax in axs.ravel())) == 6 + # 2d indexing + assert axs[0, 0] is axs.flat[0] + assert axs[1, 2] is axs.flat[-1] + # single index + axs_type = type(axs) + assert type(axs[0]) is axs_type + assert axs[0].shape == (3,) + # iteration + assert all(self.contain_same_axes(axi, axs[i]) for i, axi in enumerate(axs)) + + def test_deprecated(self): + axs = plt.figure().subplots(2, 2) + with pytest.warns(PendingDeprecationWarning, + match="Using 'diagonal' on AxesArray"): + axs.diagonal()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: