1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Copyright 2021 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

pub mod device;

use std::fmt::Debug;

use argh::FromArgValue;
use serde::Deserialize;
use serde_keyvalue::ErrorKind;
use serde_keyvalue::KeyValueDeserializer;

pub use self::device::*;

/// Extends any device configuration with a mandatory extra "vhost" parameter to specify the socket
/// or PCI device to use in order to communicate with a vhost client.
///
/// The `vhost` argument must come first, followed by all the arguments required by `device`.
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
// TODO(b/262345003): This requires a custom `Deserialize` implementation to support configuration
// files properly. Right now the pseudo-flattening is done by the `FromArgValue` implementation,
// which is only used with command-line arguments. A good `Deserialize` implementation would allow
// the same behavior with any deserializer, but requires some serde-fu that is above my current
// skills.
pub struct VhostUserParams<T: Debug> {
    pub vhost: String,
    pub device: T,
}

impl<T> FromArgValue for VhostUserParams<T>
where
    T: Debug + for<'de> Deserialize<'de>,
{
    fn from_arg_value(value: &str) -> std::result::Result<Self, String> {
        // `from_arg_value` returns a `String` as error, but our deserializer API defines its own
        // error type. Perform parsing from a closure so we can easily map returned errors.
        let builder = move || {
            let mut deserializer = KeyValueDeserializer::from(value);

            // Parse the "vhost" parameter
            let id = deserializer.parse_identifier()?;
            if id != "vhost" {
                return Err(deserializer
                    .error_here(ErrorKind::SerdeError("expected \"vhost\" parameter".into())));
            }
            if deserializer.next_char() != Some('=') {
                return Err(deserializer.error_here(ErrorKind::ExpectedEqual));
            }
            let vhost = deserializer.parse_string()?;
            match deserializer.next_char() {
                Some(',') | None => (),
                _ => return Err(deserializer.error_here(ErrorKind::ExpectedComma)),
            }

            // Parse the device-specific parameters and finish
            let device = T::deserialize(&mut deserializer)?;
            deserializer.finish()?;

            Ok(Self {
                vhost: vhost.into(),
                device,
            })
        };

        builder().map_err(|e| e.to_string())
    }
}

#[cfg(test)]
mod tests {
    use std::path::PathBuf;

    use argh::FromArgValue;
    use serde::Deserialize;
    use serde_keyvalue::*;

    use super::VhostUserParams;

    #[derive(Debug, Deserialize, PartialEq, Eq)]
    #[serde(deny_unknown_fields, rename_all = "kebab-case")]
    struct DummyDevice {
        path: PathBuf,
        #[serde(default)]
        boom_range: u32,
    }

    fn from_arg_value(s: &str) -> Result<VhostUserParams<DummyDevice>, String> {
        VhostUserParams::<DummyDevice>::from_arg_value(s)
    }

    #[test]
    fn vhost_user_params() {
        let device = from_arg_value("vhost=vhost_sock,path=/path/to/dummy,boom-range=42").unwrap();
        assert_eq!(device.vhost.as_str(), "vhost_sock");
        assert_eq!(
            device.device,
            DummyDevice {
                path: "/path/to/dummy".into(),
                boom_range: 42,
            }
        );

        // Default parameter of device not specified.
        let device = from_arg_value("vhost=vhost_sock,path=/path/to/dummy").unwrap();
        assert_eq!(device.vhost.as_str(), "vhost_sock");
        assert_eq!(
            device.device,
            DummyDevice {
                path: "/path/to/dummy".into(),
                boom_range: Default::default(),
            }
        );

        // Invalid parameter is rejected.
        assert_eq!(
            from_arg_value("vhost=vhost_sock,path=/path/to/dummy,boom-range=42,invalid-param=10")
                .unwrap_err(),
            "unknown field `invalid-param`, expected `path` or `boom-range`".to_string(),
        );

        // Device path can be parsed even if specified as a number.
        // This ensures that we don't flatten the `device` member, which would result in
        // `deserialize_any` being called and the type of `path` to be mistaken for an integer.
        let device = from_arg_value("vhost=vhost_sock,path=10").unwrap();
        assert_eq!(device.vhost.as_str(), "vhost_sock");
        assert_eq!(
            device.device,
            DummyDevice {
                path: "10".into(),
                boom_range: Default::default(),
            }
        );

        // Misplaced `vhost` parameter is rejected
        let _ = from_arg_value("path=/path/to/dummy,vhost=vhost_sock,boom-range=42").unwrap_err();
    }
}