Skip to content

Commit b70085c

Browse files
committed
shape: Use reshape_dim function in .to_shape()
1 parent f96977f commit b70085c

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

src/impl_methods.rs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ use crate::dimension::{
2323
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
26+
use crate::dimension::reshape_dim;
2627
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2728
use crate::math_cell::MathCell;
2829
use crate::itertools::zip;
2930
use crate::AxisDescription;
30-
use crate::Layout;
3131
use crate::order::Order;
3232
use crate::shape_builder::ShapeArg;
3333
use crate::zip::{IntoNdProducer, Zip};
@@ -1641,27 +1641,38 @@ where
16411641
A: Clone,
16421642
S: Data,
16431643
{
1644-
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
1644+
let len = self.dim.size();
1645+
if size_of_shape_checked(&shape) != Ok(len) {
16451646
return Err(error::incompatible_shapes(&self.dim, &shape));
16461647
}
1647-
let layout = self.layout_impl();
16481648

1649-
unsafe {
1650-
if layout.is(Layout::CORDER) && order == Order::RowMajor {
1651-
let strides = shape.default_strides();
1652-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1653-
} else if layout.is(Layout::FORDER) && order == Order::ColumnMajor {
1654-
let strides = shape.fortran_strides();
1655-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1656-
} else {
1657-
let (shape, view) = match order {
1658-
Order::RowMajor => (shape.set_f(false), self.view()),
1659-
Order::ColumnMajor => (shape.set_f(true), self.t()),
1660-
};
1661-
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1662-
shape, view.into_iter(), A::clone)))
1649+
// Create a view if the length is 0, safe because the array and new shape is empty.
1650+
if len == 0 {
1651+
unsafe {
1652+
return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr())));
16631653
}
16641654
}
1655+
1656+
// Try to reshape the array as a view into the existing data
1657+
match reshape_dim(&self.dim, &self.strides, &shape, order) {
1658+
Ok(to_strides) => unsafe {
1659+
return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides)));
1660+
}
1661+
Err(err) if err.kind() == ErrorKind::IncompatibleShape => {
1662+
return Err(error::incompatible_shapes(&self.dim, &shape));
1663+
}
1664+
_otherwise => { }
1665+
}
1666+
1667+
// otherwise create a new array and copy the elements
1668+
unsafe {
1669+
let (shape, view) = match order {
1670+
Order::RowMajor => (shape.set_f(false), self.view()),
1671+
Order::ColumnMajor => (shape.set_f(true), self.t()),
1672+
};
1673+
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1674+
shape, view.into_iter(), A::clone)))
1675+
}
16651676
}
16661677

16671678
/// Transform the array into `shape`; any shape with the same number of

tests/reshape.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn to_shape_add_axis() {
133133
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
134134

135135
assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
136-
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
136+
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view());
137137
}
138138

139139

@@ -150,10 +150,29 @@ fn to_shape_copy_stride() {
150150
assert!(lin2.is_owned());
151151
}
152152

153+
154+
#[test]
155+
fn to_shape_zero_len() {
156+
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
157+
let vs = v.slice(s![.., ..0]);
158+
let lin1 = vs.to_shape(0).unwrap();
159+
assert_eq!(lin1, array![]);
160+
assert!(lin1.is_view());
161+
}
162+
153163
#[test]
154164
#[should_panic(expected = "IncompatibleShape")]
155165
fn to_shape_error1() {
156166
let data = [1, 2, 3, 4, 5, 6, 7, 8];
157167
let v = aview1(&data);
158168
let _u = v.to_shape((2, 5)).unwrap();
159169
}
170+
171+
#[test]
172+
#[should_panic(expected = "IncompatibleShape")]
173+
fn to_shape_error2() {
174+
// overflow
175+
let data = [3, 4, 5, 6, 7, 8];
176+
let v = aview1(&data);
177+
let _u = v.to_shape((2, usize::MAX)).unwrap();
178+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy