Skip to content

[WIP] CSR/ CSC Elemwise #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Resize array
  • Loading branch information
ivirshup committed Apr 24, 2021
commit c9510af87a4467eed3005d6fee257808b1cd0674
74 changes: 48 additions & 26 deletions sparse/_compressed/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,27 @@ def op_union_indices(
# TODO: numpy is weird with bools here
out_dtype = np.array(op(a.data[0], b.data[0])).dtype
default_value = out_dtype.type(default_value)
return type(a)(
op_union_indices_csr_csr(
out_indptr = np.zeros_like(a.indptr)
out_indices = np.zeros(len(a.indices) + len(b.indices), dtype=np.promote_types(a.indices.dtype, b.indices.dtype))
out_data = np.zeros(len(out_indices), dtype=out_dtype)

nnz = op_union_indices_csr_csr(
op,
a.indptr,
a.indices,
a.data,
b.indptr,
b.indices,
b.data,
out_indptr,
out_indices,
out_data,
out_dtype=out_dtype,
default_value=default_value,
),
a.shape,
)
)
out_data.resize(nnz)
out_indices.resize(nnz)
return type(a)((out_data, out_indices, out_indptr), shape=a.shape)


@njit
Expand All @@ -65,12 +72,15 @@ def op_union_indices_csr_csr(
b_indptr: np.ndarray,
b_indices: np.ndarray,
b_data: np.ndarray,
out_indptr: np.ndarray,
out_indices: np.ndarray,
out_data: np.ndarray,
out_dtype,
default_value,
):
out_indptr = np.zeros_like(a_indptr)
out_indices = np.zeros(len(a_indices) + len(b_indices), dtype=a_indices.dtype)
out_data = np.zeros(len(out_indices), dtype=out_dtype)
# out_indptr = np.zeros_like(a_indptr)
# out_indices = np.zeros(len(a_indices) + len(b_indices), dtype=a_indices.dtype)
# out_data = np.zeros(len(out_indices), dtype=out_dtype)

out_idx = 0

Expand All @@ -85,38 +95,50 @@ def op_union_indices_csr_csr(
a_j = a_indices[a_idx]
b_j = b_indices[b_idx]
if a_j < b_j:
out_indices[out_idx] = a_j
out_data[out_idx] = op(a_data[a_idx], default_value)
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
elif b_j < a_j:
out_indices[out_idx] = b_j
out_data[out_idx] = op(default_value, b_data[b_idx])
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_j
out_data[out_idx] = val
out_idx += 1
b_idx += 1
else:
out_indices[out_idx] = a_j
out_data[out_idx] = op(a_data[a_idx], b_data[b_idx])
val = op(a_data[a_idx], b_data[b_idx])
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
b_idx += 1
out_idx += 1

# Catch up the other set
while a_idx < a_end:
a_j = a_indices[a_idx]
out_indices[out_idx] = a_j
out_data[out_idx] = op(a_data[a_idx], default_value)
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_indices[a_idx]
out_data[out_idx] = val
out_idx += 1
a_idx += 1
out_idx += 1

while b_idx < b_end:
b_j = b_indices[b_idx]
out_indices[out_idx] = b_j
out_data[out_idx] = op(default_value, b_data[b_idx])
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_indices[b_idx]
out_data[out_idx] = val
out_idx += 1
b_idx += 1
out_idx += 1

out_indptr[i + 1] = out_idx

out_indices = out_indices[: out_idx]
out_data = out_data[: out_idx]
# This may need to change to be "resize" to allow memory reallocation
# resize is currently not implemented in numba
# out_indices = out_indices[: out_idx]
# out_data = out_data[: out_idx]

return out_data, out_indices, out_indptr
return out_idx
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy