Skip to content

Commit 8f68377

Browse files
authored
BUG: Any dtype should call square on arr ** 2 (#29392)
* BUG: update fast_scalar_power to handle special-case squaring for any array type except object arrays * BUG: fix missing declaration * TST: add test to ensure `arr**2` calls square for structured dtypes * STY: remove whitespace * BUG: replace new variable `is_square` with direct op comparison in `fast_scalar_power` function
1 parent f13cf6e commit 8f68377

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

numpy/_core/src/multiarray/number.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ static int
332332
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
333333
{
334334
PyObject *fastop = NULL;
335+
335336
if (PyLong_CheckExact(o2)) {
336337
int overflow = 0;
337338
long exp = PyLong_AsLongAndOverflow(o2, &overflow);
@@ -363,7 +364,12 @@ fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
363364
}
364365

365366
PyArrayObject *a1 = (PyArrayObject *)o1;
366-
if (!(PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
367+
if (PyArray_ISOBJECT(a1)) {
368+
return 1;
369+
}
370+
if (fastop != n_ops.square && !PyArray_ISFLOAT(a1) && !PyArray_ISCOMPLEX(a1)) {
371+
// we special-case squaring for any array type
372+
// gh-29388
367373
return 1;
368374
}
369375

numpy/_core/tests/test_multiarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4215,6 +4215,13 @@ def pow_for(exp, arr):
42154215
assert_equal(obj_arr ** -1, pow_for(-1, obj_arr))
42164216
assert_equal(obj_arr ** 2, pow_for(2, obj_arr))
42174217

4218+
def test_pow_calls_square_structured_dtype(self):
4219+
# gh-29388
4220+
dt = np.dtype([('a', 'i4'), ('b', 'i4')])
4221+
a = np.array([(1, 2), (3, 4)], dtype=dt)
4222+
with pytest.raises(TypeError, match="ufunc 'square' not supported"):
4223+
a ** 2
4224+
42184225
def test_pos_array_ufunc_override(self):
42194226
class A(np.ndarray):
42204227
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy