Skip to content

ENH: Add support for complex weights in np.bincount #23641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions numpy/core/src/multiarray/compiled_base.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,11 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
{
PyObject *list = NULL, *weight = Py_None, *mlength = NULL;
PyArrayObject *lst = NULL, *ans = NULL, *wts = NULL;
npy_intp *numbers, *ians, len, mx, mn, ans_size;
npy_intp *numbers, len, mx, mn, ans_size;
npy_intp minlength = 0;
npy_intp i;
double *weights , *dans;
npy_uintp *ians, *iweights;
double *weights = NULL, *dans;

NPY_PREPARE_ARGPARSER;
if (npy_parse_arguments("bincount", args, len_args, kwnames,
Expand Down Expand Up @@ -183,39 +184,72 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
}
}
if (weight == Py_None) {
ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_INTP, 0);
ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_UINTP, 0);
if (ans == NULL) {
goto fail;
}
ians = (npy_intp *)PyArray_DATA(ans);
ians = (npy_uintp *)PyArray_DATA(ans);
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < len; i++)
ians[numbers[i]] += 1;
NPY_END_ALLOW_THREADS;
Py_DECREF(lst);
}
else {
wts = (PyArrayObject *)PyArray_ContiguousFromAny(
weight, NPY_DOUBLE, 1, 1);
wts = (PyArrayObject *)PyArray_FromAny(
weight, NULL, 1, 1, NPY_ARRAY_DEFAULT, NULL);
if (wts == NULL) {
goto fail;
}
weights = (double *)PyArray_DATA(wts);
if (PyArray_SIZE(wts) != len) {
PyErr_SetString(PyExc_ValueError,
"The weights and list don't have the same length.");
goto fail;
}
ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_DOUBLE, 0);
if (ans == NULL) {
if (PyArray_ISINTEGER(wts)) {
iweights = (npy_uintp *)PyArray_DATA(wts);
Copy link
Member

@seberg seberg Apr 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cast is wrong. You should stick with intp everywhere.
EDIT: Ops, that is unclear. The cast is wrong, not all integers are uintp (in fact it looks like you make it double here?). The intp comment is that uintp isn't the correct type unless you are preserving user input, you shouldn't change that.

One reason I haven't looked at this much yet is that I am not sure how far we actually want to go here, since np.add.at is faster now and is often a work-around (it doesn't auto-resize create/resize the result mainly).

Expanding this to more than just complex requires either proper templating (moving to c.src format or C++) or figuring out how to re-use the ufunc.at logic (which may well be easier, since that has fast loops now; but that needs some new logic like chunking the index to resize, which might affect performance, although I doubt much).

@mattip do you have a tought about this?

ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_UINTP, 0);
if (ans == NULL) {
goto fail;
}
ians = (npy_uintp *)PyArray_DATA(ans);
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < len; i++) {
ians[numbers[i]] += iweights[i];
}
NPY_END_ALLOW_THREADS;
} else if (PyArray_ISFLOAT(wts)) {
weights = (double *)PyArray_DATA(wts);
ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_DOUBLE, 0);
if (ans == NULL) {
goto fail;
}
dans = (double *)PyArray_DATA(ans);
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < len; i++) {
dans[numbers[i]] += weights[i];
}
NPY_END_ALLOW_THREADS;
} else if (PyArray_ISCOMPLEX(wts)) {
weights = (double *)PyArray_DATA(wts);
ans = (PyArrayObject *)PyArray_ZEROS(1, &ans_size, NPY_CDOUBLE, 0);
if (ans == NULL) {
goto fail;
}
dans = (double *)PyArray_DATA(ans);
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < len; i++) {
/* Add real parts */
dans[2 * numbers[i]] += weights[2 * i];
/* Add complex parts */
dans[2 * numbers[i] + 1] += weights[2 * i + 1];
}
NPY_END_ALLOW_THREADS;
} else {
PyErr_SetString(PyExc_TypeError,
"The weights array must only contain floats or complex numbers.");
goto fail;
}
dans = (double *)PyArray_DATA(ans);
NPY_BEGIN_ALLOW_THREADS;
for (i = 0; i < len; i++) {
dans[numbers[i]] += weights[i];
}
NPY_END_ALLOW_THREADS;
Py_DECREF(lst);
Py_DECREF(wts);
}
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy