Skip to content

Commit 02da719

Browse files
authored
feat: meta-gradient module (#101)
1 parent 9def086 commit 02da719

27 files changed

+1083
-182
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16+
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
1617
- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104).
1718
- Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93).
1819
- Add RPC-based distributed training support and add distributed MAML example by [@XuehaiPan](https://github.com/XuehaiPan) in [#83](https://github.com/metaopt/torchopt/pull/83).

conda-recipe.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ dependencies:
5454
- gxx = 10
5555
- nvidia/label/cuda-11.7.1::cuda-nvcc
5656
- nvidia/label/cuda-11.7.1::cuda-cudart-dev
57-
- patchelf >= 0.9
57+
- patchelf >= 0.14
5858
- pybind11
5959

6060
# Misc
61+
- optree >= 0.3.0
6162
- typing-extensions >= 4.0.0
6263
- numpy
6364
- matplotlib-base

docs/conda-recipe.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies:
3333
# Learning
3434
- pytorch::pytorch >= 1.13 # sync with project.dependencies
3535
- pytorch::cpuonly
36+
- pytorch::pytorch-mutex = *=*cpu*
3637
- pip:
3738
- torchviz
3839
- sphinxcontrib-katex # for documentation
@@ -47,6 +48,7 @@ dependencies:
4748
- pybind11
4849

4950
# Misc
51+
- optree >= 0.3.0
5052
- typing-extensions >= 4.0.0
5153
- numpy
5254
- matplotlib-base

docs/source/api/api.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,22 @@ Implicit differentiation
139139
.. autosummary::
140140

141141
custom_root
142+
nn.ImplicitMetaGradientModule
142143

143144
Custom solvers
144145
~~~~~~~~~~~~~~
145146

146147
.. autofunction:: custom_root
147148

149+
150+
Implicit Meta-Gradient Module
151+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
152+
153+
.. currentmodule:: torchopt.diff.implicit.nn
154+
155+
.. autoclass:: ImplicitMetaGradientModule
156+
:members:
157+
148158
------
149159

150160
Linear system solving

docs/source/conf.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# add these directories to sys.path here. If the directory is relative to the
2626
# documentation root, use os.path.abspath to make it absolute, like shown here.
2727
#
28+
import logging
2829
import os
2930
import pathlib
3031
import sys
@@ -43,6 +44,24 @@ def get_version() -> str:
4344
return version.__version__
4445

4546

47+
try:
48+
import sphinx_autodoc_typehints
49+
except ImportError:
50+
pass
51+
else:
52+
53+
class RecursiveForwardRefFilter(logging.Filter):
54+
def filter(self, record):
55+
if (
56+
"name 'TensorTree' is not defined" in record.getMessage()
57+
or "name 'OptionalTensorTree' is not defined" in record.getMessage()
58+
):
59+
return False
60+
return super().filter(record)
61+
62+
sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter())
63+
64+
4665
# -- Project information -----------------------------------------------------
4766

4867
project = 'TorchOpt'
@@ -75,7 +94,7 @@ def get_version() -> str:
7594
'sphinxcontrib.bibtex',
7695
'sphinxcontrib.katex',
7796
'sphinx_autodoc_typehints',
78-
'myst_nb', # This is used for the .ipynb notebooks
97+
'myst_nb', # this is used for the .ipynb notebooks
7998
]
8099

81100
if not os.getenv('READTHEDOCS', None):
@@ -120,6 +139,7 @@ def get_version() -> str:
120139
'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__',
121140
}
122141
autoclass_content = 'both'
142+
simplify_optional_unions = False
123143

124144
# -- Options for bibtex -----------------------------------------------------
125145

docs/source/spelling_wordlist.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ deepcopy
8888
deepclone
8989
RRef
9090
rref
91+
ints
92+
Karush
93+
Kuhn
94+
Tucker

torchopt/alias/adamw.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr
3737
from torchopt.combine import chain_flat
3838
from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam
39-
from torchopt.typing import Params # pylint: disable=unused-import
40-
from torchopt.typing import GradientTransformation, ScalarOrSchedule
39+
from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule
4140

4241

4342
__all__ = ['adamw']
@@ -51,7 +50,7 @@ def adamw(
5150
weight_decay: float = 1e-2,
5251
*,
5352
eps_root: float = 0.0,
54-
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
53+
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
5554
moment_requires_grad: bool = False,
5655
maximize: bool = False,
5756
use_accelerated_op: bool = False,

torchopt/diff/implicit/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@
1414
# ==============================================================================
1515
"""Implicit Meta-Gradient."""
1616

17+
from torchopt.diff.implicit import nn
1718
from torchopt.diff.implicit.decorator import custom_root
19+
from torchopt.diff.implicit.nn import ImplicitMetaGradientModule
20+
21+
22+
__all__ = ['custom_root', 'ImplicitMetaGradientModule']

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy