Skip to content

Commit 7330b58

Browse files
committed
remove some torch issues
1 parent 395e281 commit 7330b58

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

onnx_array_api/graph_api/graph_builder.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto, TensorProto
77
from onnx.reference import ReferenceEvaluator
88

9+
T = "TENSOR"
10+
911

1012
class Opset:
1113
# defined for opset >= 18
@@ -78,8 +80,8 @@ def make_node(
7880
class OptimizationOptions:
7981
def __init__(
8082
self,
81-
remove_unused: bool = False,
82-
constant_folding: bool = True,
83+
remove_unused: bool = True,
84+
constant_folding: bool = False,
8385
constant_size: int = 1024,
8486
):
8587
self.remove_unused = remove_unused
@@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray:
205207
if isinstance(value, np.ndarray):
206208
return value
207209

208-
import torch
209-
210-
if isinstance(value, torch.Tensor):
211-
return value.detach().numpy()
212210
raise TypeError(f"Unable to convert type {type(value)} into numpy array.")
213211

214212
def set_shape(self, name: str, shape: Tuple[int, ...]):
@@ -513,9 +511,7 @@ def make_nodes(
513511
return output_names[0]
514512
return output_names
515513

516-
def from_array(
517-
self, arr: "torch.Tensor", name: str = None # noqa: F821
518-
) -> TensorProto:
514+
def from_array(self, arr: T, name: str = None) -> TensorProto: # noqa: F821
519515
import sys
520516
import torch
521517

@@ -552,15 +548,8 @@ def from_array(
552548
return tensor
553549

554550
def _build_initializers(self) -> List[TensorProto]:
555-
import torch
556-
557551
res = []
558552
for k, v in sorted(self.initializers_dict.items()):
559-
if isinstance(v, torch.Tensor):
560-
# no string tensor
561-
t = self.from_array(v, name=k)
562-
res.append(t)
563-
continue
564553
if isinstance(v, np.ndarray):
565554
if self.verbose and np.prod(v.shape) > 100:
566555
print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]")
@@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]:
575564

576565
def process(
577566
self,
578-
graph_module: "torch.f.GraphModule", # noqa: F821
567+
graph_module: Any,
579568
interpreter: "Interpreter", # noqa: F821
580569
):
581570
for node in graph_module.graph.nodes:
@@ -656,19 +645,15 @@ def remove_unused(self):
656645
self.constants_ = {k: v for k, v in self.constants_.items() if k in marked}
657646
self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed]
658647

659-
def _apply_transpose(
660-
self, node: NodeProto, feeds: Dict[str, "torch.Tensor"] # noqa: F821
661-
) -> "torch.Tensor": # noqa: F821
662-
import torch
663-
648+
def _apply_transpose(self, node: NodeProto, feeds: Dict[str, T]) -> T: # noqa: F821
664649
perm = None
665650
for att in node.attribute:
666651
if att.name == "perm":
667652
perm = tuple(att.ints)
668653
break
669654
assert perm, f"perm not here in node {node}"
670655
assert len(perm) == 2, f"perm={perm} is not supported with torch"
671-
return [torch.transpose(feeds[node.input[0]], *perm)]
656+
return [np.transpose(feeds[node.input[0]], *perm)]
672657

673658
def constant_folding(self):
674659
"""

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy