@@ -31,7 +31,7 @@ void adamForwardInplaceCPUKernel(
31
31
const other_t inv_one_minus_pow_b2, const other_t eps,
32
32
const other_t eps_root, const size_t n, scalar_t * __restrict__ updates_ptr,
33
33
scalar_t * __restrict__ mu_ptr, scalar_t * __restrict__ nu_ptr) {
34
- #pragma omp parallel for num_threads(32 )
34
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
35
35
for (size_t tid = 0 ; tid < n; ++tid) {
36
36
const scalar_t updates = updates_ptr[tid];
37
37
const scalar_t mu = mu_ptr[tid];
@@ -76,7 +76,7 @@ void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr,
76
76
const scalar_t * __restrict__ mu_ptr,
77
77
const other_t b1, const size_t n,
78
78
scalar_t * __restrict__ mu_out_ptr) {
79
- #pragma omp parallel for num_threads(32 )
79
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
80
80
for (size_t tid = 0 ; tid < n; ++tid) {
81
81
const scalar_t updates = updates_ptr[tid];
82
82
const scalar_t mu = mu_ptr[tid];
@@ -108,7 +108,7 @@ void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr,
108
108
const scalar_t * __restrict__ nu_ptr,
109
109
const other_t b2, const size_t n,
110
110
scalar_t * __restrict__ nu_out_ptr) {
111
- #pragma omp parallel for num_threads(32 )
111
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
112
112
for (size_t tid = 0 ; tid < n; ++tid) {
113
113
const scalar_t updates = updates_ptr[tid];
114
114
const scalar_t nu = nu_ptr[tid];
@@ -144,7 +144,7 @@ void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr,
144
144
const other_t eps, const other_t eps_root,
145
145
const size_t n,
146
146
scalar_t * __restrict__ updates_out_ptr) {
147
- #pragma omp parallel for num_threads(32 )
147
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
148
148
for (size_t tid = 0 ; tid < n; ++tid) {
149
149
const scalar_t new_mu = new_mu_ptr[tid];
150
150
const scalar_t new_nu = new_nu_ptr[tid];
@@ -185,7 +185,7 @@ void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr,
185
185
const other_t b1, const size_t n,
186
186
scalar_t * __restrict__ dupdates_out_ptr,
187
187
scalar_t * __restrict__ dmu_out_ptr) {
188
- #pragma omp parallel for num_threads(32 )
188
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
189
189
for (size_t tid = 0 ; tid < n; ++tid) {
190
190
const scalar_t dmu = dmu_ptr[tid];
191
191
@@ -220,7 +220,7 @@ void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr,
220
220
const other_t b2, const size_t n,
221
221
scalar_t * __restrict__ dupdates_out_ptr,
222
222
scalar_t * __restrict__ dnu_out_ptr) {
223
- #pragma omp parallel for num_threads(32 )
223
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
224
224
for (size_t tid = 0 ; tid < n; ++tid) {
225
225
const scalar_t dnu = dnu_ptr[tid];
226
226
const scalar_t updates = updates_ptr[tid];
@@ -259,7 +259,7 @@ void adamBackwardUpdatesCPUKernel(const scalar_t* __restrict__ dupdates_ptr,
259
259
const size_t n,
260
260
scalar_t * __restrict__ dnew_mu_out_ptr,
261
261
scalar_t * __restrict__ dnew_nu_out_ptr) {
262
- #pragma omp parallel for num_threads(32 )
262
+ #pragma omp parallel for num_threads(omp_get_num_procs() )
263
263
for (size_t tid = 0 ; tid < n; ++tid) {
264
264
const scalar_t dupdates = dupdates_ptr[tid];
265
265
const scalar_t updates = updates_ptr[tid];
0 commit comments