Skip to content

MAINT: simplify power fast path logic #27901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6923108
MAINT: remove fast paths from array power
MaanasArora Dec 4, 2024
8196fbb
MAINT: Add fast paths to power loops
MaanasArora Dec 4, 2024
4090a1c
MAINT: Clean loops for integer power in umath
MaanasArora Dec 4, 2024
91dd9dd
MAINT: Remove blocking regression test for power fast paths
MaanasArora Dec 4, 2024
c0e88c8
MAINT: Add helper function for power fast paths
MaanasArora Dec 4, 2024
c00489d
BUG: Change misspelled bitwise and to logical and
MaanasArora Dec 4, 2024
1d9f355
BUG: Fix missing value on power helper return
MaanasArora Dec 4, 2024
4b2920c
BUG: Fix exponent bitwise logic in power fast paths
MaanasArora Dec 4, 2024
dd6e773
MAINT: Add power fast paths to floating point umath
MaanasArora Dec 4, 2024
c95bff2
MAINT: Add fast power paths to array power when exponent is python ob…
MaanasArora Dec 4, 2024
084416e
MAINT: Fix division by zero runtime warning in test regression
MaanasArora Dec 4, 2024
e23423c
MAINT: Adapt object regression test for linalg to power fast paths
MaanasArora Dec 5, 2024
7cad24e
MAINT: Remove incorrect declarations in power fast paths
MaanasArora Dec 5, 2024
c028996
MAINT: Reduce calls to power fast path helper when scalar is ineligible
MaanasArora Dec 5, 2024
3297309
MAINT: Fix missing sliding loop
MaanasArora Dec 5, 2024
455407f
BUG: Fix syntax error
MaanasArora Dec 5, 2024
df6f54a
MAINT: Fix semantic misuse of -1 for non-error returns
MaanasArora Dec 5, 2024
d10bce5
MAINT: Improve error checking in power fast paths to remove PyErr_Clear
MaanasArora Dec 5, 2024
21c12a6
MAINT: Improve type checking in power fast paths
MaanasArora Dec 5, 2024
c9929ff
MAINT: Efficient handling of ones arrays in scalar fast paths
MaanasArora Dec 5, 2024
ed449e7
MAINT: Simplify outer check for scalar power fast paths
MaanasArora Dec 6, 2024
ba24783
MAINT: Reduce code reuse in float power fast paths and add reciprocal
MaanasArora Dec 6, 2024
b5f8a2b
MAINT: Remove Python scalar checking for fast power paths
MaanasArora Dec 10, 2024
023d55a
MAINT: Add benchmarks for power operators in float binary bench
MaanasArora Dec 11, 2024
efbd6b8
MAINT: Add scalar power fast paths
MaanasArora Dec 11, 2024
31418e9
BUG: Add missing pointer cast
MaanasArora Dec 11, 2024
a5d0ef4
BUG: Allow scalar power fast paths only for non-integers
MaanasArora Dec 11, 2024
2e0f6ca
MAINT: Restore outdated changes in regression test to master
MaanasArora Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/benchmarks/bench_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,12 @@ def time_pow_2(self, dtype):
def time_pow_half(self, dtype):
np.power(self.a, 0.5)

def time_pow_2_op(self, dtype):
self.a ** 2

def time_pow_half_op(self, dtype):
self.a ** 0.5

def time_atan2(self, dtype):
np.arctan2(self.a, self.b)

Expand Down
183 changes: 36 additions & 147 deletions numpy/_core/src/multiarray/number.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,165 +328,53 @@ array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
return res;
}

/*
* Determine if object is a scalar and if so, convert the object
* to a double and place it in the out_exponent argument
* and return the "scalar kind" as a result. If the object is
* not a scalar (or if there are other error conditions)
* return NPY_NOSCALAR, and out_exponent is undefined.
*/
static NPY_SCALARKIND
is_scalar_with_conversion(PyObject *o2, double* out_exponent)
static int
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
{
PyObject *temp;
const int optimize_fpexps = 1;

if (PyLong_Check(o2)) {
long tmp = PyLong_AsLong(o2);
if (error_converting(tmp)) {
PyErr_Clear();
return NPY_NOSCALAR;
PyObject *fastop = NULL;
if (PyLong_CheckExact(o2)) {
int overflow = 0;
long exp = PyLong_AsLongAndOverflow(o2, &overflow);
if (overflow != 0) {
return -1;
}
*out_exponent = (double)tmp;
return NPY_INTPOS_SCALAR;
}

if (optimize_fpexps && PyFloat_Check(o2)) {
*out_exponent = PyFloat_AsDouble(o2);
return NPY_FLOAT_SCALAR;
}

if (PyArray_Check(o2)) {
if ((PyArray_NDIM((PyArrayObject *)o2) == 0) &&
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
(optimize_fpexps && PyArray_ISFLOAT((PyArrayObject *)o2))))) {
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
if (temp == NULL) {
return NPY_NOSCALAR;
}
*out_exponent = PyFloat_AsDouble(o2);
Py_DECREF(temp);
if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
return NPY_INTPOS_SCALAR;
}
else { /* ISFLOAT */
return NPY_FLOAT_SCALAR;
}
if (exp == -1) {
fastop = n_ops.reciprocal;
}
}
else if (PyArray_IsScalar(o2, Integer) ||
(optimize_fpexps && PyArray_IsScalar(o2, Floating))) {
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
if (temp == NULL) {
return NPY_NOSCALAR;
}
*out_exponent = PyFloat_AsDouble(o2);
Py_DECREF(temp);

if (PyArray_IsScalar(o2, Integer)) {
return NPY_INTPOS_SCALAR;
else if (exp == 2) {
fastop = n_ops.square;
}
else { /* IsScalar(o2, Floating) */
return NPY_FLOAT_SCALAR;
else {
return 1;
}
}
else if (PyIndex_Check(o2)) {
PyObject* value = PyNumber_Index(o2);
Py_ssize_t val;
if (value == NULL) {
if (PyErr_Occurred()) {
PyErr_Clear();
}
return NPY_NOSCALAR;
else if (PyFloat_CheckExact(o2)) {
double exp = PyFloat_AsDouble(o2);
if (exp == 0.5) {
fastop = n_ops.sqrt;
}
val = PyLong_AsSsize_t(value);
Py_DECREF(value);
if (error_converting(val)) {
PyErr_Clear();
return NPY_NOSCALAR;
else {
return 1;
}
*out_exponent = (double) val;
return NPY_INTPOS_SCALAR;
}
return NPY_NOSCALAR;
}
else {
return 1;
}

/*
* optimize float array or complex array to a scalar power
* returns 0 on success, -1 if no optimization is possible
* the result is in value (can be NULL if an error occurred)
*/
static int
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace,
PyObject **value)
{
double exponent;
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */

if (PyArray_Check(o1) &&
!PyArray_ISOBJECT((PyArrayObject *)o1) &&
((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
PyArrayObject *a1 = (PyArrayObject *)o1;
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
fastop = n_ops.positive;
}
else if (exponent == -1.0) {
fastop = n_ops.reciprocal;
}
else if (exponent == 0.0) {
fastop = n_ops._ones_like;
}
else if (exponent == 0.5) {
fastop = n_ops.sqrt;
}
else if (exponent == 2.0) {
fastop = n_ops.square;
}
else {
return -1;
}
PyArrayObject *a1 = (PyArrayObject *)o1;
if (!(PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
return 1;
}

if (inplace || can_elide_temp_unary(a1)) {
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
else {
*value = PyArray_GenericUnaryFunction(a1, fastop);
}
return 0;
}
/* Because this is called with all arrays, we need to
* change the output if the kind of the scalar is different
* than that of the input and inplace is not on ---
* (thus, the input should be up-cast)
*/
else if (exponent == 2.0) {
fastop = n_ops.square;
if (inplace) {
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
else {
/* We only special-case the FLOAT_SCALAR and integer types */
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
PyArray_ISFORTRAN(a1));
if (a1 != NULL) {
/* cast always creates a new array */
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
Py_DECREF(a1);
}
}
else {
*value = PyArray_GenericUnaryFunction(a1, fastop);
}
}
return 0;
}
if (inplace || can_elide_temp_unary(a1)) {
*result = PyArray_GenericInplaceUnaryFunction(a1, fastop);
}
/* no fast operation found */
return -1;
else {
*result = PyArray_GenericUnaryFunction(a1, fastop);
}

return 0;
}

static PyObject *
Expand Down Expand Up @@ -643,7 +531,8 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo

INPLACE_GIVE_UP_IF_NEEDED(
a1, o2, nb_inplace_power, array_inplace_power);
if (fast_scalar_power((PyObject *)a1, o2, 1, &value) != 0) {

if (fast_scalar_power((PyObject *) a1, o2, 1, &value) != 0) {
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
}
return value;
Expand Down
62 changes: 41 additions & 21 deletions numpy/_core/src/umath/loops.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -486,28 +486,54 @@ _@TYPE@_squared_exponentiation_helper(@type@ base, @type@ exponent_two, int firs
return out;
}

static inline @type@
_@TYPE@_power_fast_path_helper(@type@ in1, @type@ in2, @type@ *op1) {
// Fast path for power calculation
if (in2 == 0 || in1 == 1) {
*op1 = 1;
}
else if (in2 == 1) {
*op1 = in1;
}
else if (in2 == 2) {
*op1 = in1 * in1;
}
else {
return 1;
}
return 0;
}


NPY_NO_EXPORT void
@TYPE@_power(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
if (steps[1]==0) {
// stride for second argument is 0
BINARY_DEFS
const @type@ in2 = *(@type@ *)ip2;
#if @SIGNED@
if (in2 < 0) {
npy_gil_error(PyExc_ValueError,
"Integers to negative integer powers are not allowed.");
return;
}
#endif

#if @SIGNED@
if (in2 < 0) {
npy_gil_error(PyExc_ValueError,
"Integers to negative integer powers are not allowed.");
return;
}
#endif

int first_bit = in2 & 1;
@type@ in2start = in2 >> 1;

int fastop_exists = (in2 == 0) || (in2 == 1) || (in2 == 2);

BINARY_LOOP_SLIDING {
@type@ in1 = *(@type@ *)ip1;

*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
if (fastop_exists) {
_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1);
}
else {
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
}
}
return;
}
Expand All @@ -518,22 +544,16 @@ NPY_NO_EXPORT void
#if @SIGNED@
if (in2 < 0) {
npy_gil_error(PyExc_ValueError,
"Integers to negative integer powers are not allowed.");
"Integers to negative integer powers are not allowed.");
return;
}
#endif
if (in2 == 0) {
*((@type@ *)op1) = 1;
continue;
}
if (in1 == 1) {
*((@type@ *)op1) = 1;
continue;
}

int first_bit = in2 & 1;
in2 >>= 1;
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
if (_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1) != 0) {
int first_bit = in2 & 1;
in2 >>= 1;
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
}
}
}
/**end repeat**/
Expand Down
25 changes: 22 additions & 3 deletions numpy/_core/src/umath/loops_umath_fp.dispatch.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,30 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@func@)
if (stride_zero) {
BINARY_DEFS
const @type@ in2 = *(@type@ *)ip2;
if (in2 == 2.0) {
BINARY_LOOP_SLIDING {
const @type@ in1 = *(@type@ *)ip1;
int fastop_found = 1;
BINARY_LOOP_SLIDING {
const @type@ in1 = *(@type@ *)ip1;
if (in2 == -1.0) {
*(@type@ *)op1 = 1.0 / in1;
}
else if (in2 == 0.0) {
*(@type@ *)op1 = 1.0;
}
else if (in2 == 0.5) {
*(@type@ *)op1 = @sqrt@(in1);
}
else if (in2 == 1.0) {
*(@type@ *)op1 = in1;
}
else if (in2 == 2.0) {
*(@type@ *)op1 = in1 * in1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, here we also should maybe just call the normal multiply loop (with duplicated op1 pointer).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do with the SQRT!

}
else {
fastop_found = 0;
break;
}
}
if (fastop_found) {
return;
}
}
Expand Down
21 changes: 0 additions & 21 deletions numpy/_core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4123,27 +4123,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kw):
assert_equal(A[0], 30)
assert_(isinstance(A, OutClass))

def test_pow_override_with_errors(self):
# regression test for gh-9112
class PowerOnly(np.ndarray):
def __array_ufunc__(self, ufunc, method, *inputs, **kw):
if ufunc is not np.power:
raise NotImplementedError
return "POWER!"
# explicit cast to float, to ensure the fast power path is taken.
a = np.array(5., dtype=np.float64).view(PowerOnly)
assert_equal(a ** 2.5, "POWER!")
with assert_raises(NotImplementedError):
a ** 0.5
with assert_raises(NotImplementedError):
a ** 0
with assert_raises(NotImplementedError):
a ** 1
with assert_raises(NotImplementedError):
a ** -1
with assert_raises(NotImplementedError):
a ** 2

def test_pow_array_object_dtype(self):
# test pow on arrays of object dtype
class SomeClass:
Expand Down
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy