-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Hey! I'd like to propose that apart from the current registration mechanism JAX allowed users to define pytrees via a Pytree Protocol. There are two reasons for doing this:
- Simplifies the definition of custom Pytrees. Currently the best mechanism is to ensure that a class and all its subclasses are registered is to override the
__init_subclass__
method and perform the registration there, implementing a protocol is more straightforward as__init_subclass__
is not widely known. - As the concept of a Pytree becomes more wide spread within the Python ecosystem (e.g. torch.utils._pytree -> stable pytorch/pytorch#65761, and
dm-tree
) a protocol could be a simple way of getting cross-library compatibility if the different implementations adopt it.
Implementation
The idea would be to take the same mechanism implemented in register_pytree_node_class
but with "special methods":
class Pytree(Protocol):
def __tree_flatten__(self) -> Tuple[Sequence[Any], Any]:
...
@classmethod
def __tree_unflatten__(cls, children: Sequence[Any], aux: Any) -> Any:
...
Any object that defines these method would be treated as a Pytree at runtime.
Other comments/ideas
- Pytree seems like a general concept orthogonal to numeric computing, making JAX's Pytree implementation a separate project that JAX depended on would have a positive impact on the whole Python ecosystem.
- Libraries that focus on general Pytree manipulation like Treeo could start to be useful independent of JAX.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request