## Motivation speedup solution for training small neural networks on one GPU #17 ## Solution - [x] functorch + torchopt <img width="621" alt="image" src="https://user-images.githubusercontent.com/32269413/177388694-b78a91e1-56bc-4df5-8008-6b3ff8ee5578.png"> ## Resource - [functorch example](https://github.com/pytorch/functorch/blob/main/examples/ensembling/parallel_train.py) - [functorch zou example](https://github.com/Chillee/pt_transformation_experiments/blob/ad5b2ff19989ba3d6bab38169a2f4b287a25e827/ensembling/parallel_train.py) - [JAX + FLAX](http://willwhitney.com/parallel-training-jax.html) ## Checklist - [x] I have checked that there is no similar issue in the repo (**required**)