Skip to content

feat(diff/zero_order): add OOP API for zero-order differentiation #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs(README): add notes for zero-order differentiation OOP APIs
  • Loading branch information
XuehaiPan committed Jan 5, 2023
commit ba0a8b027dac1e2033d90c7a8686b27a7fa60af6
53 changes: 47 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Users need to define the stationary condition/objective function and the inner-l
```python
# Inherited from the class ImplicitMetaGradientModule
# Optionally specify the linear solver (conjugate gradient or Neumann series)
class InnerNet(ImplicitMetaGradientModule, linear_solver):
class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):
def __init__(self, meta_param):
super().__init__()
self.meta_param = meta_param
Expand Down Expand Up @@ -293,17 +293,58 @@ Refer to the tutorial notebook [Zero-order Differentiation](tutorials/6_Zero_Ord

#### Functional API <!-- omit in toc -->

For zero-order differentiation, users need to define the forward pass calculation and the noise sampling procedure. TorchOpt provides the decorator to wrap the forward function for enabling zero-order differentiation.

```python
# Customize the noise sampling function in ES
def sample(sample_shape):
def distribution(sample_shape):
# Generate a batch of noise samples
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return sample_noise
return noise_batch

# Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)

# Specify method and hyper-parameter of ES
@torchopt.diff.zero_order(sample, method)
@torchopt.diff.zero_order(distribution, method)
def forward(params, batch, labels):
# forward process
return output
# Forward process
...
return objective # the returned tensor should be a scalar tensor
```

#### OOP API <!-- omit in toc -->

TorchOpt also offer an OOP API, users need to inherit from the class `torchopt.nn.ZeroOrderGradientModule` to construct the network as an `nn.Module` following a classical PyTorch style.
Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`.

```python
# Inherited from the class ZeroOrderGradientModule
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
class Net(ZeroOrderGradientModule, method=method, num_samples=num_samples, sigma=sigma):
def __init__(self, ...):
...

def forward(self, batch):
# Forward process
...
return objective # the returned tensor should be a scalar tensor

def sample(self, sample_shape=torch.Size()):
# Generate a batch of noise samples
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return noise_batch

# Get model and data
net = Net(...)
data = ...

# Forward pass
loss = Net(data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, net.parameters())
```

--------------------------------------------------------------------------------
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy