diff --git a/Cargo.toml b/Cargo.toml index a648b09bc..c75066c5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } +dlpark = { version = "0.3.0", optional = true } + [dev-dependencies] defmac = "0.2" quickcheck = { version = "1.0", default-features = false } @@ -73,6 +75,8 @@ rayon = ["rayon_", "std"] matrixmultiply-threading = ["matrixmultiply/threading"] +dlpack = ["dep:dlpark"] + [profile.bench] debug = true [profile.dev.package.numeric-tests] diff --git a/src/data_traits.rs b/src/data_traits.rs index acf4b0b7a..7095db73d 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -17,9 +17,12 @@ use alloc::sync::Arc; use alloc::vec::Vec; use crate::{ - ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, + ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr }; +#[cfg(feature = "dlpack")] +use crate::ManagedRepr; + /// Array representation trait. /// /// For an array that meets the invariants of the `ArrayBase` type. This trait @@ -346,6 +349,24 @@ unsafe impl RawData for OwnedRepr { private_impl! {} } +#[cfg(feature = "dlpack")] +unsafe impl RawData for ManagedRepr { + type Elem = A; + + fn _data_slice(&self) -> Option<&[A]> { + Some(self.as_slice()) + } + + fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool { + let slc = self.as_slice(); + let ptr = slc.as_ptr() as *mut A; + let end = unsafe { ptr.add(slc.len()) }; + self_ptr >= ptr && self_ptr <= end + } + + private_impl! {} +} + unsafe impl RawDataMut for OwnedRepr { #[inline] fn try_ensure_unique(_: &mut ArrayBase) @@ -382,6 +403,28 @@ unsafe impl Data for OwnedRepr { } } +#[cfg(feature = "dlpack")] +unsafe impl Data for ManagedRepr { + #[inline] + fn into_owned(self_: ArrayBase) -> Array + where + A: Clone, + D: Dimension, + { + self_.to_owned() + } + + #[inline] + fn try_into_owned_nocopy( + self_: ArrayBase, + ) -> Result, ArrayBase> + where + D: Dimension, + { + Err(self_) + } +} + unsafe impl DataMut for OwnedRepr {} unsafe impl RawDataClone for OwnedRepr diff --git a/src/dlpack.rs b/src/dlpack.rs new file mode 100644 index 000000000..d80919e11 --- /dev/null +++ b/src/dlpack.rs @@ -0,0 +1,94 @@ +use core::ptr::NonNull; +use std::marker::PhantomData; + +use dlpark::prelude::*; + +use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData}; + +impl ToTensor for ArrayBase +where + A: InferDtype, + S: RawData, + D: Dimension, +{ + fn data_ptr(&self) -> *mut std::ffi::c_void { + self.as_ptr() as *mut std::ffi::c_void + } + + fn byte_offset(&self) -> u64 { + 0 + } + + fn device(&self) -> Device { + Device::CPU + } + + fn dtype(&self) -> DataType { + A::infer_dtype() + } + + fn shape(&self) -> CowIntArray { + dlpark::prelude::CowIntArray::from_owned( + self.shape().into_iter().map(|&x| x as i64).collect(), + ) + } + + fn strides(&self) -> Option { + Some(dlpark::prelude::CowIntArray::from_owned( + self.strides().into_iter().map(|&x| x as i64).collect(), + )) + } +} + +pub struct ManagedRepr { + managed_tensor: ManagedTensor, + _ty: PhantomData, +} + +impl ManagedRepr { + pub fn new(managed_tensor: ManagedTensor) -> Self { + Self { + managed_tensor, + _ty: PhantomData, + } + } + + pub fn as_slice(&self) -> &[A] { + self.managed_tensor.as_slice() + } + + pub fn as_ptr(&self) -> *const A { + self.managed_tensor.data_ptr() as *const A + } +} + +unsafe impl Sync for ManagedRepr where A: Sync {} +unsafe impl Send for ManagedRepr where A: Send {} + +impl FromDLPack for ManagedArray { + fn from_dlpack(dlpack: NonNull) -> Self { + let managed_tensor = ManagedTensor::new(dlpack); + let shape: Vec = managed_tensor + .shape() + .into_iter() + .map(|x| *x as _) + .collect(); + + let strides: Vec = match (managed_tensor.strides(), managed_tensor.is_contiguous()) { + (Some(s), _) => s.into_iter().map(|&x| x as _).collect(), + (None, true) => managed_tensor + .calculate_contiguous_strides() + .into_iter() + .map(|x| x as _) + .collect(), + (None, false) => panic!("dlpack: invalid strides"), + }; + let ptr = managed_tensor.data_ptr() as *mut A; + + let managed_repr = ManagedRepr::::new(managed_tensor); + unsafe { + ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr)) + .with_strides_dim(strides.into_dimension(), shape.into_dimension()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 07e5ed680..2927f80fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,9 @@ mod zip; mod dimension; +#[cfg(feature = "dlpack")] +mod dlpack; + pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip}; pub use crate::layout::Layout; @@ -1346,6 +1349,12 @@ pub type Array = ArrayBase, D>; /// instead of either a view or a uniquely owned copy. pub type CowArray<'a, A, D> = ArrayBase, D>; + +/// An array from managed memory +#[cfg(feature = "dlpack")] +pub type ManagedArray = ArrayBase, D>; + + /// A read-only array view. /// /// An array view represents an array or a part of it, created from @@ -1420,6 +1429,10 @@ pub type RawArrayViewMut = ArrayBase, D>; pub use data_repr::OwnedRepr; +#[cfg(feature = "dlpack")] +pub use dlpack::ManagedRepr; + + /// ArcArray's representation. /// /// *Don’t use this type directly—use the type alias diff --git a/tests/dlpack.rs b/tests/dlpack.rs new file mode 100644 index 000000000..c0ba6e307 --- /dev/null +++ b/tests/dlpack.rs @@ -0,0 +1,17 @@ +#![cfg(feature = "dlpack")] + +use dlpark::prelude::*; +use ndarray::ManagedArray; + +#[test] +fn test_dlpack() { + let arr = ndarray::arr1(&[1i32, 2, 3]); + let ptr = arr.as_ptr(); + let dlpack = arr.into_dlpack(); + let arr2 = ManagedArray::::from_dlpack(dlpack); + let ptr2 = arr2.as_ptr(); + assert_eq!(ptr, ptr2); + let arr3 = arr2.to_owned(); + let ptr3 = arr3.as_ptr(); + assert_ne!(ptr2, ptr3); +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy