diff --git a/.gitignore b/.gitignore index e9b5ca25b..44b22a5e8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ target/ # Editor settings .vscode .idea + +# Apple details +**/.DS_Store diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index a9dca7e83..f604ae091 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -280,7 +280,7 @@ fn gen_mat_mul() cv = c.view_mut(); } - let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv; + let answer_part: Array = alpha * reference_mat_mul(&av, &bv) + beta * &cv; answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); general_mat_mul(alpha, &av, &bv, beta, &mut cv); diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 3a54a52fb..7f80a637f 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -13,7 +13,7 @@ use ndarray::prelude::*; /// it corresponds to their order in memory. /// /// Errors if array has a 0-stride axis -fn regularize(a: &mut Array) -> Result<(), &'static str> +fn regularize(a: &mut ArrayRef) -> Result<(), &'static str> where D: Dimension, A: ::std::fmt::Debug, diff --git a/examples/convo.rs b/examples/convo.rs index a59795e12..79e8ab6b6 100644 --- a/examples/convo.rs +++ b/examples/convo.rs @@ -14,7 +14,7 @@ type Kernel3x3 = [[A; 3]; 3]; #[inline(never)] #[cfg(feature = "std")] -fn conv_3x3(a: &ArrayView2<'_, F>, out: &mut ArrayViewMut2<'_, F>, kernel: &Kernel3x3) +fn conv_3x3(a: &ArrayRef2, out: &mut ArrayRef2, kernel: &Kernel3x3) where F: Float { let (n, m) = a.dim(); diff --git a/examples/functions_and_traits.rs b/examples/functions_and_traits.rs new file mode 100644 index 000000000..dc8f73da4 --- /dev/null +++ b/examples/functions_and_traits.rs @@ -0,0 +1,178 @@ +//! Examples of how to write functions and traits that operate on `ndarray` types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], the owner of the layout that describes an array in memory; +//! this includes [`ndarray::Array`], [`ndarray::ArcArray`], [`ndarray::ArrayView`], +//! [`ndarray::RawArrayView`], and other variants. +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array. +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array. +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind. +//! +//! Below, we illustrate how to write functions and traits for most variants of these types. + +use ndarray::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; + +/// Take an array with the most basic requirements. +/// +/// This function takes its data as owning. It is very rare that a user will need to specifically +/// take a reference to an `ArrayBase`, rather than to one of the other four types. +#[rustfmt::skip] +fn takes_base_raw(arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsRef` + takes_rawref(arr.as_ref()); // Caller uses `.as_ref` + takes_rawref_asref(&arr); // Implementor uses `.as_ref` + takes_layout(arr.as_ref()); // Caller uses `.as_ref` + takes_layout_asref(&arr); // Implementor uses `.as_ref` + + arr +} + +/// Similar to above, but allow us to read the underlying data. +#[rustfmt::skip] +fn takes_base_raw_mut(mut arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsMut` + takes_rawref_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_rawref_asmut(&mut arr); // Implementor uses `.as_mut` + takes_layout_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_layout_asmut(&mut arr); // Implementor uses `.as_mut` + + arr +} + +/// Now take an array whose data is safe to read. +#[allow(dead_code)] +fn takes_base(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw(arr); + + // No need for AsRef, since data is safe + takes_arrref(&arr); + takes_rawref(&arr); + takes_rawref_asref(&arr); + takes_layout(&arr); + takes_layout_asref(&arr); + + arr +} + +/// Now, an array whose data is safe to read and that we can mutate. +/// +/// Notice that we include now a trait bound on `D: Dimension`; this is necessary in order +/// for the `ArrayBase` to dereference to an `ArrayRef` (or to any of the other types). +#[allow(dead_code)] +fn takes_base_mut(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw_mut(arr); + + // No need for AsMut, since data is safe + takes_arrref_mut(&mut arr); + takes_rawref_mut(&mut arr); + takes_rawref_asmut(&mut arr); + takes_layout_mut(&mut arr); + takes_layout_asmut(&mut arr); + + arr +} + +/// Now for new stuff: we want to read (but not alter) any array whose data is safe to read. +/// +/// This is probably the most common functionality that one would want to write. +/// As we'll see below, calling this function is very simple for `ArrayBase`. +fn takes_arrref(arr: &ArrayRef) +{ + // No need for AsRef, since data is safe + takes_rawref(arr); + takes_rawref_asref(arr); + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Now we want any array whose data is safe to mutate. +/// +/// **Importantly**, any array passed to this function is guaranteed to uniquely point to its data. +/// As a result, passing a shared array to this function will **silently** un-share the array. +#[allow(dead_code)] +fn takes_arrref_mut(arr: &mut ArrayRef) +{ + // Immutable call + takes_arrref(arr); + + // No need for AsMut, since data is safe + takes_rawref_mut(arr); + takes_rawref_asmut(arr); + takes_layout_mut(arr); + takes_rawref_asmut(arr); +} + +/// Now, we no longer care about whether we can safely read data. +/// +/// This is probably the rarest type to deal with, since `LayoutRef` can access and modify an array's +/// shape and strides, and even do in-place slicing. As a result, `RawRef` is only for functionality +/// that requires unsafe data access, something that `LayoutRef` can't do. +/// +/// Writing functions and traits that deal with `RawRef`s and `LayoutRef`s can be done two ways: +/// 1. Directly on the types; calling these functions on arrays whose data are not known to be safe +/// to dereference (i.e., raw array views or `ArrayBase`) must explicitly call `.as_ref()`. +/// 2. Via a generic with `: AsRef>`; doing this will allow direct calling for all `ArrayBase` and +/// `ArrayRef` instances. +/// We'll demonstrate #1 here for both immutable and mutable references, then #2 directly below. +#[allow(dead_code)] +fn takes_rawref(arr: &RawRef) +{ + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Mutable, directly take `RawRef` +#[allow(dead_code)] +fn takes_rawref_mut(arr: &mut RawRef) +{ + takes_layout(arr); + takes_layout_asmut(arr); +} + +/// Immutable, take a generic that implements `AsRef` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asref(_arr: &T) +where T: AsRef> +{ + takes_layout(_arr.as_ref()); + takes_layout_asref(_arr.as_ref()); +} + +/// Mutable, take a generic that implements `AsMut` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asmut(_arr: &mut T) +where T: AsMut> +{ + takes_layout_mut(_arr.as_mut()); + takes_layout_asmut(_arr.as_mut()); +} + +/// Finally, there's `LayoutRef`: this type provides read and write access to an array's *structure*, but not its *data*. +/// +/// Practically, this means that functions that only read/modify an array's shape or strides, +/// such as checking dimensionality or slicing, should take `LayoutRef`. +/// +/// Like `RawRef`, functions can be written either directly on `LayoutRef` or as generics with `: AsRef>>`. +#[allow(dead_code)] +fn takes_layout(_arr: &LayoutRef) {} + +/// Mutable, directly take `LayoutRef` +#[allow(dead_code)] +fn takes_layout_mut(_arr: &mut LayoutRef) {} + +/// Immutable, take a generic that implements `AsRef` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asref>, A, D>(_arr: &T) {} + +/// Mutable, take a generic that implements `AsMut` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asmut>, A, D>(_arr: &mut T) {} + +fn main() {} diff --git a/src/alias_asref.rs b/src/alias_asref.rs new file mode 100644 index 000000000..ab78af605 --- /dev/null +++ b/src/alias_asref.rs @@ -0,0 +1,359 @@ +use crate::{ + iter::Axes, + ArrayBase, + Axis, + AxisDescription, + Dimension, + NdIndex, + RawArrayView, + RawData, + RawDataMut, + Slice, + SliceArg, +}; + +/// Functions coming from RawRef +impl, D: Dimension> ArrayBase +{ + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view(); + /// let p = a.get_ptr((0, 1)).unwrap(); + /// + /// assert_eq!(unsafe { *p }, 2.); + /// ``` + pub fn get_ptr(&self, index: I) -> Option<*const A> + where I: NdIndex + { + self.as_raw_ref().get_ptr(index) + } + + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let mut a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view_mut(); + /// let p = a.get_mut_ptr((0, 1)).unwrap(); + /// + /// unsafe { + /// *p = 5.; + /// } + /// + /// assert_eq!(a.get((0, 1)), Some(&5.)); + /// ``` + pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> + where + S: RawDataMut, + I: NdIndex, + { + self.as_raw_ref_mut().get_mut_ptr(index) + } + + /// Return a pointer to the first element in the array. + /// + /// Raw access to array elements needs to follow the strided indexing + /// scheme: an element at multi-index *I* in an array with strides *S* is + /// located at offset + /// + /// *Σ0 ≤ k < d Ik × Sk* + /// + /// where *d* is `self.ndim()`. + #[inline(always)] + pub fn as_ptr(&self) -> *const A + { + self.as_raw_ref().as_ptr() + } + + /// Return a raw view of the array. + #[inline] + pub fn raw_view(&self) -> RawArrayView + { + self.as_raw_ref().raw_view() + } +} + +/// Functions coming from LayoutRef +impl ArrayBase +{ + /// Slice the array in place without changing the number of dimensions. + /// + /// In particular, if an axis is sliced with an index, the axis is + /// collapsed, as in [`.collapse_axis()`], rather than removed, as in + /// [`.slice_move()`] or [`.index_axis_move()`]. + /// + /// [`.collapse_axis()`]: Self::collapse_axis + /// [`.slice_move()`]: Self::slice_move + /// [`.index_axis_move()`]: Self::index_axis_move + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). + /// + /// **Panics** in the following cases: + /// + /// - if an index is out of bounds + /// - if a step size is zero + /// - if [`NewAxis`](`crate::SliceInfoElem::NewAxis`) is in `info`, e.g. if `NewAxis` was + /// used in the [`s!`] macro + /// - if `D` is `IxDyn` and `info` does not match the number of array axes + #[track_caller] + pub fn slice_collapse(&mut self, info: I) + where I: SliceArg + { + self.as_layout_ref_mut().slice_collapse(info); + } + + /// Slice the array in place along the specified axis. + /// + /// **Panics** if an index is out of bounds or step size is zero.
+ /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) + { + self.as_layout_ref_mut().slice_axis_inplace(axis, indices); + } + + /// Slice the array in place, with a closure specifying the slice for each + /// axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + #[track_caller] + pub fn slice_each_axis_inplace(&mut self, f: F) + where F: FnMut(AxisDescription) -> Slice + { + self.as_layout_ref_mut().slice_each_axis_inplace(f); + } + + /// Selects `index` along the axis, collapsing the axis into length one. + /// + /// **Panics** if `axis` or `index` is out of bounds. + #[track_caller] + pub fn collapse_axis(&mut self, axis: Axis, index: usize) + { + self.as_layout_ref_mut().collapse_axis(axis, index); + } + + /// Return `true` if the array data is laid out in contiguous “C order” in + /// memory (where the last index is the most rapidly varying). + /// + /// Return `false` otherwise, i.e. the array is possibly not + /// contiguous in memory, it has custom strides, etc. + pub fn is_standard_layout(&self) -> bool + { + self.as_layout_ref().is_standard_layout() + } + + /// Return true if the array is known to be contiguous. + pub(crate) fn is_contiguous(&self) -> bool + { + self.as_layout_ref().is_contiguous() + } + + /// Return an iterator over the length and stride of each axis. + pub fn axes(&self) -> Axes<'_, D> + { + self.as_layout_ref().axes() + } + + /* + /// Return the axis with the least stride (by absolute value) + pub fn min_stride_axis(&self) -> Axis { + self.dim.min_stride_axis(&self.strides) + } + */ + + /// Return the axis with the greatest stride (by absolute value), + /// preferring axes with len > 1. + pub fn max_stride_axis(&self) -> Axis + { + self.as_layout_ref().max_stride_axis() + } + + /// Reverse the stride of `axis`. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn invert_axis(&mut self, axis: Axis) + { + self.as_layout_ref_mut().invert_axis(axis); + } + + /// Swap axes `ax` and `bx`. + /// + /// This does not move any data, it just adjusts the array’s dimensions + /// and strides. + /// + /// **Panics** if the axes are out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let mut a = arr2(&[[1., 2., 3.]]); + /// a.swap_axes(0, 1); + /// assert!( + /// a == arr2(&[[1.], [2.], [3.]]) + /// ); + /// ``` + #[track_caller] + pub fn swap_axes(&mut self, ax: usize, bx: usize) + { + self.as_layout_ref_mut().swap_axes(ax, bx); + } + + /// If possible, merge in the axis `take` to `into`. + /// + /// Returns `true` iff the axes are now merged. + /// + /// This method merges the axes if movement along the two original axes + /// (moving fastest along the `into` axis) can be equivalently represented + /// as movement along one (merged) axis. Merging the axes preserves this + /// order in the merged axis. If `take` and `into` are the same axis, then + /// the axis is "merged" if its length is ≤ 1. + /// + /// If the return value is `true`, then the following hold: + /// + /// * The new length of the `into` axis is the product of the original + /// lengths of the two axes. + /// + /// * The new length of the `take` axis is 0 if the product of the original + /// lengths of the two axes is 0, and 1 otherwise. + /// + /// If the return value is `false`, then merging is not possible, and the + /// original shape and strides have been preserved. + /// + /// Note that the ordering constraint means that if it's possible to merge + /// `take` into `into`, it's usually not possible to merge `into` into + /// `take`, and vice versa. + /// + /// ``` + /// use ndarray::Array3; + /// use ndarray::Axis; + /// + /// let mut a = Array3::::zeros((2, 3, 4)); + /// assert!(a.merge_axes(Axis(1), Axis(2))); + /// assert_eq!(a.shape(), &[2, 1, 12]); + /// ``` + /// + /// ***Panics*** if an axis is out of bounds. + #[track_caller] + pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool + { + self.as_layout_ref_mut().merge_axes(take, into) + } + + /// Return the total number of elements in the array. + pub fn len(&self) -> usize + { + self.as_layout_ref().len() + } + + /// Return the length of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn len_of(&self, axis: Axis) -> usize + { + self.as_layout_ref().len_of(axis) + } + + /// Return whether the array has any elements + pub fn is_empty(&self) -> bool + { + self.as_layout_ref().is_empty() + } + + /// Return the number of dimensions (axes) in the array + pub fn ndim(&self) -> usize + { + self.as_layout_ref().ndim() + } + + /// Return the shape of the array in its “pattern” form, + /// an integer in the one-dimensional case, tuple in the n-dimensional cases + /// and so on. + pub fn dim(&self) -> D::Pattern + { + self.as_layout_ref().dim() + } + + /// Return the shape of the array as it's stored in the array. + /// + /// This is primarily useful for passing to other `ArrayBase` + /// functions, such as when creating another array of the same + /// shape and dimensionality. + /// + /// ``` + /// use ndarray::Array; + /// + /// let a = Array::from_elem((2, 3), 5.); + /// + /// // Create an array of zeros that's the same shape and dimensionality as `a`. + /// let b = Array::::zeros(a.raw_dim()); + /// ``` + pub fn raw_dim(&self) -> D + { + self.as_layout_ref().raw_dim() + } + + /// Return the shape of the array as a slice. + /// + /// Note that you probably don't want to use this to create an array of the + /// same shape as another array because creating an array with e.g. + /// [`Array::zeros()`](ArrayBase::zeros) using a shape of type `&[usize]` + /// results in a dynamic-dimensional array. If you want to create an array + /// that has the same shape and dimensionality as another array, use + /// [`.raw_dim()`](ArrayBase::raw_dim) instead: + /// + /// ```rust + /// use ndarray::{Array, Array2}; + /// + /// let a = Array2::::zeros((3, 4)); + /// let shape = a.shape(); + /// assert_eq!(shape, &[3, 4]); + /// + /// // Since `a.shape()` returned `&[usize]`, we get an `ArrayD` instance: + /// let b = Array::zeros(shape); + /// assert_eq!(a.clone().into_dyn(), b); + /// + /// // To get the same dimension type, use `.raw_dim()` instead: + /// let c = Array::zeros(a.raw_dim()); + /// assert_eq!(a, c); + /// ``` + pub fn shape(&self) -> &[usize] + { + self.as_layout_ref().shape() + } + + /// Return the strides of the array as a slice. + pub fn strides(&self) -> &[isize] + { + self.as_layout_ref().strides() + } + + /// Return the stride of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn stride_of(&self, axis: Axis) -> isize + { + self.as_layout_ref().stride_of(axis) + } +} diff --git a/src/aliases.rs b/src/aliases.rs index 5df0c95ec..7f897304b 100644 --- a/src/aliases.rs +++ b/src/aliases.rs @@ -2,7 +2,7 @@ //! use crate::dimension::Dim; -use crate::{ArcArray, Array, ArrayView, ArrayViewMut, Ix, IxDynImpl}; +use crate::{ArcArray, Array, ArrayRef, ArrayView, ArrayViewMut, Ix, IxDynImpl, LayoutRef}; /// Create a zero-dimensional index #[allow(non_snake_case)] @@ -123,6 +123,40 @@ pub type Array6
= Array; /// dynamic-dimensional array pub type ArrayD = Array; +/// zero-dimensional array reference +pub type ArrayRef0 = ArrayRef; +/// one-dimensional array reference +pub type ArrayRef1 = ArrayRef; +/// two-dimensional array reference +pub type ArrayRef2 = ArrayRef; +/// three-dimensional array reference +pub type ArrayRef3 = ArrayRef; +/// four-dimensional array reference +pub type ArrayRef4 = ArrayRef; +/// five-dimensional array reference +pub type ArrayRef5 = ArrayRef; +/// six-dimensional array reference +pub type ArrayRef6 = ArrayRef; +/// dynamic-dimensional array reference +pub type ArrayRefD = ArrayRef; + +/// zero-dimensional layout reference +pub type LayoutRef0 = LayoutRef; +/// one-dimensional layout reference +pub type LayoutRef1 = LayoutRef; +/// two-dimensional layout reference +pub type LayoutRef2 = LayoutRef; +/// three-dimensional layout reference +pub type LayoutRef3 = LayoutRef; +/// four-dimensional layout reference +pub type LayoutRef4 = LayoutRef; +/// five-dimensional layout reference +pub type LayoutRef5 = LayoutRef; +/// six-dimensional layout reference +pub type LayoutRef6 = LayoutRef; +/// dynamic-dimensional layout reference +pub type LayoutRefD = LayoutRef; + /// zero-dimensional array view pub type ArrayView0<'a, A> = ArrayView<'a, A, Ix0>; /// one-dimensional array view diff --git a/src/array_approx.rs b/src/array_approx.rs index 493864c7e..e350c3883 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -3,20 +3,16 @@ mod approx_methods { use crate::imp_prelude::*; - impl ArrayBase - where - S: Data, - D: Dimension, + impl ArrayRef { /// A test for equality that uses the elementwise absolute difference to compute the /// approximate equality of two arrays. /// /// **Requires crate feature `"approx"`** - pub fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool + pub fn abs_diff_eq(&self, other: &ArrayRef, epsilon: A::Epsilon) -> bool where - A: ::approx::AbsDiffEq, + A: ::approx::AbsDiffEq, A::Epsilon: Clone, - S2: Data, { >::abs_diff_eq(self, other, epsilon) } @@ -25,11 +21,10 @@ mod approx_methods /// apart; and the absolute difference otherwise. /// /// **Requires crate feature `"approx"`** - pub fn relative_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool + pub fn relative_eq(&self, other: &ArrayRef, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool where - A: ::approx::RelativeEq, + A: ::approx::RelativeEq, A::Epsilon: Clone, - S2: Data, { >::relative_eq(self, other, epsilon, max_relative) } @@ -44,12 +39,10 @@ macro_rules! impl_approx_traits { use $approx::{AbsDiffEq, RelativeEq, UlpsEq}; #[doc = $doc] - impl AbsDiffEq> for ArrayBase + impl AbsDiffEq> for ArrayRef where A: AbsDiffEq, A::Epsilon: Clone, - S: Data, - S2: Data, D: Dimension, { type Epsilon = A::Epsilon; @@ -58,7 +51,7 @@ macro_rules! impl_approx_traits { A::default_epsilon() } - fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + fn abs_diff_eq(&self, other: &ArrayRef, epsilon: A::Epsilon) -> bool { if self.shape() != other.shape() { return false; } @@ -70,13 +63,31 @@ macro_rules! impl_approx_traits { } #[doc = $doc] - impl RelativeEq> for ArrayBase + impl AbsDiffEq> for ArrayBase where - A: RelativeEq, + A: AbsDiffEq, A::Epsilon: Clone, S: Data, S2: Data, D: Dimension, + { + type Epsilon = A::Epsilon; + + fn default_epsilon() -> A::Epsilon { + A::default_epsilon() + } + + fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + (**self).abs_diff_eq(other, epsilon) + } + } + + #[doc = $doc] + impl RelativeEq> for ArrayRef + where + A: RelativeEq, + A::Epsilon: Clone, + D: Dimension, { fn default_max_relative() -> A::Epsilon { A::default_max_relative() @@ -84,7 +95,7 @@ macro_rules! impl_approx_traits { fn relative_eq( &self, - other: &ArrayBase, + other: &ArrayRef, epsilon: A::Epsilon, max_relative: A::Epsilon, ) -> bool { @@ -99,13 +110,34 @@ macro_rules! impl_approx_traits { } #[doc = $doc] - impl UlpsEq> for ArrayBase + impl RelativeEq> for ArrayBase where - A: UlpsEq, + A: RelativeEq, A::Epsilon: Clone, S: Data, S2: Data, D: Dimension, + { + fn default_max_relative() -> A::Epsilon { + A::default_max_relative() + } + + fn relative_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_relative: A::Epsilon, + ) -> bool { + (**self).relative_eq(other, epsilon, max_relative) + } + } + + #[doc = $doc] + impl UlpsEq> for ArrayRef + where + A: UlpsEq, + A::Epsilon: Clone, + D: Dimension, { fn default_max_ulps() -> u32 { A::default_max_ulps() @@ -113,7 +145,7 @@ macro_rules! impl_approx_traits { fn ulps_eq( &self, - other: &ArrayBase, + other: &ArrayRef, epsilon: A::Epsilon, max_ulps: u32, ) -> bool { @@ -127,6 +159,29 @@ macro_rules! impl_approx_traits { } } + #[doc = $doc] + impl UlpsEq> for ArrayBase + where + A: UlpsEq, + A::Epsilon: Clone, + S: Data, + S2: Data, + D: Dimension, + { + fn default_max_ulps() -> u32 { + A::default_max_ulps() + } + + fn ulps_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_ulps: u32, + ) -> bool { + (**self).ulps_eq(other, epsilon, max_ulps) + } + } + #[cfg(test)] mod tests { use crate::prelude::*; diff --git a/src/arrayformat.rs b/src/arrayformat.rs index 1a3b714c3..7e5e1b1c9 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -6,7 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. use super::{ArrayBase, ArrayView, Axis, Data, Dimension, NdProducer}; -use crate::aliases::{Ix1, IxDyn}; +use crate::{ + aliases::{Ix1, IxDyn}, + ArrayRef, +}; use alloc::format; use std::fmt; @@ -112,13 +115,12 @@ fn format_with_overflow( Ok(()) } -fn format_array( - array: &ArrayBase, f: &mut fmt::Formatter<'_>, format: F, fmt_opt: &FormatOptions, +fn format_array( + array: &ArrayRef, f: &mut fmt::Formatter<'_>, format: F, fmt_opt: &FormatOptions, ) -> fmt::Result where F: FnMut(&A, &mut fmt::Formatter<'_>) -> fmt::Result + Clone, D: Dimension, - S: Data, { // Cast into a dynamically dimensioned view // This is required to be able to use `index_axis` for the recursive case @@ -174,6 +176,18 @@ where /// The array is shown in multiline style. impl fmt::Display for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `Display` and apply the formatting parameters +/// used to each element. +/// +/// The array is shown in multiline style. +impl fmt::Display for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -188,6 +202,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::Debug for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `Debug` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::Debug for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -216,6 +242,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::LowerExp for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `LowerExp` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::LowerExp for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -230,6 +268,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::UpperExp for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `UpperExp` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::UpperExp for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -237,12 +287,25 @@ where S: Data format_array(self, f, <_>::fmt, &fmt_opt) } } + /// Format the array using `LowerHex` and apply the formatting parameters used /// to each element. /// /// The array is shown in multiline style. impl fmt::LowerHex for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `LowerHex` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::LowerHex for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -257,6 +320,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::Binary for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `Binary` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::Binary for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/arraytraits.rs b/src/arraytraits.rs index e68b5d56a..1a843af56 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -19,6 +19,7 @@ use std::{iter::FromIterator, slice}; use crate::imp_prelude::*; use crate::Arc; +use crate::LayoutRef; use crate::{ dimension, iter::{Iter, IterMut}, @@ -37,11 +38,10 @@ pub(crate) fn array_out_of_bounds() -> ! } #[inline(always)] -pub fn debug_bounds_check(_a: &ArrayBase, _index: &I) +pub fn debug_bounds_check(_a: &LayoutRef, _index: &I) where D: Dimension, I: NdIndex, - S: Data, { debug_bounds_check!(_a, *_index); } @@ -49,15 +49,15 @@ where /// Access the element at **index**. /// /// **Panics** if index is out of bounds. -impl Index for ArrayBase +impl Index for ArrayRef where D: Dimension, I: NdIndex, - S: Data, { - type Output = S::Elem; + type Output = A; + #[inline] - fn index(&self, index: I) -> &S::Elem + fn index(&self, index: I) -> &Self::Output { debug_bounds_check!(self, index); unsafe { @@ -73,14 +73,13 @@ where /// Access the element at **index** mutably. /// /// **Panics** if index is out of bounds. -impl IndexMut for ArrayBase +impl IndexMut for ArrayRef where D: Dimension, I: NdIndex, - S: DataMut, { #[inline] - fn index_mut(&mut self, index: I) -> &mut S::Elem + fn index_mut(&mut self, index: I) -> &mut A { debug_bounds_check!(self, index); unsafe { @@ -93,16 +92,48 @@ where } } +/// Access the element at **index**. +/// +/// **Panics** if index is out of bounds. +impl Index for ArrayBase +where + D: Dimension, + I: NdIndex, + S: Data, +{ + type Output = S::Elem; + + #[inline] + fn index(&self, index: I) -> &S::Elem + { + Index::index(&**self, index) + } +} + +/// Access the element at **index** mutably. +/// +/// **Panics** if index is out of bounds. +impl IndexMut for ArrayBase +where + D: Dimension, + I: NdIndex, + S: DataMut, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut S::Elem + { + IndexMut::index_mut(&mut (**self), index) + } +} + /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. -impl PartialEq> for ArrayBase +impl PartialEq> for ArrayRef where A: PartialEq, - S: Data, - S2: Data, D: Dimension, { - fn eq(&self, rhs: &ArrayBase) -> bool + fn eq(&self, rhs: &ArrayRef) -> bool { if self.shape() != rhs.shape() { return false; @@ -125,6 +156,56 @@ where } } +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +#[allow(clippy::unconditional_recursion)] // false positive +impl<'a, A, B, D> PartialEq<&'a ArrayRef> for ArrayRef +where + A: PartialEq, + D: Dimension, +{ + fn eq(&self, rhs: &&ArrayRef) -> bool + { + *self == **rhs + } +} + +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +#[allow(clippy::unconditional_recursion)] // false positive +impl<'a, A, B, D> PartialEq> for &'a ArrayRef +where + A: PartialEq, + D: Dimension, +{ + fn eq(&self, rhs: &ArrayRef) -> bool + { + **self == *rhs + } +} + +impl Eq for ArrayRef +where + D: Dimension, + A: Eq, +{ +} + +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +impl PartialEq> for ArrayBase +where + A: PartialEq, + S: Data, + S2: Data, + D: Dimension, +{ + fn eq(&self, rhs: &ArrayBase) -> bool + { + PartialEq::eq(&**self, &**rhs) + } +} + /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. #[allow(clippy::unconditional_recursion)] // false positive @@ -216,6 +297,32 @@ where S: DataOwned } } +impl<'a, A, D> IntoIterator for &'a ArrayRef +where D: Dimension +{ + type Item = &'a A; + + type IntoIter = Iter<'a, A, D>; + + fn into_iter(self) -> Self::IntoIter + { + self.iter() + } +} + +impl<'a, A, D> IntoIterator for &'a mut ArrayRef +where D: Dimension +{ + type Item = &'a mut A; + + type IntoIter = IterMut<'a, A, D>; + + fn into_iter(self) -> Self::IntoIter + { + self.iter_mut() + } +} + impl<'a, S, D> IntoIterator for &'a ArrayBase where D: Dimension, @@ -268,11 +375,10 @@ where D: Dimension } } -impl hash::Hash for ArrayBase +impl hash::Hash for ArrayRef where D: Dimension, - S: Data, - S::Elem: hash::Hash, + A: hash::Hash, { // Note: elements are hashed in the logical order fn hash(&self, state: &mut H) @@ -294,6 +400,19 @@ where } } +impl hash::Hash for ArrayBase +where + D: Dimension, + S: Data, + S::Elem: hash::Hash, +{ + // Note: elements are hashed in the logical order + fn hash(&self, state: &mut H) + { + (**self).hash(state) + } +} + // NOTE: ArrayBase keeps an internal raw pointer that always // points into the storage. This is Sync & Send as long as we // follow the usual inherited mutability rules, as we do with @@ -463,7 +582,7 @@ where D: Dimension { let data = OwnedArcRepr(Arc::new(arr.data)); // safe because: equivalent unmoved data, ptr and dims remain valid - unsafe { ArrayBase::from_data_ptr(data, arr.ptr).with_strides_dim(arr.strides, arr.dim) } + unsafe { ArrayBase::from_data_ptr(data, arr.layout.ptr).with_strides_dim(arr.layout.strides, arr.layout.dim) } } } diff --git a/src/data_traits.rs b/src/data_traits.rs index f43bfb4ef..2d7a7de31 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -23,7 +23,7 @@ use std::mem::MaybeUninit; use std::mem::{self, size_of}; use std::ptr::NonNull; -use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; +use crate::{ArcArray, Array, ArrayBase, ArrayRef, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -251,7 +251,7 @@ where A: Clone if Arc::get_mut(&mut self_.data.0).is_some() { return; } - if self_.dim.size() <= self_.data.0.len() / 2 { + if self_.layout.dim.size() <= self_.data.0.len() / 2 { // Clone only the visible elements if the current view is less than // half of backing data. *self_ = self_.to_owned().into_shared(); @@ -260,13 +260,13 @@ where A: Clone let rcvec = &mut self_.data.0; let a_size = mem::size_of::() as isize; let our_off = if a_size != 0 { - (self_.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size + (self_.layout.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size } else { 0 }; let rvec = Arc::make_mut(rcvec); unsafe { - self_.ptr = rvec.as_nonnull_mut().offset(our_off); + self_.layout.ptr = rvec.as_nonnull_mut().offset(our_off); } } @@ -286,7 +286,9 @@ unsafe impl Data for OwnedArcRepr Self::ensure_unique(&mut self_); let data = Arc::try_unwrap(self_.data.0).ok().unwrap(); // safe because data is equivalent - unsafe { ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) } + unsafe { + ArrayBase::from_data_ptr(data, self_.layout.ptr).with_strides_dim(self_.layout.strides, self_.layout.dim) + } } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> @@ -295,13 +297,14 @@ unsafe impl Data for OwnedArcRepr match Arc::try_unwrap(self_.data.0) { Ok(owned_data) => unsafe { // Safe because the data is equivalent. - Ok(ArrayBase::from_data_ptr(owned_data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + Ok(ArrayBase::from_data_ptr(owned_data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, Err(arc_data) => unsafe { // Safe because the data is equivalent; we're just // reconstructing `self_`. - Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.ptr) - .with_strides_dim(self_.strides, self_.dim)) + Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, } } @@ -598,11 +601,11 @@ where A: Clone { match array.data { CowRepr::View(_) => { - let owned = array.to_owned(); + let owned = ArrayRef::to_owned(array); array.data = CowRepr::Owned(owned.data); - array.ptr = owned.ptr; - array.dim = owned.dim; - array.strides = owned.strides; + array.layout.ptr = owned.layout.ptr; + array.layout.dim = owned.layout.dim; + array.layout.strides = owned.layout.strides; } CowRepr::Owned(_) => {} } @@ -663,7 +666,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> CowRepr::View(_) => self_.to_owned(), CowRepr::Owned(data) => unsafe { // safe because the data is equivalent so ptr, dims remain valid - ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) + ArrayBase::from_data_ptr(data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim) }, } } @@ -675,7 +679,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> CowRepr::View(_) => Err(self_), CowRepr::Owned(data) => unsafe { // safe because the data is equivalent so ptr, dims remain valid - Ok(ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + Ok(ArrayBase::from_data_ptr(data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, } } diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index eba96cdd0..bb6b7ae83 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -322,7 +322,7 @@ //! //! //! -//! [`mat1.dot(&mat2)`][matrix-* dot] +//! [`mat1.dot(&mat2)`][dot-2-2] //! //! //! @@ -336,7 +336,7 @@ //! //! //! -//! [`mat.dot(&vec)`][matrix-* dot] +//! [`mat.dot(&vec)`][dot-2-1] //! //! //! @@ -350,7 +350,7 @@ //! //! //! -//! [`vec.dot(&mat)`][vec-* dot] +//! [`vec.dot(&mat)`][dot-1-2] //! //! //! @@ -364,7 +364,7 @@ //! //! //! -//! [`vec1.dot(&vec2)`][vec-* dot] +//! [`vec1.dot(&vec2)`][dot-1-1] //! //! //! @@ -670,22 +670,22 @@ //! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array //! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square //! -//! [.abs_diff_eq()]: ArrayBase#impl-AbsDiffEq> -//! [.assign()]: ArrayBase::assign -//! [.axis_iter()]: ArrayBase::axis_iter -//! [.ncols()]: ArrayBase::ncols -//! [.column()]: ArrayBase::column -//! [.column_mut()]: ArrayBase::column_mut +//! [.abs_diff_eq()]: ArrayRef#impl-AbsDiffEq%3CArrayRef%3CB,+D%3E%3E +//! [.assign()]: ArrayRef::assign +//! [.axis_iter()]: ArrayRef::axis_iter +//! [.ncols()]: LayoutRef::ncols +//! [.column()]: ArrayRef::column +//! [.column_mut()]: ArrayRef::column_mut //! [concatenate()]: crate::concatenate() //! [concatenate!]: crate::concatenate! //! [stack!]: crate::stack! //! [::default()]: ArrayBase::default -//! [.diag()]: ArrayBase::diag -//! [.dim()]: ArrayBase::dim +//! [.diag()]: ArrayRef::diag +//! [.dim()]: LayoutRef::dim //! [::eye()]: ArrayBase::eye -//! [.fill()]: ArrayBase::fill -//! [.fold()]: ArrayBase::fold -//! [.fold_axis()]: ArrayBase::fold_axis +//! [.fill()]: ArrayRef::fill +//! [.fold()]: ArrayRef::fold +//! [.fold_axis()]: ArrayRef::fold_axis //! [::from_elem()]: ArrayBase::from_elem //! [::from_iter()]: ArrayBase::from_iter //! [::from_diag()]: ArrayBase::from_diag @@ -694,48 +694,51 @@ //! [::from_shape_vec_unchecked()]: ArrayBase::from_shape_vec_unchecked //! [::from_vec()]: ArrayBase::from_vec //! [.index()]: ArrayBase#impl-Index -//! [.indexed_iter()]: ArrayBase::indexed_iter +//! [.indexed_iter()]: ArrayRef::indexed_iter //! [.insert_axis()]: ArrayBase::insert_axis -//! [.is_empty()]: ArrayBase::is_empty -//! [.is_square()]: ArrayBase::is_square -//! [.iter()]: ArrayBase::iter -//! [.len()]: ArrayBase::len -//! [.len_of()]: ArrayBase::len_of +//! [.is_empty()]: LayoutRef::is_empty +//! [.is_square()]: LayoutRef::is_square +//! [.iter()]: ArrayRef::iter +//! [.len()]: LayoutRef::len +//! [.len_of()]: LayoutRef::len_of //! [::linspace()]: ArrayBase::linspace //! [::logspace()]: ArrayBase::logspace //! [::geomspace()]: ArrayBase::geomspace -//! [.map()]: ArrayBase::map -//! [.map_axis()]: ArrayBase::map_axis -//! [.map_inplace()]: ArrayBase::map_inplace -//! [.mapv()]: ArrayBase::mapv -//! [.mapv_inplace()]: ArrayBase::mapv_inplace +//! [.map()]: ArrayRef::map +//! [.map_axis()]: ArrayRef::map_axis +//! [.map_inplace()]: ArrayRef::map_inplace +//! [.mapv()]: ArrayRef::mapv +//! [.mapv_inplace()]: ArrayRef::mapv_inplace //! [.mapv_into()]: ArrayBase::mapv_into -//! [matrix-* dot]: ArrayBase::dot-1 -//! [.mean()]: ArrayBase::mean -//! [.mean_axis()]: ArrayBase::mean_axis -//! [.ndim()]: ArrayBase::ndim +//! [dot-2-2]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-1-1]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-1-2]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-2-1]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [.mean()]: ArrayRef::mean +//! [.mean_axis()]: ArrayRef::mean_axis +//! [.ndim()]: LayoutRef::ndim //! [::ones()]: ArrayBase::ones -//! [.outer_iter()]: ArrayBase::outer_iter +//! [.outer_iter()]: ArrayRef::outer_iter //! [::range()]: ArrayBase::range -//! [.raw_dim()]: ArrayBase::raw_dim +//! [.raw_dim()]: LayoutRef::raw_dim //! [.reversed_axes()]: ArrayBase::reversed_axes -//! [.row()]: ArrayBase::row -//! [.row_mut()]: ArrayBase::row_mut -//! [.nrows()]: ArrayBase::nrows -//! [.sum()]: ArrayBase::sum -//! [.slice()]: ArrayBase::slice -//! [.slice_axis()]: ArrayBase::slice_axis -//! [.slice_collapse()]: ArrayBase::slice_collapse +//! [.row()]: ArrayRef::row +//! [.row_mut()]: ArrayRef::row_mut +//! [.nrows()]: LayoutRef::nrows +//! [.sum()]: ArrayRef::sum +//! [.slice()]: ArrayRef::slice +//! [.slice_axis()]: ArrayRef::slice_axis +//! [.slice_collapse()]: LayoutRef::slice_collapse //! [.slice_move()]: ArrayBase::slice_move -//! [.slice_mut()]: ArrayBase::slice_mut -//! [.shape()]: ArrayBase::shape +//! [.slice_mut()]: ArrayRef::slice_mut +//! [.shape()]: LayoutRef::shape //! [stack()]: crate::stack() -//! [.strides()]: ArrayBase::strides -//! [.index_axis()]: ArrayBase::index_axis -//! [.sum_axis()]: ArrayBase::sum_axis -//! [.t()]: ArrayBase::t -//! [vec-* dot]: ArrayBase::dot -//! [.for_each()]: ArrayBase::for_each +//! [.strides()]: LayoutRef::strides +//! [.index_axis()]: ArrayRef::index_axis +//! [.sum_axis()]: ArrayRef::sum_axis +//! [.t()]: ArrayRef::t +//! [vec-* dot]: ArrayRef::dot +//! [.for_each()]: ArrayRef::for_each //! [::zeros()]: ArrayBase::zeros //! [`Zip`]: crate::Zip diff --git a/src/doc/ndarray_for_numpy_users/rk_step.rs b/src/doc/ndarray_for_numpy_users/rk_step.rs index c882a3d00..820c71d9c 100644 --- a/src/doc/ndarray_for_numpy_users/rk_step.rs +++ b/src/doc/ndarray_for_numpy_users/rk_step.rs @@ -122,7 +122,7 @@ //! //! * Use [`c.mul_add(h, t)`](f64::mul_add) instead of `t + c * h`. This is //! faster and reduces the floating-point error. It might also be beneficial -//! to use [`.scaled_add()`] or a combination of +//! to use [`.scaled_add()`](crate::ArrayRef::scaled_add) or a combination of //! [`azip!()`] and [`.mul_add()`](f64::mul_add) on the arrays in //! some places, but that's not demonstrated in the example below. //! diff --git a/src/free_functions.rs b/src/free_functions.rs index 5659d7024..c1889cec8 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -14,8 +14,8 @@ use std::compile_error; use std::mem::{forget, size_of}; use std::ptr::NonNull; -use crate::imp_prelude::*; use crate::{dimension, ArcArray1, ArcArray2}; +use crate::{imp_prelude::*, LayoutRef}; /// Create an **[`Array`]** with one, two, three, four, five, or six dimensions. /// @@ -106,10 +106,12 @@ pub const fn aview0(x: &A) -> ArrayView0<'_, A> { ArrayBase { data: ViewRepr::new(), - // Safe because references are always non-null. - ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) }, - dim: Ix0(), - strides: Ix0(), + layout: LayoutRef { + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) }, + dim: Ix0(), + strides: Ix0(), + }, } } @@ -144,10 +146,12 @@ pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> } ArrayBase { data: ViewRepr::new(), - // Safe because references are always non-null. - ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }, - dim: Ix1(xs.len()), - strides: Ix1(1), + layout: LayoutRef { + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }, + dim: Ix1(xs.len()), + strides: Ix1(1), + }, } } @@ -200,9 +204,7 @@ pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> }; ArrayBase { data: ViewRepr::new(), - ptr, - dim, - strides, + layout: LayoutRef { ptr, dim, strides }, } } diff --git a/src/impl_1d.rs b/src/impl_1d.rs index e49fdd731..bd34ba2ca 100644 --- a/src/impl_1d.rs +++ b/src/impl_1d.rs @@ -15,14 +15,11 @@ use crate::imp_prelude::*; use crate::low_level_util::AbortIfPanic; /// # Methods For 1-D Arrays -impl ArrayBase -where S: RawData +impl ArrayRef { /// Return an vector with the elements of the one-dimensional array. pub fn to_vec(&self) -> Vec - where - A: Clone, - S: Data, + where A: Clone { if let Some(slc) = self.as_slice() { slc.to_vec() @@ -34,7 +31,6 @@ where S: RawData /// Rotate the elements of the array by 1 element towards the front; /// the former first element becomes the last. pub(crate) fn rotate1_front(&mut self) - where S: DataMut { // use swapping to keep all elements initialized (as required by owned storage) let mut lane_iter = self.iter_mut(); diff --git a/src/impl_2d.rs b/src/impl_2d.rs index c2e9725ac..b6379e67b 100644 --- a/src/impl_2d.rs +++ b/src/impl_2d.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods For 2-D Arrays -impl ArrayBase -where S: RawData +impl ArrayRef { /// Return an array view of row `index`. /// @@ -24,7 +23,6 @@ where S: RawData /// ``` #[track_caller] pub fn row(&self, index: Ix) -> ArrayView1<'_, A> - where S: Data { self.index_axis(Axis(0), index) } @@ -41,11 +39,13 @@ where S: RawData /// ``` #[track_caller] pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where S: DataMut { self.index_axis_mut(Axis(0), index) } +} +impl LayoutRef +{ /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. /// /// ``` @@ -67,7 +67,10 @@ where S: RawData { self.len_of(Axis(0)) } +} +impl ArrayRef +{ /// Return an array view of column `index`. /// /// **Panics** if `index` is out of bounds. @@ -79,7 +82,6 @@ where S: RawData /// ``` #[track_caller] pub fn column(&self, index: Ix) -> ArrayView1<'_, A> - where S: Data { self.index_axis(Axis(1), index) } @@ -96,11 +98,13 @@ where S: RawData /// ``` #[track_caller] pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where S: DataMut { self.index_axis_mut(Axis(1), index) } +} +impl LayoutRef +{ /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. /// /// ``` @@ -144,3 +148,70 @@ where S: RawData m == n } } + +impl ArrayBase +{ + /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.nrows(), 3); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(m, array.nrows()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(m, array.len_of(Axis(0))); + /// ``` + pub fn nrows(&self) -> usize + { + self.as_layout_ref().nrows() + } + + /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.ncols(), 2); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(n, array.ncols()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(n, array.len_of(Axis(1))); + /// ``` + pub fn ncols(&self) -> usize + { + self.as_layout_ref().ncols() + } + + /// Return true if the array is square, false otherwise. + /// + /// # Examples + /// Square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2.], [3., 4.]]; + /// assert!(array.is_square()); + /// ``` + /// Not square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2., 5.], [3., 4., 6.]]; + /// assert!(!array.is_square()); + /// ``` + pub fn is_square(&self) -> bool + { + self.as_layout_ref().is_square() + } +} diff --git a/src/impl_clone.rs b/src/impl_clone.rs index d65f6c338..402437941 100644 --- a/src/impl_clone.rs +++ b/src/impl_clone.rs @@ -7,6 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; +use crate::LayoutRef; use crate::RawDataClone; impl Clone for ArrayBase @@ -15,12 +16,14 @@ impl Clone for ArrayBase { // safe because `clone_with_ptr` promises to provide equivalent data and ptr unsafe { - let (data, ptr) = self.data.clone_with_ptr(self.ptr); + let (data, ptr) = self.data.clone_with_ptr(self.layout.ptr); ArrayBase { data, - ptr, - dim: self.dim.clone(), - strides: self.strides.clone(), + layout: LayoutRef { + ptr, + dim: self.layout.dim.clone(), + strides: self.layout.strides.clone(), + }, } } } @@ -31,9 +34,9 @@ impl Clone for ArrayBase fn clone_from(&mut self, other: &Self) { unsafe { - self.ptr = self.data.clone_from_with_ptr(&other.data, other.ptr); - self.dim.clone_from(&other.dim); - self.strides.clone_from(&other.strides); + self.layout.ptr = self.data.clone_from_with_ptr(&other.data, other.layout.ptr); + self.layout.dim.clone_from(&other.layout.dim); + self.layout.strides.clone_from(&other.layout.strides); } } } diff --git a/src/impl_cow.rs b/src/impl_cow.rs index f064ce7bd..c409ae0cc 100644 --- a/src/impl_cow.rs +++ b/src/impl_cow.rs @@ -33,7 +33,10 @@ where D: Dimension fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> { // safe because equivalent data - unsafe { ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr).with_strides_dim(view.strides, view.dim) } + unsafe { + ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr) + .with_strides_dim(view.layout.strides, view.layout.dim) + } } } @@ -44,7 +47,8 @@ where D: Dimension { // safe because equivalent data unsafe { - ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr).with_strides_dim(array.strides, array.dim) + ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.layout.ptr) + .with_strides_dim(array.layout.strides, array.layout.dim) } } } diff --git a/src/impl_dyn.rs b/src/impl_dyn.rs index b86c5dd69..409fe991a 100644 --- a/src/impl_dyn.rs +++ b/src/impl_dyn.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods for Dynamic-Dimensional Arrays -impl ArrayBase -where S: Data +impl LayoutRef { /// Insert new array axis of length 1 at `axis`, modifying the shape and /// strides in-place. @@ -58,7 +57,56 @@ where S: Data self.dim = self.dim.remove_axis(axis); self.strides = self.strides.remove_axis(axis); } +} + +impl ArrayBase +{ + /// Insert new array axis of length 1 at `axis`, modifying the shape and + /// strides in-place. + /// + /// **Panics** if the axis is out of bounds. + /// + /// ``` + /// use ndarray::{Axis, arr2, arr3}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 3]); + /// + /// a.insert_axis_inplace(Axis(1)); + /// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn()); + /// assert_eq!(a.shape(), &[2, 1, 3]); + /// ``` + #[track_caller] + pub fn insert_axis_inplace(&mut self, axis: Axis) + { + self.as_mut().insert_axis_inplace(axis) + } + /// Collapses the array to `index` along the axis and removes the axis, + /// modifying the shape and strides in-place. + /// + /// **Panics** if `axis` or `index` is out of bounds. + /// + /// ``` + /// use ndarray::{Axis, arr1, arr2}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 3]); + /// + /// a.index_axis_inplace(Axis(1), 1); + /// assert_eq!(a, arr1(&[2, 5]).into_dyn()); + /// assert_eq!(a.shape(), &[2]); + /// ``` + #[track_caller] + pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) + { + self.as_mut().index_axis_inplace(axis, index) + } +} + +impl ArrayBase +where S: Data +{ /// Remove axes of length 1 and return the modified array. /// /// If the array has more the one dimension, the result array will always diff --git a/src/impl_internal_constructors.rs b/src/impl_internal_constructors.rs index adb4cbd35..7f95339d5 100644 --- a/src/impl_internal_constructors.rs +++ b/src/impl_internal_constructors.rs @@ -8,7 +8,7 @@ use std::ptr::NonNull; -use crate::imp_prelude::*; +use crate::{imp_prelude::*, LayoutRef}; // internal "builder-like" methods impl ArrayBase @@ -27,9 +27,11 @@ where S: RawData { let array = ArrayBase { data, - ptr, - dim: Ix1(0), - strides: Ix1(1), + layout: LayoutRef { + ptr, + dim: Ix1(0), + strides: Ix1(1), + }, }; debug_assert!(array.pointer_is_inbounds()); array @@ -58,9 +60,11 @@ where debug_assert_eq!(strides.ndim(), dim.ndim()); ArrayBase { data: self.data, - ptr: self.ptr, - dim, - strides, + layout: LayoutRef { + ptr: self.layout.ptr, + dim, + strides, + }, } } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4a00ea000..3bd1a8b68 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -38,7 +38,10 @@ use crate::math_cell::MathCell; use crate::order::Order; use crate::shape_builder::ShapeArg; use crate::zip::{IntoNdProducer, Zip}; +use crate::ArrayRef; use crate::AxisDescription; +use crate::LayoutRef; +use crate::RawRef; use crate::{arraytraits, DimMax}; use crate::iter::{ @@ -62,10 +65,7 @@ use crate::stacking::concatenate; use crate::{NdIndex, Slice, SliceInfoElem}; /// # Methods For All Array Types -impl ArrayBase -where - S: RawData, - D: Dimension, +impl LayoutRef { /// Return the total number of elements in the array. pub fn len(&self) -> usize @@ -173,20 +173,20 @@ where // strides are reinterpreted as isize self.strides[axis.index()] as isize } +} +impl ArrayRef +{ /// Return a read-only view of the array pub fn view(&self) -> ArrayView<'_, A, D> - where S: Data { - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return a read-write view of the array pub fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> - where S: DataMut { - self.ensure_unique(); unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } @@ -198,7 +198,6 @@ where /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> - where S: DataMut { self.view_mut().into_cell_view() } @@ -234,9 +233,7 @@ where /// # assert_eq!(arr, owned); /// ``` pub fn to_owned(&self) -> Array - where - A: Clone, - S: Data, + where A: Clone { if let Some(slc) = self.as_slice_memory_order() { unsafe { Array::from_shape_vec_unchecked(self.dim.clone().strides(self.strides.clone()), slc.to_vec()) } @@ -244,6 +241,50 @@ where self.map(A::clone) } } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ + /// Return an uniquely owned copy of the array. + /// + /// If the input array is contiguous, then the output array will have the same + /// memory layout. Otherwise, the layout of the output array is unspecified. + /// If you need a particular layout, you can allocate a new array with the + /// desired memory layout and [`.assign()`](ArrayRef::assign) the data. + /// Alternatively, you can collectan iterator, like this for a result in + /// standard layout: + /// + /// ``` + /// # use ndarray::prelude::*; + /// # let arr = Array::from_shape_vec((2, 2).f(), vec![1, 2, 3, 4]).unwrap(); + /// # let owned = { + /// Array::from_shape_vec(arr.raw_dim(), arr.iter().cloned().collect()).unwrap() + /// # }; + /// # assert!(owned.is_standard_layout()); + /// # assert_eq!(arr, owned); + /// ``` + /// + /// or this for a result in column-major (Fortran) layout: + /// + /// ``` + /// # use ndarray::prelude::*; + /// # let arr = Array::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); + /// # let owned = { + /// Array::from_shape_vec(arr.raw_dim().f(), arr.t().iter().cloned().collect()).unwrap() + /// # }; + /// # assert!(owned.t().is_standard_layout()); + /// # assert_eq!(arr, owned); + /// ``` + pub fn to_owned(&self) -> Array + where + A: Clone, + S: Data, + { + (**self).to_owned() + } /// Return a shared ownership (copy on write) array, cloning the array /// elements if necessary. @@ -305,7 +346,10 @@ where { S::into_shared(self) } +} +impl ArrayRef +{ /// Returns a reference to the first element of the array, or `None` if it /// is empty. /// @@ -322,7 +366,6 @@ where /// assert_eq!(b.first(), None); /// ``` pub fn first(&self) -> Option<&A> - where S: Data { if self.is_empty() { None @@ -347,7 +390,6 @@ where /// assert_eq!(b.first_mut(), None); /// ``` pub fn first_mut(&mut self) -> Option<&mut A> - where S: DataMut { if self.is_empty() { None @@ -372,7 +414,6 @@ where /// assert_eq!(b.last(), None); /// ``` pub fn last(&self) -> Option<&A> - where S: Data { if self.is_empty() { None @@ -401,12 +442,10 @@ where /// assert_eq!(b.last_mut(), None); /// ``` pub fn last_mut(&mut self) -> Option<&mut A> - where S: DataMut { if self.is_empty() { None } else { - self.ensure_unique(); let mut index = self.raw_dim(); for ax in 0..index.ndim() { index[ax] -= 1; @@ -422,9 +461,8 @@ where /// /// Iterator element type is `&A`. pub fn iter(&self) -> Iter<'_, A, D> - where S: Data { - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); self.view().into_iter_() } @@ -435,7 +473,6 @@ where /// /// Iterator element type is `&mut A`. pub fn iter_mut(&mut self) -> IterMut<'_, A, D> - where S: DataMut { self.view_mut().into_iter_() } @@ -449,7 +486,6 @@ where /// /// See also [`Zip::indexed`] pub fn indexed_iter(&self) -> IndexedIter<'_, A, D> - where S: Data { IndexedIter::new(self.view().into_elements_base()) } @@ -461,7 +497,6 @@ where /// /// Iterator element type is `(D::Pattern, &mut A)`. pub fn indexed_iter_mut(&mut self) -> IndexedIterMut<'_, A, D> - where S: DataMut { IndexedIterMut::new(self.view_mut().into_elements_base()) } @@ -475,9 +510,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice(&self, info: I) -> ArrayView<'_, A, I::OutDim> - where - I: SliceArg, - S: Data, + where I: SliceArg { self.view().slice_move(info) } @@ -491,9 +524,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice_mut(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim> - where - I: SliceArg, - S: DataMut, + where I: SliceArg { self.view_mut().slice_move(info) } @@ -523,13 +554,17 @@ where /// ``` #[track_caller] pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output - where - M: MultiSliceArg<'a, A, D>, - S: DataMut, + where M: MultiSliceArg<'a, A, D> { info.multi_slice_move(self.view_mut()) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. @@ -557,8 +592,8 @@ where // Slice the axis in-place to update the `dim`, `strides`, and `ptr`. self.slice_axis_inplace(Axis(old_axis), Slice { start, end, step }); // Copy the sliced dim and stride to corresponding axis. - new_dim[new_axis] = self.dim[old_axis]; - new_strides[new_axis] = self.strides[old_axis]; + new_dim[new_axis] = self.layout.dim[old_axis]; + new_strides[new_axis] = self.layout.strides[old_axis]; old_axis += 1; new_axis += 1; } @@ -585,16 +620,19 @@ where // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(new_strides, new_dim) } } +} +impl LayoutRef +{ /// Slice the array in place without changing the number of dimensions. /// /// In particular, if an axis is sliced with an index, the axis is /// collapsed, as in [`.collapse_axis()`], rather than removed, as in /// [`.slice_move()`] or [`.index_axis_move()`]. /// - /// [`.collapse_axis()`]: Self::collapse_axis - /// [`.slice_move()`]: Self::slice_move - /// [`.index_axis_move()`]: Self::index_axis_move + /// [`.collapse_axis()`]: LayoutRef::collapse_axis + /// [`.slice_move()`]: ArrayBase::slice_move + /// [`.index_axis_move()`]: ArrayBase::index_axis_move /// /// See [*Slicing*](#slicing) for full documentation. /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). @@ -630,7 +668,10 @@ where }); debug_assert_eq!(axis, self.ndim()); } +} +impl ArrayRef +{ /// Return a view of the array, sliced along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -638,7 +679,6 @@ where #[track_caller] #[must_use = "slice_axis returns an array view with the sliced result"] pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D> - where S: Data { let mut view = self.view(); view.slice_axis_inplace(axis, indices); @@ -652,13 +692,15 @@ where #[track_caller] #[must_use = "slice_axis_mut returns an array view with the sliced result"] pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D> - where S: DataMut { let mut view_mut = self.view_mut(); view_mut.slice_axis_inplace(axis, indices); view_mut } +} +impl LayoutRef +{ /// Slice the array in place along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -671,9 +713,15 @@ where unsafe { self.ptr = self.ptr.offset(offset); } - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Slice the array in place along the specified axis, then return the sliced array. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -684,7 +732,10 @@ where self.slice_axis_inplace(axis, indices); self } +} +impl ArrayRef +{ /// Return a view of a slice of the array, with a closure specifying the /// slice for each axis. /// @@ -694,9 +745,7 @@ where /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis(&self, f: F) -> ArrayView<'_, A, D> - where - F: FnMut(AxisDescription) -> Slice, - S: Data, + where F: FnMut(AxisDescription) -> Slice { let mut view = self.view(); view.slice_each_axis_inplace(f); @@ -712,15 +761,16 @@ where /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis_mut(&mut self, f: F) -> ArrayViewMut<'_, A, D> - where - F: FnMut(AxisDescription) -> Slice, - S: DataMut, + where F: FnMut(AxisDescription) -> Slice { let mut view = self.view_mut(); view.slice_each_axis_inplace(f); view } +} +impl LayoutRef +{ /// Slice the array in place, with a closure specifying the slice for each /// axis. /// @@ -743,7 +793,10 @@ where ) } } +} +impl ArrayRef +{ /// Return a reference to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -763,13 +816,14 @@ where /// ); /// ``` pub fn get(&self, index: I) -> Option<&A> - where - S: Data, - I: NdIndex, + where I: NdIndex { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } +} +impl RawRef +{ /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -791,17 +845,21 @@ where .index_checked(&self.dim, &self.strides) .map(move |offset| unsafe { ptr.as_ptr().offset(offset) as *const _ }) } +} +impl ArrayRef +{ /// Return a mutable reference to the element at `index`, or return `None` /// if the index is out of bounds. pub fn get_mut(&mut self, index: I) -> Option<&mut A> - where - S: DataMut, - I: NdIndex, + where I: NdIndex { unsafe { self.get_mut_ptr(index).map(|ptr| &mut *ptr) } } +} +impl RawRef +{ /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -820,9 +878,7 @@ where /// assert_eq!(a.get((0, 1)), Some(&5.)); /// ``` pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> - where - S: RawDataMut, - I: NdIndex, + where I: NdIndex { // const and mut are separate to enforce &mutness as well as the // extra code in as_mut_ptr @@ -831,7 +887,10 @@ where .index_checked(&self.dim, &self.strides) .map(move |offset| unsafe { ptr.offset(offset) }) } +} +impl ArrayRef +{ /// Perform *unchecked* array indexing. /// /// Return a reference to the element at `index`. @@ -843,9 +902,7 @@ where /// The caller must ensure that the index is in-bounds. #[inline] pub unsafe fn uget(&self, index: I) -> &A - where - S: Data, - I: NdIndex, + where I: NdIndex { arraytraits::debug_bounds_check(self, &index); let off = index.index_unchecked(&self.strides); @@ -868,11 +925,9 @@ where /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) #[inline] pub unsafe fn uget_mut(&mut self, index: I) -> &mut A - where - S: DataMut, - I: NdIndex, + where I: NdIndex { - debug_assert!(self.data.is_unique()); + // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index); let off = index.index_unchecked(&self.strides); &mut *self.ptr.as_ptr().offset(off) @@ -885,9 +940,7 @@ where /// ***Panics*** if an index is out of bounds. #[track_caller] pub fn swap(&mut self, index1: I, index2: I) - where - S: DataMut, - I: NdIndex, + where I: NdIndex { let ptr = self.as_mut_ptr(); let offset1 = index1.index_checked(&self.dim, &self.strides); @@ -918,11 +971,9 @@ where /// 2. the data is uniquely held by the array. (This property is guaranteed /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) pub unsafe fn uswap(&mut self, index1: I, index2: I) - where - S: DataMut, - I: NdIndex, + where I: NdIndex { - debug_assert!(self.data.is_unique()); + // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index1); arraytraits::debug_bounds_check(self, &index2); let off1 = index1.index_unchecked(&self.strides); @@ -933,7 +984,6 @@ where // `get` for zero-dimensional arrays // panics if dimension is not zero. otherwise an element is always present. fn get_0d(&self) -> &A - where S: Data { assert!(self.ndim() == 0); unsafe { &*self.as_ptr() } @@ -962,9 +1012,7 @@ where /// ``` #[track_caller] pub fn index_axis(&self, axis: Axis, index: usize) -> ArrayView<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { self.view().index_axis_move(axis, index) } @@ -995,16 +1043,20 @@ where /// ``` #[track_caller] pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ArrayViewMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { self.view_mut().index_axis_move(axis, index) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Collapses the array to `index` along the axis and removes the axis. /// - /// See [`.index_axis()`](Self::index_axis) and [*Subviews*](#subviews) for full documentation. + /// See [`.index_axis()`](ArrayRef::index_axis) and [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] @@ -1012,12 +1064,15 @@ where where D: RemoveAxis { self.collapse_axis(axis, index); - let dim = self.dim.remove_axis(axis); - let strides = self.strides.remove_axis(axis); + let dim = self.layout.dim.remove_axis(axis); + let strides = self.layout.strides.remove_axis(axis); // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(strides, dim) } } +} +impl LayoutRef +{ /// Selects `index` along the axis, collapsing the axis into length one. /// /// **Panics** if `axis` or `index` is out of bounds. @@ -1026,9 +1081,12 @@ where { let offset = dimension::do_collapse_axis(&mut self.dim, &self.strides, axis.index(), index); self.ptr = unsafe { self.ptr.offset(offset) }; - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); } +} +impl ArrayRef +{ /// Along `axis`, select arbitrary subviews corresponding to `indices` /// and copy them into a new array. /// @@ -1054,7 +1112,6 @@ where pub fn select(&self, axis: Axis, indices: &[Ix]) -> Array where A: Clone, - S: Data, D: RemoveAxis, { if self.ndim() == 1 { @@ -1116,7 +1173,6 @@ where /// } /// ``` pub fn rows(&self) -> Lanes<'_, A, D::Smaller> - where S: Data { let mut n = self.ndim(); if n == 0 { @@ -1130,7 +1186,6 @@ where /// /// Iterator element is `ArrayView1
` (1D read-write array view). pub fn rows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { let mut n = self.ndim(); if n == 0 { @@ -1166,7 +1221,6 @@ where /// } /// ``` pub fn columns(&self) -> Lanes<'_, A, D::Smaller> - where S: Data { Lanes::new(self.view(), Axis(0)) } @@ -1176,7 +1230,6 @@ where /// /// Iterator element is `ArrayView1` (1D read-write array view). pub fn columns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { LanesMut::new(self.view_mut(), Axis(0)) } @@ -1210,7 +1263,6 @@ where /// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2])); /// ``` pub fn lanes(&self, axis: Axis) -> Lanes<'_, A, D::Smaller> - where S: Data { Lanes::new(self.view(), axis) } @@ -1220,7 +1272,6 @@ where /// /// Iterator element is `ArrayViewMut1` (1D read-write array view). pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { LanesMut::new(self.view_mut(), axis) } @@ -1233,9 +1284,7 @@ where /// Iterator element is `ArrayView` (read-only array view). #[allow(deprecated)] pub fn outer_iter(&self) -> AxisIter<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { self.view().into_outer_iter() } @@ -1248,9 +1297,7 @@ where /// Iterator element is `ArrayViewMut` (read-write array view). #[allow(deprecated)] pub fn outer_iter_mut(&mut self) -> AxisIterMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { self.view_mut().into_outer_iter() } @@ -1272,9 +1319,7 @@ where /// #[track_caller] pub fn axis_iter(&self, axis: Axis) -> AxisIter<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { AxisIter::new(self.view(), axis) } @@ -1288,9 +1333,7 @@ where /// **Panics** if `axis` is out of bounds. #[track_caller] pub fn axis_iter_mut(&mut self, axis: Axis) -> AxisIterMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { AxisIterMut::new(self.view_mut(), axis) } @@ -1323,7 +1366,6 @@ where /// ``` #[track_caller] pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<'_, A, D> - where S: Data { AxisChunksIter::new(self.view(), axis, size) } @@ -1336,7 +1378,6 @@ where /// **Panics** if `axis` is out of bounds or if `size` is zero. #[track_caller] pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> - where S: DataMut { AxisChunksIterMut::new(self.view_mut(), axis, size) } @@ -1354,9 +1395,7 @@ where /// number of array axes.) #[track_caller] pub fn exact_chunks(&self, chunk_size: E) -> ExactChunks<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { ExactChunks::new(self.view(), chunk_size) } @@ -1395,9 +1434,7 @@ where /// ``` #[track_caller] pub fn exact_chunks_mut(&mut self, chunk_size: E) -> ExactChunksMut<'_, A, D> - where - E: IntoDimension, - S: DataMut, + where E: IntoDimension { ExactChunksMut::new(self.view_mut(), chunk_size) } @@ -1410,9 +1447,7 @@ where /// This is essentially equivalent to [`.windows_with_stride()`] with unit stride. #[track_caller] pub fn windows(&self, window_size: E) -> Windows<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { Windows::new(self.view(), window_size) } @@ -1433,13 +1468,13 @@ where /// `window_size`. /// /// Note that passing a stride of only ones is similar to - /// calling [`ArrayBase::windows()`]. + /// calling [`ArrayRef::windows()`]. /// /// **Panics** if any dimension of `window_size` or `stride` is zero.
/// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the /// number of array axes.) /// - /// This is the same illustration found in [`ArrayBase::windows()`], + /// This is the same illustration found in [`ArrayRef::windows()`], /// 2×2 windows in a 3×4 array, but now with a (1, 2) stride: /// /// ```text @@ -1463,9 +1498,7 @@ where /// ``` #[track_caller] pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { Windows::new_with_stride(self.view(), window_size, stride) } @@ -1492,7 +1525,6 @@ where /// } /// ``` pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> - where S: Data { let axis_index = axis.index(); @@ -1510,31 +1542,35 @@ where AxisWindows::new(self.view(), axis, window_size) } - // Return (length, stride) for diagonal - fn diag_params(&self) -> (Ix, Ixs) - { - /* empty shape has len 1 */ - let len = self.dim.slice().iter().cloned().min().unwrap_or(1); - let stride = self.strides().iter().sum(); - (len, stride) - } - /// Return a view of the diagonal elements of the array. /// /// The diagonal is simply the sequence indexed by *(0, 0, .., 0)*, /// *(1, 1, ..., 1)* etc as long as all axes have elements. pub fn diag(&self) -> ArrayView1<'_, A> - where S: Data { self.view().into_diag() } /// Return a read-write view over the diagonal elements of the array. pub fn diag_mut(&mut self) -> ArrayViewMut1<'_, A> - where S: DataMut { self.view_mut().into_diag() } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ + // Return (length, stride) for diagonal + fn diag_params(&self) -> (Ix, Ixs) + { + /* empty shape has len 1 */ + let len = self.layout.dim.slice().iter().cloned().min().unwrap_or(1); + let stride = self.strides().iter().sum(); + (len, stride) + } /// Return the diagonal as a one-dimensional array. pub fn into_diag(self) -> ArrayBase @@ -1560,14 +1596,17 @@ where /// Make the array unshared. /// /// This method is mostly only useful with unsafe code. - fn ensure_unique(&mut self) + pub(crate) fn ensure_unique(&mut self) where S: DataMut { debug_assert!(self.pointer_is_inbounds()); S::ensure_unique(self); debug_assert!(self.pointer_is_inbounds()); } +} +impl LayoutRef +{ /// Return `true` if the array data is laid out in contiguous “C order” in /// memory (where the last index is the most rapidly varying). /// @@ -1583,7 +1622,10 @@ where { D::is_contiguous(&self.dim, &self.strides) } +} +impl ArrayRef +{ /// Return a standard-layout array containing the data, cloning if /// necessary. /// @@ -1607,9 +1649,7 @@ where /// assert!(cow_owned.is_standard_layout()); /// ``` pub fn as_standard_layout(&self) -> CowArray<'_, A, D> - where - S: Data, - A: Clone, + where A: Clone { if self.is_standard_layout() { CowArray::from(self.view()) @@ -1625,7 +1665,10 @@ where } } } +} +impl RawRef +{ /// Return a pointer to the first element in the array. /// /// Raw access to array elements needs to follow the strided indexing @@ -1641,6 +1684,19 @@ where self.ptr.as_ptr() as *const A } + /// Return a mutable pointer to the first element in the array reference. + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut A + { + self.ptr.as_ptr() + } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Return a mutable pointer to the first element in the array. /// /// This method attempts to unshare the data. If `S: DataMut`, then the @@ -1656,9 +1712,12 @@ where where S: RawDataMut { self.try_ensure_unique(); // for ArcArray - self.ptr.as_ptr() + self.layout.ptr.as_ptr() } +} +impl RawRef +{ /// Return a raw view of the array. #[inline] pub fn raw_view(&self) -> RawArrayView @@ -1666,6 +1725,19 @@ where unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } + /// Return a raw mutable view of the array. + #[inline] + pub fn raw_view_mut(&mut self) -> RawArrayViewMut + { + unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Return a raw mutable view of the array. /// /// This method attempts to unshare the data. If `S: DataMut`, then the @@ -1675,7 +1747,7 @@ where where S: RawDataMut { self.try_ensure_unique(); // for ArcArray - unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + unsafe { RawArrayViewMut::new(self.layout.ptr, self.layout.dim.clone(), self.layout.strides.clone()) } } /// Return a raw mutable view of the array. @@ -1688,13 +1760,54 @@ where RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + /// Return the array’s data as a slice, if it is contiguous and in standard order. + /// Return `None` otherwise. + pub fn as_slice_mut(&mut self) -> Option<&mut [A]> + where S: DataMut + { + if self.is_standard_layout() { + self.ensure_unique(); + unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } + } else { + None + } + } + + /// Return the array’s data as a slice if it is contiguous, + /// return `None` otherwise. + /// + /// In the contiguous case, in order to return a unique reference, this + /// method unshares the data if necessary, but it preserves the existing + /// strides. + pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> + where S: DataMut + { + self.try_as_slice_memory_order_mut().ok() + } + + /// Return the array’s data as a slice if it is contiguous, otherwise + /// return `self` in the `Err` variant. + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> + where S: DataMut + { + if self.is_contiguous() { + self.ensure_unique(); + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } + } else { + Err(self) + } + } +} + +impl ArrayRef +{ /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// /// If this function returns `Some(_)`, then the element order in the slice /// corresponds to the logical order of the array’s elements. pub fn as_slice(&self) -> Option<&[A]> - where S: Data { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } @@ -1706,10 +1819,8 @@ where /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. pub fn as_slice_mut(&mut self) -> Option<&mut [A]> - where S: DataMut { if self.is_standard_layout() { - self.ensure_unique(); unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } } else { None @@ -1722,7 +1833,6 @@ where /// If this function returns `Some(_)`, then the elements in the slice /// have whatever order the elements have in memory. pub fn as_slice_memory_order(&self) -> Option<&[A]> - where S: Data { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); @@ -1739,7 +1849,6 @@ where /// method unshares the data if necessary, but it preserves the existing /// strides. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> - where S: DataMut { self.try_as_slice_memory_order_mut().ok() } @@ -1747,10 +1856,8 @@ where /// Return the array’s data as a slice if it is contiguous, otherwise /// return `self` in the `Err` variant. pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> - where S: DataMut { if self.is_contiguous() { - self.ensure_unique(); let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } } else { @@ -1817,7 +1924,6 @@ where where E: ShapeArg, A: Clone, - S: Data, { let (shape, order) = new_shape.into_shape_and_order(); self.to_shape_order(shape, order.unwrap_or(Order::RowMajor)) @@ -1827,7 +1933,6 @@ where where E: Dimension, A: Clone, - S: Data, { let len = self.dim.size(); if size_of_shape_checked(&shape) != Ok(len) { @@ -1861,7 +1966,13 @@ where Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(shape, view.into_iter(), A::clone))) } } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Transform the array into `shape`; any shape with the same number of /// elements is accepted, but the source array must be contiguous. /// @@ -1918,8 +2029,8 @@ where where E: Dimension { let shape = shape.into_dimension(); - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { - return Err(error::incompatible_shapes(&self.dim, &shape)); + if size_of_shape_checked(&shape) != Ok(self.layout.dim.size()) { + return Err(error::incompatible_shapes(&self.layout.dim, &shape)); } // Check if contiguous, then we can change shape @@ -1963,8 +2074,8 @@ where where E: IntoDimension { let shape = shape.into_dimension(); - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { - return Err(error::incompatible_shapes(&self.dim, &shape)); + if size_of_shape_checked(&shape) != Ok(self.layout.dim.size()) { + return Err(error::incompatible_shapes(&self.layout.dim, &shape)); } // Check if contiguous, if not => copy all, else just adapt strides unsafe { @@ -2092,7 +2203,10 @@ where unsafe { ArrayBase::from_shape_vec_unchecked(shape, v) } } } +} +impl ArrayRef +{ /// Flatten the array to a one-dimensional array. /// /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. @@ -2105,9 +2219,7 @@ where /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); /// ``` pub fn flatten(&self) -> CowArray<'_, A, Ix1> - where - A: Clone, - S: Data, + where A: Clone { self.flatten_with_order(Order::RowMajor) } @@ -2128,13 +2240,17 @@ where /// assert_eq!(flattened, arr1(&[1, 3, 5, 7, 2, 4, 6, 8])); /// ``` pub fn flatten_with_order(&self, order: Order) -> CowArray<'_, A, Ix1> - where - A: Clone, - S: Data, + where A: Clone { self.to_shape((self.len(), order)).unwrap() } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Flatten the array to a one-dimensional array, consuming the array. /// /// If possible, no copy is made, and the new array use the same memory as the original array. @@ -2169,7 +2285,8 @@ where { // safe because new dims equivalent unsafe { - ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(self.strides.into_dyn(), self.dim.into_dyn()) + ArrayBase::from_data_ptr(self.data, self.layout.ptr) + .with_strides_dim(self.layout.strides.into_dyn(), self.layout.dim.into_dyn()) } } @@ -2195,14 +2312,14 @@ where unsafe { if D::NDIM == D2::NDIM { // safe because D == D2 - let dim = unlimited_transmute::(self.dim); - let strides = unlimited_transmute::(self.strides); - return Ok(ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(strides, dim)); + let dim = unlimited_transmute::(self.layout.dim); + let strides = unlimited_transmute::(self.layout.strides); + return Ok(ArrayBase::from_data_ptr(self.data, self.layout.ptr).with_strides_dim(strides, dim)); } else if D::NDIM.is_none() || D2::NDIM.is_none() { // one is dynamic dim // safe because dim, strides are equivalent under a different type - if let Some(dim) = D2::from_dimension(&self.dim) { - if let Some(strides) = D2::from_dimension(&self.strides) { + if let Some(dim) = D2::from_dimension(&self.layout.dim) { + if let Some(strides) = D2::from_dimension(&self.layout.strides) { return Ok(self.with_strides_dim(strides, dim)); } } @@ -2210,7 +2327,10 @@ where } Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) } +} +impl ArrayRef +{ /// Act like a larger size and/or shape array by *broadcasting* /// into a larger shape, if possible. /// @@ -2241,9 +2361,7 @@ where /// ); /// ``` pub fn broadcast(&self, dim: E) -> Option> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { /// Return new stride when trying to grow `from` into shape `to` /// @@ -2311,12 +2429,10 @@ where /// /// Return `ShapeError` if their shapes can not be broadcast together. #[allow(clippy::type_complexity)] - pub(crate) fn broadcast_with<'a, 'b, B, S2, E>( - &'a self, other: &'b ArrayBase, + pub(crate) fn broadcast_with<'a, 'b, B, E>( + &'a self, other: &'b ArrayRef, ) -> Result<(ArrayView<'a, A, DimMaxOf>, ArrayView<'b, B, DimMaxOf>), ShapeError> where - S: Data, - S2: Data, D: Dimension + DimMax, E: Dimension, { @@ -2342,7 +2458,10 @@ where }; Ok((view1, view2)) } +} +impl LayoutRef +{ /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2365,7 +2484,13 @@ where self.dim.slice_mut().swap(ax, bx); self.strides.slice_mut().swap(ax, bx); } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Permute the axes. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2405,8 +2530,8 @@ where let mut new_dim = usage_counts; // reuse to avoid an allocation let mut new_strides = D::zeros(self.ndim()); { - let dim = self.dim.slice(); - let strides = self.strides.slice(); + let dim = self.layout.dim.slice(); + let strides = self.layout.strides.slice(); for (new_axis, &axis) in axes.slice().iter().enumerate() { new_dim[new_axis] = dim[axis]; new_strides[new_axis] = strides[axis]; @@ -2422,22 +2547,27 @@ where /// while retaining the same data. pub fn reversed_axes(mut self) -> ArrayBase { - self.dim.slice_mut().reverse(); - self.strides.slice_mut().reverse(); + self.layout.dim.slice_mut().reverse(); + self.layout.strides.slice_mut().reverse(); self } +} +impl ArrayRef +{ /// Return a transposed view of the array. /// /// This is a shorthand for `self.view().reversed_axes()`. /// /// See also the more general methods `.reversed_axes()` and `.swap_axes()`. pub fn t(&self) -> ArrayView<'_, A, D> - where S: Data { self.view().reversed_axes() } +} +impl LayoutRef +{ /// Return an iterator over the length and stride of each axis. pub fn axes(&self) -> Axes<'_, D> { @@ -2514,7 +2644,13 @@ where { merge_axes(&mut self.dim, &mut self.strides, take, into) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Insert new array axis at `axis` and return the result. /// /// ``` @@ -2542,8 +2678,8 @@ where assert!(axis.index() <= self.ndim()); // safe because a new axis of length one does not affect memory layout unsafe { - let strides = self.strides.insert_axis(axis); - let dim = self.dim.insert_axis(axis); + let strides = self.layout.strides.insert_axis(axis); + let dim = self.layout.dim.insert_axis(axis); self.with_strides_dim(strides, dim) } } @@ -2565,18 +2701,18 @@ where { self.data._is_pointer_inbounds(self.as_ptr()) } +} +impl ArrayRef +{ /// Perform an elementwise assigment to `self` from `rhs`. /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting isn’t possible. #[track_caller] - pub fn assign(&mut self, rhs: &ArrayBase) - where - S: DataMut, - A: Clone, - S2: Data, + pub fn assign(&mut self, rhs: &ArrayRef) + where A: Clone { self.zip_mut_with(rhs, |x, y| x.clone_from(y)); } @@ -2590,7 +2726,6 @@ where #[track_caller] pub fn assign_to

(&self, to: P) where - S: Data, P: IntoNdProducer, P::Item: AssignElem, A: Clone, @@ -2600,17 +2735,13 @@ where /// Perform an elementwise assigment to `self` from element `x`. pub fn fill(&mut self, x: A) - where - S: DataMut, - A: Clone, + where A: Clone { self.map_inplace(move |elt| elt.clone_from(&x)); } - pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayRef, mut f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2633,10 +2764,8 @@ where // zip two arrays where they have different layout or strides #[inline(always)] - fn zip_mut_with_by_rows(&mut self, rhs: &ArrayBase, mut f: F) + fn zip_mut_with_by_rows(&mut self, rhs: &ArrayRef, mut f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2652,9 +2781,7 @@ where } fn zip_mut_with_elem(&mut self, rhs_elem: &B, mut f: F) - where - S: DataMut, - F: FnMut(&mut A, &B), + where F: FnMut(&mut A, &B) { self.map_inplace(move |elt| f(elt, rhs_elem)); } @@ -2667,10 +2794,8 @@ where /// **Panics** if broadcasting isn’t possible. #[track_caller] #[inline] - pub fn zip_mut_with(&mut self, rhs: &ArrayBase, f: F) + pub fn zip_mut_with(&mut self, rhs: &ArrayRef, f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2693,13 +2818,12 @@ where where F: FnMut(B, &'a A) -> B, A: 'a, - S: Data, { if let Some(slc) = self.as_slice_memory_order() { slc.iter().fold(init, f) } else { let mut v = self.view(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + move_min_stride_axis_to_last(&mut v.layout.dim, &mut v.layout.strides); v.into_elements_base().fold(init, f) } } @@ -2726,7 +2850,6 @@ where where F: FnMut(&'a A) -> B, A: 'a, - S: Data, { unsafe { if let Some(slc) = self.as_slice_memory_order() { @@ -2751,7 +2874,6 @@ where where F: FnMut(&'a mut A) -> B, A: 'a, - S: DataMut, { let dim = self.dim.clone(); if self.is_contiguous() { @@ -2784,11 +2906,16 @@ where where F: FnMut(A) -> B, A: Clone, - S: Data, { self.map(move |x| f(x.clone())) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Call `f` by **v**alue on each element, update the array with the new values /// and return it. /// @@ -2816,7 +2943,7 @@ where /// Elements are visited in arbitrary order. /// /// [`mapv_into`]: ArrayBase::mapv_into - /// [`mapv`]: ArrayBase::mapv + /// [`mapv`]: ArrayRef::mapv pub fn mapv_into_any(self, mut f: F) -> Array where S: DataMut, @@ -2844,13 +2971,15 @@ where self.mapv(f) } } +} +impl ArrayRef +{ /// Modify the array in place by calling `f` by mutable reference on each element. /// /// Elements are visited in arbitrary order. pub fn map_inplace<'a, F>(&'a mut self, f: F) where - S: DataMut, A: 'a, F: FnMut(&'a mut A), { @@ -2858,7 +2987,7 @@ where Ok(slc) => slc.iter_mut().for_each(f), Err(arr) => { let mut v = arr.view_mut(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + move_min_stride_axis_to_last(&mut v.layout.dim, &mut v.layout.strides); v.into_elements_base().for_each(f); } } @@ -2887,7 +3016,6 @@ where /// ``` pub fn mapv_inplace(&mut self, mut f: F) where - S: DataMut, F: FnMut(A) -> A, A: Clone, { @@ -2901,7 +3029,6 @@ where where F: FnMut(&'a A), A: 'a, - S: Data, { self.fold((), move |(), elt| f(elt)) } @@ -2920,7 +3047,6 @@ where D: RemoveAxis, F: FnMut(&B, &A) -> B, B: Clone, - S: Data, { let mut res = Array::from_elem(self.raw_dim().remove_axis(axis), init); for subview in self.axis_iter(axis) { @@ -2943,7 +3069,6 @@ where D: RemoveAxis, F: FnMut(ArrayView1<'a, A>) -> B, A: 'a, - S: Data, { if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); @@ -2969,7 +3094,6 @@ where D: RemoveAxis, F: FnMut(ArrayViewMut1<'a, A>) -> B, A: 'a, - S: DataMut, { if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); @@ -2978,7 +3102,13 @@ where Zip::from(self.lanes_mut(axis)).map_collect(mapping) } } +} +impl ArrayBase +where + S: DataOwned + DataMut, + D: Dimension, +{ /// Remove the `index`th elements along `axis` and shift down elements from higher indexes. /// /// Note that this "removes" the elements by swapping them around to the end of the axis and @@ -2991,7 +3121,6 @@ where /// ***Panics*** if `axis` is out of bounds
/// ***Panics*** if not `index < self.len_of(axis)`. pub fn remove_index(&mut self, axis: Axis, index: usize) - where S: DataOwned + DataMut { assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index()); @@ -3001,7 +3130,10 @@ where // then slice the axis in place to cut out the removed final element self.slice_axis_inplace(axis, Slice::new(0, Some(-1), 1)); } +} +impl ArrayRef +{ /// Iterates over pairs of consecutive elements along the axis. /// /// The first argument to the closure is an element, and the second @@ -3031,9 +3163,7 @@ where /// ); /// ``` pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) - where - F: FnMut(&A, &mut A), - S: DataMut, + where F: FnMut(&A, &mut A) { if self.len_of(axis) <= 1 { return; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 46ea18a7c..cb5249479 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -72,6 +72,7 @@ where E: Dimension, { type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: ArrayBase) -> Self::Output { @@ -100,8 +101,37 @@ where E: Dimension, { type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: &ArrayBase) -> Self::Output + { + self.$mth(&**rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between `self` and reference `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, D, E> $trt<&'a ArrayRef> for ArrayBase +where + A: Clone + $trt, + B: Clone, + S: DataOwned + DataMut, + D: Dimension + DimMax, + E: Dimension, +{ + type Output = ArrayBase>::Output>; + + #[track_caller] + fn $mth(self, rhs: &ArrayRef) -> Self::Output { if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); @@ -141,6 +171,36 @@ where E: Dimension + DimMax, { type Output = ArrayBase>::Output>; + + #[track_caller] + fn $mth(self, rhs: ArrayBase) -> Self::Output + where + { + (&**self).$mth(rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between reference `self` and `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S2, D, E> $trt> for &'a ArrayRef +where + A: Clone + $trt, + B: Clone, + S2: DataOwned + DataMut, + D: Dimension, + E: Dimension + DimMax, +{ + type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: ArrayBase) -> Self::Output where @@ -173,16 +233,41 @@ where /// **Panics** if broadcasting isn’t possible. impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase> for &'a ArrayBase where - A: Clone + $trt, + A: Clone + $trt, B: Clone, S: Data, S2: Data, D: Dimension + DimMax, E: Dimension, { - type Output = Array>::Output>; + type Output = Array<
>::Output, >::Output>; + #[track_caller] fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + (&**self).$mth(&**rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between references `self` and `rhs`, +/// and return the result as a new `Array`. +/// +/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, D, E> $trt<&'a ArrayRef> for &'a ArrayRef +where + A: Clone + $trt, + B: Clone, + D: Dimension + DimMax, + E: Dimension, +{ + type Output = Array<>::Output, >::Output>; + + #[track_caller] + fn $mth(self, rhs: &'a ArrayRef) -> Self::Output { let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let lhs = self.view().into_dimensionality::<>::Output>().unwrap(); let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap(); @@ -226,6 +311,23 @@ impl<'a, A, S, D, B> $trt for &'a ArrayBase B: ScalarOperand, { type Output = Array; + + fn $mth(self, x: B) -> Self::Output { + (&**self).$mth(x) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between the reference `self` and the scalar `x`, +/// and return the result as a new `Array`. +impl<'a, A, D, B> $trt for &'a ArrayRef + where A: Clone + $trt, + D: Dimension, + B: ScalarOperand, +{ + type Output = Array; + fn $mth(self, x: B) -> Self::Output { self.map(move |elt| elt.clone() $operator x.clone()) } @@ -277,7 +379,21 @@ impl<'a, S, D> $trt<&'a ArrayBase> for $scalar D: Dimension, { type Output = Array<$scalar, D>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output { + self.$mth(&**rhs) + } +} + +// Perform elementwise +// between the scalar `self` and array `rhs`, +// and return the result as a new `Array`. +impl<'a, D> $trt<&'a ArrayRef<$scalar, D>> for $scalar + where D: Dimension +{ + type Output = Array<$scalar, D>; + + fn $mth(self, rhs: &ArrayRef<$scalar, D>) -> Self::Output { if_commutative!($commutative { rhs.$mth(self) } or { @@ -381,6 +497,7 @@ mod arithmetic_ops D: Dimension, { type Output = Self; + /// Perform an elementwise negation of `self` and return the result. fn neg(mut self) -> Self { @@ -398,6 +515,22 @@ mod arithmetic_ops D: Dimension, { type Output = Array; + + /// Perform an elementwise negation of reference `self` and return the + /// result as a new `Array`. + fn neg(self) -> Array + { + (&**self).neg() + } + } + + impl<'a, A, D> Neg for &'a ArrayRef + where + &'a A: 'a + Neg, + D: Dimension, + { + type Output = Array; + /// Perform an elementwise negation of reference `self` and return the /// result as a new `Array`. fn neg(self) -> Array @@ -413,6 +546,7 @@ mod arithmetic_ops D: Dimension, { type Output = Self; + /// Perform an elementwise unary not of `self` and return the result. fn not(mut self) -> Self { @@ -430,6 +564,22 @@ mod arithmetic_ops D: Dimension, { type Output = Array; + + /// Perform an elementwise unary not of reference `self` and return the + /// result as a new `Array`. + fn not(self) -> Array + { + (&**self).not() + } + } + + impl<'a, A, D> Not for &'a ArrayRef + where + &'a A: 'a + Not, + D: Dimension, + { + type Output = Array; + /// Perform an elementwise unary not of reference `self` and return the /// result as a new `Array`. fn not(self) -> Array @@ -462,6 +612,22 @@ mod assign_ops { #[track_caller] fn $method(&mut self, rhs: &ArrayBase) { + (**self).$method(&**rhs) + } + } + + #[doc=$doc] + /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. + /// + /// **Panics** if broadcasting isn’t possible. + impl<'a, A, D, E> $trt<&'a ArrayRef> for ArrayRef + where + A: Clone + $trt, + D: Dimension, + E: Dimension, + { + #[track_caller] + fn $method(&mut self, rhs: &ArrayRef) { self.zip_mut_with(rhs, |x, y| { x.$method(y.clone()); }); @@ -474,6 +640,17 @@ mod assign_ops A: ScalarOperand + $trt, S: DataMut, D: Dimension, + { + fn $method(&mut self, rhs: A) { + (**self).$method(rhs) + } + } + + #[doc=$doc] + impl $trt for ArrayRef + where + A: ScalarOperand + $trt, + D: Dimension, { fn $method(&mut self, rhs: A) { self.map_inplace(move |elt| { diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index bb970f876..023e9ebb4 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -849,7 +849,7 @@ where D: Dimension 0 }; debug_assert!(data_to_array_offset >= 0); - self.ptr = self + self.layout.ptr = self .data .reserve(len_to_append) .offset(data_to_array_offset); @@ -880,7 +880,7 @@ pub(crate) unsafe fn drop_unreachable_raw( } sort_axes_in_default_order(&mut self_); // with uninverted axes this is now the element with lowest address - let array_memory_head_ptr = self_.ptr; + let array_memory_head_ptr = self_.layout.ptr; let data_end_ptr = data_ptr.add(data_len); debug_assert!(data_ptr <= array_memory_head_ptr); debug_assert!(array_memory_head_ptr <= data_end_ptr); @@ -897,19 +897,19 @@ pub(crate) unsafe fn drop_unreachable_raw( // As an optimization, the innermost axis is removed if it has stride 1, because // we then have a long stretch of contiguous elements we can skip as one. let inner_lane_len; - if self_.ndim() > 1 && self_.strides.last_elem() == 1 { - self_.dim.slice_mut().rotate_right(1); - self_.strides.slice_mut().rotate_right(1); - inner_lane_len = self_.dim[0]; - self_.dim[0] = 1; - self_.strides[0] = 1; + if self_.ndim() > 1 && self_.layout.strides.last_elem() == 1 { + self_.layout.dim.slice_mut().rotate_right(1); + self_.layout.strides.slice_mut().rotate_right(1); + inner_lane_len = self_.layout.dim[0]; + self_.layout.dim[0] = 1; + self_.layout.strides[0] = 1; } else { inner_lane_len = 1; } // iter is a raw pointer iterator traversing the array in memory order now with the // sorted axes. - let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides); + let mut iter = Baseiter::new(self_.layout.ptr, self_.layout.dim, self_.layout.strides); let mut dropped_elements = 0; let mut last_ptr = data_ptr; @@ -948,7 +948,7 @@ where if a.ndim() <= 1 { return; } - sort_axes1_impl(&mut a.dim, &mut a.strides); + sort_axes1_impl(&mut a.layout.dim, &mut a.layout.strides); } fn sort_axes1_impl(adim: &mut D, astrides: &mut D) @@ -988,7 +988,7 @@ where if a.ndim() <= 1 { return; } - sort_axes2_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides); + sort_axes2_impl(&mut a.layout.dim, &mut a.layout.strides, &mut b.layout.dim, &mut b.layout.strides); } fn sort_axes2_impl(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D) diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 5132b1158..5bb2a0e42 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -98,10 +98,10 @@ where D: Dimension pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayView::new(self.ptr, self.dim, self.strides) + ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -113,23 +113,23 @@ where D: Dimension pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { assert!(index <= self.len_of(axis)); - let left_ptr = self.ptr.as_ptr(); + let left_ptr = self.layout.ptr.as_ptr(); let right_ptr = if index == self.len_of(axis) { - self.ptr.as_ptr() + self.layout.ptr.as_ptr() } else { - let offset = stride_offset(index, self.strides.axis(axis)); + let offset = stride_offset(index, self.layout.strides.axis(axis)); // The `.offset()` is safe due to the guarantees of `RawData`. - unsafe { self.ptr.as_ptr().offset(offset) } + unsafe { self.layout.ptr.as_ptr().offset(offset) } }; - let mut dim_left = self.dim.clone(); + let mut dim_left = self.layout.dim.clone(); dim_left.set_axis(axis, index); - let left = unsafe { Self::new_(left_ptr, dim_left, self.strides.clone()) }; + let left = unsafe { Self::new_(left_ptr, dim_left, self.layout.strides.clone()) }; - let mut dim_right = self.dim; + let mut dim_right = self.layout.dim; let right_len = dim_right.axis(axis) - index; dim_right.set_axis(axis, right_len); - let right = unsafe { Self::new_(right_ptr, dim_right, self.strides) }; + let right = unsafe { Self::new_(right_ptr, dim_right, self.layout.strides) }; (left, right) } @@ -152,8 +152,8 @@ where D: Dimension mem::size_of::(), "size mismatch in raw view cast" ); - let ptr = self.ptr.cast::(); - unsafe { RawArrayView::new(ptr, self.dim, self.strides) } + let ptr = self.layout.ptr.cast::(); + unsafe { RawArrayView::new(ptr, self.layout.dim, self.layout.strides) } } } @@ -172,11 +172,11 @@ where D: Dimension ); assert_eq!(mem::align_of::>(), mem::align_of::()); - let dim = self.dim.clone(); + let dim = self.layout.dim.clone(); // Double the strides. In the zero-sized element case and for axes of // length <= 1, we leave the strides as-is to avoid possible overflow. - let mut strides = self.strides.clone(); + let mut strides = self.layout.strides.clone(); if mem::size_of::() != 0 { for ax in 0..strides.ndim() { if dim[ax] > 1 { @@ -185,7 +185,7 @@ where D: Dimension } } - let ptr_re: *mut T = self.ptr.as_ptr().cast(); + let ptr_re: *mut T = self.layout.ptr.as_ptr().cast(); let ptr_im: *mut T = if self.is_empty() { // In the empty case, we can just reuse the existing pointer since // it won't be dereferenced anyway. It is not safe to offset by @@ -308,7 +308,7 @@ where D: Dimension #[inline] pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Converts to a read-only view of the array. @@ -323,10 +323,10 @@ where D: Dimension pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayView::new(self.ptr, self.dim, self.strides) + ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Converts to a mutable view of the array. @@ -341,10 +341,10 @@ where D: Dimension pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayViewMut::new(self.ptr, self.dim, self.strides) + ArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -356,7 +356,12 @@ where D: Dimension pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { let (left, right) = self.into_raw_view().split_at(axis, index); - unsafe { (Self::new(left.ptr, left.dim, left.strides), Self::new(right.ptr, right.dim, right.strides)) } + unsafe { + ( + Self::new(left.layout.ptr, left.layout.dim, left.layout.strides), + Self::new(right.layout.ptr, right.layout.dim, right.layout.strides), + ) + } } /// Cast the raw pointer of the raw array view to a different type @@ -377,8 +382,8 @@ where D: Dimension mem::size_of::(), "size mismatch in raw view cast" ); - let ptr = self.ptr.cast::(); - unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } + let ptr = self.layout.ptr.cast::(); + unsafe { RawArrayViewMut::new(ptr, self.layout.dim, self.layout.strides) } } } @@ -392,8 +397,8 @@ where D: Dimension let Complex { re, im } = self.into_raw_view().split_complex(); unsafe { Complex { - re: RawArrayViewMut::new(re.ptr, re.dim, re.strides), - im: RawArrayViewMut::new(im.ptr, im.dim, im.strides), + re: RawArrayViewMut::new(re.layout.ptr, re.layout.dim, re.layout.strides), + im: RawArrayViewMut::new(im.layout.ptr, im.layout.dim, im.layout.strides), } } } diff --git a/src/impl_ref_types.rs b/src/impl_ref_types.rs new file mode 100644 index 000000000..d93a996bf --- /dev/null +++ b/src/impl_ref_types.rs @@ -0,0 +1,370 @@ +//! Implementations that connect arrays to their reference types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], which represents arrays which own their layout (shape and strides) +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind +//! +//! These types are connected through a number of `Deref` and `AsRef` implementations. +//! 1. `ArrayBase` dereferences to `ArrayRef` when `S: Data` +//! 2. `ArrayBase` mutably dereferences to `ArrayRef` when `S: DataMut`, and ensures uniqueness +//! 3. `ArrayRef` mutably dereferences to `RawRef` +//! 4. `RawRef` mutably dereferences to `LayoutRef` +//! This chain works very well for arrays whose data is safe to read and is uniquely held. +//! Because raw views do not meet `S: Data`, they cannot dereference to `ArrayRef`; furthermore, +//! technical limitations of Rust's compiler means that `ArrayBase` cannot have multiple `Deref` implementations. +//! In addition, shared-data arrays do not want to go down the `Deref` path to get to methods on `RawRef` +//! or `LayoutRef`, since that would unecessarily ensure their uniqueness. +//! +//! To mitigate these problems, `ndarray` also provides `AsRef` and `AsMut` implementations as follows: +//! 1. `ArrayBase` implements `AsRef` to `RawRef` and `LayoutRef` when `S: RawData` +//! 2. `ArrayBase` implements `AsMut` to `RawRef` when `S: RawDataMut` +//! 3. `ArrayBase` implements `AsRef` and `AsMut` to `LayoutRef` unconditionally +//! 4. `ArrayRef` implements `AsRef` and `AsMut` to `RawRef` and `LayoutRef` unconditionally +//! 5. `RawRef` implements `AsRef` and `AsMut` to `LayoutRef` +//! 6. `RawRef` and `LayoutRef` implement `AsRef` and `AsMut` to themselves +//! +//! This allows users to write a single method or trait implementation that takes `T: AsRef>` +//! or `T: AsRef>` and have that functionality work on any of the relevant array types. + +use alloc::borrow::ToOwned; +use core::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut}, +}; + +use crate::{Array, ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; + +// D1: &ArrayBase -> &ArrayRef when data is safe to read +impl Deref for ArrayBase +where S: Data +{ + type Target = ArrayRef; + + fn deref(&self) -> &Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(&self.layout as *const LayoutRef).cast::>() } + } +} + +// D2: &mut ArrayBase -> &mut ArrayRef when data is safe to read; ensure uniqueness +impl DerefMut for ArrayBase +where + S: DataMut, + D: Dimension, +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + self.ensure_unique(); + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(&mut self.layout as *mut LayoutRef).cast::>() } + } +} + +// D3: &ArrayRef -> &RawRef +impl Deref for ArrayRef +{ + type Target = RawRef; + + fn deref(&self) -> &Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(self as *const ArrayRef).cast::>() } + } +} + +// D4: &mut ArrayRef -> &mut RawRef +impl DerefMut for ArrayRef +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(self as *mut ArrayRef).cast::>() } + } +} + +// D5: &RawRef -> &LayoutRef +impl Deref for RawRef +{ + type Target = LayoutRef; + + fn deref(&self) -> &Self::Target + { + &self.0 + } +} + +// D5: &mut RawRef -> &mut LayoutRef +impl DerefMut for RawRef +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + &mut self.0 + } +} + +// A1: &ArrayBase -AR-> &RawRef +impl AsRef> for ArrayBase +where S: RawData +{ + fn as_ref(&self) -> &RawRef + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(&self.layout as *const LayoutRef).cast::>() } + } +} + +// A2: &mut ArrayBase -AM-> &mut RawRef +impl AsMut> for ArrayBase +where S: RawDataMut +{ + fn as_mut(&mut self) -> &mut RawRef + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(&mut self.layout as *mut LayoutRef).cast::>() } + } +} + +// A3: &ArrayBase -AR-> &LayoutRef +impl AsRef> for ArrayBase +where S: RawData +{ + fn as_ref(&self) -> &LayoutRef + { + &self.layout + } +} + +// A3: &mut ArrayBase -AM-> &mut LayoutRef +impl AsMut> for ArrayBase +where S: RawData +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + &mut self.layout + } +} + +// A4: &ArrayRef -AR-> &RawRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &RawRef + { + self + } +} + +// A4: &mut ArrayRef -AM-> &mut RawRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut RawRef + { + self + } +} + +// A4: &ArrayRef -AR-> &LayoutRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A4: &mut ArrayRef -AM-> &mut LayoutRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +// A5: &RawRef -AR-> &LayoutRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A5: &mut RawRef -AM-> &mut LayoutRef +impl AsMut> for RawRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +// A6: &RawRef -AR-> &RawRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &RawRef + { + self + } +} + +// A6: &mut RawRef -AM-> &mut RawRef +impl AsMut> for RawRef +{ + fn as_mut(&mut self) -> &mut RawRef + { + self + } +} + +// A6: &LayoutRef -AR-> &LayoutRef +impl AsRef> for LayoutRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A6: &mut LayoutRef -AR-> &mut LayoutRef +impl AsMut> for LayoutRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +/// # Safety +/// +/// Usually the pointer would be bad to just clone, as we'd have aliasing +/// and completely separated references to the same data. However, it is +/// impossible to read the data behind the pointer from a LayoutRef (this +/// is a safety invariant that *must* be maintained), and therefore we can +/// Clone and Copy as desired. +impl Clone for LayoutRef +{ + fn clone(&self) -> Self + { + Self { + dim: self.dim.clone(), + strides: self.strides.clone(), + ptr: self.ptr, + } + } +} + +impl Copy for LayoutRef {} + +impl Borrow> for ArrayBase +where S: RawData +{ + fn borrow(&self) -> &RawRef + { + self.as_ref() + } +} + +impl BorrowMut> for ArrayBase +where S: RawDataMut +{ + fn borrow_mut(&mut self) -> &mut RawRef + { + self.as_mut() + } +} + +impl Borrow> for ArrayBase +where S: Data +{ + fn borrow(&self) -> &ArrayRef + { + self + } +} + +impl BorrowMut> for ArrayBase +where + S: DataMut, + D: Dimension, +{ + fn borrow_mut(&mut self) -> &mut ArrayRef + { + self + } +} + +impl ToOwned for ArrayRef +where + A: Clone, + D: Dimension, +{ + type Owned = Array; + + fn to_owned(&self) -> Self::Owned + { + self.to_owned() + } + + fn clone_into(&self, target: &mut Array) + { + target.zip_mut_with(self, |tgt, src| tgt.clone_from(src)); + } +} + +/// Shortcuts for the various as_ref calls +impl ArrayBase +where S: RawData +{ + /// Cheaply convert a reference to the array to an &LayoutRef + pub fn as_layout_ref(&self) -> &LayoutRef + { + self.as_ref() + } + + /// Cheaply and mutably convert a reference to the array to an &LayoutRef + pub fn as_layout_ref_mut(&mut self) -> &mut LayoutRef + { + self.as_mut() + } + + /// Cheaply convert a reference to the array to an &RawRef + pub fn as_raw_ref(&self) -> &RawRef + { + self.as_ref() + } + + /// Cheaply and mutably convert a reference to the array to an &RawRef + pub fn as_raw_ref_mut(&mut self) -> &mut RawRef + where S: RawDataMut + { + self.as_mut() + } +} diff --git a/src/impl_special_element_types.rs b/src/impl_special_element_types.rs index e430b20bc..42b524bc2 100644 --- a/src/impl_special_element_types.rs +++ b/src/impl_special_element_types.rs @@ -9,6 +9,7 @@ use std::mem::MaybeUninit; use crate::imp_prelude::*; +use crate::LayoutRef; use crate::RawDataSubst; /// Methods specific to arrays with `MaybeUninit` elements. @@ -35,9 +36,7 @@ where { let ArrayBase { data, - ptr, - dim, - strides, + layout: LayoutRef { ptr, dim, strides }, } = self; // "transmute" from storage of MaybeUninit to storage of A diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index 15f2b9b6b..e20644548 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -225,7 +225,7 @@ where D: Dimension pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D> where 'a: 'b { - unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 1dd7d97f2..efd876f7a 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -29,13 +29,13 @@ where D: Dimension pub fn reborrow<'b>(self) -> ArrayView<'b, A, D> where 'a: 'b { - unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// - /// Note that while the method is similar to [`ArrayBase::as_slice()`], this method transfers + /// Note that while the method is similar to [`ArrayRef::as_slice()`], this method transfers /// the view's lifetime to the slice, so it is a bit more powerful. pub fn to_slice(&self) -> Option<&'a [A]> { @@ -50,7 +50,7 @@ where D: Dimension /// Return `None` otherwise. /// /// Note that while the method is similar to - /// [`ArrayBase::as_slice_memory_order()`], this method transfers the view's + /// [`ArrayRef::as_slice_memory_order()`], this method transfers the view's /// lifetime to the slice, so it is a bit more powerful. pub fn to_slice_memory_order(&self) -> Option<&'a [A]> { @@ -66,7 +66,7 @@ where D: Dimension #[inline] pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -199,7 +199,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -209,7 +209,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -220,7 +220,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] @@ -250,19 +250,19 @@ where D: Dimension // Convert into a read-only view pub(crate) fn into_view(self) -> ArrayView<'a, A, D> { - unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Converts to a mutable raw array view. pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut { - unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] diff --git a/src/impl_views/indexing.rs b/src/impl_views/indexing.rs index 2b72c2142..12a47e46a 100644 --- a/src/impl_views/indexing.rs +++ b/src/impl_views/indexing.rs @@ -60,7 +60,7 @@ pub trait IndexLonger /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get + /// [1]: ArrayRef::get /// /// **Panics** if index is out of bounds. #[track_caller] @@ -68,15 +68,15 @@ pub trait IndexLonger /// Get a reference of a element through the view. /// - /// This method is like `ArrayBase::get` but with a longer lifetime (matching + /// This method is like `ArrayRef::get` but with a longer lifetime (matching /// the array view); which we can only do for the array view and not in the /// `Index` trait. /// /// See also [the `get` method][1] (and [`get_mut`][2]) which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get - /// [2]: ArrayBase::get_mut + /// [1]: ArrayRef::get + /// [2]: ArrayRef::get_mut /// /// **Panics** if index is out of bounds. #[track_caller] @@ -90,7 +90,7 @@ pub trait IndexLonger /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget + /// [1]: ArrayRef::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. /// @@ -116,7 +116,7 @@ where /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get + /// [1]: ArrayRef::get /// /// **Panics** if index is out of bounds. #[track_caller] @@ -139,7 +139,7 @@ where /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget + /// [1]: ArrayRef::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. unsafe fn uget(self, index: I) -> &'a A @@ -165,7 +165,7 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get_mut + /// [1]: ArrayRef::get_mut /// /// **Panics** if index is out of bounds. #[track_caller] @@ -186,7 +186,7 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get_mut + /// [1]: ArrayRef::get_mut /// fn get(mut self, index: I) -> Option<&'a mut A> { @@ -205,7 +205,7 @@ where /// See also [the `uget_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget_mut + /// [1]: ArrayRef::uget_mut /// /// **Note:** only unchecked for non-debug builds of ndarray. unsafe fn uget(mut self, index: I) -> &'a mut A diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 6d6ea275b..d4ccb9552 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -157,7 +157,7 @@ where D: Dimension /// [`MultiSliceArg`], [`s!`], [`SliceArg`](crate::SliceArg), and /// [`SliceInfo`](crate::SliceInfo). /// - /// [`.multi_slice_mut()`]: ArrayBase::multi_slice_mut + /// [`.multi_slice_mut()`]: ArrayRef::multi_slice_mut /// /// **Panics** if any of the following occur: /// diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 9e2f08e1e..4dd99f002 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -27,7 +27,7 @@ impl_ndproducer! { /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more +/// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. //#[derive(Debug)] pub struct ExactChunks<'a, A, D> @@ -59,10 +59,10 @@ impl<'a, A, D: Dimension> ExactChunks<'a, A, D> a.shape() ); for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; + a.layout.dim[i] /= chunk[i]; } - let inner_strides = a.strides.clone(); - a.strides *= &chunk; + let inner_strides = a.layout.strides.clone(); + a.layout.strides *= &chunk; ExactChunks { base: a, @@ -93,7 +93,7 @@ where /// Exact chunks iterator. /// -/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more +/// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. pub struct ExactChunksIter<'a, A, D> { @@ -126,7 +126,7 @@ impl_ndproducer! { /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) +/// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. //#[derive(Debug)] pub struct ExactChunksMut<'a, A, D> @@ -158,10 +158,10 @@ impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> a.shape() ); for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; + a.layout.dim[i] /= chunk[i]; } - let inner_strides = a.strides.clone(); - a.strides *= &chunk; + let inner_strides = a.layout.strides.clone(); + a.layout.strides *= &chunk; ExactChunksMut { base: a, @@ -237,7 +237,7 @@ impl_iterator! { /// Exact chunks iterator. /// -/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) +/// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. pub struct ExactChunksIterMut<'a, A, D> { diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs index 9374608cb..b51315a0f 100644 --- a/src/iterators/into_iter.rs +++ b/src/iterators/into_iter.rs @@ -39,9 +39,9 @@ where D: Dimension let array_head_ptr = array.ptr; let mut array_data = array.data; let data_len = array_data.release_all_elements(); - debug_assert!(data_len >= array.dim.size()); - let has_unreachable_elements = array.dim.size() != data_len; - let inner = Baseiter::new(array_head_ptr, array.dim, array.strides); + debug_assert!(data_len >= array.layout.dim.size()); + let has_unreachable_elements = array.layout.dim.size() != data_len; + let inner = Baseiter::new(array_head_ptr, array.layout.dim, array.layout.strides); IntoIter { array_data, diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 11c83d002..0f9678872 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -23,7 +23,7 @@ impl_ndproducer! { } } -/// See [`.lanes()`](ArrayBase::lanes) +/// See [`.lanes()`](crate::ArrayRef::lanes) /// for more information. pub struct Lanes<'a, A, D> { @@ -92,7 +92,7 @@ where D: Dimension } } -/// See [`.lanes_mut()`](ArrayBase::lanes_mut) +/// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. pub struct LanesMut<'a, A, D> { diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index e7321d15b..b779d2be5 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -336,7 +336,7 @@ pub enum ElementsRepr /// /// Iterator element type is `&'a A`. /// -/// See [`.iter()`](ArrayBase::iter) for more information. +/// See [`.iter()`](crate::ArrayRef::iter) for more information. #[derive(Debug)] pub struct Iter<'a, A, D> { @@ -355,7 +355,7 @@ pub struct ElementsBase<'a, A, D> /// /// Iterator element type is `&'a mut A`. /// -/// See [`.iter_mut()`](ArrayBase::iter_mut) for more information. +/// See [`.iter_mut()`](crate::ArrayRef::iter_mut) for more information. #[derive(Debug)] pub struct IterMut<'a, A, D> { @@ -385,12 +385,12 @@ impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> /// An iterator over the indexes and elements of an array. /// -/// See [`.indexed_iter()`](ArrayBase::indexed_iter) for more information. +/// See [`.indexed_iter()`](crate::ArrayRef::indexed_iter) for more information. #[derive(Clone)] pub struct IndexedIter<'a, A, D>(ElementsBase<'a, A, D>); /// An iterator over the indexes and elements of an array (mutable). /// -/// See [`.indexed_iter_mut()`](ArrayBase::indexed_iter_mut) for more information. +/// See [`.indexed_iter_mut()`](crate::ArrayRef::indexed_iter_mut) for more information. pub struct IndexedIterMut<'a, A, D>(ElementsBaseMut<'a, A, D>); impl<'a, A, D> IndexedIter<'a, A, D> @@ -729,7 +729,7 @@ where D: Dimension /// An iterator that traverses over all axes but one, and yields a view for /// each lane along that axis. /// -/// See [`.lanes()`](ArrayBase::lanes) for more information. +/// See [`.lanes()`](crate::ArrayRef::lanes) for more information. pub struct LanesIter<'a, A, D> { inner_len: Ix, @@ -792,7 +792,7 @@ impl<'a, A> DoubleEndedIterator for LanesIter<'a, A, Ix1> /// An iterator that traverses over all dimensions but the innermost, /// and yields each inner row (mutable). /// -/// See [`.lanes_mut()`](ArrayBase::lanes_mut) +/// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. pub struct LanesIterMut<'a, A, D> { @@ -1007,8 +1007,8 @@ where D: Dimension /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.outer_iter()`](ArrayBase::outer_iter) -/// or [`.axis_iter()`](ArrayBase::axis_iter) +/// See [`.outer_iter()`](crate::ArrayRef::outer_iter) +/// or [`.axis_iter()`](crate::ArrayRef::axis_iter) /// for more information. #[derive(Debug)] pub struct AxisIter<'a, A, D> @@ -1108,8 +1108,8 @@ where D: Dimension /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.outer_iter_mut()`](ArrayBase::outer_iter_mut) -/// or [`.axis_iter_mut()`](ArrayBase::axis_iter_mut) +/// See [`.outer_iter_mut()`](crate::ArrayRef::outer_iter_mut) +/// or [`.axis_iter_mut()`](crate::ArrayRef::axis_iter_mut) /// for more information. pub struct AxisIterMut<'a, A, D> { @@ -1314,7 +1314,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.axis_chunks_iter()`](ArrayBase::axis_chunks_iter) for more information. +/// See [`.axis_chunks_iter()`](crate::ArrayRef::axis_chunks_iter) for more information. pub struct AxisChunksIter<'a, A, D> { iter: AxisIterCore, @@ -1372,7 +1372,7 @@ fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: u let mut inner_dim = v.dim.clone(); inner_dim[axis] = size; - let mut partial_chunk_dim = v.dim; + let mut partial_chunk_dim = v.layout.dim; partial_chunk_dim[axis] = chunk_remainder; let partial_chunk_index = n_whole_chunks; @@ -1381,8 +1381,8 @@ fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: u end: iter_len, stride, inner_dim, - inner_strides: v.strides, - ptr: v.ptr.as_ptr(), + inner_strides: v.layout.strides, + ptr: v.layout.ptr.as_ptr(), }; (iter, partial_chunk_index, partial_chunk_dim) @@ -1496,7 +1496,7 @@ macro_rules! chunk_iter_impl { /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.axis_chunks_iter_mut()`](ArrayBase::axis_chunks_iter_mut) +/// See [`.axis_chunks_iter_mut()`](crate::ArrayRef::axis_chunks_iter_mut) /// for more information. pub struct AxisChunksIterMut<'a, A, D> { diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 1c2ab6a85..afdaaa895 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -9,7 +9,7 @@ use crate::Slice; /// Window producer and iterable /// -/// See [`.windows()`](ArrayBase::windows) for more +/// See [`.windows()`](crate::ArrayRef::windows) for more /// information. pub struct Windows<'a, A, D> { @@ -91,7 +91,7 @@ where /// Window iterator. /// -/// See [`.windows()`](ArrayBase::windows) for more +/// See [`.windows()`](crate::ArrayRef::windows) for more /// information. pub struct WindowsIter<'a, A, D> { @@ -129,7 +129,7 @@ send_sync_read_only!(WindowsIter); /// Window producer and iterable /// -/// See [`.axis_windows()`](ArrayBase::axis_windows) for more +/// See [`.axis_windows()`](crate::ArrayRef::axis_windows) for more /// information. pub struct AxisWindows<'a, A, D> { diff --git a/src/layout/mod.rs b/src/layout/mod.rs index 026688d63..36853848e 100644 --- a/src/layout/mod.rs +++ b/src/layout/mod.rs @@ -1,6 +1,6 @@ mod layoutfmt; -// Layout it a bitset used for internal layout description of +// Layout is a bitset used for internal layout description of // arrays, producers and sets of producers. // The type is public but users don't interact with it. #[doc(hidden)] diff --git a/src/lib.rs b/src/lib.rs index f52f25e5e..599ca59a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,12 +29,17 @@ //! dimensions, then an element in the array is accessed by using that many indices. //! Each dimension is also called an *axis*. //! +//! To get started, functionality is provided in the following core types: //! - **[`ArrayBase`]**: //! The *n*-dimensional array type itself.
//! It is used to implement both the owned arrays and the views; see its docs //! for an overview of all array features.
//! - The main specific array type is **[`Array`]**, which owns //! its elements. +//! - A reference type, **[`ArrayRef`]**, that contains most of the functionality +//! for reading and writing to arrays. +//! - A reference type, **[`LayoutRef`]**, that contains most of the functionality +//! for reading and writing to array layouts: their shape and strides. //! //! ## Highlights //! @@ -60,8 +65,8 @@ //! - Performance: //! + Prefer higher order methods and arithmetic operations on arrays first, //! then iteration, and as a last priority using indexed algorithms. -//! + The higher order functions like [`.map()`](ArrayBase::map), -//! [`.map_inplace()`](ArrayBase::map_inplace), [`.zip_mut_with()`](ArrayBase::zip_mut_with), +//! + The higher order functions like [`.map()`](ArrayRef::map), +//! [`.map_inplace()`](ArrayRef::map_inplace), [`.zip_mut_with()`](ArrayRef::zip_mut_with), //! [`Zip`] and [`azip!()`](azip) are the most efficient ways //! to perform single traversal and lock step traversal respectively. //! + Performance of an operation depends on the memory layout of the array @@ -160,6 +165,7 @@ pub use crate::shape_builder::{Shape, ShapeArg, ShapeBuilder, StrideShape}; mod macro_utils; #[macro_use] mod private; +mod impl_ref_types; mod aliases; #[macro_use] mod itertools; @@ -307,7 +313,7 @@ pub type Ixs = isize; /// data (shared ownership). /// Sharing requires that it uses copy-on-write for mutable operations. /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](Self::view_mut) or [`get_mut()`](Self::get_mut), +/// [`view_mut()`](ArrayRef::view_mut) or [`get_mut()`](ArrayRef::get_mut), /// will break sharing and require a clone of the data (if it is not uniquely held). /// /// ## `CowArray` @@ -331,9 +337,9 @@ pub type Ixs = isize; /// Please see the documentation for the respective array view for an overview /// of methods specific to array views: [`ArrayView`], [`ArrayViewMut`]. /// -/// A view is created from an array using [`.view()`](ArrayBase::view), -/// [`.view_mut()`](ArrayBase::view_mut), using -/// slicing ([`.slice()`](ArrayBase::slice), [`.slice_mut()`](ArrayBase::slice_mut)) or from one of +/// A view is created from an array using [`.view()`](ArrayRef::view), +/// [`.view_mut()`](ArrayRef::view_mut), using +/// slicing ([`.slice()`](ArrayRef::slice), [`.slice_mut()`](ArrayRef::slice_mut)) or from one of /// the many iterators that yield array views. /// /// You can also create an array view from a regular slice of data not @@ -475,12 +481,12 @@ pub type Ixs = isize; /// [`.columns()`][gc], [`.columns_mut()`][gcm], /// [`.lanes(axis)`][l], [`.lanes_mut(axis)`][lm]. /// -/// [gr]: Self::rows -/// [grm]: Self::rows_mut -/// [gc]: Self::columns -/// [gcm]: Self::columns_mut -/// [l]: Self::lanes -/// [lm]: Self::lanes_mut +/// [gr]: ArrayRef::rows +/// [grm]: ArrayRef::rows_mut +/// [gc]: ArrayRef::columns +/// [gcm]: ArrayRef::columns_mut +/// [l]: ArrayRef::lanes +/// [lm]: ArrayRef::lanes_mut /// /// Yes, for 2D arrays `.rows()` and `.outer_iter()` have about the same /// effect: @@ -506,10 +512,10 @@ pub type Ixs = isize; /// [`.slice_collapse()`] panics on `NewAxis` elements and behaves like /// [`.collapse_axis()`] by preserving the number of dimensions. /// -/// [`.slice()`]: Self::slice -/// [`.slice_mut()`]: Self::slice_mut +/// [`.slice()`]: ArrayRef::slice +/// [`.slice_mut()`]: ArrayRef::slice_mut /// [`.slice_move()`]: Self::slice_move -/// [`.slice_collapse()`]: Self::slice_collapse +/// [`.slice_collapse()`]: LayoutRef::slice_collapse /// /// When slicing arrays with generic dimensionality, creating an instance of /// [`SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`] @@ -518,17 +524,17 @@ pub type Ixs = isize; /// or to create a view and then slice individual axes of the view using /// methods such as [`.slice_axis_inplace()`] and [`.collapse_axis()`]. /// -/// [`.slice_each_axis()`]: Self::slice_each_axis -/// [`.slice_each_axis_mut()`]: Self::slice_each_axis_mut +/// [`.slice_each_axis()`]: ArrayRef::slice_each_axis +/// [`.slice_each_axis_mut()`]: ArrayRef::slice_each_axis_mut /// [`.slice_each_axis_inplace()`]: Self::slice_each_axis_inplace /// [`.slice_axis_inplace()`]: Self::slice_axis_inplace -/// [`.collapse_axis()`]: Self::collapse_axis +/// [`.collapse_axis()`]: LayoutRef::collapse_axis /// /// It's possible to take multiple simultaneous *mutable* slices with /// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) /// [`.multi_slice_move()`]. /// -/// [`.multi_slice_mut()`]: Self::multi_slice_mut +/// [`.multi_slice_mut()`]: ArrayRef::multi_slice_mut /// [`.multi_slice_move()`]: ArrayViewMut#method.multi_slice_move /// /// ``` @@ -627,16 +633,16 @@ pub type Ixs = isize; /// Methods for selecting an individual subview take two arguments: `axis` and /// `index`. /// -/// [`.axis_iter()`]: Self::axis_iter -/// [`.axis_iter_mut()`]: Self::axis_iter_mut -/// [`.fold_axis()`]: Self::fold_axis -/// [`.index_axis()`]: Self::index_axis -/// [`.index_axis_inplace()`]: Self::index_axis_inplace -/// [`.index_axis_mut()`]: Self::index_axis_mut +/// [`.axis_iter()`]: ArrayRef::axis_iter +/// [`.axis_iter_mut()`]: ArrayRef::axis_iter_mut +/// [`.fold_axis()`]: ArrayRef::fold_axis +/// [`.index_axis()`]: ArrayRef::index_axis +/// [`.index_axis_inplace()`]: LayoutRef::index_axis_inplace +/// [`.index_axis_mut()`]: ArrayRef::index_axis_mut /// [`.index_axis_move()`]: Self::index_axis_move -/// [`.collapse_axis()`]: Self::collapse_axis -/// [`.outer_iter()`]: Self::outer_iter -/// [`.outer_iter_mut()`]: Self::outer_iter_mut +/// [`.collapse_axis()`]: LayoutRef::collapse_axis +/// [`.outer_iter()`]: ArrayRef::outer_iter +/// [`.outer_iter_mut()`]: ArrayRef::outer_iter_mut /// /// ``` /// @@ -742,7 +748,7 @@ pub type Ixs = isize; /// Arrays support limited *broadcasting*, where arithmetic operations with /// array operands of different sizes can be carried out by repeating the /// elements of the smaller dimension array. See -/// [`.broadcast()`](Self::broadcast) for a more detailed +/// [`.broadcast()`](ArrayRef::broadcast) for a more detailed /// description. /// /// ``` @@ -1043,9 +1049,9 @@ pub type Ixs = isize; /// `&[A]` | `ArrayView` | [`::from_shape()`](ArrayView#method.from_shape) /// `&mut [A]` | `ArrayViewMut1
` | [`::from()`](ArrayViewMut#method.from) /// `&mut [A]` | `ArrayViewMut` | [`::from_shape()`](ArrayViewMut#method.from_shape) -/// `&ArrayBase` | `Vec` | [`.to_vec()`](Self::to_vec) +/// `&ArrayBase` | `Vec` | [`.to_vec()`](ArrayRef::to_vec) /// `Array` | `Vec` | [`.into_raw_vec()`](Array#method.into_raw_vec)[1](#into_raw_vec) -/// `&ArrayBase` | `&[A]` | [`.as_slice()`](Self::as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](Self::as_slice_memory_order)[3](#req_contig) +/// `&ArrayBase` | `&[A]` | [`.as_slice()`](ArrayRef::as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](ArrayRef::as_slice_memory_order)[3](#req_contig) /// `&mut ArrayBase` | `&mut [A]` | [`.as_slice_mut()`](Self::as_slice_mut)[2](#req_contig_std), [`.as_slice_memory_order_mut()`](Self::as_slice_memory_order_mut)[3](#req_contig) /// `ArrayView` | `&[A]` | [`.to_slice()`](ArrayView#method.to_slice)[2](#req_contig_std) /// `ArrayViewMut` | `&mut [A]` | [`.into_slice()`](ArrayViewMut#method.into_slice)[2](#req_contig_std) @@ -1069,9 +1075,9 @@ pub type Ixs = isize; /// [.into_owned()]: Self::into_owned /// [.into_shared()]: Self::into_shared /// [.to_owned()]: Self::to_owned -/// [.map()]: Self::map -/// [.view()]: Self::view -/// [.view_mut()]: Self::view_mut +/// [.map()]: ArrayRef::map +/// [.view()]: ArrayRef::view +/// [.view_mut()]: ArrayRef::view_mut /// /// ### Conversions from Nested `Vec`s/`Array`s /// @@ -1272,6 +1278,9 @@ pub type Ixs = isize; // implementation since `ArrayBase` doesn't implement `Drop` and `&mut // ArrayBase` is `!UnwindSafe`, but the implementation must not call // methods/functions on the array while it violates the constraints. +// Critically, this includes calling `DerefMut`; as a result, methods/functions +// that temporarily violate these must not rely on the `DerefMut` implementation +// for access to the underlying `ptr`, `strides`, or `dim`. // // Users of the `ndarray` crate cannot rely on these constraints because they // may change in the future. @@ -1283,15 +1292,206 @@ where S: RawData /// Data buffer / ownership information. (If owned, contains the data /// buffer; if borrowed, contains the lifetime and mutability.) data: S, + /// The dimension, strides, and pointer to inside of `data` + layout: LayoutRef, +} + +/// A reference to the layout of an *n*-dimensional array. +/// +/// This type can be used to read and write to the layout of an array; +/// that is to say, its shape and strides. It does not provide any read +/// or write access to the array's underlying data. It is generic on two +/// types: `D`, its dimensionality, and `A`, the element type of its data. +/// +/// ## Example +/// Say we wanted to write a function that provides the aspect ratio +/// of any 2D array: the ratio of its width (number of columns) to its +/// height (number of rows). We would write that as follows: +/// ```rust +/// use ndarray::{LayoutRef2, array}; +/// +/// fn aspect_ratio(layout: &T) -> (usize, usize) +/// where T: AsRef> +/// { +/// let layout = layout.as_ref(); +/// (layout.ncols(), layout.nrows()) +/// } +/// +/// let arr = array![[1, 2], [3, 4]]; +/// assert_eq!(aspect_ratio(&arr), (2, 2)); +/// ``` +/// Similarly, new traits that provide functions that only depend on +/// or alter the layout of an array should do so via a blanket +/// implementation. Lets write a trait that both provides the aspect ratio +/// and lets users cut down arrays to a desired aspect ratio. +/// For simplicity, we'll panic if the user provides an aspect ratio +/// where either element is larger than the array's size. +/// ```rust +/// use ndarray::{LayoutRef2, array, s}; +/// +/// trait Ratioable { +/// fn aspect_ratio(&self) -> (usize, usize) +/// where Self: AsRef>; +/// +/// fn cut_to_ratio(&mut self, ratio: (usize, usize)) +/// where Self: AsMut>; +/// } +/// +/// impl Ratioable for T +/// where T: AsRef> + AsMut> +/// { +/// fn aspect_ratio(&self) -> (usize, usize) +/// { +/// let layout = self.as_ref(); +/// (layout.ncols(), layout.nrows()) +/// } +/// +/// fn cut_to_ratio(&mut self, ratio: (usize, usize)) +/// { +/// let layout = self.as_mut(); +/// layout.slice_collapse(s![..ratio.1, ..ratio.0]); +/// } +/// } +/// +/// let mut arr = array![[1, 2, 3], [4, 5, 6]]; +/// assert_eq!(arr.aspect_ratio(), (3, 2)); +/// arr.cut_to_ratio((2, 2)); +/// assert_eq!(arr, array![[1, 2], [4, 5]]); +/// ``` +/// Continue reading for why we use `AsRef` instead of taking `&LayoutRef` directly. +/// +/// ## Writing Functions +/// Writing functions that accept `LayoutRef` is not as simple as taking +/// a `&LayoutRef` argument, as the above examples show. This is because +/// `LayoutRef` can be obtained either cheaply or expensively, depending +/// on the method used. `LayoutRef` can be obtained from all kinds of arrays +/// -- [owned](Array), [shared](ArcArray), [viewed](ArrayView), [referenced](ArrayRef), +/// and [raw referenced](RawRef) -- via `.as_ref()`. Critically, this way of +/// obtaining a `LayoutRef` is cheap, as it does not guarantee that the +/// underlying data is uniquely held. +/// +/// However, `LayoutRef`s can be obtained a second way: they sit at the bottom +/// of a "deref chain" going from shared arrays, through `ArrayRef`, through +/// `RawRef`, and finally to `LayoutRef`. As a result, `LayoutRef`s can also +/// be obtained via auto-dereferencing. When requesting a mutable reference -- +/// `&mut LayoutRef` -- the `deref_mut` to `ArrayRef` triggers a (possibly +/// expensive) guarantee that the data is uniquely held (see [`ArrayRef`] +/// for more information). +/// +/// To help users avoid this error cost, functions that operate on `LayoutRef`s +/// should take their parameters as a generic type `T: AsRef>`, +/// as the above examples show. This aids the caller in two ways: they can pass +/// their arrays by reference (`&arr`) instead of explicitly calling `as_ref`, +/// and they will avoid paying a performance penalty for mutating the shape. +// +// # Safety for Implementors +// +// Despite carrying around a `ptr`, maintainers of `LayoutRef` +// must *guarantee* that the pointer is *never* dereferenced. +// No read access can be used when handling a `LayoutRef`, and +// the `ptr` can *never* be exposed to the user. +// +// The reason the pointer is included here is because some methods +// which alter the layout / shape / strides of an array must also +// alter the offset of the pointer. This is allowed, as it does not +// cause a pointer deref. +#[derive(Debug)] +pub struct LayoutRef +{ /// A non-null pointer into the buffer held by `data`; may point anywhere /// in its range. If `S: Data`, this pointer must be aligned. - ptr: std::ptr::NonNull, + ptr: std::ptr::NonNull, /// The lengths of the axes. dim: D, /// The element count stride per axis. To be parsed as `isize`. strides: D, } +/// A reference to an *n*-dimensional array whose data is safe to read and write. +/// +/// This type's relationship to [`ArrayBase`] can be thought of a bit like the +/// relationship between [`Vec`] and [`std::slice`]: it represents a look into the +/// array, and is the [`Deref`](std::ops::Deref) target for owned, shared, and viewed +/// arrays. Most functionality is implemented on `ArrayRef`, and most functions +/// should take `&ArrayRef` instead of `&ArrayBase`. +/// +/// ## Relationship to Views +/// `ArrayRef` and [`ArrayView`] are very similar types: they both represent a +/// "look" into an array. There is one key difference: views have their own +/// shape and strides, while `ArrayRef` just points to the shape and strides of +/// whatever array it came from. +/// +/// As an example, let's write a function that takes an array, trims it +/// down to a square in-place, and then returns the sum: +/// ```rust +/// use std::cmp; +/// use std::ops::Add; +/// +/// use ndarray::{ArrayRef2, array, s}; +/// use num_traits::Zero; +/// +/// fn square_and_sum(arr: &mut ArrayRef2) -> A +/// where A: Clone + Add + Zero +/// { +/// let side_len = cmp::min(arr.nrows(), arr.ncols()); +/// arr.slice_collapse(s![..side_len, ..side_len]); +/// arr.sum() +/// } +/// +/// let mut arr = array![ +/// [ 1, 2, 3], +/// [ 4, 5, 6], +/// [ 7, 8, 9], +/// [10, 11, 12] +/// ]; +/// // Take a view of the array, excluding the first column +/// let mut view = arr.slice_mut(s![.., 1..]); +/// let sum_view = square_and_sum(&mut view); +/// assert_eq!(sum_view, 16); +/// assert_eq!(view.ncols(), 2usize); // The view has changed shape... +/// assert_eq!(view.nrows(), 2usize); +/// assert_eq!(arr.ncols(), 3usize); // ... but the original array has not +/// assert_eq!(arr.nrows(), 4usize); +/// +/// let sum_all = square_and_sum(&mut arr); +/// assert_eq!(sum_all, 45); +/// assert_eq!(arr.ncols(), 3usize); // Now the original array has changed shape +/// assert_eq!(arr.nrows(), 3usize); // because we passed it directly to the function +/// ``` +/// Critically, we can call the same function on both the view and the array itself. +/// We can see that, because the view has its own shape and strides, "squaring" it does +/// not affect the shape of the original array. Those only change when we pass the array +/// itself into the function. +/// +/// Also notice that the output of `slice_mut` is a *view*, not an `ArrayRef`. +/// This is where the analogy to `Vec`/`slice` breaks down a bit: due to limitations of +/// the Rust language, `ArrayRef` *cannot* have a different shape / stride from the +/// array from which it is dereferenced. So slicing still produces an `ArrayView`, +/// not an `ArrayRef`. +/// +/// ## Uniqueness +/// `ndarray` has copy-on-write shared data; see [`ArcArray`], for example. +/// When a copy-on-write array is passed to a function that takes `ArrayRef` as mutable +/// (i.e., `&mut ArrayRef`, like above), that array will be un-shared when it is dereferenced +/// into `ArrayRef`. In other words, having a `&mut ArrayRef` guarantees that the underlying +/// data is un-shared and safe to write to. +#[repr(transparent)] +pub struct ArrayRef(LayoutRef); + +/// A reference to an *n*-dimensional array whose data is not safe to read or write. +/// +/// This type is similar to [`ArrayRef`] but does not guarantee that its data is safe +/// to read or write; i.e., the underlying data may come from a shared array or be otherwise +/// unsafe to dereference. This type should be used sparingly and with extreme caution; +/// most of its methods either provide pointers or return [`RawArrayView`], both of +/// which tend to be full of unsafety. +/// +/// For the few times when this type is appropriate, it has the same `AsRef` semantics +/// as [`LayoutRef`]; see [its documentation on writing functions](LayoutRef#writing-functions) +/// for information on how to properly handle functionality on this type. +#[repr(transparent)] +pub struct RawRef(LayoutRef); + /// An array where the data has shared ownership and is copy on write. /// /// The `ArcArray` is parameterized by `A` for the element type and `D` for @@ -1300,8 +1500,8 @@ where S: RawData /// It can act as both an owner as the data as well as a shared reference (view /// like). /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](ArrayBase::view_mut) or -/// [`get_mut()`](ArrayBase::get_mut), will break sharing and +/// [`view_mut()`](ArrayRef::view_mut) or +/// [`get_mut()`](ArrayRef::get_mut), will break sharing and /// require a clone of the data (if it is not uniquely held). /// /// `ArcArray` uses atomic reference counting like `Arc`, so it is `Send` and @@ -1528,14 +1728,12 @@ mod impl_internal_constructors; mod impl_constructors; mod impl_methods; +mod alias_asref; mod impl_owned_array; mod impl_special_element_types; /// Private Methods -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef { #[inline] fn broadcast_unwrap(&self, dim: E) -> ArrayView<'_, A, E> @@ -1548,11 +1746,7 @@ where D: Dimension, E: Dimension, { - panic!( - "ndarray: could not broadcast array from shape: {:?} to: {:?}", - from.slice(), - to.slice() - ) + panic!("ndarray: could not broadcast array from shape: {:?} to: {:?}", from.slice(), to.slice()) } match self.broadcast(dim.clone()) { @@ -1574,12 +1768,18 @@ where strides.slice_mut().copy_from_slice(self.strides.slice()); unsafe { ArrayView::new(ptr, dim, strides) } } +} +impl ArrayBase +where + S: Data, + D: Dimension, +{ /// Remove array axis `axis` and return the result. fn try_remove_axis(self, axis: Axis) -> ArrayBase { - let d = self.dim.try_remove_axis(axis); - let s = self.strides.try_remove_axis(axis); + let d = self.layout.dim.try_remove_axis(axis); + let s = self.layout.strides.try_remove_axis(axis); // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(s, d) } } diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 7472d8292..df7e26abc 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -11,6 +11,8 @@ use crate::imp_prelude::*; #[cfg(feature = "blas")] use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::numeric_util; +use crate::ArrayRef1; +use crate::ArrayRef2; use crate::{LinalgScalar, Zip}; @@ -40,8 +42,7 @@ const GEMM_BLAS_CUTOFF: usize = 7; #[allow(non_camel_case_types)] type blas_index = c_int; // blas index type -impl ArrayBase -where S: Data +impl ArrayRef { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// @@ -67,10 +68,8 @@ where S: Data Dot::dot(self, rhs) } - fn dot_generic(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_generic(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { debug_assert_eq!(self.len(), rhs.len()); assert!(self.len() == rhs.len()); @@ -89,19 +88,15 @@ where S: Data } #[cfg(not(feature = "blas"))] - fn dot_impl(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_impl(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { self.dot_generic(rhs) } #[cfg(feature = "blas")] - fn dot_impl(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_impl(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { // Use only if the vector is large enough to be worth it if self.len() >= DOT_BLAS_CUTOFF { @@ -165,14 +160,64 @@ pub trait Dot /// /// For two-dimensional arrays: a rectangular array. type Output; + fn dot(&self, rhs: &Rhs) -> Self::Output; } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +macro_rules! impl_dots { + ( + $shape1:ty, + $shape2:ty + ) => { + impl Dot> for ArrayBase + where + S: Data, + S2: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + Dot::dot(&**self, &**rhs) + } + } + + impl Dot> for ArrayBase + where + S: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayRef) -> Self::Output + { + (**self).dot(rhs) + } + } + + impl Dot> for ArrayRef + where + S: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + self.dot(&**rhs) + } + } + }; +} + +impl_dots!(Ix1, Ix1); +impl_dots!(Ix1, Ix2); +impl_dots!(Ix2, Ix1); +impl_dots!(Ix2, Ix2); + +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = A; @@ -185,17 +230,14 @@ where /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> A + fn dot(&self, rhs: &ArrayRef) -> A { self.dot_impl(rhs) } } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array; @@ -209,14 +251,13 @@ where /// /// **Panics** if shapes are incompatible. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array + fn dot(&self, rhs: &ArrayRef) -> Array { - rhs.t().dot(self) + (*rhs.t()).dot(self) } } -impl ArrayBase -where S: Data +impl ArrayRef { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// @@ -255,14 +296,12 @@ where S: Data } } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array2; - fn dot(&self, b: &ArrayBase) -> Array2 + + fn dot(&self, b: &ArrayRef) -> Array2 { let a = self.view(); let b = b.view(); @@ -318,15 +357,13 @@ fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c /// Return a result array with shape *M*. /// /// **Panics** if shapes are incompatible. -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array; + #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array + fn dot(&self, rhs: &ArrayRef) -> Array { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { @@ -342,10 +379,8 @@ where } } -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef +where D: Dimension { /// Perform the operation `self += alpha * rhs` efficiently, where /// `alpha` is a scalar and `rhs` is another array. This operation is @@ -355,10 +390,8 @@ where /// /// **Panics** if broadcasting isn’t possible. #[track_caller] - pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) + pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayRef) where - S: DataMut, - S2: Data, A: LinalgScalar, E: Dimension, { @@ -366,13 +399,13 @@ where } } -// mat_mul_impl uses ArrayView arguments to send all array kinds into +// mat_mul_impl uses ArrayRef arguments to send all array kinds into // the same instantiated implementation. #[cfg(not(feature = "blas"))] use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] -fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) +fn mat_mul_impl(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) where A: LinalgScalar { let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -458,9 +491,8 @@ where A: LinalgScalar } /// C ← α A B + β C -fn mat_mul_general( - alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, -) where A: LinalgScalar +fn mat_mul_general(alpha: A, lhs: &ArrayRef2, rhs: &ArrayRef2, beta: A, c: &mut ArrayRef2) +where A: LinalgScalar { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -592,13 +624,8 @@ fn mat_mul_general( /// layout allows. The default matrixmultiply backend is otherwise used for /// `f32, f64` for all memory layouts. #[track_caller] -pub fn general_mat_mul( - alpha: A, a: &ArrayBase, b: &ArrayBase, beta: A, c: &mut ArrayBase, -) where - S1: Data, - S2: Data, - S3: DataMut, - A: LinalgScalar, +pub fn general_mat_mul(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) +where A: LinalgScalar { let ((m, k), (k2, n)) = (a.dim(), b.dim()); let (m2, n2) = c.dim(); @@ -621,13 +648,8 @@ pub fn general_mat_mul( /// layout allows. #[track_caller] #[allow(clippy::collapsible_if)] -pub fn general_mat_vec_mul( - alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: &mut ArrayBase, -) where - S1: Data, - S2: Data, - S3: DataMut, - A: LinalgScalar, +pub fn general_mat_vec_mul(alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: &mut ArrayRef1) +where A: LinalgScalar { unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) } } @@ -641,12 +663,9 @@ pub fn general_mat_vec_mul( /// The caller must ensure that the raw view is valid for writing. /// the destination may be uninitialized iff beta is zero. #[allow(clippy::collapsible_else_if)] -unsafe fn general_mat_vec_mul_impl( - alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: RawArrayViewMut, -) where - S1: Data, - S2: Data, - A: LinalgScalar, +unsafe fn general_mat_vec_mul_impl( + alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: RawArrayViewMut, +) where A: LinalgScalar { let ((m, k), k2) = (a.dim(), x.dim()); let m2 = y.dim(); @@ -658,7 +677,7 @@ unsafe fn general_mat_vec_mul_impl( ($ty:ty, $gemv:ident) => { if same_type::() { if let Some(layout) = get_blas_compatible_layout(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y.as_ref()) { // Determine stride between rows or columns. Note that the stride is // adjusted to at least `k` or `m` to handle the case of a matrix with a // trivial (length 1) dimension, since the stride for the trivial dimension @@ -671,8 +690,8 @@ unsafe fn general_mat_vec_mul_impl( // Low addr in memory pointers required for x, y let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); let x_ptr = x.ptr.as_ptr().sub(x_offset); - let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); - let y_ptr = y.ptr.as_ptr().sub(y_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.layout.dim, &y.layout.strides); + let y_ptr = y.layout.ptr.as_ptr().sub(y_offset); let x_stride = x.strides()[0] as blas_index; let y_stride = y.strides()[0] as blas_index; @@ -721,11 +740,8 @@ unsafe fn general_mat_vec_mul_impl( /// /// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) /// matrix K formed by the block multiplication A_ij * B. -pub fn kron(a: &ArrayBase, b: &ArrayBase) -> Array -where - S1: Data, - S2: Data, - A: LinalgScalar, +pub fn kron(a: &ArrayRef2, b: &ArrayRef2) -> Array +where A: LinalgScalar { let dimar = a.shape()[0]; let dimac = a.shape()[1]; @@ -774,13 +790,12 @@ fn complex_array(z: Complex) -> [A; 2] } #[cfg(feature = "blas")] -fn blas_compat_1d(a: &ArrayBase) -> bool +fn blas_compat_1d(a: &RawRef) -> bool where - S: RawData, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } if a.len() > blas_index::MAX as usize { @@ -886,8 +901,7 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] -fn get_blas_compatible_layout(a: &ArrayBase) -> Option -where S: Data +fn get_blas_compatible_layout(a: &ArrayRef) -> Option { if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { Some(BlasOrder::C) @@ -903,8 +917,7 @@ where S: Data /// /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index -where S: Data +fn blas_stride(a: &ArrayRef, order: BlasOrder) -> blas_index { let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; @@ -925,13 +938,12 @@ where S: Data #[cfg(test)] #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool +fn blas_row_major_2d(a: &ArrayRef2) -> bool where - S: Data, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } is_blas_2d(&a.dim, &a.strides, BlasOrder::C) @@ -939,13 +951,12 @@ where #[cfg(test)] #[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool +fn blas_column_major_2d(a: &ArrayRef2) -> bool where - S: Data, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } is_blas_2d(&a.dim, &a.strides, BlasOrder::F) diff --git a/src/math_cell.rs b/src/math_cell.rs index 6ed1ed71f..629e5575d 100644 --- a/src/math_cell.rs +++ b/src/math_cell.rs @@ -7,7 +7,7 @@ use std::ops::{Deref, DerefMut}; /// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except /// it will implement arithmetic operators as well. /// -/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view). +/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayRef::cell_view). /// The `MathCell` derefs to `Cell`, so all the cell's methods are available. #[repr(transparent)] #[derive(Default)] diff --git a/src/numeric/impl_float_maths.rs b/src/numeric/impl_float_maths.rs index 54fed49c2..7f6731ab8 100644 --- a/src/numeric/impl_float_maths.rs +++ b/src/numeric/impl_float_maths.rs @@ -54,10 +54,9 @@ macro_rules! binary_ops { /// /// Element-wise math functions for any array type that contains float number. #[cfg(feature = "std")] -impl ArrayBase +impl ArrayRef where A: 'static + Float, - S: Data, D: Dimension, { boolean_ops! { @@ -143,10 +142,9 @@ where } } -impl ArrayBase +impl ArrayRef where A: 'static + PartialOrd + Clone, - S: Data, D: Dimension, { /// Limit the values for each element, similar to NumPy's `clip` function. diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 6c67b9135..0f209fa33 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -17,10 +17,8 @@ use crate::numeric_util; use crate::Slice; /// # Numerical Methods for Arrays -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef +where D: Dimension { /// Return the sum of all elements in the array. /// diff --git a/src/parallel/impl_par_methods.rs b/src/parallel/impl_par_methods.rs index c6af4e8f3..0a2f41c8e 100644 --- a/src/parallel/impl_par_methods.rs +++ b/src/parallel/impl_par_methods.rs @@ -1,5 +1,5 @@ use crate::AssignElem; -use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zip}; +use crate::{Array, ArrayRef, Dimension, IntoNdProducer, NdProducer, Zip}; use super::send_producer::SendProducer; use crate::parallel::par::ParallelSplits; @@ -10,9 +10,8 @@ use crate::partial::Partial; /// # Parallel methods /// /// These methods require crate feature `rayon`. -impl ArrayBase +impl ArrayRef where - S: DataMut, D: Dimension, A: Send + Sync, { diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 0c84baa91..2eef69307 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -19,8 +19,8 @@ //! //! The following other parallelized methods exist: //! -//! - [`ArrayBase::par_map_inplace()`] -//! - [`ArrayBase::par_mapv_inplace()`] +//! - [`ArrayRef::par_map_inplace()`](crate::ArrayRef::par_map_inplace) +//! - [`ArrayRef::par_mapv_inplace()`](crate::ArrayRef::par_mapv_inplace) //! - [`Zip::par_for_each()`] (all arities) //! - [`Zip::par_map_collect()`] (all arities) //! - [`Zip::par_map_assign_into()`] (all arities) diff --git a/src/prelude.rs b/src/prelude.rs index acf39da1a..072eb4825 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -18,11 +18,26 @@ //! ``` #[doc(no_inline)] -pub use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, CowArray, RawArrayView, RawArrayViewMut}; +pub use crate::{ + ArcArray, + Array, + ArrayBase, + ArrayRef, + ArrayView, + ArrayViewMut, + CowArray, + LayoutRef, + RawArrayView, + RawArrayViewMut, + RawRef, +}; #[doc(no_inline)] pub use crate::{Axis, Dim, Dimension}; +#[doc(no_inline)] +pub use crate::{ArrayRef0, ArrayRef1, ArrayRef2, ArrayRef3, ArrayRef4, ArrayRef5, ArrayRef6, ArrayRefD}; + #[doc(no_inline)] pub use crate::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD}; diff --git a/src/shape_builder.rs b/src/shape_builder.rs index cd790a25f..b9a4b0ab6 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -210,7 +210,7 @@ where D: Dimension /// This is an argument conversion trait that is used to accept an array shape and /// (optionally) an ordering argument. /// -/// See for example [`.to_shape()`](crate::ArrayBase::to_shape). +/// See for example [`.to_shape()`](crate::ArrayRef::to_shape). pub trait ShapeArg { type Dim: Dimension; diff --git a/src/slice.rs b/src/slice.rs index e6c237a92..e2ce1e727 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -430,7 +430,7 @@ unsafe impl SliceArg for [SliceInfoElem] /// `SliceInfo` instance can still be used to slice an array with dimension /// `IxDyn` as long as the number of axes matches. /// -/// [`.slice()`]: crate::ArrayBase::slice +/// [`.slice()`]: crate::ArrayRef::slice #[derive(Debug)] pub struct SliceInfo { @@ -521,7 +521,7 @@ where } /// Returns the number of dimensions of the input array for - /// [`.slice()`](crate::ArrayBase::slice). + /// [`.slice()`](crate::ArrayRef::slice). /// /// If `Din` is a fixed-size dimension type, then this is equivalent to /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating @@ -536,7 +536,7 @@ where } /// Returns the number of dimensions after calling - /// [`.slice()`](crate::ArrayBase::slice) (including taking + /// [`.slice()`](crate::ArrayRef::slice) (including taking /// subviews). /// /// If `Dout` is a fixed-size dimension type, then this is equivalent to @@ -755,10 +755,10 @@ impl_slicenextdim!((), NewAxis, Ix0, Ix1); /// panic. Without the `NewAxis`, i.e. `s![0..4;2, 6, 1..5]`, /// [`.slice_collapse()`] would result in an array of shape `[2, 1, 4]`. /// -/// [`.slice()`]: crate::ArrayBase::slice -/// [`.slice_mut()`]: crate::ArrayBase::slice_mut +/// [`.slice()`]: crate::ArrayRef::slice +/// [`.slice_mut()`]: crate::ArrayRef::slice_mut /// [`.slice_move()`]: crate::ArrayBase::slice_move -/// [`.slice_collapse()`]: crate::ArrayBase::slice_collapse +/// [`.slice_collapse()`]: crate::LayoutRef::slice_collapse /// /// See also [*Slicing*](crate::ArrayBase#slicing). /// diff --git a/src/tri.rs b/src/tri.rs index b7d297fcc..6e3b90b5b 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -13,16 +13,14 @@ use num_traits::Zero; use crate::{ dimension::{is_layout_c, is_layout_f}, Array, - ArrayBase, + ArrayRef, Axis, - Data, Dimension, Zip, }; -impl ArrayBase +impl ArrayRef where - S: Data, D: Dimension, A: Clone + Zero, { @@ -32,7 +30,7 @@ where /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes. /// For 0D and 1D arrays, `triu` will return an unchanged clone. /// - /// See also [`ArrayBase::tril`] + /// See also [`ArrayRef::tril`] /// /// ``` /// use ndarray::array; @@ -83,7 +81,9 @@ where false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 }; lower = min(lower, ncols); - dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + (*dst) + .slice_mut(s![lower..]) + .assign(&(*src).slice(s![lower..])); }); res @@ -95,7 +95,7 @@ where /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes. /// For 0D and 1D arrays, `tril` will return an unchanged clone. /// - /// See also [`ArrayBase::triu`] + /// See also [`ArrayRef::triu`] /// /// ``` /// use ndarray::array; @@ -127,10 +127,10 @@ where if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { let mut x = self.view(); x.swap_axes(n - 2, n - 1); - let mut tril = x.triu(-k); - tril.swap_axes(n - 2, n - 1); + let mut triu = x.triu(-k); + triu.swap_axes(n - 2, n - 1); - return tril; + return triu; } let mut res = Array::zeros(self.raw_dim()); @@ -147,7 +147,9 @@ where false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow }; upper = min(upper, ncols); - dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + (*dst) + .slice_mut(s![..upper]) + .assign(&(*src).slice(s![..upper])); }); res diff --git a/src/zip/mod.rs b/src/zip/mod.rs index b58752f66..640a74d1b 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -76,10 +76,8 @@ fn array_layout(dim: &D, strides: &D) -> Layout } } -impl ArrayBase -where - S: RawData, - D: Dimension, +impl LayoutRef +where D: Dimension { pub(crate) fn layout_impl(&self) -> Layout { @@ -96,8 +94,8 @@ where fn broadcast_unwrap(self, shape: E) -> Self::Output { #[allow(clippy::needless_borrow)] - let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension()); - unsafe { ArrayView::new(res.ptr, res.dim, res.strides) } + let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension()); + unsafe { ArrayView::new(res.layout.ptr, res.layout.dim, res.layout.strides) } } private_impl! {} } @@ -762,7 +760,8 @@ macro_rules! map_impl { pub(crate) fn map_collect_owned(self, f: impl FnMut($($p::Item,)* ) -> R) -> ArrayBase - where S: DataOwned + where + S: DataOwned, { // safe because: all elements are written before the array is completed diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs index 1d1b3391b..82f3f43a7 100644 --- a/src/zip/ndproducer.rs +++ b/src/zip/ndproducer.rs @@ -1,4 +1,5 @@ use crate::imp_prelude::*; +use crate::ArrayRef; use crate::Layout; use crate::NdIndex; #[cfg(not(feature = "std"))] @@ -156,6 +157,34 @@ where } } +/// An array reference is an n-dimensional producer of element references +/// (like ArrayView). +impl<'a, A: 'a, D> IntoNdProducer for &'a ArrayRef +where D: Dimension +{ + type Item = &'a A; + type Dim = D; + type Output = ArrayView<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view() + } +} + +/// A mutable array reference is an n-dimensional producer of mutable element +/// references (like ArrayViewMut). +impl<'a, A: 'a, D> IntoNdProducer for &'a mut ArrayRef +where D: Dimension +{ + type Item = &'a mut A; + type Dim = D; + type Output = ArrayViewMut<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view_mut() + } +} + /// A slice is a one-dimensional producer impl<'a, A: 'a> IntoNdProducer for &'a [A] { @@ -239,7 +268,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn raw_dim(&self) -> Self::Dim { - self.raw_dim() + (***self).raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -249,7 +278,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ + (**self).as_ptr() as _ } fn layout(&self) -> Layout @@ -269,7 +298,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) + (**self).stride_of(axis) } #[inline(always)] @@ -295,7 +324,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn raw_dim(&self) -> Self::Dim { - self.raw_dim() + (***self).raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -305,7 +334,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ + (**self).as_ptr() as _ } fn layout(&self) -> Layout @@ -325,7 +354,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) + (**self).stride_of(axis) } #[inline(always)] @@ -356,17 +385,17 @@ impl NdProducer for RawArrayView fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) + self.layout.dim.equal(dim) } fn as_ptr(&self) -> *const A { - self.as_ptr() + self.as_ptr() as _ } fn layout(&self) -> Layout { - self.layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *const A) -> *const A @@ -376,7 +405,10 @@ impl NdProducer for RawArrayView unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + self.layout + .ptr + .as_ptr() + .offset(i.index_unchecked(&self.layout.strides)) } fn stride_of(&self, axis: Axis) -> isize @@ -412,7 +444,7 @@ impl NdProducer for RawArrayViewMut fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) + self.layout.dim.equal(dim) } fn as_ptr(&self) -> *mut A @@ -422,7 +454,7 @@ impl NdProducer for RawArrayViewMut fn layout(&self) -> Layout { - self.layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *mut A) -> *mut A @@ -432,7 +464,10 @@ impl NdProducer for RawArrayViewMut unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + self.layout + .ptr + .as_ptr() + .offset(i.index_unchecked(&self.layout.strides)) } fn stride_of(&self, axis: Axis) -> isize diff --git a/tests/array.rs b/tests/array.rs index 696904dab..fbc7b4d11 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2819,3 +2819,11 @@ fn test_split_complex_invert_axis() assert_eq!(cmplx.re, a.mapv(|z| z.re)); assert_eq!(cmplx.im, a.mapv(|z| z.im)); } + +#[test] +fn test_slice_assign() +{ + let mut a = array![0, 1, 2, 3, 4]; + *a.slice_mut(s![1..3]) += 1; + assert_eq!(a, array![0, 2, 3, 3, 4]); +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy