|
14 | 14 | # ==============================================================================
|
15 | 15 | """Hook utilities."""
|
16 | 16 |
|
| 17 | +from typing import Callable, Optional |
| 18 | + |
17 | 19 | import torch
|
18 | 20 |
|
19 | 21 | from torchopt import pytree
|
20 | 22 | from torchopt.base import EmptyState, GradientTransformation
|
21 | 23 |
|
22 | 24 |
|
23 |
| -__all__ = ['zero_nan_hook', 'register_hook'] |
| 25 | +__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'nan_to_num', 'register_hook'] |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
|
27 |
| - """Registers a zero nan hook to replace nan with zero.""" |
28 |
| - return torch.where(torch.isnan(g), torch.zeros_like(g), g) |
| 29 | + """A zero ``nan`` hook to replace ``nan`` with zero.""" |
| 30 | + return g.nan_to_num(nan=0.0) |
| 31 | + |
| 32 | + |
| 33 | +def nan_to_num_hook( |
| 34 | + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None |
| 35 | +) -> Callable[[torch.Tensor], torch.Tensor]: |
| 36 | + """Returns a ``nan`` to num hook to replace ``nan`` with given number.""" |
| 37 | + |
| 38 | + def hook(g: torch.Tensor) -> torch.Tensor: |
| 39 | + """A zero ``nan`` hook to replace ``nan`` with given number.""" |
| 40 | + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) |
| 41 | + |
| 42 | + return hook |
| 43 | + |
| 44 | + |
| 45 | +def nan_to_num( |
| 46 | + nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None |
| 47 | +) -> GradientTransformation: |
| 48 | + """A gradient transformation that replaces gradient values of ``nan`` with given number. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + An ``(init_fn, update_fn)`` tuple. |
| 52 | + """ |
| 53 | + |
| 54 | + def init_fn(params): # pylint: disable=unused-argument |
| 55 | + return EmptyState() |
| 56 | + |
| 57 | + def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument |
| 58 | + if inplace: |
| 59 | + |
| 60 | + def f(g): |
| 61 | + return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf) |
| 62 | + |
| 63 | + else: |
| 64 | + |
| 65 | + def f(g): |
| 66 | + return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf) |
| 67 | + |
| 68 | + new_updates = pytree.tree_map(f, updates) |
| 69 | + return new_updates, state |
| 70 | + |
| 71 | + return GradientTransformation(init_fn, update_fn) |
29 | 72 |
|
30 | 73 |
|
31 | 74 | def register_hook(hook) -> GradientTransformation:
|
|
0 commit comments