Skip to content

Commit fd5fc82

Browse files
committed
Fix as_tensor
1 parent 96eb50e commit fd5fc82

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

onnx_array_api/plotting/text_plot.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464
self.nodes_missing_value_tracks_true = None
6565
for k, v in atts.items():
6666
if k.startswith("nodes"):
67-
setattr(self, k, v[i])
67+
if k.endswith("_as_tensor"):
68+
setattr(self, k.replace("_as_tensor", ""), v[i])
69+
else:
70+
setattr(self, k, v[i])
6871
self.depth = 0
6972
self.true_false = ""
7073
self.targets = []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123
]
121124
for k, v in atts.items():
122125
if k.startswith(prefix):
123-
if "classlabels" in k:
124-
short[k] = list(v)
125-
else:
126-
short[k] = [v[i] for i in idx]
126+
short[k] = list(v) if "classlabels" in k else [v[i] for i in idx]
127127

128128
nodes = OrderedDict()
129129
for i in range(len(short["nodes_treeids"])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132
for i in range(len(short[f"{prefix}_treeids"])):
133133
idn = short[f"{prefix}_nodeids"][i]
134134
node = nodes[idn]
135-
node.append_target(
136-
tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
137-
)
135+
key = f"{prefix}_weights"
136+
if key not in short:
137+
key = f"{prefix}_weights_as_tensor"
138+
node.append_target(tid=short[f"{prefix}_ids"][i], weight=short[key][i])
138139

139140
def iterate(nodes, node, depth=0, true_false=""):
140141
node.depth = depth

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy