Skip to content

Commit 5bc8133

Browse files
committed
chore(pre-commit): update pre-commit hooks
1 parent 605929a commit 5bc8133

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+519
-280
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ repos:
2626
- id: debug-statements
2727
- id: double-quote-string-fixer
2828
- repo: https://github.com/pre-commit/mirrors-clang-format
29-
rev: v18.1.5
29+
rev: v18.1.6
3030
hooks:
3131
- id: clang-format
3232
- repo: https://github.com/astral-sh/ruff-pre-commit
33-
rev: v0.4.7
33+
rev: v0.4.9
3434
hooks:
3535
- id: ruff
3636
args: [--fix, --exit-non-zero-on-fix]
@@ -43,7 +43,7 @@ repos:
4343
hooks:
4444
- id: black-jupyter
4545
- repo: https://github.com/asottile/pyupgrade
46-
rev: v3.15.2
46+
rev: v3.16.0
4747
hooks:
4848
- id: pyupgrade
4949
args: [--py38-plus] # sync with requires-python
@@ -52,7 +52,7 @@ repos:
5252
^examples/
5353
)
5454
- repo: https://github.com/pycqa/flake8
55-
rev: 7.0.0
55+
rev: 7.1.0
5656
hooks:
5757
- id: flake8
5858
additional_dependencies:

pyproject.toml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ extend-exclude = ["examples"]
235235
select = [
236236
"E", "W", # pycodestyle
237237
"F", # pyflakes
238+
"C90", # mccabe
238239
"UP", # pyupgrade
239240
"ANN", # flake8-annotations
240241
"S", # flake8-bandit
@@ -243,14 +244,21 @@ select = [
243244
"COM", # flake8-commas
244245
"C4", # flake8-comprehensions
245246
"EXE", # flake8-executable
247+
"FA", # flake8-future-annotations
248+
"LOG", # flake8-logging
246249
"ISC", # flake8-implicit-str-concat
250+
"INP", # flake8-no-pep420
247251
"PIE", # flake8-pie
248252
"PYI", # flake8-pyi
249253
"Q", # flake8-quotes
250254
"RSE", # flake8-raise
251255
"RET", # flake8-return
252256
"SIM", # flake8-simplify
253257
"TID", # flake8-tidy-imports
258+
"TCH", # flake8-type-checking
259+
"PERF", # perflint
260+
"FURB", # refurb
261+
"TRY", # tryceratops
254262
"RUF", # ruff
255263
]
256264
ignore = [
@@ -268,9 +276,9 @@ ignore = [
268276
# S101: use of `assert` detected
269277
# internal use and may never raise at runtime
270278
"S101",
271-
# PLR0402: use from {module} import {name} in lieu of alias
272-
# use alias for import convention (e.g., `import torch.nn as nn`)
273-
"PLR0402",
279+
# TRY003: avoid specifying long messages outside the exception class
280+
# long messages are necessary for clarity
281+
"TRY003",
274282
]
275283
typing-modules = ["torchopt.typing"]
276284

@@ -296,6 +304,9 @@ typing-modules = ["torchopt.typing"]
296304
"F401", # unused-import
297305
"F811", # redefined-while-unused
298306
]
307+
"docs/source/conf.py" = [
308+
"INP001", # flake8-no-pep420
309+
]
299310

300311
[tool.ruff.lint.flake8-annotations]
301312
allow-star-arg-any = true

tests/helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import itertools
2121
import os
2222
import random
23-
from typing import Iterable
23+
from typing import TYPE_CHECKING, Iterable
2424

2525
import numpy as np
2626
import pytest
@@ -30,7 +30,10 @@
3030
from torch.utils import data
3131

3232
from torchopt import pytree
33-
from torchopt.typing import TensorTree
33+
34+
35+
if TYPE_CHECKING:
36+
from torchopt.typing import TensorTree
3437

3538

3639
BATCH_SIZE = 64

tests/test_alias.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Callable
18+
from typing import TYPE_CHECKING, Callable
1919

2020
import functorch
2121
import pytest
@@ -26,7 +26,10 @@
2626
import torchopt
2727
from torchopt import pytree
2828
from torchopt.alias.utils import _set_use_chain_flat
29-
from torchopt.typing import TensorTree
29+
30+
31+
if TYPE_CHECKING:
32+
from torchopt.typing import TensorTree
3033

3134

3235
@helpers.parametrize(

tests/test_implicit.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import copy
1919
import re
2020
from collections import OrderedDict
21-
from types import FunctionType
21+
from typing import TYPE_CHECKING
2222

2323
import functorch
2424
import numpy as np
@@ -47,6 +47,10 @@
4747
HAS_JAX = False
4848

4949

50+
if TYPE_CHECKING:
51+
from types import FunctionType
52+
53+
5054
BATCH_SIZE = 8
5155
NUM_UPDATES = 3
5256

@@ -123,7 +127,7 @@ def get_rr_dataset_torch() -> data.DataLoader:
123127
inner_lr=[2e-2, 2e-3],
124128
inner_update=[20, 50, 100],
125129
)
126-
def test_imaml_solve_normal_cg(
130+
def test_imaml_solve_normal_cg( # noqa: C901
127131
dtype: torch.dtype,
128132
lr: float,
129133
inner_lr: float,
@@ -251,7 +255,7 @@ def outer_level(p, xs, ys):
251255
inner_update=[20, 50, 100],
252256
ns=[False, True],
253257
)
254-
def test_imaml_solve_inv(
258+
def test_imaml_solve_inv( # noqa: C901
255259
dtype: torch.dtype,
256260
lr: float,
257261
inner_lr: float,
@@ -375,7 +379,12 @@ def outer_level(p, xs, ys):
375379
inner_lr=[2e-2, 2e-3],
376380
inner_update=[20, 50, 100],
377381
)
378-
def test_imaml_module(dtype: torch.dtype, lr: float, inner_lr: float, inner_update: int) -> None:
382+
def test_imaml_module( # noqa: C901
383+
dtype: torch.dtype,
384+
lr: float,
385+
inner_lr: float,
386+
inner_update: int,
387+
) -> None:
379388
np_dtype = helpers.dtype_torch2numpy(dtype)
380389

381390
jax_model, jax_params = get_model_jax(dtype=np_dtype)
@@ -763,7 +772,7 @@ def solve(self):
763772
make_optimality_from_objective(MyModule2)
764773

765774

766-
def test_module_abstract_methods() -> None:
775+
def test_module_abstract_methods() -> None: # noqa: C901
767776
class MyModule1(torchopt.nn.ImplicitMetaGradientModule):
768777
def objective(self):
769778
return torch.tensor(0.0)
@@ -809,7 +818,7 @@ def solve(self):
809818

810819
class MyModule5(torchopt.nn.ImplicitMetaGradientModule):
811820
@classmethod
812-
def optimality(self):
821+
def optimality(cls):
813822
return ()
814823

815824
def solve(self):
@@ -846,7 +855,7 @@ def solve(self):
846855

847856
class MyModule8(torchopt.nn.ImplicitMetaGradientModule):
848857
@classmethod
849-
def objective(self):
858+
def objective(cls):
850859
return ()
851860

852861
def solve(self):

tests/test_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
import operator
17+
1618
import torch
1719

1820
import torchopt
@@ -80,7 +82,7 @@ def test_module_clone() -> None:
8082
assert y.is_cuda
8183

8284

83-
def test_extract_state_dict():
85+
def test_extract_state_dict(): # noqa: C901
8486
fc = torch.nn.Linear(1, 1)
8587
state_dict = torchopt.extract_state_dict(fc, by='reference', device=torch.device('meta'))
8688
for param_dict in state_dict.params:
@@ -121,7 +123,7 @@ def test_extract_state_dict():
121123
loss = fc(torch.ones(1, 1)).sum()
122124
optim.step(loss)
123125
state_dict = torchopt.extract_state_dict(optim)
124-
same = pytree.tree_map(lambda x, y: x is y, state_dict, tuple(optim.state_groups))
126+
same = pytree.tree_map(operator.is_, state_dict, tuple(optim.state_groups))
125127
assert all(pytree.tree_flatten(same)[0])
126128

127129

torchopt/__init__.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -81,50 +81,50 @@
8181

8282

8383
__all__ = [
84-
'accelerated_op_available',
85-
'adam',
86-
'adamax',
87-
'adadelta',
88-
'radam',
89-
'adamw',
90-
'adagrad',
91-
'rmsprop',
92-
'sgd',
93-
'clip_grad_norm',
94-
'nan_to_num',
95-
'register_hook',
96-
'chain',
97-
'Optimizer',
9884
'SGD',
99-
'Adam',
100-
'AdaMax',
101-
'Adamax',
10285
'AdaDelta',
103-
'Adadelta',
104-
'RAdam',
105-
'AdamW',
10686
'AdaGrad',
87+
'AdaMax',
88+
'Adadelta',
10789
'Adagrad',
108-
'RMSProp',
109-
'RMSprop',
110-
'MetaOptimizer',
111-
'MetaSGD',
112-
'MetaAdam',
113-
'MetaAdaMax',
114-
'MetaAdamax',
90+
'Adam',
91+
'AdamW',
92+
'Adamax',
93+
'FuncOptimizer',
11594
'MetaAdaDelta',
116-
'MetaAdadelta',
117-
'MetaRAdam',
118-
'MetaAdamW',
11995
'MetaAdaGrad',
96+
'MetaAdaMax',
97+
'MetaAdadelta',
12098
'MetaAdagrad',
99+
'MetaAdam',
100+
'MetaAdamW',
101+
'MetaAdamax',
102+
'MetaOptimizer',
103+
'MetaRAdam',
121104
'MetaRMSProp',
122105
'MetaRMSprop',
123-
'FuncOptimizer',
106+
'MetaSGD',
107+
'Optimizer',
108+
'RAdam',
109+
'RMSProp',
110+
'RMSprop',
111+
'accelerated_op_available',
112+
'adadelta',
113+
'adagrad',
114+
'adam',
115+
'adamax',
116+
'adamw',
124117
'apply_updates',
118+
'chain',
119+
'clip_grad_norm',
125120
'extract_state_dict',
126-
'recover_state_dict',
127-
'stop_gradient',
128121
'module_clone',
129122
'module_detach_',
123+
'nan_to_num',
124+
'radam',
125+
'recover_state_dict',
126+
'register_hook',
127+
'rmsprop',
128+
'sgd',
129+
'stop_gradient',
130130
]

torchopt/accelerated_op/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Iterable
19+
from typing import TYPE_CHECKING, Iterable
2020

2121
import torch
2222

2323
from torchopt.accelerated_op.adam_op import AdamOp
24-
from torchopt.typing import Device
24+
25+
26+
if TYPE_CHECKING:
27+
from torchopt.typing import Device
2528

2629

2730
def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
@@ -42,6 +45,6 @@ def is_available(devices: Device | Iterable[Device] | None = None) -> bool:
4245
return False
4346
updates = torch.tensor(1.0, device=device)
4447
op(updates, updates, updates, 1)
45-
return True
4648
except Exception: # noqa: BLE001 # pylint: disable=broad-except
4749
return False
50+
return True

torchopt/accelerated_op/_src/adam_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
from __future__ import annotations
2020

21-
import torch
21+
from typing import TYPE_CHECKING
22+
23+
24+
if TYPE_CHECKING:
25+
import torch
2226

2327

2428
def forward_(

torchopt/alias/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,13 @@
4141
from torchopt.alias.sgd import sgd
4242

4343

44-
__all__ = ['adagrad', 'radam', 'adam', 'adamax', 'adadelta', 'adamw', 'rmsprop', 'sgd']
44+
__all__ = [
45+
'adadelta',
46+
'adagrad',
47+
'adam',
48+
'adamax',
49+
'adamw',
50+
'radam',
51+
'rmsprop',
52+
'sgd',
53+
]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy