Skip to content

Pytree Protocol #8099

@cgarciae

Description

@cgarciae

Hey! I'd like to propose that apart from the current registration mechanism JAX allowed users to define pytrees via a Pytree Protocol. There are two reasons for doing this:

  1. Simplifies the definition of custom Pytrees. Currently the best mechanism is to ensure that a class and all its subclasses are registered is to override the __init_subclass__ method and perform the registration there, implementing a protocol is more straightforward as __init_subclass__ is not widely known.
  2. As the concept of a Pytree becomes more wide spread within the Python ecosystem (e.g. torch.utils._pytree -> stable pytorch/pytorch#65761, and dm-tree) a protocol could be a simple way of getting cross-library compatibility if the different implementations adopt it.

Implementation

The idea would be to take the same mechanism implemented in register_pytree_node_class but with "special methods":

class Pytree(Protocol):
    def __tree_flatten__(self) -> Tuple[Sequence[Any], Any]:
        ...
    
    @classmethod
    def __tree_unflatten__(cls, children: Sequence[Any], aux: Any) -> Any:
        ...

Any object that defines these method would be treated as a Pytree at runtime.

Other comments/ideas

  • Pytree seems like a general concept orthogonal to numeric computing, making JAX's Pytree implementation a separate project that JAX depended on would have a positive impact on the whole Python ecosystem.
  • Libraries that focus on general Pytree manipulation like Treeo could start to be useful independent of JAX.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy