Skip to content

Commit 381f4d1

Browse files
authored
feat(diff/zero_order): add OOP API for zero-order differentiation (#125)
1 parent 03686dc commit 381f4d1

File tree

12 files changed

+435
-82
lines changed

12 files changed

+435
-82
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
### Added
1515

16-
16+
- Add object-oriented modules support for zero-order differentiation by [@XuehaiPan](https://github.com/XuehaiPan) in [#125](https://github.com/metaopt/torchopt/pull/125).
1717

1818
### Changed
1919

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ clang-format: clang-format-install
138138
# Documentation
139139

140140
addlicense: addlicense-install
141-
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022 -check $(SOURCE_FOLDERS)
141+
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") -check $(SOURCE_FOLDERS)
142142

143143
docstyle: docs-install
144144
make -C docs clean
@@ -162,7 +162,7 @@ format: py-format-install clang-format-install addlicense-install
162162
$(PYTHON) -m isort --project $(PROJECT_NAME) $(PYTHON_FILES)
163163
$(PYTHON) -m black $(PYTHON_FILES) tutorials
164164
$(CLANG_FORMAT) -style=file -i $(CXX_FILES)
165-
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022 $(SOURCE_FOLDERS)
165+
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") $(SOURCE_FOLDERS)
166166

167167
clean-py:
168168
find . -type f -name '*.py[co]' -delete

README.md

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ Users need to define the stationary condition/objective function and the inner-l
247247
```python
248248
# Inherited from the class ImplicitMetaGradientModule
249249
# Optionally specify the linear solver (conjugate gradient or Neumann series)
250-
class InnerNet(ImplicitMetaGradientModule, linear_solver):
250+
class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):
251251
def __init__(self, meta_param):
252252
super().__init__()
253253
self.meta_param = meta_param
@@ -293,17 +293,58 @@ Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Ord
293293

294294
#### Functional API <!-- omit in toc -->
295295

296+
For zero-order differentiation, users need to define the forward pass calculation and the noise sampling procedure. TorchOpt provides the decorator to wrap the forward function for enabling zero-order differentiation.
297+
296298
```python
297299
# Customize the noise sampling function in ES
298-
def sample(sample_shape):
300+
def distribution(sample_shape):
301+
# Generate a batch of noise samples
302+
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
299303
...
300-
return sample_noise
304+
return noise_batch
305+
306+
# Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
307+
distribution = torch.distributions.Normal(loc=0, scale=1)
301308

302309
# Specify method and hyper-parameter of ES
303-
@torchopt.diff.zero_order(sample, method)
310+
@torchopt.diff.zero_order(distribution, method)
304311
def forward(params, batch, labels):
305-
# forward process
306-
return output
312+
# Forward process
313+
...
314+
return objective # the returned tensor should be a scalar tensor
315+
```
316+
317+
#### OOP API <!-- omit in toc -->
318+
319+
TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style.
320+
Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`.
321+
322+
```python
323+
# Inherited from the class ZeroOrderGradientModule
324+
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
325+
class Net(ZeroOrderGradientModule, method=method, num_samples=num_samples, sigma=sigma):
326+
def __init__(self, ...):
327+
...
328+
329+
def forward(self, batch):
330+
# Forward process
331+
...
332+
return objective # the returned tensor should be a scalar tensor
333+
334+
def sample(self, sample_shape=torch.Size()):
335+
# Generate a batch of noise samples
336+
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
337+
...
338+
return noise_batch
339+
340+
# Get model and data
341+
net = Net(...)
342+
data = ...
343+
344+
# Forward pass
345+
loss = Net(data)
346+
# Backward pass using zero-order differentiation
347+
grads = torch.autograd.grad(loss, net.parameters())
307348
```
308349

309350
--------------------------------------------------------------------------------

tests/test_zero_order.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 MetaOPT Team. All Rights Reserved.
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616
import functorch
1717
import torch
1818
import torch.nn as nn
19+
import torch.nn.functional as F
1920
import torch.types
2021

2122
import helpers
@@ -30,20 +31,17 @@ class FcNet(nn.Module):
3031
def __init__(self, dim, out):
3132
super().__init__()
3233
self.fc = nn.Linear(in_features=dim, out_features=out, bias=True)
33-
nn.init.ones_(self.fc.weight)
34-
nn.init.zeros_(self.fc.bias)
3534

3635
def forward(self, x):
3736
return self.fc(x)
3837

3938

4039
@helpers.parametrize(
41-
dtype=[torch.float64, torch.float32],
4240
lr=[1e-2, 1e-3],
4341
method=['naive', 'forward', 'antithetic'],
4442
sigma=[0.01, 0.1, 1],
4543
)
46-
def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) -> None:
44+
def test_zero_order(lr: float, method: str, sigma: float) -> None:
4745
helpers.seed_everything(42)
4846
input_size = 32
4947
output_size = 1
@@ -59,21 +57,63 @@ def test_zero_order(dtype: torch.dtype, lr: float, method: str, sigma: float) ->
5957
y = torch.randn(input_size) * coef
6058
distribution = torch.distributions.Normal(loc=0, scale=1)
6159

62-
@torchopt.diff.zero_order.zero_order(
60+
@torchopt.diff.zero_order(
6361
distribution=distribution, method=method, argnums=0, sigma=sigma, num_samples=num_samples
6462
)
6563
def forward_process(params, fn, x, y):
6664
y_pred = fn(params, x)
67-
loss = torch.mean((y - y_pred) ** 2)
65+
loss = F.mse_loss(y_pred, y)
6866
return loss
6967

7068
optimizer = torchopt.adam(lr=lr)
71-
opt_state = optimizer.init(params)
69+
opt_state = optimizer.init(params) # init optimizer
7270

7371
for i in range(num_iterations):
74-
opt_state = optimizer.init(params) # init optimizer
7572
loss = forward_process(params, fmodel, x, y) # compute loss
7673

7774
grads = torch.autograd.grad(loss, params) # compute gradients
7875
updates, opt_state = optimizer.update(grads, opt_state) # get updates
7976
params = torchopt.apply_updates(params, updates) # update network parameters
77+
78+
79+
@helpers.parametrize(
80+
lr=[1e-2, 1e-3],
81+
method=['naive', 'forward', 'antithetic'],
82+
sigma=[0.01, 0.1, 1],
83+
)
84+
def test_zero_order_module(lr: float, method: str, sigma: float) -> None:
85+
helpers.seed_everything(42)
86+
input_size = 32
87+
output_size = 1
88+
batch_size = BATCH_SIZE
89+
coef = 0.1
90+
num_iterations = NUM_UPDATES
91+
num_samples = 500
92+
93+
class FcNetWithLoss(
94+
torchopt.nn.ZeroOrderGradientModule, method=method, sigma=sigma, num_samples=num_samples
95+
):
96+
def __init__(self, dim, out):
97+
super().__init__()
98+
self.net = FcNet(dim, out)
99+
self.loss = nn.MSELoss()
100+
self.distribution = torch.distributions.Normal(loc=0, scale=1)
101+
102+
def forward(self, x, y):
103+
return self.loss(self.net(x), y)
104+
105+
def sample(self, sample_shape=torch.Size()):
106+
return self.distribution.sample(sample_shape)
107+
108+
x = torch.randn(batch_size, input_size) * coef
109+
y = torch.randn(input_size) * coef
110+
model_with_loss = FcNetWithLoss(input_size, output_size)
111+
112+
optimizer = torchopt.Adam(model_with_loss.parameters(), lr=lr)
113+
114+
for i in range(num_iterations):
115+
loss = model_with_loss(x, y) # compute loss
116+
117+
optimizer.zero_grad()
118+
loss.backward() # compute gradients
119+
optimizer.step() # update network parameters

torchopt/diff/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 MetaOPT Team. All Rights Reserved.
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,3 +16,4 @@
1616

1717
from torchopt.diff import implicit, zero_order
1818
from torchopt.diff.implicit import ImplicitMetaGradientModule
19+
from torchopt.diff.zero_order import ZeroOrderGradientModule

torchopt/diff/implicit/nn/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 MetaOPT Team. All Rights Reserved.
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,9 +14,10 @@
1414
# ==============================================================================
1515
"""The base class for differentiable implicit meta-gradient models."""
1616

17-
# Preload to resolve circular references
18-
import torchopt.nn.module # pylint: disable=unused-import
17+
import torchopt.nn.module # preload to resolve circular references
1918
from torchopt.diff.implicit.nn.module import ImplicitMetaGradientModule
2019

2120

2221
__all__ = ['ImplicitMetaGradientModule']
22+
23+
del torchopt

torchopt/diff/zero_order/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 MetaOPT Team. All Rights Reserved.
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,10 +17,12 @@
1717
import sys as _sys
1818
from types import ModuleType as _ModuleType
1919

20+
from torchopt.diff.zero_order import nn
2021
from torchopt.diff.zero_order.decorator import zero_order
22+
from torchopt.diff.zero_order.nn import ZeroOrderGradientModule
2123

2224

23-
__all__ = ['zero_order']
25+
__all__ = ['zero_order', 'ZeroOrderGradientModule']
2426

2527

2628
class _CallableModule(_ModuleType): # pylint: disable=too-few-public-methods
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The base class for zero-order gradient models."""
16+
17+
import torchopt.nn.module # preload to resolve circular references
18+
from torchopt.diff.zero_order.nn.module import ZeroOrderGradientModule
19+
20+
21+
__all__ = ['ZeroOrderGradientModule']
22+
23+
del torchopt

torchopt/diff/zero_order/nn/module.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""The base class for zero-order gradient models."""
16+
17+
# pylint: disable=redefined-builtin
18+
19+
import abc
20+
import functools
21+
from typing import Dict, Optional, Sequence, Tuple, Type, Union
22+
23+
import torch
24+
import torch.nn as nn
25+
26+
from torchopt import pytree
27+
from torchopt.diff.implicit.nn.module import container_context
28+
from torchopt.diff.zero_order.decorator import Method, Samplable, zero_order
29+
from torchopt.typing import Numeric, TupleOfTensors
30+
from torchopt.utils import extract_module_containers
31+
32+
33+
__all__ = ['ZeroOrderGradientModule']
34+
35+
36+
def enable_zero_order_gradients(
37+
cls: Type['ZeroOrderGradientModule'],
38+
method: Method = 'naive',
39+
num_samples: int = 1,
40+
sigma: Numeric = 1.0,
41+
) -> Type['ZeroOrderGradientModule']:
42+
"""Enable zero-order gradient estimation for the :func:`forward` method."""
43+
cls_forward = cls.forward
44+
if getattr(cls_forward, '__zero_order_gradients_enabled__', False):
45+
raise TypeError(
46+
'Zero-order gradient estimation is already enabled for the `forward` method.'
47+
)
48+
49+
@functools.wraps(cls_forward)
50+
def wrapped( # pylint: disable=too-many-locals
51+
self: 'ZeroOrderGradientModule', *input, **kwargs
52+
) -> torch.Tensor:
53+
"""Do the forward pass calculation."""
54+
params_containers = extract_module_containers(self, with_buffers=False)[0]
55+
56+
flat_params: TupleOfTensors
57+
flat_params, params_containers_treespec = pytree.tree_flatten_as_tuple(
58+
params_containers # type: ignore[arg-type]
59+
)
60+
61+
@zero_order(self.sample, argnums=0, method=method, num_samples=num_samples, sigma=sigma)
62+
def forward_fn(
63+
__flat_params: TupleOfTensors, # pylint: disable=unused-argument
64+
*input,
65+
**kwargs,
66+
) -> torch.Tensor:
67+
flat_grad_tracking_params = __flat_params
68+
grad_tracking_params_containers: Tuple[
69+
Dict[str, Optional[torch.Tensor]], ...
70+
] = pytree.tree_unflatten( # type: ignore[assignment]
71+
params_containers_treespec, flat_grad_tracking_params
72+
)
73+
74+
with container_context(
75+
params_containers,
76+
grad_tracking_params_containers,
77+
):
78+
return cls_forward(self, *input, **kwargs)
79+
80+
return forward_fn(flat_params, *input, **kwargs)
81+
82+
wrapped.__zero_order_gradients_enabled__ = True # type: ignore[attr-defined]
83+
cls.forward = wrapped # type: ignore[assignment]
84+
return cls
85+
86+
87+
class ZeroOrderGradientModule(nn.Module, Samplable):
88+
"""The base class for zero-order gradient models."""
89+
90+
def __init_subclass__( # pylint: disable=arguments-differ
91+
cls,
92+
method: Method = 'naive',
93+
num_samples: int = 1,
94+
sigma: Numeric = 1.0,
95+
) -> None:
96+
"""Validate and initialize the subclass."""
97+
super().__init_subclass__()
98+
enable_zero_order_gradients(
99+
cls,
100+
method=method,
101+
num_samples=num_samples,
102+
sigma=sigma,
103+
)
104+
105+
@abc.abstractmethod
106+
def forward(self, *args, **kwargs) -> torch.Tensor:
107+
"""Do the forward pass of the model."""
108+
raise NotImplementedError
109+
110+
@abc.abstractmethod
111+
def sample(
112+
self, sample_shape: torch.Size = torch.Size() # pylint: disable=unused-argument
113+
) -> Union[torch.Tensor, Sequence[Numeric]]:
114+
# pylint: disable-next=line-too-long
115+
"""Generate a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched."""
116+
raise NotImplementedError

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy