Skip to content

Commit 381c658

Browse files
committed
refactor: format code
1 parent 91c0567 commit 381c658

File tree

10 files changed

+490
-317
lines changed

10 files changed

+490
-317
lines changed

.clang-format

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
BasedOnStyle: Google
2+
ColumnLimit: 100
3+
BinPackArguments: false
4+
BinPackParameters: false

include/adam_op/adam_op.h

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,55 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24+
25+
namespace py = pybind11;
26+
2427
namespace adam_op {
25-
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
26-
const torch::Tensor& mu,
27-
const torch::Tensor& nu, const pyfloat_t& b1,
28-
const pyfloat_t& b2, const pyfloat_t& eps,
29-
const pyfloat_t& eps_root,
30-
const pyuint_t& count);
31-
32-
torch::Tensor adamForwardMu(const torch::Tensor& updates,
33-
const torch::Tensor& mu, const pyfloat_t& b1);
34-
35-
torch::Tensor adamForwardNu(const torch::Tensor& updates,
36-
const torch::Tensor& nu, const pyfloat_t& b2);
37-
38-
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
39-
const torch::Tensor& new_nu,
40-
const pyfloat_t& b1, const pyfloat_t& b2,
41-
const pyfloat_t& eps,
42-
const pyfloat_t& eps_root,
43-
const pyuint_t& count);
44-
45-
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
46-
const torch::Tensor& updates,
47-
const torch::Tensor& mu, const pyfloat_t& b1);
48-
49-
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
50-
const torch::Tensor& updates,
51-
const torch::Tensor& nu, const pyfloat_t& b2);
52-
53-
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
54-
const torch::Tensor& updates,
55-
const torch::Tensor& new_mu,
56-
const torch::Tensor& new_nu,
57-
const pyfloat_t& b1, const pyfloat_t& b2,
58-
const pyuint_t& count);
28+
29+
TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
30+
const torch::Tensor &mu,
31+
const torch::Tensor &nu,
32+
const pyfloat_t &b1,
33+
const pyfloat_t &b2,
34+
const pyfloat_t &eps,
35+
const pyfloat_t &eps_root,
36+
const pyuint_t &count);
37+
38+
torch::Tensor adamForwardMu(const torch::Tensor &updates,
39+
const torch::Tensor &mu,
40+
const pyfloat_t &b1);
41+
42+
torch::Tensor adamForwardNu(const torch::Tensor &updates,
43+
const torch::Tensor &nu,
44+
const pyfloat_t &b2);
45+
46+
torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
47+
const torch::Tensor &new_nu,
48+
const pyfloat_t &b1,
49+
const pyfloat_t &b2,
50+
const pyfloat_t &eps,
51+
const pyfloat_t &eps_root,
52+
const pyuint_t &count);
53+
54+
TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
55+
const torch::Tensor &updates,
56+
const torch::Tensor &mu,
57+
const pyfloat_t &b1);
58+
59+
TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
60+
const torch::Tensor &updates,
61+
const torch::Tensor &nu,
62+
const pyfloat_t &b2);
63+
64+
TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
65+
const torch::Tensor &updates,
66+
const torch::Tensor &new_mu,
67+
const torch::Tensor &new_nu,
68+
const pyfloat_t &b1,
69+
const pyfloat_t &b2,
70+
const pyuint_t &count);
71+
72+
void buildSubmodule(pybind11::module &mod);
73+
5974
} // namespace adam_op
6075
} // namespace torchopt

include/adam_op/adam_op_impl_cpu.h

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,47 @@
2222

2323
namespace torchopt {
2424
namespace adam_op {
25-
TensorArray<3> adamForwardInplaceCPU(
26-
const torch::Tensor& updates, const torch::Tensor& mu,
27-
const torch::Tensor& nu, const pyfloat_t& b1, const pyfloat_t& b2,
28-
const pyfloat_t& eps, const pyfloat_t& eps_root, const pyuint_t& count);
29-
30-
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const pyfloat_t& b1);
32-
33-
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const pyfloat_t& b2);
35-
36-
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu,
38-
const pyfloat_t& b1, const pyfloat_t& b2,
39-
const pyfloat_t& eps,
40-
const pyfloat_t& eps_root,
41-
const pyuint_t& count);
42-
43-
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
44-
const torch::Tensor& updates,
45-
const torch::Tensor& mu, const pyfloat_t& b1);
46-
47-
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
48-
const torch::Tensor& updates,
49-
const torch::Tensor& nu, const pyfloat_t& b2);
50-
51-
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
52-
const torch::Tensor& updates,
53-
const torch::Tensor& new_mu,
54-
const torch::Tensor& new_nu,
55-
const pyfloat_t& b1, const pyfloat_t& b2,
56-
const pyuint_t& count);
25+
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
26+
const torch::Tensor &mu,
27+
const torch::Tensor &nu,
28+
const pyfloat_t &b1,
29+
const pyfloat_t &b2,
30+
const pyfloat_t &eps,
31+
const pyfloat_t &eps_root,
32+
const pyuint_t &count);
33+
34+
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
35+
const torch::Tensor &mu,
36+
const pyfloat_t &b1);
37+
38+
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
39+
const torch::Tensor &nu,
40+
const pyfloat_t &b2);
41+
42+
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
43+
const torch::Tensor &new_nu,
44+
const pyfloat_t &b1,
45+
const pyfloat_t &b2,
46+
const pyfloat_t &eps,
47+
const pyfloat_t &eps_root,
48+
const pyuint_t &count);
49+
50+
TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
51+
const torch::Tensor &updates,
52+
const torch::Tensor &mu,
53+
const pyfloat_t &b1);
54+
55+
TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
56+
const torch::Tensor &updates,
57+
const torch::Tensor &nu,
58+
const pyfloat_t &b2);
59+
60+
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates,
61+
const torch::Tensor &updates,
62+
const torch::Tensor &new_mu,
63+
const torch::Tensor &new_nu,
64+
const pyfloat_t &b1,
65+
const pyfloat_t &b2,
66+
const pyuint_t &count);
5767
} // namespace adam_op
5868
} // namespace torchopt

include/adam_op/adam_op_impl_cuda.cuh

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,47 @@
2222

2323
namespace torchopt {
2424
namespace adam_op {
25-
TensorArray<3> adamForwardInplaceCUDA(
26-
const torch::Tensor &updates, const torch::Tensor &mu,
27-
const torch::Tensor &nu, const pyfloat_t &b1, const pyfloat_t &b2,
28-
const pyfloat_t &eps, const pyfloat_t &eps_root, const pyuint_t &count);
25+
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
26+
const torch::Tensor &mu,
27+
const torch::Tensor &nu,
28+
const pyfloat_t &b1,
29+
const pyfloat_t &b2,
30+
const pyfloat_t &eps,
31+
const pyfloat_t &eps_root,
32+
const pyuint_t &count);
2933

3034
torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
31-
const torch::Tensor &mu, const pyfloat_t &b1);
35+
const torch::Tensor &mu,
36+
const pyfloat_t &b1);
3237

3338
torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
34-
const torch::Tensor &nu, const pyfloat_t &b2);
39+
const torch::Tensor &nu,
40+
const pyfloat_t &b2);
3541

3642
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
3743
const torch::Tensor &new_nu,
38-
const pyfloat_t &b1, const pyfloat_t &b2,
44+
const pyfloat_t &b1,
45+
const pyfloat_t &b2,
3946
const pyfloat_t &eps,
4047
const pyfloat_t &eps_root,
4148
const pyuint_t &count);
4249

4350
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
4451
const torch::Tensor &updates,
45-
const torch::Tensor &mu, const pyfloat_t &b1);
52+
const torch::Tensor &mu,
53+
const pyfloat_t &b1);
4654

4755
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
4856
const torch::Tensor &updates,
49-
const torch::Tensor &nu, const pyfloat_t &b2);
57+
const torch::Tensor &nu,
58+
const pyfloat_t &b2);
5059

5160
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
5261
const torch::Tensor &updates,
5362
const torch::Tensor &new_mu,
5463
const torch::Tensor &new_nu,
55-
const pyfloat_t &b1, const pyfloat_t &b2,
64+
const pyfloat_t &b1,
65+
const pyfloat_t &b2,
5666
const pyuint_t &count);
5767
} // namespace adam_op
5868
} // namespace torchopt

include/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ using pyfloat_t = double;
2323
using pyuint_t = std::size_t;
2424

2525
namespace torchopt {
26-
template <size_t _Nm>
27-
using TensorArray = std::array<torch::Tensor, _Nm>;
26+
template <size_t N>
27+
using TensorArray = std::array<torch::Tensor, N>;
2828
}

include/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#endif
2424

2525
namespace torchopt {
26-
__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
26+
__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) {
2727
const auto dim = tensor.dim();
2828
size_t n = 1;
2929
for (std::decay_t<decltype(dim)> i = 0; i < dim; ++i) {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy