base/
descriptor_reflection.rs

1// Copyright 2020 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
6//!
7//! # Example
8//!
9//! ```
10//! use serde_json::to_string;
11//! use base::{
12//!     FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
13//!     deserialize_with_descriptors,
14//! };
15//! use tempfile::tempfile;
16//!
17//! let tmp_f = tempfile().unwrap();
18//!
19//! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
20//! let data = FileSerdeWrapper(tmp_f);
21//!
22//! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
23//! let data_wrapper = SerializeDescriptors::new(&data);
24//!
25//! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
26//! // as the data structures are serialized by the serializer.
27//! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
28//!
29//! // If data_wrapper contains any side channel descriptor refs
30//! // (it contains tmp_f in this case), we can retrieve the actual descriptors
31//! // from the side channel using into_descriptors().
32//! let out_descriptors = data_wrapper.into_descriptors();
33//!
34//! // When sending out_json over some transport, also send out_descriptors.
35//!
36//! // For this example, we aren't really transporting data across the process, but we do need to
37//! // convert the descriptor type.
38//! let mut safe_descriptors = out_descriptors
39//!     .iter()
40//!     .map(|&v| unsafe { SafeDescriptor::from_raw_descriptor(v) });
41//! std::mem::forget(data); // Prevent double drop of tmp_f.
42//!
43//! // The deserialize_with_descriptors function is used give the descriptor deserializers access
44//! // to side channel descriptors.
45//! let res: FileSerdeWrapper =
46//!     deserialize_with_descriptors(|| serde_json::from_str(&out_json), safe_descriptors)
47//!        .expect("failed to deserialize");
48//! ```
49
50use std::cell::Cell;
51use std::cell::RefCell;
52use std::convert::TryInto;
53use std::fmt;
54use std::fs::File;
55use std::ops::Deref;
56use std::ops::DerefMut;
57use std::panic::catch_unwind;
58use std::panic::resume_unwind;
59use std::panic::AssertUnwindSafe;
60
61use serde::de;
62use serde::de::Error;
63use serde::de::Visitor;
64use serde::ser;
65use serde::Deserialize;
66use serde::Deserializer;
67use serde::Serialize;
68use serde::Serializer;
69
70use super::RawDescriptor;
71use crate::descriptor::SafeDescriptor;
72
73thread_local! {
74    static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
75}
76
77/// Initializes the thread local storage for descriptor serialization. Fails if it was already
78/// initialized without an intervening `take_descriptor_dst` on this thread.
79fn init_descriptor_dst() -> Result<(), &'static str> {
80    DESCRIPTOR_DST.with(|d| {
81        let mut descriptors = d.borrow_mut();
82        if descriptors.is_some() {
83            return Err(
84                "attempt to initialize descriptor destination that was already initialized",
85            );
86        }
87        *descriptors = Some(Default::default());
88        Ok(())
89    })
90}
91
92/// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
93/// to `init_descriptor_dst` on this thread.
94fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
95    match DESCRIPTOR_DST.with(|d| d.replace(None)) {
96        Some(d) => Ok(d),
97        None => Err("attempt to take descriptor destination before it was initialized"),
98    }
99}
100
101/// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
102/// the descriptor was pushed.
103//
104/// Returns Err if the thread local destination was not already initialized.
105fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
106    DESCRIPTOR_DST.with(|d| {
107        d.borrow_mut()
108            .as_mut()
109            .ok_or("attempt to serialize descriptor without descriptor destination")
110            .map(|descriptors| {
111                let index = descriptors.len();
112                descriptors.push(rd);
113                index
114            })
115    })
116}
117
118/// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
119///
120/// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
121///
122/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
123/// "...")]` attribute which will make use of this function.
124pub fn serialize_descriptor<S: Serializer>(
125    rd: &RawDescriptor,
126    se: S,
127) -> std::result::Result<S::Ok, S::Error> {
128    let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
129    se.serialize_u32(
130        index
131            .try_into()
132            .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
133    )
134}
135
136/// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
137/// given to an ordinary `Serializer`.
138///
139/// This is the corresponding type to use for serialization before using
140/// `deserialize_with_descriptors`.
141///
142/// # Examples
143///
144/// ```
145/// use serde_json::to_string;
146/// use base::{FileSerdeWrapper, SerializeDescriptors};
147/// use tempfile::tempfile;
148///
149/// let tmp_f = tempfile().unwrap();
150/// let data = FileSerdeWrapper(tmp_f);
151/// let data_wrapper = SerializeDescriptors::new(&data);
152///
153/// // Serializes `v` as normal...
154/// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
155/// // If `serialize_descriptor` was called, we can capture the descriptors from here.
156/// let out_descriptors = data_wrapper.into_descriptors();
157/// ```
158pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
159
160impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
161    pub fn new(inner: &'a T) -> Self {
162        Self(inner, Default::default())
163    }
164
165    pub fn into_descriptors(self) -> Vec<RawDescriptor> {
166        self.1.into_inner()
167    }
168}
169
170impl<T: Serialize> Serialize for SerializeDescriptors<'_, T> {
171    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
172    where
173        S: Serializer,
174    {
175        init_descriptor_dst().map_err(ser::Error::custom)?;
176
177        // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
178        // take_descriptor_dst afterwards.
179        let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
180        self.1.set(take_descriptor_dst().unwrap());
181        match res {
182            Ok(r) => r,
183            Err(e) => resume_unwind(e),
184        }
185    }
186}
187
188thread_local! {
189    static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
190}
191
192/// Sets the thread local storage of descriptors for deserialization. Fails if this was already
193/// called without a call to `take_descriptor_src` on this thread.
194///
195/// This is given as a collection of `Option` so that unused descriptors can be returned.
196fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
197    DESCRIPTOR_SRC.with(|d| {
198        let mut src = d.borrow_mut();
199        if src.is_some() {
200            return Err("attempt to set descriptor source that was already set");
201        }
202        *src = Some(descriptors);
203        Ok(())
204    })
205}
206
207/// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
208/// already taken or never set with `set_descriptor_src`.
209///
210/// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
211/// were unused.
212fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
213    DESCRIPTOR_SRC.with(|d| {
214        d.replace(None)
215            .ok_or("attempt to take descriptor source which was never set")
216    })
217}
218
219/// Takes a descriptor at the given index from the thread local source of descriptors.
220//
221/// Returns None if the thread local source was not already initialized.
222fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
223    DESCRIPTOR_SRC.with(|d| {
224        d.borrow_mut()
225            .as_mut()
226            .ok_or("attempt to deserialize descriptor without descriptor source")?
227            .get_mut(index)
228            .ok_or("attempt to deserialize out of bounds descriptor")?
229            .take()
230            .ok_or("attempt to deserialize descriptor that was already taken")
231    })
232}
233
234/// Deserializes a descriptor provided via `deserialize_with_descriptors`.
235///
236/// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
237///
238/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
239/// "...")]` attribute which will make use of this function.
240pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
241where
242    D: Deserializer<'de>,
243{
244    struct DescriptorVisitor;
245
246    impl Visitor<'_> for DescriptorVisitor {
247        type Value = u32;
248
249        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
250            formatter.write_str("an integer which fits into a u32")
251        }
252
253        fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
254            Ok(value as _)
255        }
256
257        fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
258            Ok(value as _)
259        }
260
261        fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
262            Ok(value)
263        }
264
265        fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
266            value.try_into().map_err(E::custom)
267        }
268
269        fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
270            value.try_into().map_err(E::custom)
271        }
272
273        fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
274            value.try_into().map_err(E::custom)
275        }
276
277        fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
278            value.try_into().map_err(E::custom)
279        }
280
281        fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
282            value.try_into().map_err(E::custom)
283        }
284
285        fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
286            value.try_into().map_err(E::custom)
287        }
288
289        fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
290            value.try_into().map_err(E::custom)
291        }
292    }
293
294    let index = de.deserialize_u32(DescriptorVisitor)? as usize;
295    take_descriptor(index).map_err(D::Error::custom)
296}
297
298/// Allows the use of any serde deserializer within a closure while providing access to the a set of
299/// descriptors for use in `deserialize_descriptor`.
300///
301/// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
302///
303/// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
304/// error.
305pub fn deserialize_with_descriptors<F, T, E>(
306    f: F,
307    descriptors: impl IntoIterator<Item = SafeDescriptor>,
308) -> Result<T, E>
309where
310    F: FnOnce() -> Result<T, E>,
311    E: de::Error,
312{
313    let descriptor_src = descriptors.into_iter().map(Option::Some).collect();
314    set_descriptor_src(descriptor_src).map_err(E::custom)?;
315
316    // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
317    // take_descriptor_src afterwards.
318    let res = catch_unwind(AssertUnwindSafe(f));
319
320    // unwrap is used because set_descriptor_src is always called before this, so it should never
321    // panic.
322    let empty_descriptors = take_descriptor_src().unwrap();
323
324    // The deserializer should have consumed every descriptor.
325    debug_assert!(empty_descriptors.into_iter().all(|d| d.is_none()));
326
327    match res {
328        Ok(r) => r,
329        Err(e) => resume_unwind(e),
330    }
331}
332
333/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
334/// attribute. It only works with fields with `RawDescriptor` type.
335///
336/// # Examples
337///
338/// ```
339/// use serde::{Deserialize, Serialize};
340/// use base::RawDescriptor;
341///
342/// #[derive(Serialize, Deserialize)]
343/// struct RawContainer {
344///     #[serde(with = "base::with_raw_descriptor")]
345///     rd: RawDescriptor,
346/// }
347/// ```
348pub mod with_raw_descriptor {
349    use serde::Deserializer;
350
351    use super::super::RawDescriptor;
352    pub use super::serialize_descriptor as serialize;
353    use crate::descriptor::IntoRawDescriptor;
354
355    pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
356    where
357        D: Deserializer<'de>,
358    {
359        super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
360    }
361}
362
363/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
364/// attribute.
365///
366/// # Examples
367///
368/// ```
369/// use std::fs::File;
370/// use serde::{Deserialize, Serialize};
371/// use base::RawDescriptor;
372///
373/// #[derive(Serialize, Deserialize)]
374/// struct FileContainer {
375///     #[serde(with = "base::with_as_descriptor")]
376///     file: File,
377/// }
378/// ```
379pub mod with_as_descriptor {
380    use serde::Deserializer;
381    use serde::Serializer;
382
383    use crate::descriptor::AsRawDescriptor;
384    use crate::descriptor::FromRawDescriptor;
385    use crate::descriptor::IntoRawDescriptor;
386
387    pub fn serialize<S: Serializer>(
388        rd: &dyn AsRawDescriptor,
389        se: S,
390    ) -> std::result::Result<S::Ok, S::Error> {
391        super::serialize_descriptor(&rd.as_raw_descriptor(), se)
392    }
393
394    pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
395    where
396        D: Deserializer<'de>,
397        T: FromRawDescriptor,
398    {
399        super::deserialize_descriptor(de)
400            .map(IntoRawDescriptor::into_raw_descriptor)
401            .map(|rd|
402                // SAFETY: rd is expected to be valid for the duration of the call.
403                unsafe { T::from_raw_descriptor(rd) })
404    }
405}
406
407/// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
408/// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
409/// `Option<File>`.
410#[derive(Serialize, Deserialize)]
411#[serde(transparent)]
412pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
413
414impl fmt::Debug for FileSerdeWrapper {
415    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
416        self.0.fmt(f)
417    }
418}
419
420impl From<File> for FileSerdeWrapper {
421    fn from(file: File) -> Self {
422        FileSerdeWrapper(file)
423    }
424}
425
426impl From<FileSerdeWrapper> for File {
427    fn from(f: FileSerdeWrapper) -> File {
428        f.0
429    }
430}
431
432impl Deref for FileSerdeWrapper {
433    type Target = File;
434    fn deref(&self) -> &Self::Target {
435        &self.0
436    }
437}
438
439impl DerefMut for FileSerdeWrapper {
440    fn deref_mut(&mut self) -> &mut Self::Target {
441        &mut self.0
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use std::collections::HashMap;
448    use std::fs::File;
449    use std::mem::ManuallyDrop;
450
451    use serde::de::DeserializeOwned;
452    use serde::Deserialize;
453    use serde::Serialize;
454    use tempfile::tempfile;
455
456    use super::super::deserialize_with_descriptors;
457    use super::super::with_as_descriptor;
458    use super::super::with_raw_descriptor;
459    use super::super::AsRawDescriptor;
460    use super::super::FileSerdeWrapper;
461    use super::super::FromRawDescriptor;
462    use super::super::RawDescriptor;
463    use super::super::SafeDescriptor;
464    use super::super::SerializeDescriptors;
465
466    fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
467        let safe_descriptors = descriptors.iter().map(|&v|
468                // SAFETY: `descriptor` is expected to be valid.
469                unsafe { SafeDescriptor::from_raw_descriptor(v) });
470
471        deserialize_with_descriptors(|| serde_json::from_str(json), safe_descriptors).unwrap()
472    }
473
474    #[test]
475    fn raw() {
476        #[derive(Serialize, Deserialize, PartialEq, Debug)]
477        struct RawContainer {
478            #[serde(with = "with_raw_descriptor")]
479            rd: RawDescriptor,
480        }
481        // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
482        // descriptors for this test.
483        let fake_rd = 5_123_457_i32;
484        let v = RawContainer {
485            rd: fake_rd as RawDescriptor,
486        };
487        let v_serialize = SerializeDescriptors::new(&v);
488        let json = serde_json::to_string(&v_serialize).unwrap();
489        let descriptors = v_serialize.into_descriptors();
490        let res = deserialize(&json, &descriptors);
491        assert_eq!(v, res);
492    }
493
494    #[test]
495    fn file() {
496        #[derive(Serialize, Deserialize)]
497        struct FileContainer {
498            #[serde(with = "with_as_descriptor")]
499            file: File,
500        }
501
502        let v = FileContainer {
503            file: tempfile().unwrap(),
504        };
505        let v_serialize = SerializeDescriptors::new(&v);
506        let json = serde_json::to_string(&v_serialize).unwrap();
507        let descriptors = v_serialize.into_descriptors();
508        let v = ManuallyDrop::new(v);
509        let res: FileContainer = deserialize(&json, &descriptors);
510        assert_eq!(v.file.as_raw_descriptor(), res.file.as_raw_descriptor());
511    }
512
513    #[test]
514    fn option() {
515        #[derive(Serialize, Deserialize)]
516        struct TestOption {
517            a: Option<FileSerdeWrapper>,
518            b: Option<FileSerdeWrapper>,
519        }
520
521        let v = TestOption {
522            a: None,
523            b: Some(tempfile().unwrap().into()),
524        };
525        let v_serialize = SerializeDescriptors::new(&v);
526        let json = serde_json::to_string(&v_serialize).unwrap();
527        let descriptors = v_serialize.into_descriptors();
528        let v = ManuallyDrop::new(v);
529        let res: TestOption = deserialize(&json, &descriptors);
530        assert!(res.a.is_none());
531        assert!(res.b.is_some());
532        assert_eq!(
533            v.b.as_ref().unwrap().as_raw_descriptor(),
534            res.b.unwrap().as_raw_descriptor()
535        );
536    }
537
538    #[test]
539    fn map() {
540        let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
541        v.insert("a".into(), tempfile().unwrap().into());
542        v.insert("b".into(), tempfile().unwrap().into());
543        v.insert("c".into(), tempfile().unwrap().into());
544        let v_serialize = SerializeDescriptors::new(&v);
545        let json = serde_json::to_string(&v_serialize).unwrap();
546        let descriptors = v_serialize.into_descriptors();
547        // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
548        // done this way to prevent a double close of the files (which should reside in `res`)
549        // without triggering the leak sanitizer on `v`'s HashMap heap memory.
550        let v: HashMap<_, _> = v
551            .into_iter()
552            .map(|(k, v)| (k, ManuallyDrop::new(v)))
553            .collect();
554        let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
555
556        assert_eq!(v.len(), res.len());
557        for (k, v) in v.iter() {
558            assert_eq!(
559                res.get(k).unwrap().as_raw_descriptor(),
560                v.as_raw_descriptor()
561            );
562        }
563    }
564}