Skip to content

Commit e9a260a

Browse files
committed
chore: remove unused code
1 parent 450e380 commit e9a260a

File tree

1 file changed

+0
-110
lines changed

1 file changed

+0
-110
lines changed

src/adam_op/adam_op_impl_cpu.cpp

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,6 @@ using std::size_t;
2727

2828
namespace adam_op {
2929

30-
template <typename scalar_t, typename other_t>
31-
void adamForwardInplaceCPUKernel(const other_t b1,
32-
const other_t inv_one_minus_pow_b1,
33-
const other_t b2,
34-
const other_t inv_one_minus_pow_b2,
35-
const other_t eps,
36-
const other_t eps_root,
37-
const size_t n,
38-
scalar_t *__restrict__ updates_ptr,
39-
scalar_t *__restrict__ mu_ptr,
40-
scalar_t *__restrict__ nu_ptr) {
41-
#pragma omp parallel for num_threads(omp_get_num_procs())
42-
for (size_t tid = 0; tid < n; ++tid) {
43-
const scalar_t updates = updates_ptr[tid];
44-
const scalar_t mu = mu_ptr[tid];
45-
const scalar_t nu = nu_ptr[tid];
46-
47-
const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
48-
const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates;
49-
const scalar_t updates_out =
50-
mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps);
51-
52-
mu_ptr[tid] = mu_out;
53-
nu_ptr[tid] = nu_out;
54-
updates_ptr[tid] = updates_out;
55-
}
56-
}
57-
5830
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
5931
const torch::Tensor &mu,
6032
const torch::Tensor &nu,
@@ -82,21 +54,6 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
8254
return TensorArray<3>{updates, mu, nu};
8355
}
8456

85-
template <typename scalar_t, typename other_t>
86-
void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
87-
const scalar_t *__restrict__ mu_ptr,
88-
const other_t b1,
89-
const size_t n,
90-
scalar_t *__restrict__ mu_out_ptr) {
91-
#pragma omp parallel for num_threads(omp_get_num_procs())
92-
for (size_t tid = 0; tid < n; ++tid) {
93-
const scalar_t updates = updates_ptr[tid];
94-
const scalar_t mu = mu_ptr[tid];
95-
const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
96-
mu_out_ptr[tid] = mu_out;
97-
}
98-
}
99-
10057
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
10158
const torch::Tensor &mu,
10259
const pyfloat_t b1) {
@@ -108,22 +65,6 @@ torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
10865
return mu_out;
10966
}
11067

111-
template <typename scalar_t, typename other_t>
112-
void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
113-
const scalar_t *__restrict__ nu_ptr,
114-
const other_t b2,
115-
const size_t n,
116-
scalar_t *__restrict__ nu_out_ptr) {
117-
#pragma omp parallel for num_threads(omp_get_num_procs())
118-
for (size_t tid = 0; tid < n; ++tid) {
119-
const scalar_t updates = updates_ptr[tid];
120-
const scalar_t nu = nu_ptr[tid];
121-
122-
const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2);
123-
nu_out_ptr[tid] = nu_out;
124-
}
125-
}
126-
12768
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
12869
const torch::Tensor &nu,
12970
const pyfloat_t b2) {
@@ -136,25 +77,6 @@ torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
13677
return nu_out;
13778
}
13879

139-
template <typename scalar_t, typename other_t>
140-
void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
141-
const scalar_t *__restrict__ new_nu_ptr,
142-
const other_t inv_one_minus_pow_b1,
143-
const other_t inv_one_minus_pow_b2,
144-
const other_t eps,
145-
const other_t eps_root,
146-
const size_t n,
147-
scalar_t *__restrict__ updates_out_ptr) {
148-
#pragma omp parallel for num_threads(omp_get_num_procs())
149-
for (size_t tid = 0; tid < n; ++tid) {
150-
const scalar_t new_mu = new_mu_ptr[tid];
151-
const scalar_t new_nu = new_nu_ptr[tid];
152-
const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1;
153-
const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2;
154-
updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps);
155-
}
156-
}
157-
15880
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
15981
const torch::Tensor &new_nu,
16082
const pyfloat_t b1,
@@ -181,21 +103,6 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
181103
return updates_out;
182104
}
183105

184-
template <typename scalar_t, typename other_t>
185-
void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
186-
const other_t b1,
187-
const size_t n,
188-
scalar_t *__restrict__ dupdates_out_ptr,
189-
scalar_t *__restrict__ dmu_out_ptr) {
190-
#pragma omp parallel for num_threads(omp_get_num_procs())
191-
for (size_t tid = 0; tid < n; ++tid) {
192-
const scalar_t dmu = dmu_ptr[tid];
193-
194-
dupdates_out_ptr[tid] = (1 - b1) * dmu;
195-
dmu_out_ptr[tid] = b1 * dmu;
196-
}
197-
}
198-
199106
TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
200107
const torch::Tensor &updates,
201108
const torch::Tensor &mu,
@@ -210,23 +117,6 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
210117
return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)};
211118
}
212119

213-
template <typename scalar_t, typename other_t>
214-
void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
215-
const scalar_t *__restrict__ updates_ptr,
216-
const other_t b2,
217-
const size_t n,
218-
scalar_t *__restrict__ dupdates_out_ptr,
219-
scalar_t *__restrict__ dnu_out_ptr) {
220-
#pragma omp parallel for num_threads(omp_get_num_procs())
221-
for (size_t tid = 0; tid < n; ++tid) {
222-
const scalar_t dnu = dnu_ptr[tid];
223-
const scalar_t updates = updates_ptr[tid];
224-
225-
dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu;
226-
dnu_out_ptr[tid] = b2 * dnu;
227-
}
228-
}
229-
230120
TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
231121
const torch::Tensor &updates,
232122
const torch::Tensor &nu,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy