Skip to content

test: refactor tests using pytest.mark.parametrize #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 24, 2022
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
dfb1b63
test: use `torch.allclose` to compare tensors
XuehaiPan Aug 8, 2022
63d5012
test: refactor tests using `pytest.mark.parametrize`
XuehaiPan Aug 8, 2022
d2ea7fc
chore: update test ids
XuehaiPan Aug 8, 2022
29284c4
chore: update test ids
XuehaiPan Aug 8, 2022
591999f
chore(workflows): increase timeout
XuehaiPan Aug 8, 2022
803c7b9
test: reorganize tests
XuehaiPan Aug 8, 2022
99a561d
feat: parallel testing
XuehaiPan Aug 8, 2022
5b7dedb
test: reorganize tests
XuehaiPan Aug 8, 2022
4992712
test: loop for more update iterations
XuehaiPan Aug 8, 2022
85bddb0
test: reduce number of tests
XuehaiPan Aug 8, 2022
92fa5c7
test: reduce number of tests
XuehaiPan Aug 8, 2022
acdb41a
test: update assert
XuehaiPan Aug 8, 2022
cf45c8e
test: rename variable
XuehaiPan Aug 9, 2022
aba2efc
test: update assert
XuehaiPan Aug 9, 2022
2d28f49
docs(CHANGELOG): update CHANGELOG.md
XuehaiPan Aug 10, 2022
2745eb1
test: cov project name
XuehaiPan Aug 10, 2022
de69946
test: update tests
XuehaiPan Aug 11, 2022
468c4dc
test: set `CUBLAS_WORKSPACE_CONFIG`
XuehaiPan Aug 12, 2022
c7e72c5
test: compare diffs
XuehaiPan Aug 15, 2022
83717bb
test: update tol
XuehaiPan Aug 15, 2022
998b400
test: test float32
XuehaiPan Aug 15, 2022
933eb5d
test: show tol in assert messages
XuehaiPan Aug 15, 2022
876e9d3
test: use smaller network for tests
XuehaiPan Aug 16, 2022
9450669
Merge remote-tracking branch 'upstream/main' into fix-tests
Benjamin-eecs Aug 22, 2022
7034ec3
fix: change the order of add and multiply in sgd
Benjamin-eecs Aug 22, 2022
4b30fc0
Merge branch 'main' into fix-tests
XuehaiPan Aug 22, 2022
1b0a35c
fix: correct test writing and pass sgd
Benjamin-eecs Aug 23, 2022
2a09c8c
Merge branch 'fix-tests' of https://github.com/XuehaiPan/TorchOpt int…
Benjamin-eecs Aug 23, 2022
d6179f6
fix: correct test writing and fix other optims
Benjamin-eecs Aug 23, 2022
48f4a94
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
4ddf147
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
d77cdd8
to(tests): high level non differentiable optimizer unfixed
Benjamin-eecs Aug 23, 2022
f32ef9f
test: disable parallel testing
XuehaiPan Aug 24, 2022
629a36f
fix(transform): fix momentum trace
XuehaiPan Aug 24, 2022
ca580a6
refactor: refactor transform and utils
XuehaiPan Aug 24, 2022
6771907
test: test inplace operators
XuehaiPan Aug 24, 2022
579e786
fix: fix RMSProp optimizer
XuehaiPan Aug 24, 2022
0cb3c44
fix: fix Makefile
XuehaiPan Aug 24, 2022
51d15b3
lint: appease linters
XuehaiPan Aug 24, 2022
1e1d86f
chore: update unused function
XuehaiPan Aug 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(transform): fix momentum trace
  • Loading branch information
XuehaiPan committed Aug 24, 2022
commit 629a36f1c0bb5a3b840342fbe4b4b4c7c564e28d
17 changes: 8 additions & 9 deletions torchopt/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def update_fn(updates, state, inplace=True):
if inplace:

def f1(g, t):
return t.copy_(t.mul_(decay).add_(g, alpha=1.0))
return t.mul_(decay).add_(g)

def f2(g, t):
return g.add_(t, alpha=decay)
Expand All @@ -369,21 +369,20 @@ def f2(g, t):
if inplace:

def f(g, t):
return t.mul_(decay).add_(g, alpha=1.0)
return t.mul_(decay).add_(g)

def copy_(t, g):
t.copy_(g)
def copy_(g, t):
return g.copy_(t)

updates = pytree.tree_map(f, updates, state.trace)
pytree.tree_map(copy_, state.trace, updates)
new_trace = state.trace
new_trace = pytree.tree_map(f, updates, state.trace)
updates = pytree.tree_map(copy_, updates, state.trace)
else:

def f(g, t):
return t.mul(decay).add(g)

updates = pytree.tree_map(f, updates, state.trace)
new_trace = updates
new_trace = pytree.tree_map(f, updates, state.trace)
updates = new_trace

return updates, TraceState(trace=new_trace)

Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy