@@ -195,8 +195,8 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
195
195
Py_DECREF (lst );
196
196
}
197
197
else {
198
- wts = (PyArrayObject * )PyArray_ContiguousFromAny (
199
- weight , NPY_DOUBLE , 1 , 1 );
198
+ wts = (PyArrayObject * )PyArray_FromAny (
199
+ weight , NULL , 1 , 1 , NPY_ARRAY_DEFAULT , NULL );
200
200
if (wts == NULL ) {
201
201
goto fail ;
202
202
}
@@ -206,16 +206,36 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
206
206
"The weights and list don't have the same length." );
207
207
goto fail ;
208
208
}
209
- ans = (PyArrayObject * )PyArray_ZEROS (1 , & ans_size , NPY_DOUBLE , 0 );
210
- if (ans == NULL ) {
209
+ if (PyArray_ISFLOAT (wts )) {
210
+ ans = (PyArrayObject * )PyArray_ZEROS (1 , & ans_size , NPY_DOUBLE , 0 );
211
+ if (ans == NULL ) {
212
+ goto fail ;
213
+ }
214
+ dans = (double * )PyArray_DATA (ans );
215
+ NPY_BEGIN_ALLOW_THREADS ;
216
+ for (i = 0 ; i < len ; i ++ ) {
217
+ dans [numbers [i ]] += weights [i ];
218
+ }
219
+ NPY_END_ALLOW_THREADS ;
220
+ } else if (PyArray_ISCOMPLEX (wts )) {
221
+ ans = (PyArrayObject * )PyArray_ZEROS (1 , & ans_size , NPY_CDOUBLE , 0 );
222
+ if (ans == NULL ) {
223
+ goto fail ;
224
+ }
225
+ dans = (double * )PyArray_DATA (ans );
226
+ NPY_BEGIN_ALLOW_THREADS ;
227
+ for (i = 0 ; i < len ; i ++ ) {
228
+ /* Add real parts */
229
+ dans [2 * numbers [i ]] += weights [2 * i ];
230
+ /* Add complex parts */
231
+ dans [2 * numbers [i ] + 1 ] += weights [2 * i + 1 ];
232
+ }
233
+ NPY_END_ALLOW_THREADS ;
234
+ } else {
235
+ PyErr_SetString (PyExc_TypeError ,
236
+ "The weights array must only contain floats or complex numbers." );
211
237
goto fail ;
212
238
}
213
- dans = (double * )PyArray_DATA (ans );
214
- NPY_BEGIN_ALLOW_THREADS ;
215
- for (i = 0 ; i < len ; i ++ ) {
216
- dans [numbers [i ]] += weights [i ];
217
- }
218
- NPY_END_ALLOW_THREADS ;
219
239
Py_DECREF (lst );
220
240
Py_DECREF (wts );
221
241
}
0 commit comments