Skip to content

[BUG] Got empty meta-parameters using ImplicitMetaGradientModule #144

Closed
@lmz123321

Description

@lmz123321

Required prerequisites

What version of TorchOpt are you using?

0.7.0

System information

System version: 3.9.16 (main, Jan 11 2023, 16:05:54)
System platform: [GCC 11.2.0] linux
Torchopt version: 0.7.0 (installed via conda)
Torch version: 1.13.1
Functorch version: 1.13.1

Problem description

Hi, I am using the implicit model from torchopt.nn.ImplicitMetaGradientModule.

One of my input neural networks contains batch normalization layers, I hope to frozen them when calling implicit_model.solve(). So, I use the network.eval() offered by PyTorch.

However, torchopt throws the following error:

  • ValueError: not enough values to unpack (expected 2, got 0)

from the 140-line meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters())) in /torchopt/diff/implicit/nn/module.py.


What I find about this bug:

  • self.named_meta_parameters() is empty
  • using the following codes to frozen BN layers can avoid the error:
for m in network.modules():
    if isinstance(m, nn.BatchNorm1d):
        m.eval()

Reproducible example code

The Python snippets:

import torch
import torch.nn as nn
import torchopt

_ = torch.manual_seed(123)
torch.set_default_dtype(torch.float64)

class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule):
    def __init__(self, mlp, x0):
        super().__init__()
        self.mlp = mlp
        self.x = nn.Parameter(x0.clone().detach_(), requires_grad=True)

    def objective(self):
        return self.mlp(self.x).mean()

    @torch.enable_grad()
    def solve(self):
        optimizer = torch.optim.Adam([self.x], lr=0.01)
        for epoch in range(100):
            optimizer.zero_grad()
            loss = self.objective()
            loss.backward(inputs=[self.x])
            optimizer. step()

mlp = nn.Sequential(nn.Linear(5,5), nn.BatchNorm1d(5), nn.Tanh(), nn.Linear(5,1))
_ = mlp.eval()

# this will work
# for m in mlp.modules():
#   if isinstance(m, nn.BatchNorm1d):
#       m.eval()

x0 = torch.rand(10,5)
    
model = ImplicitModel(mlp, x0)
model. Solve()

Traceback

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/Miniconda3/envs/torch1.13/lib/python3.9/site-packages/torchopt/diff/implicit/n │
│ n/module.py:140 in wrapped                                                                       │
│                                                                                                  │
│   137 │   def wrapped(self: ImplicitMetaGradientModule, *input, **kwargs) -> Any:                │
│   138 │   │   """Solve the optimization problem."""                                              │
│   139 │   │   params_names, flat_params = tuple(zip(*self.named_parameters()))                   │
│ ❱ 140 │   │   meta_params_names, flat_meta_params = tuple(zip(*self.named_meta_parameters()))    │
│   141 │   │                                                                                      │
│   142 │   │   flat_optimal_params, output = stateless_solver_fn(                                 │
│   143 │   │   │   flat_params,                                                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: not enough values to unpack (expected 2, got 0)

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Labels

better errorsNeed better user-end error messages

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    pFad - Phonifier reborn

    Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

    Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


    Alternative Proxies:

    Alternative Proxy

    pFad Proxy

    pFad v3 Proxy

    pFad v4 Proxy