Skip to content

Commit a3d4ccf

Browse files
committed
2 parents c7bb055 + 6ed1d1c commit a3d4ccf

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

onnx_array_api/graph_api/graph_builder.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -836,30 +836,57 @@ def remove_identity_nodes(self):
836836
"""
837837
Removes identity nodes.
838838
"""
839-
# f<irst pass: detect replacements
839+
# first pass: detect replacements
840840
new_nodes = []
841841
input_names = set(i.name for i in self.inputs)
842842
output_names = set(i.name for i in self.outputs)
843843
replacements = {}
844+
replacements_rev = {}
844845
for node in self.nodes:
845846
if node.op_type != "Identity":
846847
new_nodes.append(node)
847848
continue
848849

849850
if node.output[0] not in output_names:
850851
old_name, new_name = node.output[0], node.input[0]
851-
elif node.input[0] not in input_names:
852+
elif (
853+
node.input[0] not in input_names
854+
and node.input[0] not in output_names
855+
and node.input[0] not in replacements
856+
):
852857
old_name, new_name = node.input[0], node.output[0]
853858
else:
854859
new_nodes.append(node)
855860
continue
856861

857862
# the new name can be set for replacements as well
858-
assert old_name not in replacements
859863
if new_name in replacements:
860864
new_name = replacements[new_name]
861-
assert new_name not in replacements
865+
assert new_name not in replacements, (
866+
f"Name {old_name!r} still in {replacements}, node.op_type={node.op_type!r}, "
867+
f"node.input={node.input}, node.output={node.output}, "
868+
f"input_names={input_names}, output_names={output_names}"
869+
)
870+
if old_name in replacements_rev:
871+
old_old_name = replacements_rev[old_name]
872+
replacements[old_old_name] = new_name
873+
replacements_rev[new_name] = old_old_name
874+
if old_name in replacements:
875+
replacements[replacements[old_name]] = new_name
876+
assert new_name not in replacements, (
877+
f"Name {old_name!r} still in {replacements}, node.op_type={node.op_type!r}, "
878+
f"node.input={node.input}, node.output={node.output}, "
879+
f"input_names={input_names}, output_names={output_names}"
880+
)
862881
replacements[old_name] = new_name
882+
replacements_rev[new_name] = old_name
883+
884+
# verification
885+
for k, v in replacements.items():
886+
assert v not in replacements, (
887+
f"replacement {k}->{v} is not possible because of "
888+
f"{v}->{replacements[v]}, old_name={old_name!r}, new_name={new_name!r}"
889+
)
863890

864891
# second pass: replacements in initializer
865892
for k, v in replacements.items():
@@ -876,10 +903,12 @@ def remove_identity_nodes(self):
876903
repo = {o for o in node.output if o in replacements}
877904
repi = {o for o in node.input if o in replacements}
878905
if repi or repo:
906+
new_inputs = [replacements.get(i, i) for i in node.input]
907+
new_outputs = [replacements.get(i, i) for i in node.output]
879908
new_node = oh.make_node(
880909
node.op_type,
881-
[replacements.get(i, i) for i in node.input],
882-
[replacements.get(i, i) for i in node.output],
910+
new_inputs,
911+
new_outputs,
883912
domain=node.domain,
884913
name=node.name,
885914
)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ exclude = [
1111
# Same as Black.
1212
line-length = 88
1313

14-
[tool.ruff.mccabe]
14+
[tool.ruff.lint.mccabe]
1515
# Unlike Flake8, default to a complexity level of 10.
1616
max-complexity = 10
1717

18-
[tool.ruff.per-file-ignores]
18+
[tool.ruff.lint.per-file-ignores]
1919
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
2020
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
2121
"onnx_array_api/array_api/_onnx_common.py" = ["F821"]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy