Skip to content

Commit a906010

Browse files
authored
Documentation (#78)
* update requirements * Add ConstantOfShape to light API * add slice * changelogs * k
1 parent 2dd0686 commit a906010

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`77`: supports ConcatOfShape and Slice with the light API
78
* :pr:`76`: add a mode to compare models without execution
89
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
* :pr:`71`: adds tools to compare two onnx graphs

_unittests/ut_light_api/test_light_api.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,35 @@ def test_constant_of_shape(self):
538538
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539539
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540540

541+
def test_constant_of_shape_value(self):
542+
onx = (
543+
start()
544+
.vin("X", TensorProto.INT64, shape=[None, None])
545+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref = ReferenceEvaluator(onx)
550+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
552+
553+
def test_slice(self):
554+
onx = (
555+
start(opset=18, ir_version=9)
556+
.cst(np.array([1], dtype=np.int64), name="one")
557+
.cst(np.array([2], dtype=np.int64), name="two")
558+
.vin("X", TensorProto.INT64, shape=[None, None])
559+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX", "one", "two", "one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref = ReferenceEvaluator(onx)
567+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
569+
541570

542571
if __name__ == "__main__":
543-
TestLightApi().test_add()
544572
unittest.main(verbosity=2)

onnx_array_api/light_api/_op_var.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ def Selu(
314314
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
315315
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
316316

317+
def Slice(
318+
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
319+
) -> "Var":
320+
if steps is None:
321+
return self.make_node("Slice", self, starts, ends, axes)
322+
return self.make_node("Slice", self, starts, ends, axes, steps)
323+
317324
def Softmax(self, axis: int = -1) -> "Var":
318325
return self.make_node("Softmax", self, axis=axis)
319326

onnx_array_api/light_api/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO:
406406
return graph
407407
model = make_model(graph, opset_imports=opsets)
408408
if self.ir_version:
409-
model.ir_version = ir_version
409+
model.ir_version = self.ir_version
410410
if not is_windows() or not is_azure():
411411
# check_model fails sometimes on Windows
412412
check_model(model)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy