Skip to content

Commit 2892929

Browse files
feat(torchopt): adagrad optimizer support
Co-authored-by: Benjamin-eecs <benjaminliu.eecs@gmail.com>
1 parent 89e7912 commit 2892929

File tree

6 files changed

+327
-2
lines changed

6 files changed

+327
-2
lines changed

torchopt/alias/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
# ==============================================================================
3232
r"""The aliases of preset :class:`GradientTransformation`\s for optimizers."""
3333

34+
from torchopt.alias.adagrad import adagrad
3435
from torchopt.alias.adam import adam
3536
from torchopt.alias.adamw import adamw
3637
from torchopt.alias.rmsprop import rmsprop
3738
from torchopt.alias.sgd import sgd
3839

3940

40-
__all__ = ['adam', 'adamw', 'rmsprop', 'sgd']
41+
__all__ = ['adagrad', 'adam', 'adamw', 'rmsprop', 'sgd']

torchopt/alias/adagrad.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file is modified from:
16+
# https://github.com/deepmind/optax/blob/master/optax/_src/alias.py
17+
# ==============================================================================
18+
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
19+
#
20+
# Licensed under the Apache License, Version 2.0 (the "License");
21+
# you may not use this file except in compliance with the License.
22+
# You may obtain a copy of the License at
23+
#
24+
# http://www.apache.org/licenses/LICENSE-2.0
25+
#
26+
# Unless required by applicable law or agreed to in writing, software
27+
# distributed under the License is distributed on an "AS IS" BASIS,
28+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29+
# See the License for the specific language governing permissions and
30+
# limitations under the License.
31+
# ==============================================================================
32+
"""Preset :class:`GradientTransformation` for the AdaGrad optimizer."""
33+
34+
from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr
35+
from torchopt.combine import chain_flat
36+
from torchopt.transform import scale_by_rss
37+
from torchopt.typing import GradientTransformation, ScalarOrSchedule
38+
39+
40+
__all__ = ['adagrad']
41+
42+
43+
# pylint: disable-next=too-many-arguments
44+
def adagrad(
45+
lr: ScalarOrSchedule = 1e-2,
46+
lr_decay: float = 0.0,
47+
weight_decay: float = 0.0,
48+
initial_accumulator_value: float = 0.0,
49+
eps: float = 1e-10,
50+
*,
51+
maximize: bool = False,
52+
) -> GradientTransformation:
53+
"""The functional AdaGrad optimizer.
54+
55+
AdaGrad is an algorithm for gradient based optimization that anneals the learning rate for each
56+
parameter during the course of training.
57+
WARNING: AdaGrad's main limit is the monotonic accumulation of squared gradients in the
58+
denominator: since all terms are >0, the sum keeps growing during training and the learning rate
59+
eventually becomes vanishingly small.
60+
61+
References:
62+
Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html
63+
64+
Args:
65+
lr: (default: :const:`1e-3`)
66+
This is a fixed global scaling factor.
67+
lr_decay: (default: :const:`0.0`)
68+
Learning rate decay.
69+
weight_decay: (default: :const:`0.0`)
70+
Weight decay, add L2 penalty to parameters.
71+
initial_accumulator_value: (default: :const:`0.0`)
72+
Initial value for the accumulator.
73+
eps: (default: :const:`1e-8`)
74+
A small constant applied to denominator outside of the square root (as in the Adam
75+
paper) to avoid dividing by zero when rescaling.
76+
maximize: (default: :data:`False`)
77+
Maximize the params based on the objective, instead of minimizing.
78+
use_accelerated_op: (default: :data:`False`)
79+
If :data:`True` use our implemented fused operator.
80+
81+
Returns:
82+
The corresponding :class:`GradientTransformation` instance.
83+
84+
See Also:
85+
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
86+
"""
87+
# pylint: disable=unneeded-not
88+
if not (callable(lr) or 0.0 <= lr):
89+
raise ValueError(f'Invalid learning rate: {lr}')
90+
if not 0.0 <= eps:
91+
raise ValueError(f'Invalid epsilon value: {eps}')
92+
if not 0.0 <= lr_decay:
93+
raise ValueError(f'Invalid lr_decay value: {lr_decay}')
94+
if not 0.0 <= weight_decay:
95+
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
96+
# pylint: enable=unneeded-not
97+
98+
return chain_flat(
99+
flip_sign_and_add_weight_decay(weight_decay=weight_decay, maximize=maximize),
100+
scale_by_rss.flat(initial_accumulator_value=initial_accumulator_value, eps=eps), # type: ignore[attr-defined]
101+
scale_by_neg_lr(lr),
102+
)

torchopt/schedule/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
# ==============================================================================
3232
"""Learning rate schedules."""
3333

34+
from torchopt.schedule.exponential_decay import exponential_decay
3435
from torchopt.schedule.polynomial import linear_schedule, polynomial_schedule
3536

3637

37-
__all__ = ['polynomial_schedule', 'linear_schedule']
38+
__all__ = ['exponential_decay', 'polynomial_schedule', 'linear_schedule']
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file is modified from:
16+
# https://github.com/deepmind/optax/blob/master/optax/_src/schedule.py
17+
# ==============================================================================
18+
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
19+
#
20+
# Licensed under the Apache License, Version 2.0 (the "License");
21+
# you may not use this file except in compliance with the License.
22+
# You may obtain a copy of the License at
23+
#
24+
# http://www.apache.org/licenses/LICENSE-2.0
25+
#
26+
# Unless required by applicable law or agreed to in writing, software
27+
# distributed under the License is distributed on an "AS IS" BASIS,
28+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29+
# See the License for the specific language governing permissions and
30+
# limitations under the License.
31+
# ==============================================================================
32+
"""Exponential learning rate decay."""
33+
34+
import logging
35+
from typing import Optional
36+
37+
from torchopt.typing import Numeric, Scalar, Schedule
38+
39+
40+
__all__ = ['exponential_decay']
41+
42+
43+
def exponential_decay(
44+
init_value: Scalar,
45+
decay_rate: Scalar,
46+
transition_begin: int = 0,
47+
transition_steps: Optional[int] = None,
48+
end_value: Optional[float] = None,
49+
) -> Schedule:
50+
"""Constructs a schedule with either continuous or discrete exponential decay.
51+
Args:
52+
value: value to be held constant throughout.
53+
Returns:
54+
schedule: A function that maps step counts to values.
55+
"""
56+
if transition_steps is not None and transition_steps <= 0:
57+
logging.info(
58+
'An linear schedule was set with a non-positive `transition_steps`'
59+
' value; this will result in a constant schedule with value '
60+
'`init_value`.'
61+
)
62+
return lambda count: init_value
63+
64+
if decay_rate == 0:
65+
logging.info(
66+
'An linear schedule was set with a zero `decay_rate` value; '
67+
'this will result in a constant schedule with value `init_value`.'
68+
)
69+
return lambda count: init_value
70+
71+
if transition_begin < 0:
72+
logging.info(
73+
'An linear schedule was set with a negative `transition_begin` '
74+
'value; this will result in `transition_begin` falling back to `0`.'
75+
)
76+
transition_begin = 0
77+
78+
if end_value is not None:
79+
clip_fn = max if decay_rate < 1.0 else min
80+
81+
def schedule(count: Numeric) -> Numeric:
82+
decreased_count = count - transition_begin
83+
decayed_value = (
84+
init_value / (1 + (decreased_count - 1) * decay_rate)
85+
if decreased_count > 0
86+
else init_value
87+
)
88+
if end_value is not None:
89+
decayed_value = clip_fn(decayed_value, end_value)
90+
return decayed_value
91+
92+
return schedule

torchopt/transform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchopt.transform.scale import scale
3636
from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam
3737
from torchopt.transform.scale_by_rms import scale_by_rms
38+
from torchopt.transform.scale_by_rss import scale_by_rss
3839
from torchopt.transform.scale_by_schedule import scale_by_schedule
3940
from torchopt.transform.scale_by_stddev import scale_by_stddev
4041
from torchopt.transform.trace import trace
@@ -47,6 +48,7 @@
4748
'add_decayed_weights',
4849
'scale_by_adam',
4950
'scale_by_accelerated_adam',
51+
'scale_by_rss',
5052
'scale_by_rms',
5153
'scale_by_stddev',
5254
]

torchopt/transform/scale_by_rss.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file is modified from:
16+
# https://github.com/deepmind/optax/blob/master/optax/_src/transform.py
17+
# ==============================================================================
18+
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
19+
#
20+
# Licensed under the Apache License, Version 2.0 (the "License");
21+
# you may not use this file except in compliance with the License.
22+
# You may obtain a copy of the License at
23+
#
24+
# http://www.apache.org/licenses/LICENSE-2.0
25+
#
26+
# Unless required by applicable law or agreed to in writing, software
27+
# distributed under the License is distributed on an "AS IS" BASIS,
28+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29+
# See the License for the specific language governing permissions and
30+
# limitations under the License.
31+
# ==============================================================================
32+
"""Preset transformations for scaling updates by the root of the sum of all squared gradients."""
33+
34+
from typing import NamedTuple
35+
36+
import torch
37+
38+
from torchopt import pytree
39+
from torchopt.base import GradientTransformation
40+
from torchopt.transform.utils import tree_map_flat
41+
from torchopt.typing import Updates
42+
43+
44+
__all__ = ['scale_by_rss']
45+
46+
47+
class ScaleByRssState(NamedTuple):
48+
"""State holding the sum of gradient squares to date."""
49+
50+
sum_of_squares: Updates
51+
52+
53+
def scale_by_rss(
54+
initial_accumulator_value: float = 0.1,
55+
eps: float = 1e-7,
56+
) -> GradientTransformation:
57+
"""Rescale updates by the root of the sum of all squared gradients to date.
58+
59+
References:
60+
[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
61+
[McMahan et al., 2010](https://arxiv.org/abs/1002.4908)
62+
63+
Args:
64+
initial_accumulator_value: Starting value for accumulators, must be >= 0.
65+
eps: A small floating point value to avoid zero denominator.
66+
67+
Returns:
68+
An (init_fn, update_fn) tuple.
69+
"""
70+
return _scale_by_rss(
71+
initial_accumulator_value=initial_accumulator_value,
72+
eps=eps,
73+
already_flattened=False,
74+
)
75+
76+
77+
def _scale_by_rss_flat(
78+
initial_accumulator_value: float = 0.1,
79+
eps: float = 1e-7,
80+
) -> GradientTransformation:
81+
return _scale_by_rss(
82+
initial_accumulator_value=initial_accumulator_value,
83+
eps=eps,
84+
already_flattened=True,
85+
)
86+
87+
88+
def _scale_by_rss(
89+
initial_accumulator_value: float = 0.1,
90+
eps: float = 1e-7,
91+
*,
92+
already_flattened: bool = False,
93+
) -> GradientTransformation:
94+
95+
if already_flattened:
96+
tree_map = tree_map_flat
97+
else:
98+
tree_map = pytree.tree_map
99+
100+
def init_fn(params):
101+
sum_of_squares = tree_map(lambda t: torch.full_like(t, initial_accumulator_value), params)
102+
return ScaleByRssState(sum_of_squares=sum_of_squares)
103+
104+
def update_fn(updates, state, params=None, inplace=True): # pylint: disable=unused-argument
105+
sum_of_squares = tree_map(
106+
lambda g, t: (g.conj() * g).real + t, updates, state.sum_of_squares
107+
)
108+
109+
if inplace:
110+
111+
def f(t):
112+
return t.add_(eps).rsqrt_() if t > 0.0 else 0.0
113+
114+
else:
115+
116+
def f(t):
117+
return t.add(eps).rsqrt() if t > 0.0 else 0.0
118+
119+
inv_sqrt_g_square = tree_map(f, sum_of_squares)
120+
updates = tree_map(lambda scale, g: scale * g, inv_sqrt_g_square, updates)
121+
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)
122+
123+
return GradientTransformation(init_fn, update_fn)
124+
125+
126+
scale_by_rss.flat = _scale_by_rss_flat # type: ignore[attr-defined]
127+
scale_by_rss.impl = _scale_by_rss # type: ignore[attr-defined]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy