diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index e2481229..d7be6b47 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal from warnings import warn import pytest @@ -544,6 +544,10 @@ class UnaryCase(Case): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +r_nan_signbit = re.compile( + "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " + "the result is ``(.+)``" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase: + if signbit: + sign = -1 + nan_expr = "-NaN" + float_arg = "-nan" + else: + sign = 1 + nan_expr = "+NaN" + float_arg = "nan" + + return UnaryCase( + cond_expr=f"x_i is {nan_expr}", + cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign, + cond_from_dtype=lambda _: st.just(float(float_arg)), + result_expr=str(expected), + check_result=lambda _, result: result == float(expected), + ) + + def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck: def check_result(i: float, result: float) -> bool: return check_just_result(result) @@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - if m := r_already_int_case.search(case_str): + if r_already_int_case.search(case_str): cases.append(already_int_case) - elif m := r_even_round_halves_case.search(case_str): + elif r_even_round_halves_case.search(case_str): cases.append(even_round_halves_case) + elif m := r_nan_signbit.search(case_str): + signbit = parse_value(m.group(1)) + expected = bool(parse_value(m.group(2))) + cases.append(make_nan_signbit_case(signbit, expected)) elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy