1use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{
11 Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, HandoffKind, graph_algorithms,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::union_find::UnionFind;
15
16struct BarrierCrossers {
18 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
20 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
22}
23impl BarrierCrossers {
24 fn iter_node_pairs<'a>(
26 &'a self,
27 partitioned_graph: &'a DfirGraph,
28 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
29 let edge_pairs_iter = self
30 .edge_barrier_crossers
31 .iter()
32 .map(|(edge_id, &delay_type)| {
33 let src_dst = partitioned_graph.edge(edge_id);
34 (src_dst, delay_type)
35 });
36 let singleton_pairs_iter = self
37 .singleton_barrier_crossers
38 .iter()
39 .map(|&src_dst| (src_dst, DelayType::Stratum));
40 edge_pairs_iter.chain(singleton_pairs_iter)
41 }
42
43 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
45 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
46 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
47 }
48 }
49}
50
51fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
53 let edge_barrier_crossers = partitioned_graph
54 .edges()
55 .filter(|&(_, (_src, dst))| {
56 partitioned_graph.node_loop(dst).is_none()
58 })
59 .filter_map(|(edge_id, (_src, dst))| {
60 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
61 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
62 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
63 Some((edge_id, input_barrier))
64 })
65 .collect();
66 let singleton_barrier_crossers = partitioned_graph
67 .node_ids()
68 .flat_map(|dst| {
69 partitioned_graph
70 .node_singleton_references(dst)
71 .iter()
72 .flatten()
73 .map(move |&src_ref| (src_ref, dst))
74 })
75 .collect();
76 BarrierCrossers {
77 edge_barrier_crossers,
78 singleton_barrier_crossers,
79 }
80}
81
82fn find_subgraph_unionfind(
83 partitioned_graph: &DfirGraph,
84 barrier_crossers: &BarrierCrossers,
85) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
86 let mut node_color = partitioned_graph
91 .node_ids()
92 .filter_map(|node_id| {
93 let op_color = partitioned_graph.node_color(node_id)?;
94 Some((node_id, op_color))
95 })
96 .collect::<SparseSecondaryMap<_, _>>();
97
98 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
99 UnionFind::with_capacity(partitioned_graph.nodes().len());
100
101 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
104 let mut progress = true;
113 while progress {
114 progress = false;
115 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
117 if subgraph_unionfind.same_set(src, dst) {
119 continue;
122 }
123
124 if barrier_crossers
126 .iter_node_pairs(partitioned_graph)
127 .any(|((x_src, x_dst), _)| {
128 (subgraph_unionfind.same_set(x_src, src)
129 && subgraph_unionfind.same_set(x_dst, dst))
130 || (subgraph_unionfind.same_set(x_src, dst)
131 && subgraph_unionfind.same_set(x_dst, src))
132 })
133 {
134 continue;
135 }
136
137 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
139 continue;
140 }
141 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
143 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
144 }) {
145 continue;
146 }
147
148 if can_connect_colorize(&mut node_color, src, dst) {
149 subgraph_unionfind.union(src, dst);
152 assert!(handoff_edges.remove(&edge_id));
153 progress = true;
154 }
155 }
156 }
157
158 (subgraph_unionfind, handoff_edges)
159}
160
161fn make_subgraph_collect(
165 partitioned_graph: &DfirGraph,
166 mut subgraph_unionfind: UnionFind<GraphNodeId>,
167) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
168 let topo_sort = graph_algorithms::topo_sort(
172 partitioned_graph
173 .nodes()
174 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
175 .map(|(node_id, _)| node_id),
176 |v| {
177 partitioned_graph
178 .node_predecessor_nodes(v)
179 .filter(|&pred_id| {
180 let pred = partitioned_graph.node(pred_id);
181 !matches!(pred, GraphNode::Handoff { .. })
182 })
183 },
184 )
185 .expect("Subgraphs are in-out trees.");
186
187 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
188 for node_id in topo_sort {
189 let repr_node = subgraph_unionfind.find(node_id);
190 if !grouped_nodes.contains_key(repr_node) {
191 grouped_nodes.insert(repr_node, Default::default());
192 }
193 grouped_nodes[repr_node].push(node_id);
194 }
195 grouped_nodes
196}
197
198fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
202 let (subgraph_unionfind, handoff_edges) =
211 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
212
213 for edge_id in handoff_edges {
215 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
216
217 let src_node = partitioned_graph.node(src_id);
219 let dst_node = partitioned_graph.node(dst_id);
220 if matches!(src_node, GraphNode::Handoff { .. })
221 || matches!(dst_node, GraphNode::Handoff { .. })
222 {
223 continue;
224 }
225
226 let hoff = GraphNode::Handoff {
227 kind: HandoffKind::Vec,
228 src_span: src_node.span(),
229 dst_span: dst_node.span(),
230 };
231 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
232
233 barrier_crossers.replace_edge(edge_id, out_edge_id);
235 }
236
237 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
241 for (_repr_node, member_nodes) in grouped_nodes {
242 partitioned_graph.insert_subgraph(member_nodes).unwrap();
243 }
244}
245
246fn can_connect_colorize(
252 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
253 src: GraphNodeId,
254 dst: GraphNodeId,
255) -> bool {
256 let can_connect = match (node_color.get(src), node_color.get(dst)) {
261 (None, None) => false,
264
265 (None, Some(Color::Pull | Color::Comp)) => {
267 node_color.insert(src, Color::Pull);
268 true
269 }
270 (None, Some(Color::Push | Color::Hoff)) => {
271 node_color.insert(src, Color::Push);
272 true
273 }
274
275 (Some(Color::Pull | Color::Hoff), None) => {
277 node_color.insert(dst, Color::Pull);
278 true
279 }
280 (Some(Color::Comp | Color::Push), None) => {
281 node_color.insert(dst, Color::Push);
282 true
283 }
284
285 (Some(Color::Pull), Some(Color::Pull)) => true,
287 (Some(Color::Pull), Some(Color::Comp)) => true,
288 (Some(Color::Pull), Some(Color::Push)) => true,
289
290 (Some(Color::Comp), Some(Color::Pull)) => false,
291 (Some(Color::Comp), Some(Color::Comp)) => false,
292 (Some(Color::Comp), Some(Color::Push)) => true,
293
294 (Some(Color::Push), Some(Color::Pull)) => false,
295 (Some(Color::Push), Some(Color::Comp)) => false,
296 (Some(Color::Push), Some(Color::Push)) => true,
297
298 (Some(Color::Hoff), Some(_)) => false,
300 (Some(_), Some(Color::Hoff)) => false,
301 };
302 can_connect
303}
304
305fn order_subgraphs(
311 partitioned_graph: &mut DfirGraph,
312 barrier_crossers: &BarrierCrossers,
313) -> Result<(), Diagnostic> {
314 let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
316
317 let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
319
320 for (hoff_id, hoff) in partitioned_graph.nodes() {
322 if !matches!(hoff, GraphNode::Handoff { .. }) {
323 continue;
324 }
325
326 if partitioned_graph.node_degree_out(hoff_id) == 0 {
328 continue;
329 }
330 assert_eq!(1, partitioned_graph.node_degree_out(hoff_id));
331
332 let (succ_edge, succ) = partitioned_graph.node_successors(hoff_id).next().unwrap();
333
334 let succ_edge_delaytype = barrier_crossers
335 .edge_barrier_crossers
336 .get(succ_edge)
337 .copied();
338 if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
340 tick_edges.push((succ_edge, delay_type));
341 continue;
342 }
343
344 assert_eq!(1, partitioned_graph.node_degree_in(hoff_id));
345 let (_edge_id, pred) = partitioned_graph.node_predecessors(hoff_id).next().unwrap();
346
347 let pred_sg = partitioned_graph
348 .node_subgraph(pred)
349 .expect("Handoff pred not in subgraph, may be a doubled/adjacent handoff");
350 let succ_sg = partitioned_graph
351 .node_subgraph(succ)
352 .expect("Handoff succ not in subgraph, may be a doubled/adjacent handoff");
353
354 sg_preds.entry(succ_sg).or_default().push(pred_sg);
355 }
356 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
358 assert_ne!(pred, succ);
359 let pred_sg = if let Some(sg) = partitioned_graph.node_subgraph(pred) {
361 sg
362 } else {
363 let (_edge, pred_pred) = partitioned_graph
365 .node_predecessors(pred)
366 .next()
367 .expect("handoff must have a predecessor");
368 partitioned_graph.node_subgraph(pred_pred).unwrap()
369 };
370 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
371 if pred_sg == succ_sg {
372 continue;
373 }
374 sg_preds.entry(succ_sg).or_default().push(pred_sg);
375
376 if matches!(partitioned_graph.node(pred), GraphNode::Handoff { .. }) {
379 assert!(
380 partitioned_graph.node_degree_out(pred) <= 1,
381 "handoff should have at most one successor"
382 );
383 if let Some((_edge, consumer)) = partitioned_graph.node_successors(pred).next() {
384 let consumer_sg = partitioned_graph.node_subgraph(consumer).unwrap();
385 if consumer_sg != succ_sg {
386 sg_preds.entry(consumer_sg).or_default().push(succ_sg);
387 }
388 }
389 }
390 }
391
392 if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
394 sg_preds.get(&v).into_iter().flatten().copied()
395 }) {
396 let span = cycle
397 .first()
398 .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
399 .map(|n| partitioned_graph.node(n).span())
400 .unwrap_or_else(Span::call_site);
401 return Err(Diagnostic::spanned(
402 span,
403 Level::Error,
404 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
405 ));
406 }
407
408 for (edge_id, delay_type) in tick_edges {
413 let (hoff, _dst) = partitioned_graph.edge(edge_id);
414 assert!(matches!(
415 partitioned_graph.node(hoff),
416 GraphNode::Handoff {
417 kind: HandoffKind::Vec,
418 ..
419 }
420 ));
421 partitioned_graph.set_handoff_delay_type(hoff, delay_type);
422 }
423 Ok(())
424}
425
426pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
430 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
432 let mut partitioned_graph = flat_graph;
433
434 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
436
437 order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
439
440 Ok(partitioned_graph)
441}