base/sys/linux/
netlink.rs

1// Copyright 2021 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::alloc::Layout;
6use std::mem::MaybeUninit;
7use std::os::unix::io::AsRawFd;
8use std::str;
9
10use libc::EINVAL;
11use log::error;
12use zerocopy::FromBytes;
13use zerocopy::Immutable;
14use zerocopy::IntoBytes;
15use zerocopy::KnownLayout;
16
17use super::errno_result;
18use super::getpid;
19use super::Error;
20use super::RawDescriptor;
21use super::Result;
22use crate::alloc::LayoutAllocation;
23use crate::descriptor::AsRawDescriptor;
24use crate::descriptor::FromRawDescriptor;
25use crate::descriptor::SafeDescriptor;
26
27macro_rules! debug_pr {
28    // By default debugs are suppressed, to enabled them replace macro body with:
29    // $($args:tt)+) => (println!($($args)*))
30    ($($args:tt)+) => {};
31}
32
33const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
34const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
35const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
36const NLATTR_ALIGN_TO: usize = 4;
37
38#[repr(C)]
39#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
40struct NlMsgHdr {
41    pub nlmsg_len: u32,
42    pub nlmsg_type: u16,
43    pub nlmsg_flags: u16,
44    pub nlmsg_seq: u32,
45    pub nlmsg_pid: u32,
46}
47
48/// Netlink attribute struct, can be used by netlink consumer
49#[repr(C)]
50#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
51pub struct NlAttr {
52    pub len: u16,
53    pub _type: u16,
54}
55
56/// Generic netlink header struct, can be used by netlink consumer
57#[repr(C)]
58#[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
59pub struct GenlMsgHdr {
60    pub cmd: u8,
61    pub version: u8,
62    pub reserved: u16,
63}
64/// A single netlink message, including its header and data.
65pub struct NetlinkMessage<'a> {
66    pub _type: u16,
67    pub flags: u16,
68    pub seq: u32,
69    pub pid: u32,
70    pub data: &'a [u8],
71}
72
73pub struct NlAttrWithData<'a> {
74    pub len: u16,
75    pub _type: u16,
76    pub data: &'a [u8],
77}
78
79/// Iterator over `struct NlAttr` as received from a netlink socket.
80pub struct NetlinkGenericDataIter<'a> {
81    // `data` must be properly aligned for NlAttr.
82    data: &'a [u8],
83}
84
85impl<'a> Iterator for NetlinkGenericDataIter<'a> {
86    type Item = NlAttrWithData<'a>;
87
88    fn next(&mut self) -> Option<Self::Item> {
89        let (nl_hdr, _) = NlAttr::read_from_prefix(self.data).ok()?;
90        let nl_data_len = nl_hdr.len as usize;
91        let data = self.data.get(NLA_HDRLEN..nl_data_len)?;
92
93        // Get next NlAttr
94        let next_hdr = nl_data_len.next_multiple_of(NLATTR_ALIGN_TO);
95        self.data = self.data.get(next_hdr..).unwrap_or(&[]);
96
97        Some(NlAttrWithData {
98            _type: nl_hdr._type,
99            len: nl_hdr.len,
100            data,
101        })
102    }
103}
104
105/// Iterator over `struct nlmsghdr` as received from a netlink socket.
106pub struct NetlinkMessageIter<'a> {
107    // `data` must be properly aligned for nlmsghdr.
108    data: &'a [u8],
109}
110
111impl<'a> Iterator for NetlinkMessageIter<'a> {
112    type Item = NetlinkMessage<'a>;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        let (hdr, _) = NlMsgHdr::read_from_prefix(self.data).ok()?;
116        let msg_len = hdr.nlmsg_len as usize;
117        let data = self.data.get(NLMSGHDR_SIZE..msg_len)?;
118
119        // NLMSG_NEXT
120        let next_hdr = msg_len.next_multiple_of(std::mem::align_of::<NlMsgHdr>());
121        self.data = self.data.get(next_hdr..).unwrap_or(&[]);
122
123        Some(NetlinkMessage {
124            _type: hdr.nlmsg_type,
125            flags: hdr.nlmsg_flags,
126            seq: hdr.nlmsg_seq,
127            pid: hdr.nlmsg_pid,
128            data,
129        })
130    }
131}
132
133/// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
134pub struct NetlinkGenericSocket {
135    sock: SafeDescriptor,
136}
137
138impl AsRawDescriptor for NetlinkGenericSocket {
139    fn as_raw_descriptor(&self) -> RawDescriptor {
140        self.sock.as_raw_descriptor()
141    }
142}
143
144impl NetlinkGenericSocket {
145    /// Create and bind a new `NETLINK_GENERIC` socket.
146    pub fn new(nl_groups: u32) -> Result<Self> {
147        // SAFETY:
148        // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
149        let sock = unsafe {
150            let fd = libc::socket(
151                libc::AF_NETLINK,
152                libc::SOCK_RAW | libc::SOCK_CLOEXEC,
153                libc::NETLINK_GENERIC,
154            );
155            if fd < 0 {
156                return errno_result();
157            }
158
159            SafeDescriptor::from_raw_descriptor(fd)
160        };
161
162        // SAFETY:
163        // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
164        // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
165        let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
166        sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
167        sa.nl_groups = nl_groups;
168
169        // SAFETY:
170        // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
171        unsafe {
172            let res = libc::bind(
173                sock.as_raw_fd(),
174                &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
175                std::mem::size_of_val(&sa) as libc::socklen_t,
176            );
177            if res < 0 {
178                return errno_result();
179            }
180        }
181
182        Ok(NetlinkGenericSocket { sock })
183    }
184
185    /// Receive messages from the netlink socket.
186    pub fn recv(&self) -> Result<NetlinkGenericRead> {
187        let buf_size = 8192; // TODO(dverkamp): make this configurable?
188
189        // Create a buffer with sufficient alignment for nlmsghdr.
190        let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
191            .map_err(|_| Error::new(EINVAL))?;
192        let allocation = LayoutAllocation::uninitialized(layout);
193
194        // SAFETY:
195        // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
196        let bytes_read = unsafe {
197            let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
198            if res < 0 {
199                return errno_result();
200            }
201            res as usize
202        };
203
204        Ok(NetlinkGenericRead {
205            allocation,
206            len: bytes_read,
207        })
208    }
209
210    pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
211        let buf_size = 1024;
212        debug_pr!(
213            "preparing query for family name {}, len {}",
214            family_name,
215            family_name.len()
216        );
217
218        // Create a buffer with sufficient alignment for nlmsghdr.
219        let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
220            .map_err(|_| Error::new(EINVAL))
221            .unwrap();
222        let mut allocation = LayoutAllocation::zeroed(layout);
223
224        // SAFETY:
225        // Safe because the data in allocation was initialized up to `buf_size` and is
226        // sufficiently aligned.
227        let data = unsafe { allocation.as_mut_slice(buf_size) };
228
229        // Prepare the netlink message header
230        let (hdr, genl_hdr) = NlMsgHdr::mut_from_prefix(data).expect("failed to unwrap");
231        hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
232        hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
233        hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
234        hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
235        hdr.nlmsg_pid = getpid() as u32;
236
237        // Prepare generic netlink message header
238        let (genl_hdr, nlattr) =
239            GenlMsgHdr::mut_from_prefix(genl_hdr).expect("unable to get GenlMsgHdr from slice");
240        genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
241        genl_hdr.version = 0x1;
242
243        // Netlink attributes
244        let (nl_attr, payload) =
245            NlAttr::mut_from_prefix(nlattr).expect("unable to get NlAttr from slice");
246        nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
247        nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
248
249        // Fill the message payload with the family name
250        payload[..family_name.len()].copy_from_slice(family_name.as_bytes());
251
252        let len = NLMSGHDR_SIZE + GENL_HDRLEN + NLA_HDRLEN + family_name.len() + 1;
253
254        // SAFETY:
255        // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
256        unsafe {
257            let res = libc::send(self.sock.as_raw_fd(), allocation.as_ptr(), len, 0);
258            if res < 0 {
259                error!("failed to send get_family_cmd");
260                return errno_result();
261            }
262        };
263
264        // Return the answer
265        match self.recv() {
266            Ok(msg) => Ok(msg),
267            Err(e) => {
268                error!("recv get_family returned with error {}", e);
269                Err(e)
270            }
271        }
272    }
273}
274
275fn parse_ctrl_group_name_and_id(
276    nested_nl_attr_data: NetlinkGenericDataIter,
277    group_name: &str,
278) -> Option<u32> {
279    let mut mcast_group_id: Option<u32> = None;
280
281    for nested_nl_attr in nested_nl_attr_data {
282        debug_pr!(
283            "\t\tmcast_grp: nlattr type {}, len {}",
284            nested_nl_attr._type,
285            nested_nl_attr.len
286        );
287
288        if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
289            mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
290            debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
291        }
292
293        if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
294            debug_pr!(
295                "\t\t mcast group name {}",
296                strip_padding(&nested_nl_attr.data)
297            );
298
299            // If the group name match and the group_id was set in previous iteration, return,
300            // valid for group_name, group_id
301            if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
302                debug_pr!(
303                    "\t\t Got what we were looking for group_id = {} for {}",
304                    mcast_group_id?,
305                    group_name
306                );
307
308                return mcast_group_id;
309            }
310        }
311    }
312
313    None
314}
315
316/// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
317///
318/// On success, returns group_id for a given `group_name`
319///
320/// # Arguments
321///
322/// * `nl_attr_area`
323///
324///   Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's corresponding to
325///   specific groups are embed
326///
327/// * `group_name`
328///
329///     String with group_name for which we are looking group_id
330///
331/// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
332/// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
333/// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
334///
335/// Illustrated layout:
336/// CTRL_ATTR_MCAST_GROUPS:
337///   GR1 (nl_attr._type = 1):
338///       CTRL_ATTR_MCAST_GRP_ID,
339///       CTRL_ATTR_MCAST_GRP_NAME,
340///   GR2 (nl_attr._type = 2):
341///       CTRL_ATTR_MCAST_GRP_ID,
342///       CTRL_ATTR_MCAST_GRP_NAME,
343///   ..
344///
345/// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
346/// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
347/// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
348fn parse_ctrl_mcast_group_id(
349    nl_attr_area: NetlinkGenericDataIter,
350    group_name: &str,
351) -> Option<u32> {
352    // There may be multiple nested multicast groups, go through all of them.
353    // Each of nested group, has other nested nlattr:
354    //  CTRL_ATTR_MCAST_GRP_ID
355    //  CTRL_ATTR_MCAST_GRP_NAME
356    //
357    //  which are further proceed by parse_ctrl_group_name_and_id
358    for nested_gr_nl_attr in nl_attr_area {
359        debug_pr!(
360            "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
361            nested_gr_nl_attr._type,
362            nested_gr_nl_attr.len
363        );
364
365        let netlink_nested_attr = NetlinkGenericDataIter {
366            data: nested_gr_nl_attr.data,
367        };
368
369        if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
370        {
371            return Some(mcast_group_id);
372        }
373    }
374
375    None
376}
377
378// Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
379// returns &str. Panics if `b` doesn't contain any '\0' bytes.
380fn strip_padding(b: &[u8]) -> &str {
381    // It would be nice if we could use memchr here but that's locked behind an unstable gate.
382    let pos = b
383        .iter()
384        .position(|&c| c == 0)
385        .expect("`b` doesn't contain any nul bytes");
386
387    str::from_utf8(&b[..pos]).unwrap()
388}
389
390pub struct NetlinkGenericRead {
391    allocation: LayoutAllocation,
392    len: usize,
393}
394
395impl NetlinkGenericRead {
396    pub fn iter(&self) -> NetlinkMessageIter {
397        // SAFETY:
398        // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
399        // sufficiently aligned.
400        let data = unsafe { &self.allocation.as_slice(self.len) };
401        NetlinkMessageIter { data }
402    }
403
404    /// Parse NetlinkGeneric response in order to get multicast group id
405    ///
406    /// On success, returns group_id for a given `group_name`
407    ///
408    /// # Arguments
409    ///
410    /// * `group_name` - String with group_name for which we are looking group_id
411    ///
412    /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
413    /// attributes encapsulated (some of them are nested). An example response layout is
414    /// illustrated below:
415    ///
416    ///  {
417    ///    CTRL_ATTR_FAMILY_NAME
418    ///    CTRL_ATTR_FAMILY_ID
419    ///    CTRL_ATTR_VERSION
420    ///    ...
421    ///    CTRL_ATTR_MCAST_GROUPS {
422    ///      GR1 (nl_attr._type = 1) {
423    ///          CTRL_ATTR_MCAST_GRP_ID    *we need parse this attr to obtain group id used for
424    ///                                     the group mask
425    ///          CTRL_ATTR_MCAST_GRP_NAME  *group_name that we need to match with
426    ///      }
427    ///      GR2 (nl_attr._type = 2) {
428    ///          CTRL_ATTR_MCAST_GRP_ID
429    ///          CTRL_ATTR_MCAST_GRP_NAME
430    ///      }
431    ///      ...
432    ///     }
433    ///   }
434    pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
435        for netlink_msg in self.iter() {
436            debug_pr!(
437                "received type: {}, flags {}, pid {}, data {:?}",
438                netlink_msg._type,
439                netlink_msg.flags,
440                netlink_msg.pid,
441                netlink_msg.data
442            );
443
444            if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
445                error!("Received not a generic netlink controller msg");
446                return None;
447            }
448
449            let netlink_data = NetlinkGenericDataIter {
450                data: &netlink_msg.data[GENL_HDRLEN..],
451            };
452            for nl_attr in netlink_data {
453                debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
454
455                if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
456                    let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
457
458                    if let Some(mcast_group_id) =
459                        parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
460                    {
461                        return Some(mcast_group_id);
462                    }
463                }
464            }
465        }
466        None
467    }
468}