Skip to content

Commit ab2ebff

Browse files
JieRen98XuehaiPan
andauthored
feat(custom_op): use dynamic process number (#42)
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
1 parent 809f9c7 commit ab2ebff

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [Unreleased]
1111

12+
### Added
13+
14+
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42).
15+
1216
------
1317

1418
## [0.4.2] - 2022-07-26

src/adam_op/adam_op_impl.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void adamForwardInplaceCPUKernel(
3131
const other_t inv_one_minus_pow_b2, const other_t eps,
3232
const other_t eps_root, const size_t n, scalar_t* __restrict__ updates_ptr,
3333
scalar_t* __restrict__ mu_ptr, scalar_t* __restrict__ nu_ptr) {
34-
#pragma omp parallel for num_threads(32)
34+
#pragma omp parallel for num_threads(omp_get_num_procs())
3535
for (size_t tid = 0; tid < n; ++tid) {
3636
const scalar_t updates = updates_ptr[tid];
3737
const scalar_t mu = mu_ptr[tid];
@@ -76,7 +76,7 @@ void adamForwardMuCPUKernel(const scalar_t* __restrict__ updates_ptr,
7676
const scalar_t* __restrict__ mu_ptr,
7777
const other_t b1, const size_t n,
7878
scalar_t* __restrict__ mu_out_ptr) {
79-
#pragma omp parallel for num_threads(32)
79+
#pragma omp parallel for num_threads(omp_get_num_procs())
8080
for (size_t tid = 0; tid < n; ++tid) {
8181
const scalar_t updates = updates_ptr[tid];
8282
const scalar_t mu = mu_ptr[tid];
@@ -108,7 +108,7 @@ void adamForwardNuCPUKernel(const scalar_t* __restrict__ updates_ptr,
108108
const scalar_t* __restrict__ nu_ptr,
109109
const other_t b2, const size_t n,
110110
scalar_t* __restrict__ nu_out_ptr) {
111-
#pragma omp parallel for num_threads(32)
111+
#pragma omp parallel for num_threads(omp_get_num_procs())
112112
for (size_t tid = 0; tid < n; ++tid) {
113113
const scalar_t updates = updates_ptr[tid];
114114
const scalar_t nu = nu_ptr[tid];
@@ -144,7 +144,7 @@ void adamForwardUpdatesCPUKernel(const scalar_t* __restrict__ new_mu_ptr,
144144
const other_t eps, const other_t eps_root,
145145
const size_t n,
146146
scalar_t* __restrict__ updates_out_ptr) {
147-
#pragma omp parallel for num_threads(32)
147+
#pragma omp parallel for num_threads(omp_get_num_procs())
148148
for (size_t tid = 0; tid < n; ++tid) {
149149
const scalar_t new_mu = new_mu_ptr[tid];
150150
const scalar_t new_nu = new_nu_ptr[tid];
@@ -185,7 +185,7 @@ void adamBackwardMuCPUKernel(const scalar_t* __restrict__ dmu_ptr,
185185
const other_t b1, const size_t n,
186186
scalar_t* __restrict__ dupdates_out_ptr,
187187
scalar_t* __restrict__ dmu_out_ptr) {
188-
#pragma omp parallel for num_threads(32)
188+
#pragma omp parallel for num_threads(omp_get_num_procs())
189189
for (size_t tid = 0; tid < n; ++tid) {
190190
const scalar_t dmu = dmu_ptr[tid];
191191

@@ -220,7 +220,7 @@ void adamBackwardNuCPUKernel(const scalar_t* __restrict__ dnu_ptr,
220220
const other_t b2, const size_t n,
221221
scalar_t* __restrict__ dupdates_out_ptr,
222222
scalar_t* __restrict__ dnu_out_ptr) {
223-
#pragma omp parallel for num_threads(32)
223+
#pragma omp parallel for num_threads(omp_get_num_procs())
224224
for (size_t tid = 0; tid < n; ++tid) {
225225
const scalar_t dnu = dnu_ptr[tid];
226226
const scalar_t updates = updates_ptr[tid];
@@ -259,7 +259,7 @@ void adamBackwardUpdatesCPUKernel(const scalar_t* __restrict__ dupdates_ptr,
259259
const size_t n,
260260
scalar_t* __restrict__ dnew_mu_out_ptr,
261261
scalar_t* __restrict__ dnew_nu_out_ptr) {
262-
#pragma omp parallel for num_threads(32)
262+
#pragma omp parallel for num_threads(omp_get_num_procs())
263263
for (size_t tid = 0; tid < n; ++tid) {
264264
const scalar_t dupdates = dupdates_ptr[tid];
265265
const scalar_t updates = updates_ptr[tid];

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy