Skip to content

Commit 7f4b991

Browse files
committed
feat(hook): add nan_to_num hooks
1 parent f831b53 commit 7f4b991

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

docs/source/api/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,16 @@ Optimizer Hooks
186186

187187
register_hook
188188
zero_nan_hook
189+
nan_to_zero_hook
190+
nan_to_zero
189191

190192
Hook
191193
~~~~
192194

193195
.. autofunction:: register_hook
194196
.. autofunction:: zero_nan_hook
197+
.. autofunction:: nan_to_zero_hook
198+
.. autofunction:: nan_to_zero
195199

196200
------
197201

torchopt/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torchopt.alias import adam, adamw, rmsprop, sgd
3232
from torchopt.clip import clip_grad_norm
3333
from torchopt.combine import chain
34+
from torchopt.hook import nan_to_num, register_hook
3435
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
3536
from torchopt.optim.func import FuncOptimizer
3637
from torchopt.optim.meta import (
@@ -60,6 +61,8 @@
6061
'rmsprop',
6162
'sgd',
6263
'clip_grad_norm',
64+
'nan_to_num',
65+
'register_hook',
6366
'chain',
6467
'Optimizer',
6568
'SGD',

torchopt/hook.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,61 @@
1414
# ==============================================================================
1515
"""Hook utilities."""
1616

17+
from typing import Callable, Optional
18+
1719
import torch
1820

1921
from torchopt import pytree
2022
from torchopt.base import EmptyState, GradientTransformation
2123

2224

23-
__all__ = ['zero_nan_hook', 'register_hook']
25+
__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'nan_to_num', 'register_hook']
2426

2527

2628
def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
27-
"""Registers a zero nan hook to replace nan with zero."""
28-
return torch.where(torch.isnan(g), torch.zeros_like(g), g)
29+
"""A zero ``nan`` hook to replace ``nan`` with zero."""
30+
return g.nan_to_num(nan=0.0)
31+
32+
33+
def nan_to_num_hook(
34+
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
35+
) -> Callable[[torch.Tensor], torch.Tensor]:
36+
"""Returns a ``nan`` to num hook to replace ``nan`` with given number."""
37+
38+
def hook(g: torch.Tensor) -> torch.Tensor:
39+
"""A zero ``nan`` hook to replace ``nan`` with given number."""
40+
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
41+
42+
return hook
43+
44+
45+
def nan_to_num(
46+
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
47+
) -> GradientTransformation:
48+
"""A gradient transformation that replaces gradient values of ``nan`` with given number.
49+
50+
Returns:
51+
An ``(init_fn, update_fn)`` tuple.
52+
"""
53+
54+
def init_fn(params): # pylint: disable=unused-argument
55+
return EmptyState()
56+
57+
def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
58+
if inplace:
59+
60+
def f(g):
61+
return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)
62+
63+
else:
64+
65+
def f(g):
66+
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)
67+
68+
new_updates = pytree.tree_map(f, updates)
69+
return new_updates, state
70+
71+
return GradientTransformation(init_fn, update_fn)
2972

3073

3174
def register_hook(hook) -> GradientTransformation:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy