Skip to content

Commit eee76cc

Browse files
authored
Lint (#89)
* example * lint * exc * array " * fix * fix missing dependency * yml * disable some tests
1 parent 6076c1c commit eee76cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+295
-189
lines changed

_doc/examples/plot_benchmark_rf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
4040
)
4141

4242
options = scope.get_options(operator.raw_operator)
43-
if "split" in options:
44-
operator.split = options["split"]
45-
else:
46-
operator.split = None
43+
operator.split = options.get("split", None)
4744
convert_lightgbm(scope, operator, container)
4845

4946

@@ -103,7 +100,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
103100
:return: number of runs, sum of the time, average, median
104101
"""
105102
times = []
106-
for n in range(repeat):
103+
for _n in range(repeat):
107104
perf = time.perf_counter()
108105
fct(X)
109106
delta = time.perf_counter() - perf
@@ -241,7 +238,10 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
241238
# onnxruntime
242239
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} predictO")
243240
r, t, mean, med = measure_inference(
244-
lambda x: sess.run(None, {"X": x}), X, repeat=repeat, max_time=max_time
241+
lambda x, sess=sess: sess.run(None, {"X": x}),
242+
X,
243+
repeat=repeat,
244+
max_time=max_time,
245245
)
246246
o2 = obs.copy()
247247
o2.update(dict(avg=mean, med=med, n_runs=r, ttime=t, name="ort_"))

_doc/examples/plot_onnxruntime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def loop(n=1000):
8787
x = np.random.randn(n, 2).astype(np.float32)
8888
y = np.random.randn(n, 2).astype(np.float32)
8989

90-
obs = measure_time(lambda: myloss(x, y))
90+
obs = measure_time(lambda x=x, y=y: myloss(x, y))
9191
obs["name"] = "numpy"
9292
obs["n"] = n
9393
data.append(obs)
9494

9595
xort = OrtTensor.from_array(x)
9696
yort = OrtTensor.from_array(y)
97-
obs = measure_time(lambda: ort_myloss(xort, yort))
97+
obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort))
9898
obs["name"] = "ort"
9999
obs["n"] = n
100100
data.append(obs)

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from os import getenv
44
from functools import reduce
5+
import packaging.version as pv
56
import numpy as np
67
from operator import mul
78
from hypothesis import given
@@ -44,9 +45,12 @@ class TestHypothesisArraysApis(ExtTestCase):
4445

4546
@classmethod
4647
def setUpClass(cls):
47-
with warnings.catch_warnings():
48-
warnings.simplefilter("ignore")
49-
from numpy import array_api as xp
48+
try:
49+
import array_api_strict as xp
50+
except ImportError:
51+
with warnings.catch_warnings():
52+
warnings.simplefilter("ignore")
53+
from numpy import array_api as xp
5054

5155
api_version = getenv(
5256
"ARRAY_API_TESTS_VERSION",
@@ -63,6 +67,9 @@ def test_strategies(self):
6367
self.assertNotEmpty(self.xps)
6468
self.assertNotEmpty(self.onxps)
6569

70+
@unittest.skipIf(
71+
pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
72+
)
6673
def test_scalar_strategies(self):
6774
dtypes = dict(
6875
integer_dtypes=self.xps.integer_dtypes(),
@@ -139,6 +146,9 @@ def fctonx(x, kw):
139146
fctonx()
140147
self.assertEqual(len(args_onxp), len(args_np))
141148

149+
@unittest.skipIf(
150+
pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
151+
)
142152
def test_square_sizes_strategies(self):
143153
dtypes = dict(
144154
integer_dtypes=self.xps.integer_dtypes(),

_unittests/ut_light_api/test_backend_export.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import packaging.version as pv
66
import numpy
77
from numpy.testing import assert_allclose
8+
from onnx.defs import onnx_opset_version
89
import onnx.backend.base
910
import onnx.backend.test
1011
import onnx.shape_inference
@@ -31,7 +32,6 @@
3132

3233
class ReferenceImplementationError(RuntimeError):
3334
"Fails, export cannot be compared."
34-
pass
3535

3636

3737
class ExportWrapper:
@@ -64,7 +64,8 @@ def run(
6464
expected = self.expected_sess.run(names, feeds)
6565
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
6666
raise ReferenceImplementationError(
67-
f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
67+
f"ReferenceImplementation fails with "
68+
f"{onnx_simple_text_plot(self.model)}"
6869
f"\n--RAW--\n{self.model}"
6970
) from e
7071

@@ -85,7 +86,7 @@ def run(
8586
new_code = "\n".join(
8687
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
8788
)
88-
raise AssertionError(f"ERROR {e}\n{new_code}")
89+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
8990

9091
locs = {
9192
"np": numpy,
@@ -154,7 +155,8 @@ def run(
154155
):
155156
if a.tolist() != b.tolist():
156157
raise AssertionError(
157-
f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
158+
f"Text discrepancies for api {api!r} "
159+
f"with a.dtype={a.dtype} "
158160
f"and b.dtype={b.dtype}"
159161
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
160162
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
@@ -275,6 +277,22 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
275277
")"
276278
)
277279

280+
if onnx_opset_version() < 22:
281+
backend_test.exclude(
282+
"("
283+
"test_dft_inverse_cpu"
284+
"|test_dft_inverse_opset19_cpu"
285+
"|test_lppool_1d_default_cpu"
286+
"|test_lppool_2d_default_cpu"
287+
"|test_lppool_2d_dilations_cpu"
288+
"|test_lppool_2d_pads_cpu"
289+
"|test_lppool_2d_same_lower_cpu"
290+
"|test_lppool_2d_same_upper_cpu"
291+
"|test_lppool_2d_strides_cpu"
292+
"|test_lppool_3d_default_cpu"
293+
")"
294+
)
295+
278296
if pv.Version(onnx_version) < pv.Version("1.16.0"):
279297
backend_test.exclude("(test_strnorm|test_range_)")
280298

_unittests/ut_light_api/test_light_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def g(self):
484484
def ah(self):
485485
return True
486486

487-
setattr(A, "h", ah)
487+
setattr(A, "h", ah) # noqa: B010
488488

489489
self.assertTrue(A().h())
490490
self.assertIn("(self)", str(inspect.signature(A.h)))

_unittests/ut_plotting/test_dot_plot.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import unittest
43

_unittests/ut_plotting/test_text_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import textwrap
43
import unittest
@@ -95,6 +94,7 @@ def test_onnx_text_plot_tree_cls_2(self):
9594
+f 0:1 1:0 2:0
9695
"""
9796
).strip(" \n\r")
97+
res = res.replace("np.float32(", "").replace(")", "")
9898
self.assertEqual(expected, res.strip(" \n\r"))
9999

100100
@ignore_warnings((UserWarning, FutureWarning))

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,25 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
217217
# The following tests fail due to a type mismatch.
218218
backend_test.exclude("(test_eyelike_without_dtype)")
219219

220+
if onnx_opset_version() < 22:
221+
backend_test.exclude(
222+
"("
223+
"test_adagrad_cpu"
224+
"|test_adagrad_multiple_cpu"
225+
"|test_dft_inverse_cpu"
226+
"|test_dft_inverse_opset19_cpu"
227+
"|test_lppool_1d_default_cpu"
228+
"|test_lppool_2d_default_cpu"
229+
"|test_lppool_2d_dilations_cpu"
230+
"|test_lppool_2d_pads_cpu"
231+
"|test_lppool_2d_same_lower_cpu"
232+
"|test_lppool_2d_same_upper_cpu"
233+
"|test_lppool_2d_strides_cpu"
234+
"|test_lppool_3d_default_cpu"
235+
")"
236+
)
237+
238+
220239
# The following tests fail due to discrepancies (small but still higher than 1e-7).
221240
backend_test.exclude("test_adam_multiple") # 1e-2
222241

_unittests/ut_translate_api/test_translate.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,14 @@ def test_export_if(self):
160160
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
161161

162162
code = translate(onx)
163-
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
164-
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
163+
selse = (
164+
"g().cst(np.array([0], dtype=np.int64)).rename('Z')."
165+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
166+
)
167+
sthen = (
168+
"g().cst(np.array([1], dtype=np.int64)).rename('Z')."
169+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
170+
)
165171
expected = dedent(
166172
f"""
167173
(

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_fft(self):
252252
new_code = "\n".join(
253253
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
254254
)
255-
raise AssertionError(f"ERROR {e}\n{new_code}")
255+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
256256

257257
def test_aionnxml(self):
258258
onx = (

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy