@@ -27,34 +27,6 @@ using std::size_t;
27
27
28
28
namespace adam_op {
29
29
30
- template <typename scalar_t , typename other_t >
31
- void adamForwardInplaceCPUKernel (const other_t b1,
32
- const other_t inv_one_minus_pow_b1,
33
- const other_t b2,
34
- const other_t inv_one_minus_pow_b2,
35
- const other_t eps,
36
- const other_t eps_root,
37
- const size_t n,
38
- scalar_t *__restrict__ updates_ptr,
39
- scalar_t *__restrict__ mu_ptr,
40
- scalar_t *__restrict__ nu_ptr) {
41
- #pragma omp parallel for num_threads(omp_get_num_procs())
42
- for (size_t tid = 0 ; tid < n; ++tid) {
43
- const scalar_t updates = updates_ptr[tid];
44
- const scalar_t mu = mu_ptr[tid];
45
- const scalar_t nu = nu_ptr[tid];
46
-
47
- const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
48
- const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates;
49
- const scalar_t updates_out =
50
- mu_out * inv_one_minus_pow_b1 / (sqrt (nu_out * inv_one_minus_pow_b2 + eps_root) + eps);
51
-
52
- mu_ptr[tid] = mu_out;
53
- nu_ptr[tid] = nu_out;
54
- updates_ptr[tid] = updates_out;
55
- }
56
- }
57
-
58
30
TensorArray<3 > adamForwardInplaceCPU (const torch::Tensor &updates,
59
31
const torch::Tensor &mu,
60
32
const torch::Tensor &nu,
@@ -82,21 +54,6 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
82
54
return TensorArray<3 >{updates, mu, nu};
83
55
}
84
56
85
- template <typename scalar_t , typename other_t >
86
- void adamForwardMuCPUKernel (const scalar_t *__restrict__ updates_ptr,
87
- const scalar_t *__restrict__ mu_ptr,
88
- const other_t b1,
89
- const size_t n,
90
- scalar_t *__restrict__ mu_out_ptr) {
91
- #pragma omp parallel for num_threads(omp_get_num_procs())
92
- for (size_t tid = 0 ; tid < n; ++tid) {
93
- const scalar_t updates = updates_ptr[tid];
94
- const scalar_t mu = mu_ptr[tid];
95
- const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
96
- mu_out_ptr[tid] = mu_out;
97
- }
98
- }
99
-
100
57
torch::Tensor adamForwardMuCPU (const torch::Tensor &updates,
101
58
const torch::Tensor &mu,
102
59
const pyfloat_t b1) {
@@ -108,22 +65,6 @@ torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
108
65
return mu_out;
109
66
}
110
67
111
- template <typename scalar_t , typename other_t >
112
- void adamForwardNuCPUKernel (const scalar_t *__restrict__ updates_ptr,
113
- const scalar_t *__restrict__ nu_ptr,
114
- const other_t b2,
115
- const size_t n,
116
- scalar_t *__restrict__ nu_out_ptr) {
117
- #pragma omp parallel for num_threads(omp_get_num_procs())
118
- for (size_t tid = 0 ; tid < n; ++tid) {
119
- const scalar_t updates = updates_ptr[tid];
120
- const scalar_t nu = nu_ptr[tid];
121
-
122
- const scalar_t nu_out = b2 * nu + (1 - b2) * pow (updates, 2 );
123
- nu_out_ptr[tid] = nu_out;
124
- }
125
- }
126
-
127
68
torch::Tensor adamForwardNuCPU (const torch::Tensor &updates,
128
69
const torch::Tensor &nu,
129
70
const pyfloat_t b2) {
@@ -136,25 +77,6 @@ torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
136
77
return nu_out;
137
78
}
138
79
139
- template <typename scalar_t , typename other_t >
140
- void adamForwardUpdatesCPUKernel (const scalar_t *__restrict__ new_mu_ptr,
141
- const scalar_t *__restrict__ new_nu_ptr,
142
- const other_t inv_one_minus_pow_b1,
143
- const other_t inv_one_minus_pow_b2,
144
- const other_t eps,
145
- const other_t eps_root,
146
- const size_t n,
147
- scalar_t *__restrict__ updates_out_ptr) {
148
- #pragma omp parallel for num_threads(omp_get_num_procs())
149
- for (size_t tid = 0 ; tid < n; ++tid) {
150
- const scalar_t new_mu = new_mu_ptr[tid];
151
- const scalar_t new_nu = new_nu_ptr[tid];
152
- const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1;
153
- const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2;
154
- updates_out_ptr[tid] = mu_hat / (sqrt (nu_hat + eps_root) + eps);
155
- }
156
- }
157
-
158
80
torch::Tensor adamForwardUpdatesCPU (const torch::Tensor &new_mu,
159
81
const torch::Tensor &new_nu,
160
82
const pyfloat_t b1,
@@ -181,21 +103,6 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
181
103
return updates_out;
182
104
}
183
105
184
- template <typename scalar_t , typename other_t >
185
- void adamBackwardMuCPUKernel (const scalar_t *__restrict__ dmu_ptr,
186
- const other_t b1,
187
- const size_t n,
188
- scalar_t *__restrict__ dupdates_out_ptr,
189
- scalar_t *__restrict__ dmu_out_ptr) {
190
- #pragma omp parallel for num_threads(omp_get_num_procs())
191
- for (size_t tid = 0 ; tid < n; ++tid) {
192
- const scalar_t dmu = dmu_ptr[tid];
193
-
194
- dupdates_out_ptr[tid] = (1 - b1) * dmu;
195
- dmu_out_ptr[tid] = b1 * dmu;
196
- }
197
- }
198
-
199
106
TensorArray<2 > adamBackwardMuCPU (const torch::Tensor &dmu,
200
107
const torch::Tensor &updates,
201
108
const torch::Tensor &mu,
@@ -210,23 +117,6 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
210
117
return TensorArray<2 >{std::move (dupdates_out), std::move (dmu_out)};
211
118
}
212
119
213
- template <typename scalar_t , typename other_t >
214
- void adamBackwardNuCPUKernel (const scalar_t *__restrict__ dnu_ptr,
215
- const scalar_t *__restrict__ updates_ptr,
216
- const other_t b2,
217
- const size_t n,
218
- scalar_t *__restrict__ dupdates_out_ptr,
219
- scalar_t *__restrict__ dnu_out_ptr) {
220
- #pragma omp parallel for num_threads(omp_get_num_procs())
221
- for (size_t tid = 0 ; tid < n; ++tid) {
222
- const scalar_t dnu = dnu_ptr[tid];
223
- const scalar_t updates = updates_ptr[tid];
224
-
225
- dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu;
226
- dnu_out_ptr[tid] = b2 * dnu;
227
- }
228
- }
229
-
230
120
TensorArray<2 > adamBackwardNuCPU (const torch::Tensor &dnu,
231
121
const torch::Tensor &updates,
232
122
const torch::Tensor &nu,
0 commit comments