Skip to content

Commit 9615aba

Browse files
committed
Introduce array_api_stubs.py
1 parent 792f11e commit 9615aba

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

array_api_stubs.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import sys
2+
from importlib import import_module
3+
from importlib.util import find_spec
4+
from inspect import getmembers, isfunction, signature
5+
from pathlib import Path
6+
from types import FunctionType, ModuleType, SimpleNamespace
7+
from typing import Dict, List, Optional
8+
from unittest import TestCase
9+
10+
__all__ = ["make_stubs_namespace"]
11+
12+
API_VERSIONS = {"2012.12"} # TODO: infer released versions dynamically
13+
14+
15+
def make_stubs_namespace(api_version: Optional[str] = None) -> SimpleNamespace:
16+
"""
17+
Returns a ``SimpleNamespace`` where
18+
19+
* ``functions`` (``dict[str, FunctionType]``) maps names of top-level
20+
functions to their respective stubs.
21+
* ``array_methods`` (``dict[str, FunctionType]``) maps names of array
22+
methods to their respective stubs.
23+
* ``dtype_methods`` (``dict[str, FunctionType]``) maps names of dtype object
24+
methods to their respective stubs.
25+
* ``category_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps
26+
names of categories to their respective function mappings.
27+
* ``extension_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps
28+
names of extensions to their respective function mappings.
29+
30+
Examples
31+
--------
32+
33+
Make a stubs namespace.
34+
35+
>>> from array_api_stubs import make_stubs_namespace
36+
>>> stubs = make_stubs_namespace()
37+
38+
Access the ``array_api.square()`` stub.
39+
40+
>>> stubs.functions["square"]
41+
<function array_api.square(x: ~array, /) -> ~array>
42+
43+
Find names of all set functions.
44+
45+
>>> stubs.category_to_functions["set"].keys()
46+
dict_keys(['unique_all', 'unique_counts', 'unique_inverse', 'unique_values'])
47+
48+
Access the array object's ``__add__`` stub.
49+
50+
>>> stubs.array_methods["__add__"]
51+
<function array_api._Array.__add__(self: 'array', other: 'Union[int, float, array]', /) -> 'array'>
52+
53+
Access the ``array_api.linalg.cross()`` stub.
54+
55+
>>> stubs.extension_to_functions["linalg"]["cross"]
56+
<function array_api.linalg.cross(x1: ~array, x2: ~array, /, *, axis: int = -1) -> ~array>
57+
58+
"""
59+
if api_version is None:
60+
api_version = "draft"
61+
if api_version in API_VERSIONS or api_version == "latest":
62+
raise NotImplementedError("{api_version=} not yet supported")
63+
elif api_version != "draft":
64+
raise ValueError(
65+
f"{api_version=} not 'draft', 'latest', "
66+
f"or a released version ({API_VERSIONS})"
67+
)
68+
69+
spec_dir = Path(__file__).parent / "spec" / "API_specification"
70+
signatures_dir = spec_dir / "array_api"
71+
assert signatures_dir.exists() # sanity check
72+
spec_abs_path: str = str(spec_dir.resolve())
73+
sys.path.append(spec_abs_path)
74+
assert find_spec("array_api") is not None # sanity check
75+
76+
name_to_mod: Dict[str, ModuleType] = {}
77+
for path in signatures_dir.glob("*.py"):
78+
name = path.name.replace(".py", "")
79+
name_to_mod[name] = import_module(f"array_api.{name}")
80+
81+
array = name_to_mod["array_object"].array
82+
array_methods: Dict[str, FunctionType] = {}
83+
for name, func in getmembers(array, predicate=isfunction):
84+
func.__module__ = "array_api"
85+
assert "Alias" not in func.__doc__ # sanity check
86+
func.__qualname__ = f"_Array.{func.__name__}"
87+
array_methods[name] = func
88+
89+
dtype_eq = name_to_mod["data_types"].__eq__
90+
assert isinstance(dtype_eq, FunctionType) # for mypy
91+
dtype_eq.__module__ = "array_api"
92+
dtype_eq.__qualname__ = "_DataType.__eq__"
93+
dtype_methods: Dict[str, FunctionType] = {"__eq__": dtype_eq}
94+
95+
functions: Dict[str, FunctionType] = {}
96+
category_to_functions: Dict[str, Dict[str, FunctionType]] = {}
97+
for name, mod in name_to_mod.items():
98+
if name.endswith("_functions"):
99+
category = name.replace("_functions", "")
100+
name_to_func = {}
101+
for name in mod.__all__:
102+
func = getattr(mod, name)
103+
assert isinstance(func, FunctionType) # sanity check
104+
func.__module__ = "array_api"
105+
name_to_func[name] = func
106+
functions.update(name_to_func)
107+
category_to_functions[category] = name_to_func
108+
109+
extensions: List[str] = ["linalg"] # TODO: infer on runtime
110+
extension_to_functions: Dict[str, Dict[str, FunctionType]] = {}
111+
for ext in extensions:
112+
mod = name_to_mod[ext]
113+
name_to_func = {name: getattr(mod, name) for name in mod.__all__}
114+
name_to_func = {}
115+
for name in mod.__all__:
116+
func = getattr(mod, name)
117+
assert isinstance(func, FunctionType) # sanity check
118+
assert func.__doc__ is not None # for mypy
119+
if "Alias" in func.__doc__:
120+
func.__doc__ = functions[name].__doc__
121+
func.__module__ = f"array_api.{ext}"
122+
name_to_func[name] = func
123+
extension_to_functions[ext] = name_to_func
124+
125+
return SimpleNamespace(
126+
functions=functions,
127+
array_methods=array_methods,
128+
dtype_methods=dtype_methods,
129+
category_to_functions=category_to_functions,
130+
extension_to_functions=extension_to_functions,
131+
)
132+
133+
134+
class TestMakeStubsNamespace(TestCase):
135+
def setUp(self):
136+
self.stubs = make_stubs_namespace()
137+
138+
def test_attributes(self):
139+
assert isinstance(self.stubs, SimpleNamespace)
140+
for attr in ["functions", "array_methods", "dtype_methods"]:
141+
mapping = getattr(self.stubs, attr)
142+
assert isinstance(mapping, dict)
143+
assert all(isinstance(k, str) for k in mapping.keys())
144+
assert all(isinstance(v, FunctionType) for v in mapping.values())
145+
for attr in ["category_to_functions", "extension_to_functions"]:
146+
mapping = getattr(self.stubs, attr)
147+
assert isinstance(mapping, dict)
148+
assert all(isinstance(k, str) for k in mapping.keys())
149+
for sub_mapping in mapping.values():
150+
assert isinstance(sub_mapping, dict)
151+
assert all(isinstance(k, str) for k in sub_mapping.keys())
152+
assert all(isinstance(v, FunctionType) for v in sub_mapping.values())
153+
154+
def test_function_meta(self):
155+
toplevel_stub = self.stubs.functions["matmul"]
156+
assert toplevel_stub.__module__ == "array_api"
157+
extension_stub = self.stubs.extension_to_functions["linalg"]["matmul"]
158+
assert extension_stub.__module__ == "array_api.linalg"
159+
assert extension_stub.__doc__ == toplevel_stub.__doc__
160+
161+
def test_array_method_meta(self):
162+
stub = self.stubs.array_methods["__add__"]
163+
assert stub.__module__ == "array_api"
164+
assert stub.__qualname__ == "_Array.__add__"
165+
first_arg = next(iter(signature(stub).parameters.values()))
166+
assert first_arg.name == "self"
167+
168+
def test_dtype_method_meta(self):
169+
stub = self.stubs.dtype_methods["__eq__"]
170+
assert stub.__module__ == "array_api"
171+
assert stub.__qualname__ == "_DataType.__eq__"
172+
first_arg = next(iter(signature(stub).parameters.values()))
173+
assert first_arg.name == "self"

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy