Skip to content

Add dlpack support #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add dlpack support
  • Loading branch information
SunDoge committed Jul 14, 2023
commit 512c1bfc3e6adb4a5dc1c73986d3df90e90ce3f9
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

dlpark = { git = "https://github.com/SunDoge/dlpark", rev = "f4d45cd" }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "1.0", default-features = false }
Expand Down
42 changes: 41 additions & 1 deletion src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use alloc::sync::Arc;
use alloc::vec::Vec;

use crate::{
ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr,
ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr, ManagedRepr
};

/// Array representation trait.
Expand Down Expand Up @@ -346,6 +346,24 @@ unsafe impl<A> RawData for OwnedRepr<A> {
private_impl! {}
}


unsafe impl<A> RawData for ManagedRepr<A> {
type Elem = A;

fn _data_slice(&self) -> Option<&[A]> {
Some(self.as_slice())
}

fn _is_pointer_inbounds(&self, self_ptr: *const Self::Elem) -> bool {
let slc = self.as_slice();
let ptr = slc.as_ptr() as *mut A;
let end = unsafe { ptr.add(slc.len()) };
self_ptr >= ptr && self_ptr <= end
}

private_impl! {}
}

unsafe impl<A> RawDataMut for OwnedRepr<A> {
#[inline]
fn try_ensure_unique<D>(_: &mut ArrayBase<Self, D>)
Expand Down Expand Up @@ -382,6 +400,28 @@ unsafe impl<A> Data for OwnedRepr<A> {
}
}


unsafe impl<A> Data for ManagedRepr<A> {
#[inline]
fn into_owned<D>(self_: ArrayBase<Self, D>) -> Array<Self::Elem, D>
where
A: Clone,
D: Dimension,
{
self_.to_owned()
}

#[inline]
fn try_into_owned_nocopy<D>(
self_: ArrayBase<Self, D>,
) -> Result<Array<Self::Elem, D>, ArrayBase<Self, D>>
where
D: Dimension,
{
Err(self_)
}
}

unsafe impl<A> DataMut for OwnedRepr<A> {}

unsafe impl<A> RawDataClone for OwnedRepr<A>
Expand Down
93 changes: 93 additions & 0 deletions src/dlpack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use core::ptr::NonNull;
use std::marker::PhantomData;

use dlpark::prelude::*;

use crate::{ArrayBase, Dimension, IntoDimension, IxDyn, ManagedArray, RawData};

impl<A, S, D> ToTensor for ArrayBase<S, D>
where
A: InferDtype,
S: RawData<Elem = A>,
D: Dimension,
{
fn data_ptr(&self) -> *mut std::ffi::c_void {
self.as_ptr() as *mut std::ffi::c_void
}

fn byte_offset(&self) -> u64 {
0
}

fn device(&self) -> Device {
Device::CPU
}

fn dtype(&self) -> DataType {
A::infer_dtype()
}

fn shape(&self) -> CowIntArray {
dlpark::prelude::CowIntArray::from_owned(
self.shape().into_iter().map(|&x| x as i64).collect(),
)
}

fn strides(&self) -> Option<CowIntArray> {
Some(dlpark::prelude::CowIntArray::from_owned(
self.strides().into_iter().map(|&x| x as i64).collect(),
))
}
}

pub struct ManagedRepr<A> {
managed_tensor: ManagedTensor,
_ty: PhantomData<A>,
}

impl<A> ManagedRepr<A> {
pub fn new(managed_tensor: ManagedTensor) -> Self {
Self {
managed_tensor,
_ty: PhantomData,
}
}

pub fn as_slice(&self) -> &[A] {
self.managed_tensor.as_slice()
}

pub fn as_ptr(&self) -> *const A {
self.managed_tensor.data_ptr() as *const A
}
}

unsafe impl<A> Sync for ManagedRepr<A> where A: Sync {}
unsafe impl<A> Send for ManagedRepr<A> where A: Send {}

impl<A> FromDLPack for ManagedArray<A, IxDyn> {
fn from_dlpack(dlpack: NonNull<dlpark::ffi::DLManagedTensor>) -> Self {
Copy link
Member

@bluss bluss Mar 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function takes a raw pointer (wrapped in NonNull) and it must be an unsafe function, otherwise we can trivially violate memory safety unfortunately.

The only way to remove this requirement - the requirement of using unsafe - would be if you have a "magical" function that can take an arbitrary pointer and say whether it's a valid, live, non-mutably aliased pointer to a tensor.

Here's how to create a dangling bad pointer: NonNull::new(1 as *mut u8 as *mut dlpark::ffi::DLManagedTensor) does this code crash if we run with this pointer? I think it would..

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. from_dlpack should be unsafe, and users should use it at their own risk.

Copy link
Member

@bluss bluss Mar 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say, we normally don't commit to public dependencies that are not stable (yes, not a very fair policy since ndarray itself is not so stable.), and dlpark is a public dependency here because it becomes part of our API. It could mean it takes a long time between version bumps.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need to include dlpark as a dependency. We can create an ArrayView using ArrayView::from_shape_ptr and ManagedTensor. I can implement ToTensor for ArrayD in dlpark with a new feature ndarray. I'll do some quick experiments.

let managed_tensor = ManagedTensor::new(dlpack);
let shape: Vec<usize> = managed_tensor
.shape()
.into_iter()
.map(|x| *x as _)
.collect();

let strides: Vec<usize> = match (managed_tensor.strides(), managed_tensor.is_contiguous()) {
(Some(s), _) => s.into_iter().map(|&x| x as _).collect(),
(None, true) => dlpark::tensor::calculate_contiguous_strides(managed_tensor.shape())
.into_iter()
.map(|x| x as _)
.collect(),
(None, false) => panic!("fail"),
};
let ptr = managed_tensor.data_ptr() as *mut A;

let managed_repr = ManagedRepr::<A>::new(managed_tensor);
unsafe {
ArrayBase::from_data_ptr(managed_repr, NonNull::new_unchecked(ptr))
.with_strides_dim(strides.into_dimension(), shape.into_dimension())
}
}
}
9 changes: 9 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ mod zip;

mod dimension;

mod dlpack;

pub use crate::zip::{FoldWhile, IntoNdProducer, NdProducer, Zip};

pub use crate::layout::Layout;
Expand Down Expand Up @@ -1346,6 +1348,11 @@ pub type Array<A, D> = ArrayBase<OwnedRepr<A>, D>;
/// instead of either a view or a uniquely owned copy.
pub type CowArray<'a, A, D> = ArrayBase<CowRepr<'a, A>, D>;


/// An array from managed memory
pub type ManagedArray<A, D> = ArrayBase<ManagedRepr<A>, D>;


/// A read-only array view.
///
/// An array view represents an array or a part of it, created from
Expand Down Expand Up @@ -1419,6 +1426,8 @@ pub type RawArrayView<A, D> = ArrayBase<RawViewRepr<*const A>, D>;
pub type RawArrayViewMut<A, D> = ArrayBase<RawViewRepr<*mut A>, D>;

pub use data_repr::OwnedRepr;
pub use dlpack::ManagedRepr;


/// ArcArray's representation.
///
Expand Down
17 changes: 17 additions & 0 deletions tests/dlpack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use dlpark::prelude::*;
use ndarray::ManagedArray;

#[test]
fn test_dlpack() {
let arr = ndarray::arr1(&[1i32, 2, 3]);
let ptr = arr.as_ptr();
let dlpack = arr.to_dlpack();
let arr2 = ManagedArray::<i32, _>::from_dlpack(dlpack);
let ptr2 = arr2.as_ptr();
assert_eq!(ptr, ptr2);
// dbg!(&arr2);
let arr3 = arr2.to_owned();
// dbg!(&arr3);
let ptr3 = arr3.as_ptr();
assert_ne!(ptr2, ptr3);
}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy