devices/virtio/vhost_user_backend/
params.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::fmt::Debug;
6
7use argh::FromArgValue;
8use serde::Deserialize;
9use serde_keyvalue::ErrorKind;
10use serde_keyvalue::KeyValueDeserializer;
11
12/// Extends any device configuration with a mandatory extra "vhost" parameter to specify the socket
13/// or PCI device to use in order to communicate with a vhost client.
14///
15/// The `vhost` argument must come first, followed by all the arguments required by `device`.
16#[derive(Debug, Deserialize)]
17#[serde(deny_unknown_fields)]
18// TODO(b/262345003): This requires a custom `Deserialize` implementation to support configuration
19// files properly. Right now the pseudo-flattening is done by the `FromArgValue` implementation,
20// which is only used with command-line arguments. A good `Deserialize` implementation would allow
21// the same behavior with any deserializer, but requires some serde-fu that is above my current
22// skills.
23pub struct VhostUserParams<T: Debug> {
24    pub vhost: String,
25    pub device: T,
26}
27
28impl<T> FromArgValue for VhostUserParams<T>
29where
30    T: Debug + for<'de> Deserialize<'de>,
31{
32    fn from_arg_value(value: &str) -> std::result::Result<Self, String> {
33        // `from_arg_value` returns a `String` as error, but our deserializer API defines its own
34        // error type. Perform parsing from a closure so we can easily map returned errors.
35        let builder = move || {
36            let mut deserializer = KeyValueDeserializer::from(value);
37
38            // Parse the "vhost" parameter
39            let id = deserializer.parse_identifier()?;
40            if id != "vhost" {
41                return Err(deserializer
42                    .error_here(ErrorKind::SerdeError("expected \"vhost\" parameter".into())));
43            }
44            if deserializer.next_char() != Some('=') {
45                return Err(deserializer.error_here(ErrorKind::ExpectedEqual));
46            }
47            let vhost = deserializer.parse_string()?;
48            match deserializer.next_char() {
49                Some(',') | None => (),
50                _ => return Err(deserializer.error_here(ErrorKind::ExpectedComma)),
51            }
52
53            // Parse the device-specific parameters and finish
54            let device = T::deserialize(&mut deserializer)?;
55            deserializer.finish()?;
56
57            Ok(Self {
58                vhost: vhost.into(),
59                device,
60            })
61        };
62
63        builder().map_err(|e| e.to_string())
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use std::path::PathBuf;
70
71    use argh::FromArgValue;
72    use serde::Deserialize;
73    use serde_keyvalue::*;
74
75    use super::VhostUserParams;
76
77    #[derive(Debug, Deserialize, PartialEq, Eq)]
78    #[serde(deny_unknown_fields, rename_all = "kebab-case")]
79    struct DummyDevice {
80        path: PathBuf,
81        #[serde(default)]
82        boom_range: u32,
83    }
84
85    fn from_arg_value(s: &str) -> Result<VhostUserParams<DummyDevice>, String> {
86        VhostUserParams::<DummyDevice>::from_arg_value(s)
87    }
88
89    #[test]
90    fn vhost_user_params() {
91        let device = from_arg_value("vhost=vhost_sock,path=/path/to/dummy,boom-range=42").unwrap();
92        assert_eq!(device.vhost.as_str(), "vhost_sock");
93        assert_eq!(
94            device.device,
95            DummyDevice {
96                path: "/path/to/dummy".into(),
97                boom_range: 42,
98            }
99        );
100
101        // Default parameter of device not specified.
102        let device = from_arg_value("vhost=vhost_sock,path=/path/to/dummy").unwrap();
103        assert_eq!(device.vhost.as_str(), "vhost_sock");
104        assert_eq!(
105            device.device,
106            DummyDevice {
107                path: "/path/to/dummy".into(),
108                boom_range: Default::default(),
109            }
110        );
111
112        // Invalid parameter is rejected.
113        assert_eq!(
114            from_arg_value("vhost=vhost_sock,path=/path/to/dummy,boom-range=42,invalid-param=10")
115                .unwrap_err(),
116            "unknown field `invalid-param`, expected `path` or `boom-range`".to_string(),
117        );
118
119        // Device path can be parsed even if specified as a number.
120        // This ensures that we don't flatten the `device` member, which would result in
121        // `deserialize_any` being called and the type of `path` to be mistaken for an integer.
122        let device = from_arg_value("vhost=vhost_sock,path=10").unwrap();
123        assert_eq!(device.vhost.as_str(), "vhost_sock");
124        assert_eq!(
125            device.device,
126            DummyDevice {
127                path: "10".into(),
128                boom_range: Default::default(),
129            }
130        );
131
132        // Misplaced `vhost` parameter is rejected
133        let _ = from_arg_value("path=/path/to/dummy,vhost=vhost_sock,boom-range=42").unwrap_err();
134    }
135}