Skip to main content

dfir_lang/
process_singletons.rs

1//! Utility methods for processing singleton references: `#my_var`.
2
3use itertools::Itertools;
4use proc_macro2::{Group, Ident, TokenStream, TokenTree};
5use syn::punctuated::Punctuated;
6use syn::{Expr, Token};
7
8use crate::parse::parse_terminated;
9
10/// Finds all the singleton references `#my_var` and appends them to `found_idents`. Returns the
11/// `TokenStream` but with the hashes removed from the varnames.
12///
13/// The returned tokens are used for "preflight" parsing, to check that the rest of the syntax is
14/// OK. However the returned tokens are not used in the codegen as we need to use [`postprocess_singletons`]
15/// later to substitute-in the context referencing code for each singleton
16pub fn preprocess_singletons(tokens: TokenStream, found_idents: &mut Vec<Ident>) -> TokenStream {
17    process_singletons(tokens, &mut |singleton_ident| {
18        found_idents.push(singleton_ident.clone());
19        TokenTree::Ident(singleton_ident)
20    })
21}
22
23/// Replaces singleton references `#my_var` with the code needed to actually get the value inside.
24///
25/// * `tokens` - The tokens to update singleton references within.
26/// * `resolved_exprs` - Token streams that correspond 1:1 and in the same
27///   order as the singleton references within `tokens` (found in-order via [`preprocess_singletons`]).
28///
29/// Generates `(*expr)` — an immutable place expression that prevents consumer mutation.
30pub fn postprocess_singletons(
31    tokens: TokenStream,
32    resolved_exprs: impl IntoIterator<Item = TokenStream>,
33) -> Punctuated<Expr, Token![,]> {
34    let mut resolved_exprs_iter = resolved_exprs.into_iter();
35    let processed = process_singletons(tokens, &mut |singleton_ident| {
36        let span = singleton_ident.span();
37        let expr_tokens = resolved_exprs_iter.next().unwrap();
38        // Emit `(*expr)` so consumers get an immutable place expression.
39        let deref_tokens: TokenStream = std::iter::once(TokenTree::Punct(proc_macro2::Punct::new(
40            '*',
41            proc_macro2::Spacing::Alone,
42        )))
43        .chain(expr_tokens)
44        .collect();
45        let mut group = Group::new(proc_macro2::Delimiter::Parenthesis, deref_tokens);
46        group.set_span(span);
47        TokenTree::Group(group)
48    });
49    parse_terminated(processed).unwrap()
50}
51
52/// Traverse the token stream, applying the `map_singleton_fn` whenever a singleton is found,
53/// returning the transformed token stream.
54fn process_singletons(
55    tokens: TokenStream,
56    map_singleton_fn: &mut impl FnMut(Ident) -> TokenTree,
57) -> TokenStream {
58    tokens
59        .into_iter()
60        .peekable()
61        .batching(|iter| {
62            let out = match iter.next()? {
63                TokenTree::Group(group) => {
64                    let mut new_group = Group::new(
65                        group.delimiter(),
66                        process_singletons(group.stream(), map_singleton_fn),
67                    );
68                    new_group.set_span(group.span());
69                    TokenTree::Group(new_group)
70                }
71                TokenTree::Ident(ident) => TokenTree::Ident(ident),
72                TokenTree::Punct(punct) => {
73                    if '#' == punct.as_char() && matches!(iter.peek(), Some(TokenTree::Ident(_))) {
74                        // Found a singleton.
75                        let Some(TokenTree::Ident(mut singleton_ident)) = iter.next() else {
76                            unreachable!()
77                        };
78                        {
79                            // Include the `#` in the span.
80                            let span = singleton_ident
81                                .span()
82                                .join(punct.span())
83                                .unwrap_or(singleton_ident.span());
84                            singleton_ident.set_span(span.resolved_at(singleton_ident.span()));
85                        }
86                        (map_singleton_fn)(singleton_ident)
87                    } else {
88                        TokenTree::Punct(punct)
89                    }
90                }
91                TokenTree::Literal(lit) => TokenTree::Literal(lit),
92            };
93            Some(out)
94        })
95        .collect()
96}