Skip to content

perf(acc_op): add if condition for the element number small situations #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 6, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
perf(acc_op): use torch API for accelerating CPU codes
  • Loading branch information
JieRen98 authored and XuehaiPan committed Nov 3, 2022
commit 5ac9a39a3ce7a8e08c9094c84b45261cdf4d8979
91 changes: 31 additions & 60 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,18 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count));
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
adamForwardInplaceCPUKernel<scalar_t, scalar_t>(
scalar_t(b1),
scalar_t(inv_one_minus_pow_b1),
scalar_t(b2),
scalar_t(inv_one_minus_pow_b2),
scalar_t(eps),
scalar_t(eps_root),
n,
updates.data_ptr<scalar_t>(),
mu.data_ptr<scalar_t>(),
nu.data_ptr<scalar_t>());
}));
AT_DISPATCH_SCALAR_TYPES(
updates.scalar_type(), "adamForwardInplaceCPU", ([&] {
mu.mul_(scalar_t(b1)).add_(updates, 1 - scalar_t(b1));

nu.mul_(scalar_t(b2)).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));

updates.copy_(mu.mul(scalar_t(inv_one_minus_pow_b1))
.div_(nu.mul(inv_one_minus_pow_b2)
.add_(scalar_t(eps_root))
.sqrt_()
.add_(scalar_t(eps))));
}));
return TensorArray<3>{updates, mu, nu};
}

Expand All @@ -102,16 +100,10 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
const torch::Tensor &mu,
const pyfloat_t b1) {
auto mu_out = torch::empty_like(mu);
torch::Tensor mu_out;

const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardMuCPU", ([&] {
adamForwardMuCPUKernel<scalar_t, scalar_t>(
updates.data_ptr<scalar_t>(),
mu.data_ptr<scalar_t>(),
scalar_t(b1),
n,
mu_out.data_ptr<scalar_t>());
mu_out = mu.mul(b1).add_(updates, 1 - scalar_t(b1));
}));
return mu_out;
}
Expand All @@ -135,16 +127,11 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
const torch::Tensor &nu,
const pyfloat_t b2) {
auto nu_out = torch::empty_like(nu);
torch::Tensor nu_out;

const size_t n = getTensorPlainSize(updates);
AT_DISPATCH_SCALAR_TYPES(updates.scalar_type(), "adamForwardNuCPU", ([&] {
adamForwardNuCPUKernel<scalar_t, scalar_t>(
updates.data_ptr<scalar_t>(),
nu.data_ptr<scalar_t>(),
scalar_t(b2),
n,
nu_out.data_ptr<scalar_t>());
nu_out =
nu.mul(b2).addcmul_(updates, updates.conj(), 1 - scalar_t(b2));
}));
return nu_out;
}
Expand Down Expand Up @@ -177,24 +164,19 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
const pyuint_t count) {
using other_t = pyfloat_t;

auto updates_out = torch::empty_like(new_mu);
torch::Tensor updates_out;

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(new_mu);
AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] {
adamForwardUpdatesCPUKernel<scalar_t, scalar_t>(
new_mu.data_ptr<scalar_t>(),
new_nu.data_ptr<scalar_t>(),
scalar_t(inv_one_minus_pow_b1),
scalar_t(inv_one_minus_pow_b2),
scalar_t(eps),
scalar_t(eps_root),
n,
updates_out.data_ptr<scalar_t>());
updates_out = new_mu.mul(scalar_t(inv_one_minus_pow_b1))
.div_(new_nu.mul(scalar_t(inv_one_minus_pow_b2))
.add_(scalar_t(eps_root))
.sqrt_()
.add_(scalar_t(eps)));
}));
return updates_out;
}
Expand All @@ -218,17 +200,12 @@ TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu,
const pyfloat_t b1) {
auto dupdates_out = torch::empty_like(updates);
auto dmu_out = torch::empty_like(mu);
torch::Tensor dupdates_out;
torch::Tensor dmu_out;

const size_t n = getTensorPlainSize(dmu);
AT_DISPATCH_SCALAR_TYPES(dmu.scalar_type(), "adamBackwardMuCPU", ([&] {
adamBackwardMuCPUKernel<scalar_t, scalar_t>(
dmu.data_ptr<scalar_t>(),
scalar_t(b1),
n,
dupdates_out.data_ptr<scalar_t>(),
dmu_out.data_ptr<scalar_t>());
dupdates_out = dmu.mul(1 - scalar_t(b1));
dmu_out = dmu.mul(scalar_t(b1));
}));
return TensorArray<2>{std::move(dupdates_out), std::move(dmu_out)};
}
Expand All @@ -254,18 +231,12 @@ TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu,
const pyfloat_t b2) {
auto dupdates_out = torch::empty_like(updates);
auto dnu_out = torch::empty_like(nu);
torch::Tensor dupdates_out;
torch::Tensor dnu_out;

const size_t n = getTensorPlainSize(dnu);
AT_DISPATCH_SCALAR_TYPES(dnu.scalar_type(), "adamForwardNuCPU", ([&] {
adamBackwardNuCPUKernel<scalar_t, scalar_t>(
dnu.data_ptr<scalar_t>(),
updates.data_ptr<scalar_t>(),
scalar_t(b2),
n,
dupdates_out.data_ptr<scalar_t>(),
dnu_out.data_ptr<scalar_t>());
dupdates_out = updates.mul(2 - 2 * scalar_t(b2)).mul_(dnu);
dnu_out = dnu.mul(scalar_t(b2));
}));
return TensorArray<2>{std::move(dupdates_out), std::move(dnu_out)};
}
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy