Skip to content

Commit b5737e8

Browse files
committed
ENH: add __array_function__ protocol in polynomial
1 parent c807e09 commit b5737e8

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

numpy/polynomial/polynomial.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import numpy as np
8484
import numpy.linalg as la
8585
from numpy.lib.array_utils import normalize_axis_index
86+
from numpy._core.overrides import array_function_dispatch
8687

8788
from . import polyutils as pu
8889
from ._polybase import ABCPolyBase
@@ -841,7 +842,13 @@ def polyvalfromroots(x, r, tensor=True):
841842
raise ValueError("x.ndim must be < r.ndim when tensor == False")
842843
return np.prod(x - r, axis=0)
843844

845+
def _polyval2d_dispatcher(x, y, c):
846+
return (x, y, c)
844847

848+
def _polygrid2d_dispatcher(x, y, c):
849+
return (x, y, c)
850+
851+
@array_function_dispatch(_polyval2d_dispatcher)
845852
def polyval2d(x, y, c):
846853
"""
847854
Evaluate a 2-D polynomial at points (x, y).
@@ -893,7 +900,7 @@ def polyval2d(x, y, c):
893900
"""
894901
return pu._valnd(polyval, c, x, y)
895902

896-
903+
@array_function_dispatch(_polygrid2d_dispatcher)
897904
def polygrid2d(x, y, c):
898905
"""
899906
Evaluate a 2-D polynomial on the Cartesian product of x and y.

numpy/polynomial/tests/test_polynomial.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,25 @@ def test_result_type(self):
667667

668668
arr = np.polydiv(1, np.float32(1))
669669
assert_equal(arr[0].dtype, np.float64)
670+
671+
class ArrayFunctionInterceptor:
672+
def __init__(self):
673+
self.called = False
674+
675+
def __array_function__(self, func, types, args, kwargs):
676+
self.called = True
677+
return "intercepted"
678+
679+
def test_polyval2d_array_function_hook():
680+
x = ArrayFunctionInterceptor()
681+
y = ArrayFunctionInterceptor()
682+
c = ArrayFunctionInterceptor()
683+
result = np.polynomial.polynomial.polyval2d(x, y, c)
684+
assert result == "intercepted"
685+
686+
def test_polygrid2d_array_function_hook():
687+
x = ArrayFunctionInterceptor()
688+
y = ArrayFunctionInterceptor()
689+
c = ArrayFunctionInterceptor()
690+
result = np.polynomial.polynomial.polygrid2d(x, y, c)
691+
assert result == "intercepted"

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy