1use serde::de::{self, MapAccess, SeqAccess, Visitor};
9use serde::ser::{SerializeSeq, SerializeStruct};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12use alloc::format;
13#[cfg(not(feature = "std"))]
14use alloc::vec::Vec;
15use std::fmt;
16use std::marker::PhantomData;
17
18use crate::imp_prelude::*;
19
20use super::arraytraits::ARRAY_FORMAT_VERSION;
21use super::Iter;
22use crate::IntoDimension;
23
24pub fn verify_version<E>(v: u8) -> Result<(), E>
27where E: de::Error
28{
29 if v != ARRAY_FORMAT_VERSION {
30 let err_msg = format!("unknown array version: {}", v);
31 Err(de::Error::custom(err_msg))
32 } else {
33 Ok(())
34 }
35}
36
37impl<I> Serialize for Dim<I>
39where I: Serialize
40{
41 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
42 where Se: Serializer
43 {
44 self.ix().serialize(serializer)
45 }
46}
47
48impl<'de, I> Deserialize<'de> for Dim<I>
50where I: Deserialize<'de>
51{
52 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53 where D: Deserializer<'de>
54 {
55 I::deserialize(deserializer).map(Dim::new)
56 }
57}
58
59impl Serialize for IxDyn
61{
62 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
63 where Se: Serializer
64 {
65 self.ix().serialize(serializer)
66 }
67}
68
69impl<'de> Deserialize<'de> for IxDyn
71{
72 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
73 where D: Deserializer<'de>
74 {
75 let v = Vec::<Ix>::deserialize(deserializer)?;
76 Ok(v.into_dimension())
77 }
78}
79
80impl<A, D, S> Serialize for ArrayBase<S, D>
82where
83 A: Serialize,
84 D: Dimension + Serialize,
85 S: Data<Elem = A>,
86{
87 fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
88 where Se: Serializer
89 {
90 let mut state = serializer.serialize_struct("Array", 3)?;
91 state.serialize_field("v", &ARRAY_FORMAT_VERSION)?;
92 state.serialize_field("dim", &self.raw_dim())?;
93 state.serialize_field("data", &Sequence(self.iter()))?;
94 state.end()
95 }
96}
97
98struct Sequence<'a, A, D>(Iter<'a, A, D>);
100
101impl<'a, A, D> Serialize for Sequence<'a, A, D>
102where
103 A: Serialize,
104 D: Dimension + Serialize,
105{
106 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107 where S: Serializer
108 {
109 let iter = &self.0;
110 let mut seq = serializer.serialize_seq(Some(iter.len()))?;
111 for elt in iter.clone() {
112 seq.serialize_element(elt)?;
113 }
114 seq.end()
115 }
116}
117
118struct ArrayVisitor<S, Di>
119{
120 _marker_a: PhantomData<S>,
121 _marker_b: PhantomData<Di>,
122}
123
124enum ArrayField
125{
126 Version,
127 Dim,
128 Data,
129}
130
131impl<S, Di> ArrayVisitor<S, Di>
132{
133 pub fn new() -> Self
134 {
135 ArrayVisitor {
136 _marker_a: PhantomData,
137 _marker_b: PhantomData,
138 }
139 }
140}
141
142static ARRAY_FIELDS: &[&str] = &["v", "dim", "data"];
143
144impl<'de, A, Di, S> Deserialize<'de> for ArrayBase<S, Di>
146where
147 A: Deserialize<'de>,
148 Di: Deserialize<'de> + Dimension,
149 S: DataOwned<Elem = A>,
150{
151 fn deserialize<D>(deserializer: D) -> Result<ArrayBase<S, Di>, D::Error>
152 where D: Deserializer<'de>
153 {
154 deserializer.deserialize_struct("Array", ARRAY_FIELDS, ArrayVisitor::new())
155 }
156}
157
158impl<'de> Deserialize<'de> for ArrayField
159{
160 fn deserialize<D>(deserializer: D) -> Result<ArrayField, D::Error>
161 where D: Deserializer<'de>
162 {
163 struct ArrayFieldVisitor;
164
165 impl<'de> Visitor<'de> for ArrayFieldVisitor
166 {
167 type Value = ArrayField;
168
169 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
170 {
171 formatter.write_str(r#""v", "dim", or "data""#)
172 }
173
174 fn visit_str<E>(self, value: &str) -> Result<ArrayField, E>
175 where E: de::Error
176 {
177 match value {
178 "v" => Ok(ArrayField::Version),
179 "dim" => Ok(ArrayField::Dim),
180 "data" => Ok(ArrayField::Data),
181 other => Err(de::Error::unknown_field(other, ARRAY_FIELDS)),
182 }
183 }
184
185 fn visit_bytes<E>(self, value: &[u8]) -> Result<ArrayField, E>
186 where E: de::Error
187 {
188 match value {
189 b"v" => Ok(ArrayField::Version),
190 b"dim" => Ok(ArrayField::Dim),
191 b"data" => Ok(ArrayField::Data),
192 other => Err(de::Error::unknown_field(&format!("{:?}", other), ARRAY_FIELDS)),
193 }
194 }
195 }
196
197 deserializer.deserialize_identifier(ArrayFieldVisitor)
198 }
199}
200
201impl<'de, A, Di, S> Visitor<'de> for ArrayVisitor<S, Di>
202where
203 A: Deserialize<'de>,
204 Di: Deserialize<'de> + Dimension,
205 S: DataOwned<Elem = A>,
206{
207 type Value = ArrayBase<S, Di>;
208
209 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result
210 {
211 formatter.write_str("ndarray representation")
212 }
213
214 fn visit_seq<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
215 where V: SeqAccess<'de>
216 {
217 let v: u8 = match visitor.next_element()? {
218 Some(value) => value,
219 None => {
220 return Err(de::Error::invalid_length(0, &self));
221 }
222 };
223
224 verify_version(v)?;
225
226 let dim: Di = match visitor.next_element()? {
227 Some(value) => value,
228 None => {
229 return Err(de::Error::invalid_length(1, &self));
230 }
231 };
232
233 let data: Vec<A> = match visitor.next_element()? {
234 Some(value) => value,
235 None => {
236 return Err(de::Error::invalid_length(2, &self));
237 }
238 };
239
240 if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
241 Ok(array)
242 } else {
243 Err(de::Error::custom("data and dimension must match in size"))
244 }
245 }
246
247 fn visit_map<V>(self, mut visitor: V) -> Result<ArrayBase<S, Di>, V::Error>
248 where V: MapAccess<'de>
249 {
250 let mut v: Option<u8> = None;
251 let mut data: Option<Vec<A>> = None;
252 let mut dim: Option<Di> = None;
253
254 while let Some(key) = visitor.next_key()? {
255 match key {
256 ArrayField::Version => {
257 let val = visitor.next_value()?;
258 verify_version(val)?;
259 v = Some(val);
260 }
261 ArrayField::Data => {
262 data = Some(visitor.next_value()?);
263 }
264 ArrayField::Dim => {
265 dim = Some(visitor.next_value()?);
266 }
267 }
268 }
269
270 let _v = match v {
271 Some(v) => v,
272 None => return Err(de::Error::missing_field("v")),
273 };
274
275 let data = match data {
276 Some(data) => data,
277 None => return Err(de::Error::missing_field("data")),
278 };
279
280 let dim = match dim {
281 Some(dim) => dim,
282 None => return Err(de::Error::missing_field("dim")),
283 };
284
285 if let Ok(array) = ArrayBase::from_shape_vec(dim, data) {
286 Ok(array)
287 } else {
288 Err(de::Error::custom("data and dimension must match in size"))
289 }
290 }
291}