rukalang/pass/
mod.rs

1//! Shared pass orchestration and provenance tracking.
2
3use std::time::Duration;
4
5#[cfg(not(target_arch = "wasm32"))]
6use std::time::Instant;
7
8use cranelift_entity::{entity_impl, PrimaryMap};
9use ruka_frontend::SourceSpan;
10
11/// Generic pass interface for one compiler transformation.
12pub trait Pass {
13    /// Input value consumed by this pass.
14    type In;
15    /// Output value produced by this pass.
16    type Out;
17    /// Error type emitted by this pass.
18    type Error;
19
20    /// Stable pass name used for stats and diagnostics.
21    const NAME: &'static str;
22
23    /// Execute one pass invocation.
24    fn run(&mut self, input: Self::In, cx: &mut PassContext) -> Result<Self::Out, Self::Error>;
25}
26
27/// Stable identifier for one pass in a pipeline.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct PassId(u32);
30entity_impl!(PassId);
31
32/// Stable identifier for one source file in shared provenance tables.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub struct SourceFileId(u32);
35entity_impl!(SourceFileId);
36
37/// Stable identifier for one source span in shared provenance tables.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
39pub struct SpanId(u32);
40entity_impl!(SpanId);
41
42/// Stable identifier for one provenance node in shared provenance tables.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct OriginId(u32);
45entity_impl!(OriginId);
46
47/// One pass timing sample recorded by the pass runner.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct PassTiming {
50    /// Pass name.
51    pub name: &'static str,
52    /// Wall-clock pass duration.
53    pub elapsed: Duration,
54}
55
56/// One pass snapshot emitted by the driver pipeline.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct PassSnapshot {
59    /// Stable semantic kind for machine-readable snapshot consumers.
60    pub kind: PassSnapshotKind,
61    /// Pass name associated with this snapshot.
62    pub name: &'static str,
63    /// Human-readable snapshot detail string.
64    pub detail: String,
65    /// Structured fields for machine-readable snapshot consumers.
66    pub fields: Vec<PassSnapshotField>,
67}
68
69/// One structured key/value field in a pass snapshot.
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct PassSnapshotField {
72    /// Field key.
73    pub key: &'static str,
74    /// Field value.
75    pub value: PassSnapshotValue,
76}
77
78/// Structured value payload for one snapshot field.
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub enum PassSnapshotValue {
81    /// Unsigned integer value.
82    U64(u64),
83    /// Signed integer value.
84    I64(i64),
85    /// Boolean value.
86    Bool(bool),
87    /// Text value.
88    Text(String),
89}
90
91/// Stable semantic kind for one pass snapshot payload.
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum PassSnapshotKind {
94    /// Snapshot emitted after meta expansion.
95    MetaProgram,
96    /// Snapshot emitted after elaboration.
97    ElabProgram,
98    /// Snapshot emitted after HIR lowering.
99    HirProgram,
100    /// Snapshot emitted after semantic checking.
101    CheckProgram,
102    /// Snapshot emitted after MIR lowering.
103    MirProgram,
104    /// Snapshot emitted after Rust code generation.
105    CodegenRust,
106    /// Snapshot emitted after WASM code generation.
107    CodegenWasm,
108}
109
110impl PassSnapshotKind {
111    /// Return stable text encoding used by JSON output.
112    pub fn as_str(self) -> &'static str {
113        match self {
114            Self::MetaProgram => "meta_program",
115            Self::ElabProgram => "elab_program",
116            Self::HirProgram => "hir_program",
117            Self::CheckProgram => "check_program",
118            Self::MirProgram => "mir_program",
119            Self::CodegenRust => "codegen_rust",
120            Self::CodegenWasm => "codegen_wasm",
121        }
122    }
123}
124
125/// Aggregated provenance table counts captured during one compile run.
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub struct ProvenanceSummary {
128    /// Total number of interned spans.
129    pub span_count: usize,
130    /// Total number of provenance nodes.
131    pub origin_count: usize,
132    /// Number of parsed origins.
133    pub parsed_count: usize,
134    /// Number of expanded origins.
135    pub expanded_count: usize,
136    /// Number of lowered origins.
137    pub lowered_count: usize,
138    /// Number of synthesized origins.
139    pub synthesized_count: usize,
140}
141
142/// Provenance node that maps transformed entities back to source.
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub enum Origin {
145    /// Node parsed directly from source.
146    Parsed {
147        /// Interned source span.
148        span: SpanId,
149    },
150    /// Node produced by one expansion pass.
151    Expanded {
152        /// Input provenance for this transformation.
153        from: OriginId,
154        /// Pass that produced this node.
155        phase: PassId,
156    },
157    /// Node produced by one lowering pass.
158    Lowered {
159        /// Input provenance for this transformation.
160        from: OriginId,
161        /// Pass that produced this node.
162        phase: PassId,
163    },
164    /// Node synthesized without a direct one-to-one input.
165    Synthesized {
166        /// Human-readable reason for synthesis.
167        reason: &'static str,
168        /// Parent provenance when available.
169        parent: Option<OriginId>,
170    },
171}
172
173/// Arena-backed provenance tables shared by all passes.
174#[derive(Debug, Default)]
175pub struct ProvenanceTables {
176    spans: PrimaryMap<SpanId, SourceSpan>,
177    origins: PrimaryMap<OriginId, Origin>,
178}
179
180impl ProvenanceTables {
181    /// Intern a source span and return its stable id.
182    pub fn intern_span(&mut self, span: SourceSpan) -> SpanId {
183        self.spans.push(span)
184    }
185
186    /// Record one parsed origin from a source span.
187    pub fn record_parsed(&mut self, span: SourceSpan) -> OriginId {
188        let span_id = self.intern_span(span);
189        self.origins.push(Origin::Parsed { span: span_id })
190    }
191
192    /// Record one expanded origin from an existing origin id.
193    pub fn record_expanded(&mut self, from: OriginId, phase: PassId) -> OriginId {
194        self.origins.push(Origin::Expanded { from, phase })
195    }
196
197    /// Record one lowered origin from an existing origin id.
198    pub fn record_lowered(&mut self, from: OriginId, phase: PassId) -> OriginId {
199        self.origins.push(Origin::Lowered { from, phase })
200    }
201
202    /// Record one synthesized origin.
203    pub fn record_synthesized(
204        &mut self,
205        reason: &'static str,
206        parent: Option<OriginId>,
207    ) -> OriginId {
208        self.origins.push(Origin::Synthesized { reason, parent })
209    }
210
211    /// Return one span by id.
212    pub fn span(&self, span_id: SpanId) -> &SourceSpan {
213        &self.spans[span_id]
214    }
215
216    /// Return one origin by id.
217    pub fn origin(&self, origin_id: OriginId) -> &Origin {
218        &self.origins[origin_id]
219    }
220
221    /// Return aggregate counters for provenance tables.
222    pub fn summary(&self) -> ProvenanceSummary {
223        let mut parsed_count = 0usize;
224        let mut expanded_count = 0usize;
225        let mut lowered_count = 0usize;
226        let mut synthesized_count = 0usize;
227
228        for (_, origin) in self.origins.iter() {
229            match origin {
230                Origin::Parsed { .. } => parsed_count += 1,
231                Origin::Expanded { .. } => expanded_count += 1,
232                Origin::Lowered { .. } => lowered_count += 1,
233                Origin::Synthesized { .. } => synthesized_count += 1,
234            }
235        }
236
237        ProvenanceSummary {
238            span_count: self.spans.len(),
239            origin_count: self.origins.len(),
240            parsed_count,
241            expanded_count,
242            lowered_count,
243            synthesized_count,
244        }
245    }
246}
247
248/// Shared context passed through compiler passes.
249#[derive(Debug, Default)]
250pub struct PassContext {
251    pass_ids: PrimaryMap<PassId, &'static str>,
252    pass_timings: Vec<PassTiming>,
253    provenance: ProvenanceTables,
254}
255
256impl PassContext {
257    /// Construct an empty pass context.
258    pub fn new() -> Self {
259        Self::default()
260    }
261
262    /// Run one pass, measure its wall-clock duration, and store timing stats.
263    pub fn run_timed<T, E, F>(&mut self, pass_name: &'static str, run: F) -> Result<T, E>
264    where
265        F: FnOnce(&mut PassContext, PassId) -> Result<T, E>,
266    {
267        let pass_id = self.pass_ids.push(pass_name);
268        let (result, elapsed) = time_pass(|| run(self, pass_id));
269        self.pass_timings.push(PassTiming {
270            name: pass_name,
271            elapsed,
272        });
273        result
274    }
275
276    /// Run one typed pass, measure duration, and store timing stats.
277    pub fn run_pass<P>(&mut self, pass: &mut P, input: P::In) -> Result<P::Out, P::Error>
278    where
279        P: Pass,
280    {
281        self.run_timed(P::NAME, |cx, _pass_id| pass.run(input, cx))
282    }
283
284    /// Run one typed pass and return its pass id with output.
285    pub fn run_pass_with_id<P>(
286        &mut self,
287        pass: &mut P,
288        input: P::In,
289    ) -> Result<(PassId, P::Out), P::Error>
290    where
291        P: Pass,
292    {
293        let pass_id = self.pass_ids.push(P::NAME);
294        let (result, elapsed) = time_pass(|| pass.run(input, self));
295        self.pass_timings.push(PassTiming {
296            name: P::NAME,
297            elapsed,
298        });
299        result.map(|output| (pass_id, output))
300    }
301
302    /// Run one infallible pass, measure duration, and store timing stats.
303    pub fn run_timed_value<T, F>(&mut self, pass_name: &'static str, run: F) -> T
304    where
305        F: FnOnce(&mut PassContext, PassId) -> T,
306    {
307        let pass_id = self.pass_ids.push(pass_name);
308        let (output, elapsed) = time_pass(|| run(self, pass_id));
309        self.pass_timings.push(PassTiming {
310            name: pass_name,
311            elapsed,
312        });
313        output
314    }
315
316    /// Return recorded pass timings in execution order.
317    pub fn pass_timings(&self) -> &[PassTiming] {
318        &self.pass_timings
319    }
320
321    /// Append externally collected pass timing samples.
322    pub fn extend_pass_timings<I>(&mut self, timings: I)
323    where
324        I: IntoIterator<Item = PassTiming>,
325    {
326        self.pass_timings.extend(timings);
327    }
328
329    /// Return shared provenance tables.
330    pub fn provenance(&self) -> &ProvenanceTables {
331        &self.provenance
332    }
333
334    /// Return mutable shared provenance tables.
335    pub fn provenance_mut(&mut self) -> &mut ProvenanceTables {
336        &mut self.provenance
337    }
338}
339
340/// Time one pass execution, returning output and elapsed duration.
341fn time_pass<T, F>(run: F) -> (T, Duration)
342where
343    F: FnOnce() -> T,
344{
345    #[cfg(target_arch = "wasm32")]
346    {
347        (run(), Duration::ZERO)
348    }
349
350    #[cfg(not(target_arch = "wasm32"))]
351    {
352        let start = Instant::now();
353        let output = run();
354        (output, start.elapsed())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use cranelift_entity::EntityRef;
361    use ruka_frontend::{FileId, SourceLocation, SourceSpan};
362
363    use super::{Origin, Pass, PassContext, ProvenanceSummary};
364
365    struct AddOnePass;
366
367    impl Pass for AddOnePass {
368        type In = i32;
369        type Out = i32;
370        type Error = ();
371
372        const NAME: &'static str = "pass.add_one";
373
374        fn run(
375            &mut self,
376            input: Self::In,
377            _cx: &mut PassContext,
378        ) -> Result<Self::Out, Self::Error> {
379            Ok(input + 1)
380        }
381    }
382
383    #[test]
384    fn records_timing_for_run_timed() {
385        let mut cx = PassContext::new();
386        let value = cx
387            .run_timed("pass.alpha", |_cx, _pass_id| Ok::<_, ()>(42))
388            .expect("pass should succeed");
389        assert_eq!(value, 42);
390        assert_eq!(cx.pass_timings().len(), 1);
391        assert_eq!(cx.pass_timings()[0].name, "pass.alpha");
392    }
393
394    #[test]
395    fn runs_typed_pass_and_records_timing() {
396        let mut cx = PassContext::new();
397        let mut pass = AddOnePass;
398        let value = cx.run_pass(&mut pass, 41).expect("pass should succeed");
399        assert_eq!(value, 42);
400        assert_eq!(cx.pass_timings().len(), 1);
401        assert_eq!(cx.pass_timings()[0].name, AddOnePass::NAME);
402    }
403
404    #[test]
405    fn runs_typed_pass_with_id_and_records_timing() {
406        let mut cx = PassContext::new();
407        let mut pass = AddOnePass;
408        let (pass_id, value) = cx
409            .run_pass_with_id(&mut pass, 1)
410            .expect("pass should succeed");
411        assert_eq!(pass_id.index(), 0);
412        assert_eq!(value, 2);
413        assert_eq!(cx.pass_timings().len(), 1);
414        assert_eq!(cx.pass_timings()[0].name, AddOnePass::NAME);
415    }
416
417    #[test]
418    fn records_parsed_origin_and_span() {
419        let mut cx = PassContext::new();
420        let span = SourceSpan::new(
421            FileId::from_u32(0),
422            SourceLocation::start(),
423            SourceLocation {
424                byte: 3,
425                line: 1,
426                column: 4,
427            },
428        );
429        let origin_id = cx.provenance_mut().record_parsed(span);
430        let origin = cx.provenance().origin(origin_id);
431
432        match origin {
433            Origin::Parsed { span: span_id } => {
434                assert_eq!(cx.provenance().span(*span_id), &span);
435            }
436            _ => panic!("expected parsed origin"),
437        }
438    }
439
440    #[test]
441    fn reports_provenance_summary_counts() {
442        let mut cx = PassContext::new();
443        let first_span = SourceSpan::new(
444            FileId::from_u32(0),
445            SourceLocation::start(),
446            SourceLocation {
447                byte: 1,
448                line: 1,
449                column: 2,
450            },
451        );
452        let second_span = SourceSpan::new(
453            FileId::from_u32(0),
454            SourceLocation {
455                byte: 2,
456                line: 1,
457                column: 3,
458            },
459            SourceLocation {
460                byte: 4,
461                line: 1,
462                column: 5,
463            },
464        );
465
466        let parsed = cx.provenance_mut().record_parsed(first_span);
467        let _ = cx.provenance_mut().record_parsed(second_span);
468        let pass_id = cx
469            .run_timed("pass.summary", |_cx, pass_id| Ok::<_, ()>(pass_id))
470            .expect("pass should succeed");
471        let _ = cx.provenance_mut().record_expanded(parsed, pass_id);
472        let _ = cx.provenance_mut().record_lowered(parsed, pass_id);
473        let _ = cx.provenance_mut().record_synthesized("test", Some(parsed));
474
475        let summary = cx.provenance().summary();
476        assert_eq!(
477            summary,
478            ProvenanceSummary {
479                span_count: 2,
480                origin_count: 5,
481                parsed_count: 2,
482                expanded_count: 1,
483                lowered_count: 1,
484                synthesized_count: 1,
485            }
486        );
487    }
488}