Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, HandoffKind, MODULE_BOUNDARY_NODE_STR, OperatorInstance,
24    PortIndexValue, SINGLETON_SLOT_NODE_STR, Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71
72    /// Resolved singletons varnames references, per node.
73    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
74    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
75    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
76
77    /// Delay type for handoff nodes that represent tick-boundary back-edges.
78    /// Set by `order_subgraphs` for `defer_tick` / `defer_tick_lazy`, either on handoff nodes
79    /// it injects or on existing handoff nodes that it marks as tick-boundary back-edges.
80    handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
81}
82
83/// Basic methods.
84impl DfirGraph {
85    /// Create a new empty graph.
86    pub fn new() -> Self {
87        Default::default()
88    }
89}
90
91/// Node methods.
92impl DfirGraph {
93    /// Get a node with its operator instance (if applicable).
94    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
95        self.nodes.get(node_id).expect("Node not found.")
96    }
97
98    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
99    /// `OperatorInstance` present, otherwise will return `None`.
100    ///
101    /// Note that no operator instances will be persent after deserialization.
102    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
103        self.operator_instances.get(node_id)
104    }
105
106    /// Get the debug variable name attached to a graph node.
107    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
108        self.node_varnames.get(node_id)
109    }
110
111    /// Get subgraph for node.
112    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
113        self.node_subgraph.get(node_id).copied()
114    }
115
116    /// Degree into a node, i.e. the number of predecessors.
117    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
118        self.graph.degree_in(node_id)
119    }
120
121    /// Degree out of a node, i.e. the number of successors.
122    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
123        self.graph.degree_out(node_id)
124    }
125
126    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
127    pub fn node_successors(
128        &self,
129        src: GraphNodeId,
130    ) -> impl '_
131    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
132    + ExactSizeIterator
133    + FusedIterator
134    + Clone
135    + Debug {
136        self.graph.successors(src)
137    }
138
139    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
140    pub fn node_predecessors(
141        &self,
142        dst: GraphNodeId,
143    ) -> impl '_
144    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
145    + ExactSizeIterator
146    + FusedIterator
147    + Clone
148    + Debug {
149        self.graph.predecessors(dst)
150    }
151
152    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
153    pub fn node_successor_edges(
154        &self,
155        src: GraphNodeId,
156    ) -> impl '_
157    + DoubleEndedIterator<Item = GraphEdgeId>
158    + ExactSizeIterator
159    + FusedIterator
160    + Clone
161    + Debug {
162        self.graph.successor_edges(src)
163    }
164
165    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
166    pub fn node_predecessor_edges(
167        &self,
168        dst: GraphNodeId,
169    ) -> impl '_
170    + DoubleEndedIterator<Item = GraphEdgeId>
171    + ExactSizeIterator
172    + FusedIterator
173    + Clone
174    + Debug {
175        self.graph.predecessor_edges(dst)
176    }
177
178    /// Successor nodes, iterator of `GraphNodeId`.
179    pub fn node_successor_nodes(
180        &self,
181        src: GraphNodeId,
182    ) -> impl '_
183    + DoubleEndedIterator<Item = GraphNodeId>
184    + ExactSizeIterator
185    + FusedIterator
186    + Clone
187    + Debug {
188        self.graph.successor_vertices(src)
189    }
190
191    /// Predecessor nodes, iterator of `GraphNodeId`.
192    pub fn node_predecessor_nodes(
193        &self,
194        dst: GraphNodeId,
195    ) -> impl '_
196    + DoubleEndedIterator<Item = GraphNodeId>
197    + ExactSizeIterator
198    + FusedIterator
199    + Clone
200    + Debug {
201        self.graph.predecessor_vertices(dst)
202    }
203
204    /// Iterator of node IDs `GraphNodeId`.
205    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
206        self.nodes.keys()
207    }
208
209    /// Iterator over `(GraphNodeId, &Node)` pairs.
210    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
211        self.nodes.iter()
212    }
213
214    /// Insert a node, assigning the given varname.
215    pub fn insert_node(
216        &mut self,
217        node: GraphNode,
218        varname_opt: Option<Ident>,
219        loop_opt: Option<GraphLoopId>,
220    ) -> GraphNodeId {
221        let node_id = self.nodes.insert(node);
222        if let Some(varname) = varname_opt {
223            self.node_varnames.insert(node_id, Varname(varname));
224        }
225        if let Some(loop_id) = loop_opt {
226            self.node_loops.insert(node_id, loop_id);
227            self.loop_nodes[loop_id].push(node_id);
228        }
229        node_id
230    }
231
232    /// Insert an operator instance for the given node. Panics if already set.
233    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
234        assert!(matches!(
235            self.nodes.get(node_id),
236            Some(GraphNode::Operator(_))
237        ));
238        let old_inst = self.operator_instances.insert(node_id, op_inst);
239        assert!(old_inst.is_none());
240    }
241
242    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
243    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
244        // Handle all nodes in two phases, since the helper methods take total ownership of `&self`.
245        // Possible to do in one phase, but would require accessing fields directly for partial mutable ownership.
246
247        // Collect operator instances, then assign.
248        let mut op_insts = Vec::new();
249        // Collect nodes that should be lowered to handoffs (the `handoff()`/`singleton()` pseudo-operators).
250        let mut handoff_nodes: Vec<(GraphNodeId, HandoffKind, Span)> = Vec::new();
251
252        for (node_id, node) in self.nodes() {
253            let GraphNode::Operator(operator) = node else {
254                continue;
255            };
256            if self.node_op_inst(node_id).is_some() {
257                continue;
258            };
259
260            // Recognize `handoff()`/`singleton()` pseudo-operators and lower to GraphNode::Handoff.
261            let handoff_kind = match &*operator.name_string() {
262                "handoff" => Some(HandoffKind::Vec),
263                "singleton" => Some(HandoffKind::Option),
264                _ => None,
265            };
266            if let Some(kind) = handoff_kind {
267                if !operator.args.is_empty() {
268                    diagnostics.push(Diagnostic::spanned(
269                        operator.path.span(),
270                        Level::Error,
271                        format!("`{}` takes no arguments.", operator.name_string()),
272                    ));
273                }
274                if operator.type_arguments().is_some() {
275                    diagnostics.push(Diagnostic::spanned(
276                        operator.path.span(),
277                        Level::Error,
278                        format!("`{}` takes no generic arguments.", operator.name_string()),
279                    ));
280                }
281                handoff_nodes.push((node_id, kind, operator.path.span()));
282                continue;
283            }
284
285            // Op constraints.
286            let Some(op_constraints) = find_op_op_constraints(operator) else {
287                diagnostics.push(Diagnostic::spanned(
288                    operator.path.span(),
289                    Level::Error,
290                    format!("Unknown operator `{}`", operator.name_string()),
291                ));
292                continue;
293            };
294
295            // Input and output ports.
296            let (input_ports, output_ports) = {
297                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
298                    .node_predecessors(node_id)
299                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
300                    .collect();
301                // Ensure sorted by port index.
302                input_edges.sort();
303                let input_ports: Vec<PortIndexValue> = input_edges
304                    .into_iter()
305                    .map(|(port, _pred)| port)
306                    .cloned()
307                    .collect();
308
309                // Collect output arguments (successors).
310                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
311                    .node_successors(node_id)
312                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
313                    .collect();
314                // Ensure sorted by port index.
315                output_edges.sort();
316                let output_ports: Vec<PortIndexValue> = output_edges
317                    .into_iter()
318                    .map(|(port, _succ)| port)
319                    .cloned()
320                    .collect();
321
322                (input_ports, output_ports)
323            };
324
325            // Generic arguments.
326            let generics = get_operator_generics(diagnostics, operator);
327            // Generic argument errors.
328            {
329                // Span of `generic_args` (if it exists), otherwise span of the operator name.
330                let generics_span = generics
331                    .generic_args
332                    .as_ref()
333                    .map(Spanned::span)
334                    .unwrap_or_else(|| operator.path.span());
335
336                if !op_constraints
337                    .persistence_args
338                    .contains(&generics.persistence_args.len())
339                {
340                    diagnostics.push(Diagnostic::spanned(
341                        generics.persistence_args_span().unwrap_or(generics_span),
342                        Level::Error,
343                        format!(
344                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
345                            op_constraints.name,
346                            op_constraints.persistence_args.human_string(),
347                            generics.persistence_args.len()
348                        ),
349                    ));
350                }
351                if !op_constraints.type_args.contains(&generics.type_args.len()) {
352                    diagnostics.push(Diagnostic::spanned(
353                        generics.type_args_span().unwrap_or(generics_span),
354                        Level::Error,
355                        format!(
356                            "`{}` should have {} generic type arguments, actually has {}.",
357                            op_constraints.name,
358                            op_constraints.type_args.human_string(),
359                            generics.type_args.len()
360                        ),
361                    ));
362                }
363            }
364
365            op_insts.push((
366                node_id,
367                OperatorInstance {
368                    op_constraints,
369                    input_ports,
370                    output_ports,
371                    singletons_referenced: operator.singletons_referenced.clone(),
372                    generics,
373                    arguments_pre: operator.args.clone(),
374                    arguments_raw: operator.args_raw.clone(),
375                },
376            ));
377        }
378
379        for (node_id, op_inst) in op_insts {
380            self.insert_node_op_inst(node_id, op_inst);
381        }
382
383        // Replace pseudo-operator nodes with GraphNode::Handoff.
384        for (node_id, kind, span) in handoff_nodes {
385            self.nodes[node_id] = GraphNode::Handoff {
386                kind,
387                src_span: span,
388                dst_span: span,
389            };
390        }
391    }
392
393    /// Inserts a node between two existing nodes connected by the given `edge_id`.
394    ///
395    /// `edge`: (src, dst, dst_idx)
396    ///
397    /// Before: A (src) ------------> B (dst)
398    /// After:  A (src) -> X (new) -> B (dst)
399    ///
400    /// Returns the ID of X & ID of edge OUT of X.
401    ///
402    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
403    /// get the edge type of the original edge.
404    pub fn insert_intermediate_node(
405        &mut self,
406        edge_id: GraphEdgeId,
407        new_node: GraphNode,
408    ) -> (GraphNodeId, GraphEdgeId) {
409        let span = Some(new_node.span());
410
411        // Make corresponding operator instance (if `node` is an operator).
412        let op_inst_opt = 'oc: {
413            let GraphNode::Operator(operator) = &new_node else {
414                break 'oc None;
415            };
416            let Some(op_constraints) = find_op_op_constraints(operator) else {
417                break 'oc None;
418            };
419            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
420
421            let mut dummy_diagnostics = Diagnostics::new();
422            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
423            assert!(dummy_diagnostics.is_empty());
424
425            Some(OperatorInstance {
426                op_constraints,
427                input_ports: vec![input_port],
428                output_ports: vec![output_port],
429                singletons_referenced: operator.singletons_referenced.clone(),
430                generics,
431                arguments_pre: operator.args.clone(),
432                arguments_raw: operator.args_raw.clone(),
433            })
434        };
435
436        // Insert new `node`.
437        let node_id = self.nodes.insert(new_node);
438        // Insert corresponding `OperatorInstance` if applicable.
439        if let Some(op_inst) = op_inst_opt {
440            self.operator_instances.insert(node_id, op_inst);
441        }
442        // Update edges to insert node within `edge_id`.
443        let (e0, e1) = self
444            .graph
445            .insert_intermediate_vertex(node_id, edge_id)
446            .unwrap();
447
448        // Update corresponding ports.
449        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
450        self.ports
451            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
452        self.ports
453            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
454
455        (node_id, e1)
456    }
457
458    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
459    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
460    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
461        assert_eq!(
462            1,
463            self.node_degree_in(node_id),
464            "Removed intermediate node must have one predecessor"
465        );
466        assert_eq!(
467            1,
468            self.node_degree_out(node_id),
469            "Removed intermediate node must have one successor"
470        );
471        assert!(
472            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
473            "Should not remove intermediate node after subgraph partitioning"
474        );
475
476        assert!(self.nodes.remove(node_id).is_some());
477        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
478            self.graph.remove_intermediate_vertex(node_id).unwrap();
479        self.operator_instances.remove(node_id);
480        self.node_varnames.remove(node_id);
481
482        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
483        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
484        self.ports.insert(new_edge_id, (src_port, dst_port));
485    }
486
487    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
488    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
489    /// either push or pull.
490    ///
491    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
492    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
493        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
494            return Some(Color::Hoff);
495        }
496
497        // TODO(shadaj): this is a horrible hack
498        if let GraphNode::Operator(op) = self.node(node_id)
499            && (op.name_string() == "resolve_futures_blocking"
500                || op.name_string() == "resolve_futures_blocking_ordered")
501        {
502            return Some(Color::Push);
503        }
504
505        // In-degree, excluding ref-edges.
506        let inn_degree = self.node_predecessor_nodes(node_id).len();
507        // Out-degree excluding ref-edges.
508        let out_degree = self.node_successor_nodes(node_id).len();
509
510        match (inn_degree, out_degree) {
511            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
512            (0, 1) => Some(Color::Pull),
513            (1, 0) => Some(Color::Push),
514            (1, 1) => None, // Linear, can be either push or pull.
515            (_many, 0 | 1) => Some(Color::Pull),
516            (0 | 1, _many) => Some(Color::Push),
517            (_many, _to_many) => Some(Color::Comp),
518        }
519    }
520
521    /// Set the operator tag (for debugging/tracing).
522    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
523        self.operator_tag.insert(node_id, tag);
524    }
525}
526
527/// Singleton references.
528impl DfirGraph {
529    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
530    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
531    pub fn set_node_singleton_references(
532        &mut self,
533        node_id: GraphNodeId,
534        singletons_referenced: Vec<Option<GraphNodeId>>,
535    ) -> Option<Vec<Option<GraphNodeId>>> {
536        self.node_singleton_references
537            .insert(node_id, singletons_referenced)
538    }
539
540    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
541    /// operators that do not reference singletons.
542    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
543        self.node_singleton_references
544            .get(node_id)
545            .map(std::ops::Deref::deref)
546            .unwrap_or_default()
547    }
548}
549
550/// Module methods.
551impl DfirGraph {
552    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
553    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
554    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
555    /// For example:
556    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
557    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
558    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
559    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
560        let mod_bound_nodes = self
561            .nodes()
562            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
563            .map(|(nid, _node)| nid)
564            .collect::<Vec<_>>();
565
566        for mod_bound_node in mod_bound_nodes {
567            self.remove_module_boundary(mod_bound_node)?;
568        }
569
570        Ok(())
571    }
572
573    /// see `merge_modules`
574    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
575    /// `merge_modules` calls this function for each module boundary in the graph.
576    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
577        assert!(
578            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
579            "Should not remove intermediate node after subgraph partitioning"
580        );
581
582        let mut mod_pred_ports = BTreeMap::new();
583        let mut mod_succ_ports = BTreeMap::new();
584
585        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
586            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
587            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
588        }
589
590        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
591            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
592            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
593        }
594
595        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
596            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
597        {
598            // get module boundary node
599            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
600                panic!();
601            };
602
603            if *input {
604                return Err(Diagnostic {
605                    span: *import_expr,
606                    level: Level::Error,
607                    message: format!(
608                        "The ports into the module did not match. input: {:?}, expected: {:?}",
609                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
610                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
611                    ),
612                });
613            } else {
614                return Err(Diagnostic {
615                    span: *import_expr,
616                    level: Level::Error,
617                    message: format!(
618                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
619                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
620                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
621                    ),
622                });
623            }
624        }
625
626        for (port, (pred_edge, pred_port)) in mod_pred_ports {
627            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
628
629            let (src, _) = self.edge(pred_edge);
630            let (_, dst) = self.edge(succ_edge);
631            self.remove_edge(pred_edge);
632            self.remove_edge(succ_edge);
633
634            let new_edge_id = self.graph.insert_edge(src, dst);
635            self.ports.insert(new_edge_id, (pred_port, succ_port));
636        }
637
638        self.graph.remove_vertex(mod_bound_node);
639        self.nodes.remove(mod_bound_node);
640
641        Ok(())
642    }
643}
644
645/// Edge methods.
646impl DfirGraph {
647    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
648    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
649        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
650        (src, dst)
651    }
652
653    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
654    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
655        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
656        (src_port, dst_port)
657    }
658
659    /// Iterator of all edge IDs `GraphEdgeId`.
660    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
661        self.graph.edge_ids()
662    }
663
664    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
665    pub fn edges(
666        &self,
667    ) -> impl '_
668    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
669    + FusedIterator
670    + Clone
671    + Debug {
672        self.graph.edges()
673    }
674
675    /// Insert an edge between nodes thru the given ports.
676    pub fn insert_edge(
677        &mut self,
678        src: GraphNodeId,
679        src_port: PortIndexValue,
680        dst: GraphNodeId,
681        dst_port: PortIndexValue,
682    ) -> GraphEdgeId {
683        let edge_id = self.graph.insert_edge(src, dst);
684        self.ports.insert(edge_id, (src_port, dst_port));
685        edge_id
686    }
687
688    /// Removes an edge and its corresponding ports and edge type info.
689    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
690        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
691        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
692    }
693}
694
695/// Subgraph methods.
696impl DfirGraph {
697    /// Nodes belonging to the given subgraph.
698    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
699        self.subgraph_nodes
700            .get(subgraph_id)
701            .expect("Subgraph not found.")
702    }
703
704    /// Iterator over all subgraph IDs.
705    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
706        self.subgraph_nodes.keys()
707    }
708
709    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
710    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
711        self.subgraph_nodes.iter()
712    }
713
714    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
715    pub fn insert_subgraph(
716        &mut self,
717        node_ids: Vec<GraphNodeId>,
718    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
719        // Check none are already in subgraphs
720        for &node_id in node_ids.iter() {
721            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
722                return Err((node_id, old_sg_id));
723            }
724        }
725        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
726            for &node_id in node_ids.iter() {
727                self.node_subgraph.insert(node_id, sg_id);
728            }
729            node_ids
730        });
731
732        Ok(subgraph_id)
733    }
734
735    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
736    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
737        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
738            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
739            true
740        } else {
741            false
742        }
743    }
744
745    /// Gets the delay type for a handoff node, if set.
746    pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
747        self.handoff_delay_type.get(node_id).copied()
748    }
749
750    /// Sets the delay type for a handoff node.
751    pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
752        self.handoff_delay_type.insert(node_id, delay_type);
753    }
754
755    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
756    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
757        subgraph_nodes
758            .iter()
759            .position(|&node_id| {
760                self.node_color(node_id)
761                    .is_some_and(|color| Color::Pull != color)
762            })
763            .unwrap_or(subgraph_nodes.len())
764    }
765}
766
767/// Display/output methods.
768impl DfirGraph {
769    /// Helper to generate a deterministic `Ident` for the given node.
770    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
771        let name = match &self.nodes[node_id] {
772            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
773            GraphNode::Handoff {
774                kind: HandoffKind::Vec,
775                ..
776            } => format!(
777                "hoff_{:?}_{}",
778                node_id.data(),
779                if is_pred { "recv" } else { "send" }
780            ),
781            GraphNode::Handoff {
782                kind: HandoffKind::Option,
783                ..
784            } => format!(
785                "singleton_{:?}_{}",
786                node_id.data(),
787                if is_pred { "recv" } else { "send" }
788            ),
789            GraphNode::ModuleBoundary { .. } => panic!(),
790        };
791        let span = match (is_pred, &self.nodes[node_id]) {
792            (_, GraphNode::Operator(operator)) => operator.span(),
793            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
794            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
795            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
796        };
797        Ident::new(&name, span)
798    }
799
800    /// Helper to generate the main buffer `Ident` for a handoff node.
801    fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
802        Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
803    }
804
805    /// Helper to generate the back (double-buffer) `Ident` for a handoff node.
806    fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
807        Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
808    }
809
810    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
811    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
812        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
813    }
814
815    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
816    /// Returns token streams for each reference:
817    /// - For stateful operators: `&singleton_op_XXX` (borrow the operator's state)
818    /// - For HandoffKind::Option: `&(hoff_XXX_buf.as_ref().unwrap())` (intentionally produce `&&T`
819    ///   so the later `(*expr)` deref yields `&T`) - TODO(mingwei)
820    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<TokenStream> {
821        self.node_singleton_references(node_id)
822            .iter()
823            .map(|singleton_node_id| {
824                // TODO(mingwei): this `expect` should be caught in error checking
825                let ref_node_id = singleton_node_id
826                    .expect("Expected singleton to be resolved but was not, this is a bug.");
827                if matches!(
828                    self.node(ref_node_id),
829                    GraphNode::Handoff {
830                        kind: HandoffKind::Option,
831                        ..
832                    }
833                ) {
834                    let buf_ident = self.hoff_buf_ident(ref_node_id, span);
835                    // Wrapping in &(...) produces &&T so that postprocess_singletons'
836                    // (*expr) deref gives &T — matching `type O = &'a T`.
837                    // TODO(mingwei): Make postprocess_singletons not deref, remove old singletons (the `else` case below).
838                    quote_spanned! {span=> &(#buf_ident.as_ref().unwrap()) }
839                } else {
840                    let singleton_ident = self.node_as_singleton_ident(ref_node_id, span);
841                    quote_spanned! {span=> &#singleton_ident }
842                }
843            })
844            .collect::<Vec<_>>()
845    }
846
847    /// Returns each subgraph's receive and send handoffs.
848    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
849    fn helper_collect_subgraph_handoffs(
850        &self,
851    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
852        // Get data on handoff src and dst subgraphs.
853        let mut subgraph_handoffs: SecondaryMap<
854            GraphSubgraphId,
855            (Vec<GraphNodeId>, Vec<GraphNodeId>),
856        > = self
857            .subgraph_nodes
858            .keys()
859            .map(|k| (k, Default::default()))
860            .collect();
861
862        // For each handoff/singleton node, add it to the `send`/`recv` lists for the corresponding subgraphs.
863        for (hoff_id, hoff) in self.nodes() {
864            if !matches!(hoff, GraphNode::Handoff { .. }) {
865                continue;
866            }
867            // Receivers from the handoff. (Should really only be one).
868            for (_edge, succ_id) in self.node_successors(hoff_id) {
869                let succ_sg = self.node_subgraph(succ_id).unwrap();
870                subgraph_handoffs[succ_sg].0.push(hoff_id);
871            }
872            // Senders into the handoff. (Should really only be one).
873            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
874                let pred_sg = self.node_subgraph(pred_id).unwrap();
875                subgraph_handoffs[pred_sg].1.push(hoff_id);
876            }
877        }
878
879        subgraph_handoffs
880    }
881
882    /// Emit this graph as runnable Rust source code tokens that execute inline.
883    /// Generates a flat `async move |df: &mut Context|` closure where subgraph
884    /// blocks are inlined in topological order, using local `Vec<T>` buffers
885    /// instead of runtime handoffs. Each call to the closure runs one tick.
886    ///
887    /// The generated code block evaluates to a `Dfir` instance wrapping the
888    /// closure. Operator prologues run at construction time on the `Context`
889    /// before it is moved into `Dfir::new`. `Dfir` provides the `Context`
890    /// to the closure on each tick run.
891    ///
892    /// # Errors
893    ///
894    /// Returns all diagnostics as `Err(diagnostics)` if any are errors
895    /// (leaving `&mut diagnostics` empty).
896    pub fn as_code(
897        &self,
898        root: &TokenStream,
899        include_type_guards: bool,
900        prefix: TokenStream,
901        diagnostics: &mut Diagnostics,
902    ) -> Result<TokenStream, Diagnostics> {
903        self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
904    }
905
906    /// Like [`Self::as_code`], but with `include_meta` controlling whether
907    /// the runtime meta graph + diagnostics JSON blobs are baked into the
908    /// generated `Dfir::new(...)` call.
909    ///
910    /// The simulator calls Dfir::new() on each iteration, and as a part of that
911    /// it does parsing of the metagraph and diganostics blob. One of them causes spans to get allocated,
912    /// each time a span is allocated, some threadlocal u32 is being incremented, and, on a long simulator run,
913    /// the u32 overflows and panics.
914    pub fn as_code_with_options(
915        &self,
916        root: &TokenStream,
917        include_type_guards: bool,
918        include_meta: bool,
919        prefix: TokenStream,
920        diagnostics: &mut Diagnostics,
921    ) -> Result<TokenStream, Diagnostics> {
922        let df = Ident::new(GRAPH, Span::call_site());
923        let context = Ident::new(CONTEXT, Span::call_site());
924
925        // 1. Generate local buffers for each handoff node (Vec for streams, Option for singletons).
926        let handoff_nodes: Vec<_> = self
927            .nodes
928            .iter()
929            .filter_map(|(node_id, node)| match node {
930                &GraphNode::Handoff {
931                    kind,
932                    src_span,
933                    dst_span,
934                } => Some((node_id, kind, (src_span, dst_span))),
935                GraphNode::Operator(_) => None,
936                GraphNode::ModuleBoundary { .. } => panic!(),
937            })
938            .collect();
939
940        let buffer_code: Vec<TokenStream> = handoff_nodes
941            .iter()
942            .map(|&(node_id, kind, (src_span, dst_span))| {
943                let span = src_span.join(dst_span).unwrap_or(src_span);
944                let buf_ident = self.hoff_buf_ident(node_id, span);
945                match kind {
946                    HandoffKind::Vec => quote_spanned! {span=>
947                        let mut #buf_ident = ::std::vec::Vec::new();
948                    },
949                    HandoffKind::Option => quote_spanned! {span=>
950                        let mut #buf_ident = ::std::option::Option::None;
951                    },
952                }
953            })
954            .collect();
955
956        // For tick-boundary handoffs (`defer_tick` / `defer_tick_lazy`), declare a
957        // second "back" buffer for double-buffering. At the start of each tick, the
958        // main buffer and back buffer are swapped so the consumer reads last tick's
959        // data while the producer writes to a fresh buffer.
960        let back_buffer_code: Vec<TokenStream> = handoff_nodes
961            .iter()
962            .filter(|(node_id, _kind, _)| self.handoff_delay_type(*node_id).is_some())
963            .map(|&(node_id, kind, (src_span, dst_span))| {
964                assert!(
965                    matches!(kind, HandoffKind::Vec),
966                    "bug: only Vec handoffs should have delay types"
967                );
968                let span = src_span.join(dst_span).unwrap_or(src_span);
969                let back_ident = self.hoff_back_ident(node_id, span);
970                quote_spanned! {span=>
971                    let mut #back_ident: Vec<_> = Vec::new();
972                }
973            })
974            .collect();
975
976        // 2. Collect subgraph handoffs (same as as_code).
977        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
978
979        // 3. Sort subgraphs topologically and collect non-lazy defer_tick buffer idents.
980        //
981        // Handoffs marked with a `DelayType` (Tick/TickLazy) are tick-boundary back-edges.
982        // These are excluded from the topo sort (no ordering constraint). Double-buffering
983        // ensures data written by the producer in tick N is only visible to the consumer
984        // in tick N+1, regardless of execution order.
985        //
986        // While iterating handoffs, we also collect buffer idents for non-lazy tick-boundary
987        // edges (defer_tick). When these buffers are non-empty at end of tick, we set
988        // can_start_tick so that run_available continues ticking.
989        let mut defer_tick_buf_idents: Vec<Ident> = Vec::new();
990        let mut back_edge_hoff_ids: BTreeSet<GraphNodeId> = BTreeSet::new();
991        let all_subgraphs = {
992            // Build predecessor map for subgraphs.
993            let mut sg_preds: SecondaryMap<GraphSubgraphId, Vec<GraphSubgraphId>> =
994                SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
995            for (hoff_id, hoff) in self.nodes() {
996                if !matches!(hoff, GraphNode::Handoff { .. }) {
997                    // Not a handoff; skip.
998                    continue;
999                }
1000                if 0 == self.node_successors(hoff_id).len() {
1001                    // Is a handoff only used by reference, not consumed.
1002                    continue;
1003                }
1004                assert_eq!(1, self.node_successors(hoff_id).len());
1005                assert_eq!(1, self.node_predecessors(hoff_id).len());
1006                let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
1007                let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
1008                let pred_sg = self.node_subgraph(pred).unwrap();
1009                let succ_sg = self.node_subgraph(succ).unwrap();
1010                if pred_sg == succ_sg {
1011                    panic!("bug: unexpected subgraph self-handoff cycle");
1012                }
1013                if let Some(delay_type) = self.handoff_delay_type(hoff_id) {
1014                    debug_assert!(matches!(delay_type, DelayType::Tick | DelayType::TickLazy));
1015                    // Tick/back-edge handoff: no ordering constraint. Double-buffering
1016                    // handles the tick deferral regardless of execution order.
1017                    back_edge_hoff_ids.insert(hoff_id);
1018
1019                    // Non-lazy tick-boundary: defer_tick (not defer_tick_lazy).
1020                    if !matches!(delay_type, DelayType::TickLazy) {
1021                        defer_tick_buf_idents.push(self.hoff_buf_ident(hoff_id, hoff.span()));
1022                    }
1023                } else {
1024                    sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
1025                }
1026            }
1027
1028            // Include singleton reference edges: if node A references the
1029            // singleton output of node B, then A's subgraph must run after B's.
1030            for dst_id in self.node_ids() {
1031                for src_ref_id in self
1032                    .node_singleton_references(dst_id)
1033                    .iter()
1034                    .copied()
1035                    .flatten()
1036                {
1037                    // For handoff nodes (no subgraph), use the predecessor's subgraph.
1038                    let src_sg = if let Some(sg) = self.node_subgraph(src_ref_id) {
1039                        sg
1040                    } else {
1041                        let (_edge, pred) = self
1042                            .node_predecessors(src_ref_id)
1043                            .next()
1044                            .expect("handoff must have a predecessor");
1045                        self.node_subgraph(pred).unwrap()
1046                    };
1047                    let dst_sg = self
1048                        .node_subgraph(dst_id)
1049                        .expect("bug: singleton ref consumer must belong to a subgraph");
1050                    if src_sg != dst_sg {
1051                        sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
1052                    }
1053
1054                    // Ensure the borrower runs before the pipe consumer
1055                    // (which takes/drains the value).
1056                    // All handoffs should have at most one successor.
1057                    if self.node_subgraph(src_ref_id).is_none() {
1058                        assert!(
1059                            self.node_degree_out(src_ref_id) <= 1,
1060                            "handoff should have at most one successor"
1061                        );
1062                        if let Some((_edge, succ_id)) = self.node_successors(src_ref_id).next()
1063                            && let Some(consumer_sg) = self.node_subgraph(succ_id)
1064                            && consumer_sg != dst_sg
1065                        {
1066                            sg_preds
1067                                .entry(consumer_sg)
1068                                .unwrap()
1069                                .or_default()
1070                                .push(dst_sg);
1071                        }
1072                    }
1073                }
1074            }
1075
1076            let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
1077                sg_preds.get(sg_id).into_iter().flatten().copied()
1078            })
1079            .expect("bug: unexpected cycle between subgraphs within the tick");
1080
1081            topo_sort
1082                .into_iter()
1083                .map(|sg_id| (sg_id, self.subgraph(sg_id)))
1084                .collect::<Vec<_>>()
1085        };
1086
1087        // Generate swap code for tick-boundary (defer_tick / defer_tick_lazy) handoffs.
1088        // At the start of each tick, swap the main buffer and back buffer so the
1089        // consumer reads last tick's data from the back buffer.
1090        let back_edge_swap_code: Vec<TokenStream> = back_edge_hoff_ids
1091            .iter()
1092            .map(|&hoff_id| {
1093                let span = self.nodes[hoff_id].span();
1094                let buf_ident = self.hoff_buf_ident(hoff_id, span);
1095                let back_ident = self.hoff_back_ident(hoff_id, span);
1096                quote_spanned! {span=>
1097                    ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1098                }
1099            })
1100            .collect();
1101
1102        // Generate drain code for handoffs with no pipe consumer (0 successors).
1103        // These are only accessed via #var references and must be cleared each tick.
1104        let no_consumer_drain_code: Vec<TokenStream> = handoff_nodes
1105            .iter()
1106            .filter(|&&(node_id, _, _)| self.node_degree_out(node_id) == 0)
1107            .map(|&(node_id, kind, (src_span, dst_span))| {
1108                let span = src_span.join(dst_span).unwrap_or(src_span);
1109                let buf_ident = self.hoff_buf_ident(node_id, span);
1110                match kind {
1111                    HandoffKind::Option => quote_spanned! {span=> #buf_ident.take(); },
1112                    HandoffKind::Vec => quote_spanned! {span=> #buf_ident.clear(); },
1113                }
1114            })
1115            .collect();
1116
1117        let mut op_prologue_code = Vec::new();
1118        let mut op_tick_end_code = Vec::new();
1119        let mut subgraph_blocks = Vec::new();
1120        {
1121            for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1122                let sg_metrics_ffi = subgraph_id.data().as_ffi();
1123                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1124
1125                // Generate buffer ident helpers for this subgraph's handoffs.
1126                let recv_port_idents: Vec<Ident> = recv_hoffs
1127                    .iter()
1128                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1129                    .collect();
1130                let send_port_idents: Vec<Ident> = send_hoffs
1131                    .iter()
1132                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1133                    .collect();
1134
1135                // Map handoff node IDs to buffer idents.
1136                let recv_buf_idents: Vec<Ident> = recv_hoffs
1137                    .iter()
1138                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1139                    .collect();
1140                let send_buf_idents: Vec<Ident> = send_hoffs
1141                    .iter()
1142                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1143                    .collect();
1144
1145                // Recv port code: drain from buffer into iterator, tracking if non-empty.
1146                // For back-edge (defer_tick) handoffs, drain from the back buffer instead.
1147                // Also update handoff metrics (measured at recv, not send — see graph.rs).
1148                let recv_port_code: Vec<TokenStream> = recv_port_idents
1149                    .iter()
1150                    .zip(recv_buf_idents.iter())
1151                    .zip(recv_hoffs.iter())
1152                    .map(|((port_ident, buf_ident), &hoff_id)| {
1153                        let hoff_ffi = hoff_id.data().as_ffi();
1154                        // Use call_site span for internal identifiers to avoid
1155                        // hygiene issues when invoked through declarative macros
1156                        // (e.g. dfir_expect_warnings!). TODO(#2781): define these once.
1157                        let work_done = Ident::new("__dfir_work_done", Span::call_site());
1158                        let metrics = Ident::new("__dfir_metrics", Span::call_site());
1159
1160                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1161                            unreachable!()
1162                        };
1163
1164                        // Compute len and drain expressions based on handoff kind.
1165                        let (len_expr, drain_expr) = match kind {
1166                            HandoffKind::Option => (
1167                                quote! { if #buf_ident.is_some() { 1usize } else { 0usize } },
1168                                quote! { #root::dfir_pipes::pull::iter(#buf_ident.take().into_iter()) },
1169                            ),
1170                            HandoffKind::Vec => {
1171                                let drain_ident = if back_edge_hoff_ids.contains(&hoff_id) {
1172                                    self.hoff_back_ident(hoff_id, buf_ident.span())
1173                                } else {
1174                                    buf_ident.clone()
1175                                };
1176                                (
1177                                    quote! { #drain_ident.len() },
1178                                    quote! { #root::dfir_pipes::pull::iter(#drain_ident.drain(..)) },
1179                                )
1180                            }
1181                        };
1182
1183                        quote_spanned! {port_ident.span()=>
1184                            {
1185                                let hoff_len = #len_expr;
1186                                if hoff_len > 0 {
1187                                    #work_done = true;
1188                                }
1189                                let hoff_metrics = &#metrics.handoffs[
1190                                    #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1191                                ];
1192                                hoff_metrics.total_items_count.update(|x| x + hoff_len);
1193                                hoff_metrics.curr_items_count.set(hoff_len);
1194                            }
1195                            let #port_ident = #drain_expr;
1196                        }
1197                    })
1198                    .collect();
1199
1200                // Send port code: push into buffer.
1201                let send_port_code: Vec<TokenStream> = send_port_idents
1202                    .iter()
1203                    .zip(send_buf_idents.iter())
1204                    .zip(send_hoffs.iter())
1205                    .map(|((port_ident, buf_ident), &hoff_id)| {
1206                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1207                            unreachable!()
1208                        };
1209                        match kind {
1210                            HandoffKind::Option => {
1211                                // Singleton slot: store exactly one item, panic on duplicate.
1212                                quote_spanned! {port_ident.span()=>
1213                                    let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1214                                        if #buf_ident.replace(__item).is_some() {
1215                                            panic!("singleton() received more than one item");
1216                                        }
1217                                    });
1218                                }
1219                            }
1220                            HandoffKind::Vec => {
1221                                quote_spanned! {port_ident.span()=>
1222                                    let #port_ident = #root::dfir_pipes::push::vec_push(&mut #buf_ident);
1223                                }
1224                            }
1225                        }
1226                    })
1227                    .collect();
1228
1229                // All nodes in a subgraph should be in the same loop.
1230                let loop_id = self.node_loop(subgraph_nodes[0]);
1231
1232                let mut subgraph_op_iter_code = Vec::new();
1233                let mut subgraph_op_iter_after_code = Vec::new();
1234                {
1235                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1236
1237                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1238                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1239
1240                    for (idx, &node_id) in nodes_iter.enumerate() {
1241                        let node = &self.nodes[node_id];
1242                        assert!(
1243                            matches!(node, GraphNode::Operator(_)),
1244                            "Handoffs are not part of subgraphs."
1245                        );
1246                        let op_inst = &self.operator_instances[node_id];
1247
1248                        let op_span = node.span();
1249                        let op_name = op_inst.op_constraints.name;
1250                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
1251                        let root = change_spans(root.clone(), op_span);
1252                        let op_constraints = OPERATORS
1253                            .iter()
1254                            .find(|op| op_name == op.name)
1255                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1256
1257                        let ident = self.node_as_ident(node_id, false);
1258
1259                        {
1260                            // TODO clean this up.
1261                            // Collect input arguments (predecessors).
1262                            let mut input_edges = self
1263                                .graph
1264                                .predecessor_edges(node_id)
1265                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1266                                .collect::<Vec<_>>();
1267                            // Ensure sorted by port index.
1268                            input_edges.sort();
1269
1270                            let inputs = input_edges
1271                                .iter()
1272                                .map(|&(_port, edge_id)| {
1273                                    let (pred, _) = self.edge(edge_id);
1274                                    self.node_as_ident(pred, true)
1275                                })
1276                                .collect::<Vec<_>>();
1277
1278                            // Collect output arguments (successors).
1279                            let mut output_edges = self
1280                                .graph
1281                                .successor_edges(node_id)
1282                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1283                                .collect::<Vec<_>>();
1284                            // Ensure sorted by port index.
1285                            output_edges.sort();
1286
1287                            let outputs = output_edges
1288                                .iter()
1289                                .map(|&(_port, edge_id)| {
1290                                    let (_, succ) = self.edge(edge_id);
1291                                    self.node_as_ident(succ, false)
1292                                })
1293                                .collect::<Vec<_>>();
1294
1295                            let is_pull = idx < pull_to_push_idx;
1296
1297                            let singleton_output_ident = &if op_constraints.has_singleton_output {
1298                                self.node_as_singleton_ident(node_id, op_span)
1299                            } else {
1300                                // This ident *should* go unused.
1301                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
1302                            };
1303
1304                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1305                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1306                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1307                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1308                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1309                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1310                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1311                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1312                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1313                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1314
1315                            let singletons_resolved =
1316                                self.helper_resolve_singletons(node_id, op_span);
1317
1318                            let arguments = &process_singletons::postprocess_singletons(
1319                                op_inst.arguments_raw.clone(),
1320                                singletons_resolved,
1321                            );
1322
1323                            let source_tag = 'a: {
1324                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1325                                    break 'a tag;
1326                                }
1327
1328                                #[cfg(nightly)]
1329                                if proc_macro::is_available() {
1330                                    let op_span = op_span.unwrap();
1331                                    break 'a format!(
1332                                        "loc_{}_{}_{}_{}_{}",
1333                                        crate::pretty_span::make_source_path_relative(
1334                                            &op_span.file()
1335                                        )
1336                                        .display()
1337                                        .to_string()
1338                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1339                                        op_span.start().line(),
1340                                        op_span.start().column(),
1341                                        op_span.end().line(),
1342                                        op_span.end().column(),
1343                                    );
1344                                }
1345
1346                                format!(
1347                                    "loc_nopath_{}_{}_{}_{}",
1348                                    op_span.start().line,
1349                                    op_span.start().column,
1350                                    op_span.end().line,
1351                                    op_span.end().column
1352                                )
1353                            };
1354
1355                            let work_fn = format_ident!(
1356                                "{}__{}__{}",
1357                                ident,
1358                                op_name,
1359                                source_tag,
1360                                span = op_span
1361                            );
1362                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1363
1364                            let context_args = WriteContextArgs {
1365                                root: &root,
1366                                df_ident: df_local,
1367                                context,
1368                                subgraph_id,
1369                                node_id,
1370                                loop_id,
1371                                op_span,
1372                                op_tag: self.operator_tag.get(node_id).cloned(),
1373                                work_fn: &work_fn,
1374                                work_fn_async: &work_fn_async,
1375                                ident: &ident,
1376                                is_pull,
1377                                inputs: &inputs,
1378                                outputs: &outputs,
1379                                singleton_output_ident,
1380                                op_name,
1381                                op_inst,
1382                                arguments,
1383                            };
1384
1385                            let write_result =
1386                                (op_constraints.write_fn)(&context_args, diagnostics);
1387                            let OperatorWriteOutput {
1388                                write_prologue,
1389                                write_iterator,
1390                                write_iterator_after,
1391                                write_tick_end,
1392                            } = write_result.unwrap_or_else(|()| {
1393                                assert!(
1394                                    diagnostics.has_error(),
1395                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1396                                    op_name,
1397                                );
1398                                OperatorWriteOutput {
1399                                    write_iterator: null_write_iterator_fn(&context_args),
1400                                    ..Default::default()
1401                                }
1402                            });
1403
1404                            op_prologue_code.push(syn::parse_quote! {
1405                                #[allow(non_snake_case)]
1406                                #[inline(always)]
1407                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1408                                    thunk()
1409                                }
1410
1411                                #[allow(non_snake_case)]
1412                                #[inline(always)]
1413                                async fn #work_fn_async<T>(
1414                                    thunk: impl ::std::future::Future<Output = T>,
1415                                ) -> T {
1416                                    thunk.await
1417                                }
1418                            });
1419                            op_prologue_code.push(write_prologue);
1420                            op_tick_end_code.push(write_tick_end);
1421                            subgraph_op_iter_code.push(write_iterator);
1422
1423                            if include_type_guards {
1424                                let type_guard = if is_pull {
1425                                    quote_spanned! {op_span=>
1426                                        let #ident = {
1427                                            #[allow(non_snake_case)]
1428                                            #[inline(always)]
1429                                            pub fn #work_fn<Item, Input>(input: Input)
1430                                                -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1431                                            where
1432                                                Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1433                                            {
1434                                                #root::pin_project_lite::pin_project! {
1435                                                    #[repr(transparent)]
1436                                                    struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1437                                                        #[pin]
1438                                                        inner: Input
1439                                                    }
1440                                                }
1441
1442                                                impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1443                                                where
1444                                                    Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1445                                                {
1446                                                    type Ctx<'ctx> = Input::Ctx<'ctx>;
1447
1448                                                    type Item = Item;
1449                                                    type Meta = Input::Meta;
1450                                                    type CanPend = Input::CanPend;
1451                                                    type CanEnd = Input::CanEnd;
1452
1453                                                    #[inline(always)]
1454                                                    fn pull(
1455                                                        self: ::std::pin::Pin<&mut Self>,
1456                                                        ctx: &mut Self::Ctx<'_>,
1457                                                    ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1458                                                        #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1459                                                    }
1460
1461                                                    #[inline(always)]
1462                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1463                                                        #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1464                                                    }
1465                                                }
1466
1467                                                Pull {
1468                                                    inner: input
1469                                                }
1470                                            }
1471                                            #work_fn::<_, _>( #ident )
1472                                        };
1473                                    }
1474                                } else {
1475                                    quote_spanned! {op_span=>
1476                                        let #ident = {
1477                                            #[allow(non_snake_case)]
1478                                            #[inline(always)]
1479                                            pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1480                                            where
1481                                                Psh: #root::dfir_pipes::push::Push<Item, ()>
1482                                            {
1483                                                #root::pin_project_lite::pin_project! {
1484                                                    #[repr(transparent)]
1485                                                    struct PushGuard<Psh> {
1486                                                        #[pin]
1487                                                        inner: Psh,
1488                                                    }
1489                                                }
1490
1491                                                impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1492                                                where
1493                                                    Psh: #root::dfir_pipes::push::Push<Item, ()>,
1494                                                {
1495                                                    type Ctx<'ctx> = Psh::Ctx<'ctx>;
1496
1497                                                    type CanPend = Psh::CanPend;
1498
1499                                                    #[inline(always)]
1500                                                    fn poll_ready(
1501                                                        self: ::std::pin::Pin<&mut Self>,
1502                                                        ctx: &mut Self::Ctx<'_>,
1503                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1504                                                        #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1505                                                    }
1506
1507                                                    #[inline(always)]
1508                                                    fn start_send(
1509                                                        self: ::std::pin::Pin<&mut Self>,
1510                                                        item: Item,
1511                                                        meta: (),
1512                                                    ) {
1513                                                        #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1514                                                    }
1515
1516                                                    #[inline(always)]
1517                                                    fn poll_finalize(
1518                                                        self: ::std::pin::Pin<&mut Self>,
1519                                                        ctx: &mut Self::Ctx<'_>,
1520                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1521                                                        #root::dfir_pipes::push::Push::poll_finalize(self.project().inner, ctx)
1522                                                    }
1523
1524                                                    #[inline(always)]
1525                                                    fn size_hint(
1526                                                        self: ::std::pin::Pin<&mut Self>,
1527                                                        hint: (usize, Option<usize>),
1528                                                    ) {
1529                                                        #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1530                                                    }
1531                                                }
1532
1533                                                PushGuard {
1534                                                    inner: psh
1535                                                }
1536                                            }
1537                                            #work_fn( #ident )
1538                                        };
1539                                    }
1540                                };
1541                                subgraph_op_iter_code.push(type_guard);
1542                            }
1543                            subgraph_op_iter_after_code.push(write_iterator_after);
1544                        }
1545                    }
1546
1547                    {
1548                        // Determine pull and push halves of the `Pivot`.
1549                        let pull_ident = if 0 < pull_to_push_idx {
1550                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1551                        } else {
1552                            // Entire subgraph is push (with a single recv/pull handoff input).
1553                            recv_port_idents[0].clone()
1554                        };
1555
1556                        #[rustfmt::skip]
1557                        let push_ident = if let Some(&node_id) =
1558                            subgraph_nodes.get(pull_to_push_idx)
1559                        {
1560                            self.node_as_ident(node_id, false)
1561                        } else if 1 == send_port_idents.len() {
1562                            // Entire subgraph is pull (with a single send/push handoff output).
1563                            send_port_idents[0].clone()
1564                        } else {
1565                            diagnostics.push(Diagnostic::spanned(
1566                                pull_ident.span(),
1567                                Level::Error,
1568                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1569                            ));
1570                            continue;
1571                        };
1572
1573                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1574                        let pivot_span = pull_ident
1575                            .span()
1576                            .join(push_ident.span())
1577                            .unwrap_or_else(|| push_ident.span());
1578                        let pivot_fn_ident = Ident::new(
1579                            &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1580                            pivot_span,
1581                        );
1582                        let root = change_spans(root.clone(), pivot_span);
1583                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1584                            #[inline(always)]
1585                            fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1586                                -> impl ::std::future::Future<Output = ()>
1587                            where
1588                                Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1589                                Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1590                            {
1591                                #root::dfir_pipes::pull::Pull::send_push(pull, push)
1592                            }
1593                            (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1594                        });
1595                    }
1596                };
1597
1598                // Each subgraph block is an async block so it can be individually instrumented.
1599                // Note: this ident is for the subgraph future, not a runtime SubgraphId binding
1600                // (unlike the scheduled path's `sg_ident`).
1601                let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1602
1603                // Generate send-side curr_items_count updates (after subgraph runs).
1604                let send_metrics_code: Vec<TokenStream> = send_hoffs
1605                    .iter()
1606                    .zip(send_buf_idents.iter())
1607                    .map(|(&hoff_id, buf_ident)| {
1608                        let hoff_ffi = hoff_id.data().as_ffi();
1609                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1610                            unreachable!()
1611                        };
1612                        let len_expr = match kind {
1613                            HandoffKind::Option => {
1614                                quote! { if #buf_ident.is_some() { 1 } else { 0 } }
1615                            }
1616                            HandoffKind::Vec => {
1617                                quote! { #buf_ident.len() }
1618                            }
1619                        };
1620                        quote! {
1621                            __dfir_metrics.handoffs[
1622                                #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1623                            ].curr_items_count.set(#len_expr);
1624                        }
1625                    })
1626                    .collect();
1627
1628                subgraph_blocks.push(quote! {
1629                    let #sg_fut_ident = async {
1630                        let #context = &#df;
1631                        #( #recv_port_code )*
1632                        #( #send_port_code )*
1633                        #( #subgraph_op_iter_code )*
1634                        #( #subgraph_op_iter_after_code )*
1635                    };
1636                    {
1637                        let sg_metrics = &__dfir_metrics.subgraphs[
1638                            #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1639                        ];
1640                        #root::scheduled::metrics::InstrumentSubgraph::new(
1641                            #sg_fut_ident, sg_metrics
1642                        ).await;
1643                        sg_metrics.total_run_count.update(|x| x + 1);
1644                    }
1645                    #( #send_metrics_code )*
1646                });
1647
1648                // Collect per-subgraph prologues into the main prologue lists.
1649                // (They are already pushed above in the operator loop.)
1650            }
1651        }
1652
1653        if diagnostics.has_error() {
1654            return Err(std::mem::take(diagnostics));
1655        }
1656        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1657
1658        let (meta_graph_arg, diagnostics_arg) = if include_meta {
1659            let meta_graph_json = serde_json::to_string(&self).unwrap();
1660            let meta_graph_json = Literal::string(&meta_graph_json);
1661
1662            let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1663            let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1664            let diagnostics_json = Literal::string(&diagnostics_json);
1665
1666            (
1667                quote! { Some(#meta_graph_json) },
1668                quote! { Some(#diagnostics_json) },
1669            )
1670        } else {
1671            (quote! { None }, quote! { None })
1672        };
1673
1674        // Generate metrics initialization: one entry per handoff and per subgraph.
1675        let metrics_init_code = {
1676            let handoff_inits = handoff_nodes.iter().map(|&(node_id, _, _)| {
1677                let ffi = node_id.data().as_ffi();
1678                quote! {
1679                    dfir_metrics.handoffs.insert(
1680                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1681                        ::std::default::Default::default(),
1682                    );
1683                }
1684            });
1685            let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1686                let ffi = sg_id.data().as_ffi();
1687                quote! {
1688                    dfir_metrics.subgraphs.insert(
1689                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1690                        ::std::default::Default::default(),
1691                    );
1692                }
1693            });
1694            handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1695        };
1696
1697        // Prologues and buffer declarations persist across ticks (outside the closure).
1698        // Subgraph blocks run each tick (inside the closure).
1699        Ok(quote! {
1700            {
1701                #prefix
1702
1703                use #root::{var_expr, var_args};
1704
1705                let __dfir_wake_state = ::std::sync::Arc::new(
1706                    #root::scheduled::context::WakeState::default()
1707                );
1708
1709                let __dfir_metrics = {
1710                    let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1711                    #( #metrics_init_code )*
1712                    ::std::rc::Rc::new(dfir_metrics)
1713                };
1714
1715                #[allow(unused_mut)]
1716                let mut #df = #root::scheduled::context::Context::new(
1717                    ::std::clone::Clone::clone(&__dfir_wake_state),
1718                    __dfir_metrics,
1719                );
1720
1721                #( #buffer_code )*
1722                #( #back_buffer_code )*
1723                #( #op_prologue_code )*
1724
1725                // Pre-set to true so the first tick always returns true
1726                // (matching Dfir pre-scheduling behavior). Subsequent ticks
1727                // start false (from take()) and are set true by recv port code
1728                // if any handoff buffer has data.
1729                let mut __dfir_work_done = true;
1730                #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1731                let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1732                    let __dfir_metrics = #df.metrics();
1733                    // Double-buffer swap for defer_tick handoffs: move last tick's
1734                    // producer output into the back buffer for the consumer to drain.
1735                    #( #back_edge_swap_code )*
1736                    #( #subgraph_blocks )*
1737
1738                    // For non-lazy defer_tick: if any deferred buffer has data,
1739                    // signal that another tick should run.
1740                    if false #( || !#defer_tick_buf_idents.is_empty() )* {
1741                        #df.schedule_subgraph(true);
1742                    }
1743
1744                    // End-of-tick state reset (e.g. 'tick persistence).
1745                    #( #op_tick_end_code )*
1746
1747                    // Drain handoff buffers that have no pipe consumer (e.g. singleton
1748                    // used only via #var reference). Without this, the value would
1749                    // persist across ticks and cause panics on the next write.
1750                    #( #no_consumer_drain_code )*
1751
1752                    #df.__end_tick();
1753                    ::std::mem::take(&mut __dfir_work_done)
1754                };
1755                #root::scheduled::context::Dfir::new(
1756                    __dfir_inline_tick,
1757                    #df,
1758                    #meta_graph_arg,
1759                    #diagnostics_arg,
1760                )
1761            }
1762        })
1763    }
1764
1765    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1766    /// those nodes will not be set in the returned map.
1767    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1768        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1769            .node_ids()
1770            .filter_map(|node_id| {
1771                let op_color = self.node_color(node_id)?;
1772                Some((node_id, op_color))
1773            })
1774            .collect();
1775
1776        // Fill in rest via subgraphs.
1777        for sg_nodes in self.subgraph_nodes.values() {
1778            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1779
1780            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1781                let is_pull = idx < pull_to_push_idx;
1782                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1783            }
1784        }
1785
1786        node_color_map
1787    }
1788
1789    /// Writes this graph as mermaid into a string.
1790    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1791        let mut output = String::new();
1792        self.write_mermaid(&mut output, write_config).unwrap();
1793        output
1794    }
1795
1796    /// Writes this graph as mermaid into the given `Write`.
1797    pub fn write_mermaid(
1798        &self,
1799        output: impl std::fmt::Write,
1800        write_config: &WriteConfig,
1801    ) -> std::fmt::Result {
1802        let mut graph_write = Mermaid::new(output);
1803        self.write_graph(&mut graph_write, write_config)
1804    }
1805
1806    /// Writes this graph as DOT (graphviz) into a string.
1807    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1808        let mut output = String::new();
1809        let mut graph_write = Dot::new(&mut output);
1810        self.write_graph(&mut graph_write, write_config).unwrap();
1811        output
1812    }
1813
1814    /// Writes this graph as DOT (graphviz) into the given `Write`.
1815    pub fn write_dot(
1816        &self,
1817        output: impl std::fmt::Write,
1818        write_config: &WriteConfig,
1819    ) -> std::fmt::Result {
1820        let mut graph_write = Dot::new(output);
1821        self.write_graph(&mut graph_write, write_config)
1822    }
1823
1824    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1825    pub(crate) fn write_graph<W>(
1826        &self,
1827        mut graph_write: W,
1828        write_config: &WriteConfig,
1829    ) -> Result<(), W::Err>
1830    where
1831        W: GraphWrite,
1832    {
1833        fn helper_edge_label(
1834            src_port: &PortIndexValue,
1835            dst_port: &PortIndexValue,
1836        ) -> Option<String> {
1837            let src_label = match src_port {
1838                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1839                PortIndexValue::Int(index) => Some(index.value.to_string()),
1840                _ => None,
1841            };
1842            let dst_label = match dst_port {
1843                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1844                PortIndexValue::Int(index) => Some(index.value.to_string()),
1845                _ => None,
1846            };
1847            let label = match (src_label, dst_label) {
1848                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1849                (Some(l1), None) => Some(l1),
1850                (None, Some(l2)) => Some(l2),
1851                (None, None) => None,
1852            };
1853            label
1854        }
1855
1856        // Make node color map one time.
1857        let node_color_map = self.node_color_map();
1858
1859        // Write prologue.
1860        graph_write.write_prologue()?;
1861
1862        // Define nodes.
1863        let mut skipped_handoffs = BTreeSet::new();
1864        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1865        for (node_id, node) in self.nodes() {
1866            if matches!(node, GraphNode::Handoff { .. }) {
1867                if write_config.no_handoffs {
1868                    skipped_handoffs.insert(node_id);
1869                    continue;
1870                } else {
1871                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1872                    let pred_sg = self.node_subgraph(pred_node);
1873                    let succ_node = self.node_successor_nodes(node_id).next();
1874                    let succ_sg = succ_node.and_then(|n| self.node_subgraph(n));
1875                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1876                        && pred_sg == succ_sg
1877                    {
1878                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1879                    }
1880                }
1881            }
1882            graph_write.write_node_definition(
1883                node_id,
1884                &if write_config.op_short_text {
1885                    node.to_name_string()
1886                } else if write_config.op_text_no_imports {
1887                    // Remove any lines that start with "use" (imports)
1888                    let full_text = node.to_pretty_string();
1889                    let mut output = String::new();
1890                    for sentence in full_text.split('\n') {
1891                        if sentence.trim().starts_with("use") {
1892                            continue;
1893                        }
1894                        output.push('\n');
1895                        output.push_str(sentence);
1896                    }
1897                    output.into()
1898                } else {
1899                    node.to_pretty_string()
1900                },
1901                if write_config.no_pull_push {
1902                    None
1903                } else {
1904                    node_color_map.get(node_id).copied()
1905                },
1906            )?;
1907        }
1908
1909        // Write edges.
1910        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1911            // Handling for if `write_config.no_handoffs` true.
1912            if skipped_handoffs.contains(&src_id) {
1913                continue;
1914            }
1915
1916            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1917            if skipped_handoffs.contains(&dst_id) {
1918                let mut handoff_succs = self.node_successors(dst_id);
1919                assert_eq!(1, handoff_succs.len());
1920                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1921                dst_id = succ_node;
1922                dst_port = self.edge_ports(succ_edge).1;
1923            }
1924
1925            let label = helper_edge_label(src_port, dst_port);
1926            let delay_type = self
1927                .node_op_inst(dst_id)
1928                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1929            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1930        }
1931
1932        // Write reference edges.
1933        if !write_config.no_references {
1934            for dst_id in self.node_ids() {
1935                for src_ref_id in self
1936                    .node_singleton_references(dst_id)
1937                    .iter()
1938                    .copied()
1939                    .flatten()
1940                {
1941                    let delay_type = Some(DelayType::Stratum);
1942                    let label = None;
1943                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1944                }
1945            }
1946        }
1947
1948        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1949        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1950        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1951        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1952        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1953
1954        // Loop -> Subgraphs
1955        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1956            let loop_id = if write_config.no_loops {
1957                None
1958            } else {
1959                self.subgraph_loop(sg_id)
1960            };
1961            (loop_id, sg_id)
1962        });
1963        let loop_subgraphs = into_group_map(loop_subgraphs);
1964        for (loop_id, subgraph_ids) in loop_subgraphs {
1965            if let Some(loop_id) = loop_id {
1966                graph_write.write_loop_start(loop_id)?;
1967            }
1968
1969            // Subgraph -> Varnames.
1970            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1971                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1972                    let opt_sg_id = if write_config.no_subgraphs {
1973                        None
1974                    } else {
1975                        Some(sg_id)
1976                    };
1977                    (opt_sg_id, (self.node_varname(node_id), node_id))
1978                })
1979            });
1980            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1981            for (sg_id, varnames) in subgraph_varnames_nodes {
1982                if let Some(sg_id) = sg_id {
1983                    graph_write.write_subgraph_start(sg_id)?;
1984                }
1985
1986                // Varnames -> Nodes.
1987                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1988                    let varname = if write_config.no_varnames {
1989                        None
1990                    } else {
1991                        varname
1992                    };
1993                    (varname, node)
1994                });
1995                let varname_nodes = into_group_map(varname_nodes);
1996                for (varname, node_ids) in varname_nodes {
1997                    if let Some(varname) = varname {
1998                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1999                    }
2000
2001                    // Write all nodes.
2002                    for node_id in node_ids {
2003                        graph_write.write_node(node_id)?;
2004                    }
2005
2006                    if varname.is_some() {
2007                        graph_write.write_varname_end()?;
2008                    }
2009                }
2010
2011                if sg_id.is_some() {
2012                    graph_write.write_subgraph_end()?;
2013                }
2014            }
2015
2016            if loop_id.is_some() {
2017                graph_write.write_loop_end()?;
2018            }
2019        }
2020
2021        // Write epilogue.
2022        graph_write.write_epilogue()?;
2023
2024        Ok(())
2025    }
2026
2027    /// Convert back into surface syntax.
2028    pub fn surface_syntax_string(&self) -> String {
2029        let mut string = String::new();
2030        self.write_surface_syntax(&mut string).unwrap();
2031        string
2032    }
2033
2034    /// Convert back into surface syntax.
2035    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2036        for (key, node) in self.nodes.iter() {
2037            match node {
2038                GraphNode::Operator(op) => {
2039                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
2040                }
2041                GraphNode::Handoff {
2042                    kind: HandoffKind::Vec,
2043                    ..
2044                } => {
2045                    writeln!(write, "{:?} = handoff();", key.data())?;
2046                }
2047                GraphNode::Handoff {
2048                    kind: HandoffKind::Option,
2049                    ..
2050                } => {
2051                    writeln!(write, "{:?} = singleton();", key.data())?;
2052                }
2053                GraphNode::ModuleBoundary { .. } => panic!(),
2054            }
2055        }
2056        writeln!(write)?;
2057        for (_e, (src_key, dst_key)) in self.graph.edges() {
2058            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
2059        }
2060        Ok(())
2061    }
2062
2063    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2064    pub fn mermaid_string_flat(&self) -> String {
2065        let mut string = String::new();
2066        self.write_mermaid_flat(&mut string).unwrap();
2067        string
2068    }
2069
2070    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2071    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2072        writeln!(write, "flowchart TB")?;
2073        for (key, node) in self.nodes.iter() {
2074            match node {
2075                GraphNode::Operator(operator) => writeln!(
2076                    write,
2077                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
2078                    span = PrettySpan(node.span()),
2079                    id = key.data(),
2080                    row_col = PrettyRowCol(node.span()),
2081                    code = operator
2082                        .to_token_stream()
2083                        .to_string()
2084                        .replace('&', "&amp;")
2085                        .replace('<', "&lt;")
2086                        .replace('>', "&gt;")
2087                        .replace('"', "&quot;")
2088                        .replace('\n', "<br>"),
2089                ),
2090                GraphNode::Handoff {
2091                    kind: HandoffKind::Vec,
2092                    ..
2093                } => {
2094                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
2095                }
2096                GraphNode::Handoff {
2097                    kind: HandoffKind::Option,
2098                    ..
2099                } => {
2100                    writeln!(
2101                        write,
2102                        r#"    {:?}{{"{}"}}"#,
2103                        key.data(),
2104                        SINGLETON_SLOT_NODE_STR
2105                    )
2106                }
2107                GraphNode::ModuleBoundary { .. } => {
2108                    writeln!(
2109                        write,
2110                        r#"    {:?}{{"{}"}}"#,
2111                        key.data(),
2112                        MODULE_BOUNDARY_NODE_STR
2113                    )
2114                }
2115            }?;
2116        }
2117        writeln!(write)?;
2118        for (_e, (src_key, dst_key)) in self.graph.edges() {
2119            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
2120        }
2121        Ok(())
2122    }
2123}
2124
2125/// Loops
2126impl DfirGraph {
2127    /// Iterator over all loop IDs.
2128    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
2129        self.loop_nodes.keys()
2130    }
2131
2132    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
2133    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
2134        self.loop_nodes.iter()
2135    }
2136
2137    /// Create a new loop context, with the given parent loop (or `None`).
2138    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
2139        let loop_id = self.loop_nodes.insert(Vec::new());
2140        self.loop_children.insert(loop_id, Vec::new());
2141        if let Some(parent_loop) = parent_loop {
2142            self.loop_parent.insert(loop_id, parent_loop);
2143            self.loop_children
2144                .get_mut(parent_loop)
2145                .unwrap()
2146                .push(loop_id);
2147        } else {
2148            self.root_loops.push(loop_id);
2149        }
2150        loop_id
2151    }
2152
2153    /// Get a node's loop context (or `None` for root).
2154    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
2155        self.node_loops.get(node_id).copied()
2156    }
2157
2158    /// Get a subgraph's loop context (or `None` for root).
2159    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2160        let &node_id = self.subgraph(subgraph_id).first().unwrap();
2161        let out = self.node_loop(node_id);
2162        debug_assert!(
2163            self.subgraph(subgraph_id)
2164                .iter()
2165                .all(|&node_id| self.node_loop(node_id) == out),
2166            "Subgraph nodes should all have the same loop context."
2167        );
2168        out
2169    }
2170
2171    /// Get a loop context's parent loop context (or `None` for root).
2172    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2173        self.loop_parent.get(loop_id).copied()
2174    }
2175
2176    /// Get a loop context's child loops.
2177    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2178        self.loop_children.get(loop_id).unwrap()
2179    }
2180}
2181
2182/// Configuration for writing graphs.
2183#[derive(Clone, Debug, Default)]
2184#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2185pub struct WriteConfig {
2186    /// Subgraphs will not be rendered if set.
2187    #[cfg_attr(feature = "clap-derive", arg(long))]
2188    pub no_subgraphs: bool,
2189    /// Variable names will not be rendered if set.
2190    #[cfg_attr(feature = "clap-derive", arg(long))]
2191    pub no_varnames: bool,
2192    /// Will not render pull/push shapes if set.
2193    #[cfg_attr(feature = "clap-derive", arg(long))]
2194    pub no_pull_push: bool,
2195    /// Will not render handoffs if set.
2196    #[cfg_attr(feature = "clap-derive", arg(long))]
2197    pub no_handoffs: bool,
2198    /// Will not render singleton references if set.
2199    #[cfg_attr(feature = "clap-derive", arg(long))]
2200    pub no_references: bool,
2201    /// Will not render loops if set.
2202    #[cfg_attr(feature = "clap-derive", arg(long))]
2203    pub no_loops: bool,
2204
2205    /// Op text will only be their name instead of the whole source.
2206    #[cfg_attr(feature = "clap-derive", arg(long))]
2207    pub op_short_text: bool,
2208    /// Op text will exclude any line that starts with "use".
2209    #[cfg_attr(feature = "clap-derive", arg(long))]
2210    pub op_text_no_imports: bool,
2211}
2212
2213/// Enum for choosing between mermaid and dot graph writing.
2214#[derive(Copy, Clone, Debug)]
2215#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2216pub enum WriteGraphType {
2217    /// Mermaid graphs.
2218    Mermaid,
2219    /// Dot (Graphviz) graphs.
2220    Dot,
2221}
2222
2223/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
2224fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2225where
2226    K: Ord,
2227{
2228    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2229    for (k, v) in iter {
2230        out.entry(k).or_default().push(v);
2231    }
2232    out
2233}