Skip to content

feat: add Python implementation of accelerated OP #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: make CXX/CUDA extenstion optional
  • Loading branch information
XuehaiPan committed Nov 25, 2022
commit 26dc3349c68c4cb9f70253b7de8834c011be7906
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Environment :: GPU",
"Environment :: GPU :: NVIDIA CUDA",
"Intended Audience :: Developers",
Expand Down
49 changes: 37 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,50 @@ def build_extension(self, ext):
os.chdir(HERE)


CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1'
LINUX = platform.system() == 'Linux'
MACOS = platform.system() == 'Darwin'
WINDOWS = platform.system() == 'Windows'
ext_kwargs = dict(
cmdclass={'build_ext': cmake_build_ext},
ext_modules=[
CMakeExtension(
'torchopt._C',
source_dir=HERE,
optional=not (LINUX and CIBUILDWHEEL),
)
],
)

TORCHOPT_NO_EXTENSIONS = (
bool(os.getenv('TORCHOPT_NO_EXTENSIONS', '')) or WINDOWS or (MACOS and CIBUILDWHEEL)
)
if TORCHOPT_NO_EXTENSIONS:
ext_kwargs.clear()


VERSION_CONTENT = None
if not version.__release__:
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
VERSION_FILE.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
f"__version__ = '{version.__version__}'",
string=VERSION_CONTENT,
),
encoding='UTF-8',
)

try:
if not version.__release__:
try:
VERSION_CONTENT = VERSION_FILE.read_text(encoding='UTF-8')
VERSION_FILE.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
r"__version__ = '{}'".format(version.__version__),
string=VERSION_CONTENT,
),
encoding='UTF-8',
)
except OSError:
VERSION_CONTENT = None

setup(
version=version.__version__,
package_data={'sharedlib': ['*.so', '*.pyd']},
include_package_data=True,
cmdclass={'build_ext': cmake_build_ext},
ext_modules=[CMakeExtension('torchopt._C', source_dir=HERE)],
**ext_kwargs,
)
finally:
if VERSION_CONTENT is not None:
Expand Down
12 changes: 6 additions & 6 deletions src/adam_op/adam_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,19 @@ void buildSubmodule(py::module &mod) { // NOLINT
py::arg("eps"),
py::arg("eps_root"),
py::arg("count"));
m.def("forwardMu",
m.def("forward_mu",
&adamForwardMu,
"Adam forward mu",
py::arg("updates"),
py::arg("mu"),
py::arg("b1"));
m.def("forwardNu",
m.def("forward_nu",
&adamForwardNu,
"Adam forward nu",
py::arg("updates"),
py::arg("nu"),
py::arg("b2"));
m.def("forwardUpdates",
m.def("forward_updates",
&adamForwardUpdates,
"Adam forward updates",
py::arg("new_mu"),
Expand All @@ -184,21 +184,21 @@ void buildSubmodule(py::module &mod) { // NOLINT
py::arg("eps"),
py::arg("eps_root"),
py::arg("count"));
m.def("backwardMu",
m.def("backward_mu",
&adamBackwardMu,
"Adam backward mu",
py::arg("dmu"),
py::arg("updates"),
py::arg("mu"),
py::arg("b1"));
m.def("backwardNu",
m.def("backward_nu",
&adamBackwardNu,
"Adam backward nu",
py::arg("dnu"),
py::arg("updates"),
py::arg("nu"),
py::arg("b1"));
m.def("backwardUpdates",
m.def("backward_updates",
&adamBackwardUpdates,
"Adam backward updates",
py::arg("dupdates"),
Expand Down
13 changes: 4 additions & 9 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,11 @@ torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
const pyfloat_t eps_root,
const pyuint_t count) {
using other_t = pyfloat_t;
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count));
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

auto updates_out = torch::empty_like(new_mu);

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(new_mu);
AT_DISPATCH_SCALAR_TYPES(new_mu.scalar_type(), "adamForwardUpdatesCPU", ([&] {
adamForwardUpdatesCPUKernel<scalar_t, scalar_t>(
Expand Down Expand Up @@ -325,14 +322,12 @@ TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates,
const pyfloat_t b2,
const pyuint_t count) {
using other_t = pyfloat_t;
const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

auto dmu_out = torch::empty_like(new_mu);
auto dnu_out = torch::empty_like(new_nu);

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(dupdates);
AT_DISPATCH_SCALAR_TYPES(dupdates.scalar_type(), "adamBackwardUpdatesCPU", ([&] {
adamBackwardUpdatesCPUKernel<scalar_t, scalar_t>(
Expand Down
13 changes: 4 additions & 9 deletions src/adam_op/adam_op_impl_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,11 @@ torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
const pyfloat_t eps_root,
const pyuint_t count) {
using other_t = pyfloat_t;
const other_t inv_one_minus_pow_b1 = 1 / (1 - std::pow(b1, count));
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

auto updates_out = torch::empty_like(new_mu);

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b1 = 1 / one_minus_pow_b1;
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(new_mu);
const dim3 block(std::min(n, size_t(256)));
const dim3 grid((n - 1) / block.x + 1);
Expand Down Expand Up @@ -331,14 +328,12 @@ TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const pyfloat_t b2,
const pyuint_t count) {
using other_t = pyfloat_t;
const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t inv_one_minus_pow_b2 = 1 / (1 - std::pow(b2, count));

auto dmu_out = torch::empty_like(new_mu);
auto dnu_out = torch::empty_like(new_nu);

const other_t one_minus_pow_b1 = 1 - std::pow(b1, count);
const other_t one_minus_pow_b2 = 1 - std::pow(b2, count);
const other_t inv_one_minus_pow_b2 = 1 / one_minus_pow_b2;

const size_t n = getTensorPlainSize(dupdates);
const dim3 block(std::min(n, size_t(256)));
const dim3 grid((n - 1) / block.x + 1);
Expand Down
12 changes: 6 additions & 6 deletions torchopt/_C/adam_op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def forward_(
eps_root: float,
count: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
def forwardMu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ...
def forwardNu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ...
def forwardUpdates(
def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor: ...
def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor: ...
def forward_updates(
new_mu: torch.Tensor,
new_nu: torch.Tensor,
b1: float,
Expand All @@ -41,13 +41,13 @@ def forwardUpdates(
eps_root: float,
count: int,
) -> torch.Tensor: ...
def backwardMu(
def backward_mu(
dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def backwardNu(
def backward_nu(
dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def backwardUpdates(
def backward_updates(
dupdates: torch.Tensor,
updates: torch.Tensor,
new_mu: torch.Tensor,
Expand Down
15 changes: 15 additions & 0 deletions torchopt/accelerated_op/_src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Python implementation of accelerated ops."""
114 changes: 114 additions & 0 deletions torchopt/accelerated_op/_src/adam_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Python implementation of accelerated AdamOp."""

# pylint: disable=invalid-name,too-many-arguments,unused-argument

from typing import Tuple

import torch


def forward_(
updates: torch.Tensor,
mu: torch.Tensor,
nu: torch.Tensor,
b1: float,
b2: float,
eps: float,
eps_root: float,
count: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adam forward inplace."""
inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count))
inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count))

mu = mu.mul_(b1).add_(updates, alpha=1.0 - b1)
nu = nu.mul_(b2).add_(updates.square(), alpha=1.0 - b2)
updates.data.copy_(
mu.mul(inv_one_minus_pow_b1).div_(
nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps)
)
)
return updates, mu, nu


def forward_mu(updates: torch.Tensor, mu: torch.Tensor, b1: float) -> torch.Tensor:
"""Adam forward mu."""
return mu.mul(b1).add_(updates, alpha=1.0 - b1)


def forward_nu(updates: torch.Tensor, nu: torch.Tensor, b2: float) -> torch.Tensor:
"""Adam forward nu."""
return nu.mul(b2).add_(updates.square(), alpha=1.0 - b2)


def forward_updates(
new_mu: torch.Tensor,
new_nu: torch.Tensor,
b1: float,
b2: float,
eps: float,
eps_root: float,
count: int,
) -> torch.Tensor:
"""Adam forward updates."""
inv_one_minus_pow_b1 = 1.0 / (1.0 - pow(b1, count))
inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count))
return new_mu.mul(inv_one_minus_pow_b1).div_(
new_nu.mul(inv_one_minus_pow_b2).add_(eps_root).sqrt_().add_(eps)
)


def backward_mu(
dmu: torch.Tensor, updates: torch.Tensor, mu: torch.Tensor, b1: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adam backward mu."""
dupdates = dmu.mul(1.0 - b1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.0 -> double or float or?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be cast to the same floating point type as dmu.dtype, usually torch.float32.

dmu = dmu.mul(b1)
return dupdates, dmu


def backward_nu(
dnu: torch.Tensor, updates: torch.Tensor, nu: torch.Tensor, b2: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adam backward nu."""
dupdates = updates.mul(dnu).mul_(2.0 * (1.0 - b2))
dnu = dnu.mul(b2)
return dupdates, dnu


def backward_updates(
dupdates: torch.Tensor,
updates: torch.Tensor,
new_mu: torch.Tensor,
new_nu: torch.Tensor,
b1: float,
b2: float,
count: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Adam backward updates."""
one_minus_pow_b1 = 1.0 - pow(b1, count)
inv_one_minus_pow_b2 = 1.0 / (1.0 - pow(b2, count))

updates_div_new_mu = updates.div(new_mu)
denominator = updates_div_new_mu.mul(one_minus_pow_b1)
dnew_mu_out = dupdates.mul(updates_div_new_mu)
dnew_nu_out = dupdates.mul(updates).mul_(denominator.square()).mul_(-0.5 * inv_one_minus_pow_b2)

mask = new_mu == 0
dnew_mu_out[mask] = 0
dnew_nu_out[mask] = 0
return dnew_mu_out, dnew_nu_out
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy