Skip to content

Commit 28f76ac

Browse files
committed
feat(acc_op): add if condition for the element number small situations
1 parent e9a260a commit 28f76ac

File tree

1 file changed

+180
-32
lines changed

1 file changed

+180
-32
lines changed

src/adam_op/adam_op_impl_cpu.cpp

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,37 @@ using std::size_t;
2727

2828
namespace adam_op {
2929

30+
constexpr int min_elements_use_omp = 1000;
31+
32+
template <typename scalar_t, typename other_t>
33+
void adamForwardInplaceCPUKernel(const other_t b1,
34+
const other_t inv_one_minus_pow_b1,
35+
const other_t b2,
36+
const other_t inv_one_minus_pow_b2,
37+
const other_t eps,
38+
const other_t eps_root,
39+
const size_t n,
40+
scalar_t *__restrict__ updates_ptr,
41+
scalar_t *__restrict__ mu_ptr,
42+
scalar_t *__restrict__ nu_ptr) {
43+
#pragma omp parallel for num_threads(std::min( \
44+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
45+
for (size_t tid = 0; tid < n; ++tid) {
46+
const scalar_t updates = updates_ptr[tid];
47+
const scalar_t mu = mu_ptr[tid];
48+
const scalar_t nu = nu_ptr[tid];
49+
50+
const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
51+
const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates;
52+
const scalar_t updates_out =
53+
mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps);
54+
55+
mu_ptr[tid] = mu_out;
56+
nu_ptr[tid] = nu_out;
57+
updates_ptr[tid] = updates_out;
58+
}
59+
}
60+
3061
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
3162
const torch::Tensor &mu,
3263
const torch::Tensor &nu,
@@ -39,44 +70,110 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
3970
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count));
4071
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));
4172

42-
AT_DISPATCH_SCALAR_TYPES(
43-
updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
44-
mu.mul_(scalar_t(b1)).add_(updates, 1 - scalar_t(b1));
45-
46-
nu.mul_(scalar_t(b2)).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));
47-
48-
updates.copy_(mu.mul(scalar_t(inv_one_minus_pow_b1))
49-
.div_(nu.mul(inv_one_minus_pow_b2)
50-
.add_(scalar_t(eps_root))
51-
.sqrt_()
52-
.add_(scalar_t(eps))));
53-
}));
73+
const size_t n = getTensorPlainSize(updates);
74+
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
75+
adamForwardInplaceCPUKernel<scalar_t, scalar_t>(
76+
scalar_t(b1),
77+
scalar_t(inv_one_minus_pow_b1),
78+
scalar_t(b2),
79+
scalar_t(inv_one_minus_pow_b2),
80+
scalar_t(eps),
81+
scalar_t(eps_root),
82+
n,
83+
updates.data_ptr<scalar_t>(),
84+
mu.data_ptr<scalar_t>(),
85+
nu.data_ptr<scalar_t>());
86+
}));
5487
return TensorArray<3>{updates, mu, nu};
5588
}
5689

90+
template <typename scalar_t, typename other_t>
91+
void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
92+
const scalar_t *__restrict__ mu_ptr,
93+
const other_t b1,
94+
const size_t n,
95+
scalar_t *__restrict__ mu_out_ptr) {
96+
#pragma omp parallel for num_threads(std::min( \
97+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
98+
for (size_t tid = 0; tid < n; ++tid) {
99+
const scalar_t updates = updates_ptr[tid];
100+
const scalar_t mu = mu_ptr[tid];
101+
const scalar_t mu_out = b1 * mu + (1 - b1) * updates;
102+
mu_out_ptr[tid] = mu_out;
103+
}
104+
}
105+
57106
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
58107
const torch::Tensor &mu,
59108
const pyfloat_t b1) {
60-
torch::Tensor mu_out;
109+
auto mu_out = torch::empty_like(mu);
61110

111+
const size_t n = getTensorPlainSize(updates);
62112
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardMuCPU", ([&] {
63-
mu_out = mu.mul(b1).add_(updates, 1 - scalar_t(b1));
113+
adamForwardMuCPUKernel<scalar_t, scalar_t>(
114+
updates.data_ptr<scalar_t>(),
115+
mu.data_ptr<scalar_t>(),
116+
scalar_t(b1),
117+
n,
118+
mu_out.data_ptr<scalar_t>());
64119
}));
65120
return mu_out;
66121
}
67122

123+
template <typename scalar_t, typename other_t>
124+
void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
125+
const scalar_t *__restrict__ nu_ptr,
126+
const other_t b2,
127+
const size_t n,
128+
scalar_t *__restrict__ nu_out_ptr) {
129+
#pragma omp parallel for num_threads(std::min( \
130+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
131+
for (size_t tid = 0; tid < n; ++tid) {
132+
const scalar_t updates = updates_ptr[tid];
133+
const scalar_t nu = nu_ptr[tid];
134+
135+
const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2);
136+
nu_out_ptr[tid] = nu_out;
137+
}
138+
}
139+
68140
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
69141
const torch::Tensor &nu,
70142
const pyfloat_t b2) {
71-
torch::Tensor nu_out;
143+
auto nu_out = torch::empty_like(nu);
72144

145+
const size_t n = getTensorPlainSize(updates);
73146
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardNuCPU", ([&] {
74-
nu_out =
75-
nu.mul(b2).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));
147+
adamForwardNuCPUKernel<scalar_t, scalar_t>(
148+
updates.data_ptr<scalar_t>(),
149+
nu.data_ptr<scalar_t>(),
150+
scalar_t(b2),
151+
n,
152+
nu_out.data_ptr<scalar_t>());
76153
}));
77154
return nu_out;
78155
}
79156

157+
template <typename scalar_t, typename other_t>
158+
void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
159+
const scalar_t *__restrict__ new_nu_ptr,
160+
const other_t inv_one_minus_pow_b1,
161+
const other_t inv_one_minus_pow_b2,
162+
const other_t eps,
163+
const other_t eps_root,
164+
const size_t n,
165+
scalar_t *__restrict__ updates_out_ptr) {
166+
#pragma omp parallel for num_threads(std::min( \
167+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
168+
for (size_t tid = 0; tid < n; ++tid) {
169+
const scalar_t new_mu = new_mu_ptr[tid];
170+
const scalar_t new_nu = new_nu_ptr[tid];
171+
const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1;
172+
const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2;
173+
updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps);
174+
}
175+
}
176+
80177
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
81178
const torch::Tensor &new_nu,
82179
const pyfloat_t b1,
@@ -86,47 +183,97 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
86183
const pyuint_t count) {
87184
using other_t = pyfloat_t;
88185

89-
torch::Tensor updates_out;
186+
auto updates_out = torch::empty_like(new_mu);
90187

91188
const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
92189
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
93190
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
94191
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;
95192

193+
const size_t n = getTensorPlainSize(new_mu);
96194
AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] {
97-
updates_out = new_mu.mul(scalar_t(inv_one_minus_pow_b1))
98-
.div_(new_nu.mul(scalar_t(inv_one_minus_pow_b2))
99-
.add_(scalar_t(eps_root))
100-
.sqrt_()
101-
.add_(scalar_t(eps)));
195+
adamForwardUpdatesCPUKernel<scalar_t, scalar_t>(
196+
new_mu.data_ptr<scalar_t>(),
197+
new_nu.data_ptr<scalar_t>(),
198+
scalar_t(inv_one_minus_pow_b1),
199+
scalar_t(inv_one_minus_pow_b2),
200+
scalar_t(eps),
201+
scalar_t(eps_root),
202+
n,
203+
updates_out.data_ptr<scalar_t>());
102204
}));
103205
return updates_out;
104206
}
105207

208+
template <typename scalar_t, typename other_t>
209+
void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
210+
const other_t b1,
211+
const size_t n,
212+
scalar_t *__restrict__ dupdates_out_ptr,
213+
scalar_t *__restrict__ dmu_out_ptr) {
214+
#pragma omp parallel for num_threads(std::min( \
215+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
216+
for (size_t tid = 0; tid < n; ++tid) {
217+
const scalar_t dmu = dmu_ptr[tid];
218+
219+
dupdates_out_ptr[tid] = (1 - b1) * dmu;
220+
dmu_out_ptr[tid] = b1 * dmu;
221+
}
222+
}
223+
106224
TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
107225
const torch::Tensor &updates,
108226
const torch::Tensor &mu,
109227
const pyfloat_t b1) {
110-
torch::Tensor dupdates_out;
111-
torch::Tensor dmu_out;
228+
auto dupdates_out = torch::empty_like(updates);
229+
auto dmu_out = torch::empty_like(mu);
112230

231+
const size_t n = getTensorPlainSize(dmu);
113232
AT_DISPATCH_SCALAR_TYPES(dmu.scalar_type(), "adamBackwardMuCPU", ([&] {
114-
dupdates_out = dmu.mul(1 - scalar_t(b1));
115-
dmu_out = dmu.mul(scalar_t(b1));
233+
adamBackwardMuCPUKernel<scalar_t, scalar_t>(
234+
dmu.data_ptr<scalar_t>(),
235+
scalar_t(b1),
236+
n,
237+
dupdates_out.data_ptr<scalar_t>(),
238+
dmu_out.data_ptr<scalar_t>());
116239
}));
117240
return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)};
118241
}
119242

243+
template <typename scalar_t, typename other_t>
244+
void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
245+
const scalar_t *__restrict__ updates_ptr,
246+
const other_t b2,
247+
const size_t n,
248+
scalar_t *__restrict__ dupdates_out_ptr,
249+
scalar_t *__restrict__ dnu_out_ptr) {
250+
#pragma omp parallel for num_threads(std::min( \
251+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
252+
for (size_t tid = 0; tid < n; ++tid) {
253+
const scalar_t dnu = dnu_ptr[tid];
254+
const scalar_t updates = updates_ptr[tid];
255+
256+
dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu;
257+
dnu_out_ptr[tid] = b2 * dnu;
258+
}
259+
}
260+
120261
TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
121262
const torch::Tensor &updates,
122263
const torch::Tensor &nu,
123264
const pyfloat_t b2) {
124-
torch::Tensor dupdates_out;
125-
torch::Tensor dnu_out;
265+
auto dupdates_out = torch::empty_like(updates);
266+
auto dnu_out = torch::empty_like(nu);
126267

268+
const size_t n = getTensorPlainSize(dnu);
127269
AT_DISPATCH_SCALAR_TYPES(dnu.scalar_type(), "adamForwardNuCPU", ([&] {
128-
dupdates_out = updates.mul(2 - 2 * scalar_t(b2)).mul_(dnu);
129-
dnu_out = dnu.mul(scalar_t(b2));
270+
adamBackwardNuCPUKernel<scalar_t, scalar_t>(
271+
dnu.data_ptr<scalar_t>(),
272+
updates.data_ptr<scalar_t>(),
273+
scalar_t(b2),
274+
n,
275+
dupdates_out.data_ptr<scalar_t>(),
276+
dnu_out.data_ptr<scalar_t>());
130277
}));
131278
return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)};
132279
}
@@ -140,7 +287,8 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
140287
const size_t n,
141288
scalar_t *__restrict__ dnew_mu_out_ptr,
142289
scalar_t *__restrict__ dnew_nu_out_ptr) {
143-
#pragma omp parallel for num_threads(omp_get_num_procs())
290+
#pragma omp parallel for num_threads(std::min( \
291+
n / (size_t)min_elements_use_omp, (size_t)omp_get_num_procs())) if (n > min_elements_use_omp)
144292
for (size_t tid = 0; tid < n; ++tid) {
145293
const scalar_t dupdates = dupdates_ptr[tid];
146294
const scalar_t updates = updates_ptr[tid];

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy