Skip to content

refactor: align argument names with PyTorch #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5e4676e
refactor: refactor optimizer bases
XuehaiPan Aug 25, 2022
20d5377
refactor: align Adam options with PyTorch
XuehaiPan Aug 25, 2022
bbc3bdd
refactor: align RMSProp options with PyTorch
XuehaiPan Aug 25, 2022
a45111a
refactor: align SGD options with PyTorch
XuehaiPan Aug 25, 2022
0c30794
feat(alias): check value range
XuehaiPan Aug 25, 2022
4a985ea
feat: add `params` to `update_fn`'s signature
XuehaiPan Aug 25, 2022
a3cb948
feat: add weight decay
XuehaiPan Aug 26, 2022
cbf1a52
test: add weight decay tests
XuehaiPan Aug 26, 2022
a0be5d9
lint: pass lint
XuehaiPan Aug 26, 2022
71a4a63
docs: update docstring
XuehaiPan Aug 26, 2022
2ab1da0
chore: update type hints
XuehaiPan Aug 26, 2022
a7d7643
fix: fix grad tracing for weight decay
XuehaiPan Aug 26, 2022
63d77cc
test: reorganize tests
XuehaiPan Aug 26, 2022
b31e245
chore: add RMSprop aliases for PyTorch compatibility
XuehaiPan Aug 26, 2022
a929c51
test: add module buffers
XuehaiPan Aug 26, 2022
f082d6c
test: update test parameters
XuehaiPan Aug 26, 2022
1c8421d
chore: update .gitignore
XuehaiPan Aug 26, 2022
2723c01
test: update test parameters
XuehaiPan Aug 27, 2022
1ec3d0f
refactor: refactor transform
XuehaiPan Aug 27, 2022
e8bd609
refactor: chain
XuehaiPan Aug 27, 2022
353b628
refactor: identity
XuehaiPan Aug 27, 2022
bee81bb
feat: add with_flattened_tree
XuehaiPan Aug 27, 2022
a8b6dc0
test: update test parameters
XuehaiPan Aug 27, 2022
7d1d20a
feat: add dampening
XuehaiPan Aug 27, 2022
58dcc56
docs: update docstring
XuehaiPan Aug 28, 2022
acde3fd
lint: fix mypy
XuehaiPan Aug 28, 2022
b1e1521
fix: fix grad tracing for initial value
XuehaiPan Aug 29, 2022
6331cdc
test: update test parameters
XuehaiPan Aug 29, 2022
1b39ad5
docs: update docstrings
XuehaiPan Aug 29, 2022
9976f96
chore: rename variables
XuehaiPan Aug 29, 2022
4c726f2
test: update test parameters
XuehaiPan Aug 29, 2022
3eee243
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 29, 2022
1fded5d
test: test with pre-release of PyTorch
XuehaiPan Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: chain
  • Loading branch information
XuehaiPan committed Aug 27, 2022
commit e8bd609003efa12a7f72f14570925a017b340842
2 changes: 1 addition & 1 deletion tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_sgd(

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

chain = torchopt.combine.chain(
chain = torchopt.chain(
torchopt.clip.clip_grad_norm(max_norm=max_norm),
torchopt.sgd(
lr=lr,
Expand Down
4 changes: 4 additions & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from torchopt._src import accelerated_op_available, clip, combine, hook, schedule, visual
from torchopt._src.alias import adam, rmsprop, sgd
from torchopt._src.clip import clip_grad_norm
from torchopt._src.combine import chain
from torchopt._src.optimizer import SGD, Adam, Optimizer, RMSProp, RMSprop, meta
from torchopt._src.optimizer.meta import MetaAdam, MetaOptimizer, MetaRMSProp, MetaRMSprop, MetaSGD
from torchopt._src.update import apply_updates
Expand All @@ -33,6 +35,8 @@
'adam',
'rmsprop',
'sgd',
'clip_grad_norm',
'chain',
'Optimizer',
'SGD',
'Adam',
Expand Down
79 changes: 77 additions & 2 deletions torchopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@
# limitations under the License.
# ==============================================================================

import itertools
from abc import abstractmethod
from typing import Callable, NamedTuple, Optional, Tuple

from typing_extensions import Protocol

from torchopt._src.typing import Numeric, TensorTree


try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol


OptState = TensorTree # States are arbitrary nests of `torch.Tensor`.
# Parameters are arbitrary nests of `torch.Tensor`.
Params = TensorTree
Expand Down Expand Up @@ -132,6 +137,76 @@ class GradientTransformation(NamedTuple):
init: TransformInitFn
update: TransformUpdateFn

# pylint: disable-next=redefined-builtin
def chain(self, next: 'GradientTransformation') -> 'ChainedGradientTransformation':
"""Chain two gradient transformations together."""
return ChainedGradientTransformation(self, next)


class ChainedGradientTransformation(GradientTransformation):
"""A chain of gradient transformations.

This class is a subclass of :class:`GradientTransformation` which allows for chaining of
gradient transformations.
"""

transformations: Tuple[GradientTransformation, ...]

def __new__(cls, *transformations: GradientTransformation) -> 'ChainedGradientTransformation':
transformations = tuple(
itertools.chain.from_iterable(
t.transformations if isinstance(t, ChainedGradientTransformation) else (t,)
for t in transformations
)
)

init_fns, update_fns = tuple(zip(*transformations))

def init_fn(params):
return tuple(fn(params) for fn in init_fns)

def update_fn(updates, state, *, params=None, inplace=True):
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you'
'have called init first!'
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
updates, new_s = fn(updates, s, params=params, inplace=inplace)
new_state.append(new_s)
return updates, tuple(new_state)

instance = super().__new__(cls, init_fn, update_fn)
instance.transformations = tuple(transformations)
return instance

def __str__(self):
return '{}(\n {}\n)'.format(
self.__class__.__name__, ',\n '.join(repr(t) for t in self.transformations)
)

__repr__ = __str__

def __eq__(self, other: object) -> bool:
if isinstance(other, ChainedGradientTransformation):
return self.transformations == other.transformations
if isinstance(other, GradientTransformation):
return self.transformations == (other,)
return False

def __hash__(self) -> int:
return hash(self.transformations)

def __getstate__(self) -> Tuple[GradientTransformation, ...]:
return self.transformations

def __setstate__(self, state: Tuple[GradientTransformation, ...]) -> None:
self.transformations = state

def __reduce__(self) -> Tuple[Callable, Tuple[Tuple[GradientTransformation, ...]]]:
return ChainedGradientTransformation, (self.transformations,)


def identity() -> GradientTransformation:
"""Stateless identity transformation that leaves input gradients untouched.
Expand Down
3 changes: 1 addition & 2 deletions torchopt/_src/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def clip_grad_norm(
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params):
del params
def init_fn(_):
return ClipState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
Expand Down
18 changes: 1 addition & 17 deletions torchopt/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,5 @@ def chain(*args: base.GradientTransformation) -> base.GradientTransformation:
Returns:
A single ``(init_fn, update_fn)`` tuple.
"""
init_fns, update_fns = tuple(zip(*args))

def init_fn(params):
return tuple(fn(params) for fn in init_fns)

def update_fn(updates, state, *, params=None, inplace=True):
if len(update_fns) != len(state):
raise ValueError(
'The number of updates and states has to be the same in chain! Make sure you have '
'called init first!'
)
new_state = []
for s, fn in zip(state, update_fns): # pylint: disable=invalid-name
updates, new_s = fn(updates, s, params=params, inplace=inplace)
new_state.append(new_s)
return updates, tuple(new_state)

return base.GradientTransformation(init_fn, update_fn)
return base.ChainedGradientTransformation(*args)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy