Skip to content

feat(hook): add nan_to_num hooks #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/set_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

VERSION_FILE = ROOT / 'torchopt' / 'version.py'

VERSION_CONTENT = VERSION_FILE.read_text()
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')

VERSION_FILE.write_text(
data=re.sub(
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119).
- Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98).
- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105).
- Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107).
Expand Down
6 changes: 5 additions & 1 deletion docs/source/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,28 +186,32 @@ Optimizer Hooks

register_hook
zero_nan_hook
nan_to_num_hook

Hook
~~~~

.. autofunction:: register_hook
.. autofunction:: zero_nan_hook
.. autofunction:: nan_to_num_hook

------

Gradient Transformation
=======================

.. currentmodule:: torchopt.clip
.. currentmodule:: torchopt

.. autosummary::

clip_grad_norm
nan_to_num

Transforms
~~~~~~~~~~

.. autofunction:: clip_grad_norm
.. autofunction:: nan_to_num

Optimizer Schedules
===================
Expand Down
3 changes: 3 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,6 @@ Karush
Kuhn
Tucker
Neumann
num
posinf
neginf
4 changes: 4 additions & 0 deletions torchopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchopt.alias import adam, adamw, rmsprop, sgd
from torchopt.clip import clip_grad_norm
from torchopt.combine import chain
from torchopt.hook import register_hook
from torchopt.optim import SGD, Adam, AdamW, Optimizer, RMSProp, RMSprop, meta
from torchopt.optim.func import FuncOptimizer
from torchopt.optim.meta import (
Expand All @@ -41,6 +42,7 @@
MetaRMSprop,
MetaSGD,
)
from torchopt.transform import nan_to_num
from torchopt.update import apply_updates
from torchopt.utils import (
extract_state_dict,
Expand All @@ -60,6 +62,8 @@
'rmsprop',
'sgd',
'clip_grad_norm',
'nan_to_num',
'register_hook',
'chain',
'Optimizer',
'SGD',
Expand Down
20 changes: 17 additions & 3 deletions torchopt/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,32 @@
# ==============================================================================
"""Hook utilities."""

from typing import Callable, Optional

import torch

from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation


__all__ = ['zero_nan_hook', 'register_hook']
__all__ = ['zero_nan_hook', 'nan_to_num_hook', 'register_hook']


def zero_nan_hook(g: torch.Tensor) -> torch.Tensor:
"""Registers a zero nan hook to replace nan with zero."""
return torch.where(torch.isnan(g), torch.zeros_like(g), g)
"""A zero ``nan`` hook to replace ``nan`` with zero."""
return g.nan_to_num(nan=0.0)


def nan_to_num_hook(
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Returns a ``nan`` to num hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""

def hook(g: torch.Tensor) -> torch.Tensor:
"""A hook to replace ``nan`` / ``+inf`` / ``-inf`` with the given numbers."""
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)

return hook


def register_hook(hook) -> GradientTransformation:
Expand Down
2 changes: 2 additions & 0 deletions torchopt/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""Preset transformations."""

from torchopt.transform.add_decayed_weights import add_decayed_weights
from torchopt.transform.nan_to_num import nan_to_num
from torchopt.transform.scale import scale
from torchopt.transform.scale_by_adam import scale_by_accelerated_adam, scale_by_adam
from torchopt.transform.scale_by_rms import scale_by_rms
Expand All @@ -49,4 +50,5 @@
'scale_by_accelerated_adam',
'scale_by_rms',
'scale_by_stddev',
'nan_to_num',
]
49 changes: 49 additions & 0 deletions torchopt/transform/nan_to_num.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preset transformations that replaces updates with non-finite values to the given numbers."""

from typing import Optional

from torchopt import pytree
from torchopt.base import EmptyState, GradientTransformation


def nan_to_num(
nan: float = 0.0, posinf: Optional[float] = None, neginf: Optional[float] = None
) -> GradientTransformation:
"""Replaces updates with values ``nan`` / ``+inf`` / ``-inf`` to the given numbers.

Returns:
An ``(init_fn, update_fn)`` tuple.
"""

def init_fn(params): # pylint: disable=unused-argument
return EmptyState()

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
if inplace:

def f(g):
return g.nan_to_num_(nan=nan, posinf=posinf, neginf=neginf)

else:

def f(g):
return g.nan_to_num(nan=nan, posinf=posinf, neginf=neginf)

new_updates = pytree.tree_map(f, updates)
return new_updates, state

return GradientTransformation(init_fn, update_fn)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy