diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 8b9b4b678..e922ead87 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -1,6 +1,6 @@ import itertools import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import numpy as np @@ -12,7 +12,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _dict_to_func, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -2279,7 +2286,11 @@ def apply(self, op, right=None, *, left=None): right = False # most basic form of 0 when unifying dtypes if left is not None: raise TypeError("Do not pass `left` when applying IndexUnaryOp") - + elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping): + if left is not None: + raise TypeError("Do not pass `left` when applying a Mapping") + op = _dict_to_func(op, right) + right = None if left is None and right is None: op = get_typed_op(op, self.dtype, kind="unary") self._expect_op( diff --git a/graphblas/core/operator.py b/graphblas/core/operator.py index eca7c9d75..78bf37cc9 100644 --- a/graphblas/core/operator.py +++ b/graphblas/core/operator.py @@ -3597,3 +3597,25 @@ def aggregator_from_string(string): from .. import agg # noqa: E402 isort:skip agg.from_string = aggregator_from_string + + +def _dict_to_func(d, default): + # This probably doesn't work on UDTs, and we could probably be smarter with dtypes + if default is None: + default = False + keys, vals = zip(*d.items()) + keys = np.array(keys) + lookup_dtype(keys.dtype) + vals = np.array(vals) + lookup_dtype(vals.dtype) + p = np.argsort(keys) + keys = keys[p] + vals = vals[p] + + def func(x): + i = np.searchsorted(keys, x) + if i < keys.size and keys[i] == x: + return vals[i] + return default + + return func diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index dd183d856..0ba8cec11 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1,5 +1,6 @@ import itertools import warnings +from collections.abc import Mapping import numpy as np @@ -11,7 +12,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _dict_to_func, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1315,7 +1323,11 @@ def apply(self, op, right=None, *, left=None): right = False # most basic form of 0 when unifying dtypes if left is not None: raise TypeError("Do not pass `left` when applying IndexUnaryOp") - + elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping): + if left is not None: + raise TypeError("Do not pass `left` when applying a Mapping") + op = _dict_to_func(op, right) + right = None if left is None and right is None: op = get_typed_op(op, self.dtype, kind="unary") self._expect_op( diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 40676f71a..160217235 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -1232,6 +1232,23 @@ def test_apply_indexunary(A): A.apply(select.valueeq, left=s3) +def test_apply_dict(): + rows = [0, 0, 0, 0] + cols = [1, 3, 4, 6] + vals = [1, 1, 2, 0] + V = Matrix.from_coo(rows, cols, vals) + # Use right as default + W1 = V.apply({1: 10, 2: 20}, 100).new() + expected = Matrix.from_coo(rows, cols, [10, 10, 20, 100]) + assert W1.isequal(expected) + # Default is 0 if unspecified + W2 = V.apply({0: 10, 2: 20}).new() + expected = Matrix.from_coo(rows, cols, [0, 0, 20, 10]) + assert W2.isequal(expected) + with pytest.raises(TypeError, match="left"): + V.apply({0: 10, 2: 20}, left=999) + + def test_select(A): A3 = Matrix.from_coo([0, 3, 3, 6], [3, 0, 2, 4], [3, 3, 3, 3], nrows=7, ncols=7) w1 = A.select(select.valueeq, 3).new() diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 8505313e4..7da626b3b 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -724,6 +724,29 @@ def test_apply_indexunary(v): v.apply(indexunary.valueeq, left=s2) +def test_apply_dict(v): + # Use right as default + w1 = v.apply({1: 10, 2: 20}, 100).new() + expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 100]) + assert w1.isequal(expected) + # Default is 0 if unspecified + w2 = v.apply({0: 10, 2: 20}).new() + expected = Vector.from_coo([1, 3, 4, 6], [0, 0, 20, 10]) + assert w2.isequal(expected) + # Scalar default can up-cast dtype + w3 = v.apply({1: 10, 2: 20}, 0.5).new() + expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 0.5]) + assert w3.isequal(expected) + with pytest.raises(TypeError, match="left"): + v.apply({0: 10, 2: 20}, left=999) + with pytest.raises(ValueError, match="Unknown dtype"): + v.apply({0: 10, 2: object()}) + import numba + + with pytest.raises(numba.TypingError): # TODO: this error and message should be better + v.apply({0: 10, 2: 20}, object()) + + def test_select(v): result = Vector.from_coo([1, 3], [1, 1], size=7) w1 = v.select(select.valueeq, 1).new()
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: