From da365be33b6b7fbe60e63aaf2e8b3dc0671ddbe8 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Mar 2021 18:50:51 +0200 Subject: [PATCH 1/3] TEST: Add benchmarks for Result<_, ShapeError> methods --- benches/error-handling.rs | 110 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 benches/error-handling.rs diff --git a/benches/error-handling.rs b/benches/error-handling.rs new file mode 100644 index 000000000..0c2b1136c --- /dev/null +++ b/benches/error-handling.rs @@ -0,0 +1,110 @@ +#![feature(test)] +#![allow( + clippy::many_single_char_names, + clippy::deref_addrof, + clippy::unreadable_literal, + clippy::many_single_char_names +)] +extern crate test; +use test::Bencher; + +use ndarray::prelude::*; +use ndarray::ErrorKind; + +// Use ZST elements to remove allocation from the benchmarks + +#[derive(Copy, Clone, Debug)] +struct Zst; + +type A4 = Array4; + +#[bench] +fn from_elem(bench: &mut Bencher) { + bench.iter(|| { + A4::from_elem((1, 2, 3, 4), Zst) + }) +} + +#[bench] +fn from_shape_vec_ok(bench: &mut Bencher) { + bench.iter(|| { + let v: Vec = vec![Zst; 1 * 2 * 3 * 4]; + let x = A4::from_shape_vec((1, 2, 3, 4).strides((24, 12, 4, 1)), v); + debug_assert!(x.is_ok(), "problem with {:?}", x); + x + }) +} + +#[bench] +fn from_shape_vec_fail(bench: &mut Bencher) { + bench.iter(|| { + let v: Vec = vec![Zst; 1 * 2 * 3 * 4]; + let x = A4::from_shape_vec((1, 2, 3, 4).strides((4, 3, 2, 1)), v); + debug_assert!(x.is_err()); + x + }) +} + +#[bench] +fn into_shape_fail(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((5, 3, 2, 1)) + }) +} + +#[bench] +fn into_shape_ok_c(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((4, 3, 2, 1)) + }) +} + +#[bench] +fn into_shape_ok_f(bench: &mut Bencher) { + let a = A4::from_elem((1, 2, 3, 4).f(), Zst); + let v = a.view(); + bench.iter(|| { + v.clone().into_shape((4, 3, 2, 1)) + }) +} + +#[bench] +fn stack_ok(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter().collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(1), &rows); + debug_assert!(res.is_ok(), "err {:?}", res); + res + }); +} + +#[bench] +fn stack_err_axis(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter().collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(2), &rows); + debug_assert!(res.is_err()); + res + }); +} + +#[bench] +fn stack_err_shape(bench: &mut Bencher) { + let a = Array::from_elem((15, 15), Zst); + let rows = a.rows().into_iter() + .enumerate() + .map(|(i, mut row)| { row.slice_collapse(s![..(i as isize)]); row }) + .collect::>(); + bench.iter(|| { + let res = ndarray::stack(Axis(1), &rows); + debug_assert!(res.is_err()); + debug_assert_eq!(res.clone().unwrap_err().kind(), ErrorKind::IncompatibleShape); + res + }); +} From 3a59caafa91c0266d1427f3e072619133569b622 Mon Sep 17 00:00:00 2001 From: bluss Date: Sun, 28 Mar 2021 13:33:31 +0200 Subject: [PATCH 2/3] FEAT: Encode expected/actual info in ShapeError MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a (limited) way to add specific information to a ShapeError Admittedly wonky, but space efficient. Result<(), ShapeError> used to be 1 byte, and with this change it expands to 16 bytes (2 usize on 64-bit). The remaining 15 bytes are used for optimistically packing as much of extra info into the error message as possible. For example we can store expected/actual index for errors (for example index out of bounds or axis out of bounds, these are not so commonly handled with ShapeError). With this change it is supported: - Expected/actual index with 7 bytes per index - Expected/actual shape with 7 bytes per shape supports storing shapes with one or two bytes (< 256²) per dimension, with limited ndim. --- Cargo.toml | 1 + src/error.rs | 525 +++++++++++++++++++++++++++++++++++++++++++++- tests/array.rs | 8 +- tests/stacking.rs | 6 + 4 files changed, 530 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 61c26d4cd..8288e6dc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ defmac = "0.2" quickcheck = { version = "0.9", default-features = false } approx = "0.4" itertools = { version = "0.10.0", default-features = false, features = ["use_std"] } +matches = "0.1.8" [features] default = ["std"] diff --git a/src/error.rs b/src/error.rs index c45496142..177711c9b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,16 +5,26 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. +#![allow(clippy::identity_op)] use super::Dimension; +use crate::itertools::enumerate; + #[cfg(feature = "std")] use std::error::Error; use std::fmt; +use std::mem::size_of; /// An error related to array shape or layout. +/// +/// The shape error encodes and shows expected/actual indices and shapes in some cases, which +/// is visible in the Display/Debug representation. Since this is done without allocation, it is +/// space-limited and bigger indices and shapes may not be representable. #[derive(Clone)] pub struct ShapeError { - // we want to be able to change this representation later + /// Error category repr: ErrorKind, + /// Additional info + info: InfoType, } impl ShapeError { @@ -24,9 +34,52 @@ impl ShapeError { self.repr } - /// Create a new `ShapeError` + /// Create a new `ShapeError` from the given kind pub fn from_kind(error: ErrorKind) -> Self { - from_kind(error) + Self::from_kind_info(error, info_default()) + } + + fn from_kind_info(repr: ErrorKind, info: InfoType) -> Self { + ShapeError { repr, info } + } + + pub(crate) fn invalid_axis(expected: usize, actual: usize) -> Self { + // TODO: OutOfBounds for compatibility reasons, should be more specific + Self::from_kind_info(ErrorKind::OutOfBounds, encode_indices(expected, actual)) + } + + pub(crate) fn shape_length_exceeds_data_length(expected: usize, actual: usize) -> Self { + // TODO: OutOfBounds for compatibility reasons, should be more specific + Self::from_kind_info(ErrorKind::OutOfBounds, encode_indices(expected, actual)) + } + + pub(crate) fn incompatible_layout(expected: ExpectedLayout) -> Self { + Self::from_kind_info(ErrorKind::IncompatibleLayout, encode_indices(expected as usize, 0)) + } + + pub(crate) fn incompatible_shapes(expected: &D, actual: &E) -> ShapeError + where + D: Dimension, + E: Dimension, + { + Self::from_kind_info(ErrorKind::IncompatibleShape, encode_shapes(expected, actual)) + } + + #[cfg(test)] + fn info_expected_index(&self) -> Option { + let (exp, _) = decode_indices(self.info); + exp + } + + #[cfg(test)] + fn info_actual_index(&self) -> Option { + let (_, actual) = decode_indices(self.info); + actual + } + + #[cfg(test)] + fn decode_shapes(&self) -> (Option, Option) { + decode_shapes(self.info) } } @@ -38,12 +91,15 @@ impl ShapeError { #[derive(Copy, Clone, Debug)] pub enum ErrorKind { /// incompatible shape + // encodes info: expected and actual shape IncompatibleShape = 1, /// incompatible memory layout + // encodes info: expected layout IncompatibleLayout, /// the shape does not fit inside type limits RangeLimited, /// out of bounds indexing + // encodes info: expected and actual index OutOfBounds, /// aliasing array elements Unsupported, @@ -52,8 +108,8 @@ pub enum ErrorKind { } #[inline(always)] -pub fn from_kind(k: ErrorKind) -> ShapeError { - ShapeError { repr: k } +pub fn from_kind(error: ErrorKind) -> ShapeError { + ShapeError::from_kind(error) } impl PartialEq for ErrorKind { @@ -83,7 +139,7 @@ impl fmt::Display for ShapeError { ErrorKind::Unsupported => "unsupported operation", ErrorKind::Overflow => "arithmetic overflow", }; - write!(f, "ShapeError/{:?}: {}", self.kind(), description) + write!(f, "ShapeError/{:?}: {}{}", self.kind(), description, ExtendedInfo(&self)) } } @@ -93,10 +149,463 @@ impl fmt::Debug for ShapeError { } } -pub fn incompatible_shapes(_a: &D, _b: &E) -> ShapeError +pub(crate) enum ExpectedLayout { + ContiguousCF = 1, + Unused, +} + +impl From> for ExpectedLayout { + #[inline] + fn from(x: Option) -> Self { + match x { + Some(1) => ExpectedLayout::ContiguousCF, + _ => ExpectedLayout::Unused, + } + } +} + +pub(crate) fn incompatible_shapes(_a: &D, _b: &E) -> ShapeError where D: Dimension, E: Dimension, { - from_kind(ErrorKind::IncompatibleShape) + ShapeError::incompatible_shapes(_a, _b) +} + +/// The InfoType encodes extra information per error kind, for example expected/actual axis for a +/// given error site, or expected layout for a layout error. +/// +/// It uses a custom and fixed-width (very limited) encoding; in some cases it skips filling in +/// information because it doesn't fit. +/// +/// Two bits in the first byte are reserved for EncodedInformationType, the rest is used +/// for situation-specific encoding of extra info. +/// If the first byte is zero, it shows that there is no encoded info. +type InfoType = [u8; INFO_TYPE_LEN]; + +const INFO_TYPE_LEN: usize = 15; +const INFO_BYTES: usize = INFO_TYPE_LEN - 1; + +const fn info_default() -> InfoType { [0; INFO_TYPE_LEN] } + +#[repr(u8)] +// 2 bits +enum EncodedInformationType { + Nothing = 0, + Expected = 0b1, + Actual = 0b10, +} + +const IXBYTES: usize = INFO_BYTES / 2; + +fn encode_index(x: usize) -> Option<[u8; IXBYTES]> { + let bits = size_of::() * 8; + let used_bits = bits - x.leading_zeros() as usize; + if used_bits > IXBYTES * 8 { + None + } else { + let bytes = x.to_le_bytes(); + let mut result = [0; IXBYTES]; + let len = bytes.len().min(result.len()); + result[..len].copy_from_slice(&bytes[..len]); + Some(result) + } +} + +fn decode_index(x: &[u8]) -> usize { + let mut bytes = 0usize.to_le_bytes(); + let len = x.len().min(bytes.len()); + bytes[..len].copy_from_slice(&x[..len]); + usize::from_le_bytes(bytes) +} + +fn encode_indices(expected: usize, actual: usize) -> InfoType { + let eexp = encode_index(expected); + let eact = encode_index(actual); + let mut info = info_default(); + let mut info_type = EncodedInformationType::Nothing as u8; + if eexp.is_some() { + info_type |= EncodedInformationType::Expected as u8; + } + let (ebytes, abytes) = info[1..].split_at_mut(IXBYTES); + if let Some(exp) = eexp { + ebytes.copy_from_slice(&exp); + } + if eact.is_some() { + info_type |= EncodedInformationType::Actual as u8; + } + if let Some(act) = eact { + abytes.copy_from_slice(&act); + } + info[0] = info_type; + info +} + +fn decode_indices(info: InfoType) -> (Option, Option) { + let (ebytes, abytes) = info[1..].split_at(IXBYTES); + ( + if info[0] & (EncodedInformationType::Expected as u8) != 0 { + Some(decode_index(ebytes)) + } else { None }, + if info[0] & (EncodedInformationType::Actual as u8) != 0 { + Some(decode_index(abytes)) + } else { None }, + ) + +} + +fn encode_shapes(expected: &D, actual: &E) -> InfoType +where + D: Dimension, + E: Dimension, +{ + encode_shapes_impl(expected.slice(), actual.slice()) +} + +// Shape encoding +// +// 15 bytes to use +// 1 byte: +// 1 bit expected shape is two-byte encoded yes/no (a) +// 3 bits expected shape len (len 0..=7) (b) +// 1 bit actual shape is two-byte encoded yes/no (c) +// 3 bits actual shape len (len 0..=7) (d) +// 14 bytes encoding of expected shape and actual shape: +// X bytes for expected: X = (a + 1) * (b) bytes +// then directly following: +// Y bytes for actual: Y = (c + 1) * (d) bytes + +const SHAPE_MAX_LEN: usize = (1 << 3) - 1; + +struct ShapeEncoding { + len: usize, + element_width: usize, + data: [u8; INFO_BYTES - 1], +} + +#[derive(Copy, Clone)] +enum EncodingWidth { + One = 1, + Two = 2, +} + +fn encode_shape(shape: &[usize], use_width: EncodingWidth) -> ShapeEncoding { + let mut info = [0; INFO_BYTES - 1]; + match use_width { + EncodingWidth::One => { + for (i, &d) in enumerate(shape) { + debug_assert!(d < 256); + info[i] = d as u8; + } + ShapeEncoding { + len: shape.len(), + element_width: 1, + data: info, + } + } + EncodingWidth::Two => { + for (i, &d) in enumerate(shape) { + debug_assert!(d < 256 * 256); + let dbytes = d.to_le_bytes(); + info[2 * i] = dbytes[0]; + info[2 * i + 1] = dbytes[1]; + } + ShapeEncoding { + len: shape.len(), + element_width: 2, + data: info, + } + } + } +} + +fn encode_shapes_impl(expected: &[usize], actual: &[usize]) -> InfoType { + let exp_onebyte = expected.iter().all(|&i| i < 256); + let exp_fit = exp_onebyte && expected.len() <= SHAPE_MAX_LEN || + expected.iter().all(|&i| i < 256 * 256) && expected.len() <= (INFO_BYTES - 1) / 2; + let act_onebyte = actual.iter().all(|&i| i < 256); + + let mut info = info_default(); + let mut info_type = EncodedInformationType::Nothing as u8; + let mut shape_header = 0; + + let mut remaining_len = INFO_BYTES - 1; + if exp_fit { + info_type |= EncodedInformationType::Expected as u8; + let eexp = encode_shape(expected, if exp_onebyte { EncodingWidth::One } else { EncodingWidth::Two }); + shape_header |= (!exp_onebyte as u8) << 0; + + info[2..].copy_from_slice(&eexp.data[..]); + remaining_len -= eexp.len * eexp.element_width; + shape_header |= (eexp.len as u8) << 1; + } + + if remaining_len > 0 { + if (act_onebyte && remaining_len >= actual.len()) || + remaining_len / 2 >= actual.len() + { + info_type |= EncodedInformationType::Actual as u8; + let eact = encode_shape(actual, if act_onebyte { EncodingWidth::One } else { EncodingWidth::Two }); + shape_header |= (!act_onebyte as u8) << 4; + let data_start = INFO_BYTES - 1 - remaining_len; + + info[2 + data_start..].copy_from_slice(&eact.data[..remaining_len]); + shape_header |= (eact.len as u8) << 5; + } else { + // skip encoding + } + } + info[0] = info_type; + info[1] = shape_header; + info +} + +#[derive(Default)] +#[cfg_attr(test, derive(Debug))] +struct DecodedShape { + len: usize, + shape: [usize; 8], +} + +impl DecodedShape { + fn as_slice(&self) -> &[usize] { + &self.shape[..self.len] + } +} + +fn decode_shape(data: &[u8], len: usize, width: EncodingWidth) -> DecodedShape { + debug_assert!(len * (width as usize) <= data.len(), + "Too short data when decoding shape"); + let mut shape = DecodedShape { len, ..<_>::default() }; + match width { + EncodingWidth::One => { + for (i, &d) in (0..len).zip(data) { + shape.shape[i] = d as usize; + } + } + EncodingWidth::Two => { + for i in 0..len { + let mut bytes = 0usize.to_le_bytes(); + bytes[0] = data[2 * i]; + bytes[1] = data[2 * i + 1]; + shape.shape[i] = usize::from_le_bytes(bytes); + } + } + } + shape +} + +fn decode_shapes(info: InfoType) -> (Option, Option) { + let exp_present = info[0] & (EncodedInformationType::Expected as u8) != 0; + let act_present = info[0] & (EncodedInformationType::Actual as u8) != 0; + let exp_twobyte = ((info[1] >> 0) & 0b1) != 0; + let act_twobyte = ((info[1] >> 4) & 0b1) != 0; + let exp_len_mask = if !act_present { !0u8 } else { (1u8 << 3) - 1 }; + let exp_len = ((info[1] >> 1) & exp_len_mask) as usize; + let act_len = ((info[1] >> 5) & 0b111) as usize; + let mut start = 2; + let exp = if exp_present { + let width = if !exp_twobyte { EncodingWidth::One } else { EncodingWidth::Two }; + let exp = decode_shape(&info[start..], exp_len, width); + start += exp_len * width as usize; + Some(exp) + } else { None }; + let act = if act_present { + let width = if !act_twobyte { EncodingWidth::One } else { EncodingWidth::Two }; + let act = decode_shape(&info[start..], act_len, width); + Some(act) + } else { None }; + (exp, act) +} + +#[derive(Copy, Clone)] +struct ExtendedInfo<'a>(&'a ShapeError); + +impl<'a> ExtendedInfo<'a> { + fn has_info(&self) -> bool { + self.0.info[0] != EncodedInformationType::Nothing as u8 + } +} + +impl<'a> fmt::Display for ExtendedInfo<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.has_info() { + return Ok(()); + } + // Use the wording of "expected: X, but got: Y" for + // expected and actual part of the error exented info. + match self.0.kind() { + ErrorKind::IncompatibleLayout => { + let (expected, _) = decode_indices(self.0.info); + match ExpectedLayout::from(expected) { + ExpectedLayout::ContiguousCF => { + write!(f, "; expected c- or f-contiguous input")?; + } + ExpectedLayout::Unused => {} + } + } + ErrorKind::IncompatibleShape => { + let (expected, actual) = decode_shapes(self.0.info); + write!(f, "; expected compatible: ")?; + if let Some(value) = expected { + write!(f, "{:?}", value.as_slice())?; + } else { + write!(f, "unknown")?; + } + if let Some(value) = actual { + write!(f, ", but got: {:?}", value.as_slice())?; + } else { + write!(f, "unknown")?; + } + } + _otherwise => { + let (expected, actual) = decode_indices(self.0.info); + write!(f, "; expected: ")?; + if let Some(value) = expected { + write!(f, "{}", value)?; + } else { + write!(f, "unknown")?; + } + + write!(f, ", but got: ")?; + if let Some(value) = actual { + write!(f, "{}", value)?; + } else { + write!(f, "unknown")?; + } + } + } + Ok(()) + } +} + + +#[cfg(test)] +use matches::assert_matches; +#[cfg(test)] +use crate::IntoDimension; + +#[test] +fn test_sizes() { + assert!(size_of::() <= size_of::()); + assert!(size_of::() <= 16); + + assert_eq!(size_of::>(), size_of::()); +} + +#[test] +fn test_encode_decode_format() { + use alloc::string::ToString; + + assert_eq!( + ShapeError::invalid_axis(1, 0).to_string(), + "ShapeError/OutOfBounds: out of bounds indexing; expected: 1, but got: 0"); + + if size_of::() > 4 { + assert_eq!( + ShapeError::invalid_axis(usize::MAX, usize::MAX).to_string(), + "ShapeError/OutOfBounds: out of bounds indexing"); + } + + assert_eq!( + ShapeError::incompatible_shapes(&(1, 2, 3).into_dimension(), &(2, 3).into_dimension()) + .to_string(), + "ShapeError/IncompatibleShape: \ + incompatible shapes; expected compatible: [1, 2, 3], but got: [2, 3]"); +} + +#[test] +fn test_encode_decode() { + for &i in [0, 1, 2, 3, 10, 32, 256, 1736, 16300].iter() { + let err = ShapeError::invalid_axis(i, 0); + assert_eq!(err.info_expected_index(), Some(i)); + let err = ShapeError::invalid_axis(0, i); + assert_eq!(err.info_actual_index(), Some(i)); + } + + let err = ShapeError::invalid_axis(1 << 24, (1 << 24) + 1); + assert_eq!(err.info_expected_index(), Some(1 << 24)); + assert_eq!(err.info_actual_index(), Some((1 << 24) + 1)); + + if size_of::() > 4 { + // use .wrapping_shl(_) for portability + let err = ShapeError::invalid_axis(1usize.wrapping_shl(56) - 1, 0); + assert_eq!(err.info_expected_index(), Some(1usize.wrapping_shl(56) - 1)); + assert_eq!(err.info_actual_index(), Some(0)); + + let err = ShapeError::invalid_axis(1usize.wrapping_shl(56), 1usize.wrapping_shl(56)); + assert_eq!(err.info_expected_index(), None); + assert_eq!(err.info_actual_index(), None); + + let err = ShapeError::invalid_axis(usize::MAX, usize::MAX); + assert_eq!(err.info_expected_index(), None); + assert_eq!(err.info_actual_index(), None); + } +} + +#[test] +fn test_encode_decode_shape() { + let err = ShapeError::incompatible_shapes(&(1, 2).into_dimension(), &(4, 5).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2]); + assert_eq!(act.unwrap().as_slice(), &[4, 5]); + + let err = ShapeError::incompatible_shapes(&(1, 2, 3).into_dimension(), &(4, 5, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2, 3]); + assert_eq!(act.unwrap().as_slice(), &[4, 5, 6]); + + let err = ShapeError::incompatible_shapes(&().into_dimension(), &().into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[]); + assert_eq!(act.unwrap().as_slice(), &[]); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&(m, n).into_dimension(), &(m + 1, n + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m, n]); + assert_eq!(act.unwrap().as_slice(), &[m + 1, n + 1]); + //assert!(act.is_none()); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&(m, n).into_dimension(), &(m + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m, n]); + assert_eq!(act.unwrap().as_slice(), &[m + 1]); + + let (m, n) = (256, 768); + let err = ShapeError::incompatible_shapes(&m.into_dimension(), &(m + 1, n + 1).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[m]); + assert_eq!(act.unwrap().as_slice(), &[m + 1, n + 1]); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024).into_dimension(), &(4, 500, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024]); + assert_eq!(act.unwrap().as_slice(), &[4, 500, 6]); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024, 3, 300).into_dimension(), &(4, 500, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024, 3, 300]); + assert_matches!(act, None); + + let err = ShapeError::incompatible_shapes(&(768, 2, 1024, 3, 300).into_dimension(), &(4, 6).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[768, 2, 1024, 3, 300]); + assert_eq!(act.unwrap().as_slice(), &[4, 6]); + + let err = ShapeError::incompatible_shapes(&().into_dimension(), &(768, 2, 1024).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[]); + assert_eq!(act.unwrap().as_slice(), &[768, 2, 1024]); + + let err = ShapeError::incompatible_shapes(&[1, 2, 3, 4, 5, 6, 7, 8].into_dimension(), &().into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_matches!(exp, None); + assert_eq!(act.unwrap().as_slice(), &[]); + + let err = ShapeError::incompatible_shapes(&[1, 2, 3, 4, 5, 6, 7].into_dimension(), &(1, 2).into_dimension()); + let (exp, act) = err.decode_shapes(); + assert_eq!(exp.unwrap().as_slice(), &[1, 2, 3, 4, 5, 6, 7]); + assert_eq!(act.unwrap().as_slice(), &[1, 2]); } diff --git a/tests/array.rs b/tests/array.rs index 8e084e49e..87cca71a6 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1377,7 +1377,9 @@ fn reshape() { fn reshape_error1() { let data = [1, 2, 3, 4, 5, 6, 7, 8]; let v = aview1(&data); - let _u = v.into_shape((2, 5)).unwrap(); + let res = v.into_shape((2, 5)); + println!("{:?}", res); + res.unwrap(); } #[test] @@ -1387,7 +1389,9 @@ fn reshape_error2() { let v = aview1(&data); let mut u = v.into_shape((2, 2, 2)).unwrap(); u.swap_axes(0, 1); - let _s = u.into_shape((2, 4)).unwrap(); + let res = u.into_shape((2, 4)); + println!("{:?}", res); + res.unwrap(); } #[test] diff --git a/tests/stacking.rs b/tests/stacking.rs index 032525ffa..cd262dfe5 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -16,12 +16,15 @@ fn concatenating() { assert_eq!(d, aview1(&[2., 2., 9., 9.])); let res = ndarray::concatenate(Axis(1), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } @@ -36,11 +39,14 @@ fn stacking() { let c = arr2(&[[3., 2., 3.], [2., 3., 2.]]); let res = ndarray::stack(Axis(1), &[a.view(), c.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::IncompatibleShape); let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]); + println!("{:?}", res); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } From 5f34c4ac868a3089ccabe58a2e353831e289d72a Mon Sep 17 00:00:00 2001 From: bluss Date: Mon, 29 Mar 2021 20:15:05 +0200 Subject: [PATCH 3/3] FIX: Add extra information to errors where possible Where possible, add expected/actual information to the ShapeError. In many places it is identified new places where more specific ErrorKinds and error messages are needed. These are not updated here - a comment is inserted - this will be updated in a future version, when we can accept breaking changes. --- src/dimension/broadcast.rs | 1 + src/dimension/mod.rs | 16 ++++++++++++---- src/impl_methods.rs | 14 +++++++++----- src/slice.rs | 2 ++ src/stacking.rs | 24 ++++++++++++++++-------- 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index dc1513f04..a23d24921 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -27,6 +27,7 @@ where if *out == 1 { *out = *s2 } else if *s2 != 1 { + // TODO More specific error axis length mismatch return Err(from_kind(ErrorKind::IncompatibleShape)); } } diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index f4f46e764..057e39aed 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -89,6 +89,7 @@ pub fn size_of_shape_checked(dim: &D) -> Result .try_fold(1usize, |acc, &d| acc.checked_mul(d)) .ok_or_else(|| from_kind(ErrorKind::Overflow))?; if size_nonzero > ::std::isize::MAX as usize { + // TODO More specific error Err(from_kind(ErrorKind::Overflow)) } else { Ok(dim.size()) @@ -137,7 +138,7 @@ pub(crate) fn can_index_slice_not_custom(data_len: usize, dim: &D) let len = size_of_shape_checked(dim)?; // Condition 2. if len > data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len, len)); } Ok(()) } @@ -170,6 +171,7 @@ where { // Condition 1. if dim.ndim() != strides.ndim() { + // TODO More specific error for dimension stride dimensionality mismatch return Err(from_kind(ErrorKind::IncompatibleLayout)); } @@ -185,9 +187,11 @@ where let off = d.saturating_sub(1).checked_mul(s.abs() as usize)?; acc.checked_add(off) }) + // TODO More specific error .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2a. if max_offset > isize::MAX as usize { + // TODO More specific error return Err(from_kind(ErrorKind::Overflow)); } @@ -195,9 +199,11 @@ where // greatest address accessible by moving along all axes let max_offset_bytes = max_offset .checked_mul(elem_size) + // TODO More specific error .ok_or_else(|| from_kind(ErrorKind::Overflow))?; // Condition 2b. if max_offset_bytes > isize::MAX as usize { + // TODO More specific error return Err(from_kind(ErrorKind::Overflow)); } @@ -256,15 +262,16 @@ fn can_index_slice_impl( // Check condition 3. let is_empty = dim.slice().iter().any(|&d| d == 0); if is_empty && max_offset > data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len, max_offset)); } if !is_empty && max_offset >= data_len { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::shape_length_exceeds_data_length(data_len.wrapping_sub(1), max_offset)); } // Check condition 4. if !is_empty && dim_stride_overlap(dim, strides) { - return Err(from_kind(ErrorKind::Unsupported)); + // TODO: More specific error kind Strides result in overlapping elements + return Err(ShapeError::from_kind(ErrorKind::Unsupported)); } Ok(()) @@ -293,6 +300,7 @@ where { for &stride in strides.slice() { if (stride as isize) < 0 { + // TODO: More specific error kind Non-negative strides required return Err(from_kind(ErrorKind::Unsupported)); } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4045f2b59..82cc1cc0a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -23,7 +23,7 @@ use crate::dimension::{ offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes, }; use crate::dimension::broadcast::co_broadcast; -use crate::error::{self, ErrorKind, ShapeError, from_kind}; +use crate::error::{self, ErrorKind, ShapeError}; use crate::math_cell::MathCell; use crate::itertools::zip; use crate::zip::{IntoNdProducer, Zip}; @@ -1588,7 +1588,7 @@ where } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { Ok(self.with_strides_dim(shape.fortran_strides(), shape)) } else { - Err(error::from_kind(error::ErrorKind::IncompatibleLayout)) + Err(ShapeError::incompatible_layout(error::ExpectedLayout::ContiguousCF)) } } } @@ -1693,6 +1693,7 @@ where } } } + // TODO More specific error incompatible ndim Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) } @@ -1805,11 +1806,14 @@ where { let shape = co_broadcast::>::Output>(&self.dim, &other.dim)?; if let Some(view1) = self.broadcast(shape.clone()) { - if let Some(view2) = other.broadcast(shape) { - return Ok((view1, view2)); + if let Some(view2) = other.broadcast(shape.clone()) { + Ok((view1, view2)) + } else { + Err(ShapeError::incompatible_shapes(&other.dim, &shape)) } + } else { + Err(ShapeError::incompatible_shapes(&other.dim, &shape)) } - Err(from_kind(ErrorKind::IncompatibleShape)) } /// Swap axes `ax` and `bx`. diff --git a/src/slice.rs b/src/slice.rs index 3c554a5ca..da76cda74 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -415,11 +415,13 @@ where { if let Some(in_ndim) = Din::NDIM { if in_ndim != indices.in_ndim() { + // TODO More specific error incompatible ndim return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } if let Some(out_ndim) = Dout::NDIM { if out_ndim != indices.out_ndim() { + // TODO More specific error incompatible ndim return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)); } } diff --git a/src/stacking.rs b/src/stacking.rs index 500ded6af..213e2ac1f 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -8,6 +8,7 @@ use crate::error::{from_kind, ErrorKind, ShapeError}; use crate::imp_prelude::*; +use crate::NdProducer; /// Stack arrays along the new axis. /// @@ -72,18 +73,22 @@ where D: RemoveAxis, { if arrays.is_empty() { + // TODO More specific error for empty input not supported return Err(from_kind(ErrorKind::Unsupported)); } let mut res_dim = arrays[0].raw_dim(); if axis.index() >= res_dim.ndim() { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::invalid_axis(res_dim.ndim().wrapping_sub(1), axis.index())); } let common_dim = res_dim.remove_axis(axis); - if arrays - .iter() - .any(|a| a.raw_dim().remove_axis(axis) != common_dim) + if let Some(a) = arrays.iter().find_map(|a| + if a.raw_dim().remove_axis(axis) != common_dim { + Some(a) + } else { + None + }) { - return Err(from_kind(ErrorKind::IncompatibleShape)); + return Err(ShapeError::incompatible_shapes(&common_dim, &a.dim)); } let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis)); @@ -143,17 +148,20 @@ where D::Larger: RemoveAxis, { if arrays.is_empty() { + // TODO More specific error for empty input not supported return Err(from_kind(ErrorKind::Unsupported)); } let common_dim = arrays[0].raw_dim(); // Avoid panic on `insert_axis` call, return an Err instead of it. if axis.index() > common_dim.ndim() { - return Err(from_kind(ErrorKind::OutOfBounds)); + return Err(ShapeError::invalid_axis(common_dim.ndim(), axis.index())); } let mut res_dim = common_dim.insert_axis(axis); - if arrays.iter().any(|a| a.raw_dim() != common_dim) { - return Err(from_kind(ErrorKind::IncompatibleShape)); + if let Some(array) = arrays.iter().find_map(|array| if !array.equal_dim(&common_dim) { + Some(array) + } else { None }) { + return Err(ShapeError::incompatible_shapes(&common_dim, &array.dim)); } res_dim.set_axis(axis, arrays.len()); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy