1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
121 pub write_iterator: TokenStream,
128 pub write_iterator_after: TokenStream,
130 pub write_tick_end: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
170 where
171 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
172 {
173 pull
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
183 where
184 Psh: #root::dfir_pipes::push::Push<Item, ()>,
185 {
186 push
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::dfir_pipes::pull::poll_fn({
227 #(
228 let mut #inputs = ::std::boxed::Box::pin(#inputs);
229 )*
230 move |_cx| {
231 #(
235 let #inputs = #root::dfir_pipes::pull::Pull::pull(
236 ::std::pin::Pin::as_mut(&mut #inputs),
237 <_ as #root::dfir_pipes::Context>::from_task(_cx),
238 );
239 )*
240 #(
241 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
242 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
243 }
244 )*
245 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
246 }
247 });
248 }
249 } else {
250 quote_spanned! {op_span=>
251 #[allow(clippy::let_unit_value)]
252 let _ = (#(#outputs),*);
253 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
254 }
255 }
256}
257
258pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
261 let write_iterator = null_write_iterator_fn(write_context_args);
262 Ok(OperatorWriteOutput {
263 write_iterator,
264 ..Default::default()
265 })
266};
267
268macro_rules! declare_ops {
269 ( $( $mod:ident :: $op:ident, )* ) => {
270 $( pub(crate) mod $mod; )*
271 pub const OPERATORS: &[OperatorConstraints] = &[
273 $( $mod :: $op, )*
274 ];
275 };
276}
277declare_ops![
278 all_iterations::ALL_ITERATIONS,
279 all_once::ALL_ONCE,
280 anti_join::ANTI_JOIN,
281 assert::ASSERT,
282 assert_eq::ASSERT_EQ,
283 batch::BATCH,
284 chain::CHAIN,
285 chain_first_n::CHAIN_FIRST_N,
286 _counter::_COUNTER,
287 cross_join::CROSS_JOIN,
288 cross_join_multiset::CROSS_JOIN_MULTISET,
289 cross_singleton::CROSS_SINGLETON,
290 demux_enum::DEMUX_ENUM,
291 dest_file::DEST_FILE,
292 dest_sink::DEST_SINK,
293 dest_sink_serde::DEST_SINK_SERDE,
294 difference::DIFFERENCE,
295 enumerate::ENUMERATE,
296 filter::FILTER,
297 filter_map::FILTER_MAP,
298 flat_map::FLAT_MAP,
299 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
300 flatten::FLATTEN,
301 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
302 fold::FOLD,
303 fold_no_replay::FOLD_NO_REPLAY,
304 for_each::FOR_EACH,
305 identity::IDENTITY,
306 initialize::INITIALIZE,
307 inspect::INSPECT,
308 join::JOIN,
309 join_fused::JOIN_FUSED,
310 join_fused_lhs::JOIN_FUSED_LHS,
311 join_fused_rhs::JOIN_FUSED_RHS,
312 join_multiset::JOIN_MULTISET,
313 join_multiset_half::JOIN_MULTISET_HALF,
314 fold_keyed::FOLD_KEYED,
315 reduce_keyed::REDUCE_KEYED,
316 repeat_n::REPEAT_N,
317 lattice_bimorphism::LATTICE_BIMORPHISM,
319 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
320 lattice_fold::LATTICE_FOLD,
321 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
322 lattice_reduce::LATTICE_REDUCE,
323 map::MAP,
324 union::UNION,
325 multiset_delta::MULTISET_DELTA,
326 next_iteration::NEXT_ITERATION,
327 defer_signal::DEFER_SIGNAL,
328 defer_tick::DEFER_TICK,
329 defer_tick_lazy::DEFER_TICK_LAZY,
330 null::NULL,
331 partition::PARTITION,
332 persist::PERSIST,
333 persist_mut::PERSIST_MUT,
334 persist_mut_keyed::PERSIST_MUT_KEYED,
335 prefix::PREFIX,
336 resolve_futures::RESOLVE_FUTURES,
337 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
338 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
339 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
340 reduce::REDUCE,
341 reduce_no_replay::REDUCE_NO_REPLAY,
342 scan::SCAN,
343 scan_async_blocking::SCAN_ASYNC_BLOCKING,
344 spin::SPIN,
345 sort::SORT,
346 sort_by_key::SORT_BY_KEY,
347 source_file::SOURCE_FILE,
348 source_interval::SOURCE_INTERVAL,
349 source_iter::SOURCE_ITER,
350 source_json::SOURCE_JSON,
351 source_stdin::SOURCE_STDIN,
352 source_stream::SOURCE_STREAM,
353 source_stream_serde::SOURCE_STREAM_SERDE,
354 state::STATE,
355 state_by::STATE_BY,
356 tee::TEE,
357 unique::UNIQUE,
358 unzip::UNZIP,
359 zip::ZIP,
360 zip_longest::ZIP_LONGEST,
361];
362
363pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
365 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
366 OnceLock::new();
367 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
368}
369pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
371 if let GraphNode::Operator(operator) = node {
372 find_op_op_constraints(operator)
373 } else {
374 None
375 }
376}
377pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
379 let name = &*operator.name_string();
380 operator_lookup().get(name).copied()
381}
382
383#[derive(Clone)]
385pub struct WriteContextArgs<'a> {
386 pub root: &'a TokenStream,
388 pub context: &'a Ident,
391 pub df_ident: &'a Ident,
395 pub subgraph_id: GraphSubgraphId,
397 pub node_id: GraphNodeId,
399 pub loop_id: Option<GraphLoopId>,
401 pub op_span: Span,
403 pub op_tag: Option<String>,
405 pub work_fn: &'a Ident,
407 pub work_fn_async: &'a Ident,
409
410 pub ident: &'a Ident,
412 pub is_pull: bool,
414 pub inputs: &'a [Ident],
416 pub outputs: &'a [Ident],
418 pub singleton_output_ident: &'a Ident,
420
421 pub op_name: &'static str,
423 pub op_inst: &'a OperatorInstance,
425 pub arguments: &'a Punctuated<Expr, Token![,]>,
431}
432impl WriteContextArgs<'_> {
433 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
439 Ident::new(
440 &format!(
441 "sg_{:?}_node_{:?}_{}",
442 self.subgraph_id.data(),
443 self.node_id.data(),
444 suffix.as_ref(),
445 ),
446 self.op_span,
447 )
448 }
449
450 pub fn persistence_args_disallow_mutable<const N: usize>(
452 &self,
453 diagnostics: &mut Diagnostics,
454 ) -> [Persistence; N] {
455 let len = self.op_inst.generics.persistence_args.len();
456 if 0 != len && 1 != len && N != len {
457 diagnostics.push(Diagnostic::spanned(
458 self.op_span,
459 Level::Error,
460 format!(
461 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
462 self.op_name, N
463 ),
464 ));
465 }
466
467 let default_persistence = if self.loop_id.is_some() {
468 Persistence::None
469 } else {
470 Persistence::Tick
471 };
472 let mut out = [default_persistence; N];
473 self.op_inst
474 .generics
475 .persistence_args
476 .iter()
477 .copied()
478 .cycle() .take(N)
480 .enumerate()
481 .filter(|&(_i, p)| {
482 if p == Persistence::Mutable {
483 diagnostics.push(Diagnostic::spanned(
484 self.op_span,
485 Level::Error,
486 format!(
487 "An implementation of `'{}` does not exist",
488 p.to_str_lowercase()
489 ),
490 ));
491 false
492 } else {
493 true
494 }
495 })
496 .for_each(|(i, p)| {
497 out[i] = p;
498 });
499 out
500 }
501}
502
503pub trait RangeTrait<T>: Send + Sync + Debug
505where
506 T: ?Sized,
507{
508 fn start_bound(&self) -> Bound<&T>;
510 fn end_bound(&self) -> Bound<&T>;
512 fn contains(&self, item: &T) -> bool
514 where
515 T: PartialOrd<T>;
516
517 fn human_string(&self) -> String
519 where
520 T: Display + PartialEq,
521 {
522 match (self.start_bound(), self.end_bound()) {
523 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
524
525 (Bound::Included(n), Bound::Included(x)) if n == x => {
526 format!("exactly {}", n)
527 }
528 (Bound::Included(n), Bound::Included(x)) => {
529 format!("at least {} and at most {}", n, x)
530 }
531 (Bound::Included(n), Bound::Excluded(x)) => {
532 format!("at least {} and less than {}", n, x)
533 }
534 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
535 (Bound::Excluded(n), Bound::Included(x)) => {
536 format!("more than {} and at most {}", n, x)
537 }
538 (Bound::Excluded(n), Bound::Excluded(x)) => {
539 format!("more than {} and less than {}", n, x)
540 }
541 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
542 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
543 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
544 }
545 }
546}
547
548impl<R, T> RangeTrait<T> for R
549where
550 R: RangeBounds<T> + Send + Sync + Debug,
551{
552 fn start_bound(&self) -> Bound<&T> {
553 self.start_bound()
554 }
555
556 fn end_bound(&self) -> Bound<&T> {
557 self.end_bound()
558 }
559
560 fn contains(&self, item: &T) -> bool
561 where
562 T: PartialOrd<T>,
563 {
564 self.contains(item)
565 }
566}
567
568#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
570pub enum Persistence {
571 None,
573 Loop,
575 Tick,
577 Static,
579 Mutable,
581}
582impl Persistence {
583 pub fn to_str_lowercase(self) -> &'static str {
585 match self {
586 Persistence::None => "none",
587 Persistence::Tick => "tick",
588 Persistence::Loop => "loop",
589 Persistence::Static => "static",
590 Persistence::Mutable => "mutable",
591 }
592 }
593}
594
595fn make_missing_runtime_msg(op_name: &str) -> Literal {
597 Literal::string(&format!(
598 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
599 op_name
600 ))
601}
602
603#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
605pub enum OperatorCategory {
606 Map,
608 Filter,
610 Flatten,
612 Fold,
614 KeyedFold,
616 LatticeFold,
618 Persistence,
620 MultiIn,
622 MultiOut,
624 Source,
626 Sink,
628 Control,
630 CompilerFusionOperator,
632 Windowing,
634 Unwindowing,
636}
637impl OperatorCategory {
638 pub fn name(self) -> &'static str {
640 self.get_variant_docs().split_once(":").unwrap().0
641 }
642 pub fn description(self) -> &'static str {
644 self.get_variant_docs().split_once(":").unwrap().1
645 }
646}
647
648#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
650pub enum FloType {
651 Source,
653 Windowing,
655 Unwindowing,
657 NextIteration,
659}