Skip to content

Commit 7b65822

Browse files
committed
add complex
1 parent df4f45e commit 7b65822

File tree

6 files changed

+24
-0
lines changed

6 files changed

+24
-0
lines changed

onnx_array_api/_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any):
4040
dt = TensorProto.INT64
4141
elif dtype is float:
4242
dt = TensorProto.DOUBLE
43+
elif dtype == np.complex64:
44+
dt = TensorProto.COMPLEX64
45+
elif dtype == np.complex128:
46+
dt = TensorProto.COMPLEX128
4347
else:
4448
raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
4549
return dt

onnx_array_api/annotations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6464
np.uint64: TensorProto.UINT64,
6565
np.bool_: TensorProto.BOOL,
6666
np.str_: TensorProto.STRING,
67+
np.complex64: TensorProto.COMPLEX64,
68+
np.complex128: TensorProto.COMPLEX128,
6769
}
6870

6971

onnx_array_api/array_api/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _finfo(dtype):
4747
continue
4848
if isinstance(v, (np.float32, np.float64, np.float16)):
4949
d[k] = float(v)
50+
elif isinstance(v, (np.complex128, np.complex64)):
51+
d[k] = complex(v)
5052
else:
5153
d[k] = v
5254
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
@@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor):
124126
module.float16 = DType(TensorProto.FLOAT16)
125127
module.float32 = DType(TensorProto.FLOAT)
126128
module.float64 = DType(TensorProto.DOUBLE)
129+
module.complex64 = DType(TensorProto.COMPLEX64)
130+
module.complex128 = DType(TensorProto.COMPLEX128)
127131
module.int8 = DType(TensorProto.INT8)
128132
module.int16 = DType(TensorProto.INT16)
129133
module.int32 = DType(TensorProto.INT32)

onnx_array_api/npx/npx_var.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,13 +1171,17 @@ def __init__(self, cst: Any):
11711171
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
11721172
elif isinstance(cst, float):
11731173
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
1174+
elif isinstance(cst, complex):
1175+
Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
11741176
elif isinstance(cst, list):
11751177
if all(isinstance(t, bool) for t in cst):
11761178
Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
11771179
elif all(isinstance(t, (int, bool)) for t in cst):
11781180
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
11791181
elif all(isinstance(t, (float, int, bool)) for t in cst):
11801182
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
1183+
elif all(isinstance(t, (float, int, bool, complex)) for t in cst):
1184+
Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
11811185
else:
11821186
raise ValueError(
11831187
f"Unable to convert cst (type={type(cst)}), value={cst}."

onnx_array_api/reference/evaluator_yield.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def generate_input(info: ValueInfoProto) -> np.ndarray:
485485
return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape)
486486
if elem_type == TensorProto.DOUBLE:
487487
return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape)
488+
if elem_type == TensorProto.COMPLEX64:
489+
return (value.astype(np.complex64) / p).astype(np.complex64).reshape(new_shape)
490+
if elem_type == TensorProto.COMPLEX128:
491+
return (
492+
(value.astype(np.complex128) / p).astype(np.complex128).reshape(new_shape)
493+
)
488494
raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}")
489495

490496

onnx_array_api/reference/ops/op_constant_of_shape.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def _process(value):
1919
cst = np.int64(cst)
2020
elif isinstance(cst, float):
2121
cst = np.float64(cst)
22+
elif isinstance(cst, complex):
23+
cst = np.complex128(cst)
2224
elif cst is None:
2325
cst = np.float32(0)
2426
if not isinstance(
@@ -27,6 +29,8 @@ def _process(value):
2729
np.float16,
2830
np.float32,
2931
np.float64,
32+
np.complex64,
33+
np.complex128,
3034
np.int64,
3135
np.int32,
3236
np.int16,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy