Skip to content

fix: fix lr scheduling #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: fix lr scheduling
  • Loading branch information
XuehaiPan committed Sep 9, 2022
commit 7930a5fed86da0a8ec53be39297256313656f7d7
10 changes: 3 additions & 7 deletions torchopt/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,10 @@ def _scale_by_neg_lr(lr: ScalarOrSchedule):
if callable(lr):

def schedule_wrapper(count):
def f(scaled_lr):
return -scaled_lr
return -lr(count) # type: ignore[operator]

return transform.map_flattened(f, lr(count)) # type: ignore[operator]

return transform._scale_by_schedule( # pylint: disable=protected-access
schedule_wrapper, already_flattened=True
)
# pylint: disable-next=protected-access
return transform._scale_by_schedule(schedule_wrapper, already_flattened=True)
return transform._scale(-lr, already_flattened=True) # pylint: disable=protected-access


Expand Down
12 changes: 6 additions & 6 deletions torchopt/_src/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _split_tensor_and_others(
non_tensors = []
is_tensor_mask = []
for item in flattened:
if isinstance(item, torch.Tensor):
if torch.is_tensor(item):
tensors.append(item)
is_tensor_mask.append(True)
else:
Expand Down Expand Up @@ -236,18 +236,18 @@ def forward(ctx, *flat_args): # pylint: disable=arguments-differ
if has_aux:
aux = res[1]
res = res[0]
if isinstance(res, torch.Tensor):
if torch.is_tensor(res):
ctx.save_for_backward(res, *args_tensors)
else:
ctx.save_for_backward(*res, *args_tensors)
ctx.res_is_tensor = isinstance(res, torch.Tensor)
ctx.res_is_tensor = torch.is_tensor(res)
return res + (aux,)

if isinstance(res, torch.Tensor):
if torch.is_tensor(res):
ctx.save_for_backward(res, *args_tensors)
else:
ctx.save_for_backward(*res, *args_tensors)
ctx.res_is_tensor = isinstance(res, torch.Tensor)
ctx.res_is_tensor = torch.is_tensor(res)
return res

@staticmethod
Expand Down Expand Up @@ -314,7 +314,7 @@ def wrapped_solver_fun(*args, **kwargs):
args_counter = 0
for idx, arg in enumerate(args):
if idx in argnums:
if isinstance(arg, torch.Tensor):
if torch.is_tensor(arg):
args_sign.append((args_counter, False)) # start position, is_tuple
flatten_args.append(arg)
args_counter += 1
Expand Down
2 changes: 1 addition & 1 deletion torchopt/_src/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _normalize_matvec(
if callable(f):
return f

assert isinstance(f, torch.Tensor)
assert torch.is_tensor(f)
if f.ndim != 2 or f.shape[0] != f.shape[1]:
raise ValueError(f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(torch.matmul, f)
Expand Down
16 changes: 7 additions & 9 deletions torchopt/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
import logging

import numpy as np
import torch

from torchopt._src import base
from torchopt._src.typing import Scalar
from torchopt._src.utils import pytree
from torchopt._src.typing import Numeric, Scalar


def polynomial_schedule(
Expand Down Expand Up @@ -80,13 +80,11 @@ def polynomial_schedule(
)
transition_begin = 0

def schedule(count):
def impl(count):
count = np.clip(count - transition_begin, 0, transition_steps)
frac = 1 - count / transition_steps
return (init_value - end_value) * (frac**power) + end_value

return pytree.tree_map(impl, count)
def schedule(count: Numeric) -> Numeric:
clip = torch.clamp if torch.is_tensor(count) else np.clip
count = clip(count - transition_begin, 0, transition_steps) # type: ignore[operator]
frac = 1.0 - count / transition_steps
return (init_value - end_value) * (frac**power) + end_value

return schedule

Expand Down
17 changes: 11 additions & 6 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,25 @@ def init_fn(params):
return ScaleByScheduleState(count=zero)

def update_fn(updates, state, *, params=None, inplace=True): # pylint: disable=unused-argument
step_size = step_size_fn(state.count)

if inplace:

def f(g):
def f(g, c):
step_size = step_size_fn(c)
return g.mul_(step_size) if g is not None else None

else:

def f(g):
def f(g, c):
step_size = step_size_fn(c)
return g.mul(step_size) if g is not None else None

updates = tree_map(f, updates)
return updates, ScaleByScheduleState(count=inc_count(updates, state.count))
updates = tree_map(f, updates, state.count)
return (
updates,
ScaleByScheduleState(
count=_inc_count(updates, state.count, already_flattened=already_flattened)
),
)

return base.GradientTransformation(init_fn, update_fn)

Expand Down
4 changes: 2 additions & 2 deletions torchopt/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stop_gradient(target):
from torchopt._src.optimizer.meta.base import MetaOptimizer

def f(obj):
if isinstance(obj, torch.Tensor):
if torch.is_tensor(obj):
requires_grad = obj.requires_grad
obj.detach_().requires_grad_(requires_grad)

Expand Down Expand Up @@ -134,7 +134,7 @@ def _update(term):
if copy:

def get_variable(t):
if not isinstance(t, torch.Tensor):
if not torch.is_tensor(t):
return t
requires_grad = t.requires_grad
return t.clone().detach_().requires_grad_(requires_grad)
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy