Skip to content

Commit 84d1c3d

Browse files
authored
chore(accelerated_op): use correct Python Ctype for pybind11 function prototype (#52)
1 parent 5b5b21d commit 84d1c3d

File tree

13 files changed

+162
-156
lines changed

13 files changed

+162
-156
lines changed

.github/workflows/lint.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,9 @@ jobs:
5454
run: |
5555
python -m pip install --upgrade pip setuptools
5656
57-
- name: Install dependencies
58-
run: |
59-
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
60-
-r tests/requirements.txt
61-
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
62-
-r docs/requirements.txt
63-
6457
- name: Install TorchOpt
6558
run: |
66-
python -m pip install -e .
59+
python -m pip install -vvv -e '.[lint]'
6760
6861
- name: pre-commit
6962
run: |
@@ -97,6 +90,11 @@ jobs:
9790
run: |
9891
make mypy
9992
93+
- name: Install dependencies
94+
run: |
95+
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
96+
-r docs/requirements.txt
97+
10098
- name: docstyle
10199
run: |
102100
make docstyle

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
7373
- name: Install TorchOpt
7474
run: |
75-
python -m pip install -e .
75+
python -m pip install -vvv -e .
7676
7777
- name: Test with pytest
7878
run: |

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/TorchOpt/pull/45).
1616
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42).
1717

18+
### Changed
19+
20+
- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/TorchOpt/pull/52).
21+
1822
------
1923

2024
## [0.4.2] - 2022-07-26

conda-recipe.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ dependencies:
7676
- mypy
7777
- flake8
7878
- flake8-bugbear
79-
- doc8
79+
- doc8 < 1.0.0a0
8080
- pydocstyle
8181
- clang-format
8282
- clang-tools # clang-tidy

include/adam_op/adam_op.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,35 @@
2323
namespace torchopt {
2424
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
2525
const torch::Tensor& mu,
26-
const torch::Tensor& nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
26+
const torch::Tensor& nu, const pyfloat_t b1,
27+
const pyfloat_t b2, const pyfloat_t eps,
28+
const pyfloat_t eps_root,
29+
const pyuint_t count);
2930

3031
torch::Tensor adamForwardMu(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const float b1);
32+
const torch::Tensor& mu, const pyfloat_t b1);
3233

3334
torch::Tensor adamForwardNu(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const float b2);
35+
const torch::Tensor& nu, const pyfloat_t b2);
3536

3637
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu, const float b1,
38-
const float b2, const float eps,
39-
const float eps_root, const int count);
38+
const torch::Tensor& new_nu,
39+
const pyfloat_t b1, const pyfloat_t b2,
40+
const pyfloat_t eps, const pyfloat_t eps_root,
41+
const pyuint_t count);
4042

4143
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
4244
const torch::Tensor& updates,
43-
const torch::Tensor& mu, const float b1);
45+
const torch::Tensor& mu, const pyfloat_t b1);
4446

4547
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
4648
const torch::Tensor& updates,
47-
const torch::Tensor& nu, const float b2);
49+
const torch::Tensor& nu, const pyfloat_t b2);
4850

4951
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
5052
const torch::Tensor& updates,
5153
const torch::Tensor& new_mu,
52-
const torch::Tensor& new_nu, const float b1,
53-
const float b2, const int count);
54+
const torch::Tensor& new_nu,
55+
const pyfloat_t b1, const pyfloat_t b2,
56+
const pyuint_t count);
5457
} // namespace torchopt

include/adam_op/adam_op_impl_cpu.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,36 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24-
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
25-
const torch::Tensor& mu,
26-
const torch::Tensor& nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
24+
TensorArray<3> adamForwardInplaceCPU(
25+
const torch::Tensor& updates, const torch::Tensor& mu,
26+
const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2,
27+
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);
2928

3029
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const float b1);
30+
const torch::Tensor& mu, const pyfloat_t b1);
3231

3332
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const float b2);
33+
const torch::Tensor& nu, const pyfloat_t b2);
3534

3635
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu, const float b1,
38-
const float b2, const float eps,
39-
const float eps_root, const int count);
36+
const torch::Tensor& new_nu,
37+
const pyfloat_t b1, const pyfloat_t b2,
38+
const pyfloat_t eps,
39+
const pyfloat_t eps_root,
40+
const pyuint_t count);
4041

4142
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
4243
const torch::Tensor& updates,
43-
const torch::Tensor& mu, const float b1);
44+
const torch::Tensor& mu, const pyfloat_t b1);
4445

4546
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
4647
const torch::Tensor& updates,
47-
const torch::Tensor& nu, const float b2);
48+
const torch::Tensor& nu, const pyfloat_t b2);
4849

4950
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
5051
const torch::Tensor& updates,
5152
const torch::Tensor& new_mu,
5253
const torch::Tensor& new_nu,
53-
const float b1, const float b2,
54-
const int count);
54+
const pyfloat_t b1, const pyfloat_t b2,
55+
const pyuint_t count);
5556
} // namespace torchopt

include/adam_op/adam_op_impl_cuda.cuh

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,36 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24-
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
25-
const torch::Tensor &mu,
26-
const torch::Tensor &nu, const float b1,
27-
const float b2, const float eps,
28-
const float eps_root, const int count);
24+
TensorArray<3> adamForwardInplaceCUDA(
25+
const torch::Tensor &updates, const torch::Tensor &mu,
26+
const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2,
27+
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);
2928

3029
torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
31-
const torch::Tensor &mu, const float b1);
30+
const torch::Tensor &mu, const pyfloat_t b1);
3231

3332
torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
34-
const torch::Tensor &nu, const float b2);
33+
const torch::Tensor &nu, const pyfloat_t b2);
3534

3635
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
3736
const torch::Tensor &new_nu,
38-
const float b1, const float b2,
39-
const float eps, const float eps_root,
40-
const int count);
37+
const pyfloat_t b1, const pyfloat_t b2,
38+
const pyfloat_t eps,
39+
const pyfloat_t eps_root,
40+
const pyuint_t count);
4141

4242
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
4343
const torch::Tensor &updates,
44-
const torch::Tensor &mu, const float b1);
44+
const torch::Tensor &mu, const pyfloat_t b1);
4545

4646
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
4747
const torch::Tensor &updates,
48-
const torch::Tensor &nu, const float b2);
48+
const torch::Tensor &nu, const pyfloat_t b2);
4949

5050
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
5151
const torch::Tensor &updates,
5252
const torch::Tensor &new_mu,
5353
const torch::Tensor &new_nu,
54-
const float b1, const float b2,
55-
const int count);
54+
const pyfloat_t b1, const pyfloat_t b2,
55+
const pyuint_t count);
5656
} // namespace torchopt

include/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#include <torch/extension.h>
1818

1919
#include <array>
20+
#include <cstddef>
21+
22+
using pyfloat_t = double;
23+
using pyuint_t = std::size_t;
2024

2125
namespace torchopt {
2226
template <size_t _Nm>

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ lint = [
6666
"mypy",
6767
"flake8",
6868
"flake8-bugbear",
69-
"doc8",
69+
"doc8 < 1.0.0a0",
7070
"pydocstyle",
7171
"pyenchant",
7272
"cpplint",

src/adam_op/adam_op.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
namespace torchopt {
2727
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
2828
const torch::Tensor& mu,
29-
const torch::Tensor& nu, const float b1,
30-
const float b2, const float eps,
31-
const float eps_root, const int count) {
29+
const torch::Tensor& nu, const pyfloat_t b1,
30+
const pyfloat_t b2, const pyfloat_t eps,
31+
const pyfloat_t eps_root,
32+
const pyuint_t count) {
3233
#if defined(__CUDACC__)
3334
if (updates.device().is_cuda()) {
3435
return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root,
@@ -42,7 +43,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
4243
}
4344
}
4445
torch::Tensor adamForwardMu(const torch::Tensor& updates,
45-
const torch::Tensor& mu, const float b1) {
46+
const torch::Tensor& mu, const pyfloat_t b1) {
4647
#if defined(__CUDACC__)
4748
if (updates.device().is_cuda()) {
4849
return adamForwardMuCUDA(updates, mu, b1);
@@ -56,7 +57,7 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates,
5657
}
5758

5859
torch::Tensor adamForwardNu(const torch::Tensor& updates,
59-
const torch::Tensor& nu, const float b2) {
60+
const torch::Tensor& nu, const pyfloat_t b2) {
6061
#if defined(__CUDACC__)
6162
if (updates.device().is_cuda()) {
6263
return adamForwardNuCUDA(updates, nu, b2);
@@ -70,9 +71,10 @@ torch::Tensor adamForwardNu(const torch::Tensor& updates,
7071
}
7172

7273
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
73-
const torch::Tensor& new_nu, const float b1,
74-
const float b2, const float eps,
75-
const float eps_root, const int count) {
74+
const torch::Tensor& new_nu,
75+
const pyfloat_t b1, const pyfloat_t b2,
76+
const pyfloat_t eps, const pyfloat_t eps_root,
77+
const pyuint_t count) {
7678
#if defined(__CUDACC__)
7779
if (new_mu.device().is_cuda()) {
7880
return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count);
@@ -87,7 +89,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
8789

8890
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
8991
const torch::Tensor& updates,
90-
const torch::Tensor& mu, const float b1) {
92+
const torch::Tensor& mu, const pyfloat_t b1) {
9193
#if defined(__CUDACC__)
9294
if (dmu.device().is_cuda()) {
9395
return adamBackwardMuCUDA(dmu, updates, mu, b1);
@@ -102,7 +104,7 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
102104

103105
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
104106
const torch::Tensor& updates,
105-
const torch::Tensor& nu, const float b2) {
107+
const torch::Tensor& nu, const pyfloat_t b2) {
106108
#if defined(__CUDACC__)
107109
if (dnu.device().is_cuda()) {
108110
return adamBackwardNuCUDA(dnu, updates, nu, b2);
@@ -118,8 +120,9 @@ TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
118120
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
119121
const torch::Tensor& updates,
120122
const torch::Tensor& new_mu,
121-
const torch::Tensor& new_nu, const float b1,
122-
const float b2, const int count) {
123+
const torch::Tensor& new_nu,
124+
const pyfloat_t b1, const pyfloat_t b2,
125+
const pyuint_t count) {
123126
#if defined(__CUDACC__)
124127
if (dupdates.device().is_cuda()) {
125128
return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy