diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a524953..073d935d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add unroll pragma for CUDA OPs by [@JieRen98](https://github.com/JieRen98) and [@XuehaiPan](https://github.com/XuehaiPan) in [#112](https://github.com/metaopt/torchopt/pull/112). - Add Python implementation of accelerated OP and pure-Python wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#67](https://github.com/metaopt/torchopt/pull/67). - Add `nan_to_num` hook and gradient transformation by [@XuehaiPan](https://github.com/XuehaiPan) in [#119](https://github.com/metaopt/torchopt/pull/119). - Add matrix inversion linear solver with neumann series approximation by [@Benjamin-eecs](https://github.com/Benjamin-eecs) and [@XuehaiPan](https://github.com/XuehaiPan) in [#98](https://github.com/metaopt/torchopt/pull/98). diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index a3b1e9e2..cf734c4f 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -135,7 +135,7 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; nu_out_ptr[tid] = nu_out; } } diff --git a/src/adam_op/adam_op_impl_cuda.cu b/src/adam_op/adam_op_impl_cuda.cu index 4e0b9ac7..4b65869f 100644 --- a/src/adam_op/adam_op_impl_cuda.cu +++ b/src/adam_op/adam_op_impl_cuda.cu @@ -24,7 +24,10 @@ namespace torchopt { namespace adam_op { -template +constexpr int UNROLL_SIZE = 4; +constexpr int BLOCK_SIZE = 256; + +template __global__ void adamForwardInplaceCUDAKernel(const other_t b1, const other_t inv_one_minus_pow_b1, const other_t b2, @@ -35,22 +38,26 @@ __global__ void adamForwardInplaceCUDAKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { - unsigned tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + const scalar_t updates_out = + mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); + + mu_ptr[tid] = mu_out; + nu_ptr[tid] = nu_out; + updates_ptr[tid] = updates_out; } - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; - const scalar_t updates_out = - mu_out * inv_one_minus_pow_b1 / (sqrt(nu_out * inv_one_minus_pow_b2 + eps_root) + eps); - - mu_ptr[tid] = mu_out; - nu_ptr[tid] = nu_out; - updates_ptr[tid] = updates_out; } TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, @@ -66,39 +73,61 @@ TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates, const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count)); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { - adamForwardInplaceCUDAKernel - <<>>(scalar_t(b1), - scalar_t(inv_one_minus_pow_b1), - scalar_t(b2), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates.data_ptr(), - mu.data_ptr(), - nu.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardInplaceCUDA", ([&] { + adamForwardInplaceCUDAKernel + <<>>(scalar_t(b1), + scalar_t(inv_one_minus_pow_b1), + scalar_t(b2), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates.data_ptr(), + mu.data_ptr(), + nu.data_ptr()); + })); + } return TensorArray<3>{updates, mu, nu}; } -template +template __global__ void adamForwardMuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ mu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t mu = mu_ptr[tid]; + const scalar_t mu_out = b1 * mu + (1 - b1) * updates; + mu_out_ptr[tid] = mu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t mu = mu_ptr[tid]; - const scalar_t mu_out = b1 * mu + (1 - b1) * updates; - mu_out_ptr[tid] = mu_out; } torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, @@ -107,35 +136,52 @@ torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates, auto mu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { - adamForwardMuCUDAKernel - <<>>(updates.data_ptr(), - mu.data_ptr(), - scalar_t(b1), - n, - mu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardMuCUDA", ([&] { + adamForwardMuCUDAKernel + <<>>(updates.data_ptr(), + mu.data_ptr(), + scalar_t(b1), + n, + mu_out.data_ptr()); + })); + } return mu_out; } -template +template __global__ void adamForwardNuCUDAKernel(const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ nu_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t updates = updates_ptr[tid]; + const scalar_t nu = nu_ptr[tid]; + + const scalar_t nu_out = b2 * nu + (1 - b2) * updates * updates; + nu_out_ptr[tid] = nu_out; } - - const scalar_t updates = updates_ptr[tid]; - const scalar_t nu = nu_ptr[tid]; - - const scalar_t nu_out = b2 * nu + (1 - b2) * pow(updates, 2); - nu_out_ptr[tid] = nu_out; } torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, @@ -144,20 +190,33 @@ torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates, auto nu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(updates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { - adamForwardNuCUDAKernel - <<>>(updates.data_ptr(), - nu.data_ptr(), - scalar_t(b2), - n, - nu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(updates.scalar_type(), "adamForwardNuCUDA", ([&] { + adamForwardNuCUDAKernel + <<>>(updates.data_ptr(), + nu.data_ptr(), + scalar_t(b2), + n, + nu_out.data_ptr()); + })); + } return nu_out; } -template +template __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu_ptr, const scalar_t *__restrict__ new_nu_ptr, const other_t inv_one_minus_pow_b1, @@ -166,16 +225,20 @@ __global__ void adamForwardUpdatesCUDAKernel(const scalar_t *__restrict__ new_mu const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t new_mu = new_mu_ptr[tid]; + const scalar_t new_nu = new_nu_ptr[tid]; + const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; + const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; + updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } - - const scalar_t new_mu = new_mu_ptr[tid]; - const scalar_t new_nu = new_nu_ptr[tid]; - const scalar_t mu_hat = new_mu * inv_one_minus_pow_b1; - const scalar_t nu_hat = new_nu * inv_one_minus_pow_b2; - updates_out_ptr[tid] = mu_hat / (sqrt(nu_hat + eps_root) + eps); } torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, @@ -192,37 +255,58 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu, auto updates_out = torch::empty_like(new_mu); const size_t n = getTensorPlainSize(new_mu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { - adamForwardUpdatesCUDAKernel - <<>>(new_mu.data_ptr(), - new_nu.data_ptr(), - scalar_t(inv_one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - scalar_t(eps), - scalar_t(eps_root), - n, - updates_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(new_mu.scalar_type(), "adamForwardUpdatesCUDA", ([&] { + adamForwardUpdatesCUDAKernel + <<>>(new_mu.data_ptr(), + new_nu.data_ptr(), + scalar_t(inv_one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + scalar_t(eps), + scalar_t(eps_root), + n, + updates_out.data_ptr()); + })); + } + return updates_out; } -template +template __global__ void adamBackwardMuCUDAKernel(const scalar_t *__restrict__ dmu_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dmu = dmu_ptr[tid]; + + dupdates_out_ptr[tid] = (1 - b1) * dmu; + dmu_out_ptr[tid] = b1 * dmu; } - - const scalar_t dmu = dmu_ptr[tid]; - - dupdates_out_ptr[tid] = (1 - b1) * dmu; - dmu_out_ptr[tid] = b1 * dmu; } TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, @@ -233,36 +317,53 @@ TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu, auto dmu_out = torch::empty_like(mu); const size_t n = getTensorPlainSize(dmu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { - adamBackwardMuCUDAKernel - <<>>(dmu.data_ptr(), - scalar_t(b1), - n, - dupdates_out.data_ptr(), - dmu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dmu.scalar_type(), "adamBackwardMuCUDA", ([&] { + adamBackwardMuCUDAKernel + <<>>(dmu.data_ptr(), + scalar_t(b1), + n, + dupdates_out.data_ptr(), + dmu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)}; } -template +template __global__ void adamBackwardNuCUDAKernel(const scalar_t *__restrict__ dnu_ptr, const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dnu = dnu_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + + dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; + dnu_out_ptr[tid] = b2 * dnu; } - - const scalar_t dnu = dnu_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - - dupdates_out_ptr[tid] = 2 * (1 - b2) * updates * dnu; - dnu_out_ptr[tid] = b2 * dnu; } TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, @@ -273,21 +374,35 @@ TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu, auto dnu_out = torch::empty_like(nu); const size_t n = getTensorPlainSize(dnu); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { - adamBackwardNuCUDAKernel - <<>>(dnu.data_ptr(), - updates.data_ptr(), - scalar_t(b2), - n, - dupdates_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dnu.scalar_type(), "adamForwardNuCUDA", ([&] { + adamBackwardNuCUDAKernel + <<>>(dnu.data_ptr(), + updates.data_ptr(), + scalar_t(b2), + n, + dupdates_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)}; } -template +template __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupdates_ptr, const scalar_t *__restrict__ updates_ptr, const scalar_t *__restrict__ new_mu_ptr, @@ -296,28 +411,32 @@ __global__ void adamBackwardUpdatesCUDAKernel(const scalar_t *__restrict__ dupda const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { - size_t tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= n) { - return; + const size_t toffset = (threadIdx.x + blockIdx.x * blockDim.x) * unroll_size; +#pragma unroll + for (int i = 0; i < unroll_size; ++i) { + size_t tid = toffset + i; + if (tid >= n) { + return; + } + + const scalar_t dupdates = dupdates_ptr[tid]; + const scalar_t updates = updates_ptr[tid]; + const scalar_t new_mu = new_mu_ptr[tid]; + + if (new_mu == 0) { + dnew_mu_out_ptr[tid] = 0; + dnew_nu_out_ptr[tid] = 0; + return; + } + + const scalar_t updates_div_new_mu = updates / new_mu; + + const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; + + dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; + dnew_nu_out_ptr[tid] = + -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } - - const scalar_t dupdates = dupdates_ptr[tid]; - const scalar_t updates = updates_ptr[tid]; - const scalar_t new_mu = new_mu_ptr[tid]; - - if (new_mu == 0) { - dnew_mu_out_ptr[tid] = 0; - dnew_nu_out_ptr[tid] = 0; - return; - } - - const scalar_t updates_div_new_mu = updates / new_mu; - - const scalar_t denominator = updates_div_new_mu * one_minus_pow_b1; - - dnew_mu_out_ptr[tid] = dupdates * updates_div_new_mu; - dnew_nu_out_ptr[tid] = - -dupdates * updates * denominator * 0.5 * inv_one_minus_pow_b2 * denominator; } TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, @@ -335,19 +454,35 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates, auto dnu_out = torch::empty_like(new_nu); const size_t n = getTensorPlainSize(dupdates); - const dim3 block(std::min(n, size_t(256))); - const dim3 grid((n - 1) / block.x + 1); - AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { - adamBackwardUpdatesCUDAKernel - <<>>(dupdates.data_ptr(), - updates.data_ptr(), - new_mu.data_ptr(), - scalar_t(one_minus_pow_b1), - scalar_t(inv_one_minus_pow_b2), - n, - dmu_out.data_ptr(), - dnu_out.data_ptr()); - })); + if (n < BLOCK_SIZE * UNROLL_SIZE) { + const dim3 block(std::min(n, size_t(BLOCK_SIZE))); + const dim3 grid((n - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } else { + const dim3 block(std::min(n / UNROLL_SIZE, size_t(BLOCK_SIZE))); + const dim3 grid((n / UNROLL_SIZE - 1) / block.x + 1); + AT_DISPATCH_SCALAR_TYPES_CUDA(dupdates.scalar_type(), "adamBackwardUpdatesCUDA", ([&] { + adamBackwardUpdatesCUDAKernel + <<>>(dupdates.data_ptr(), + updates.data_ptr(), + new_mu.data_ptr(), + scalar_t(one_minus_pow_b1), + scalar_t(inv_one_minus_pow_b2), + n, + dmu_out.data_ptr(), + dnu_out.data_ptr()); + })); + } return TensorArray<2>{std::move(dmu_out), std::move(dnu_out)}; } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy