devices/usb/xhci/
scatter_gather_buffer.rs1use bit_field::Error as BitFieldError;
6use remain::sorted;
7use thiserror::Error;
8use vm_memory::GuestAddress;
9use vm_memory::GuestMemory;
10use vm_memory::GuestMemoryError;
11
12use super::xhci_abi::AddressedTrb;
13use super::xhci_abi::Error as TrbError;
14use super::xhci_abi::NormalTrb;
15use super::xhci_abi::TransferDescriptor;
16use super::xhci_abi::TrbCast;
17use super::xhci_abi::TrbType;
18
19#[sorted]
20#[derive(Error, Debug)]
21pub enum Error {
22 #[error("should not build buffer from trb type: {0:?}")]
23 BadTrbType(TrbType),
24 #[error("cannot cast trb: {0}")]
25 CastTrb(TrbError),
26 #[error("immediate data longer than allowed: {0}")]
27 ImmediateDataTooLong(usize),
28 #[error("cannot read guest memory: {0}")]
29 ReadGuestMemory(GuestMemoryError),
30 #[error("unknown trb type: {0}")]
31 UnknownTrbType(BitFieldError),
32 #[error("cannot write guest memory: {0}")]
33 WriteGuestMemory(GuestMemoryError),
34}
35
36type Result<T> = std::result::Result<T, Error>;
37
38pub struct ScatterGatherBuffer {
41 mem: GuestMemory,
42 td: TransferDescriptor,
43}
44
45impl ScatterGatherBuffer {
46 pub fn new(mem: GuestMemory, td: TransferDescriptor) -> Result<ScatterGatherBuffer> {
48 for atrb in &td {
49 let trb_type = atrb.trb.get_trb_type().map_err(Error::UnknownTrbType)?;
50 if trb_type != TrbType::Normal
51 && trb_type != TrbType::DataStage
52 && trb_type != TrbType::Isoch
53 {
54 return Err(Error::BadTrbType(trb_type));
55 }
56 }
57 Ok(ScatterGatherBuffer { mem, td })
58 }
59
60 pub fn len(&self) -> Result<usize> {
62 let mut total_len = 0usize;
63 for atrb in &self.td {
64 total_len += atrb
65 .trb
66 .cast::<NormalTrb>()
67 .map_err(Error::CastTrb)?
68 .get_trb_transfer_length() as usize;
69 }
70 Ok(total_len)
71 }
72
73 pub fn is_empty(&self) -> Result<bool> {
74 Ok(self.len()? == 0)
75 }
76
77 fn get_trb_data(&self, atrb: &AddressedTrb) -> Result<(GuestAddress, usize)> {
81 let normal_trb = atrb.trb.cast::<NormalTrb>().map_err(Error::CastTrb)?;
82 let len = normal_trb.get_trb_transfer_length() as usize;
83 let addr = if normal_trb.get_immediate_data() == 1 {
84 if len > 8 {
86 return Err(Error::ImmediateDataTooLong(len));
87 }
88 atrb.gpa
89 } else {
90 normal_trb.get_data_buffer_pointer()
91 };
92 Ok((GuestAddress(addr), len))
93 }
94
95 pub fn read(&self, buffer: &mut [u8]) -> Result<usize> {
97 let mut total_size = 0usize;
98 let mut offset = 0;
99 for atrb in &self.td {
100 let (guest_address, len) = self.get_trb_data(atrb)?;
101 let buffer_len = {
102 if offset == buffer.len() {
103 return Ok(total_size);
104 }
105 if buffer.len() > offset + len {
106 len
107 } else {
108 buffer.len() - offset
109 }
110 };
111 let buffer_end = offset + buffer_len;
112 let cur_buffer = &mut buffer[offset..buffer_end];
113 offset = buffer_end;
114 total_size += self
115 .mem
116 .read_at_addr(cur_buffer, guest_address)
117 .map_err(Error::ReadGuestMemory)?;
118 }
119 Ok(total_size)
120 }
121
122 pub fn write(&self, buffer: &[u8]) -> Result<usize> {
124 let mut total_size = 0usize;
125 let mut offset = 0;
126 for atrb in &self.td {
127 let (guest_address, len) = self.get_trb_data(atrb)?;
128 let buffer_len = {
129 if offset == buffer.len() {
130 return Ok(total_size);
131 }
132 if buffer.len() > offset + len {
133 len
134 } else {
135 buffer.len() - offset
136 }
137 };
138 let buffer_end = offset + buffer_len;
139 let cur_buffer = &buffer[offset..buffer_end];
140 offset = buffer_end;
141 total_size += self
142 .mem
143 .write_at_addr(cur_buffer, guest_address)
144 .map_err(Error::WriteGuestMemory)?;
145 }
146 Ok(total_size)
147 }
148}
149
150#[cfg(test)]
151mod test {
152 use base::pagesize;
153
154 use super::*;
155 use crate::usb::xhci::xhci_abi::AddressedTrb;
156 use crate::usb::xhci::xhci_abi::Trb;
157
158 #[test]
159 fn scatter_gather_buffer_test() {
160 let gm = GuestMemory::new(&[(GuestAddress(0), pagesize() as u64)]).unwrap();
161 let mut td = TransferDescriptor::new();
162
163 let mut trb = Trb::new();
167 let ntrb = trb.cast_mut::<NormalTrb>().unwrap();
168 ntrb.set_trb_type(TrbType::Normal);
169 ntrb.set_data_buffer_pointer(0x100);
170 ntrb.set_trb_transfer_length(4);
171 td.push(AddressedTrb { trb, gpa: 0 });
172
173 let mut trb = Trb::new();
174 let ntrb = trb.cast_mut::<NormalTrb>().unwrap();
175 ntrb.set_trb_type(TrbType::Normal);
176 ntrb.set_data_buffer_pointer(0x200);
177 ntrb.set_trb_transfer_length(2);
178 td.push(AddressedTrb { trb, gpa: 0 });
179
180 let mut trb = Trb::new();
181 let ntrb = trb.cast_mut::<NormalTrb>().unwrap();
182 ntrb.set_trb_type(TrbType::Normal);
183 ntrb.set_data_buffer_pointer(0x300);
184 ntrb.set_trb_transfer_length(1);
185 td.push(AddressedTrb { trb, gpa: 0 });
186
187 let buffer = ScatterGatherBuffer::new(gm.clone(), td).unwrap();
188
189 assert_eq!(buffer.len().unwrap(), 7);
190 let data_to_write: [u8; 7] = [7, 6, 5, 4, 3, 2, 1];
191 buffer.write(&data_to_write).unwrap();
192
193 let mut d = [0; 4];
194 gm.read_exact_at_addr(&mut d, GuestAddress(0x100)).unwrap();
195 assert_eq!(d, [7, 6, 5, 4]);
196 gm.read_exact_at_addr(&mut d, GuestAddress(0x200)).unwrap();
197 assert_eq!(d, [3, 2, 0, 0]);
198 gm.read_exact_at_addr(&mut d, GuestAddress(0x300)).unwrap();
199 assert_eq!(d, [1, 0, 0, 0]);
200
201 let mut data_read = [0; 7];
202 buffer.read(&mut data_read).unwrap();
203 assert_eq!(data_to_write, data_read);
204 }
205
206 #[test]
207 fn immediate_data_test() {
208 let gm = GuestMemory::new(&[(GuestAddress(0), pagesize() as u64)]).unwrap();
209 let mut td = TransferDescriptor::new();
210
211 let expected_immediate_data: [u8; 8] = [0xDE, 0xAD, 0xBE, 0xEF, 0xF0, 0x0D, 0xCA, 0xFE];
212
213 let mut trb = Trb::new();
214 let ntrb = trb.cast_mut::<NormalTrb>().unwrap();
215 ntrb.set_trb_type(TrbType::Normal);
216 ntrb.set_data_buffer_pointer(u64::from_le_bytes(expected_immediate_data));
217 ntrb.set_trb_transfer_length(8);
218 ntrb.set_immediate_data(1);
219 td.push(AddressedTrb { trb, gpa: 0xC00 });
220
221 gm.write_obj_at_addr(trb, GuestAddress(0xc00)).unwrap();
222
223 let buffer = ScatterGatherBuffer::new(gm, td).unwrap();
224
225 let mut data_read = [0; 8];
226 buffer.read(&mut data_read).unwrap();
227 assert_eq!(data_read, expected_immediate_data);
228 }
229}