1use std::collections::BTreeMap;
6use std::io;
7use std::io::Read;
8use std::io::Write;
9use std::ops::BitOrAssign;
10
11use anyhow::anyhow;
12use anyhow::Context;
13use base::error;
14use base::Event;
15use base::EventToken;
16use base::RawDescriptor;
17use base::WaitContext;
18use base::WorkerThread;
19use remain::sorted;
20use thiserror::Error;
21use vm_memory::GuestMemory;
22
23use super::DescriptorChain;
24use super::DeviceType;
25use super::Interrupt;
26use super::Queue;
27use super::VirtioDevice;
28
29const QUEUE_SIZE: u16 = 2;
33const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE];
34
35const TPM_BUFSIZE: usize = 4096;
39
40struct Worker {
41 queue: Queue,
42 backend: Box<dyn TpmBackend>,
43}
44
45pub trait TpmBackend: Send {
46 fn execute_command<'a>(&'a mut self, command: &[u8]) -> &'a [u8];
47}
48
49impl Worker {
50 fn perform_work(&mut self, desc: &mut DescriptorChain) -> Result<u32> {
51 let available_bytes = desc.reader.available_bytes();
52 if available_bytes > TPM_BUFSIZE {
53 return Err(Error::CommandTooLong {
54 size: available_bytes,
55 });
56 }
57
58 let mut command = vec![0u8; available_bytes];
59 desc.reader.read_exact(&mut command).map_err(Error::Read)?;
60
61 let response = self.backend.execute_command(&command);
62
63 if response.len() > TPM_BUFSIZE {
64 return Err(Error::ResponseTooLong {
65 size: response.len(),
66 });
67 }
68
69 let writer_len = desc.writer.available_bytes();
70 if response.len() > writer_len {
71 return Err(Error::BufferTooSmall {
72 size: writer_len,
73 required: response.len(),
74 });
75 }
76
77 desc.writer.write_all(response).map_err(Error::Write)?;
78
79 Ok(desc.writer.bytes_written() as u32)
80 }
81
82 fn process_queue(&mut self) -> NeedsInterrupt {
83 let mut needs_interrupt = NeedsInterrupt::No;
84 while let Some(mut avail_desc) = self.queue.pop() {
85 let len = match self.perform_work(&mut avail_desc) {
86 Ok(len) => len,
87 Err(err) => {
88 error!("{}", err);
89 0
90 }
91 };
92
93 self.queue.add_used_with_bytes_written(avail_desc, len);
94 needs_interrupt = NeedsInterrupt::Yes;
95 }
96
97 needs_interrupt
98 }
99
100 fn run(mut self, kill_evt: Event) -> anyhow::Result<()> {
101 #[derive(EventToken, Debug)]
102 enum Token {
103 QueueAvailable,
105 Kill,
107 }
108
109 let wait_ctx = WaitContext::build_with(&[
110 (self.queue.event(), Token::QueueAvailable),
111 (&kill_evt, Token::Kill),
112 ])
113 .context("WaitContext::build_with")?;
114
115 loop {
116 let events = wait_ctx.wait().context("WaitContext::wait")?;
117 let mut needs_interrupt = NeedsInterrupt::No;
118 for event in events.iter().filter(|e| e.is_readable) {
119 match event.token {
120 Token::QueueAvailable => {
121 self.queue.event().wait().context("Event::wait")?;
122 needs_interrupt |= self.process_queue();
123 }
124 Token::Kill => return Ok(()),
125 }
126 }
127 if needs_interrupt == NeedsInterrupt::Yes {
128 self.queue.trigger_interrupt();
129 }
130 }
131 }
132}
133
134pub struct Tpm {
136 backend: Option<Box<dyn TpmBackend>>,
137 worker_thread: Option<WorkerThread<()>>,
138 features: u64,
139}
140
141impl Tpm {
142 pub fn new(backend: Box<dyn TpmBackend>, base_features: u64) -> Tpm {
143 Tpm {
144 backend: Some(backend),
145 worker_thread: None,
146 features: base_features,
147 }
148 }
149}
150
151impl VirtioDevice for Tpm {
152 fn keep_rds(&self) -> Vec<RawDescriptor> {
153 Vec::new()
154 }
155
156 fn device_type(&self) -> DeviceType {
157 DeviceType::Tpm
158 }
159
160 fn queue_max_sizes(&self) -> &[u16] {
161 QUEUE_SIZES
162 }
163
164 fn features(&self) -> u64 {
165 self.features
166 }
167
168 fn activate(
169 &mut self,
170 _mem: GuestMemory,
171 _interrupt: Interrupt,
172 mut queues: BTreeMap<usize, Queue>,
173 ) -> anyhow::Result<()> {
174 if queues.len() != 1 {
175 return Err(anyhow!("expected 1 queue, got {}", queues.len()));
176 }
177 let queue = queues.pop_first().unwrap().1;
178
179 let backend = self.backend.take().context("no backend in vtpm")?;
180
181 let worker = Worker { queue, backend };
182
183 self.worker_thread = Some(WorkerThread::start("v_tpm", |kill_evt| {
184 if let Err(e) = worker.run(kill_evt) {
185 error!("virtio-tpm worker failed: {:#}", e);
186 }
187 }));
188
189 Ok(())
190 }
191}
192
193#[derive(PartialEq, Eq)]
194enum NeedsInterrupt {
195 Yes,
196 No,
197}
198
199impl BitOrAssign for NeedsInterrupt {
200 fn bitor_assign(&mut self, rhs: NeedsInterrupt) {
201 if rhs == NeedsInterrupt::Yes {
202 *self = NeedsInterrupt::Yes;
203 }
204 }
205}
206
207type Result<T> = std::result::Result<T, Error>;
208
209#[sorted]
210#[derive(Error, Debug)]
211enum Error {
212 #[error("vtpm response buffer is too small: {size} < {required} bytes")]
213 BufferTooSmall { size: usize, required: usize },
214 #[error("vtpm command is too long: {size} > {} bytes", TPM_BUFSIZE)]
215 CommandTooLong { size: usize },
216 #[error("vtpm failed to read from guest memory: {0}")]
217 Read(io::Error),
218 #[error(
219 "vtpm simulator generated a response that is unexpectedly long: {size} > {} bytes",
220 TPM_BUFSIZE
221 )]
222 ResponseTooLong { size: usize },
223 #[error("vtpm failed to write to guest memory: {0}")]
224 Write(io::Error),
225}