devices/virtio/
tpm.rs

1// Copyright 2018 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::collections::BTreeMap;
6use std::io;
7use std::io::Read;
8use std::io::Write;
9use std::ops::BitOrAssign;
10
11use anyhow::anyhow;
12use anyhow::Context;
13use base::error;
14use base::Event;
15use base::EventToken;
16use base::RawDescriptor;
17use base::WaitContext;
18use base::WorkerThread;
19use remain::sorted;
20use thiserror::Error;
21use vm_memory::GuestMemory;
22
23use super::DescriptorChain;
24use super::DeviceType;
25use super::Interrupt;
26use super::Queue;
27use super::VirtioDevice;
28
29// A single queue of size 2. The guest kernel driver will enqueue a single
30// descriptor chain containing one command buffer and one response buffer at a
31// time.
32const QUEUE_SIZE: u16 = 2;
33const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE];
34
35// Maximum command or response message size permitted by this device
36// implementation. Named to match the equivalent constant in Linux's tpm.h.
37// There is no hard requirement that the value is the same but it makes sense.
38const TPM_BUFSIZE: usize = 4096;
39
40struct Worker {
41    queue: Queue,
42    backend: Box<dyn TpmBackend>,
43}
44
45pub trait TpmBackend: Send {
46    fn execute_command<'a>(&'a mut self, command: &[u8]) -> &'a [u8];
47}
48
49impl Worker {
50    fn perform_work(&mut self, desc: &mut DescriptorChain) -> Result<u32> {
51        let available_bytes = desc.reader.available_bytes();
52        if available_bytes > TPM_BUFSIZE {
53            return Err(Error::CommandTooLong {
54                size: available_bytes,
55            });
56        }
57
58        let mut command = vec![0u8; available_bytes];
59        desc.reader.read_exact(&mut command).map_err(Error::Read)?;
60
61        let response = self.backend.execute_command(&command);
62
63        if response.len() > TPM_BUFSIZE {
64            return Err(Error::ResponseTooLong {
65                size: response.len(),
66            });
67        }
68
69        let writer_len = desc.writer.available_bytes();
70        if response.len() > writer_len {
71            return Err(Error::BufferTooSmall {
72                size: writer_len,
73                required: response.len(),
74            });
75        }
76
77        desc.writer.write_all(response).map_err(Error::Write)?;
78
79        Ok(desc.writer.bytes_written() as u32)
80    }
81
82    fn process_queue(&mut self) -> NeedsInterrupt {
83        let mut needs_interrupt = NeedsInterrupt::No;
84        while let Some(mut avail_desc) = self.queue.pop() {
85            let len = match self.perform_work(&mut avail_desc) {
86                Ok(len) => len,
87                Err(err) => {
88                    error!("{}", err);
89                    0
90                }
91            };
92
93            self.queue.add_used_with_bytes_written(avail_desc, len);
94            needs_interrupt = NeedsInterrupt::Yes;
95        }
96
97        needs_interrupt
98    }
99
100    fn run(mut self, kill_evt: Event) -> anyhow::Result<()> {
101        #[derive(EventToken, Debug)]
102        enum Token {
103            // A request is ready on the queue.
104            QueueAvailable,
105            // The parent thread requested an exit.
106            Kill,
107        }
108
109        let wait_ctx = WaitContext::build_with(&[
110            (self.queue.event(), Token::QueueAvailable),
111            (&kill_evt, Token::Kill),
112        ])
113        .context("WaitContext::build_with")?;
114
115        loop {
116            let events = wait_ctx.wait().context("WaitContext::wait")?;
117            let mut needs_interrupt = NeedsInterrupt::No;
118            for event in events.iter().filter(|e| e.is_readable) {
119                match event.token {
120                    Token::QueueAvailable => {
121                        self.queue.event().wait().context("Event::wait")?;
122                        needs_interrupt |= self.process_queue();
123                    }
124                    Token::Kill => return Ok(()),
125                }
126            }
127            if needs_interrupt == NeedsInterrupt::Yes {
128                self.queue.trigger_interrupt();
129            }
130        }
131    }
132}
133
134/// Virtio vTPM device.
135pub struct Tpm {
136    backend: Option<Box<dyn TpmBackend>>,
137    worker_thread: Option<WorkerThread<()>>,
138    features: u64,
139}
140
141impl Tpm {
142    pub fn new(backend: Box<dyn TpmBackend>, base_features: u64) -> Tpm {
143        Tpm {
144            backend: Some(backend),
145            worker_thread: None,
146            features: base_features,
147        }
148    }
149}
150
151impl VirtioDevice for Tpm {
152    fn keep_rds(&self) -> Vec<RawDescriptor> {
153        Vec::new()
154    }
155
156    fn device_type(&self) -> DeviceType {
157        DeviceType::Tpm
158    }
159
160    fn queue_max_sizes(&self) -> &[u16] {
161        QUEUE_SIZES
162    }
163
164    fn features(&self) -> u64 {
165        self.features
166    }
167
168    fn activate(
169        &mut self,
170        _mem: GuestMemory,
171        _interrupt: Interrupt,
172        mut queues: BTreeMap<usize, Queue>,
173    ) -> anyhow::Result<()> {
174        if queues.len() != 1 {
175            return Err(anyhow!("expected 1 queue, got {}", queues.len()));
176        }
177        let queue = queues.pop_first().unwrap().1;
178
179        let backend = self.backend.take().context("no backend in vtpm")?;
180
181        let worker = Worker { queue, backend };
182
183        self.worker_thread = Some(WorkerThread::start("v_tpm", |kill_evt| {
184            if let Err(e) = worker.run(kill_evt) {
185                error!("virtio-tpm worker failed: {:#}", e);
186            }
187        }));
188
189        Ok(())
190    }
191}
192
193#[derive(PartialEq, Eq)]
194enum NeedsInterrupt {
195    Yes,
196    No,
197}
198
199impl BitOrAssign for NeedsInterrupt {
200    fn bitor_assign(&mut self, rhs: NeedsInterrupt) {
201        if rhs == NeedsInterrupt::Yes {
202            *self = NeedsInterrupt::Yes;
203        }
204    }
205}
206
207type Result<T> = std::result::Result<T, Error>;
208
209#[sorted]
210#[derive(Error, Debug)]
211enum Error {
212    #[error("vtpm response buffer is too small: {size} < {required} bytes")]
213    BufferTooSmall { size: usize, required: usize },
214    #[error("vtpm command is too long: {size} > {} bytes", TPM_BUFSIZE)]
215    CommandTooLong { size: usize },
216    #[error("vtpm failed to read from guest memory: {0}")]
217    Read(io::Error),
218    #[error(
219        "vtpm simulator generated a response that is unexpectedly long: {size} > {} bytes",
220        TPM_BUFSIZE
221    )]
222    ResponseTooLong { size: usize },
223    #[error("vtpm failed to write to guest memory: {0}")]
224    Write(io::Error),
225}