1#![deny(missing_docs)]
6
7use std::fs::read_to_string;
8use std::num::ParseIntError;
9use std::path::Path;
10use std::str::FromStr;
11use std::thread::sleep;
12use std::time::Duration;
13
14use anyhow::anyhow;
15use anyhow::bail;
16use anyhow::Context;
17use anyhow::Result;
18use base::linux::getpid;
19use base::linux::kill;
20use base::linux::Signal;
21use base::Pid;
22
23pub struct ProcessesGuard {
37 pids: Vec<Pid>,
38}
39
40pub fn freeze_child_processes(monitor_pid: Pid) -> Result<ProcessesGuard> {
46 let mut guard = ProcessesGuard {
47 pids: load_descendants(getpid(), monitor_pid)?,
48 };
49
50 for _ in 0..3 {
51 guard.stop_the_world().context("stop the world")?;
52 let pids_after = load_descendants(getpid(), monitor_pid)?;
53 if pids_after == guard.pids {
54 return Ok(guard);
55 }
56 guard.pids = pids_after;
57 }
58
59 bail!("new processes forked while freezing");
60}
61
62impl ProcessesGuard {
63 fn stop_the_world(&self) -> Result<()> {
65 for pid in &self.pids {
66 unsafe { kill(*pid, Signal::Stop as i32) }.context("failed to stop process")?;
69 }
70 for pid in &self.pids {
71 wait_process_stopped(*pid).context("wait process stopped")?;
72 }
73 Ok(())
74 }
75
76 fn continue_the_world(&self) {
78 for pid in &self.pids {
79 let _ = unsafe { kill(*pid, Signal::Continue as i32) };
84 }
85 }
86}
87
88impl Drop for ProcessesGuard {
89 fn drop(&mut self) {
90 self.continue_the_world();
91 }
92}
93
94fn load_descendants(current_pid: Pid, monitor_pid: Pid) -> Result<Vec<Pid>> {
96 let children = read_to_string(format!("/proc/{current_pid}/task/{current_pid}/children"))
98 .context("read children")?;
99 let children = children.trim();
100 if children.is_empty() {
102 return Ok(Vec::new());
103 }
104 let pids: std::result::Result<Vec<i32>, ParseIntError> = children
105 .split(' ')
106 .map(i32::from_str)
107 .filter(|pid| match pid {
109 Ok(pid) => *pid != monitor_pid,
110 _ => true,
111 })
112 .collect();
113 let pids = pids.context("parse pids")?;
114 let mut result = Vec::new();
115 for pid in pids {
116 result.push(pid);
117 let pids = load_descendants(pid, monitor_pid)?;
118 result.extend(pids);
119 }
120 Ok(result)
121}
122
123fn parse_process_state(text: &str) -> Option<char> {
129 let chars = text.chars();
130 let mut chars = chars.peekable();
131 while match chars.next() {
133 Some(c) => c != ')',
134 None => false,
135 } {}
136 while match chars.peek() {
138 Some(c) => {
139 let is_whitespace = *c == ' ';
140 if is_whitespace {
141 chars.next();
142 }
143 is_whitespace
144 }
145 None => false,
146 } {}
147 chars.next()
149}
150
151fn wait_for_task_stopped(task_path: &Path) -> Result<()> {
152 for _ in 0..10 {
153 let stat = read_to_string(task_path.join("stat")).context("read process status")?;
154 if let Some(state) = parse_process_state(&stat) {
155 if state == 'T' {
156 return Ok(());
157 }
158 }
159 sleep(Duration::from_millis(50));
160 }
161 Err(anyhow!("time out"))
162}
163
164fn wait_process_stopped(pid: Pid) -> Result<()> {
165 let all_tasks = std::fs::read_dir(format!("/proc/{pid}/task")).context("read tasks")?;
166 for task in all_tasks {
167 wait_for_task_stopped(&task.context("read task entry")?.path()).context("wait for task")?;
168 }
169 Ok(())
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn parse_process_state_tests() {
178 assert_eq!(parse_process_state("1234 (crosvm) T 0 0 0").unwrap(), 'T');
179 assert_eq!(parse_process_state("1234 (crosvm) R 0 0 0").unwrap(), 'R');
180 assert_eq!(parse_process_state("1234 (crosvm) T 0 0 0").unwrap(), 'T');
182 assert_eq!(parse_process_state("1234 (crosvm)T 0 0 0").unwrap(), 'T');
184 assert_eq!(
186 parse_process_state("1234 (crosvm --test) T 0 0 0").unwrap(),
187 'T'
188 );
189 assert_eq!(parse_process_state("1234 (crosvm)").is_none(), true);
191 }
192}