diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index a1ca001fe587..8b64c7d6cb98 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -4,15 +4,35 @@ """ from __future__ import (absolute_import, division, print_function, unicode_literals) - import six import numpy as np -import matplotlib.cbook as cbook import matplotlib.units as units import matplotlib.ticker as ticker +# np 1.6/1.7 support +from distutils.version import LooseVersion +import collections + + +if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): + def shim_array(data): + return np.array(data, dtype=np.unicode) +else: + def shim_array(data): + if (isinstance(data, six.string_types) or + not isinstance(data, collections.Iterable)): + data = [data] + try: + data = [str(d) for d in data] + except UnicodeEncodeError: + # this yields gibberish but unicode text doesn't + # render under numpy1.6 anyway + data = [d.encode('utf-8', 'ignore').decode('utf-8') + for d in data] + return np.array(data, dtype=np.unicode) + class StrCategoryConverter(units.ConversionInterface): @staticmethod @@ -25,7 +45,8 @@ def convert(value, unit, axis): if isinstance(value, six.string_types): return vmap[value] - vals = np.array(value, dtype=np.unicode) + vals = shim_array(value) + for lab, loc in vmap.items(): vals[vals == lab] = loc @@ -81,8 +102,7 @@ def update(self, new_data): self._set_seq_locs(new_data, value) def _set_seq_locs(self, data, value): - strdata = np.array(data, dtype=np.unicode) - # np.unique makes dateframes work + strdata = shim_array(data) new_s = [d for d in np.unique(strdata) if d not in self.seq] for ns in new_s: self.seq.append(ns) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 83847e3150bb..6e5c43d76fb9 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -3,8 +3,6 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -from distutils.version import LooseVersion - import pytest import numpy as np @@ -14,11 +12,6 @@ import unittest -needs_new_numpy = pytest.mark.xfail( - LooseVersion(np.__version__) < LooseVersion('1.8.0'), - reason='NumPy < 1.8.0 is broken.') - - class TestUnitData(object): testdata = [("hello world", ["hello world"], [0]), ("Здравствуйте мир", ["Здравствуйте мир"], [0]), @@ -28,14 +21,12 @@ class TestUnitData(object): ids = ["single", "unicode", "mixed"] - @needs_new_numpy @pytest.mark.parametrize("data, seq, locs", testdata, ids=ids) def test_unit(self, data, seq, locs): act = cat.UnitData(data) assert act.seq == seq assert act.locs == locs - @needs_new_numpy def test_update_map(self): data = ['a', 'd'] oseq = ['a', 'd'] @@ -87,7 +78,6 @@ class TestStrCategoryConverter(object): def mock_axis(self, request): self.cc = cat.StrCategoryConverter() - @needs_new_numpy @pytest.mark.parametrize("data, unitmap, exp", testdata, ids=ids) def test_convert(self, data, unitmap, exp): MUD = MockUnitData(unitmap)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: