diff --git a/lib/matplotlib/dates.py b/lib/matplotlib/dates.py index bd27bef3759a..ef536e9c1740 100644 --- a/lib/matplotlib/dates.py +++ b/lib/matplotlib/dates.py @@ -118,6 +118,7 @@ import time import math import datetime +import functools import warnings @@ -732,20 +733,105 @@ def __call__(self, x, pos=None): class rrulewrapper(object): + def __init__(self, freq, tzinfo=None, **kwargs): + kwargs['freq'] = freq + self._base_tzinfo = tzinfo - def __init__(self, freq, **kwargs): - self._construct = kwargs.copy() - self._construct["freq"] = freq - self._rrule = rrule(**self._construct) + self._update_rrule(**kwargs) def set(self, **kwargs): self._construct.update(kwargs) + + self._update_rrule(**self._construct) + + def _update_rrule(self, **kwargs): + tzinfo = self._base_tzinfo + + # rrule does not play nicely with time zones - especially pytz time + # zones, it's best to use naive zones and attach timezones once the + # datetimes are returned + if 'dtstart' in kwargs: + dtstart = kwargs['dtstart'] + if dtstart.tzinfo is not None: + if tzinfo is None: + tzinfo = dtstart.tzinfo + else: + dtstart = dtstart.astimezone(tzinfo) + + kwargs['dtstart'] = dtstart.replace(tzinfo=None) + + if 'until' in kwargs: + until = kwargs['until'] + if until.tzinfo is not None: + if tzinfo is not None: + until = until.astimezone(tzinfo) + else: + raise ValueError('until cannot be aware if dtstart ' + 'is naive and tzinfo is None') + + kwargs['until'] = until.replace(tzinfo=None) + + self._construct = kwargs.copy() + self._tzinfo = tzinfo self._rrule = rrule(**self._construct) + def _attach_tzinfo(self, dt, tzinfo): + # pytz zones are attached by "localizing" the datetime + if hasattr(tzinfo, 'localize'): + return tzinfo.localize(dt, is_dst=True) + + return dt.replace(tzinfo=tzinfo) + + def _aware_return_wrapper(self, f, returns_list=False): + """Decorator function that allows rrule methods to handle tzinfo.""" + # This is only necessary if we're actually attaching a tzinfo + if self._tzinfo is None: + return f + + # All datetime arguments must be naive. If they are not naive, they are + # converted to the _tzinfo zone before dropping the zone. + def normalize_arg(arg): + if isinstance(arg, datetime.datetime) and arg.tzinfo is not None: + if arg.tzinfo is not self._tzinfo: + arg = arg.astimezone(self._tzinfo) + + return arg.replace(tzinfo=None) + + return arg + + def normalize_args(args, kwargs): + args = tuple(normalize_arg(arg) for arg in args) + kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()} + + return args, kwargs + + # There are two kinds of functions we care about - ones that return + # dates and ones that return lists of dates. + if not returns_list: + def inner_func(*args, **kwargs): + args, kwargs = normalize_args(args, kwargs) + dt = f(*args, **kwargs) + return self._attach_tzinfo(dt, self._tzinfo) + else: + def inner_func(*args, **kwargs): + args, kwargs = normalize_args(args, kwargs) + dts = f(*args, **kwargs) + return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts] + + return functools.wraps(f)(inner_func) + def __getattr__(self, name): if name in self.__dict__: return self.__dict__[name] - return getattr(self._rrule, name) + + f = getattr(self._rrule, name) + + if name in {'after', 'before'}: + return self._aware_return_wrapper(f) + elif name in {'xafter', 'xbefore', 'between'}: + return self._aware_return_wrapper(f, returns_list=True) + else: + return f def __setstate__(self, state): self.__dict__.update(state) @@ -1226,7 +1312,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None): bymonth = [x.item() for x in bymonth.astype(int)] rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday, - interval=interval, **self.hms0d) + interval=interval, **self.hms0d) RRuleLocator.__init__(self, rule, tz) diff --git a/lib/matplotlib/tests/test_dates.py b/lib/matplotlib/tests/test_dates.py index 5a25e6182b7e..792341ee1527 100644 --- a/lib/matplotlib/tests/test_dates.py +++ b/lib/matplotlib/tests/test_dates.py @@ -442,6 +442,24 @@ def tz_convert(*args): _test_date2num_dst(pd.date_range, tz_convert) +@pytest.mark.parametrize("attach_tz, get_tz", [ + (lambda dt, zi: zi.localize(dt), lambda n: pytz.timezone(n)), + (lambda dt, zi: dt.replace(tzinfo=zi), lambda n: dateutil.tz.gettz(n))]) +def test_rrulewrapper(attach_tz, get_tz): + SYD = get_tz('Australia/Sydney') + + dtstart = attach_tz(datetime.datetime(2017, 4, 1, 0), SYD) + dtend = attach_tz(datetime.datetime(2017, 4, 4, 0), SYD) + + rule = mdates.rrulewrapper(freq=dateutil.rrule.DAILY, dtstart=dtstart) + + act = rule.between(dtstart, dtend) + exp = [datetime.datetime(2017, 4, 1, 13, tzinfo=dateutil.tz.tzutc()), + datetime.datetime(2017, 4, 2, 14, tzinfo=dateutil.tz.tzutc())] + + assert act == exp + + def test_DayLocator(): with pytest.raises(ValueError): mdates.DayLocator(interval=-1)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: