diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index e868239e62ab..592940f3e328 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -25,7 +25,7 @@ import numpy as np import numpy.linalg._umath_linalg -from numpy import isfinite, isinf, isnan +from numpy import isfinite, isnan from numpy._core import arange, array, array_repr, empty, float32, intp, isnat, ndarray __all__ = [ @@ -758,17 +758,7 @@ def istime(x): def isvstring(x): return x.dtype.char == "T" - def func_assert_same_pos(x, y, func=isnan, hasval='nan'): - """Handling nan/inf. - - Combine results of running func on x and y, checking that they are True - at the same locations. - - """ - __tracebackhide__ = True # Hide traceback for py.test - - x_id = func(x) - y_id = func(y) + def robust_any_difference(x, y): # We include work-arounds here to handle three types of slightly # pathological ndarray subclasses: # (1) all() on fully masked arrays returns np.ma.masked, so we use != True @@ -781,10 +771,23 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): # not implement np.all(), so favor using the .all() method # We are not committed to supporting cases (2) and (3), but it's nice to # support them if possible. - result = x_id == y_id + result = x == y if not hasattr(result, "all") or not callable(result.all): result = np.bool(result) - if result.all() != True: + return result.all() != True + + def func_assert_same_pos(x, y, func=isnan, hasval='nan'): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + + x_id = func(x) + y_id = func(y) + if robust_any_difference(x_id, y_id): msg = build_err_msg( [x, y], err_msg + '\n%s location mismatch:' @@ -804,6 +807,29 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): else: return y_id + def assert_same_inf_values(x, y, infs_mask): + """ + Verify all inf values match in the two arrays + """ + __tracebackhide__ = True # Hide traceback for py.test + + if not infs_mask.any(): + return + if x.ndim > 0 and y.ndim > 0: + x = x[infs_mask] + y = y[infs_mask] + else: + assert infs_mask.all() + + if robust_any_difference(x, y): + msg = build_err_msg( + [x, y], + err_msg + '\ninf values mismatch:', + verbose=verbose, header=header, + names=names, + precision=precision) + raise AssertionError(msg) + try: if strict: cond = x.shape == y.shape and x.dtype == y.dtype @@ -828,12 +854,15 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') if equal_inf: - flagged |= func_assert_same_pos(x, y, - func=lambda xy: xy == +inf, - hasval='+inf') - flagged |= func_assert_same_pos(x, y, - func=lambda xy: xy == -inf, - hasval='-inf') + # If equal_nan=True, skip comparing nans below for equality if they are + # also infs (e.g. inf+nanj) since that would always fail. + isinf_func = lambda xy: np.logical_and(np.isinf(xy), np.invert(flagged)) + infs_mask = func_assert_same_pos( + x, y, + func=isinf_func, + hasval='inf') + assert_same_inf_values(x, y, infs_mask) + flagged |= infs_mask elif istime(x) and istime(y): # If one is datetime64 and the other timedelta64 there is no point @@ -1183,24 +1212,9 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='', """ __tracebackhide__ = True # Hide traceback for py.test from numpy._core import number, result_type - from numpy._core.fromnumeric import any as npany from numpy._core.numerictypes import issubdtype def compare(x, y): - try: - if npany(isinf(x)) or npany(isinf(y)): - xinfid = isinf(x) - yinfid = isinf(y) - if not (xinfid == yinfid).all(): - return False - # if one item, x and y is +- inf - if x.size == y.size == 1: - return x == y - x = x[~xinfid] - y = y[~yinfid] - except (TypeError, NotImplementedError): - pass - # make sure y is an inexact type to avoid abs(MIN_INT); will cause # casting of x later. dtype = result_type(y, 1.) diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 243b8d420936..07f2c9da3005 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -612,6 +612,18 @@ def test_inf(self): assert_raises(AssertionError, lambda: self._assert_func(a, b)) + def test_complex_inf(self): + a = np.array([np.inf + 1.j, 2. + 1.j, 3. + 1.j]) + b = a.copy() + self._assert_func(a, b) + b[1] = 3. + 1.j + expected_msg = ('Mismatched elements: 1 / 3 (33.3%)\n' + 'Mismatch at index:\n' + ' [1]: (2+1j) (ACTUAL), (3+1j) (DESIRED)\n' + 'Max absolute difference among violations: 1.\n') + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + self._assert_func(a, b) + def test_subclass(self): a = np.array([[1., 2.], [3., 4.]]) b = np.ma.masked_array([[1., 2.], [0., 4.]], @@ -1283,11 +1295,21 @@ def test_equal_nan(self): # Should not raise: assert_allclose(a, b, equal_nan=True) + a = np.array([complex(np.nan, np.inf)]) + b = np.array([complex(np.nan, np.inf)]) + assert_allclose(a, b, equal_nan=True) + b = np.array([complex(np.nan, -np.inf)]) + assert_allclose(a, b, equal_nan=True) + def test_not_equal_nan(self): a = np.array([np.nan]) b = np.array([np.nan]) assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False) + a = np.array([complex(np.nan, np.inf)]) + b = np.array([complex(np.nan, np.inf)]) + assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False) + def test_equal_nan_default(self): # Make sure equal_nan default behavior remains unchanged. (All # of these functions use assert_array_compare under the hood.) @@ -1336,6 +1358,33 @@ def test_strict(self): with pytest.raises(AssertionError): assert_allclose(x, x.astype(np.float32), strict=True) + def test_infs(self): + a = np.array([np.inf]) + b = np.array([np.inf]) + assert_allclose(a, b) + + b = np.array([3.]) + expected_msg = 'inf location mismatch:' + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + assert_allclose(a, b) + + b = np.array([-np.inf]) + expected_msg = 'inf values mismatch:' + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + assert_allclose(a, b) + b = np.array([complex(np.inf, 1.)]) + expected_msg = 'inf values mismatch:' + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + assert_allclose(a, b) + + a = np.array([complex(np.inf, 1.)]) + b = np.array([complex(np.inf, 1.)]) + assert_allclose(a, b) + + b = np.array([complex(np.inf, 2.)]) + expected_msg = 'inf values mismatch:' + with pytest.raises(AssertionError, match=re.escape(expected_msg)): + assert_allclose(a, b) class TestArrayAlmostEqualNulp: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy