Skip to content

Commit 1fb7b6e

Browse files
crusaderkyNeilGirdhar
authored andcommitted
Merge pull request data-apis#284 from crusaderky/autojit
ENH: `jax_autojit`
1 parent 074d0ee commit 1fb7b6e

File tree

8 files changed

+628
-104
lines changed

8 files changed

+628
-104
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 259 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,24 @@
22

33
from __future__ import annotations
44

5+
import io
56
import math
6-
from collections.abc import Generator, Iterable
7+
import pickle
8+
import types
9+
from collections.abc import Callable, Generator, Iterable
10+
from functools import wraps
711
from types import ModuleType
8-
from typing import TYPE_CHECKING, cast
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
ClassVar,
16+
Generic,
17+
Literal,
18+
ParamSpec,
19+
TypeAlias,
20+
TypeVar,
21+
cast,
22+
)
923

1024
from . import _compat
1125
from ._compat import (
@@ -19,8 +33,16 @@
1933
from ._typing import Array
2034

2135
if TYPE_CHECKING: # pragma: no cover
22-
# TODO import from typing (requires Python >=3.13)
23-
from typing_extensions import TypeIs
36+
# TODO import from typing (requires Python >=3.12 and >=3.13)
37+
from typing_extensions import TypeIs, override
38+
else:
39+
40+
def override(func):
41+
return func
42+
43+
44+
P = ParamSpec("P")
45+
T = TypeVar("T")
2446

2547

2648
__all__ = [
@@ -29,8 +51,11 @@
2951
"eager_shape",
3052
"in1d",
3153
"is_python_scalar",
54+
"jax_autojit",
3255
"mean",
3356
"meta_namespace",
57+
"pickle_flatten",
58+
"pickle_unflatten",
3459
]
3560

3661

@@ -302,3 +327,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
302327
out = out.copy()
303328
out["boolean indexing"] = False
304329
return out
330+
331+
332+
_BASIC_PICKLED_TYPES = frozenset((
333+
bool, int, float, complex, str, bytes, bytearray,
334+
list, tuple, dict, set, frozenset, range, slice,
335+
types.NoneType, types.EllipsisType,
336+
)) # fmt: skip
337+
_BASIC_REST_TYPES = frozenset((
338+
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType
339+
)) # fmt: skip
340+
341+
FlattenRest: TypeAlias = tuple[object, ...]
342+
343+
344+
def pickle_flatten(
345+
obj: object, cls: type[T] | tuple[type[T], ...]
346+
) -> tuple[list[T], FlattenRest]:
347+
"""
348+
Use the pickle machinery to extract objects out of an arbitrary container.
349+
350+
Unlike regular ``pickle.dumps``, this function always succeeds.
351+
352+
Parameters
353+
----------
354+
obj : object
355+
The object to pickle.
356+
cls : type | tuple[type, ...]
357+
One or multiple classes to extract from the object.
358+
The instances of these classes inside ``obj`` will not be pickled.
359+
360+
Returns
361+
-------
362+
instances : list[cls]
363+
All instances of ``cls`` found inside ``obj`` (not pickled).
364+
rest
365+
Opaque object containing the pickled bytes plus all other objects where
366+
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised.
367+
These are unpickleable objects, types, modules, and functions.
368+
369+
This object is *typically* hashable save for fairly exotic objects
370+
that are neither pickleable nor hashable.
371+
372+
This object is pickleable if everything except ``instances`` was pickleable
373+
in the input object.
374+
375+
See Also
376+
--------
377+
pickle_unflatten : Reverse function.
378+
379+
Examples
380+
--------
381+
>>> class A:
382+
... def __repr__(self):
383+
... return "<A>"
384+
>>> class NS:
385+
... def __repr__(self):
386+
... return "<NS>"
387+
... def __reduce__(self):
388+
... assert False, "not serializable"
389+
>>> obj = {1: A(), 2: [A(), NS(), A()]}
390+
>>> instances, rest = pickle_flatten(obj, A)
391+
>>> instances
392+
[<A>, <A>, <A>]
393+
>>> pickle_unflatten(instances, rest)
394+
{1: <A>, 2: [<A>, <NS>, <A>]}
395+
396+
This can be also used to swap inner objects; the only constraint is that
397+
the number of objects in and out must be the same:
398+
399+
>>> pickle_unflatten(["foo", "bar", "baz"], rest)
400+
{1: "foo", 2: ["bar", <NS>, "baz"]}
401+
"""
402+
instances: list[T] = []
403+
rest: list[object] = []
404+
405+
class Pickler(pickle.Pickler): # numpydoc ignore=GL08
406+
"""
407+
Use the `pickle.Pickler.persistent_id` hook to extract objects.
408+
"""
409+
410+
@override
411+
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
412+
if isinstance(obj, cls):
413+
instances.append(obj) # type: ignore[arg-type]
414+
return 0
415+
416+
typ_ = type(obj)
417+
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses!
418+
# If obj is a collection, recursively descend inside it
419+
return None
420+
if typ_ in _BASIC_REST_TYPES:
421+
rest.append(obj)
422+
return 1
423+
424+
try:
425+
# Note: a class that defines __slots__ without defining __getstate__
426+
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
427+
_ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
428+
except Exception: # pylint: disable=broad-exception-caught
429+
rest.append(obj)
430+
return 1
431+
432+
# Object can be pickled. Let the Pickler recursively descend inside it.
433+
return None
434+
435+
f = io.BytesIO()
436+
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
437+
p.dump(obj)
438+
return instances, (f.getvalue(), *rest)
439+
440+
441+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
442+
"""
443+
Reverse of ``pickle_flatten``.
444+
445+
Parameters
446+
----------
447+
instances : Iterable
448+
Inner objects to be reinserted into the flattened container.
449+
rest : FlattenRest
450+
Extra bits, as returned by ``pickle_flatten``.
451+
452+
Returns
453+
-------
454+
object
455+
The outer object originally passed to ``pickle_flatten`` after a
456+
pickle->unpickle round-trip.
457+
458+
See Also
459+
--------
460+
pickle_flatten : Serializing function.
461+
pickle.loads : Standard unpickle function.
462+
463+
Notes
464+
-----
465+
The `instances` iterable must yield at least the same number of elements as the ones
466+
returned by ``pickle_without``, but the elements do not need to be the same objects
467+
or even the same types of objects. Excess elements, if any, will be left untouched.
468+
"""
469+
iters = iter(instances), iter(rest)
470+
pik = cast(bytes, next(iters[1]))
471+
472+
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
473+
"""Mirror of the overridden Pickler in pickle_flatten."""
474+
475+
@override
476+
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
477+
try:
478+
return next(iters[pid])
479+
except StopIteration as e:
480+
msg = "Not enough objects to unpickle"
481+
raise ValueError(msg) from e
482+
483+
f = io.BytesIO(pik)
484+
return Unpickler(f).load()
485+
486+
487+
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
488+
"""
489+
Helper of :func:`jax_autojit`.
490+
491+
Wrap arbitrary inputs and outputs of the jitted function and
492+
convert them to/from PyTrees.
493+
"""
494+
495+
obj: T
496+
_registered: ClassVar[bool] = False
497+
__slots__: tuple[str, ...] = ("obj",)
498+
499+
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
500+
self._register()
501+
self.obj = obj
502+
503+
@classmethod
504+
def _register(cls): # numpydoc ignore=SS06
505+
"""
506+
Register upon first use instead of at import time, to avoid
507+
globally importing JAX.
508+
"""
509+
if not cls._registered:
510+
import jax
511+
512+
jax.tree_util.register_pytree_node(
513+
cls,
514+
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
515+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
516+
)
517+
cls._registered = True
518+
519+
520+
def jax_autojit(
521+
func: Callable[P, T],
522+
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
523+
"""
524+
Wrap `func` with ``jax.jit``, with the following differences:
525+
526+
- Python scalar arguments and return values are not automatically converted to
527+
``jax.Array`` objects.
528+
- All non-array arguments are automatically treated as static.
529+
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
530+
``pickle``.
531+
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
532+
tuple/list/dict, but can be any object serializable with ``pickle``.
533+
- Automatically descend into non-array arguments and find ``jax.Array`` objects
534+
inside them, then rebuild the arguments when entering `func`, swapping the JAX
535+
concrete arrays with tracer objects.
536+
- Automatically descend into non-array return values and find ``jax.Array`` objects
537+
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
538+
tracer objects with concrete arrays.
539+
540+
See Also
541+
--------
542+
jax.jit : JAX JIT compilation function.
543+
"""
544+
import jax
545+
546+
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
547+
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
548+
wargs: _AutoJITWrapper[Any],
549+
) -> _AutoJITWrapper[T]:
550+
args, kwargs = wargs.obj
551+
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
552+
return _AutoJITWrapper(res)
553+
554+
@wraps(func)
555+
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
556+
wargs = _AutoJITWrapper((args, kwargs))
557+
return inner(wargs).obj
558+
559+
return outer

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy