1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright 2022 The Chromium OS Authors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Unix specific code that keeps rest of the code in the crate platform independent.

use std::os::unix::io::IntoRawFd;

use base::AsRawDescriptor;
use base::FromRawDescriptor;
use base::RawDescriptor;
use base::SafeDescriptor;

use crate::master_req_handler::MasterReqHandler;
use crate::Result;
use crate::VhostUserMasterReqHandler;

impl<S: VhostUserMasterReqHandler> AsRawDescriptor for MasterReqHandler<S> {
    /// Used for polling.
    fn as_raw_descriptor(&self) -> RawDescriptor {
        self.sub_sock.as_raw_descriptor()
    }
}

impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
    /// Create a `MasterReqHandler` that uses a Unix stream internally.
    pub fn with_stream(backend: S) -> Result<Self> {
        Self::new(
            backend,
            Box::new(|stream|
                // SAFETY:
                // Safe because we own the raw fd.
                unsafe {
                    SafeDescriptor::from_raw_descriptor(stream.into_raw_fd())
            }),
        )
    }
}

#[cfg(test)]
mod tests {
    use base::AsRawDescriptor;
    use base::Descriptor;
    use base::FromRawDescriptor;
    use base::INVALID_DESCRIPTOR;

    use super::*;
    use crate::message::VhostUserFSSlaveMsg;
    use crate::HandlerResult;
    use crate::Slave;
    use crate::SystemStream;
    use crate::VhostUserMasterReqHandler;

    struct MockMasterReqHandler {}

    impl VhostUserMasterReqHandler for MockMasterReqHandler {
        /// Handle virtio-fs map file requests from the slave.
        fn fs_slave_map(
            &mut self,
            _fs: &VhostUserFSSlaveMsg,
            _fd: &dyn AsRawDescriptor,
        ) -> HandlerResult<u64> {
            Ok(0)
        }

        /// Handle virtio-fs unmap file requests from the slave.
        fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
            Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
        }
    }

    #[test]
    fn test_new_master_req_handler() {
        let backend = MockMasterReqHandler {};
        let mut handler = MasterReqHandler::with_stream(backend).unwrap();

        let tx_descriptor = handler.take_tx_descriptor();
        assert!(tx_descriptor.as_raw_descriptor() >= 0);
        assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
    }

    #[test]
    fn test_master_slave_req_handler() {
        let backend = MockMasterReqHandler {};
        let mut handler = MasterReqHandler::with_stream(backend).unwrap();

        let tx_descriptor = handler.take_tx_descriptor();
        // SAFETY: return value of dup is checked.
        let fd = unsafe { libc::dup(tx_descriptor.as_raw_descriptor()) };
        if fd < 0 {
            panic!("failed to duplicated tx fd!");
        }
        // SAFETY: fd is created above and is valid
        let stream = unsafe { SystemStream::from_raw_descriptor(fd) };
        let mut fs_cache = Slave::from_stream(stream);

        std::thread::spawn(move || {
            let res = handler.handle_request().unwrap();
            assert_eq!(res, 0);
            handler.handle_request().unwrap_err();
        });

        fs_cache
            .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd))
            .unwrap();
        // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
        // slave side.
        fs_cache
            .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
            .unwrap();
    }

    #[test]
    fn test_master_slave_req_handler_with_ack() {
        let backend = MockMasterReqHandler {};
        let mut handler = MasterReqHandler::with_stream(backend).unwrap();
        handler.set_reply_ack_flag(true);

        let tx_descriptor = handler.take_tx_descriptor();
        // SAFETY: return value of dup is checked.
        let fd = unsafe { libc::dup(tx_descriptor.as_raw_descriptor()) };
        if fd < 0 {
            panic!("failed to duplicated tx fd!");
        }

        // SAFETY: fd is created above and is valid
        let stream = unsafe { SystemStream::from_raw_descriptor(fd) };
        let mut fs_cache = Slave::from_stream(stream);

        std::thread::spawn(move || {
            let res = handler.handle_request().unwrap();
            assert_eq!(res, 0);
            handler.handle_request().unwrap_err();
        });

        fs_cache.set_reply_ack_flag(true);
        fs_cache
            .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd))
            .unwrap();
        fs_cache
            .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
            .unwrap_err();
    }
}