diff --git a/CHANGELOG.md b/CHANGELOG.md index 62d6ab83..e99aa3b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix CUDA build for accelerated OP [@XuehaiPan](https://github.com/XuehaiPan) in [#53](https://github.com/metaopt/TorchOpt/pull/53). + ### Removed ------ diff --git a/CMakeLists.txt b/CMakeLists.txt index b4b5400c..fe93c7aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,7 @@ find_package(CUDA) if(CUDA_FOUND) message(STATUS "Found CUDA, enabling CUDA support.") enable_language(CUDA) + add_definitions(-D__CUDA_ENABLED__) cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS All) list(APPEND CUDA_NVCC_FLAGS ${CUDA_ARCH_FLAGS}) diff --git a/src/adam_op/adam_op.cpp b/src/adam_op/adam_op.cpp index bb4531ac..0fa2fa4d 100644 --- a/src/adam_op/adam_op.cpp +++ b/src/adam_op/adam_op.cpp @@ -19,7 +19,7 @@ #include #include "include/adam_op/adam_op_impl_cpu.h" -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) #include "include/adam_op/adam_op_impl_cuda.cuh" #endif @@ -30,7 +30,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (updates.device().is_cuda()) { return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root, count); @@ -44,7 +44,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates, } torch::Tensor adamForwardMu(const torch::Tensor& updates, const torch::Tensor& mu, const pyfloat_t b1) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (updates.device().is_cuda()) { return adamForwardMuCUDA(updates, mu, b1); } @@ -58,7 +58,7 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates, torch::Tensor adamForwardNu(const torch::Tensor& updates, const torch::Tensor& nu, const pyfloat_t b2) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (updates.device().is_cuda()) { return adamForwardNuCUDA(updates, nu, b2); } @@ -75,7 +75,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, const pyfloat_t b1, const pyfloat_t b2, const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (new_mu.device().is_cuda()) { return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count); } @@ -90,7 +90,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu, TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, const torch::Tensor& updates, const torch::Tensor& mu, const pyfloat_t b1) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (dmu.device().is_cuda()) { return adamBackwardMuCUDA(dmu, updates, mu, b1); } @@ -105,7 +105,7 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu, TensorArray<2> adamBackwardNu(const torch::Tensor& dnu, const torch::Tensor& updates, const torch::Tensor& nu, const pyfloat_t b2) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (dnu.device().is_cuda()) { return adamBackwardNuCUDA(dnu, updates, nu, b2); } @@ -123,7 +123,7 @@ TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates, const torch::Tensor& new_nu, const pyfloat_t b1, const pyfloat_t b2, const pyuint_t count) { -#if defined(__CUDACC__) +#if defined(__CUDA_ENABLED__) if (dupdates.device().is_cuda()) { return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2, count); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy