argh_helpers/
lib.rs

1// Copyright 2022 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::fmt::Write;
6
7use quote::quote;
8
9/// A helper derive proc macro to flatten multiple subcommand enums into one
10/// Note that it is unable to check for duplicate commands and they will be
11/// tried in order of declaration
12#[proc_macro_derive(FlattenSubcommand)]
13pub fn flatten_subcommand(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
14    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
15    let de = match ast.data {
16        syn::Data::Enum(v) => v,
17        _ => unreachable!(),
18    };
19    let name = &ast.ident;
20
21    // An enum variant like `<name>(<ty>)`
22    struct SubCommandVariant<'a> {
23        name: &'a syn::Ident,
24        ty: &'a syn::Type,
25    }
26
27    let variants: Vec<SubCommandVariant<'_>> = de
28        .variants
29        .iter()
30        .map(|variant| {
31            let name = &variant.ident;
32            let ty = match &variant.fields {
33                syn::Fields::Unnamed(field) => {
34                    if field.unnamed.len() != 1 {
35                        unreachable!()
36                    }
37
38                    &field.unnamed.first().unwrap().ty
39                }
40                _ => unreachable!(),
41            };
42            SubCommandVariant { name, ty }
43        })
44        .collect();
45
46    let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
47    let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
48
49    (quote! {
50        impl argh::FromArgs for #name {
51            fn from_args(command_name: &[&str], args: &[&str])
52                -> std::result::Result<Self, argh::EarlyExit>
53            {
54                let subcommand_name = if let Some(subcommand_name) = command_name.last() {
55                    *subcommand_name
56                } else {
57                    return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
58                };
59
60                #(
61                    if <#variant_ty as argh::SubCommands>::COMMANDS
62                    .iter()
63                    .find(|ci| ci.name.eq(subcommand_name))
64                    .is_some()
65                    {
66                        return <#variant_ty as argh::FromArgs>::from_args(command_name, args)
67                            .map(|v| Self::#variant_names(v));
68                    }
69                )*
70
71                Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
72            }
73
74            fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
75                let subcommand_name = if let Some(subcommand_name) = command_name.last() {
76                    *subcommand_name
77                } else {
78                    return Err(argh::EarlyExit::from("no subcommand name".to_owned()));
79                };
80
81                #(
82                    if <#variant_ty as argh::SubCommands>::COMMANDS
83                    .iter()
84                    .find(|ci| ci.name.eq(subcommand_name))
85                    .is_some()
86                    {
87                        return <#variant_ty as argh::FromArgs>::redact_arg_values(
88                            command_name,
89                            args,
90                        );
91                    }
92
93                )*
94
95                Err(argh::EarlyExit::from("no subcommand matched".to_owned()))
96            }
97        }
98
99        impl argh::SubCommands for #name {
100            const COMMANDS: &'static [&'static argh::CommandInfo] = {
101                const TOTAL_LEN: usize = #(<#variant_ty as argh::SubCommands>::COMMANDS.len())+*;
102                const COMMANDS: [&'static argh::CommandInfo; TOTAL_LEN] = {
103                    let slices = &[#(<#variant_ty as argh::SubCommands>::COMMANDS,)*];
104                    // Its not possible for slices[0][0] to be invalid
105                    let mut output = [slices[0][0]; TOTAL_LEN];
106
107                    let mut output_index = 0;
108                    let mut which_slice = 0;
109                    while which_slice < slices.len() {
110                        let slice = &slices[which_slice];
111                        let mut index_in_slice = 0;
112                        while index_in_slice < slice.len() {
113                            output[output_index] = slice[index_in_slice];
114                            output_index += 1;
115                            index_in_slice += 1;
116                        }
117                        which_slice += 1;
118                    }
119                    output
120                };
121                &COMMANDS
122            };
123        }
124    })
125    .into()
126}
127
128/// A helper proc macro to pad strings so that argh would break them at intended points
129#[proc_macro_attribute]
130pub fn pad_description_for_argh(
131    _attr: proc_macro::TokenStream,
132    item: proc_macro::TokenStream,
133) -> proc_macro::TokenStream {
134    let mut item = syn::parse_macro_input!(item as syn::Item);
135    if let syn::Item::Struct(s) = &mut item {
136        if let syn::Fields::Named(fields) = &mut s.fields {
137            for f in fields.named.iter_mut() {
138                for a in f.attrs.iter_mut() {
139                    if a.path()
140                        .get_ident()
141                        .map(|i| i.to_string())
142                        .unwrap_or_default()
143                        == *"doc"
144                    {
145                        if let syn::Meta::NameValue(syn::MetaNameValue {
146                            value:
147                                syn::Expr::Lit(syn::ExprLit {
148                                    lit: syn::Lit::Str(s),
149                                    ..
150                                }),
151                            ..
152                        }) = &a.meta
153                        {
154                            let doc = s.value().lines().fold(String::new(), |mut output, s| {
155                                let _ = write!(output, "{s: <61}");
156                                output
157                            });
158                            *a = syn::parse_quote! { #[doc= #doc] };
159                        }
160                    }
161                }
162            }
163        } else {
164            unreachable!()
165        }
166    } else {
167        unreachable!()
168    }
169    quote! {
170        #item
171    }
172    .into()
173}