diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 13c81ab..9f22a80 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`77`: supports ConcatOfShape and Slice with the light API * :pr:`76`: add a mode to compare models without execution * :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator * :pr:`71`: adds tools to compare two onnx graphs diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 0483354..e14896a 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -538,7 +538,35 @@ def test_constant_of_shape(self): got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0] self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got) + def test_constant_of_shape_value(self): + onx = ( + start() + .vin("X", TensorProto.INT64, shape=[None, None]) + .ConstantOfShape(value=np.array([1], dtype=np.float32)) + .vout(shape=[]) + .to_onnx() + ) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0] + self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got) + + def test_slice(self): + onx = ( + start(opset=18, ir_version=9) + .cst(np.array([1], dtype=np.int64), name="one") + .cst(np.array([2], dtype=np.int64), name="two") + .vin("X", TensorProto.INT64, shape=[None, None]) + .ConstantOfShape(value=np.array([1], dtype=np.float32)) + .rename("CX") + .bring("CX", "one", "two", "one") + .Slice() + .vout(shape=[]) + .to_onnx() + ) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0] + self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got) + if __name__ == "__main__": - TestLightApi().test_add() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index 3a74ed2..1291594 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -314,6 +314,13 @@ def Selu( def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var": return self.make_node("Shrink", self, bias=bias, lambd=lambd) + def Slice( + self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None + ) -> "Var": + if steps is None: + return self.make_node("Slice", self, starts, ends, axes) + return self.make_node("Slice", self, starts, ends, axes, steps) + def Softmax(self, axis: int = -1) -> "Var": return self.make_node("Softmax", self, axis=axis) diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 25194ac..6478c4d 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO: return graph model = make_model(graph, opset_imports=opsets) if self.ir_version: - model.ir_version = ir_version + model.ir_version = self.ir_version if not is_windows() or not is_azure(): # check_model fails sometimes on Windows check_model(model)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: