ndarray/numeric/
impl_numeric.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17
18/// # Numerical Methods for Arrays
19impl<A, S, D> ArrayBase<S, D>
20where
21    S: Data<Elem = A>,
22    D: Dimension,
23{
24    /// Return the sum of all elements in the array.
25    ///
26    /// ```
27    /// use ndarray::arr2;
28    ///
29    /// let a = arr2(&[[1., 2.],
30    ///                [3., 4.]]);
31    /// assert_eq!(a.sum(), 10.);
32    /// ```
33    pub fn sum(&self) -> A
34    where A: Clone + Add<Output = A> + num_traits::Zero
35    {
36        if let Some(slc) = self.as_slice_memory_order() {
37            return numeric_util::unrolled_fold(slc, A::zero, A::add);
38        }
39        let mut sum = A::zero();
40        for row in self.rows() {
41            if let Some(slc) = row.as_slice() {
42                sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
43            } else {
44                sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
45            }
46        }
47        sum
48    }
49
50    /// Returns the [arithmetic mean] x̅ of all elements in the array:
51    ///
52    /// ```text
53    ///     1   n
54    /// x̅ = ―   ∑ xᵢ
55    ///     n  i=1
56    /// ```
57    ///
58    /// If the array is empty, `None` is returned.
59    ///
60    /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
61    ///
62    /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
63    pub fn mean(&self) -> Option<A>
64    where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
65    {
66        let n_elements = self.len();
67        if n_elements == 0 {
68            None
69        } else {
70            let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
71            Some(self.sum() / n_elements)
72        }
73    }
74
75    /// Return the product of all elements in the array.
76    ///
77    /// ```
78    /// use ndarray::arr2;
79    ///
80    /// let a = arr2(&[[1., 2.],
81    ///                [3., 4.]]);
82    /// assert_eq!(a.product(), 24.);
83    /// ```
84    pub fn product(&self) -> A
85    where A: Clone + Mul<Output = A> + num_traits::One
86    {
87        if let Some(slc) = self.as_slice_memory_order() {
88            return numeric_util::unrolled_fold(slc, A::one, A::mul);
89        }
90        let mut sum = A::one();
91        for row in self.rows() {
92            if let Some(slc) = row.as_slice() {
93                sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
94            } else {
95                sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
96            }
97        }
98        sum
99    }
100
101    /// Return variance of elements in the array.
102    ///
103    /// The variance is computed using the [Welford one-pass
104    /// algorithm](https://www.jstor.org/stable/1266577).
105    ///
106    /// The parameter `ddof` specifies the "delta degrees of freedom". For
107    /// example, to calculate the population variance, use `ddof = 0`, or to
108    /// calculate the sample variance, use `ddof = 1`.
109    ///
110    /// The variance is defined as:
111    ///
112    /// ```text
113    ///               1       n
114    /// variance = ――――――――   ∑ (xᵢ - x̅)²
115    ///            n - ddof  i=1
116    /// ```
117    ///
118    /// where
119    ///
120    /// ```text
121    ///     1   n
122    /// x̅ = ―   ∑ xᵢ
123    ///     n  i=1
124    /// ```
125    ///
126    /// and `n` is the length of the array.
127    ///
128    /// **Panics** if `ddof` is less than zero or greater than `n`
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use ndarray::array;
134    /// use approx::assert_abs_diff_eq;
135    ///
136    /// let a = array![1., -4.32, 1.14, 0.32];
137    /// let var = a.var(1.);
138    /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
139    /// ```
140    #[track_caller]
141    #[cfg(feature = "std")]
142    pub fn var(&self, ddof: A) -> A
143    where A: Float + FromPrimitive
144    {
145        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
146        let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
147        assert!(
148            !(ddof < zero || ddof > n),
149            "`ddof` must not be less than zero or greater than the length of \
150             the axis",
151        );
152        let dof = n - ddof;
153        let mut mean = A::zero();
154        let mut sum_sq = A::zero();
155        let mut i = 0;
156        self.for_each(|&x| {
157            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
158            let delta = x - mean;
159            mean = mean + delta / count;
160            sum_sq = (x - mean).mul_add(delta, sum_sq);
161            i += 1;
162        });
163        sum_sq / dof
164    }
165
166    /// Return standard deviation of elements in the array.
167    ///
168    /// The standard deviation is computed from the variance using
169    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
170    ///
171    /// The parameter `ddof` specifies the "delta degrees of freedom". For
172    /// example, to calculate the population standard deviation, use `ddof = 0`,
173    /// or to calculate the sample standard deviation, use `ddof = 1`.
174    ///
175    /// The standard deviation is defined as:
176    ///
177    /// ```text
178    ///               ⎛    1       n          ⎞
179    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
180    ///               ⎝ n - ddof  i=1         ⎠
181    /// ```
182    ///
183    /// where
184    ///
185    /// ```text
186    ///     1   n
187    /// x̅ = ―   ∑ xᵢ
188    ///     n  i=1
189    /// ```
190    ///
191    /// and `n` is the length of the array.
192    ///
193    /// **Panics** if `ddof` is less than zero or greater than `n`
194    ///
195    /// # Example
196    ///
197    /// ```
198    /// use ndarray::array;
199    /// use approx::assert_abs_diff_eq;
200    ///
201    /// let a = array![1., -4.32, 1.14, 0.32];
202    /// let stddev = a.std(1.);
203    /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
204    /// ```
205    #[track_caller]
206    #[cfg(feature = "std")]
207    pub fn std(&self, ddof: A) -> A
208    where A: Float + FromPrimitive
209    {
210        self.var(ddof).sqrt()
211    }
212
213    /// Return sum along `axis`.
214    ///
215    /// ```
216    /// use ndarray::{aview0, aview1, arr2, Axis};
217    ///
218    /// let a = arr2(&[[1., 2., 3.],
219    ///                [4., 5., 6.]]);
220    /// assert!(
221    ///     a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
222    ///     a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
223    ///
224    ///     a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
225    /// );
226    /// ```
227    ///
228    /// **Panics** if `axis` is out of bounds.
229    #[track_caller]
230    pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
231    where
232        A: Clone + Zero + Add<Output = A>,
233        D: RemoveAxis,
234    {
235        let min_stride_axis = self.dim.min_stride_axis(&self.strides);
236        if axis == min_stride_axis {
237            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
238        } else {
239            let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
240            for subview in self.axis_iter(axis) {
241                res = res + &subview;
242            }
243            res
244        }
245    }
246
247    /// Return product along `axis`.
248    ///
249    /// The product of an empty array is 1.
250    ///
251    /// ```
252    /// use ndarray::{aview0, aview1, arr2, Axis};
253    ///
254    /// let a = arr2(&[[1., 2., 3.],
255    ///                [4., 5., 6.]]);
256    ///
257    /// assert!(
258    ///     a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
259    ///     a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
260    ///
261    ///     a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
262    /// );
263    /// ```
264    ///
265    /// **Panics** if `axis` is out of bounds.
266    #[track_caller]
267    pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
268    where
269        A: Clone + One + Mul<Output = A>,
270        D: RemoveAxis,
271    {
272        let min_stride_axis = self.dim.min_stride_axis(&self.strides);
273        if axis == min_stride_axis {
274            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
275        } else {
276            let mut res = Array::ones(self.raw_dim().remove_axis(axis));
277            for subview in self.axis_iter(axis) {
278                res = res * &subview;
279            }
280            res
281        }
282    }
283
284    /// Return mean along `axis`.
285    ///
286    /// Return `None` if the length of the axis is zero.
287    ///
288    /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
289    /// fails for the axis length.
290    ///
291    /// ```
292    /// use ndarray::{aview0, aview1, arr2, Axis};
293    ///
294    /// let a = arr2(&[[1., 2., 3.],
295    ///                [4., 5., 6.]]);
296    /// assert!(
297    ///     a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
298    ///     a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
299    ///
300    ///     a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
301    /// );
302    /// ```
303    #[track_caller]
304    pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
305    where
306        A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
307        D: RemoveAxis,
308    {
309        let axis_length = self.len_of(axis);
310        if axis_length == 0 {
311            None
312        } else {
313            let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
314            let sum = self.sum_axis(axis);
315            Some(sum / aview0(&axis_length))
316        }
317    }
318
319    /// Return variance along `axis`.
320    ///
321    /// The variance is computed using the [Welford one-pass
322    /// algorithm](https://www.jstor.org/stable/1266577).
323    ///
324    /// The parameter `ddof` specifies the "delta degrees of freedom". For
325    /// example, to calculate the population variance, use `ddof = 0`, or to
326    /// calculate the sample variance, use `ddof = 1`.
327    ///
328    /// The variance is defined as:
329    ///
330    /// ```text
331    ///               1       n
332    /// variance = ――――――――   ∑ (xᵢ - x̅)²
333    ///            n - ddof  i=1
334    /// ```
335    ///
336    /// where
337    ///
338    /// ```text
339    ///     1   n
340    /// x̅ = ―   ∑ xᵢ
341    ///     n  i=1
342    /// ```
343    ///
344    /// and `n` is the length of the axis.
345    ///
346    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
347    /// is out of bounds, or if `A::from_usize()` fails for any any of the
348    /// numbers in the range `0..=n`.
349    ///
350    /// # Example
351    ///
352    /// ```
353    /// use ndarray::{aview1, arr2, Axis};
354    ///
355    /// let a = arr2(&[[1., 2.],
356    ///                [3., 4.],
357    ///                [5., 6.]]);
358    /// let var = a.var_axis(Axis(0), 1.);
359    /// assert_eq!(var, aview1(&[4., 4.]));
360    /// ```
361    #[track_caller]
362    #[cfg(feature = "std")]
363    pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
364    where
365        A: Float + FromPrimitive,
366        D: RemoveAxis,
367    {
368        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
369        let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
370        assert!(
371            !(ddof < zero || ddof > n),
372            "`ddof` must not be less than zero or greater than the length of \
373             the axis",
374        );
375        let dof = n - ddof;
376        let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
377        let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
378        for (i, subview) in self.axis_iter(axis).enumerate() {
379            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
380            azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
381                let delta = x - *mean;
382                *mean = *mean + delta / count;
383                *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
384            });
385        }
386        sum_sq.mapv_into(|s| s / dof)
387    }
388
389    /// Return standard deviation along `axis`.
390    ///
391    /// The standard deviation is computed from the variance using
392    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
393    ///
394    /// The parameter `ddof` specifies the "delta degrees of freedom". For
395    /// example, to calculate the population standard deviation, use `ddof = 0`,
396    /// or to calculate the sample standard deviation, use `ddof = 1`.
397    ///
398    /// The standard deviation is defined as:
399    ///
400    /// ```text
401    ///               ⎛    1       n          ⎞
402    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
403    ///               ⎝ n - ddof  i=1         ⎠
404    /// ```
405    ///
406    /// where
407    ///
408    /// ```text
409    ///     1   n
410    /// x̅ = ―   ∑ xᵢ
411    ///     n  i=1
412    /// ```
413    ///
414    /// and `n` is the length of the axis.
415    ///
416    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
417    /// is out of bounds, or if `A::from_usize()` fails for any any of the
418    /// numbers in the range `0..=n`.
419    ///
420    /// # Example
421    ///
422    /// ```
423    /// use ndarray::{aview1, arr2, Axis};
424    ///
425    /// let a = arr2(&[[1., 2.],
426    ///                [3., 4.],
427    ///                [5., 6.]]);
428    /// let stddev = a.std_axis(Axis(0), 1.);
429    /// assert_eq!(stddev, aview1(&[2., 2.]));
430    /// ```
431    #[track_caller]
432    #[cfg(feature = "std")]
433    pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
434    where
435        A: Float + FromPrimitive,
436        D: RemoveAxis,
437    {
438        self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
439    }
440}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy