Skip to content

Commit 3f46666

Browse files
committed
chore: resolve lint
1 parent 886639b commit 3f46666

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

examples/iMAML/imaml_omniglot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class InnerNet(
5151
def __init__(self, meta_net, n_inner_iter, reg_param):
5252
super().__init__()
5353
self.meta_net = meta_net
54-
self.net = copy.deepcopy(meta_net)
54+
self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
5555
self.n_inner_iter = n_inner_iter
5656
self.reg_param = reg_param
5757

tests/test_implicit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ class InnerNet(ImplicitMetaGradientModule, has_aux=True):
243243
def __init__(self, meta_model):
244244
super().__init__()
245245
self.meta_model = meta_model
246-
self.model = copy.deepcopy(meta_model)
246+
self.model = torchopt.module_clone(meta_model, by='deepcopy', detach_buffers=True)
247247

248248
def forward(self, x):
249249
return self.model(x)

torchopt/diff/implicit/nn/module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def enable_implicit_gradients(
8181
raise ValueError('Implicit gradients are already enabled for the solve function.')
8282

8383
cls_has_aux = cls.has_aux
84-
custom_root_kwargs = dict(has_aux=cls_has_aux)
85-
if cls.linear_solve is not None:
86-
custom_root_kwargs.update(solve=cls.linear_solve)
84+
custom_root_kwargs = dict(has_aux=cls_has_aux, solve=cls.linear_solve)
85+
if cls.linear_solve is None:
86+
custom_root_kwargs.pop('solve')
8787

8888
@functools.wraps(cls_solve)
8989
def wrapped( # pylint: disable=too-many-locals
@@ -145,7 +145,7 @@ def optimality_fn(
145145
):
146146
container.update(container_backup)
147147

148-
@custom_root(optimality_fn, argnums=1, **custom_root_kwargs)
148+
@custom_root(optimality_fn, argnums=1, **custom_root_kwargs) # type: ignore[arg-type]
149149
def solver_fn(
150150
flat_params: TupleOfTensors, # pylint: disable=unused-argument
151151
flat_meta_params: TupleOfTensors, # pylint: disable=unused-argument

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy