## Motivation Improve the speed of the current PyTree implemention: `tree_map`, `tree_flatten` and `unflatten` ## Solution An extra C++ library with python bindings ## Resource - https://github.com/pytorch/pytorch/issues/65761 - https://github.com/google/jax/issues/8099 ## Checklist - [x] I have checked that there is no similar issue in the repo (**required**)