Skip to content

Commit 1c7691e

Browse files
Benjamin-eecsXuehaiPan
authored andcommitted
feat(torchopt): adagrad optimizer support
1 parent fad99d8 commit 1c7691e

File tree

3 files changed

+227
-32
lines changed

3 files changed

+227
-32
lines changed

torchopt/alias.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,79 @@ def schedule_wrapper(count):
137137
return transform._scale(-lr, already_flattened=True) # pylint: disable=protected-access
138138

139139

140+
# pylint: disable-next=too-many-arguments
141+
def adagrad(
142+
lr: ScalarOrSchedule = 1e-2,
143+
lr_decay: float = 0.0,
144+
weight_decay: float = 0.0,
145+
initial_accumulator_value: float = 0.0,
146+
eps: float = 1e-10,
147+
*,
148+
eps_root: float = 0.0, # pylint: disable=unused-argument
149+
moment_requires_grad: bool = False, # pylint: disable=unused-argument
150+
maximize: bool = False,
151+
) -> base.GradientTransformation:
152+
"""The functional Adagrad optimizer.
153+
154+
Adagrad is an algorithm for gradient based optimization that anneals the
155+
learning rate for each parameter during the course of training.\
156+
157+
WARNING: Adagrad's main limit is the monotonic accumulation of squared
158+
gradients in the denominator: since all terms are >0, the sum keeps growing
159+
during training and the learning rate eventually becomes vanishingly small.
160+
161+
References:
162+
Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html
163+
164+
Args:
165+
lr: (default: :const:`1e-3`)
166+
This is a fixed global scaling factor.
167+
lr_decay: (default: :const:`0.0`)
168+
Learning rate decay.
169+
weight_decay: (default: :const:`0.0`)
170+
Weight decay, add L2 penalty to parameters.
171+
initial_accumulator_value: (default: :const:`0.0`)
172+
Initial value for the accumulator.
173+
eps: (default: :const:`1e-8`)
174+
A small constant applied to denominator outside of the square root (as in the Adam
175+
paper) to avoid dividing by zero when rescaling.
176+
eps_root: (default: :data:`0.0`)
177+
A small constant applied to denominator inside the square root (as in RMSProp), to avoid
178+
dividing by zero when rescaling. This is needed for example when computing
179+
(meta-)gradients through Adam.
180+
moment_requires_grad: (default: :data:`False`)
181+
If :data:`True` the momentums will be created with flag ``requires_grad=True``, this
182+
flag is often used in Meta-Learning algorithms.
183+
maximize: (default: :data:`False`)
184+
Maximize the params based on the objective, instead of minimizing.
185+
use_accelerated_op: (default: :data:`False`)
186+
If :data:`True` use our implemented fused operator.
187+
188+
Returns:
189+
The corresponding :class:`GradientTransformation` instance.
190+
191+
See Also:
192+
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
193+
"""
194+
# pylint: disable=unneeded-not
195+
if not (callable(lr) or 0.0 <= lr):
196+
raise ValueError(f'Invalid learning rate: {lr}')
197+
if not 0.0 <= eps:
198+
raise ValueError(f'Invalid epsilon value: {eps}')
199+
if not 0.0 <= lr_decay:
200+
raise ValueError(f'Invalid lr_decay value: {lr_decay}')
201+
if not 0.0 <= weight_decay:
202+
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
203+
# pylint: enable=unneeded-not
204+
return transform.with_flattened_tree(
205+
combine.chain(
206+
_flip_sign_and_weight_decay(weight_decay=weight_decay, maximize=maximize),
207+
transform.scale_by_rss(initial_accumulator_value=initial_accumulator_value, eps=eps),
208+
_scale_by_neg_lr(lr),
209+
)
210+
)
211+
212+
140213
# pylint: disable-next=too-many-arguments
141214
def adam(
142215
lr: ScalarOrSchedule = 1e-3,

torchopt/schedule.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"""Learning rate schedules."""
3333

3434
import logging
35+
from typing import Optional
3536

3637
import numpy as np
3738
import torch
@@ -42,6 +43,58 @@
4243
__all__ = ['polynomial_schedule', 'linear_schedule']
4344

4445

46+
def linear_decay(
47+
init_value: Scalar,
48+
decay_rate: Scalar,
49+
transition_begin: int = 0,
50+
transition_steps: Optional[int] = None,
51+
end_value: Optional[float] = None,
52+
) -> base.Schedule:
53+
"""Constructs a schedule with either continuous or discrete exponential decay.
54+
Args:
55+
value: value to be held constant throughout.
56+
Returns:
57+
schedule: A function that maps step counts to values.
58+
"""
59+
if transition_steps is not None and transition_steps <= 0:
60+
logging.info(
61+
'An linear schedule was set with a non-positive `transition_steps`'
62+
' value; this will result in a constant schedule with value '
63+
'`init_value`.'
64+
)
65+
return lambda count: init_value
66+
67+
if decay_rate == 0:
68+
logging.info(
69+
'An linear schedule was set with a zero `decay_rate` value; '
70+
'this will result in a constant schedule with value `init_value`.'
71+
)
72+
return lambda count: init_value
73+
74+
if transition_begin < 0:
75+
logging.info(
76+
'An linear schedule was set with a negative `transition_begin` '
77+
'value; this will result in `transition_begin` falling back to `0`.'
78+
)
79+
transition_begin = 0
80+
81+
if end_value is not None:
82+
clip_fn = max if decay_rate < 1.0 else min
83+
84+
def schedule(count: Numeric) -> Numeric:
85+
decreased_count = count - transition_begin
86+
decayed_value = (
87+
init_value / (1 + (decreased_count - 1) * decay_rate)
88+
if decreased_count > 0
89+
else init_value
90+
)
91+
if end_value is not None:
92+
decayed_value = clip_fn(decayed_value, end_value)
93+
return decayed_value
94+
95+
return schedule
96+
97+
4598
def polynomial_schedule(
4699
init_value: Scalar,
47100
end_value: Scalar,

torchopt/transform.py

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def inc_count(updates: Updates, count: TensorTree) -> TensorTree:
7878
"""Increments int counter by one.
7979
8080
Returns:
81-
A counter incremeted by one, or max_int if the maximum precision is reached.
81+
A counter incremented by one, or max_int if the maximum precision is reached.
8282
"""
8383
return _inc_count(updates=updates, count=count, already_flattened=False)
8484

@@ -265,7 +265,7 @@ def scale_by_adam(
265265
Term added to the denominator inside the square-root to improve
266266
numerical stability when back-propagating gradients through the rescaling.
267267
moment_requires_grad: (default: :data:`False`)
268-
if :data:`True`, states will be created with flag `requires_grad = True`.
268+
If :data:`True`, states will be created with flag `requires_grad = True`.
269269
270270
Returns:
271271
An (init_fn, update_fn) tuple.
@@ -367,7 +367,7 @@ def scale_by_accelerated_adam(
367367
Term added to the denominator inside the square-root to improve
368368
numerical stability when back-propagating gradients through the rescaling.
369369
moment_requires_grad: (default: :data:`False`)
370-
if :data:`True`, states will be created with flag `requires_grad = True`.
370+
If :data:`True`, states will be created with flag `requires_grad = True`.
371371
372372
Returns:
373373
An (init_fn, update_fn) tuple.
@@ -474,7 +474,7 @@ def trace(
474474
nesterov: (default: :data:`False`)
475475
Whether to use Nesterov momentum.
476476
moment_requires_grad: (default: :data:`False`)
477-
if :data:`True`, states will be created with flag `requires_grad = True`.
477+
If :data:`True`, states will be created with flag `requires_grad = True`.
478478
479479
Returns:
480480
An (init_fn, update_fn) tuple.
@@ -597,7 +597,7 @@ def scale_by_rms(
597597
eps: (default: :const:`1e-8`)
598598
Term added to the denominator to improve numerical stability.
599599
initial_scale: (default: :const:`0.0`)
600-
Initial value for second moment
600+
Initial value for second moment.
601601
602602
Returns:
603603
An (init_fn, update_fn) tuple.
@@ -675,7 +675,7 @@ def scale_by_stddev(
675675
eps: (default: :const:`1e-8`)
676676
Term added to the denominator to improve numerical stability.
677677
initial_scale: (default: :const:`0.0`)
678-
Initial value for second moment
678+
Initial value for second moment.
679679
680680
Returns:
681681
An (init_fn, update_fn) tuple.
@@ -745,9 +745,8 @@ class MaskedState(NamedTuple):
745745
class MaskedNode(NamedTuple):
746746
"""A node used to mask out unspecified parts of a tree.
747747
748-
This node is ignored when mapping functions across the tree e.g. using
749-
:func:`pytree.tree_map` since it is a container without children. It can
750-
therefore be used to mask out parts of a tree.
748+
This node is ignored when mapping functions across the tree e.g. using :func:`pytree.tree_map`
749+
since it is a container without children. It can therefore be used to mask out parts of a tree.
751750
"""
752751

753752

@@ -757,28 +756,27 @@ def masked(
757756
) -> GradientTransformation:
758757
"""Mask updates so only some are transformed, the rest are passed through.
759758
760-
For example, it is common to skip weight decay for BatchNorm scale and all
761-
bias parameters. In many networks, these are the only parameters with only
762-
one dimension. So, you may create a mask function to mask these out as
763-
follows::
764-
mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p)
765-
weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn)
759+
For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In
760+
many networks, these are the only parameters with only one dimension. So, you may create a mask
761+
function to mask these out as follows::
762+
mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p)
763+
weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn)
766764
You may alternatively create the mask pytree upfront::
767-
mask = pytree.tree_map(lambda x: x.ndim != 1, params)
768-
weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask)
765+
mask = pytree.tree_map(lambda x: x.ndim != 1, params)
766+
weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask)
769767
For the ``inner`` transform, state will only be stored for the parameters that
770-
have a mask value of ``True``.
768+
have a mask value of :data:`True`.
771769
772770
Args:
773-
inner: Inner transformation to mask.
774-
mask: a PyTree with same structure as (or a prefix of) the params PyTree, or
775-
a Callable that returns such a pytree given the params/updates. The leaves
776-
should be booleans, ``True`` for leaves/subtrees you want to apply the
777-
transformation to, and ``False`` for those you want to skip. The mask must
778-
be static for the gradient transformation to be jit-compilable.
771+
inner: Inner transformation to mask.
772+
mask: A PyTree with same structure as (or a prefix of) the params pytree, or a Callable that
773+
returns such a pytree given the params/updates. The leaves should be booleans,
774+
:data:`True` for leaves/subtrees you want to apply the transformation to, and
775+
:data:`False` for those you want to skip. The mask must be static for the gradient
776+
transformation to be jit-compilable.
779777
780778
Returns:
781-
New GradientTransformation wrapping ``inner``.
779+
A new :class:`GradientTransformation` wrapping ``inner``.
782780
"""
783781
return _masked(
784782
inner=inner,
@@ -831,17 +829,17 @@ def add_decayed_weights(
831829
weight_decay: float = 0.0,
832830
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
833831
) -> GradientTransformation:
834-
"""Add parameter scaled by `weight_decay`.
832+
"""Add parameter scaled by ``weight_decay``.
835833
836834
Args:
837-
weight_decay: a scalar weight decay rate.
838-
mask: a tree with same structure as (or a prefix of) the params PyTree,
839-
or a Callable that returns such a pytree given the params/updates.
840-
The leaves should be booleans, `True` for leaves/subtrees you want to
841-
apply the transformation to, and `False` for those you want to skip.
835+
weight_decay: A scalar weight decay rate.
836+
mask: A tree with same structure as (or a prefix of) the params pytree, or a Callable that
837+
returns such a pytree given the params/updates. The leaves should be booleans,
838+
:data:`True` for leaves/subtrees you want to apply the transformation to, and
839+
:data:`False` for those you want to skip.
842840
843841
Returns:
844-
An (init_fn, update_fn) tuple.
842+
An (init_fn, update_fn) tuple.
845843
"""
846844
return _add_decayed_weights(
847845
weight_decay=weight_decay,
@@ -902,3 +900,74 @@ def f(g, p):
902900
already_flattened=already_flattened,
903901
)
904902
return GradientTransformation(init_fn, update_fn)
903+
904+
905+
class ScaleByRssState(NamedTuple):
906+
"""State holding the sum of gradient squares to date."""
907+
908+
sum_of_squares: Updates
909+
910+
911+
def scale_by_rss(
912+
initial_accumulator_value: float = 0.1,
913+
eps: float = 1e-7,
914+
) -> GradientTransformation:
915+
"""Rescale updates by the root of the sum of all squared gradients to date.
916+
917+
References:
918+
[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
919+
[McMahan et al., 2010](https://arxiv.org/abs/1002.4908)
920+
921+
Args:
922+
initial_accumulator_value: Starting value for accumulators, must be >= 0.
923+
eps: A small floating point value to avoid zero denominator.
924+
925+
Returns:
926+
An (init_fn, update_fn) tuple.
927+
"""
928+
return _scale_by_rss(
929+
initial_accumulator_value=initial_accumulator_value,
930+
eps=eps,
931+
already_flattened=False,
932+
)
933+
934+
935+
def _scale_by_rss(
936+
initial_accumulator_value: float = 0.1,
937+
eps: float = 1e-7,
938+
*,
939+
already_flattened: bool = False,
940+
) -> GradientTransformation:
941+
942+
if already_flattened:
943+
tree_map = map_flattened
944+
else:
945+
tree_map = pytree.tree_map
946+
947+
def init_fn(params):
948+
sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params)
949+
return ScaleByRssState(sum_of_squares=sum_of_squares)
950+
951+
def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument
952+
del params
953+
sum_of_squares = tree_map(
954+
lambda g, t: (g.conj() * g).real + t, updates, state.sum_of_squares
955+
)
956+
# inv_sqrt_g_square = tree_map(
957+
# lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares
958+
# )
959+
if inplace:
960+
961+
def f(t):
962+
return t.add_(eps).rsqrt_() if t > 0.0 else 0.0
963+
964+
else:
965+
966+
def f(t):
967+
return t.add(eps).rsqrt() if t > 0.0 else 0.0
968+
969+
inv_sqrt_g_square = tree_map(f, sum_of_squares)
970+
updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates)
971+
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)
972+
973+
return GradientTransformation(init_fn, update_fn)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy