Skip to content

Commit 065e57b

Browse files
committed
chore: suppress unresolved forwardrefs type hint warnings
1 parent ce349ca commit 065e57b

File tree

10 files changed

+44
-33
lines changed

10 files changed

+44
-33
lines changed

docs/source/conf.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# add these directories to sys.path here. If the directory is relative to the
2626
# documentation root, use os.path.abspath to make it absolute, like shown here.
2727
#
28+
import logging
2829
import os
2930
import pathlib
3031
import sys
@@ -43,6 +44,21 @@ def get_version() -> str:
4344
return version.__version__
4445

4546

47+
try:
48+
import sphinx_autodoc_typehints
49+
except ImportError:
50+
pass
51+
else:
52+
53+
class RecursiveForwardRefFilter(logging.Filter):
54+
def filter(self, record):
55+
if "name 'TensorTree' is not defined" in record.getMessage():
56+
return False
57+
return super().filter(record)
58+
59+
sphinx_autodoc_typehints._LOGGER.logger.addFilter(RecursiveForwardRefFilter())
60+
61+
4662
# -- Project information -----------------------------------------------------
4763

4864
project = 'TorchOpt'
@@ -75,7 +91,7 @@ def get_version() -> str:
7591
'sphinxcontrib.bibtex',
7692
'sphinxcontrib.katex',
7793
'sphinx_autodoc_typehints',
78-
'myst_nb', # This is used for the .ipynb notebooks
94+
'myst_nb', # this is used for the .ipynb notebooks
7995
]
8096

8197
if not os.getenv('READTHEDOCS', None):
@@ -120,6 +136,7 @@ def get_version() -> str:
120136
'exclude-members': '__module__, __dict__, __repr__, __str__, __weakref__',
121137
}
122138
autoclass_content = 'both'
139+
simplify_optional_unions = False
123140

124141
# -- Options for bibtex -----------------------------------------------------
125142

torchopt/alias/adamw.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from torchopt.alias.utils import flip_sign_and_add_weight_decay, scale_by_neg_lr
3737
from torchopt.combine import chain_flat
3838
from torchopt.transform import add_decayed_weights, scale_by_accelerated_adam, scale_by_adam
39-
from torchopt.typing import Params # pylint: disable=unused-import
40-
from torchopt.typing import GradientTransformation, ScalarOrSchedule
39+
from torchopt.typing import GradientTransformation, Params, ScalarOrSchedule
4140

4241

4342
__all__ = ['adamw']
@@ -51,7 +50,7 @@ def adamw(
5150
weight_decay: float = 1e-2,
5251
*,
5352
eps_root: float = 0.0,
54-
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
53+
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
5554
moment_requires_grad: bool = False,
5655
maximize: bool = False,
5756
use_accelerated_op: bool = False,

torchopt/diff/implicit/nn/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torchopt.nn
2525
from torchopt import pytree
2626
from torchopt.diff.implicit.decorator import custom_root
27-
from torchopt.typing import TensorTree, TupleOfTensors # pylint: disable=unused-import
27+
from torchopt.typing import TensorTree, TupleOfTensors
2828
from torchopt.utils import extract_module_containers
2929

3030

@@ -228,7 +228,7 @@ def solve(self, batch, labels):
228228
raise NotImplementedError # update parameters
229229

230230
# pylint: disable-next=redefined-builtin
231-
def residual(self, *input, **kwargs) -> 'TensorTree':
231+
def residual(self, *input, **kwargs) -> TensorTree:
232232
r"""Computes the optimality residual.
233233
234234
This method stands for the residual to the optimal parameters after solving the inner

torchopt/optim/adamw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from torchopt import alias
2222
from torchopt.optim.base import Optimizer
23-
from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import
23+
from torchopt.typing import Params, ScalarOrSchedule
2424

2525

2626
__all__ = ['AdamW']
@@ -44,7 +44,7 @@ def __init__(
4444
weight_decay: float = 1e-2,
4545
*,
4646
eps_root: float = 0.0,
47-
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
47+
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
4848
maximize: bool = False,
4949
use_accelerated_op: bool = False,
5050
) -> None:

torchopt/optim/base.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@
1919
import torch
2020

2121
from torchopt import pytree
22-
from torchopt.typing import ( # pylint: disable=unused-import
23-
GradientTransformation,
24-
OptState,
25-
Params,
26-
TupleOfTensors,
27-
)
22+
from torchopt.typing import GradientTransformation, OptState, Params, TupleOfTensors
2823
from torchopt.update import apply_updates
2924

3025

@@ -84,11 +79,11 @@ def f(p):
8479

8580
pytree.tree_map(f, self.param_groups) # type: ignore[arg-type]
8681

87-
def state_dict(self) -> Tuple['OptState', ...]:
82+
def state_dict(self) -> Tuple[OptState, ...]:
8883
"""Returns the state of the optimizer."""
8984
return tuple(self.state_groups)
9085

91-
def load_state_dict(self, state_dict: Sequence['OptState']) -> None:
86+
def load_state_dict(self, state_dict: Sequence[OptState]) -> None:
9287
"""Loads the optimizer state.
9388
9489
Args:
@@ -121,7 +116,7 @@ def f(p):
121116

122117
return loss
123118

124-
def add_param_group(self, params: 'Params') -> None:
119+
def add_param_group(self, params: Params) -> None:
125120
"""Add a param group to the optimizer's :attr:`param_groups`."""
126121
flat_params, params_treespec = pytree.tree_flatten(params)
127122
flat_params: TupleOfTensors = tuple(flat_params)

torchopt/optim/func/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from torchopt.base import GradientTransformation
22-
from torchopt.typing import OptState, Params # pylint: disable=unused-import
22+
from torchopt.typing import OptState, Params
2323
from torchopt.update import apply_updates
2424

2525

@@ -61,9 +61,9 @@ def __init__(self, impl: GradientTransformation, *, inplace: bool = False) -> No
6161
def step(
6262
self,
6363
loss: torch.Tensor,
64-
params: 'Params',
64+
params: Params,
6565
inplace: Optional[bool] = None,
66-
) -> 'Params':
66+
) -> Params:
6767
r"""Compute the gradients of loss to the network parameters and update network parameters.
6868
6969
Graph of the derivative will be constructed, allowing to compute higher order derivative

torchopt/optim/meta/adamw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from torchopt import alias
2222
from torchopt.optim.meta.base import MetaOptimizer
23-
from torchopt.typing import Params, ScalarOrSchedule # pylint: disable=unused-import
23+
from torchopt.typing import Params, ScalarOrSchedule
2424

2525

2626
__all__ = ['MetaAdamW']
@@ -44,7 +44,7 @@ def __init__(
4444
weight_decay: float = 1e-2,
4545
*,
4646
eps_root: float = 0.0,
47-
mask: Optional[Union[Any, Callable[['Params'], Any]]] = None,
47+
mask: Optional[Union[Any, Callable[[Params], Any]]] = None,
4848
moment_requires_grad: bool = False,
4949
maximize: bool = False,
5050
use_accelerated_op: bool = False,

torchopt/optim/meta/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def add_param_group(self, net: nn.Module) -> None:
9999
self.param_containers_groups.append(params_container)
100100
self.state_groups.append(optimizer_state)
101101

102-
def state_dict(self) -> Tuple['OptState', ...]:
102+
def state_dict(self) -> Tuple[OptState, ...]:
103103
"""Extract the references of the optimizer states.
104104
105105
Note that the states are references, so any in-place operations will change the states
106106
inside :class:`MetaOptimizer` at the same time.
107107
"""
108108
return tuple(self.state_groups)
109109

110-
def load_state_dict(self, state_dict: Sequence['OptState']) -> None:
110+
def load_state_dict(self, state_dict: Sequence[OptState]) -> None:
111111
"""Load the references of the optimizer states."""
112112
self.state_groups[:] = list(state_dict)

torchopt/update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
"""Helper functions for applying updates."""
3333

3434
from torchopt import pytree
35-
from torchopt.typing import Params, Updates # pylint: disable=unused-import
35+
from torchopt.typing import Params, Updates
3636

3737

3838
__all__ = ['apply_updates']
3939

4040

41-
def apply_updates(params: 'Params', updates: 'Updates', *, inplace: bool = True) -> 'Params':
41+
def apply_updates(params: Params, updates: Updates, *, inplace: bool = True) -> Params:
4242
"""Applies an update to the corresponding parameters.
4343
4444
This is a utility functions that applies an update to a set of parameters, and then returns the

torchopt/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import torch.nn as nn
3636

3737
from torchopt import pytree
38-
from torchopt.typing import OptState, TensorTree # pylint: disable=unused-import
38+
from torchopt.typing import OptState, TensorTree
3939

4040

4141
if TYPE_CHECKING:
@@ -64,7 +64,7 @@ class ModuleState(NamedTuple):
6464
CopyMode: TypeAlias = Literal['reference', 'copy', 'deepcopy', 'ref', 'clone', 'deepclone']
6565

6666

67-
def stop_gradient(target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']) -> None:
67+
def stop_gradient(target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']) -> None:
6868
"""Stop the gradient for the input object.
6969
7070
Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
@@ -123,7 +123,7 @@ def extract_state_dict(
123123
with_buffers: bool = True,
124124
enable_visual: bool = False,
125125
visual_prefix: str = '',
126-
) -> Tuple['OptState', ...]:
126+
) -> Tuple[OptState, ...]:
127127
...
128128

129129

@@ -137,7 +137,7 @@ def extract_state_dict(
137137
detach_buffers: bool = False,
138138
enable_visual: bool = False,
139139
visual_prefix: str = '',
140-
) -> Union[ModuleState, Tuple['OptState', ...]]:
140+
) -> Union[ModuleState, Tuple[OptState, ...]]:
141141
"""Extract target state.
142142
143143
Since a tensor use :attr:`grad_fn` to connect itself with the previous computation graph, the
@@ -312,7 +312,7 @@ def update_container(container, items):
312312

313313
def recover_state_dict(
314314
target: Union[nn.Module, 'MetaOptimizer'],
315-
state: Union[ModuleState, Sequence['OptState']],
315+
state: Union[ModuleState, Sequence[OptState]],
316316
) -> None:
317317
"""Recover state.
318318
@@ -478,8 +478,8 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
478478

479479

480480
def module_detach_(
481-
target: Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']
482-
) -> Union['TensorTree', ModuleState, nn.Module, 'MetaOptimizer']:
481+
target: Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']
482+
) -> Union[TensorTree, ModuleState, nn.Module, 'MetaOptimizer']:
483483
"""Detach a module from the computation graph.
484484
485485
Args:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy