Skip to content

feat: functorch integration #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Sep 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
766ca2f
init
vmoens Apr 13, 2022
0635bc9
amend
vmoens Apr 13, 2022
e7ee747
amend
vmoens Apr 13, 2022
b9c0373
amend
vmoens Apr 13, 2022
589de77
feat: resolve conflicts and init func optimizer
Benjamin-eecs Aug 12, 2022
639c4dc
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
4ea3fe5
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
53e9194
feat: init func optimizer
Benjamin-eecs Aug 12, 2022
f0bdbdf
fix: keep the tensor types after update
XuehaiPan Aug 14, 2022
788fd7e
Merge branch 'main' into functorch_functional
XuehaiPan Aug 14, 2022
ced2114
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 5, 2022
1358386
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 7, 2022
5b94aab
fix: pass lint
Benjamin-eecs Sep 7, 2022
e2ff85f
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 9, 2022
d434103
Merge branch 'main' into functorch_functional
Benjamin-eecs Sep 10, 2022
f7c6858
fix: revert nn.Parameter fix
Benjamin-eecs Sep 11, 2022
3d39c30
chore: cleanup
XuehaiPan Sep 11, 2022
402e48e
feat: update `FuncOptimizer`
XuehaiPan Sep 11, 2022
ff4533c
docs: add docs for `FuncOptimizer`
XuehaiPan Sep 11, 2022
e376123
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Sep 11, 2022
847549b
chore: cleanup
XuehaiPan Sep 11, 2022
f9ab259
chore: cleanup
XuehaiPan Sep 11, 2022
8f8d3ef
chore: cleanup
XuehaiPan Sep 11, 2022
ad71291
chore: handle corner case
XuehaiPan Sep 11, 2022
474dee5
test: add tests for `FuncOptimizer`
XuehaiPan Sep 11, 2022
fa2a38c
chore: add type check
XuehaiPan Sep 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
chore: add type check
  • Loading branch information
XuehaiPan committed Sep 11, 2022
commit fa2a38ce68c79e31732ac6a73170dd579081a55a
3 changes: 3 additions & 0 deletions torchopt/_src/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, params: Iterable[torch.Tensor], impl: GradientTransformation)
Note that using ``Optimizer(sgd())`` or ``Optimizer(chain(sgd()))`` is equivalent to
:class:`torchopt.SGD`.
"""
if not isinstance(impl, GradientTransformation):
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')

self.impl = impl
self.param_groups = [] # type: ignore
self.param_tree_groups = [] # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions torchopt/_src/optimizer/func/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> No
inplace (optional): (default: :data:`False`)
The default value of ``inplace`` for each optimization update.
"""
if not isinstance(impl, GradientTransformation):
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')

self.impl = impl
self.optim_state = self.__NOT_INITIALIZED
self.inplace = bool(inplace)
Expand Down
3 changes: 3 additions & 0 deletions torchopt/_src/optimizer/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(self, net: nn.Module, impl: GradientTransformation):
``MetaOptimizer(chain(sgd(moment_requires_grad=True)))`` is equivalent to
:class:`torchopt.MetaSGD`.
"""
if not isinstance(impl, GradientTransformation):
raise TypeError(f'{impl} (type: {type(impl).__name__}) is not a GradientTransformation')

self.impl = impl
self.param_containers_groups = [] # type: ignore
self.state_groups = [] # type: ignore
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy