ndarray/
impl_ops.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13/// Elements that can be used as direct operands in arithmetic with arrays.
14///
15/// For example, `f64` is a `ScalarOperand` which means that for an array `a`,
16/// arithmetic like `a + 1.0`, and, `a * 2.`, and `a += 3.` are allowed.
17///
18/// In the description below, let `A` be an array or array view,
19/// let `B` be an array with owned data,
20/// and let `C` be an array with mutable data.
21///
22/// `ScalarOperand` determines for which scalars `K` operations `&A @ K`, and `B @ K`,
23/// and `C @= K` are defined, as ***right hand side operands***, for applicable
24/// arithmetic operators (denoted `@`).
25///
26/// ***Left hand side*** scalar operands are not related to this trait
27/// (they need one `impl` per concrete scalar type); but they are still
28/// implemented for the same types, allowing operations
29/// `K @ &A`, and `K @ B` for primitive numeric types `K`.
30///
31/// This trait ***does not*** limit which elements can be stored in an array in general.
32/// Non-`ScalarOperand` types can still participate in arithmetic as array elements in
33/// in array-array operations.
34pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54    ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55/// Perform elementwise
56#[doc=$doc]
57/// between `self` and `rhs`,
58/// and return the result.
59///
60/// `self` must be an `Array` or `ArcArray`.
61///
62/// If their shapes disagree, `self` is broadcast to their broadcast shape.
63///
64/// **Panics** if broadcasting isn’t possible.
65impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67    A: Clone + $trt<B, Output=A>,
68    B: Clone,
69    S: DataOwned<Elem=A> + DataMut,
70    S2: Data<Elem=B>,
71    D: Dimension + DimMax<E>,
72    E: Dimension,
73{
74    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75    #[track_caller]
76    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
77    {
78        self.$mth(&rhs)
79    }
80}
81
82/// Perform elementwise
83#[doc=$doc]
84/// between `self` and reference `rhs`,
85/// and return the result.
86///
87/// `rhs` must be an `Array` or `ArcArray`.
88///
89/// If their shapes disagree, `self` is broadcast to their broadcast shape,
90/// cloning the data if needed.
91///
92/// **Panics** if broadcasting isn’t possible.
93impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
94where
95    A: Clone + $trt<B, Output=A>,
96    B: Clone,
97    S: DataOwned<Elem=A> + DataMut,
98    S2: Data<Elem=B>,
99    D: Dimension + DimMax<E>,
100    E: Dimension,
101{
102    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
103    #[track_caller]
104    fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
105    {
106        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
107            let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
108            out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
109            out
110        } else {
111            let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
112            if lhs_view.shape() == self.shape() {
113                let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
114                out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
115                out
116            } else {
117                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
118            }
119        }
120    }
121}
122
123/// Perform elementwise
124#[doc=$doc]
125/// between reference `self` and `rhs`,
126/// and return the result.
127///
128/// `rhs` must be an `Array` or `ArcArray`.
129///
130/// If their shapes disagree, `self` is broadcast to their broadcast shape,
131/// cloning the data if needed.
132///
133/// **Panics** if broadcasting isn’t possible.
134impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
135where
136    A: Clone + $trt<B, Output=B>,
137    B: Clone,
138    S: Data<Elem=A>,
139    S2: DataOwned<Elem=B> + DataMut,
140    D: Dimension,
141    E: Dimension + DimMax<D>,
142{
143    type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
144    #[track_caller]
145    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
146    where
147    {
148        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
149            let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
150            out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
151            out
152        } else {
153            let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
154            if rhs_view.shape() == rhs.shape() {
155                let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
156                out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
157                out
158            } else {
159                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
160            }
161        }
162    }
163}
164
165/// Perform elementwise
166#[doc=$doc]
167/// between references `self` and `rhs`,
168/// and return the result as a new `Array`.
169///
170/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
171/// cloning the data if needed.
172///
173/// **Panics** if broadcasting isn’t possible.
174impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
175where
176    A: Clone + $trt<B, Output=A>,
177    B: Clone,
178    S: Data<Elem=A>,
179    S2: Data<Elem=B>,
180    D: Dimension + DimMax<E>,
181    E: Dimension,
182{
183    type Output = Array<A, <D as DimMax<E>>::Output>;
184    #[track_caller]
185    fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
186        let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
187            let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
188            let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
189            (lhs, rhs)
190        } else {
191            self.broadcast_with(rhs).unwrap()
192        };
193        Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
194    }
195}
196
197/// Perform elementwise
198#[doc=$doc]
199/// between `self` and the scalar `x`,
200/// and return the result (based on `self`).
201///
202/// `self` must be an `Array` or `ArcArray`.
203impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
204    where A: Clone + $trt<B, Output=A>,
205          S: DataOwned<Elem=A> + DataMut,
206          D: Dimension,
207          B: ScalarOperand,
208{
209    type Output = ArrayBase<S, D>;
210    fn $mth(mut self, x: B) -> ArrayBase<S, D> {
211        self.map_inplace(move |elt| {
212            *elt = elt.clone() $operator x.clone();
213        });
214        self
215    }
216}
217
218/// Perform elementwise
219#[doc=$doc]
220/// between the reference `self` and the scalar `x`,
221/// and return the result as a new `Array`.
222impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
223    where A: Clone + $trt<B, Output=A>,
224          S: Data<Elem=A>,
225          D: Dimension,
226          B: ScalarOperand,
227{
228    type Output = Array<A, D>;
229    fn $mth(self, x: B) -> Self::Output {
230        self.map(move |elt| elt.clone() $operator x.clone())
231    }
232}
233    );
234);
235
236// Pick the expression $a for commutative and $b for ordered binop
237macro_rules! if_commutative {
238    (Commute { $a:expr } or { $b:expr }) => {
239        $a
240    };
241    (Ordered { $a:expr } or { $b:expr }) => {
242        $b
243    };
244}
245
246macro_rules! impl_scalar_lhs_op {
247    // $commutative flag. Reuse the self + scalar impl if we can.
248    // We can do this safely since these are the primitive numeric types
249    ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
250// these have no doc -- they are not visible in rustdoc
251// Perform elementwise
252// between the scalar `self` and array `rhs`,
253// and return the result (based on `self`).
254impl<S, D> $trt<ArrayBase<S, D>> for $scalar
255    where S: DataOwned<Elem=$scalar> + DataMut,
256          D: Dimension,
257{
258    type Output = ArrayBase<S, D>;
259    fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
260        if_commutative!($commutative {
261            rhs.$mth(self)
262        } or {{
263            let mut rhs = rhs;
264            rhs.map_inplace(move |elt| {
265                *elt = self $operator *elt;
266            });
267            rhs
268        }})
269    }
270}
271
272// Perform elementwise
273// between the scalar `self` and array `rhs`,
274// and return the result as a new `Array`.
275impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
276    where S: Data<Elem=$scalar>,
277          D: Dimension,
278{
279    type Output = Array<$scalar, D>;
280    fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
281        if_commutative!($commutative {
282            rhs.$mth(self)
283        } or {
284            rhs.map(move |elt| self.clone() $operator elt.clone())
285        })
286    }
287}
288    );
289}
290
291mod arithmetic_ops
292{
293    use super::*;
294    use crate::imp_prelude::*;
295
296    use std::ops::*;
297
298    fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C
299    {
300        move |x, y| f(x.clone(), y.clone())
301    }
302
303    fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B)
304    {
305        move |x, y| *x = f(x.clone(), y.clone())
306    }
307
308    fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A)
309    {
310        move |x, y| *x = f(y.clone(), x.clone())
311    }
312
313    impl_binary_op!(Add, +, add, +=, "addition");
314    impl_binary_op!(Sub, -, sub, -=, "subtraction");
315    impl_binary_op!(Mul, *, mul, *=, "multiplication");
316    impl_binary_op!(Div, /, div, /=, "division");
317    impl_binary_op!(Rem, %, rem, %=, "remainder");
318    impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
319    impl_binary_op!(BitOr, |, bitor, |=, "bit or");
320    impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
321    impl_binary_op!(Shl, <<, shl, <<=, "left shift");
322    impl_binary_op!(Shr, >>, shr, >>=, "right shift");
323
324    macro_rules! all_scalar_ops {
325        ($int_scalar:ty) => (
326            impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
327            impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
328            impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
329            impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
330            impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
331            impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
332            impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
333            impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
334            impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
335            impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
336        );
337    }
338    all_scalar_ops!(i8);
339    all_scalar_ops!(u8);
340    all_scalar_ops!(i16);
341    all_scalar_ops!(u16);
342    all_scalar_ops!(i32);
343    all_scalar_ops!(u32);
344    all_scalar_ops!(i64);
345    all_scalar_ops!(u64);
346    all_scalar_ops!(isize);
347    all_scalar_ops!(usize);
348    all_scalar_ops!(i128);
349    all_scalar_ops!(u128);
350
351    impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
352    impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
353    impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
354
355    impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
356    impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
357    impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
358    impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
359    impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
360
361    impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
362    impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
363    impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
364    impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
365    impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
366
367    impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
368    impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
369    impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
370    impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
371
372    impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
373    impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
374    impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
375    impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
376
377    impl<A, S, D> Neg for ArrayBase<S, D>
378    where
379        A: Clone + Neg<Output = A>,
380        S: DataOwned<Elem = A> + DataMut,
381        D: Dimension,
382    {
383        type Output = Self;
384        /// Perform an elementwise negation of `self` and return the result.
385        fn neg(mut self) -> Self
386        {
387            self.map_inplace(|elt| {
388                *elt = -elt.clone();
389            });
390            self
391        }
392    }
393
394    impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
395    where
396        &'a A: 'a + Neg<Output = A>,
397        S: Data<Elem = A>,
398        D: Dimension,
399    {
400        type Output = Array<A, D>;
401        /// Perform an elementwise negation of reference `self` and return the
402        /// result as a new `Array`.
403        fn neg(self) -> Array<A, D>
404        {
405            self.map(Neg::neg)
406        }
407    }
408
409    impl<A, S, D> Not for ArrayBase<S, D>
410    where
411        A: Clone + Not<Output = A>,
412        S: DataOwned<Elem = A> + DataMut,
413        D: Dimension,
414    {
415        type Output = Self;
416        /// Perform an elementwise unary not of `self` and return the result.
417        fn not(mut self) -> Self
418        {
419            self.map_inplace(|elt| {
420                *elt = !elt.clone();
421            });
422            self
423        }
424    }
425
426    impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
427    where
428        &'a A: 'a + Not<Output = A>,
429        S: Data<Elem = A>,
430        D: Dimension,
431    {
432        type Output = Array<A, D>;
433        /// Perform an elementwise unary not of reference `self` and return the
434        /// result as a new `Array`.
435        fn not(self) -> Array<A, D>
436        {
437            self.map(Not::not)
438        }
439    }
440}
441
442mod assign_ops
443{
444    use super::*;
445    use crate::imp_prelude::*;
446
447    macro_rules! impl_assign_op {
448        ($trt:ident, $method:ident, $doc:expr) => {
449            use std::ops::$trt;
450
451            #[doc=$doc]
452            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
453            ///
454            /// **Panics** if broadcasting isn’t possible.
455            impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
456            where
457                A: Clone + $trt<A>,
458                S: DataMut<Elem = A>,
459                S2: Data<Elem = A>,
460                D: Dimension,
461                E: Dimension,
462            {
463                #[track_caller]
464                fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
465                    self.zip_mut_with(rhs, |x, y| {
466                        x.$method(y.clone());
467                    });
468                }
469            }
470
471            #[doc=$doc]
472            impl<A, S, D> $trt<A> for ArrayBase<S, D>
473            where
474                A: ScalarOperand + $trt<A>,
475                S: DataMut<Elem = A>,
476                D: Dimension,
477            {
478                fn $method(&mut self, rhs: A) {
479                    self.map_inplace(move |elt| {
480                        elt.$method(rhs.clone());
481                    });
482                }
483            }
484        };
485    }
486
487    impl_assign_op!(
488        AddAssign,
489        add_assign,
490        "Perform `self += rhs` as elementwise addition (in place).\n"
491    );
492    impl_assign_op!(
493        SubAssign,
494        sub_assign,
495        "Perform `self -= rhs` as elementwise subtraction (in place).\n"
496    );
497    impl_assign_op!(
498        MulAssign,
499        mul_assign,
500        "Perform `self *= rhs` as elementwise multiplication (in place).\n"
501    );
502    impl_assign_op!(
503        DivAssign,
504        div_assign,
505        "Perform `self /= rhs` as elementwise division (in place).\n"
506    );
507    impl_assign_op!(
508        RemAssign,
509        rem_assign,
510        "Perform `self %= rhs` as elementwise remainder (in place).\n"
511    );
512    impl_assign_op!(
513        BitAndAssign,
514        bitand_assign,
515        "Perform `self &= rhs` as elementwise bit and (in place).\n"
516    );
517    impl_assign_op!(
518        BitOrAssign,
519        bitor_assign,
520        "Perform `self |= rhs` as elementwise bit or (in place).\n"
521    );
522    impl_assign_op!(
523        BitXorAssign,
524        bitxor_assign,
525        "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
526    );
527    impl_assign_op!(
528        ShlAssign,
529        shl_assign,
530        "Perform `self <<= rhs` as elementwise left shift (in place).\n"
531    );
532    impl_assign_op!(
533        ShrAssign,
534        shr_assign,
535        "Perform `self >>= rhs` as elementwise right shift (in place).\n"
536    );
537}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy