|
21 | 21 | import torchopt
|
22 | 22 |
|
23 | 23 |
|
| 24 | +try: |
| 25 | + import torchopt._C.adam_op |
| 26 | +except ImportError: |
| 27 | + CXX_ACCELERATED_OP_AVAILABLE = False |
| 28 | +else: |
| 29 | + CXX_ACCELERATED_OP_AVAILABLE = True |
| 30 | + |
| 31 | + |
| 32 | +def test_accelerated_op_is_available(): |
| 33 | + assert torchopt.accelerated_op_available('cpu') |
| 34 | + assert torchopt.accelerated_op_available(torch.device('cpu')) |
| 35 | + |
| 36 | + if CXX_ACCELERATED_OP_AVAILABLE: |
| 37 | + assert not torchopt.accelerated_op_available('meta') |
| 38 | + assert not torchopt.accelerated_op_available(torch.device('meta')) |
| 39 | + assert not torchopt.accelerated_op_available(['cpu', 'meta']) |
| 40 | + assert not torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) |
| 41 | + else: |
| 42 | + assert torchopt.accelerated_op_available('meta') |
| 43 | + assert torchopt.accelerated_op_available(torch.device('meta')) |
| 44 | + assert torchopt.accelerated_op_available(['cpu', 'meta']) |
| 45 | + assert torchopt.accelerated_op_available([torch.device('cpu'), torch.device('meta')]) |
| 46 | + |
| 47 | + if torch.cuda.is_available(): |
| 48 | + assert torchopt.accelerated_op_available() |
| 49 | + assert torchopt.accelerated_op_available('cuda') |
| 50 | + assert torchopt.accelerated_op_available('cuda:0') |
| 51 | + assert torchopt.accelerated_op_available(0) |
| 52 | + assert torchopt.accelerated_op_available(['cpu', 'cuda']) |
| 53 | + assert torchopt.accelerated_op_available(['cpu', 'cuda:0']) |
| 54 | + assert torchopt.accelerated_op_available(['cpu', 0]) |
| 55 | + else: |
| 56 | + assert not torchopt.accelerated_op_available() |
| 57 | + assert not torchopt.accelerated_op_available('cuda') |
| 58 | + assert not torchopt.accelerated_op_available('cuda:0') |
| 59 | + assert not torchopt.accelerated_op_available(0) |
| 60 | + assert not torchopt.accelerated_op_available(['cpu', 'cuda']) |
| 61 | + assert not torchopt.accelerated_op_available(['cpu', 'cuda:0']) |
| 62 | + assert not torchopt.accelerated_op_available(['cpu', 0]) |
| 63 | + |
| 64 | + |
24 | 65 | @helpers.parametrize(
|
25 | 66 | dtype=[torch.float64, torch.float32],
|
26 | 67 | lr=[1e-2, 1e-3, 1e-4],
|
|
0 commit comments