disk/
zstd.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Use seekable zstd archive of raw disk image as read only disk
6
7use std::cmp::min;
8use std::fs::File;
9use std::io;
10use std::io::ErrorKind;
11use std::io::Read;
12use std::io::Seek;
13use std::sync::Arc;
14use std::sync::RwLock;
15
16use anyhow::bail;
17use anyhow::Context;
18use async_trait::async_trait;
19use base::AsRawDescriptor;
20use base::FileAllocate;
21use base::FileReadWriteAtVolatile;
22use base::FileSetLen;
23use base::RawDescriptor;
24use base::VolatileSlice;
25use cros_async::BackingMemory;
26use cros_async::Executor;
27use cros_async::IoSource;
28
29use crate::AsyncDisk;
30use crate::DiskFile;
31use crate::DiskGetLen;
32use crate::Error as DiskError;
33use crate::Result as DiskResult;
34use crate::ToAsyncDisk;
35
36// Zstandard frame magic
37pub const ZSTD_FRAME_MAGIC: u32 = 0xFD2FB528;
38
39// Skippable frame magic can be anything between [0x184D2A50, 0x184D2A5F]
40pub const ZSTD_SKIPPABLE_MAGIC_LOW: u32 = 0x184D2A50;
41pub const ZSTD_SKIPPABLE_MAGIC_HIGH: u32 = 0x184D2A5F;
42pub const ZSTD_SEEK_TABLE_MAGIC: u32 = 0x8F92EAB1;
43
44pub const ZSTD_DEFAULT_FRAME_SIZE: usize = 128 << 10; // 128KB
45
46#[derive(Clone, Debug)]
47pub struct ZstdSeekTable {
48    // Cumulative sum of decompressed sizes of all frames before the indexed frame.
49    // The last element is the total decompressed size of the zstd archive.
50    cumulative_decompressed_sizes: Vec<u64>,
51    // Cumulative sum of compressed sizes of all frames before the indexed frame.
52    // The last element is the total compressed size of the zstd archive.
53    cumulative_compressed_sizes: Vec<u64>,
54}
55
56impl ZstdSeekTable {
57    /// Read seek table entries from seek_table_entries
58    pub fn from_footer(
59        seek_table_entries: &[u8],
60        num_frames: u32,
61        checksum_flag: bool,
62    ) -> anyhow::Result<ZstdSeekTable> {
63        let mut cumulative_decompressed_size: u64 = 0;
64        let mut cumulative_compressed_size: u64 = 0;
65        let mut cumulative_decompressed_sizes = Vec::with_capacity(num_frames as usize + 1);
66        let mut cumulative_compressed_sizes = Vec::with_capacity(num_frames as usize + 1);
67        let mut offset = 0;
68        cumulative_decompressed_sizes.push(0);
69        cumulative_compressed_sizes.push(0);
70        for _ in 0..num_frames {
71            let compressed_size = u32::from_le_bytes(
72                seek_table_entries
73                    .get(offset..offset + 4)
74                    .context("failed to parse seektable entry")?
75                    .try_into()?,
76            );
77            let decompressed_size = u32::from_le_bytes(
78                seek_table_entries
79                    .get(offset + 4..offset + 8)
80                    .context("failed to parse seektable entry")?
81                    .try_into()?,
82            );
83            cumulative_decompressed_size += decompressed_size as u64;
84            cumulative_compressed_size += compressed_size as u64;
85            cumulative_decompressed_sizes.push(cumulative_decompressed_size);
86            cumulative_compressed_sizes.push(cumulative_compressed_size);
87            offset += 8 + (checksum_flag as usize * 4);
88        }
89        cumulative_decompressed_sizes.push(cumulative_decompressed_size);
90        cumulative_compressed_sizes.push(cumulative_compressed_size);
91
92        Ok(ZstdSeekTable {
93            cumulative_decompressed_sizes,
94            cumulative_compressed_sizes,
95        })
96    }
97
98    /// Returns the index of the frame that contains the given decompressed offset.
99    pub fn find_frame_index(&self, decompressed_offset: u64) -> Option<usize> {
100        if self.cumulative_decompressed_sizes.is_empty()
101            || decompressed_offset >= *self.cumulative_decompressed_sizes.last().unwrap()
102        {
103            return None;
104        }
105        self.cumulative_decompressed_sizes
106            .partition_point(|&size| size <= decompressed_offset)
107            .checked_sub(1)
108    }
109}
110
111#[derive(Debug)]
112pub struct ZstdDisk {
113    file: File,
114    seek_table: ZstdSeekTable,
115    cache: RwLock<Option<ZstdFrameCache>>,
116}
117
118#[derive(Debug)]
119struct ZstdFrameCache {
120    frame_index: usize,
121    data: Vec<u8>,
122}
123
124impl ZstdDisk {
125    pub fn from_file(mut file: File) -> anyhow::Result<ZstdDisk> {
126        // Verify file is large enough to contain a seek table (17 bytes)
127        if file.metadata()?.len() < 17 {
128            return Err(anyhow::anyhow!("File too small to contain zstd seek table"));
129        }
130
131        // Read last 9 bytes as seek table footer
132        let mut seektable_footer = [0u8; 9];
133        file.seek(std::io::SeekFrom::End(-9))?;
134        file.read_exact(&mut seektable_footer)?;
135
136        // Verify last 4 bytes of footer is seek table magic
137        if u32::from_le_bytes(seektable_footer[5..9].try_into()?) != ZSTD_SEEK_TABLE_MAGIC {
138            return Err(anyhow::anyhow!("Invalid zstd seek table magic"));
139        }
140
141        // Get number of frame from seek table
142        let num_frames = u32::from_le_bytes(seektable_footer[0..4].try_into()?);
143
144        // Read flags from seek table descriptor
145        let checksum_flag = (seektable_footer[4] >> 7) & 1 != 0;
146        if (seektable_footer[4] & 0x7C) != 0 {
147            bail!(
148                "This zstd seekable decoder cannot parse seek table with non-zero reserved flags"
149            );
150        }
151
152        let seek_table_entries_size = num_frames * (8 + (checksum_flag as u32 * 4));
153
154        // Seek to the beginning of the seek table
155        file.seek(std::io::SeekFrom::End(
156            -(9 + seek_table_entries_size as i64),
157        ))?;
158
159        // Return new ZstdDisk
160        let mut seek_table_entries: Vec<u8> = vec![0u8; seek_table_entries_size as usize];
161        file.read_exact(&mut seek_table_entries)?;
162
163        let seek_table =
164            ZstdSeekTable::from_footer(&seek_table_entries, num_frames, checksum_flag)?;
165
166        Ok(ZstdDisk {
167            file,
168            seek_table,
169            cache: RwLock::new(None),
170        })
171    }
172}
173
174impl DiskGetLen for ZstdDisk {
175    fn get_len(&self) -> std::io::Result<u64> {
176        self.seek_table
177            .cumulative_decompressed_sizes
178            .last()
179            .copied()
180            .ok_or(io::ErrorKind::InvalidData.into())
181    }
182}
183
184impl FileSetLen for ZstdDisk {
185    fn set_len(&self, _len: u64) -> std::io::Result<()> {
186        Err(io::Error::new(
187            io::ErrorKind::PermissionDenied,
188            "unsupported operation",
189        ))
190    }
191}
192
193impl AsRawDescriptor for ZstdDisk {
194    fn as_raw_descriptor(&self) -> RawDescriptor {
195        self.file.as_raw_descriptor()
196    }
197}
198
199struct CompressedReadInstruction {
200    frame_index: usize,
201    // byte offset of the entire compressed file to start read from
202    read_offset: u64,
203    // number of bytes to read from the compressed file
204    read_size: u64,
205}
206
207fn compresed_frame_read_instruction(
208    seek_table: &ZstdSeekTable,
209    offset: u64,
210) -> anyhow::Result<CompressedReadInstruction> {
211    let frame_index = seek_table
212        .find_frame_index(offset)
213        .with_context(|| format!("no frame for offset {offset}"))?;
214    let compressed_offset = seek_table.cumulative_compressed_sizes[frame_index];
215    let next_compressed_offset = seek_table
216        .cumulative_compressed_sizes
217        .get(frame_index + 1)
218        .context("Offset out of range (next_compressed_offset overflow)")?;
219    let compressed_size = next_compressed_offset - compressed_offset;
220    Ok(CompressedReadInstruction {
221        frame_index,
222        read_offset: compressed_offset,
223        read_size: compressed_size,
224    })
225}
226
227fn copy_to_volatile_slice(src: &[u8], dst: VolatileSlice) -> io::Result<usize> {
228    let read_len = min(dst.size(), src.len());
229    let data_to_copy = &src[..read_len];
230    dst.sub_slice(0, read_len)
231        .map_err(io::Error::other)?
232        .copy_from(data_to_copy);
233    Ok(data_to_copy.len())
234}
235
236impl FileReadWriteAtVolatile for ZstdDisk {
237    fn read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize> {
238        let read_instruction = compresed_frame_read_instruction(&self.seek_table, offset)
239            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
240
241        // Try obtain read lock of cache
242        if let Some(cache) = self.cache.try_read().ok().as_ref().and_then(|g| g.as_ref()) {
243            if cache.frame_index == read_instruction.frame_index {
244                // Cache hit
245                let decompressed_offset_in_frame = offset
246                    - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
247                return copy_to_volatile_slice(
248                    &cache.data[decompressed_offset_in_frame as usize..],
249                    slice,
250                );
251            }
252        }
253
254        let mut compressed_data = vec![0u8; read_instruction.read_size as usize];
255
256        let compressed_frame_slice = VolatileSlice::new(compressed_data.as_mut_slice());
257
258        self.file
259            .read_at_volatile(compressed_frame_slice, read_instruction.read_offset)
260            .map_err(io::Error::other)?;
261
262        let mut decompressor: zstd::bulk::Decompressor<'_> = zstd::bulk::Decompressor::new()?;
263        let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
264        let decoded_size =
265            decompressor.decompress_to_buffer(&compressed_data, &mut decompressed_data)?;
266
267        let decompressed_offset_in_frame =
268            offset - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
269
270        if decompressed_offset_in_frame >= decoded_size as u64 {
271            return Err(io::Error::new(
272                io::ErrorKind::InvalidData,
273                "BUG: Frame offset larger than decoded size",
274            ));
275        }
276
277        let updated_cache = ZstdFrameCache {
278            frame_index: read_instruction.frame_index,
279            data: decompressed_data,
280        };
281
282        let result = copy_to_volatile_slice(
283            &updated_cache.data[decompressed_offset_in_frame as usize..],
284            slice,
285        );
286
287        if let Ok(mut cache) = self.cache.try_write() {
288            *cache = Some(updated_cache);
289        };
290        result
291    }
292
293    fn write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize> {
294        Err(io::Error::new(
295            io::ErrorKind::PermissionDenied,
296            "unsupported operation",
297        ))
298    }
299}
300
301pub struct AsyncZstdDisk {
302    inner: IoSource<File>,
303    seek_table: ZstdSeekTable,
304    cache: RwLock<Option<ZstdFrameCache>>,
305}
306
307impl ToAsyncDisk for ZstdDisk {
308    fn to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>> {
309        Ok(Box::new(AsyncZstdDisk {
310            inner: ex.async_from(self.file).map_err(DiskError::ToAsync)?,
311            seek_table: self.seek_table,
312            cache: RwLock::new(None),
313        }))
314    }
315}
316
317impl DiskGetLen for AsyncZstdDisk {
318    fn get_len(&self) -> io::Result<u64> {
319        self.seek_table
320            .cumulative_decompressed_sizes
321            .last()
322            .copied()
323            .ok_or(io::ErrorKind::InvalidData.into())
324    }
325}
326
327impl FileSetLen for AsyncZstdDisk {
328    fn set_len(&self, _len: u64) -> io::Result<()> {
329        Err(io::Error::new(
330            io::ErrorKind::PermissionDenied,
331            "unsupported operation",
332        ))
333    }
334}
335
336impl FileAllocate for AsyncZstdDisk {
337    fn allocate(&self, _offset: u64, _length: u64) -> io::Result<()> {
338        Err(io::Error::new(
339            io::ErrorKind::PermissionDenied,
340            "unsupported operation",
341        ))
342    }
343}
344
345fn copy_to_mem(
346    decompressed_data: &[u8],
347    mem: Arc<dyn BackingMemory + Send + Sync>,
348    mem_offsets: cros_async::MemRegionIter,
349) -> DiskResult<usize> {
350    // Copy the decompressed data to the provided memory regions.
351    let mut total_copied = 0;
352    for mem_region in mem_offsets {
353        let src_slice = &decompressed_data[total_copied..];
354        let dst_slice = mem
355            .get_volatile_slice(mem_region)
356            .map_err(DiskError::GuestMemory)?;
357
358        let to_copy = min(src_slice.len(), dst_slice.size());
359
360        if to_copy > 0 {
361            dst_slice
362                .sub_slice(0, to_copy)
363                .map_err(|e| DiskError::ReadingData(io::Error::other(e)))?
364                .copy_from(&src_slice[..to_copy]);
365
366            total_copied += to_copy;
367
368            // if fully copied destination buffers, break the loop.
369            if total_copied == dst_slice.size() {
370                break;
371            }
372        }
373    }
374
375    Ok(total_copied)
376}
377
378#[async_trait(?Send)]
379impl AsyncDisk for AsyncZstdDisk {
380    async fn flush(&self) -> DiskResult<()> {
381        // zstd is read-only, nothing to flush.
382        Ok(())
383    }
384
385    async fn fsync(&self) -> DiskResult<()> {
386        // Do nothing because it's read-only.
387        Ok(())
388    }
389
390    async fn fdatasync(&self) -> DiskResult<()> {
391        // Do nothing because it's read-only.
392        Ok(())
393    }
394
395    /// Reads data from `file_offset` of decompressed disk image till the end of current
396    /// zstd frame and write them into memory `mem` at `mem_offsets`. This function should
397    /// function the same as running `preadv()` on decompressed zstd image and reading into
398    /// the array of `iovec`s specified with `mem` and `mem_offsets`.
399    async fn read_to_mem<'a>(
400        &'a self,
401        file_offset: u64,
402        mem: Arc<dyn BackingMemory + Send + Sync>,
403        mem_offsets: cros_async::MemRegionIter<'a>,
404    ) -> DiskResult<usize> {
405        let read_instruction = compresed_frame_read_instruction(&self.seek_table, file_offset)
406            .map_err(|e| DiskError::ReadingData(io::Error::new(io::ErrorKind::InvalidData, e)))?;
407
408        // Try obtain read lock of cache
409        if let Some(cache) = self.cache.try_read().ok().as_ref().and_then(|g| g.as_ref()) {
410            if cache.frame_index == read_instruction.frame_index {
411                // Cache hit
412                let decompressed_offset_in_frame = file_offset
413                    - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
414                return copy_to_mem(
415                    &cache.data[decompressed_offset_in_frame as usize..],
416                    mem,
417                    mem_offsets,
418                );
419            }
420        }
421
422        let compressed_data = vec![0u8; read_instruction.read_size as usize];
423
424        let (compressed_read_size, compressed_data) = self
425            .inner
426            .read_to_vec(Some(read_instruction.read_offset), compressed_data)
427            .await
428            .map_err(|e| DiskError::ReadingData(io::Error::other(e)))?;
429
430        if compressed_read_size != read_instruction.read_size as usize {
431            return Err(DiskError::ReadingData(io::Error::new(
432                ErrorKind::UnexpectedEof,
433                "Read from compressed data result in wrong length",
434            )));
435        }
436
437        let mut decompressor: zstd::bulk::Decompressor<'_> =
438            zstd::bulk::Decompressor::new().map_err(DiskError::ReadingData)?;
439        let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
440        let decoded_size = decompressor
441            .decompress_to_buffer(&compressed_data, &mut decompressed_data)
442            .map_err(DiskError::ReadingData)?;
443
444        let decompressed_offset_in_frame = file_offset
445            - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
446
447        if decompressed_offset_in_frame as usize > decoded_size {
448            return Err(DiskError::ReadingData(io::Error::new(
449                ErrorKind::InvalidData,
450                "BUG: Frame offset larger than decoded size",
451            )));
452        }
453
454        // Copy the decompressed data to the provided memory regions.
455        let result = copy_to_mem(
456            &decompressed_data[decompressed_offset_in_frame as usize..],
457            mem,
458            mem_offsets,
459        );
460
461        let updated_cache = ZstdFrameCache {
462            frame_index: read_instruction.frame_index,
463            data: decompressed_data,
464        };
465
466        if let Ok(mut cache) = self.cache.try_write() {
467            *cache = Some(updated_cache);
468        };
469        result
470    }
471
472    async fn write_from_mem<'a>(
473        &'a self,
474        _file_offset: u64,
475        _mem: Arc<dyn BackingMemory + Send + Sync>,
476        _mem_offsets: cros_async::MemRegionIter<'a>,
477    ) -> DiskResult<usize> {
478        Err(DiskError::UnsupportedOperation)
479    }
480
481    async fn punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
482        Err(DiskError::UnsupportedOperation)
483    }
484
485    async fn write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
486        Err(DiskError::UnsupportedOperation)
487    }
488}
489
490impl DiskFile for ZstdDisk {}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_find_frame_index_empty() {
498        let seek_table = ZstdSeekTable {
499            cumulative_decompressed_sizes: vec![0],
500            cumulative_compressed_sizes: vec![0],
501        };
502        assert_eq!(seek_table.find_frame_index(0), None);
503        assert_eq!(seek_table.find_frame_index(5), None);
504    }
505
506    #[test]
507    fn test_find_frame_index_single_frame() {
508        let seek_table = ZstdSeekTable {
509            cumulative_decompressed_sizes: vec![0, 100],
510            cumulative_compressed_sizes: vec![0, 50],
511        };
512        assert_eq!(seek_table.find_frame_index(0), Some(0));
513        assert_eq!(seek_table.find_frame_index(50), Some(0));
514        assert_eq!(seek_table.find_frame_index(99), Some(0));
515        assert_eq!(seek_table.find_frame_index(100), None);
516    }
517
518    #[test]
519    fn test_find_frame_index_multiple_frames() {
520        let seek_table = ZstdSeekTable {
521            cumulative_decompressed_sizes: vec![0, 100, 300, 500],
522            cumulative_compressed_sizes: vec![0, 50, 120, 200],
523        };
524        assert_eq!(seek_table.find_frame_index(0), Some(0));
525        assert_eq!(seek_table.find_frame_index(99), Some(0));
526        assert_eq!(seek_table.find_frame_index(100), Some(1));
527        assert_eq!(seek_table.find_frame_index(299), Some(1));
528        assert_eq!(seek_table.find_frame_index(300), Some(2));
529        assert_eq!(seek_table.find_frame_index(499), Some(2));
530        assert_eq!(seek_table.find_frame_index(500), None);
531        assert_eq!(seek_table.find_frame_index(1000), None);
532    }
533
534    #[test]
535    fn test_find_frame_index_with_skippable_frames() {
536        let seek_table = ZstdSeekTable {
537            cumulative_decompressed_sizes: vec![0, 100, 100, 100, 300],
538            cumulative_compressed_sizes: vec![0, 50, 60, 70, 150],
539        };
540        assert_eq!(seek_table.find_frame_index(0), Some(0));
541        assert_eq!(seek_table.find_frame_index(99), Some(0));
542        // Correctly skips the skippable frames.
543        assert_eq!(seek_table.find_frame_index(100), Some(3));
544        assert_eq!(seek_table.find_frame_index(299), Some(3));
545        assert_eq!(seek_table.find_frame_index(300), None);
546    }
547
548    #[test]
549    fn test_find_frame_index_with_last_skippable_frame() {
550        let seek_table = ZstdSeekTable {
551            cumulative_decompressed_sizes: vec![0, 20, 40, 40, 60, 60, 80, 80],
552            cumulative_compressed_sizes: vec![0, 10, 20, 30, 40, 50, 60, 70],
553        };
554        assert_eq!(seek_table.find_frame_index(0), Some(0));
555        assert_eq!(seek_table.find_frame_index(20), Some(1));
556        assert_eq!(seek_table.find_frame_index(21), Some(1));
557        assert_eq!(seek_table.find_frame_index(79), Some(5));
558        assert_eq!(seek_table.find_frame_index(80), None);
559        assert_eq!(seek_table.find_frame_index(300), None);
560    }
561}