Skip to content

Commit 7737f70

Browse files
committed
test: update tests
1 parent 0439f57 commit 7737f70

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

tests/test_accelerated_op.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,47 @@
2121
import torchopt
2222

2323

24+
try:
25+
import torchopt._C.adam_op
26+
except ImportError:
27+
CXX_ACCELERATED_OP_AVAILABLE = False
28+
else:
29+
CXX_ACCELERATED_OP_AVAILABLE = True
30+
31+
32+
def test_accelerated_op_is_available():
33+
assert torchopt.accelerated_op_available('cpu')
34+
assert torchopt.accelerated_op_available(torch.device('cpu'))
35+
36+
if CXX_ACCELERATED_OP_AVAILABLE:
37+
assert not torchopt.accelerated_op_available('meta')
38+
assert not torchopt.accelerated_op_available(torch.device('meta'))
39+
assert not torchopt.accelerated_op_available(['cpu', 'meta'])
40+
assert not torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')])
41+
else:
42+
assert torchopt.accelerated_op_available('meta')
43+
assert torchopt.accelerated_op_available(torch.device('meta'))
44+
assert torchopt.accelerated_op_available(['cpu', 'meta'])
45+
assert torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')])
46+
47+
if torch.cuda.is_available():
48+
assert torchopt.accelerated_op_available()
49+
assert torchopt.accelerated_op_available('cuda')
50+
assert torchopt.accelerated_op_available('cuda:0')
51+
assert torchopt.accelerated_op_available(0)
52+
assert torchopt.accelerated_op_available(['cpu', 'cuda'])
53+
assert torchopt.accelerated_op_available(['cpu', 'cuda:0'])
54+
assert torchopt.accelerated_op_available(['cpu', 0])
55+
else:
56+
assert not torchopt.accelerated_op_available()
57+
assert not torchopt.accelerated_op_available('cuda')
58+
assert not torchopt.accelerated_op_available('cuda:0')
59+
assert not torchopt.accelerated_op_available(0)
60+
assert not torchopt.accelerated_op_available(['cpu', 'cuda'])
61+
assert not torchopt.accelerated_op_available(['cpu', 'cuda:0'])
62+
assert not torchopt.accelerated_op_available(['cpu', 0])
63+
64+
2465
@helpers.parametrize(
2566
dtype=[torch.float64, torch.float32],
2667
lr=[1e-2, 1e-3, 1e-4],

torchopt/accelerated_op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ def is_available(devices: Optional[Union[Device, Iterable[Device]]] = None) -> b
4141
updates = torch.tensor(1.0, device=device)
4242
op(updates, updates, updates, 1)
4343
return True
44-
except BaseException: # pylint: disable=broad-except
44+
except Exception: # pylint: disable=broad-except
4545
return False

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy