From 31c71da59472dd7a6f76185a563d35ce77e80b72 Mon Sep 17 00:00:00 2001 From: bluss Date: Sat, 17 Apr 2021 13:52:01 +0200 Subject: [PATCH 1/9] order: Add enum Order --- src/lib.rs | 2 ++ src/order.rs | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 src/order.rs diff --git a/src/lib.rs b/src/lib.rs index a4c59eb66..d4f721224 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,6 +145,7 @@ pub use crate::dimension::IxDynImpl; pub use crate::dimension::NdIndex; pub use crate::error::{ErrorKind, ShapeError}; pub use crate::indexes::{indices, indices_of}; +pub use crate::order::Order; pub use crate::slice::{ MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim, }; @@ -202,6 +203,7 @@ mod linspace; mod logspace; mod math_cell; mod numeric_util; +mod order; mod partial; mod shape_builder; #[macro_use] diff --git a/src/order.rs b/src/order.rs new file mode 100644 index 000000000..e8d9c8db1 --- /dev/null +++ b/src/order.rs @@ -0,0 +1,83 @@ + +/// Array order +/// +/// Order refers to indexing order, or how a linear sequence is translated +/// into a two-dimensional or multi-dimensional array. +/// +/// - `RowMajor` means that the index along the row is the most rapidly changing +/// - `ColumnMajor` means that the index along the column is the most rapidly changing +/// +/// Given a sequence like: 1, 2, 3, 4, 5, 6 +/// +/// If it is laid it out in a 2 x 3 matrix using row major ordering, it results in: +/// +/// ```text +/// 1 2 3 +/// 4 5 6 +/// ``` +/// +/// If it is laid using column major ordering, it results in: +/// +/// ```text +/// 1 3 5 +/// 2 4 6 +/// ``` +/// +/// It can be seen as filling in "rows first" or "columns first". +/// +/// `Order` can be used both to refer to logical ordering as well as memory ordering or memory +/// layout. The orderings have common short names, also seen in other environments, where +/// row major is called "C" order (after the C programming language) and column major is called "F" +/// or "Fortran" order. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Order { + /// Row major or "C" order + RowMajor, + /// Column major or "F" order + ColumnMajor, +} + +impl Order { + /// "C" is an alias for row major ordering + pub const C: Order = Order::RowMajor; + + /// "F" (for Fortran) is an alias for column major ordering + pub const F: Order = Order::ColumnMajor; + + /// Return true if input is Order::RowMajor, false otherwise + #[inline] + pub fn is_row_major(self) -> bool { + match self { + Order::RowMajor => true, + Order::ColumnMajor => false, + } + } + + /// Return true if input is Order::ColumnMajor, false otherwise + #[inline] + pub fn is_column_major(self) -> bool { + !self.is_row_major() + } + + /// Return Order::RowMajor if the input is true, Order::ColumnMajor otherwise + #[inline] + pub fn row_major(row_major: bool) -> Order { + if row_major { Order::RowMajor } else { Order::ColumnMajor } + } + + /// Return Order::ColumnMajor if the input is true, Order::RowMajor otherwise + #[inline] + pub fn column_major(column_major: bool) -> Order { + Self::row_major(!column_major) + } + + /// Return the transpose: row major becomes column major and vice versa. + #[inline] + pub fn transpose(self) -> Order { + match self { + Order::RowMajor => Order::ColumnMajor, + Order::ColumnMajor => Order::RowMajor, + } + } +} From 0d094bf70c843ed3fa594dd76593ab967f831da6 Mon Sep 17 00:00:00 2001 From: bluss Date: Sat, 17 Apr 2021 12:26:49 +0200 Subject: [PATCH 2/9] shape: Add trait ShapeArg --- src/impl_methods.rs | 2 +- src/lib.rs | 2 +- src/shape_builder.rs | 31 +++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 9ef4277de..0affd9f0b 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -26,8 +26,8 @@ use crate::dimension::broadcast::co_broadcast; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; -use crate::zip::{IntoNdProducer, Zip}; use crate::AxisDescription; +use crate::zip::{IntoNdProducer, Zip}; use crate::iter::{ AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut, diff --git a/src/lib.rs b/src/lib.rs index d4f721224..dfce924e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,7 +163,7 @@ pub use crate::stacking::{concatenate, stack, stack_new_axis}; pub use crate::math_cell::MathCell; pub use crate::impl_views::IndexLonger; -pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape}; +pub use crate::shape_builder::{Shape, ShapeBuilder, ShapeArg, StrideShape}; #[macro_use] mod macro_utils; diff --git a/src/shape_builder.rs b/src/shape_builder.rs index dcfddc1b9..470374077 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -1,5 +1,6 @@ use crate::dimension::IntoDimension; use crate::Dimension; +use crate::order::Order; /// A contiguous array shape of n dimensions. /// @@ -184,3 +185,33 @@ where self.dim.size() } } + + +/// Array shape argument with optional order parameter +/// +/// Shape or array dimension argument, with optional [`Order`] parameter. +/// +/// This is an argument conversion trait that is used to accept an array shape and +/// (optionally) an ordering argument. +/// +/// See for example [`.to_shape()`](crate::ArrayBase::to_shape). +pub trait ShapeArg { + type Dim: Dimension; + fn into_shape_and_order(self) -> (Self::Dim, Option); +} + +impl ShapeArg for T where T: IntoDimension { + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) { + (self.into_dimension(), None) + } +} + +impl ShapeArg for (T, Order) where T: IntoDimension { + type Dim = T::Dim; + + fn into_shape_and_order(self) -> (Self::Dim, Option) { + (self.0.into_dimension(), Some(self.1)) + } +} From 3794952ce342ad7af866470a202cfedaed74bf49 Mon Sep 17 00:00:00 2001 From: bluss Date: Fri, 15 Jan 2021 20:58:38 +0100 Subject: [PATCH 3/9] shape: Add method .to_shape() --- src/impl_methods.rs | 87 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 0affd9f0b..cda2cd95e 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -27,6 +27,9 @@ use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::AxisDescription; +use crate::Layout; +use crate::order::Order; +use crate::shape_builder::ShapeArg; use crate::zip::{IntoNdProducer, Zip}; use crate::iter::{ @@ -1577,6 +1580,90 @@ where } } + /// Transform the array into `new_shape`; any shape with the same number of elements is + /// accepted. + /// + /// `order` specifies the *logical* order in which the array is to be read and reshaped. + /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. + /// + /// For example, when starting from the one-dimensional sequence 1 2 3 4 5 6, it would be + /// understood as a 2 x 3 array in row major ("C") order this way: + /// + /// ```text + /// 1 2 3 + /// 4 5 6 + /// ``` + /// + /// and as 2 x 3 in column major ("F") order this way: + /// + /// ```text + /// 1 3 5 + /// 2 4 6 + /// ``` + /// + /// This example should show that any time we "reflow" the elements in the array to a different + /// number of rows and columns (or more axes if applicable), it is important to pick an index + /// ordering, and that's the reason for the function parameter for `order`. + /// + /// **Errors** if the new shape doesn't have the same number of elements as the array's current + /// shape. + /// + /// ``` + /// use ndarray::array; + /// use ndarray::Order; + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::RowMajor)).unwrap() + /// == array![[1., 2., 3.], + /// [4., 5., 6.]] + /// ); + /// + /// assert!( + /// array![1., 2., 3., 4., 5., 6.].to_shape(((2, 3), Order::ColumnMajor)).unwrap() + /// == array![[1., 3., 5.], + /// [2., 4., 6.]] + /// ); + /// ``` + pub fn to_shape(&self, new_shape: E) -> Result, ShapeError> + where + E: ShapeArg, + A: Clone, + S: Data, + { + let (shape, order) = new_shape.into_shape_and_order(); + self.to_shape_order(shape, order.unwrap_or(Order::RowMajor)) + } + + fn to_shape_order(&self, shape: E, order: Order) + -> Result, ShapeError> + where + E: Dimension, + A: Clone, + S: Data, + { + if size_of_shape_checked(&shape) != Ok(self.dim.size()) { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + let layout = self.layout_impl(); + + unsafe { + if layout.is(Layout::CORDER) && order == Order::RowMajor { + let strides = shape.default_strides(); + Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) + } else if layout.is(Layout::FORDER) && order == Order::ColumnMajor { + let strides = shape.fortran_strides(); + Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) + } else { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( + shape, view.into_iter(), A::clone))) + } + } + } + /// Transform the array into `shape`; any shape with the same number of /// elements is accepted, but the source array or view must be in standard /// or column-major (Fortran) layout. From 8032aec1b742c36a985cd69f0ef53ec6adad6f28 Mon Sep 17 00:00:00 2001 From: bluss Date: Sat, 17 Apr 2021 12:26:49 +0200 Subject: [PATCH 4/9] shape: Add tests for .to_shape() --- tests/array.rs | 60 +----------------- tests/reshape.rs | 159 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 59 deletions(-) create mode 100644 tests/reshape.rs diff --git a/tests/array.rs b/tests/array.rs index 976824dfe..edd58adbc 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -8,7 +8,7 @@ )] use defmac::defmac; -use itertools::{enumerate, zip, Itertools}; +use itertools::{zip, Itertools}; use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::indices; @@ -1370,64 +1370,6 @@ fn transpose_view_mut() { assert_eq!(at, arr2(&[[1, 4], [2, 5], [3, 7]])); } -#[test] -fn reshape() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let u = v.into_shape((3, 3)); - assert!(u.is_err()); - let u = v.into_shape((2, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - assert_eq!(u.shape(), &[2, 2, 2]); - let s = u.into_shape((4, 2)).unwrap(); - assert_eq!(s.shape(), &[4, 2]); - assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); -} - -#[test] -#[should_panic(expected = "IncompatibleShape")] -fn reshape_error1() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let _u = v.into_shape((2, 5)).unwrap(); -} - -#[test] -#[should_panic(expected = "IncompatibleLayout")] -fn reshape_error2() { - let data = [1, 2, 3, 4, 5, 6, 7, 8]; - let v = aview1(&data); - let mut u = v.into_shape((2, 2, 2)).unwrap(); - u.swap_axes(0, 1); - let _s = u.into_shape((2, 4)).unwrap(); -} - -#[test] -fn reshape_f() { - let mut u = Array::zeros((3, 4).f()); - for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { - *elt = i as i32; - } - let v = u.view(); - println!("{:?}", v); - - // noop ok - let v2 = v.into_shape((3, 4)); - assert!(v2.is_ok()); - assert_eq!(v, v2.unwrap()); - - let u = v.into_shape((3, 2, 2)); - assert!(u.is_ok()); - let u = u.unwrap(); - println!("{:?}", u); - assert_eq!(u.shape(), &[3, 2, 2]); - let s = u.into_shape((4, 3)).unwrap(); - println!("{:?}", s); - assert_eq!(s.shape(), &[4, 3]); - assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); -} - #[test] #[allow(clippy::cognitive_complexity)] fn insert_axis() { diff --git a/tests/reshape.rs b/tests/reshape.rs new file mode 100644 index 000000000..f03f4ccf1 --- /dev/null +++ b/tests/reshape.rs @@ -0,0 +1,159 @@ +use ndarray::prelude::*; + +use itertools::enumerate; + +use ndarray::Order; + +#[test] +fn reshape() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.into_shape((3, 3)); + assert!(u.is_err()); + let u = v.into_shape((2, 2, 2)); + assert!(u.is_ok()); + let u = u.unwrap(); + assert_eq!(u.shape(), &[2, 2, 2]); + let s = u.into_shape((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn reshape_error1() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.into_shape((2, 5)).unwrap(); +} + +#[test] +#[should_panic(expected = "IncompatibleLayout")] +fn reshape_error2() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let mut u = v.into_shape((2, 2, 2)).unwrap(); + u.swap_axes(0, 1); + let _s = u.into_shape((2, 4)).unwrap(); +} + +#[test] +fn reshape_f() { + let mut u = Array::zeros((3, 4).f()); + for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) { + *elt = i as i32; + } + let v = u.view(); + println!("{:?}", v); + + // noop ok + let v2 = v.into_shape((3, 4)); + assert!(v2.is_ok()); + assert_eq!(v, v2.unwrap()); + + let u = v.into_shape((3, 2, 2)); + assert!(u.is_ok()); + let u = u.unwrap(); + println!("{:?}", u); + assert_eq!(u.shape(), &[3, 2, 2]); + let s = u.into_shape((4, 3)).unwrap(); + println!("{:?}", s); + assert_eq!(s.shape(), &[4, 3]); + assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])); +} + + +#[test] +fn to_shape_easy() { + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::RowMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::C)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + let s = u.to_shape((4, 2)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]])); + + // 1D -> F -> F + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((3, 3), Order::ColumnMajor)); + assert!(u.is_err()); + + let u = v.to_shape(((2, 2, 2), Order::ColumnMajor)); + assert!(u.is_ok()); + + let u = u.unwrap(); + assert!(u.is_view()); + assert_eq!(u.shape(), &[2, 2, 2]); + assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]); + + let s = u.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(s.shape(), &[4, 2]); + assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]); +} + +#[test] +fn to_shape_copy() { + // 1D -> C -> F + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 2], [3, 4], [5, 6], [7, 8]]); + + let u = u.to_shape(((2, 4), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); + + // 1D -> F -> C + let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]); + let u = v.to_shape(((4, 2), Order::ColumnMajor)).unwrap(); + assert_eq!(u.shape(), &[4, 2]); + assert_eq!(u, array![[1, 5], [2, 6], [3, 7], [4, 8]]); + + let u = u.to_shape((2, 4)).unwrap(); + assert_eq!(u.shape(), &[2, 4]); + assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]); +} + +#[test] +fn to_shape_add_axis() { + // 1D -> C -> C + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); + + assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); + assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned()); +} + + +#[test] +fn to_shape_copy_stride() { + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..3]); + let lin1 = vs.to_shape(6).unwrap(); + assert_eq!(lin1, array![1, 2, 3, 5, 6, 7]); + assert!(lin1.is_owned()); + + let lin2 = vs.to_shape((6, Order::ColumnMajor)).unwrap(); + assert_eq!(lin2, array![1, 5, 2, 6, 3, 7]); + assert!(lin2.is_owned()); +} + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error1() { + let data = [1, 2, 3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, 5)).unwrap(); +} From a26bf94afcea540676b715f0b5ff4355317f2ed6 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 9 May 2021 20:01:26 +0200 Subject: [PATCH 5/9] shape: Add benchmarks for .to_shape (to view) --- benches/to_shape.rs | 106 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 benches/to_shape.rs diff --git a/benches/to_shape.rs b/benches/to_shape.rs new file mode 100644 index 000000000..a048eb774 --- /dev/null +++ b/benches/to_shape.rs @@ -0,0 +1,106 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; +use ndarray::Order; + +#[bench] +fn to_shape2_1(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(4 * 5).unwrap() + }); +} + +#[bench] +fn to_shape2_2_same(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((4, 5)).unwrap() + }); +} + +#[bench] +fn to_shape2_2_flip(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((5, 4)).unwrap() + }); +} + +#[bench] +fn to_shape2_3(bench: &mut Bencher) { + let a = Array::::zeros((4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((2, 5, 2)).unwrap() + }); +} + +#[bench] +fn to_shape3_1(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(3 * 4 * 5).unwrap() + }); +} + +#[bench] +fn to_shape3_2_order(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((12, 5)).unwrap() + }); +} + +#[bench] +fn to_shape3_2_outoforder(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((4, 15)).unwrap() + }); +} + +#[bench] +fn to_shape3_3c(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape((3, 4, 5)).unwrap() + }); +} + +#[bench] +fn to_shape3_3f(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| { + view.to_shape(((3, 4, 5), Order::F)).unwrap() + }); +} + +#[bench] +fn to_shape3_4c(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5)); + let view = a.view(); + bench.iter(|| { + view.to_shape(((2, 3, 2, 5), Order::C)).unwrap() + }); +} + +#[bench] +fn to_shape3_4f(bench: &mut Bencher) { + let a = Array::::zeros((3, 4, 5).f()); + let view = a.view(); + bench.iter(|| { + view.to_shape(((2, 3, 2, 5), Order::F)).unwrap() + }); +} From b157cc68b75f23ce7adce2910709215c468dae82 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 9 May 2021 20:01:26 +0200 Subject: [PATCH 6/9] shape: Add trait Sequence To generalize over Forward/Reverse indexing of dimensions, add simple traits Sequence/Mut that allow indexing forwards and reversed references of dimension values. In a future change, they will support strides (with isize values) too. --- src/dimension/mod.rs | 1 + src/dimension/sequence.rs | 109 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 src/dimension/sequence.rs diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index f4f46e764..e883933b2 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -40,6 +40,7 @@ mod dynindeximpl; mod ndindex; mod ops; mod remove_axis; +mod sequence; /// Calculate offset from `Ix` stride converting sign properly #[inline(always)] diff --git a/src/dimension/sequence.rs b/src/dimension/sequence.rs new file mode 100644 index 000000000..835e00d18 --- /dev/null +++ b/src/dimension/sequence.rs @@ -0,0 +1,109 @@ +use std::ops::Index; +use std::ops::IndexMut; + +use crate::dimension::Dimension; + +pub(in crate::dimension) struct Forward(pub(crate) D); +pub(in crate::dimension) struct Reverse(pub(crate) D); + +impl Index for Forward<&D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[index] + } +} + +impl Index for Forward<&mut D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[index] + } +} + +impl IndexMut for Forward<&mut D> +where + D: Dimension, +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize { + &mut self.0[index] + } +} + +impl Index for Reverse<&D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[self.len() - index - 1] + } +} + +impl Index for Reverse<&mut D> +where + D: Dimension, +{ + type Output = usize; + + #[inline] + fn index(&self, index: usize) -> &usize { + &self.0[self.len() - index - 1] + } +} + +impl IndexMut for Reverse<&mut D> +where + D: Dimension, +{ + #[inline] + fn index_mut(&mut self, index: usize) -> &mut usize { + let len = self.len(); + &mut self.0[len - index - 1] + } +} + +/// Indexable sequence with length +pub(in crate::dimension) trait Sequence: Index { + fn len(&self) -> usize; +} + +/// Indexable sequence with length (mut) +pub(in crate::dimension) trait SequenceMut: Sequence + IndexMut { } + +impl Sequence for Forward<&D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl Sequence for Forward<&mut D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl SequenceMut for Forward<&mut D> where D: Dimension { } + +impl Sequence for Reverse<&D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl Sequence for Reverse<&mut D> where D: Dimension { + #[inline] + fn len(&self) -> usize { self.0.ndim() } +} + +impl SequenceMut for Reverse<&mut D> where D: Dimension { } + From 2d42e5f0ff058abaca79f9eb734975a3967d7fa5 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 9 May 2021 20:01:26 +0200 Subject: [PATCH 7/9] shape: Add function for reshape while preserving memory layout This function can compute the strides needed and if it's possible to reshape an array with given strides into a different shape. This happens using a given Order::{C, F}. --- src/dimension/mod.rs | 6 +- src/dimension/reshape.rs | 241 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+), 2 deletions(-) create mode 100644 src/dimension/reshape.rs diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index e883933b2..4aa7c6641 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -9,10 +9,10 @@ use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::slice::SliceArg; use crate::{Ix, Ixs, Slice, SliceInfoElem}; +use crate::shape_builder::Strides; use num_integer::div_floor; pub use self::axes::{Axes, AxisDescription}; -pub(crate) use self::axes::axes_of; pub use self::axis::Axis; pub use self::broadcast::DimMax; pub use self::conversion::IntoDimension; @@ -23,7 +23,8 @@ pub use self::ndindex::NdIndex; pub use self::ops::DimAdd; pub use self::remove_axis::RemoveAxis; -use crate::shape_builder::Strides; +pub(crate) use self::axes::axes_of; +pub(crate) use self::reshape::reshape_dim; use std::isize; use std::mem; @@ -40,6 +41,7 @@ mod dynindeximpl; mod ndindex; mod ops; mod remove_axis; +pub(crate) mod reshape; mod sequence; /// Calculate offset from `Ix` stride converting sign properly diff --git a/src/dimension/reshape.rs b/src/dimension/reshape.rs new file mode 100644 index 000000000..c6e08848d --- /dev/null +++ b/src/dimension/reshape.rs @@ -0,0 +1,241 @@ + +use crate::{Dimension, Order, ShapeError, ErrorKind}; +use crate::dimension::sequence::{Sequence, SequenceMut, Forward, Reverse}; + +#[inline] +pub(crate) fn reshape_dim(from: &D, strides: &D, to: &E, order: Order) + -> Result +where + D: Dimension, + E: Dimension, +{ + debug_assert_eq!(from.ndim(), strides.ndim()); + let mut to_strides = E::zeros(to.ndim()); + match order { + Order::RowMajor => { + reshape_dim_c(&Forward(from), &Forward(strides), + &Forward(to), Forward(&mut to_strides))?; + } + Order::ColumnMajor => { + reshape_dim_c(&Reverse(from), &Reverse(strides), + &Reverse(to), Reverse(&mut to_strides))?; + } + } + Ok(to_strides) +} + +/// Try to reshape an array with dimensions `from_dim` and strides `from_strides` to the new +/// dimension `to_dim`, while keeping the same layout of elements in memory. The strides needed +/// if this is possible are stored into `to_strides`. +/// +/// This function uses RowMajor index ordering if the inputs are read in the forward direction +/// (index 0 is axis 0 etc) and ColumnMajor index ordering if the inputs are read in reversed +/// direction (as made possible with the Sequence trait). +/// +/// Preconditions: +/// +/// 1. from_dim and to_dim are valid dimensions (product of all non-zero axes +/// fits in isize::MAX). +/// 2. from_dim and to_dim are don't have any axes that are zero (that should be handled before +/// this function). +/// 3. `to_strides` should be an all-zeros or all-ones dimension of the right dimensionality +/// (but it will be overwritten after successful exit of this function). +/// +/// This function returns: +/// +/// - IncompatibleShape if the two shapes are not of matching number of elements +/// - IncompatibleLayout if the input shape and stride can not be remapped to the output shape +/// without moving the array data into a new memory layout. +/// - Ok if the from dim could be mapped to the new to dim. +fn reshape_dim_c(from_dim: &D, from_strides: &D, to_dim: &E, mut to_strides: E2) + -> Result<(), ShapeError> +where + D: Sequence, + E: Sequence, + E2: SequenceMut, +{ + // cursor indexes into the from and to dimensions + let mut fi = 0; // index into `from_dim` + let mut ti = 0; // index into `to_dim`. + + while fi < from_dim.len() && ti < to_dim.len() { + let mut fd = from_dim[fi]; + let mut fs = from_strides[fi] as isize; + let mut td = to_dim[ti]; + + if fd == td { + to_strides[ti] = from_strides[fi]; + fi += 1; + ti += 1; + continue + } + + if fd == 1 { + fi += 1; + continue; + } + + if td == 1 { + to_strides[ti] = 1; + ti += 1; + continue; + } + + if fd == 0 || td == 0 { + debug_assert!(false, "zero dim not handled by this function"); + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + // stride times element count is to be distributed out over a combination of axes. + let mut fstride_whole = fs * (fd as isize); + let mut fd_product = fd; // cumulative product of axis lengths in the combination (from) + let mut td_product = td; // cumulative product of axis lengths in the combination (to) + + // The two axis lengths are not a match, so try to combine multiple axes + // to get it to match up. + while fd_product != td_product { + if fd_product < td_product { + // Take another axis on the from side + fi += 1; + if fi >= from_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + fd = from_dim[fi]; + fd_product *= fd; + if fd > 1 { + let fs_old = fs; + fs = from_strides[fi] as isize; + // check if this axis and the next are contiguous together + if fs_old != fd as isize * fs { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout)); + } + } + } else { + // Take another axis on the `to` side + // First assign the stride to the axis we leave behind + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + ti += 1; + if ti >= to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + td = to_dim[ti]; + td_product *= td; + } + } + + fstride_whole /= td as isize; + to_strides[ti] = fstride_whole as usize; + + fi += 1; + ti += 1; + } + + // skip past 1-dims at the end + while fi < from_dim.len() && from_dim[fi] == 1 { + fi += 1; + } + + while ti < to_dim.len() && to_dim[ti] == 1 { + to_strides[ti] = 1; + ti += 1; + } + + if fi < from_dim.len() || ti < to_dim.len() { + return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); + } + + Ok(()) +} + +#[cfg(feature = "std")] +#[test] +fn test_reshape() { + use crate::Dim; + + macro_rules! test_reshape { + (fail $order:ident from $from:expr, $stride:expr, to $to:expr) => { + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + let _res = res.expect_err("Expected failed reshape"); + }; + (ok $order:ident from $from:expr, $stride:expr, to $to:expr, $to_stride:expr) => {{ + let res = reshape_dim(&Dim($from), &Dim($stride), &Dim($to), Order::$order); + println!("Reshape {:?} {:?} to {:?}, order {:?}\n => {:?}", + $from, $stride, $to, Order::$order, res); + println!("default stride for from dim: {:?}", Dim($from).default_strides()); + println!("default stride for to dim: {:?}", Dim($to).default_strides()); + let res = res.expect("Expected successful reshape"); + assert_eq!(res, Dim($to_stride), "mismatch in strides"); + }}; + } + + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [1, 2, 3], [6, 3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [2, 3], [3, 1]); + test_reshape!(ok C from [1, 2, 3], [6, 3, 1], to [6], [1]); + test_reshape!(fail C from [1, 2, 3], [6, 3, 1], to [1]); + test_reshape!(fail F from [1, 2, 3], [6, 3, 1], to [1]); + + test_reshape!(ok C from [6], [1], to [3, 2], [2, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4, 1], [8, 4, 1, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 4], [8, 4, 1]); + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 2, 2], [8, 4, 2, 1]); + + test_reshape!(ok C from [4, 4], [4, 1], to [2, 2, 1, 4], [8, 4, 1, 1]); + + test_reshape!(ok C from [4, 4, 4], [16, 4, 1], to [16, 4], [4, 1]); + test_reshape!(ok C from [3, 4, 4], [16, 4, 1], to [3, 16], [16, 1]); + + test_reshape!(ok C from [4, 4], [8, 1], to [2, 2, 2, 2], [16, 8, 2, 1]); + + test_reshape!(fail C from [4, 4], [8, 1], to [2, 1, 4, 2]); + + test_reshape!(ok C from [16], [4], to [2, 2, 4], [32, 16, 4]); + test_reshape!(ok C from [16], [-4isize as usize], to [2, 2, 4], + [-32isize as usize, -16isize as usize, -4isize as usize]); + test_reshape!(ok F from [16], [4], to [2, 2, 4], [4, 8, 16]); + test_reshape!(ok F from [16], [-4isize as usize], to [2, 2, 4], + [-4isize as usize, -8isize as usize, -16isize as usize]); + + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [12, 5], [5, 1]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [4, 15], [15, 1]); + test_reshape!(fail F from [3, 4, 5], [20, 5, 1], to [4, 15]); + test_reshape!(ok C from [3, 4, 5, 7], [140, 35, 7, 1], to [28, 15], [15, 1]); + + // preserve stride if shape matches + test_reshape!(ok C from [10], [2], to [10], [2]); + test_reshape!(ok F from [10], [2], to [10], [2]); + test_reshape!(ok C from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok F from [2, 10], [1, 2], to [2, 10], [1, 2]); + test_reshape!(ok C from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + test_reshape!(ok F from [3, 4, 5], [20, 5, 1], to [3, 4, 5], [20, 5, 1]); + + test_reshape!(ok C from [3, 4, 5], [4, 1, 1], to [12, 5], [1, 1]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 12], to [12, 5], [1, 12]); + test_reshape!(ok F from [3, 4, 5], [1, 3, 1], to [12, 5], [1, 1]); + + // broadcast shapes + test_reshape!(ok C from [3, 4, 5, 7], [0, 0, 7, 1], to [12, 35], [0, 1]); + test_reshape!(fail C from [3, 4, 5, 7], [0, 0, 7, 1], to [28, 15]); + + // one-filled shapes + test_reshape!(ok C from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok F from [10], [1], to [1, 10, 1, 1, 1], [1, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok F from [1, 10], [10, 1], to [1, 10, 1, 1, 1], [10, 1, 1, 1, 1]); + test_reshape!(ok C from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 2, 2, 2, 1]); + test_reshape!(ok F from [1, 10], [1, 1], to [1, 5, 1, 1, 2], [1, 1, 5, 5, 5]); + test_reshape!(ok C from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok F from [10, 1, 1, 1, 1], [1, 1, 1, 1, 1], to [10], [1]); + test_reshape!(ok C from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10], [1]); + test_reshape!(fail F from [1, 5, 1, 2, 1], [1, 2, 1, 1, 1], to [10]); + test_reshape!(ok F from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10], [1]); + test_reshape!(fail C from [1, 5, 1, 2, 1], [1, 1, 1, 5, 1], to [10]); +} + From c9af569acbadd91224047a957fc79a0df99f5719 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 9 May 2021 20:01:26 +0200 Subject: [PATCH 8/9] shape: Use reshape_dim function in .to_shape() --- src/impl_methods.rs | 45 ++++++++++++++++++++++++++++----------------- tests/reshape.rs | 21 ++++++++++++++++++++- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index cda2cd95e..6c51b4515 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -23,11 +23,11 @@ use crate::dimension::{ offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; +use crate::dimension::reshape_dim; use crate::error::{self, ErrorKind, ShapeError, from_kind}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::AxisDescription; -use crate::Layout; use crate::order::Order; use crate::shape_builder::ShapeArg; use crate::zip::{IntoNdProducer, Zip}; @@ -1641,27 +1641,38 @@ where A: Clone, S: Data, { - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { + let len = self.dim.size(); + if size_of_shape_checked(&shape) != Ok(len) { return Err(error::incompatible_shapes(&self.dim, &shape)); } - let layout = self.layout_impl(); - unsafe { - if layout.is(Layout::CORDER) && order == Order::RowMajor { - let strides = shape.default_strides(); - Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) - } else if layout.is(Layout::FORDER) && order == Order::ColumnMajor { - let strides = shape.fortran_strides(); - Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides))) - } else { - let (shape, view) = match order { - Order::RowMajor => (shape.set_f(false), self.view()), - Order::ColumnMajor => (shape.set_f(true), self.t()), - }; - Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( - shape, view.into_iter(), A::clone))) + // Create a view if the length is 0, safe because the array and new shape is empty. + if len == 0 { + unsafe { + return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr()))); } } + + // Try to reshape the array as a view into the existing data + match reshape_dim(&self.dim, &self.strides, &shape, order) { + Ok(to_strides) => unsafe { + return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides))); + } + Err(err) if err.kind() == ErrorKind::IncompatibleShape => { + return Err(error::incompatible_shapes(&self.dim, &shape)); + } + _otherwise => { } + } + + // otherwise create a new array and copy the elements + unsafe { + let (shape, view) = match order { + Order::RowMajor => (shape.set_f(false), self.view()), + Order::ColumnMajor => (shape.set_f(true), self.t()), + }; + Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked( + shape, view.into_iter(), A::clone))) + } } /// Transform the array into `shape`; any shape with the same number of diff --git a/tests/reshape.rs b/tests/reshape.rs index f03f4ccf1..19f5b4ae1 100644 --- a/tests/reshape.rs +++ b/tests/reshape.rs @@ -133,7 +133,7 @@ fn to_shape_add_axis() { let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap(); assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view()); - assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned()); + assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view()); } @@ -150,6 +150,16 @@ fn to_shape_copy_stride() { assert!(lin2.is_owned()); } + +#[test] +fn to_shape_zero_len() { + let v = array![[1, 2, 3, 4], [5, 6, 7, 8]]; + let vs = v.slice(s![.., ..0]); + let lin1 = vs.to_shape(0).unwrap(); + assert_eq!(lin1, array![]); + assert!(lin1.is_view()); +} + #[test] #[should_panic(expected = "IncompatibleShape")] fn to_shape_error1() { @@ -157,3 +167,12 @@ fn to_shape_error1() { let v = aview1(&data); let _u = v.to_shape((2, 5)).unwrap(); } + +#[test] +#[should_panic(expected = "IncompatibleShape")] +fn to_shape_error2() { + // overflow + let data = [3, 4, 5, 6, 7, 8]; + let v = aview1(&data); + let _u = v.to_shape((2, usize::MAX)).unwrap(); +} From 4467cd8b51e2b24229319aa27ede2c2483476354 Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 10 May 2021 19:50:47 +0200 Subject: [PATCH 9/9] shape: Add more .to_shape() tests --- tests/reshape.rs | 54 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/reshape.rs b/tests/reshape.rs index 19f5b4ae1..21fe407ea 100644 --- a/tests/reshape.rs +++ b/tests/reshape.rs @@ -176,3 +176,57 @@ fn to_shape_error2() { let v = aview1(&data); let _u = v.to_shape((2, usize::MAX)).unwrap(); } + +#[test] +fn to_shape_discontig() { + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..;2, ..]); // now shape (4, 2, 4) + assert!(a1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v1 = a1.to_shape(((2, 2, 2, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 1, 2, 1, 2, 2), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((4, 2, 4), order)).unwrap(); + assert!(v1.is_view()); + let v1 = a1.to_shape(((8, 4), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::C, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape(((4, 8), order)).unwrap(); + assert_eq!(v1.is_view(), order == create_order && create_order == Order::F, + "failed for {:?}, {:?}", create_order, order); + let v1 = a1.to_shape((32, order)).unwrap(); + assert!(!v1.is_view()); + } + } +} + +#[test] +fn to_shape_broadcast() { + for &create_order in &[Order::C, Order::F] { + let a = Array::from_iter(0..64); + let mut a1 = a.to_shape(((4, 4, 4), create_order)).unwrap(); + a1.slice_collapse(s![.., ..1, ..]); // now shape (4, 1, 4) + let v1 = a1.broadcast((4, 4, 4)).unwrap(); // Now shape (4, 4, 4) + assert!(v1.as_slice_memory_order().is_none()); + + for &order in &[Order::C, Order::F] { + let v2 = v1.to_shape(((2, 2, 2, 2, 2, 2), order)).unwrap(); + assert_eq!(v2.strides(), match (create_order, order) { + (Order::C, Order::C) => { &[32, 16, 0, 0, 2, 1] } + (Order::C, Order::F) => { &[16, 32, 0, 0, 1, 2] } + (Order::F, Order::C) => { &[2, 1, 0, 0, 32, 16] } + (Order::F, Order::F) => { &[1, 2, 0, 0, 16, 32] } + _other => unreachable!() + }); + + let v2 = v1.to_shape(((4, 4, 4), order)).unwrap(); + assert!(v2.is_view()); + let v2 = v1.to_shape(((8, 8), order)).unwrap(); + assert!(v2.is_owned()); + } + } +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy