Skip to content

Commit 6305e8d

Browse files
authored
Merge pull request #16178 from yozhikoff/add-multiple-label-support
ENH: Add multiple label support for Axes.plot()
2 parents 99e6240 + a161ae3 commit 6305e8d

File tree

4 files changed

+105
-2
lines changed

4 files changed

+105
-2
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
An iterable object with labels can be passed to `.Axes.plot`
2+
------------------------------------------------------------
3+
4+
When plotting multiple datasets by passing 2D data as *y* value to
5+
`~.Axes.plot`, labels for the datasets can be passed as a list, the
6+
length matching the number of columns in *y*.
7+
8+
.. plot::
9+
10+
import matplotlib.pyplot as plt
11+
12+
x = [1, 2, 3]
13+
14+
y = [[1, 2],
15+
[2, 5],
16+
[4, 9]]
17+
18+
plt.plot(x, y, label=['low', 'high'])
19+
plt.legend()

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,8 @@ def plot(self, *args, scalex=True, scaley=True, data=None, **kwargs):
14991499
15001500
If you make multiple lines with one plot call, the kwargs
15011501
apply to all those lines.
1502+
In case if label object is iterable, each its element is
1503+
used as label for a separate line.
15021504
15031505
Here is a list of available `.Line2D` properties:
15041506

lib/matplotlib/axes/_base.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def __call__(self, *args, data=None, **kwargs):
294294
replaced[label_namer_idx], args[label_namer_idx])
295295
args = replaced
296296

297+
if len(args) >= 4 and not cbook.is_scalar_or_string(
298+
kwargs.get("label")):
299+
raise ValueError("plot() with multiple groups of data (i.e., "
300+
"pairs of x and y) does not support multiple "
301+
"labels")
302+
297303
# Repeatedly grab (x, y) or (x, y, format) from the front of args and
298304
# massage them into arguments to plot() or fill().
299305

@@ -447,8 +453,22 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
447453
ncx, ncy = x.shape[1], y.shape[1]
448454
if ncx > 1 and ncy > 1 and ncx != ncy:
449455
raise ValueError(f"x has {ncx} columns but y has {ncy} columns")
450-
result = (func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
451-
for j in range(max(ncx, ncy)))
456+
457+
label = kwargs.get('label')
458+
n_datasets = max(ncx, ncy)
459+
if n_datasets > 1 and not cbook.is_scalar_or_string(label):
460+
if len(label) != n_datasets:
461+
raise ValueError(f"label must be scalar or have the same "
462+
f"length as the input data, but found "
463+
f"{len(label)} for {n_datasets} datasets.")
464+
labels = label
465+
else:
466+
labels = [label] * n_datasets
467+
468+
result = (func(x[:, j % ncx], y[:, j % ncy], kw,
469+
{**kwargs, 'label': label})
470+
for j, label in enumerate(labels))
471+
452472
if return_kwargs:
453473
return list(result)
454474
else:

lib/matplotlib/tests/test_legend.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,65 @@ def test_no_warn_big_data_when_loc_specified():
671671
ax.plot(np.arange(5000), label=idx)
672672
legend = ax.legend('best')
673673
fig.draw_artist(legend) # Check that no warning is emitted.
674+
675+
676+
@pytest.mark.parametrize('label_array', [['low', 'high'],
677+
('low', 'high'),
678+
np.array(['low', 'high'])])
679+
def test_plot_multiple_input_multiple_label(label_array):
680+
# test ax.plot() with multidimensional input
681+
# and multiple labels
682+
x = [1, 2, 3]
683+
y = [[1, 2],
684+
[2, 5],
685+
[4, 9]]
686+
687+
fig, ax = plt.subplots()
688+
ax.plot(x, y, label=label_array)
689+
leg = ax.legend()
690+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
691+
assert legend_texts == ['low', 'high']
692+
693+
694+
@pytest.mark.parametrize('label', ['one', 1, int])
695+
def test_plot_multiple_input_single_label(label):
696+
# test ax.plot() with multidimensional input
697+
# and single label
698+
x = [1, 2, 3]
699+
y = [[1, 2],
700+
[2, 5],
701+
[4, 9]]
702+
703+
fig, ax = plt.subplots()
704+
ax.plot(x, y, label=label)
705+
leg = ax.legend()
706+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
707+
assert legend_texts == [str(label)] * 2
708+
709+
710+
@pytest.mark.parametrize('label_array', [['low', 'high'],
711+
('low', 'high'),
712+
np.array(['low', 'high'])])
713+
def test_plot_single_input_multiple_label(label_array):
714+
# test ax.plot() with 1D array like input
715+
# and iterable label
716+
x = [1, 2, 3]
717+
y = [2, 5, 6]
718+
fig, ax = plt.subplots()
719+
ax.plot(x, y, label=label_array)
720+
leg = ax.legend()
721+
assert len(leg.get_texts()) == 1
722+
assert leg.get_texts()[0].get_text() == str(label_array)
723+
724+
725+
def test_plot_multiple_label_incorrect_length_exception():
726+
# check that excepton is raised if multiple labels
727+
# are given, but number of on labels != number of lines
728+
with pytest.raises(ValueError):
729+
x = [1, 2, 3]
730+
y = [[1, 2],
731+
[2, 5],
732+
[4, 9]]
733+
label = ['high', 'low', 'medium']
734+
fig, ax = plt.subplots()
735+
ax.plot(x, y, label=label)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy