Skip to content

Implements ArrayAPI #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
new udpates
  • Loading branch information
sdpython committed Jun 8, 2023
commit 550d0dc6fe061d7075293afd7aae18542c2f4f40
2 changes: 2 additions & 0 deletions _unittests/ut_npx/test_sklearn_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_sklearn_array_api_linear_discriminant(self):
new_x = EagerNumpyTensor(X)
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
with config_context(array_api_dispatch=True):
# It fails if scikit-learn <= 1.2.2 because the ArrayAPI
# is not strictly applied.
got = ana.predict(new_x)
self.assertEqualArray(expected, got.numpy())

Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def template_asarray(
else:
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
if dtype is not None:
vt = v.astype(dtype=dtype)
vt = v.astype(dtype)
else:
vt = v
return vt
8 changes: 6 additions & 2 deletions onnx_array_api/npx/npx_jit_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def make_key(*values, **kwargs):
for iv, v in enumerate(values):
if isinstance(v, (Var, EagerTensor, JitTensor)):
res.append(v.key)
elif isinstance(v, (int, float)):
elif isinstance(v, (int, float, DType)):
res.append(v)
elif isinstance(v, slice):
res.append(("slice", v.start, v.stop, v.step))
Expand Down Expand Up @@ -344,7 +344,11 @@ def jit_call(self, *values, **kwargs):
self.info("+", "jit_call")
if self.input_to_kwargs_ is None:
# No jitting was ever called.
onx, fct = self.to_jit(*values, **kwargs)
try:
onx, fct = self.to_jit(*values, **kwargs)
except Exception as e:
raise RuntimeError(f"ERROR with self.f={self.f}, "
f"values={values!r}, kwargs={kwargs!r}") from e
if self.input_to_kwargs_ is None:
raise RuntimeError(
f"Attribute 'input_to_kwargs_' should be set for "
Expand Down
17 changes: 7 additions & 10 deletions onnx_array_api/npx/npx_tensors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any
from typing import Any, Union

import numpy as np
from onnx.helper import np_dtype_to_tensor_dtype

from .npx_types import DType, OptParType
from .npx_types import DType, ParType
from .npx_array_api import BaseArrayApi, ArrayApiError


Expand Down Expand Up @@ -74,7 +74,7 @@ def _getitem_impl_var(obj, index, method_name=None):
return meth(obj, index)

@staticmethod
def _astype_impl(x, dtype: OptParType[DType] = None, method_name=None):
def _astype_impl(x, dtype: ParType[DType], method_name=None):
# avoids circular imports.
if dtype is None:
raise ValueError("dtype cannot be None.")
Expand Down Expand Up @@ -182,18 +182,15 @@ def _np_dtype_to_tensor_dtype(dtype):
dtype = np.dtype("float64")
return np_dtype_to_tensor_dtype(dtype)

def _generic_method_astype(self, method_name, *args: Any, **kwargs: Any) -> Any:
def _generic_method_astype(self, method_name, dtype: Union[DType, "Var"], **kwargs: Any) -> Any:
# avoids circular imports.
from .npx_jit_eager import eager_onnx
from .npx_var import Var

if len(args) != 1:
raise ValueError(f"astype takes only one argument not {len(args)}.")

dtype = (
args[0]
if isinstance(args[0], (DType, Var))
else self._np_dtype_to_tensor_dtype(args[0])
dtype
if isinstance(dtype, (DType, Var))
else self._np_dtype_to_tensor_dtype(dtype)
)
eag = eager_onnx(EagerTensor._astype_impl, self.__class__, bypass_eager=True)
res = eag(self, dtype, method_name=method_name, already_eager=True, **kwargs)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy