Skip to content

Commit ac4acc6

Browse files
authored
Fix as_tensor in onnx_text_plot_tree (#101)
* Fix as_tensor * fix issues * lint * fix clean * atol * fix issues
1 parent 96eb50e commit ac4acc6

File tree

9 files changed

+78
-86
lines changed

9 files changed

+78
-86
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.2
5+
+++++
6+
7+
* :pr:`101`: fix as_tensor in onnx_text_plot_tree
8+
49
0.3.1
510
+++++
611

_unittests/ut_light_api/test_backend_export.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
make_opsetid,
2020
make_tensor_value_info,
2121
)
22-
from onnx.reference.op_run import to_array_extended
22+
23+
try:
24+
from onnx.reference.op_run import to_array_extended
25+
except ImportError:
26+
from onnx.numpy_helper import to_array as to_array_extended
2327
from onnx.numpy_helper import from_array, to_array
2428
from onnx.backend.base import Device, DeviceType
2529
from onnx_array_api.reference import ExtendedReferenceEvaluator
@@ -240,7 +244,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
240244
raise NotImplementedError("Unable to run the model node by node.")
241245

242246

243-
backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__)
247+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
248+
backend_test = onnx.backend.test.BackendTest(
249+
ExportBackend,
250+
__name__,
251+
test_kwargs={
252+
"test_dft": {"atol": dft_atol},
253+
"test_dft_axis": {"atol": dft_atol},
254+
"test_dft_axis_opset19": {"atol": dft_atol},
255+
"test_dft_inverse": {"atol": dft_atol},
256+
"test_dft_inverse_opset19": {"atol": dft_atol},
257+
"test_dft_opset19": {"atol": dft_atol},
258+
},
259+
)
244260

245261
# The following tests are too slow with the reference implementation (Conv).
246262
backend_test.exclude(

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
import sys
34
import unittest
45
from typing import Any
56
import numpy
@@ -78,10 +79,21 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
7879
raise NotImplementedError("Unable to run the model node by node.")
7980

8081

82+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
8183
backend_test = onnx.backend.test.BackendTest(
82-
ExtendedReferenceEvaluatorBackend, __name__
84+
ExtendedReferenceEvaluatorBackend,
85+
__name__,
86+
test_kwargs={
87+
"test_dft": {"atol": dft_atol},
88+
"test_dft_axis": {"atol": dft_atol},
89+
"test_dft_axis_opset19": {"atol": dft_atol},
90+
"test_dft_inverse": {"atol": dft_atol},
91+
"test_dft_inverse_opset19": {"atol": dft_atol},
92+
"test_dft_opset19": {"atol": dft_atol},
93+
},
8394
)
8495

96+
8597
if os.getenv("APPVEYOR"):
8698
backend_test.exclude("(test_vgg19|test_zfnet)")
8799
if platform.architecture()[0] == "32bit":

azure-pipelines.yml

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,63 +93,6 @@ jobs:
9393
python -m pytest
9494
displayName: 'Runs Unit Tests'
9595
96-
- job: 'TestLinuxArrayApi'
97-
pool:
98-
vmImage: 'ubuntu-latest'
99-
strategy:
100-
matrix:
101-
Python310-Linux:
102-
python.version: '3.10'
103-
maxParallel: 3
104-
105-
steps:
106-
- task: UsePythonVersion@0
107-
inputs:
108-
versionSpec: '$(python.version)'
109-
architecture: 'x64'
110-
- script: sudo apt-get update
111-
displayName: 'AptGet Update'
112-
- script: python -m pip install --upgrade pip setuptools wheel
113-
displayName: 'Install tools'
114-
- script: pip install -r requirements.txt
115-
displayName: 'Install Requirements'
116-
- script: pip install onnxruntime
117-
displayName: 'Install onnxruntime'
118-
- script: python setup.py install
119-
displayName: 'Install onnx_array_api'
120-
- script: |
121-
git clone https://github.com/data-apis/array-api-tests.git
122-
displayName: 'clone array-api-tests'
123-
- script: |
124-
cd array-api-tests
125-
git submodule update --init --recursive
126-
cd ..
127-
displayName: 'get submodules for array-api-tests'
128-
- script: pip install -r array-api-tests/requirements.txt
129-
displayName: 'Install Requirements dev'
130-
- script: |
131-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
132-
cd array-api-tests
133-
displayName: 'Set API'
134-
- script: |
135-
python -m pip freeze
136-
displayName: 'pip freeze'
137-
- script: |
138-
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
139-
cd array-api-tests
140-
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt --hypothesis-explain
141-
displayName: "numpy test_creation_functions.py"
142-
# - script: |
143-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
144-
# cd array-api-tests
145-
# python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt --hypothesis-explain
146-
# displayName: "ort test_creation_functions.py"
147-
#- script: |
148-
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
149-
# cd array-api-tests
150-
# python -m pytest -x array_api_tests
151-
# displayName: "all tests"
152-
15396
- job: 'TestLinux'
15497
pool:
15598
vmImage: 'ubuntu-latest'

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__ = "0.3.1"
5+
__version__ = "0.3.2"
66
__author__ = "Xavier Dupré"

onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464
self.nodes_missing_value_tracks_true = None
6565
for k, v in atts.items():
6666
if k.startswith("nodes"):
67-
setattr(self, k, v[i])
67+
if k.endswith("_as_tensor"):
68+
setattr(self, k.replace("_as_tensor", ""), v[i])
69+
else:
70+
setattr(self, k, v[i])
6871
self.depth = 0
6972
self.true_false = ""
7073
self.targets = []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123
]
121124
for k, v in atts.items():
122125
if k.startswith(prefix):
123-
if "classlabels" in k:
124-
short[k] = list(v)
125-
else:
126-
short[k] = [v[i] for i in idx]
126+
short[k] = list(v) if "classlabels" in k else [v[i] for i in idx]
127127

128128
nodes = OrderedDict()
129129
for i in range(len(short["nodes_treeids"])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132
for i in range(len(short[f"{prefix}_treeids"])):
133133
idn = short[f"{prefix}_nodeids"][i]
134134
node = nodes[idn]
135-
node.append_target(
136-
tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
137-
)
135+
key = f"{prefix}_weights"
136+
if key not in short:
137+
key = f"{prefix}_weights_as_tensor"
138+
node.append_target(tid=short[f"{prefix}_ids"][i], weight=short[key][i])
138139

139140
def iterate(nodes, node, depth=0, true_false=""):
140141
node.depth = depth

onnx_array_api/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def add_rows(rows, d):
438438
if verbose and fLOG is not None:
439439
fLOG(
440440
"[pstats] %s=%r"
441-
% ((clean_text(k[0].replace("\\", "/")),) + k[1:], v)
441+
% ((clean_text(k[0].replace("\\", "/")), *k[1:]), v)
442442
)
443443
if len(v) < 5:
444444
continue

onnx_array_api/reference/__init__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
import numpy as np
33
from onnx import TensorProto
44
from onnx.numpy_helper import from_array as onnx_from_array
5-
from onnx.reference.ops.op_cast import (
6-
bfloat16,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
12-
from onnx.reference.op_run import to_array_extended
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
16+
try:
17+
from onnx.reference.op_run import to_array_extended
18+
except ImportError:
19+
from onnx.numpy_helper import to_array as to_array_extended
1320
from .evaluator import ExtendedReferenceEvaluator
1421
from .evaluator_yield import (
1522
DistanceExecution,
@@ -28,6 +35,8 @@ def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorP
2835
:param name: name
2936
:return: TensorProto
3037
"""
38+
if bfloat16 is None:
39+
return onnx_from_array(tensor, name)
3140
dt = tensor.dtype
3241
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
3342
to = TensorProto.FLOAT8E4M3FN

onnx_array_api/reference/ops/op_cast_like.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from onnx.helper import np_dtype_to_tensor_dtype
22
from onnx.onnx_pb import TensorProto
33
from onnx.reference.op_run import OpRun
4-
from onnx.reference.ops.op_cast import (
5-
bfloat16,
6-
cast_to,
7-
float8e4m3fn,
8-
float8e4m3fnuz,
9-
float8e5m2,
10-
float8e5m2fnuz,
11-
)
4+
from onnx.reference.ops.op_cast import cast_to
5+
6+
try:
7+
from onnx.reference.ops.op_cast import (
8+
bfloat16,
9+
float8e4m3fn,
10+
float8e4m3fnuz,
11+
float8e5m2,
12+
float8e5m2fnuz,
13+
)
14+
except ImportError:
15+
bfloat16 = None
1216

1317

1418
def _cast_like(x, y, saturate):
19+
if bfloat16 is None:
20+
return (cast_to(x, y.dtype, saturate),)
1521
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
1622
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
1723
to = TensorProto.BFLOAT16

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy