Skip to content

Commit 32ad385

Browse files
authored
Fixes Array API with onnxruntime (#3)
* Check Array API with onnxruntime * better error message * improvment * disable one test for older version of sklearn * add one more pipeline * fix pipeline * fix array api * remove unnecessary code * disable one test one the current scikit-learn version
1 parent 062b6c1 commit 32ad385

18 files changed

+265
-68
lines changed

CHANGELOGS.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Change Logs
2+
===========
3+
4+
0.2.0
5+
+++++
6+
7+
* :pr:`3`: fixes Array API with onnxruntime

_doc/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
}
5858

5959
epkg_dictionary = {
60+
"Array API": "https://data-apis.org/array-api/",
61+
"ArrayAPI": (
62+
"https://data-apis.org/array-api/",
63+
("2022.12/API_specification/generated/array_api.{0}.html", 1),
64+
),
6065
"DOT": "https://graphviz.org/doc/info/lang.html",
6166
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
6267
"onnx": "https://onnx.ai/onnx/",
@@ -65,7 +70,7 @@
6570
"numpy": "https://numpy.org/",
6671
"numba": "https://numba.pydata.org/",
6772
"onnx-array-api": (
68-
"http://www.xavierdupre.fr/app/" "onnx-array-api/helpsphinx/index.html"
73+
"http://www.xavierdupre.fr/app/onnx-array-api/helpsphinx/index.html"
6974
),
7075
"pyinstrument": "https://github.com/joerick/pyinstrument",
7176
"python": "https://www.python.org/",

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ well as to execute it.
3434
tutorial/index
3535
api/index
3636
auto_examples/index
37+
../CHANGELOGS
3738

3839
Sources available on
3940
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_,

_unittests/ut_npx/test_sklearn_array_api.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import unittest
22
import numpy as np
3+
from packaging.version import Version
34
from onnx.defs import onnx_opset_version
4-
from sklearn import config_context
5+
from sklearn import config_context, __version__ as sklearn_version
56
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
67
from onnx_array_api.ext_test_case import ExtTestCase
78
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
@@ -10,23 +11,15 @@
1011
DEFAULT_OPSET = onnx_opset_version()
1112

1213

13-
def take(self, X, indices, *, axis):
14-
# Overwritting method take as it is using iterators.
15-
# When array_api supports `take` we can use this directly
16-
# https://github.com/data-apis/array-api/issues/177
17-
X_np = self._namespace.take(X, indices, axis=axis)
18-
return self._namespace.asarray(X_np)
19-
20-
2114
class TestSklearnArrayAPI(ExtTestCase):
15+
@unittest.skipIf(
16+
Version(sklearn_version) <= Version("1.2.2"),
17+
reason="reshape ArrayAPI not followed",
18+
)
2219
def test_sklearn_array_api_linear_discriminant(self):
23-
from sklearn.utils._array_api import _ArrayAPIWrapper
24-
25-
_ArrayAPIWrapper.take = take
2620
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
2721
y = np.array([1, 1, 1, 2, 2, 2])
2822
ana = LinearDiscriminantAnalysis()
29-
ana = LinearDiscriminantAnalysis()
3023
ana.fit(X, y)
3124
expected = ana.predict(X)
3225

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
import numpy as np
3+
from packaging.version import Version
4+
from onnx.defs import onnx_opset_version
5+
from sklearn import config_context, __version__ as sklearn_version
6+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor
9+
10+
11+
DEFAULT_OPSET = onnx_opset_version()
12+
13+
14+
class TestSklearnArrayAPIOrt(ExtTestCase):
15+
@unittest.skipIf(
16+
Version(sklearn_version) <= Version("1.2.2"),
17+
reason="reshape ArrayAPI not followed",
18+
)
19+
def test_sklearn_array_api_linear_discriminant(self):
20+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
21+
y = np.array([1, 1, 1, 2, 2, 2])
22+
ana = LinearDiscriminantAnalysis()
23+
ana.fit(X, y)
24+
expected = ana.predict(X)
25+
26+
new_x = EagerOrtTensor(OrtTensor.from_array(X))
27+
self.assertEqual(new_x.device_name, "Cpu")
28+
self.assertStartsWith(
29+
"EagerOrtTensor(OrtTensor.from_array(array([[", repr(new_x)
30+
)
31+
with config_context(array_api_dispatch=True):
32+
got = ana.predict(new_x)
33+
self.assertEqualArray(expected, got.numpy())
34+
35+
36+
if __name__ == "__main__":
37+
# import logging
38+
# logging.basicConfig(level=logging.DEBUG)
39+
unittest.main(verbosity=2)

azure-pipelines.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,53 @@ jobs:
4343
artifactName: 'wheel-linux-wheel-$(python.version)'
4444
targetPath: 'dist'
4545

46+
- job: 'TestLinuxNightly'
47+
pool:
48+
vmImage: 'ubuntu-latest'
49+
strategy:
50+
matrix:
51+
Python310-Linux:
52+
python.version: '3.11'
53+
maxParallel: 3
54+
55+
steps:
56+
- task: UsePythonVersion@0
57+
inputs:
58+
versionSpec: '$(python.version)'
59+
architecture: 'x64'
60+
- script: sudo apt-get update
61+
displayName: 'AptGet Update'
62+
- script: sudo apt-get install -y pandoc
63+
displayName: 'Install Pandoc'
64+
- script: sudo apt-get install -y inkscape
65+
displayName: 'Install Inkscape'
66+
- script: sudo apt-get install -y graphviz
67+
displayName: 'Install Graphviz'
68+
- script: python -m pip install --upgrade pip setuptools wheel
69+
displayName: 'Install tools'
70+
- script: pip install -r requirements.txt
71+
displayName: 'Install Requirements'
72+
- script: pip install -r requirements-dev.txt
73+
displayName: 'Install Requirements dev'
74+
- script: pip uninstall -y scikit-learn
75+
displayName: 'Uninstall scikit-learn'
76+
- script: pip install --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn
77+
displayName: 'Install scikit-learn nightly'
78+
- script: pip install onnxmltools --no-deps
79+
displayName: 'Install onnxmltools'
80+
- script: |
81+
ruff .
82+
displayName: 'Ruff'
83+
- script: |
84+
rstcheck -r ./_doc ./onnx_array_api
85+
displayName: 'rstcheck'
86+
- script: |
87+
black --diff .
88+
displayName: 'Black'
89+
- script: |
90+
python -m pytest -v
91+
displayName: 'Runs Unit Tests'
92+
4693
- job: 'TestLinux'
4794
pool:
4895
vmImage: 'ubuntu-latest'

onnx_array_api/npx/npx_array_api.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
import numpy as np
44

55
from .npx_types import OptParType, ParType, TupleType
66

77

8+
class ArrayApiError(RuntimeError):
9+
"""
10+
Raised when a function is not supported by the :epkg:`Array API`.
11+
"""
12+
13+
pass
14+
15+
816
class ArrayApi:
917
"""
1018
List of supported method by a tensor.
1119
"""
1220

13-
def __array_namespace__(self):
21+
def __array_namespace__(self, api_version: Optional[str] = None):
1422
"""
1523
Returns the module holding all the available functions.
1624
"""
17-
from onnx_array_api.npx import npx_functions
25+
if api_version is None or api_version == "2022.12":
26+
from onnx_array_api.npx import npx_functions
1827

19-
return npx_functions
28+
return npx_functions
29+
raise ValueError(
30+
f"Unable to return an implementation for api_version={api_version!r}."
31+
)
2032

2133
def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
2234
raise NotImplementedError(

onnx_array_api/npx/npx_core_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,10 @@ def npxapi_inline(fn):
252252
to call.
253253
"""
254254
return _xapi(fn, inline=True)
255+
256+
257+
def npxapi_no_inline(fn):
258+
"""
259+
Functions decorated with this decorator are not converted into ONNX.
260+
"""
261+
return fn

onnx_array_api/npx/npx_functions.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from typing import Any, Optional, Tuple, Union
22

3+
import array_api_compat.numpy as np_array_api
34
import numpy as np
45
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
56
from onnx.helper import np_dtype_to_tensor_dtype
67
from onnx.numpy_helper import from_array
78

89
from .npx_constants import FUNCTION_DOMAIN
9-
from .npx_core_api import cst, make_tuple, npxapi_inline, var
10+
from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var
1011
from .npx_tensors import ArrayApi
1112
from .npx_types import (
13+
DType,
1214
ElemType,
1315
OptParType,
1416
ParType,
@@ -397,6 +399,17 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]:
397399
return v
398400

399401

402+
@npxapi_no_inline
403+
def isdtype(
404+
dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]]
405+
) -> bool:
406+
"""
407+
See :epkg:`ArrayAPI:isdtype`.
408+
This function is not converted into an onnx graph.
409+
"""
410+
return np_array_api.isdtype(dtype, kind)
411+
412+
400413
@npxapi_inline
401414
def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]:
402415
"See :func:`numpy.isnan`."
@@ -460,9 +473,23 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics,
460473

461474
@npxapi_inline
462475
def reshape(
463-
x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"]
476+
x: TensorType[ElemType.numerics, "T"],
477+
shape: TensorType[ElemType.int64, "I", (None,)],
464478
) -> TensorType[ElemType.numerics, "T"]:
465-
"See :func:`numpy.reshape`."
479+
"""
480+
See :func:`numpy.reshape`.
481+
482+
.. warning::
483+
484+
Numpy definition is tricky because onnxruntime does not handle well
485+
dimensions with an undefined number of dimensions.
486+
However the array API defines a more stricly signature for
487+
`reshape <https://data-apis.org/array-api/2022.12/
488+
API_specification/generated/array_api.reshape.html>`_.
489+
:epkg:`scikit-learn` updated its code to follow the Array API in
490+
`PR 26030 ENH Forces shape to be tuple when using Array API's reshape
491+
<https://github.com/scikit-learn/scikit-learn/pull/26030>`_.
492+
"""
466493
if isinstance(shape, int):
467494
shape = cst(np.array([shape], dtype=np.int64))
468495
shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,23 @@ def to_onnx(
798798
node_inputs.append(input_name)
799799
continue
800800

801+
if isinstance(i, tuple) and all(map(lambda x: isinstance(x, int), i)):
802+
ai = np.array(list(i), dtype=np.int64)
803+
c = Cst(ai)
804+
input_name = self._unique(var._prefix)
805+
self._id_vars[id(i), index] = input_name
806+
self._id_vars[id(c), index] = input_name
807+
self.make_node(
808+
"Constant",
809+
[],
810+
[input_name],
811+
value=from_array(ai),
812+
opset=self.target_opsets[""],
813+
)
814+
self.onnx_names_[input_name] = c
815+
node_inputs.append(input_name)
816+
continue
817+
801818
raise NotImplementedError(
802819
f"Unexpected type {type(i)} for node={domop}."
803820
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy