@@ -27,6 +27,37 @@ using std::size_t;
27
27
28
28
namespace adam_op {
29
29
30
+ constexpr int min_elements_use_omp = 1000 ;
31
+
32
+ template <typename scalar_t , typename other_t >
33
+ void adamForwardInplaceCPUKernel (const other_t b1,
34
+ const other_t inv_one_minus_pow_b1,
35
+ const other_t b2,
36
+ const other_t inv_one_minus_pow_b2,
37
+ const other_t eps,
38
+ const other_t eps_root,
39
+ const size_t n,
40
+ scalar_t *__restrict__ updates_ptr,
41
+ scalar_t *__restrict__ mu_ptr,
42
+ scalar_t *__restrict__ nu_ptr) {
43
+ #pragma omp parallel for num_threads(std::min( \
44
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
45
+ for (size_t tid = 0 ; tid < n; ++tid) {
46
+ const scalar_t updates = updates_ptr[tid];
47
+ const scalar_t mu = mu_ptr[tid];
48
+ const scalar_t nu = nu_ptr[tid];
49
+
50
+ const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
51
+ const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates;
52
+ const scalar_t updates_out =
53
+ mu_out * inv_one_minus_pow_b1 / (sqrt (nu_out * inv_one_minus_pow_b2 + eps_root) + eps);
54
+
55
+ mu_ptr[tid] = mu_out;
56
+ nu_ptr[tid] = nu_out;
57
+ updates_ptr[tid] = updates_out;
58
+ }
59
+ }
60
+
30
61
TensorArray<3 > adamForwardInplaceCPU (const torch::Tensor &updates,
31
62
const torch::Tensor &mu,
32
63
const torch::Tensor &nu,
@@ -39,44 +70,110 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
39
70
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow (b1, count));
40
71
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow (b2, count));
41
72
42
- AT_DISPATCH_SCALAR_TYPES (
43
- updates.scalar_type (), " adamForwardInplaceCPU" , ([&] {
44
- mu.mul_ (scalar_t (b1)).add_ (updates, 1 - scalar_t (b1));
45
-
46
- nu.mul_ (scalar_t (b2)).addcmul_ (updates, updates.conj (), 1 - scalar_t (b2));
47
-
48
- updates.copy_ (mu.mul (scalar_t (inv_one_minus_pow_b1))
49
- .div_ (nu.mul (inv_one_minus_pow_b2)
50
- .add_ (scalar_t (eps_root))
51
- .sqrt_ ()
52
- .add_ (scalar_t (eps))));
53
- }));
73
+ const size_t n = getTensorPlainSize (updates);
74
+ AT_DISPATCH_SCALAR_TYPES (updates.scalar_type (), " adamForwardInplaceCPU" , ([&] {
75
+ adamForwardInplaceCPUKernel<scalar_t , scalar_t >(
76
+ scalar_t (b1),
77
+ scalar_t (inv_one_minus_pow_b1),
78
+ scalar_t (b2),
79
+ scalar_t (inv_one_minus_pow_b2),
80
+ scalar_t (eps),
81
+ scalar_t (eps_root),
82
+ n,
83
+ updates.data_ptr <scalar_t >(),
84
+ mu.data_ptr <scalar_t >(),
85
+ nu.data_ptr <scalar_t >());
86
+ }));
54
87
return TensorArray<3 >{updates, mu, nu};
55
88
}
56
89
90
+ template <typename scalar_t , typename other_t >
91
+ void adamForwardMuCPUKernel (const scalar_t *__restrict__ updates_ptr,
92
+ const scalar_t *__restrict__ mu_ptr,
93
+ const other_t b1,
94
+ const size_t n,
95
+ scalar_t *__restrict__ mu_out_ptr) {
96
+ #pragma omp parallel for num_threads(std::min( \
97
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
98
+ for (size_t tid = 0 ; tid < n; ++tid) {
99
+ const scalar_t updates = updates_ptr[tid];
100
+ const scalar_t mu = mu_ptr[tid];
101
+ const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
102
+ mu_out_ptr[tid] = mu_out;
103
+ }
104
+ }
105
+
57
106
torch::Tensor adamForwardMuCPU (const torch::Tensor &updates,
58
107
const torch::Tensor &mu,
59
108
const pyfloat_t b1) {
60
- torch::Tensor mu_out ;
109
+ auto mu_out = torch::empty_like (mu) ;
61
110
111
+ const size_t n = getTensorPlainSize (updates);
62
112
AT_DISPATCH_SCALAR_TYPES (updates.scalar_type (), " adamForwardMuCPU" , ([&] {
63
- mu_out = mu.mul (b1).add_ (updates, 1 - scalar_t (b1));
113
+ adamForwardMuCPUKernel<scalar_t , scalar_t >(
114
+ updates.data_ptr <scalar_t >(),
115
+ mu.data_ptr <scalar_t >(),
116
+ scalar_t (b1),
117
+ n,
118
+ mu_out.data_ptr <scalar_t >());
64
119
}));
65
120
return mu_out;
66
121
}
67
122
123
+ template <typename scalar_t , typename other_t >
124
+ void adamForwardNuCPUKernel (const scalar_t *__restrict__ updates_ptr,
125
+ const scalar_t *__restrict__ nu_ptr,
126
+ const other_t b2,
127
+ const size_t n,
128
+ scalar_t *__restrict__ nu_out_ptr) {
129
+ #pragma omp parallel for num_threads(std::min( \
130
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
131
+ for (size_t tid = 0 ; tid < n; ++tid) {
132
+ const scalar_t updates = updates_ptr[tid];
133
+ const scalar_t nu = nu_ptr[tid];
134
+
135
+ const scalar_t nu_out = b2 * nu + (1 - b2) * pow (updates, 2 );
136
+ nu_out_ptr[tid] = nu_out;
137
+ }
138
+ }
139
+
68
140
torch::Tensor adamForwardNuCPU (const torch::Tensor &updates,
69
141
const torch::Tensor &nu,
70
142
const pyfloat_t b2) {
71
- torch::Tensor nu_out ;
143
+ auto nu_out = torch::empty_like (nu) ;
72
144
145
+ const size_t n = getTensorPlainSize (updates);
73
146
AT_DISPATCH_SCALAR_TYPES (updates.scalar_type (), " adamForwardNuCPU" , ([&] {
74
- nu_out =
75
- nu.mul (b2).addcmul_ (updates, updates.conj (), 1 - scalar_t (b2));
147
+ adamForwardNuCPUKernel<scalar_t , scalar_t >(
148
+ updates.data_ptr <scalar_t >(),
149
+ nu.data_ptr <scalar_t >(),
150
+ scalar_t (b2),
151
+ n,
152
+ nu_out.data_ptr <scalar_t >());
76
153
}));
77
154
return nu_out;
78
155
}
79
156
157
+ template <typename scalar_t , typename other_t >
158
+ void adamForwardUpdatesCPUKernel (const scalar_t *__restrict__ new_mu_ptr,
159
+ const scalar_t *__restrict__ new_nu_ptr,
160
+ const other_t inv_one_minus_pow_b1,
161
+ const other_t inv_one_minus_pow_b2,
162
+ const other_t eps,
163
+ const other_t eps_root,
164
+ const size_t n,
165
+ scalar_t *__restrict__ updates_out_ptr) {
166
+ #pragma omp parallel for num_threads(std::min( \
167
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
168
+ for (size_t tid = 0 ; tid < n; ++tid) {
169
+ const scalar_t new_mu = new_mu_ptr[tid];
170
+ const scalar_t new_nu = new_nu_ptr[tid];
171
+ const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1;
172
+ const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2;
173
+ updates_out_ptr[tid] = mu_hat / (sqrt (nu_hat + eps_root) + eps);
174
+ }
175
+ }
176
+
80
177
torch::Tensor adamForwardUpdatesCPU (const torch::Tensor &new_mu,
81
178
const torch::Tensor &new_nu,
82
179
const pyfloat_t b1,
@@ -86,47 +183,97 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
86
183
const pyuint_t count) {
87
184
using other_t = pyfloat_t ;
88
185
89
- torch::Tensor updates_out ;
186
+ auto updates_out = torch::empty_like (new_mu) ;
90
187
91
188
const other_t one_minus_pow_b1 = 1 - std::pow (b1, count);
92
189
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
93
190
const other_t one_minus_pow_b2 = 1 - std::pow (b2, count);
94
191
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;
95
192
193
+ const size_t n = getTensorPlainSize (new_mu);
96
194
AT_DISPATCH_SCALAR_TYPES (new_mu.scalar_type (), " adamForwardUpdatesCPU" , ([&] {
97
- updates_out = new_mu.mul (scalar_t (inv_one_minus_pow_b1))
98
- .div_ (new_nu.mul (scalar_t (inv_one_minus_pow_b2))
99
- .add_ (scalar_t (eps_root))
100
- .sqrt_ ()
101
- .add_ (scalar_t (eps)));
195
+ adamForwardUpdatesCPUKernel<scalar_t , scalar_t >(
196
+ new_mu.data_ptr <scalar_t >(),
197
+ new_nu.data_ptr <scalar_t >(),
198
+ scalar_t (inv_one_minus_pow_b1),
199
+ scalar_t (inv_one_minus_pow_b2),
200
+ scalar_t (eps),
201
+ scalar_t (eps_root),
202
+ n,
203
+ updates_out.data_ptr <scalar_t >());
102
204
}));
103
205
return updates_out;
104
206
}
105
207
208
+ template <typename scalar_t , typename other_t >
209
+ void adamBackwardMuCPUKernel (const scalar_t *__restrict__ dmu_ptr,
210
+ const other_t b1,
211
+ const size_t n,
212
+ scalar_t *__restrict__ dupdates_out_ptr,
213
+ scalar_t *__restrict__ dmu_out_ptr) {
214
+ #pragma omp parallel for num_threads(std::min( \
215
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
216
+ for (size_t tid = 0 ; tid < n; ++tid) {
217
+ const scalar_t dmu = dmu_ptr[tid];
218
+
219
+ dupdates_out_ptr[tid] = (1 - b1) * dmu;
220
+ dmu_out_ptr[tid] = b1 * dmu;
221
+ }
222
+ }
223
+
106
224
TensorArray<2 > adamBackwardMuCPU (const torch::Tensor &dmu,
107
225
const torch::Tensor &updates,
108
226
const torch::Tensor &mu,
109
227
const pyfloat_t b1) {
110
- torch::Tensor dupdates_out ;
111
- torch::Tensor dmu_out ;
228
+ auto dupdates_out = torch::empty_like (updates) ;
229
+ auto dmu_out = torch::empty_like (mu) ;
112
230
231
+ const size_t n = getTensorPlainSize (dmu);
113
232
AT_DISPATCH_SCALAR_TYPES (dmu.scalar_type (), " adamBackwardMuCPU" , ([&] {
114
- dupdates_out = dmu.mul (1 - scalar_t (b1));
115
- dmu_out = dmu.mul (scalar_t (b1));
233
+ adamBackwardMuCPUKernel<scalar_t , scalar_t >(
234
+ dmu.data_ptr <scalar_t >(),
235
+ scalar_t (b1),
236
+ n,
237
+ dupdates_out.data_ptr <scalar_t >(),
238
+ dmu_out.data_ptr <scalar_t >());
116
239
}));
117
240
return TensorArray<2 >{std::move (dupdates_out), std::move (dmu_out)};
118
241
}
119
242
243
+ template <typename scalar_t , typename other_t >
244
+ void adamBackwardNuCPUKernel (const scalar_t *__restrict__ dnu_ptr,
245
+ const scalar_t *__restrict__ updates_ptr,
246
+ const other_t b2,
247
+ const size_t n,
248
+ scalar_t *__restrict__ dupdates_out_ptr,
249
+ scalar_t *__restrict__ dnu_out_ptr) {
250
+ #pragma omp parallel for num_threads(std::min( \
251
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
252
+ for (size_t tid = 0 ; tid < n; ++tid) {
253
+ const scalar_t dnu = dnu_ptr[tid];
254
+ const scalar_t updates = updates_ptr[tid];
255
+
256
+ dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu;
257
+ dnu_out_ptr[tid] = b2 * dnu;
258
+ }
259
+ }
260
+
120
261
TensorArray<2 > adamBackwardNuCPU (const torch::Tensor &dnu,
121
262
const torch::Tensor &updates,
122
263
const torch::Tensor &nu,
123
264
const pyfloat_t b2) {
124
- torch::Tensor dupdates_out ;
125
- torch::Tensor dnu_out ;
265
+ auto dupdates_out = torch::empty_like (updates) ;
266
+ auto dnu_out = torch::empty_like (nu) ;
126
267
268
+ const size_t n = getTensorPlainSize (dnu);
127
269
AT_DISPATCH_SCALAR_TYPES (dnu.scalar_type (), " adamForwardNuCPU" , ([&] {
128
- dupdates_out = updates.mul (2 - 2 * scalar_t (b2)).mul_ (dnu);
129
- dnu_out = dnu.mul (scalar_t (b2));
270
+ adamBackwardNuCPUKernel<scalar_t , scalar_t >(
271
+ dnu.data_ptr <scalar_t >(),
272
+ updates.data_ptr <scalar_t >(),
273
+ scalar_t (b2),
274
+ n,
275
+ dupdates_out.data_ptr <scalar_t >(),
276
+ dnu_out.data_ptr <scalar_t >());
130
277
}));
131
278
return TensorArray<2 >{std::move (dupdates_out), std::move (dnu_out)};
132
279
}
@@ -140,7 +287,8 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
140
287
const size_t n,
141
288
scalar_t *__restrict__ dnew_mu_out_ptr,
142
289
scalar_t *__restrict__ dnew_nu_out_ptr) {
143
- #pragma omp parallel for num_threads(omp_get_num_procs())
290
+ #pragma omp parallel for num_threads(std::min( \
291
+ n / (size_t )min_elements_use_omp, (size_t )omp_get_num_procs ())) if (n > min_elements_use_omp)
144
292
for (size_t tid = 0 ; tid < n; ++tid) {
145
293
const scalar_t dupdates = dupdates_ptr[tid];
146
294
const scalar_t updates = updates_ptr[tid];
0 commit comments