@@ -78,7 +78,7 @@ def inc_count(updates: Updates, count: TensorTree) -> TensorTree:
78
78
"""Increments int counter by one.
79
79
80
80
Returns:
81
- A counter incremeted by one, or max_int if the maximum precision is reached.
81
+ A counter incremented by one, or max_int if the maximum precision is reached.
82
82
"""
83
83
return _inc_count (updates = updates , count = count , already_flattened = False )
84
84
@@ -265,7 +265,7 @@ def scale_by_adam(
265
265
Term added to the denominator inside the square-root to improve
266
266
numerical stability when back-propagating gradients through the rescaling.
267
267
moment_requires_grad: (default: :data:`False`)
268
- if :data:`True`, states will be created with flag `requires_grad = True`.
268
+ If :data:`True`, states will be created with flag `requires_grad = True`.
269
269
270
270
Returns:
271
271
An (init_fn, update_fn) tuple.
@@ -367,7 +367,7 @@ def scale_by_accelerated_adam(
367
367
Term added to the denominator inside the square-root to improve
368
368
numerical stability when back-propagating gradients through the rescaling.
369
369
moment_requires_grad: (default: :data:`False`)
370
- if :data:`True`, states will be created with flag `requires_grad = True`.
370
+ If :data:`True`, states will be created with flag `requires_grad = True`.
371
371
372
372
Returns:
373
373
An (init_fn, update_fn) tuple.
@@ -474,7 +474,7 @@ def trace(
474
474
nesterov: (default: :data:`False`)
475
475
Whether to use Nesterov momentum.
476
476
moment_requires_grad: (default: :data:`False`)
477
- if :data:`True`, states will be created with flag `requires_grad = True`.
477
+ If :data:`True`, states will be created with flag `requires_grad = True`.
478
478
479
479
Returns:
480
480
An (init_fn, update_fn) tuple.
@@ -597,7 +597,7 @@ def scale_by_rms(
597
597
eps: (default: :const:`1e-8`)
598
598
Term added to the denominator to improve numerical stability.
599
599
initial_scale: (default: :const:`0.0`)
600
- Initial value for second moment
600
+ Initial value for second moment.
601
601
602
602
Returns:
603
603
An (init_fn, update_fn) tuple.
@@ -675,7 +675,7 @@ def scale_by_stddev(
675
675
eps: (default: :const:`1e-8`)
676
676
Term added to the denominator to improve numerical stability.
677
677
initial_scale: (default: :const:`0.0`)
678
- Initial value for second moment
678
+ Initial value for second moment.
679
679
680
680
Returns:
681
681
An (init_fn, update_fn) tuple.
@@ -745,9 +745,8 @@ class MaskedState(NamedTuple):
745
745
class MaskedNode (NamedTuple ):
746
746
"""A node used to mask out unspecified parts of a tree.
747
747
748
- This node is ignored when mapping functions across the tree e.g. using
749
- :func:`pytree.tree_map` since it is a container without children. It can
750
- therefore be used to mask out parts of a tree.
748
+ This node is ignored when mapping functions across the tree e.g. using :func:`pytree.tree_map`
749
+ since it is a container without children. It can therefore be used to mask out parts of a tree.
751
750
"""
752
751
753
752
@@ -757,28 +756,27 @@ def masked(
757
756
) -> GradientTransformation :
758
757
"""Mask updates so only some are transformed, the rest are passed through.
759
758
760
- For example, it is common to skip weight decay for BatchNorm scale and all
761
- bias parameters. In many networks, these are the only parameters with only
762
- one dimension. So, you may create a mask function to mask these out as
763
- follows::
764
- mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p)
765
- weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn)
759
+ For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In
760
+ many networks, these are the only parameters with only one dimension. So, you may create a mask
761
+ function to mask these out as follows::
762
+ mask_fn = lambda p: pytree.tree_map(lambda x: x.ndim != 1, p)
763
+ weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask_fn)
766
764
You may alternatively create the mask pytree upfront::
767
- mask = pytree.tree_map(lambda x: x.ndim != 1, params)
768
- weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask)
765
+ mask = pytree.tree_map(lambda x: x.ndim != 1, params)
766
+ weight_decay = torchopt.masked(torchopt.add_decayed_weights(0.001), mask)
769
767
For the ``inner`` transform, state will only be stored for the parameters that
770
- have a mask value of `` True` `.
768
+ have a mask value of :data:` True`.
771
769
772
770
Args:
773
- inner: Inner transformation to mask.
774
- mask: a PyTree with same structure as (or a prefix of) the params PyTree , or
775
- a Callable that returns such a pytree given the params/updates. The leaves
776
- should be booleans, `` True`` for leaves/subtrees you want to apply the
777
- transformation to, and `` False`` for those you want to skip. The mask must
778
- be static for the gradient transformation to be jit-compilable.
771
+ inner: Inner transformation to mask.
772
+ mask: A PyTree with same structure as (or a prefix of) the params pytree , or a Callable that
773
+ returns such a pytree given the params/updates. The leaves should be booleans,
774
+ :data:` True` for leaves/subtrees you want to apply the transformation to, and
775
+ :data:` False` for those you want to skip. The mask must be static for the gradient
776
+ transformation to be jit-compilable.
779
777
780
778
Returns:
781
- New GradientTransformation wrapping ``inner``.
779
+ A new :class:` GradientTransformation` wrapping ``inner``.
782
780
"""
783
781
return _masked (
784
782
inner = inner ,
@@ -831,17 +829,17 @@ def add_decayed_weights(
831
829
weight_decay : float = 0.0 ,
832
830
mask : Optional [Union [Any , Callable [[Params ], Any ]]] = None ,
833
831
) -> GradientTransformation :
834
- """Add parameter scaled by `weight_decay`.
832
+ """Add parameter scaled by `` weight_decay` `.
835
833
836
834
Args:
837
- weight_decay: a scalar weight decay rate.
838
- mask: a tree with same structure as (or a prefix of) the params PyTree,
839
- or a Callable that returns such a pytree given the params/updates.
840
- The leaves should be booleans, `True` for leaves/subtrees you want to
841
- apply the transformation to, and `False` for those you want to skip.
835
+ weight_decay: A scalar weight decay rate.
836
+ mask: A tree with same structure as (or a prefix of) the params pytree, or a Callable that
837
+ returns such a pytree given the params/updates. The leaves should be booleans,
838
+ :data: `True` for leaves/subtrees you want to apply the transformation to, and
839
+ :data: `False` for those you want to skip.
842
840
843
841
Returns:
844
- An (init_fn, update_fn) tuple.
842
+ An (init_fn, update_fn) tuple.
845
843
"""
846
844
return _add_decayed_weights (
847
845
weight_decay = weight_decay ,
@@ -902,3 +900,74 @@ def f(g, p):
902
900
already_flattened = already_flattened ,
903
901
)
904
902
return GradientTransformation (init_fn , update_fn )
903
+
904
+
905
+ class ScaleByRssState (NamedTuple ):
906
+ """State holding the sum of gradient squares to date."""
907
+
908
+ sum_of_squares : Updates
909
+
910
+
911
+ def scale_by_rss (
912
+ initial_accumulator_value : float = 0.1 ,
913
+ eps : float = 1e-7 ,
914
+ ) -> GradientTransformation :
915
+ """Rescale updates by the root of the sum of all squared gradients to date.
916
+
917
+ References:
918
+ [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
919
+ [McMahan et al., 2010](https://arxiv.org/abs/1002.4908)
920
+
921
+ Args:
922
+ initial_accumulator_value: Starting value for accumulators, must be >= 0.
923
+ eps: A small floating point value to avoid zero denominator.
924
+
925
+ Returns:
926
+ An (init_fn, update_fn) tuple.
927
+ """
928
+ return _scale_by_rss (
929
+ initial_accumulator_value = initial_accumulator_value ,
930
+ eps = eps ,
931
+ already_flattened = False ,
932
+ )
933
+
934
+
935
+ def _scale_by_rss (
936
+ initial_accumulator_value : float = 0.1 ,
937
+ eps : float = 1e-7 ,
938
+ * ,
939
+ already_flattened : bool = False ,
940
+ ) -> GradientTransformation :
941
+
942
+ if already_flattened :
943
+ tree_map = map_flattened
944
+ else :
945
+ tree_map = pytree .tree_map
946
+
947
+ def init_fn (params ):
948
+ sum_of_squares = tree_map (lambda t : torch .full_like (t , initial_accumulator_value ), params )
949
+ return ScaleByRssState (sum_of_squares = sum_of_squares )
950
+
951
+ def update_fn (updates , state , params = None , inplace = True ): # pylint: disable=unused-argument
952
+ del params
953
+ sum_of_squares = tree_map (
954
+ lambda g , t : (g .conj () * g ).real + t , updates , state .sum_of_squares
955
+ )
956
+ # inv_sqrt_g_square = tree_map(
957
+ # lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares
958
+ # )
959
+ if inplace :
960
+
961
+ def f (t ):
962
+ return t .add_ (eps ).rsqrt_ () if t > 0.0 else 0.0
963
+
964
+ else :
965
+
966
+ def f (t ):
967
+ return t .add (eps ).rsqrt () if t > 0.0 else 0.0
968
+
969
+ inv_sqrt_g_square = tree_map (f , sum_of_squares )
970
+ updates = tree_map (lambda scale , g : scale * g , inv_sqrt_g_square , updates )
971
+ return updates , ScaleByRssState (sum_of_squares = sum_of_squares )
972
+
973
+ return GradientTransformation (init_fn , update_fn )
0 commit comments