From 70fb777e6e68b60f382d581d06d4c2d28ba29294 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 12 Apr 2024 13:20:31 +0100 Subject: [PATCH] Fix special case testing signbit on NaNs --- array_api_tests/test_special_cases.py | 33 ++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index e2481229..d7be6b47 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal from warnings import warn import pytest @@ -544,6 +544,10 @@ class UnaryCase(Case): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +r_nan_signbit = re.compile( + "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " + "the result is ``(.+)``" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase: + if signbit: + sign = -1 + nan_expr = "-NaN" + float_arg = "-nan" + else: + sign = 1 + nan_expr = "+NaN" + float_arg = "nan" + + return UnaryCase( + cond_expr=f"x_i is {nan_expr}", + cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign, + cond_from_dtype=lambda _: st.just(float(float_arg)), + result_expr=str(expected), + check_result=lambda _, result: result == float(expected), + ) + + def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck: def check_result(i: float, result: float) -> bool: return check_just_result(result) @@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - if m := r_already_int_case.search(case_str): + if r_already_int_case.search(case_str): cases.append(already_int_case) - elif m := r_even_round_halves_case.search(case_str): + elif r_even_round_halves_case.search(case_str): cases.append(even_round_halves_case) + elif m := r_nan_signbit.search(case_str): + signbit = parse_value(m.group(1)) + expected = bool(parse_value(m.group(2))) + cases.append(make_nan_signbit_case(signbit, expected)) elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy