1use std::cmp::min;
8use std::fs::File;
9use std::io;
10use std::io::ErrorKind;
11use std::io::Read;
12use std::io::Seek;
13use std::sync::Arc;
14use std::sync::RwLock;
15
16use anyhow::bail;
17use anyhow::Context;
18use async_trait::async_trait;
19use base::AsRawDescriptor;
20use base::FileAllocate;
21use base::FileReadWriteAtVolatile;
22use base::FileSetLen;
23use base::RawDescriptor;
24use base::VolatileSlice;
25use cros_async::BackingMemory;
26use cros_async::Executor;
27use cros_async::IoSource;
28
29use crate::AsyncDisk;
30use crate::DiskFile;
31use crate::DiskGetLen;
32use crate::Error as DiskError;
33use crate::Result as DiskResult;
34use crate::ToAsyncDisk;
35
36pub const ZSTD_FRAME_MAGIC: u32 = 0xFD2FB528;
38
39pub const ZSTD_SKIPPABLE_MAGIC_LOW: u32 = 0x184D2A50;
41pub const ZSTD_SKIPPABLE_MAGIC_HIGH: u32 = 0x184D2A5F;
42pub const ZSTD_SEEK_TABLE_MAGIC: u32 = 0x8F92EAB1;
43
44pub const ZSTD_DEFAULT_FRAME_SIZE: usize = 128 << 10; #[derive(Clone, Debug)]
47pub struct ZstdSeekTable {
48 cumulative_decompressed_sizes: Vec<u64>,
51 cumulative_compressed_sizes: Vec<u64>,
54}
55
56impl ZstdSeekTable {
57 pub fn from_footer(
59 seek_table_entries: &[u8],
60 num_frames: u32,
61 checksum_flag: bool,
62 ) -> anyhow::Result<ZstdSeekTable> {
63 let mut cumulative_decompressed_size: u64 = 0;
64 let mut cumulative_compressed_size: u64 = 0;
65 let mut cumulative_decompressed_sizes = Vec::with_capacity(num_frames as usize + 1);
66 let mut cumulative_compressed_sizes = Vec::with_capacity(num_frames as usize + 1);
67 let mut offset = 0;
68 cumulative_decompressed_sizes.push(0);
69 cumulative_compressed_sizes.push(0);
70 for _ in 0..num_frames {
71 let compressed_size = u32::from_le_bytes(
72 seek_table_entries
73 .get(offset..offset + 4)
74 .context("failed to parse seektable entry")?
75 .try_into()?,
76 );
77 let decompressed_size = u32::from_le_bytes(
78 seek_table_entries
79 .get(offset + 4..offset + 8)
80 .context("failed to parse seektable entry")?
81 .try_into()?,
82 );
83 cumulative_decompressed_size += decompressed_size as u64;
84 cumulative_compressed_size += compressed_size as u64;
85 cumulative_decompressed_sizes.push(cumulative_decompressed_size);
86 cumulative_compressed_sizes.push(cumulative_compressed_size);
87 offset += 8 + (checksum_flag as usize * 4);
88 }
89 cumulative_decompressed_sizes.push(cumulative_decompressed_size);
90 cumulative_compressed_sizes.push(cumulative_compressed_size);
91
92 Ok(ZstdSeekTable {
93 cumulative_decompressed_sizes,
94 cumulative_compressed_sizes,
95 })
96 }
97
98 pub fn find_frame_index(&self, decompressed_offset: u64) -> Option<usize> {
100 if self.cumulative_decompressed_sizes.is_empty()
101 || decompressed_offset >= *self.cumulative_decompressed_sizes.last().unwrap()
102 {
103 return None;
104 }
105 self.cumulative_decompressed_sizes
106 .partition_point(|&size| size <= decompressed_offset)
107 .checked_sub(1)
108 }
109}
110
111#[derive(Debug)]
112pub struct ZstdDisk {
113 file: File,
114 seek_table: ZstdSeekTable,
115 cache: RwLock<Option<ZstdFrameCache>>,
116}
117
118#[derive(Debug)]
119struct ZstdFrameCache {
120 frame_index: usize,
121 data: Vec<u8>,
122}
123
124impl ZstdDisk {
125 pub fn from_file(mut file: File) -> anyhow::Result<ZstdDisk> {
126 if file.metadata()?.len() < 17 {
128 return Err(anyhow::anyhow!("File too small to contain zstd seek table"));
129 }
130
131 let mut seektable_footer = [0u8; 9];
133 file.seek(std::io::SeekFrom::End(-9))?;
134 file.read_exact(&mut seektable_footer)?;
135
136 if u32::from_le_bytes(seektable_footer[5..9].try_into()?) != ZSTD_SEEK_TABLE_MAGIC {
138 return Err(anyhow::anyhow!("Invalid zstd seek table magic"));
139 }
140
141 let num_frames = u32::from_le_bytes(seektable_footer[0..4].try_into()?);
143
144 let checksum_flag = (seektable_footer[4] >> 7) & 1 != 0;
146 if (seektable_footer[4] & 0x7C) != 0 {
147 bail!(
148 "This zstd seekable decoder cannot parse seek table with non-zero reserved flags"
149 );
150 }
151
152 let seek_table_entries_size = num_frames * (8 + (checksum_flag as u32 * 4));
153
154 file.seek(std::io::SeekFrom::End(
156 -(9 + seek_table_entries_size as i64),
157 ))?;
158
159 let mut seek_table_entries: Vec<u8> = vec![0u8; seek_table_entries_size as usize];
161 file.read_exact(&mut seek_table_entries)?;
162
163 let seek_table =
164 ZstdSeekTable::from_footer(&seek_table_entries, num_frames, checksum_flag)?;
165
166 Ok(ZstdDisk {
167 file,
168 seek_table,
169 cache: RwLock::new(None),
170 })
171 }
172}
173
174impl DiskGetLen for ZstdDisk {
175 fn get_len(&self) -> std::io::Result<u64> {
176 self.seek_table
177 .cumulative_decompressed_sizes
178 .last()
179 .copied()
180 .ok_or(io::ErrorKind::InvalidData.into())
181 }
182}
183
184impl FileSetLen for ZstdDisk {
185 fn set_len(&self, _len: u64) -> std::io::Result<()> {
186 Err(io::Error::new(
187 io::ErrorKind::PermissionDenied,
188 "unsupported operation",
189 ))
190 }
191}
192
193impl AsRawDescriptor for ZstdDisk {
194 fn as_raw_descriptor(&self) -> RawDescriptor {
195 self.file.as_raw_descriptor()
196 }
197}
198
199struct CompressedReadInstruction {
200 frame_index: usize,
201 read_offset: u64,
203 read_size: u64,
205}
206
207fn compresed_frame_read_instruction(
208 seek_table: &ZstdSeekTable,
209 offset: u64,
210) -> anyhow::Result<CompressedReadInstruction> {
211 let frame_index = seek_table
212 .find_frame_index(offset)
213 .with_context(|| format!("no frame for offset {offset}"))?;
214 let compressed_offset = seek_table.cumulative_compressed_sizes[frame_index];
215 let next_compressed_offset = seek_table
216 .cumulative_compressed_sizes
217 .get(frame_index + 1)
218 .context("Offset out of range (next_compressed_offset overflow)")?;
219 let compressed_size = next_compressed_offset - compressed_offset;
220 Ok(CompressedReadInstruction {
221 frame_index,
222 read_offset: compressed_offset,
223 read_size: compressed_size,
224 })
225}
226
227fn copy_to_volatile_slice(src: &[u8], dst: VolatileSlice) -> io::Result<usize> {
228 let read_len = min(dst.size(), src.len());
229 let data_to_copy = &src[..read_len];
230 dst.sub_slice(0, read_len)
231 .map_err(io::Error::other)?
232 .copy_from(data_to_copy);
233 Ok(data_to_copy.len())
234}
235
236impl FileReadWriteAtVolatile for ZstdDisk {
237 fn read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize> {
238 let read_instruction = compresed_frame_read_instruction(&self.seek_table, offset)
239 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
240
241 if let Some(cache) = self.cache.try_read().ok().as_ref().and_then(|g| g.as_ref()) {
243 if cache.frame_index == read_instruction.frame_index {
244 let decompressed_offset_in_frame = offset
246 - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
247 return copy_to_volatile_slice(
248 &cache.data[decompressed_offset_in_frame as usize..],
249 slice,
250 );
251 }
252 }
253
254 let mut compressed_data = vec![0u8; read_instruction.read_size as usize];
255
256 let compressed_frame_slice = VolatileSlice::new(compressed_data.as_mut_slice());
257
258 self.file
259 .read_at_volatile(compressed_frame_slice, read_instruction.read_offset)
260 .map_err(io::Error::other)?;
261
262 let mut decompressor: zstd::bulk::Decompressor<'_> = zstd::bulk::Decompressor::new()?;
263 let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
264 let decoded_size =
265 decompressor.decompress_to_buffer(&compressed_data, &mut decompressed_data)?;
266
267 let decompressed_offset_in_frame =
268 offset - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
269
270 if decompressed_offset_in_frame >= decoded_size as u64 {
271 return Err(io::Error::new(
272 io::ErrorKind::InvalidData,
273 "BUG: Frame offset larger than decoded size",
274 ));
275 }
276
277 let updated_cache = ZstdFrameCache {
278 frame_index: read_instruction.frame_index,
279 data: decompressed_data,
280 };
281
282 let result = copy_to_volatile_slice(
283 &updated_cache.data[decompressed_offset_in_frame as usize..],
284 slice,
285 );
286
287 if let Ok(mut cache) = self.cache.try_write() {
288 *cache = Some(updated_cache);
289 };
290 result
291 }
292
293 fn write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize> {
294 Err(io::Error::new(
295 io::ErrorKind::PermissionDenied,
296 "unsupported operation",
297 ))
298 }
299}
300
301pub struct AsyncZstdDisk {
302 inner: IoSource<File>,
303 seek_table: ZstdSeekTable,
304 cache: RwLock<Option<ZstdFrameCache>>,
305}
306
307impl ToAsyncDisk for ZstdDisk {
308 fn to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>> {
309 Ok(Box::new(AsyncZstdDisk {
310 inner: ex.async_from(self.file).map_err(DiskError::ToAsync)?,
311 seek_table: self.seek_table,
312 cache: RwLock::new(None),
313 }))
314 }
315}
316
317impl DiskGetLen for AsyncZstdDisk {
318 fn get_len(&self) -> io::Result<u64> {
319 self.seek_table
320 .cumulative_decompressed_sizes
321 .last()
322 .copied()
323 .ok_or(io::ErrorKind::InvalidData.into())
324 }
325}
326
327impl FileSetLen for AsyncZstdDisk {
328 fn set_len(&self, _len: u64) -> io::Result<()> {
329 Err(io::Error::new(
330 io::ErrorKind::PermissionDenied,
331 "unsupported operation",
332 ))
333 }
334}
335
336impl FileAllocate for AsyncZstdDisk {
337 fn allocate(&self, _offset: u64, _length: u64) -> io::Result<()> {
338 Err(io::Error::new(
339 io::ErrorKind::PermissionDenied,
340 "unsupported operation",
341 ))
342 }
343}
344
345fn copy_to_mem(
346 decompressed_data: &[u8],
347 mem: Arc<dyn BackingMemory + Send + Sync>,
348 mem_offsets: cros_async::MemRegionIter,
349) -> DiskResult<usize> {
350 let mut total_copied = 0;
352 for mem_region in mem_offsets {
353 let src_slice = &decompressed_data[total_copied..];
354 let dst_slice = mem
355 .get_volatile_slice(mem_region)
356 .map_err(DiskError::GuestMemory)?;
357
358 let to_copy = min(src_slice.len(), dst_slice.size());
359
360 if to_copy > 0 {
361 dst_slice
362 .sub_slice(0, to_copy)
363 .map_err(|e| DiskError::ReadingData(io::Error::other(e)))?
364 .copy_from(&src_slice[..to_copy]);
365
366 total_copied += to_copy;
367
368 if total_copied == dst_slice.size() {
370 break;
371 }
372 }
373 }
374
375 Ok(total_copied)
376}
377
378#[async_trait(?Send)]
379impl AsyncDisk for AsyncZstdDisk {
380 async fn flush(&self) -> DiskResult<()> {
381 Ok(())
383 }
384
385 async fn fsync(&self) -> DiskResult<()> {
386 Ok(())
388 }
389
390 async fn fdatasync(&self) -> DiskResult<()> {
391 Ok(())
393 }
394
395 async fn read_to_mem<'a>(
400 &'a self,
401 file_offset: u64,
402 mem: Arc<dyn BackingMemory + Send + Sync>,
403 mem_offsets: cros_async::MemRegionIter<'a>,
404 ) -> DiskResult<usize> {
405 let read_instruction = compresed_frame_read_instruction(&self.seek_table, file_offset)
406 .map_err(|e| DiskError::ReadingData(io::Error::new(io::ErrorKind::InvalidData, e)))?;
407
408 if let Some(cache) = self.cache.try_read().ok().as_ref().and_then(|g| g.as_ref()) {
410 if cache.frame_index == read_instruction.frame_index {
411 let decompressed_offset_in_frame = file_offset
413 - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
414 return copy_to_mem(
415 &cache.data[decompressed_offset_in_frame as usize..],
416 mem,
417 mem_offsets,
418 );
419 }
420 }
421
422 let compressed_data = vec![0u8; read_instruction.read_size as usize];
423
424 let (compressed_read_size, compressed_data) = self
425 .inner
426 .read_to_vec(Some(read_instruction.read_offset), compressed_data)
427 .await
428 .map_err(|e| DiskError::ReadingData(io::Error::other(e)))?;
429
430 if compressed_read_size != read_instruction.read_size as usize {
431 return Err(DiskError::ReadingData(io::Error::new(
432 ErrorKind::UnexpectedEof,
433 "Read from compressed data result in wrong length",
434 )));
435 }
436
437 let mut decompressor: zstd::bulk::Decompressor<'_> =
438 zstd::bulk::Decompressor::new().map_err(DiskError::ReadingData)?;
439 let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
440 let decoded_size = decompressor
441 .decompress_to_buffer(&compressed_data, &mut decompressed_data)
442 .map_err(DiskError::ReadingData)?;
443
444 let decompressed_offset_in_frame = file_offset
445 - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
446
447 if decompressed_offset_in_frame as usize > decoded_size {
448 return Err(DiskError::ReadingData(io::Error::new(
449 ErrorKind::InvalidData,
450 "BUG: Frame offset larger than decoded size",
451 )));
452 }
453
454 let result = copy_to_mem(
456 &decompressed_data[decompressed_offset_in_frame as usize..],
457 mem,
458 mem_offsets,
459 );
460
461 let updated_cache = ZstdFrameCache {
462 frame_index: read_instruction.frame_index,
463 data: decompressed_data,
464 };
465
466 if let Ok(mut cache) = self.cache.try_write() {
467 *cache = Some(updated_cache);
468 };
469 result
470 }
471
472 async fn write_from_mem<'a>(
473 &'a self,
474 _file_offset: u64,
475 _mem: Arc<dyn BackingMemory + Send + Sync>,
476 _mem_offsets: cros_async::MemRegionIter<'a>,
477 ) -> DiskResult<usize> {
478 Err(DiskError::UnsupportedOperation)
479 }
480
481 async fn punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
482 Err(DiskError::UnsupportedOperation)
483 }
484
485 async fn write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
486 Err(DiskError::UnsupportedOperation)
487 }
488}
489
490impl DiskFile for ZstdDisk {}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_find_frame_index_empty() {
498 let seek_table = ZstdSeekTable {
499 cumulative_decompressed_sizes: vec![0],
500 cumulative_compressed_sizes: vec![0],
501 };
502 assert_eq!(seek_table.find_frame_index(0), None);
503 assert_eq!(seek_table.find_frame_index(5), None);
504 }
505
506 #[test]
507 fn test_find_frame_index_single_frame() {
508 let seek_table = ZstdSeekTable {
509 cumulative_decompressed_sizes: vec![0, 100],
510 cumulative_compressed_sizes: vec![0, 50],
511 };
512 assert_eq!(seek_table.find_frame_index(0), Some(0));
513 assert_eq!(seek_table.find_frame_index(50), Some(0));
514 assert_eq!(seek_table.find_frame_index(99), Some(0));
515 assert_eq!(seek_table.find_frame_index(100), None);
516 }
517
518 #[test]
519 fn test_find_frame_index_multiple_frames() {
520 let seek_table = ZstdSeekTable {
521 cumulative_decompressed_sizes: vec![0, 100, 300, 500],
522 cumulative_compressed_sizes: vec![0, 50, 120, 200],
523 };
524 assert_eq!(seek_table.find_frame_index(0), Some(0));
525 assert_eq!(seek_table.find_frame_index(99), Some(0));
526 assert_eq!(seek_table.find_frame_index(100), Some(1));
527 assert_eq!(seek_table.find_frame_index(299), Some(1));
528 assert_eq!(seek_table.find_frame_index(300), Some(2));
529 assert_eq!(seek_table.find_frame_index(499), Some(2));
530 assert_eq!(seek_table.find_frame_index(500), None);
531 assert_eq!(seek_table.find_frame_index(1000), None);
532 }
533
534 #[test]
535 fn test_find_frame_index_with_skippable_frames() {
536 let seek_table = ZstdSeekTable {
537 cumulative_decompressed_sizes: vec![0, 100, 100, 100, 300],
538 cumulative_compressed_sizes: vec![0, 50, 60, 70, 150],
539 };
540 assert_eq!(seek_table.find_frame_index(0), Some(0));
541 assert_eq!(seek_table.find_frame_index(99), Some(0));
542 assert_eq!(seek_table.find_frame_index(100), Some(3));
544 assert_eq!(seek_table.find_frame_index(299), Some(3));
545 assert_eq!(seek_table.find_frame_index(300), None);
546 }
547
548 #[test]
549 fn test_find_frame_index_with_last_skippable_frame() {
550 let seek_table = ZstdSeekTable {
551 cumulative_decompressed_sizes: vec![0, 20, 40, 40, 60, 60, 80, 80],
552 cumulative_compressed_sizes: vec![0, 10, 20, 30, 40, 50, 60, 70],
553 };
554 assert_eq!(seek_table.find_frame_index(0), Some(0));
555 assert_eq!(seek_table.find_frame_index(20), Some(1));
556 assert_eq!(seek_table.find_frame_index(21), Some(1));
557 assert_eq!(seek_table.find_frame_index(79), Some(5));
558 assert_eq!(seek_table.find_frame_index(80), None);
559 assert_eq!(seek_table.find_frame_index(300), None);
560 }
561}