Skip to content

Commit e1bef1f

Browse files
committed
Merge branch 'main' into fix-implicit-oop
2 parents 5591aa9 + 3e7c857 commit e1bef1f

File tree

6 files changed

+18
-18
lines changed

6 files changed

+18
-18
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ COMMIT_HASH = $(shell git log -1 --format=%h)
1010
PATH := $(HOME)/go/bin:$(PATH)
1111
PYTHON ?= $(shell command -v python3 || command -v python)
1212
CLANG_FORMAT ?= $(shell command -v clang-format-14 || command -v clang-format)
13+
PYTESTOPTS ?=
1314

1415
.PHONY: default
1516
default: install
@@ -104,7 +105,7 @@ pytest: pytest-install
104105
cd tests && \
105106
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
106107
--cov="$(PROJECT_NAME)" --cov-report=xml --cov-report=term-missing \
107-
.
108+
$(PYTESTOPTS) .
108109

109110
test: pytest
110111

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def stationary(params, meta_params, data):
212212
return stationary condition
213213

214214
# Decorator for wrapping the function
215-
# and specify the linear solver (conjugate gradient or Neumann series)
215+
# Optionally specify the linear solver (conjugate gradient or Neumann series)
216216
@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)
217217
def solve(params, meta_params, data):
218218
# Forward optimization process for params
@@ -233,29 +233,29 @@ Users need to define the stationary condition/objective function and the inner-l
233233

234234
```python
235235
# Inherited from the class ImplicitMetaGradientModule
236-
# and specify the linear solver (conjugate gradient or Neumann series)
236+
# Optionally specify the linear solver (conjugate gradient or Neumann series)
237237
class InnerNet(ImplicitMetaGradientModule, linear_solver):
238238
def __init__(self, meta_param):
239239
super().__init__()
240240
self.meta_param = meta_param
241241
...
242242

243-
def forward(self, data):
243+
def forward(self, batch):
244244
# Forward process
245245
...
246246

247-
def optimality(self, data):
247+
def optimality(self, batch, labels):
248248
# Stationary condition construction for calculating implicit gradient
249249
# NOTE: If this method is not implemented, it will be automatically
250250
# derived from the gradient of the `objective` function.
251251
...
252252

253-
def objective(self, data):
253+
def objective(self, batch, labels):
254254
# Define the inner-loop optimization objective
255255
...
256256

257-
def solve(self, data):
258-
# conduct the inner-loop optimization
257+
def solve(self, batch, labels):
258+
# Conduct the inner-loop optimization
259259
...
260260

261261
# Get meta_params and data
@@ -282,12 +282,12 @@ Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Ord
282282

283283
```python
284284
# Customize the noise sampling function in ES
285-
def sample(params, batch, labels, *, sample_shape):
285+
def sample(sample_shape):
286286
...
287287
return sample_noise
288288

289289
# Specify method and hyper-parameter of ES
290-
@torchopt.diff.zero_order(method, sample)
290+
@torchopt.diff.zero_order(sample, method)
291291
def forward(params, batch, labels):
292292
# forward process
293293
return output

conda-recipe.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ dependencies:
4242
- optax # for tutorials
4343
- jaxopt # for tests
4444
- tensorboard # for examples
45-
- wandb
4645

4746
# Device select
4847
- nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7
@@ -68,13 +67,13 @@ dependencies:
6867
- setproctitle
6968

7069
# Documentation
71-
- sphinx
70+
- sphinx >= 5.2.1
7271
- sphinx_rtd_theme
7372
- sphinx-autobuild
7473
- sphinx-copybutton
7574
- sphinxcontrib-spelling
7675
- sphinxcontrib-bibtex
77-
- sphinx-autodoc-typehints
76+
- sphinx-autodoc-typehints >= 1.19.2
7877
- pyenchant
7978
- hunspell-en
8079
- myst-nb

docs/conda-recipe.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ dependencies:
5757
- pillow
5858

5959
# Documentation
60-
- sphinx
60+
- sphinx >= 5.2.1
6161
- sphinx_rtd_theme
6262
- sphinx-autobuild
6363
- sphinx-copybutton
6464
- sphinxcontrib-spelling
6565
- sphinxcontrib-bibtex
66-
- sphinx-autodoc-typehints
66+
- sphinx-autodoc-typehints >= 1.19.2
6767
- pyenchant
6868
- hunspell-en
6969
- myst-nb

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def filter(self, record):
154154

155155
# See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html
156156
latex_macros = r"""
157-
\def \d #1{\operatorname{#1}}
157+
\def \d #1{\operatorname{#1}}
158158
"""
159159

160160
# Translate LaTeX macros to KaTeX and add to options for HTML builder

src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ endif()
2323

2424
list(APPEND torchopt_csrc "${adam_op_src}")
2525

26-
pybind11_add_module(_C THIN_LTO "${torchopt_csrc}")
26+
pybind11_add_module(_C MODULE THIN_LTO "${torchopt_csrc}")
2727

2828
target_link_libraries(
2929
_C PRIVATE
30-
${TORCH_LIBRARIES}
30+
"${TORCH_LIBRARIES}"
3131
OpenMP::OpenMP_CXX
3232
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy