Skip to content

Commit 52162af

Browse files
authored
MAINT: simplify power fast path logic (#27901)
* MAINT: remove fast paths from array power * MAINT: Add fast paths to power loops * MAINT: Clean loops for integer power in umath * MAINT: Remove blocking regression test for power fast paths * MAINT: Add helper function for power fast paths * BUG: Change misspelled bitwise and to logical and * BUG: Fix missing value on power helper return * BUG: Fix exponent bitwise logic in power fast paths * MAINT: Add power fast paths to floating point umath * MAINT: Add fast power paths to array power when exponent is python object * MAINT: Fix division by zero runtime warning in test regression * MAINT: Adapt object regression test for linalg to power fast paths * MAINT: Remove incorrect declarations in power fast paths * MAINT: Reduce calls to power fast path helper when scalar is ineligible * MAINT: Fix missing sliding loop * BUG: Fix syntax error * MAINT: Fix semantic misuse of -1 for non-error returns * MAINT: Improve error checking in power fast paths to remove PyErr_Clear * MAINT: Improve type checking in power fast paths * MAINT: Efficient handling of ones arrays in scalar fast paths * MAINT: Simplify outer check for scalar power fast paths * MAINT: Reduce code reuse in float power fast paths and add reciprocal * MAINT: Remove Python scalar checking for fast power paths * MAINT: Add benchmarks for power operators in float binary bench * MAINT: Add scalar power fast paths * BUG: Add missing pointer cast * BUG: Allow scalar power fast paths only for non-integers * MAINT: Restore outdated changes in regression test to master
1 parent 1d77082 commit 52162af

File tree

5 files changed

+105
-192
lines changed

5 files changed

+105
-192
lines changed

benchmarks/benchmarks/bench_ufunc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,12 @@ def time_pow_2(self, dtype):
588588
def time_pow_half(self, dtype):
589589
np.power(self.a, 0.5)
590590

591+
def time_pow_2_op(self, dtype):
592+
self.a ** 2
593+
594+
def time_pow_half_op(self, dtype):
595+
self.a ** 0.5
596+
591597
def time_atan2(self, dtype):
592598
np.arctan2(self.a, self.b)
593599

numpy/_core/src/multiarray/number.c

Lines changed: 36 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -328,165 +328,53 @@ array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
328328
return res;
329329
}
330330

331-
/*
332-
* Determine if object is a scalar and if so, convert the object
333-
* to a double and place it in the out_exponent argument
334-
* and return the "scalar kind" as a result. If the object is
335-
* not a scalar (or if there are other error conditions)
336-
* return NPY_NOSCALAR, and out_exponent is undefined.
337-
*/
338-
static NPY_SCALARKIND
339-
is_scalar_with_conversion(PyObject *o2, double* out_exponent)
331+
static int
332+
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
340333
{
341-
PyObject *temp;
342-
const int optimize_fpexps = 1;
343-
344-
if (PyLong_Check(o2)) {
345-
long tmp = PyLong_AsLong(o2);
346-
if (error_converting(tmp)) {
347-
PyErr_Clear();
348-
return NPY_NOSCALAR;
334+
PyObject *fastop = NULL;
335+
if (PyLong_CheckExact(o2)) {
336+
int overflow = 0;
337+
long exp = PyLong_AsLongAndOverflow(o2, &overflow);
338+
if (overflow != 0) {
339+
return -1;
349340
}
350-
*out_exponent = (double)tmp;
351-
return NPY_INTPOS_SCALAR;
352-
}
353341

354-
if (optimize_fpexps && PyFloat_Check(o2)) {
355-
*out_exponent = PyFloat_AsDouble(o2);
356-
return NPY_FLOAT_SCALAR;
357-
}
358-
359-
if (PyArray_Check(o2)) {
360-
if ((PyArray_NDIM((PyArrayObject *)o2) == 0) &&
361-
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
362-
(optimize_fpexps && PyArray_ISFLOAT((PyArrayObject *)o2))))) {
363-
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
364-
if (temp == NULL) {
365-
return NPY_NOSCALAR;
366-
}
367-
*out_exponent = PyFloat_AsDouble(o2);
368-
Py_DECREF(temp);
369-
if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
370-
return NPY_INTPOS_SCALAR;
371-
}
372-
else { /* ISFLOAT */
373-
return NPY_FLOAT_SCALAR;
374-
}
342+
if (exp == -1) {
343+
fastop = n_ops.reciprocal;
375344
}
376-
}
377-
else if (PyArray_IsScalar(o2, Integer) ||
378-
(optimize_fpexps && PyArray_IsScalar(o2, Floating))) {
379-
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
380-
if (temp == NULL) {
381-
return NPY_NOSCALAR;
382-
}
383-
*out_exponent = PyFloat_AsDouble(o2);
384-
Py_DECREF(temp);
385-
386-
if (PyArray_IsScalar(o2, Integer)) {
387-
return NPY_INTPOS_SCALAR;
345+
else if (exp == 2) {
346+
fastop = n_ops.square;
388347
}
389-
else { /* IsScalar(o2, Floating) */
390-
return NPY_FLOAT_SCALAR;
348+
else {
349+
return 1;
391350
}
392351
}
393-
else if (PyIndex_Check(o2)) {
394-
PyObject* value = PyNumber_Index(o2);
395-
Py_ssize_t val;
396-
if (value == NULL) {
397-
if (PyErr_Occurred()) {
398-
PyErr_Clear();
399-
}
400-
return NPY_NOSCALAR;
352+
else if (PyFloat_CheckExact(o2)) {
353+
double exp = PyFloat_AsDouble(o2);
354+
if (exp == 0.5) {
355+
fastop = n_ops.sqrt;
401356
}
402-
val = PyLong_AsSsize_t(value);
403-
Py_DECREF(value);
404-
if (error_converting(val)) {
405-
PyErr_Clear();
406-
return NPY_NOSCALAR;
357+
else {
358+
return 1;
407359
}
408-
*out_exponent = (double) val;
409-
return NPY_INTPOS_SCALAR;
410360
}
411-
return NPY_NOSCALAR;
412-
}
361+
else {
362+
return 1;
363+
}
413364

414-
/*
415-
* optimize float array or complex array to a scalar power
416-
* returns 0 on success, -1 if no optimization is possible
417-
* the result is in value (can be NULL if an error occurred)
418-
*/
419-
static int
420-
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace,
421-
PyObject **value)
422-
{
423-
double exponent;
424-
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
425-
426-
if (PyArray_Check(o1) &&
427-
!PyArray_ISOBJECT((PyArrayObject *)o1) &&
428-
((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
429-
PyArrayObject *a1 = (PyArrayObject *)o1;
430-
PyObject *fastop = NULL;
431-
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
432-
if (exponent == 1.0) {
433-
fastop = n_ops.positive;
434-
}
435-
else if (exponent == -1.0) {
436-
fastop = n_ops.reciprocal;
437-
}
438-
else if (exponent == 0.0) {
439-
fastop = n_ops._ones_like;
440-
}
441-
else if (exponent == 0.5) {
442-
fastop = n_ops.sqrt;
443-
}
444-
else if (exponent == 2.0) {
445-
fastop = n_ops.square;
446-
}
447-
else {
448-
return -1;
449-
}
365+
PyArrayObject *a1 = (PyArrayObject *)o1;
366+
if (!(PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
367+
return 1;
368+
}
450369

451-
if (inplace || can_elide_temp_unary(a1)) {
452-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
453-
}
454-
else {
455-
*value = PyArray_GenericUnaryFunction(a1, fastop);
456-
}
457-
return 0;
458-
}
459-
/* Because this is called with all arrays, we need to
460-
* change the output if the kind of the scalar is different
461-
* than that of the input and inplace is not on ---
462-
* (thus, the input should be up-cast)
463-
*/
464-
else if (exponent == 2.0) {
465-
fastop = n_ops.square;
466-
if (inplace) {
467-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
468-
}
469-
else {
470-
/* We only special-case the FLOAT_SCALAR and integer types */
471-
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
472-
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
473-
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
474-
PyArray_ISFORTRAN(a1));
475-
if (a1 != NULL) {
476-
/* cast always creates a new array */
477-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
478-
Py_DECREF(a1);
479-
}
480-
}
481-
else {
482-
*value = PyArray_GenericUnaryFunction(a1, fastop);
483-
}
484-
}
485-
return 0;
486-
}
370+
if (inplace || can_elide_temp_unary(a1)) {
371+
*result = PyArray_GenericInplaceUnaryFunction(a1, fastop);
487372
}
488-
/* no fast operation found */
489-
return -1;
373+
else {
374+
*result = PyArray_GenericUnaryFunction(a1, fastop);
375+
}
376+
377+
return 0;
490378
}
491379

492380
static PyObject *
@@ -643,7 +531,8 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
643531

644532
INPLACE_GIVE_UP_IF_NEEDED(
645533
a1, o2, nb_inplace_power, array_inplace_power);
646-
if (fast_scalar_power((PyObject *)a1, o2, 1, &value) != 0) {
534+
535+
if (fast_scalar_power((PyObject *) a1, o2, 1, &value) != 0) {
647536
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
648537
}
649538
return value;

numpy/_core/src/umath/loops.c.src

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -486,28 +486,54 @@ _@TYPE@_squared_exponentiation_helper(@type@ base, @type@ exponent_two, int firs
486486
return out;
487487
}
488488

489+
static inline @type@
490+
_@TYPE@_power_fast_path_helper(@type@ in1, @type@ in2, @type@ *op1) {
491+
// Fast path for power calculation
492+
if (in2 == 0 || in1 == 1) {
493+
*op1 = 1;
494+
}
495+
else if (in2 == 1) {
496+
*op1 = in1;
497+
}
498+
else if (in2 == 2) {
499+
*op1 = in1 * in1;
500+
}
501+
else {
502+
return 1;
503+
}
504+
return 0;
505+
}
506+
507+
489508
NPY_NO_EXPORT void
490509
@TYPE@_power(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
491510
{
492511
if (steps[1]==0) {
493512
// stride for second argument is 0
494513
BINARY_DEFS
495514
const @type@ in2 = *(@type@ *)ip2;
496-
#if @SIGNED@
497-
if (in2 < 0) {
498-
npy_gil_error(PyExc_ValueError,
499-
"Integers to negative integer powers are not allowed.");
500-
return;
501-
}
502-
#endif
515+
516+
#if @SIGNED@
517+
if (in2 < 0) {
518+
npy_gil_error(PyExc_ValueError,
519+
"Integers to negative integer powers are not allowed.");
520+
return;
521+
}
522+
#endif
503523

504524
int first_bit = in2 & 1;
505525
@type@ in2start = in2 >> 1;
506526

527+
int fastop_exists = (in2 == 0) || (in2 == 1) || (in2 == 2);
528+
507529
BINARY_LOOP_SLIDING {
508530
@type@ in1 = *(@type@ *)ip1;
509-
510-
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
531+
if (fastop_exists) {
532+
_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1);
533+
}
534+
else {
535+
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
536+
}
511537
}
512538
return;
513539
}
@@ -518,22 +544,16 @@ NPY_NO_EXPORT void
518544
#if @SIGNED@
519545
if (in2 < 0) {
520546
npy_gil_error(PyExc_ValueError,
521-
"Integers to negative integer powers are not allowed.");
547+
"Integers to negative integer powers are not allowed.");
522548
return;
523549
}
524550
#endif
525-
if (in2 == 0) {
526-
*((@type@ *)op1) = 1;
527-
continue;
528-
}
529-
if (in1 == 1) {
530-
*((@type@ *)op1) = 1;
531-
continue;
532-
}
533551

534-
int first_bit = in2 & 1;
535-
in2 >>= 1;
536-
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
552+
if (_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1) != 0) {
553+
int first_bit = in2 & 1;
554+
in2 >>= 1;
555+
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
556+
}
537557
}
538558
}
539559
/**end repeat**/

numpy/_core/src/umath/loops_umath_fp.dispatch.c.src

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,30 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@func@)
239239
if (stride_zero) {
240240
BINARY_DEFS
241241
const @type@ in2 = *(@type@ *)ip2;
242-
if (in2 == 2.0) {
243-
BINARY_LOOP_SLIDING {
244-
const @type@ in1 = *(@type@ *)ip1;
242+
int fastop_found = 1;
243+
BINARY_LOOP_SLIDING {
244+
const @type@ in1 = *(@type@ *)ip1;
245+
if (in2 == -1.0) {
246+
*(@type@ *)op1 = 1.0 / in1;
247+
}
248+
else if (in2 == 0.0) {
249+
*(@type@ *)op1 = 1.0;
250+
}
251+
else if (in2 == 0.5) {
252+
*(@type@ *)op1 = @sqrt@(in1);
253+
}
254+
else if (in2 == 1.0) {
255+
*(@type@ *)op1 = in1;
256+
}
257+
else if (in2 == 2.0) {
245258
*(@type@ *)op1 = in1 * in1;
246259
}
260+
else {
261+
fastop_found = 0;
262+
break;
263+
}
264+
}
265+
if (fastop_found) {
247266
return;
248267
}
249268
}

numpy/_core/tests/test_multiarray.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4125,27 +4125,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kw):
41254125
assert_equal(A[0], 30)
41264126
assert_(isinstance(A, OutClass))
41274127

4128-
def test_pow_override_with_errors(self):
4129-
# regression test for gh-9112
4130-
class PowerOnly(np.ndarray):
4131-
def __array_ufunc__(self, ufunc, method, *inputs, **kw):
4132-
if ufunc is not np.power:
4133-
raise NotImplementedError
4134-
return "POWER!"
4135-
# explicit cast to float, to ensure the fast power path is taken.
4136-
a = np.array(5., dtype=np.float64).view(PowerOnly)
4137-
assert_equal(a ** 2.5, "POWER!")
4138-
with assert_raises(NotImplementedError):
4139-
a ** 0.5
4140-
with assert_raises(NotImplementedError):
4141-
a ** 0
4142-
with assert_raises(NotImplementedError):
4143-
a ** 1
4144-
with assert_raises(NotImplementedError):
4145-
a ** -1
4146-
with assert_raises(NotImplementedError):
4147-
a ** 2
4148-
41494128
def test_pow_array_object_dtype(self):
41504129
# test pow on arrays of object dtype
41514130
class SomeClass:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy