Skip to content

Commit 05e0f3f

Browse files
authored
feat(hook): add nan_to_num hooks (#119)
1 parent f831b53 commit 05e0f3f

File tree

8 files changed

+82
-5
lines changed

8 files changed

+82
-5
lines changed

.github/workflows/set_release.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
VERSION_FILE = ROOT / 'torchopt' / 'version.py'
88

9-
VERSION_CONTENT = VERSION_FILE.read_text()
9+
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
1010

1111
VERSION_FILE.write_text(
1212
data=re.sub(

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16+
- Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119).
1617
- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
1718
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
1819
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).

docs/source/api/api.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,28 +186,32 @@ Optimizer Hooks
186186

187187
register_hook
188188
zero_nan_hook
189+
nan_to_num_hook
189190

190191
Hook
191192
~~~~
192193

193194
.. autofunction:: register_hook
194195
.. autofunction:: zero_nan_hook
196+
.. autofunction:: nan_to_num_hook
195197

196198
------
197199

198200
Gradient Transformation
199201
=======================
200202

201-
.. currentmodule:: torchopt.clip
203+
.. currentmodule:: torchopt
202204

203205
.. autosummary::
204206

205207
clip_grad_norm
208+
nan_to_num
206209

207210
Transforms
208211
~~~~~~~~~~
209212

210213
.. autofunction:: clip_grad_norm
214+
.. autofunction:: nan_to_num
211215

212216
Optimizer Schedules
213217
===================

docs/source/spelling_wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,6 @@ Karush
9393
Kuhn
9494
Tucker
9595
Neumann
96+
num
97+
posinf
98+
neginf

torchopt/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torchopt.alias import adam, adamw, rmsprop, sgd
3232
from torchopt.clip import clip_grad_norm
3333
from torchopt.combine import chain
34+
from torchopt.hook import register_hook
3435
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
3536
from torchopt.optim.func import FuncOptimizer
3637
from torchopt.optim.meta import (
@@ -41,6 +42,7 @@
4142
MetaRMSprop,
4243
MetaSGD,
4344
)
45+
from torchopt.transform import nan_to_num
4446
from torchopt.update import apply_updates
4547
from torchopt.utils import (
4648
extract_state_dict,
@@ -60,6 +62,8 @@
6062
'rmsprop',
6163
'sgd',
6264
'clip_grad_norm',
65+
'nan_to_num',
66+
'register_hook',
6367
'chain',
6468
'Optimizer',
6569
'SGD',

torchopt/hook.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,32 @@
1414
# ==============================================================================
1515
"""Hook utilities."""
1616

17+
from typing import Callable, Optional
18+
1719
import torch
1820

1921
from torchopt import pytree
2022
from torchopt.base import EmptyState, GradientTransformation
2123

2224

23-
__all__ = ['zero_nan_hook', 'register_hook']
25+
__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook']
2426

2527

2628
def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
27-
"""Registers a zero nan hook to replace nan with zero."""
28-
return torch.where(torch.isnan(g), torch.zeros_like(g), g)
29+
"""A zero ``nan`` hook to replace ``nan`` with zero."""
30+
return g.nan_to_num(nan=0.0)
31+
32+
33+
def nan_to_num_hook(
34+
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
35+
) -> Callable[[torch.Tensor], torch.Tensor]:
36+
"""Returns a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""
37+
38+
def hook(g: torch.Tensor) -> torch.Tensor:
39+
"""A hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""
40+
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
41+
42+
return hook
2943

3044

3145
def register_hook(hook) -> GradientTransformation:

torchopt/transform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"""Preset transformations."""
3333

3434
from torchopt.transform.add_decayed_weights import add_decayed_weights
35+
from torchopt.transform.nan_to_num import nan_to_num
3536
from torchopt.transform.scale import scale
3637
from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam
3738
from torchopt.transform.scale_by_rms import scale_by_rms
@@ -49,4 +50,5 @@
4950
'scale_by_accelerated_adam',
5051
'scale_by_rms',
5152
'scale_by_stddev',
53+
'nan_to_num',
5254
]

torchopt/transform/nan_to_num.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2022 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Preset transformations that replaces updates with non-finite values to the given numbers."""
16+
17+
from typing import Optional
18+
19+
from torchopt import pytree
20+
from torchopt.base import EmptyState, GradientTransformation
21+
22+
23+
def nan_to_num(
24+
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
25+
) -> GradientTransformation:
26+
"""Replaces updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers.
27+
28+
Returns:
29+
An ``(init_fn, update_fn)`` tuple.
30+
"""
31+
32+
def init_fn(params): # pylint: disable=unused-argument
33+
return EmptyState()
34+
35+
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
36+
if inplace:
37+
38+
def f(g):
39+
return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)
40+
41+
else:
42+
43+
def f(g):
44+
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
45+
46+
new_updates = pytree.tree_map(f, updates)
47+
return new_updates, state
48+
49+
return GradientTransformation(init_fn, update_fn)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy