`` function points
+ to the original tensors.
- :const:`'deepcopy'`: The extracted tensors will be deep-copied from the original
tensors. The deep-copied tensors will detach from the original computation graph.
- detach_buffers: Whether to detach the reference to the buffers, this argument is only used
- if the input target is :class:`nn.Module` and ``by='reference'``.
- device: If specified, move the cloned module to the specified device.
+ detach_buffers (bool, optional): Whether to detach the reference to the buffers, this
+ argument is only used if the input target is :class:`nn.Module` and ``by='reference'``.
+ (default: :const:`False`)
+ device (Device or None, optional): If specified, move the cloned module to the specified
+ device. (default: :const:`None`)
Returns:
The cloned module.
@@ -499,7 +491,7 @@ def module_detach_(target: nn.Module) -> nn.Module: # pragma: no cover
@overload
-def module_detach_(target: 'MetaOptimizer') -> 'MetaOptimizer': # pragma: no cover
+def module_detach_(target: MetaOptimizer) -> MetaOptimizer: # pragma: no cover
...
@@ -509,12 +501,13 @@ def module_detach_(target: TensorTree) -> TensorTree: # pragma: no cover
def module_detach_(
- target: Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]
-) -> Union[ModuleState, nn.Module, 'MetaOptimizer', TensorTree]:
+ target: ModuleState | nn.Module | MetaOptimizer | TensorTree,
+) -> ModuleState | nn.Module | MetaOptimizer | TensorTree:
"""Detach a module from the computation graph.
Args:
- target: The target to be detached.
+ target (ModuleState, nn.Module, MetaOptimizer, or tree of Tensor): The
+ target to be detached.
Returns:
The detached module.
diff --git a/torchopt/visual.py b/torchopt/visual.py
index e8145240..7afe65a4 100644
--- a/torchopt/visual.py
+++ b/torchopt/visual.py
@@ -17,8 +17,10 @@
# ==============================================================================
"""Computation graph visualization."""
+from __future__ import annotations
+
from collections import namedtuple
-from typing import Generator, Iterable, Mapping, Optional, Union, cast
+from typing import Generator, Iterable, Mapping, cast
import torch
from graphviz import Digraph
@@ -71,14 +73,13 @@ def truncate(s): # pylint: disable=invalid-name
# pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
def make_dot(
var: TensorOrTensors,
- params: Optional[
- Union[
- Mapping[str, torch.Tensor],
- ModuleState,
- Generator,
- Iterable[Union[Mapping[str, torch.Tensor], ModuleState, Generator]],
- ]
- ] = None,
+ params: (
+ Mapping[str, torch.Tensor]
+ | ModuleState
+ | Generator
+ | Iterable[Mapping[str, torch.Tensor] | ModuleState | Generator]
+ | None
+ ) = None,
show_attrs: bool = False,
show_saved: bool = False,
max_attr_chars: int = 50,
@@ -89,7 +90,7 @@ def make_dot(
and is either blue, orange, or green:
- **Blue**
- Reachable leaf tensors that requires grad (tensors whose :attr:`grad` fields will be
+ Reachable leaf tensors that requires grad (tensors whose ``grad`` fields will be
populated during :meth:`backward`).
- **Orange**
Saved tensors of custom autograd functions as well as those saved by built-in backward
@@ -100,16 +101,16 @@ def make_dot(
If any output is a view, we represent its base tensor with a dark green node.
Args:
- var: Output tensor.
- params: ([dict of (name, tensor) or state_dict])
- Parameters to add names to node that requires grad.
- show_attrs: Whether to display non-tensor attributes of backward nodes
- (Requires PyTorch version >= 1.9)
- show_saved: Whether to display saved tensor nodes that are not by custom autograd
- functions. Saved tensor nodes for custom functions, if present, are always displayed.
- (Requires PyTorch version >= 1.9)
- max_attr_chars: If ``show_attrs`` is :data:`True`, sets max number of characters to display
- for any given attribute.
+ var (Tensor or sequence of Tensor): Output tensor.
+ params: (dict[str, Tensor], ModuleState, iterable of tuple[str, Tensor], or None, optional):
+ Parameters to add names to node that requires grad. (default: :data:`None`)
+ show_attrs (bool, optional): Whether to display non-tensor attributes of backward nodes.
+ (default: :data:`False`)
+ show_saved (bool, optional): Whether to display saved tensor nodes that are not by custom
+ autograd functions. Saved tensor nodes for custom functions, if present, are always
+ displayed. (default: :data:`False`)
+ max_attr_chars (int, optional): If ``show_attrs`` is :data:`True`, sets max number of
+ characters to display for any given attribute. (default: :const:`50`)
"""
param_map = {}
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy