diff --git a/CHANGELOG.md b/CHANGELOG.md index a01f6751..d0747f6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105). - Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107). - Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). - Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 82accd8c..c2ae4ae0 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -27,6 +27,8 @@ using std::size_t; namespace adam_op { +constexpr size_t MIN_NUMEL_USE_OMP = 1000; + template void adamForwardInplaceCPUKernel(const other_t b1, const other_t inv_one_minus_pow_b1, @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -90,7 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -122,7 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -158,7 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -205,7 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -240,7 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -279,7 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid]; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy