Skip to content

Commit 06f4d5e

Browse files
authored
ENH: add __array_function__ protocol in polynomial (#28996)
1 parent 5378d3d commit 06f4d5e

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

numpy/polynomial/polynomial.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282

8383
import numpy as np
8484
import numpy.linalg as la
85+
from numpy._core.overrides import array_function_dispatch
8586
from numpy.lib.array_utils import normalize_axis_index
8687

8788
from . import polyutils as pu
@@ -841,7 +842,13 @@ def polyvalfromroots(x, r, tensor=True):
841842
raise ValueError("x.ndim must be < r.ndim when tensor == False")
842843
return np.prod(x - r, axis=0)
843844

845+
def _polyval2d_dispatcher(x, y, c):
846+
return (x, y, c)
844847

848+
def _polygrid2d_dispatcher(x, y, c):
849+
return (x, y, c)
850+
851+
@array_function_dispatch(_polyval2d_dispatcher)
845852
def polyval2d(x, y, c):
846853
"""
847854
Evaluate a 2-D polynomial at points (x, y).
@@ -893,7 +900,7 @@ def polyval2d(x, y, c):
893900
"""
894901
return pu._valnd(polyval, c, x, y)
895902

896-
903+
@array_function_dispatch(_polygrid2d_dispatcher)
897904
def polygrid2d(x, y, c):
898905
"""
899906
Evaluate a 2-D polynomial on the Cartesian product of x and y.

numpy/polynomial/tests/test_polynomial.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,25 @@ def test_result_type(self):
667667

668668
arr = np.polydiv(1, np.float32(1))
669669
assert_equal(arr[0].dtype, np.float64)
670+
671+
class ArrayFunctionInterceptor:
672+
def __init__(self):
673+
self.called = False
674+
675+
def __array_function__(self, func, types, args, kwargs):
676+
self.called = True
677+
return "intercepted"
678+
679+
def test_polyval2d_array_function_hook():
680+
x = ArrayFunctionInterceptor()
681+
y = ArrayFunctionInterceptor()
682+
c = ArrayFunctionInterceptor()
683+
result = np.polynomial.polynomial.polyval2d(x, y, c)
684+
assert result == "intercepted"
685+
686+
def test_polygrid2d_array_function_hook():
687+
x = ArrayFunctionInterceptor()
688+
y = ArrayFunctionInterceptor()
689+
c = ArrayFunctionInterceptor()
690+
result = np.polynomial.polynomial.polygrid2d(x, y, c)
691+
assert result == "intercepted"

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy