ndarray/
array_serde.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8use serde::de::{self, MapAccess, SeqAccess, Visitor};
9use serde::ser::{SerializeSeq, SerializeStruct};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12use alloc::format;
13#[cfg(not(feature = "std"))]
14use alloc::vec::Vec;
15use std::fmt;
16use std::marker::PhantomData;
17
18use crate::imp_prelude::*;
19
20use super::arraytraits::ARRAY_FORMAT_VERSION;
21use super::Iter;
22use crate::IntoDimension;
23
24/// Verifies that the version of the deserialized array matches the current
25/// `ARRAY_FORMAT_VERSION`.
26pub fn verify_version<E>(v: u8) -> Result<(), E>
27where E: de::Error
28{
29    if v != ARRAY_FORMAT_VERSION {
30        let err_msg = format!("unknown array version: {}", v);
31        Err(de::Error::custom(err_msg))
32    } else {
33        Ok(())
34    }
35}
36
37/// **Requires crate feature `"serde"`**
38impl<I> Serialize for Dim<I>
39where I: Serialize
40{
41    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
42    where Se: Serializer
43    {
44        self.ix().serialize(serializer)
45    }
46}
47
48/// **Requires crate feature `"serde"`**
49impl<'de, I> Deserialize<'de> for Dim<I>
50where I: Deserialize<'de>
51{
52    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53    where D: Deserializer<'de>
54    {
55        I::deserialize(deserializer).map(Dim::new)
56    }
57}
58
59/// **Requires crate feature `"serde"`**
60impl Serialize for IxDyn
61{
62    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
63    where Se: Serializer
64    {
65        self.ix().serialize(serializer)
66    }
67}
68
69/// **Requires crate feature `"serde"`**
70impl<'de> Deserialize<'de> for IxDyn
71{
72    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73    where D: Deserializer<'de>
74    {
75        let v = Vec::<Ix>::deserialize(deserializer)?;
76        Ok(v.into_dimension())
77    }
78}
79
80/// **Requires crate feature `"serde"`**
81impl<A, D, S> Serialize for ArrayBase<S, D>
82where
83    A: Serialize,
84    D: Dimension + Serialize,
85    S: Data<Elem = A>,
86{
87    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
88    where Se: Serializer
89    {
90        let mut state = serializer.serialize_struct("Array", 3)?;
91        state.serialize_field("v", &ARRAY_FORMAT_VERSION)?;
92        state.serialize_field("dim", &self.raw_dim())?;
93        state.serialize_field("data", &Sequence(self.iter()))?;
94        state.end()
95    }
96}
97
98// private iterator wrapper
99struct Sequence<'a, A, D>(Iter<'a, A, D>);
100
101impl<'a, A, D> Serialize for Sequence<'a, A, D>
102where
103    A: Serialize,
104    D: Dimension + Serialize,
105{
106    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107    where S: Serializer
108    {
109        let iter = &self.0;
110        let mut seq = serializer.serialize_seq(Some(iter.len()))?;
111        for elt in iter.clone() {
112            seq.serialize_element(elt)?;
113        }
114        seq.end()
115    }
116}
117
118struct ArrayVisitor<S, Di>
119{
120    _marker_a: PhantomData<S>,
121    _marker_b: PhantomData<Di>,
122}
123
124enum ArrayField
125{
126    Version,
127    Dim,
128    Data,
129}
130
131impl<S, Di> ArrayVisitor<S, Di>
132{
133    pub fn new() -> Self
134    {
135        ArrayVisitor {
136            _marker_a: PhantomData,
137            _marker_b: PhantomData,
138        }
139    }
140}
141
142static ARRAY_FIELDS: &[&str] = &["v", "dim", "data"];
143
144/// **Requires crate feature `"serde"`**
145impl<'de, A, Di, S> Deserialize<'de> for ArrayBase<S, Di>
146where
147    A: Deserialize<'de>,
148    Di: Deserialize<'de> + Dimension,
149    S: DataOwned<Elem = A>,
150{
151    fn deserialize<D>(deserializer: D) -> Result<ArrayBase<S, Di>, D::Error>
152    where D: Deserializer<'de>
153    {
154        deserializer.deserialize_struct("Array", ARRAY_FIELDS, ArrayVisitor::new())
155    }
156}
157
158impl<'de> Deserialize<'de> for ArrayField
159{
160    fn deserialize<D>(deserializer: D) -> Result<ArrayField, D::Error>
161    where D: Deserializer<'de>
162    {
163        struct ArrayFieldVisitor;
164
165        impl<'de> Visitor<'de> for ArrayFieldVisitor
166        {
167            type Value = ArrayField;
168
169            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
170            {
171                formatter.write_str(r#""v", "dim", or "data""#)
172            }
173
174            fn visit_str<E>(self, value: &str) -> Result<ArrayField, E>
175            where E: de::Error
176            {
177                match value {
178                    "v" => Ok(ArrayField::Version),
179                    "dim" => Ok(ArrayField::Dim),
180                    "data" => Ok(ArrayField::Data),
181                    other => Err(de::Error::unknown_field(other, ARRAY_FIELDS)),
182                }
183            }
184
185            fn visit_bytes<E>(self, value: &[u8]) -> Result<ArrayField, E>
186            where E: de::Error
187            {
188                match value {
189                    b"v" => Ok(ArrayField::Version),
190                    b"dim" => Ok(ArrayField::Dim),
191                    b"data" => Ok(ArrayField::Data),
192                    other => Err(de::Error::unknown_field(&format!("{:?}", other), ARRAY_FIELDS)),
193                }
194            }
195        }
196
197        deserializer.deserialize_identifier(ArrayFieldVisitor)
198    }
199}
200
201impl<'de, A, Di, S> Visitor<'de> for ArrayVisitor<S, Di>
202where
203    A: Deserialize<'de>,
204    Di: Deserialize<'de> + Dimension,
205    S: DataOwned<Elem = A>,
206{
207    type Value = ArrayBase<S, Di>;
208
209    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
210    {
211        formatter.write_str("ndarray representation")
212    }
213
214    fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
215    where V: SeqAccess<'de>
216    {
217        let v: u8 = match visitor.next_element()? {
218            Some(value) => value,
219            None => {
220                return Err(de::Error::invalid_length(0, &self));
221            }
222        };
223
224        verify_version(v)?;
225
226        let dim: Di = match visitor.next_element()? {
227            Some(value) => value,
228            None => {
229                return Err(de::Error::invalid_length(1, &self));
230            }
231        };
232
233        let data: Vec<A> = match visitor.next_element()? {
234            Some(value) => value,
235            None => {
236                return Err(de::Error::invalid_length(2, &self));
237            }
238        };
239
240        if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
241            Ok(array)
242        } else {
243            Err(de::Error::custom("data and dimension must match in size"))
244        }
245    }
246
247    fn visit_map<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
248    where V: MapAccess<'de>
249    {
250        let mut v: Option<u8> = None;
251        let mut data: Option<Vec<A>> = None;
252        let mut dim: Option<Di> = None;
253
254        while let Some(key) = visitor.next_key()? {
255            match key {
256                ArrayField::Version => {
257                    let val = visitor.next_value()?;
258                    verify_version(val)?;
259                    v = Some(val);
260                }
261                ArrayField::Data => {
262                    data = Some(visitor.next_value()?);
263                }
264                ArrayField::Dim => {
265                    dim = Some(visitor.next_value()?);
266                }
267            }
268        }
269
270        let _v = match v {
271            Some(v) => v,
272            None => return Err(de::Error::missing_field("v")),
273        };
274
275        let data = match data {
276            Some(data) => data,
277            None => return Err(de::Error::missing_field("data")),
278        };
279
280        let dim = match dim {
281            Some(dim) => dim,
282            None => return Err(de::Error::missing_field("dim")),
283        };
284
285        if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
286            Ok(array)
287        } else {
288            Err(de::Error::custom("data and dimension must match in size"))
289        }
290    }
291}
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy