ruka_codegen_rust/
program_cfg.rs

1use super::*;
2
3#[derive(Debug, Clone)]
4struct SourceMap {
5    file: String,
6    function_lines: BTreeMap<String, usize>,
7}
8
9impl SourceMap {
10    #[cfg(test)]
11    fn unknown() -> Self {
12        Self {
13            file: "<unknown>".to_owned(),
14            function_lines: BTreeMap::new(),
15        }
16    }
17
18    fn from_function_lines(file: String, function_lines: &BTreeMap<String, usize>) -> Self {
19        Self {
20            file,
21            function_lines: function_lines.clone(),
22        }
23    }
24
25    fn function_line(&self, name: &str) -> usize {
26        self.function_lines.get(name).copied().unwrap_or(0)
27    }
28}
29
30/// Error produced while generating or writing Rust source.
31#[derive(Debug, Error)]
32pub enum CodegenError {
33    /// Token rendering failed.
34    #[error("failed to format generated rust source: {0}")]
35    Format(#[from] genco::fmt::Error),
36    /// Generated source could not be parsed for pretty formatting.
37    #[error("failed to parse generated rust source for pretty formatting: {0}")]
38    PrettyFormat(#[from] syn::Error),
39    /// Writing generated source to disk failed.
40    #[error("failed to write generated rust source: {0}")]
41    Io(#[from] std::io::Error),
42}
43
44/// Emit Rust source for `program` and write it to `path`.
45pub fn emit_to_path(
46    program: &MirProgram,
47    path: &Path,
48    source_file: &Path,
49    function_lines: &BTreeMap<String, usize>,
50) -> Result<(), CodegenError> {
51    let source = emit_program_with_function_lines(program, source_file, function_lines)?;
52    std::fs::write(path, source)?;
53    Ok(())
54}
55
56/// Emit Rust source for a MIR program without source mapping metadata.
57#[cfg(test)]
58pub fn emit_program(program: &MirProgram) -> Result<String, CodegenError> {
59    emit_program_with_map(program, &SourceMap::unknown())
60}
61
62/// Emit Rust source for a MIR program with source location metadata.
63pub fn emit_program_with_function_lines(
64    program: &MirProgram,
65    source_file: &Path,
66    function_lines: &BTreeMap<String, usize>,
67) -> Result<String, CodegenError> {
68    let source_map =
69        SourceMap::from_function_lines(source_file.display().to_string(), function_lines);
70    emit_program_with_map(program, &source_map)
71}
72
73fn emit_program_with_map(
74    program: &MirProgram,
75    source_map: &SourceMap,
76) -> Result<String, CodegenError> {
77    for (_, function) in program.functions.iter() {
78        function.assert_valid();
79    }
80
81    let program_names = ProgramNames::from_program(program);
82
83    let mut tokens: rust::Tokens = quote! {
84        #[derive(Debug, Clone)]
85        /// Error returned by the generated Rust runtime shim.
86        pub enum RuntimeError {
87            MissingMain,
88            InvalidMainArity { actual: usize },
89        }
90    };
91
92    for decl in &program.structs {
93        let decl_tokens = emit_struct_decl_tokens(decl);
94        quote_in!(tokens => $decl_tokens);
95    }
96
97    for decl in &program.enums {
98        let decl_tokens = emit_enum_decl_tokens(decl);
99        quote_in!(tokens => $decl_tokens);
100    }
101
102    for (func_id, function) in program.functions.iter() {
103        let function_names = FunctionNames::from_function(function);
104        let function_tokens = emit_function_tokens(
105            func_id,
106            function,
107            &program_names,
108            &function_names,
109            &source_map.file,
110            source_map.function_line(&function.name),
111        );
112        quote_in!(tokens => $function_tokens);
113    }
114
115    if let Some(main_id) = program.function_names.get("main") {
116        let main_arity = program.functions[*main_id].arity;
117        let main_ident = program_names.function_ident(*main_id);
118        quote_in!(tokens =>
119            /// Execute the generated `main` function when present and arity-zero.
120            pub fn run_main() -> Result<(), RuntimeError> {
121                if $main_arity != 0 {
122                    Err(RuntimeError::InvalidMainArity { actual: $main_arity })
123                } else {
124                    let _ = $main_ident();
125                    ruka_runtime::ptr::assert_no_leaks();
126                    Ok(())
127                }
128            }
129        );
130    } else {
131        quote_in!(tokens =>
132            /// Report that no callable `main` function was generated.
133            pub fn run_main() -> Result<(), RuntimeError> {
134                Err(RuntimeError::MissingMain)
135            }
136        );
137    }
138
139    let source = tokens.to_file_string()?;
140    let syntax = parse_file(&source)?;
141    Ok(prettyplease::unparse(&syntax))
142}
143
144fn emit_struct_decl_tokens(decl: &MirStructDecl) -> rust::Tokens {
145    let mut tokens = rust::Tokens::new();
146    let name = mangle_struct_ident(&decl.name);
147    let mut params = rust::Tokens::new();
148    for (index, param) in decl.type_params.iter().enumerate() {
149        if index > 0 {
150            quote_in!(params => ,);
151        }
152        let param = param.clone();
153        quote_in!(params => $param);
154    }
155
156    let mut fields = rust::Tokens::new();
157    for field in &decl.fields {
158        let field_name = field.name.clone();
159        let field_ty = emit_type_expr_tokens(&field.ty);
160        quote_in!(fields => $field_name: $field_ty,);
161    }
162
163    if decl.type_params.is_empty() {
164        quote_in!(tokens =>
165            #[allow(non_camel_case_types)]
166            #[derive(Debug, Clone)]
167            struct $name {
168                $fields
169            }
170        );
171    } else {
172        quote_in!(tokens =>
173            #[allow(non_camel_case_types)]
174            #[derive(Debug, Clone)]
175            struct $name<$params> {
176                $fields
177            }
178        );
179    }
180
181    tokens
182}
183
184fn emit_enum_decl_tokens(decl: &MirEnumDecl) -> rust::Tokens {
185    let mut tokens = rust::Tokens::new();
186    let name = mangle_struct_ident(&decl.name);
187
188    let mut params = rust::Tokens::new();
189    for (index, param) in decl.type_params.iter().enumerate() {
190        if index > 0 {
191            quote_in!(params => ,);
192        }
193        let param = param.clone();
194        quote_in!(params => $param);
195    }
196
197    let mut variants = rust::Tokens::new();
198    for variant in &decl.variants {
199        let variant_name = variant.name.clone();
200        if variant.payload.is_empty() {
201            quote_in!(variants => $variant_name,);
202        } else {
203            let mut payload_tokens = rust::Tokens::new();
204            for (index, payload) in variant.payload.iter().enumerate() {
205                if index > 0 {
206                    quote_in!(payload_tokens => ,);
207                }
208                let payload = emit_type_expr_tokens(payload);
209                quote_in!(payload_tokens => $payload);
210            }
211            quote_in!(variants => $variant_name($payload_tokens),);
212        }
213    }
214
215    if decl.type_params.is_empty() {
216        quote_in!(tokens =>
217            #[allow(non_camel_case_types)]
218            #[allow(dead_code)]
219            #[derive(Debug, Clone)]
220            enum $name {
221                $variants
222            }
223        );
224    } else {
225        quote_in!(tokens =>
226            #[allow(non_camel_case_types)]
227            #[allow(dead_code)]
228            #[derive(Debug, Clone)]
229            enum $name<$params> {
230                $variants
231            }
232        );
233    }
234
235    tokens
236}
237
238fn emit_function_tokens(
239    func_id: MirFuncId,
240    function: &MirFunction,
241    program_names: &ProgramNames,
242    function_names: &FunctionNames,
243    source_file: &str,
244    source_line: usize,
245) -> rust::Tokens {
246    let mut tokens = rust::Tokens::new();
247    let source_doc = format!("source: {}:{}", source_file, source_line);
248    let func_ident = program_names.function_ident(func_id);
249    let return_ty = emit_ty_tokens(&function.return_ty);
250
251    let structured_body = structurize_cfg_body(function);
252
253    let mut mut_locals = HashSet::new();
254    collect_mutable_locals(&structured_body, &mut mut_locals);
255
256    let mut read_locals = HashSet::new();
257    collect_read_locals(&structured_body, &mut read_locals);
258
259    let (ref_ro_locals, ref_mut_locals) = collect_ref_locals(function);
260    let slice_locals = collect_slice_locals(function);
261
262    let mut param_inits = rust::Tokens::new();
263    for binding in function.param_bindings() {
264        if binding.requires_materialization() {
265            assert!(
266                binding.expects_view(),
267                "only view params may require Rust materialization"
268            );
269            assert!(
270                !binding.local.is_place(),
271                "materialized Rust params should lower to value locals"
272            );
273        }
274        if !read_locals.contains(&binding.local_id) {
275            continue;
276        }
277        let local_name = function_names.local_ident(binding.local_id);
278        let arg_name = incoming_param_ident(function_names, binding.local_id, binding.index, true);
279        if binding.expects_view() || binding.expects_mut_borrow() {
280            if binding.materializes_view_from_owned() {
281                quote_in!(param_inits => let $local_name = (*$arg_name).clone(););
282            } else {
283                quote_in!(param_inits => let $local_name = $arg_name;);
284            }
285        } else {
286            let mut_kw = if mut_locals.contains(&binding.local_id) {
287                quote!(mut)
288            } else {
289                quote!()
290            };
291            quote_in!(param_inits => let $mut_kw $local_name = $arg_name;);
292        }
293    }
294
295    let body_tokens = emit_stmt_list_tokens(
296        function,
297        &structured_body,
298        source_file,
299        source_line,
300        &mut_locals,
301        &read_locals,
302        &ref_ro_locals,
303        &ref_mut_locals,
304        &slice_locals,
305        function_names,
306        program_names,
307    );
308
309    let params = emit_function_params(function, function_names, &read_locals);
310
311    quote_in!(tokens =>
312        #[doc = $(quoted(source_doc.as_str()))]
313        #[allow(non_snake_case)]
314        #[allow(unused_assignments)]
315        fn $func_ident($params) -> $return_ty {
316            $param_inits
317            $body_tokens
318        }
319    );
320
321    tokens
322}
323
324#[derive(Debug, Clone)]
325enum CfgEnd {
326    Jump {
327        target: MirBlockId,
328        args: Vec<MirLocalId>,
329    },
330    Return,
331}
332
333fn structurize_cfg_body(function: &MirFunction) -> Vec<MirStmt> {
334    let mut active_loops = Vec::new();
335    let (stmts, end) = structurize_from_block(function, function.entry, &mut active_loops);
336    if !matches!(end, CfgEnd::Return) {
337        panic!("cfg structurizer expected function to end in return");
338    }
339    stmts
340}
341
342fn structurize_from_block(
343    function: &MirFunction,
344    block_id: MirBlockId,
345    active_loops: &mut Vec<MirBlockId>,
346) -> (Vec<MirStmt>, CfgEnd) {
347    let block = &function.blocks[block_id];
348    let mut out = Vec::new();
349    for instr in &block.instrs {
350        out.push(MirStmt::Instr(instr.clone()));
351    }
352
353    match &block.terminator {
354        MirTerminator::Return { value } => {
355            out.push(MirStmt::Return { value: *value });
356            (out, CfgEnd::Return)
357        }
358        MirTerminator::Jump { target, args } => {
359            if active_loops.contains(target) {
360                return (
361                    out,
362                    CfgEnd::Jump {
363                        target: *target,
364                        args: args.clone(),
365                    },
366                );
367            }
368
369            if let Some((while_stmt, exit)) =
370                structurize_while_from_jump(function, *target, args, active_loops)
371            {
372                out.push(while_stmt);
373                let (rest, end) = structurize_from_block(function, exit, active_loops);
374                out.extend(rest);
375                (out, end)
376            } else {
377                (
378                    out,
379                    CfgEnd::Jump {
380                        target: *target,
381                        args: args.clone(),
382                    },
383                )
384            }
385        }
386        MirTerminator::Branch {
387            cond,
388            then_target,
389            else_target,
390            ..
391        } => {
392            let (then_body, then_end) =
393                structurize_from_block(function, *then_target, active_loops);
394            let (else_body, else_end) =
395                structurize_from_block(function, *else_target, active_loops);
396
397            out.push(MirStmt::If {
398                cond: *cond,
399                then_body,
400                else_body,
401            });
402
403            match (then_end, else_end) {
404                (CfgEnd::Return, CfgEnd::Return) => (out, CfgEnd::Return),
405                (
406                    CfgEnd::Jump {
407                        target: then_merge,
408                        args: then_args,
409                    },
410                    CfgEnd::Jump {
411                        target: else_merge,
412                        args: else_args,
413                    },
414                ) if then_merge == else_merge && then_args.is_empty() && else_args.is_empty() => {
415                    let (rest, end) = structurize_from_block(function, then_merge, active_loops);
416                    out.extend(rest);
417                    (out, end)
418                }
419                (CfgEnd::Return, CfgEnd::Jump { target, args }) if args.is_empty() => {
420                    let (rest, end) = structurize_from_block(function, target, active_loops);
421                    out.extend(rest);
422                    (out, end)
423                }
424                (CfgEnd::Jump { target, args }, CfgEnd::Return) if args.is_empty() => {
425                    let (rest, end) = structurize_from_block(function, target, active_loops);
426                    out.extend(rest);
427                    (out, end)
428                }
429                _ => panic!("unsupported cfg branch shape for structurization"),
430            }
431        }
432    }
433}
434
435fn structurize_while_from_jump(
436    function: &MirFunction,
437    header: MirBlockId,
438    init_args: &[MirLocalId],
439    active_loops: &mut Vec<MirBlockId>,
440) -> Option<(MirStmt, MirBlockId)> {
441    let header_block = &function.blocks[header];
442    let (cond, body_start, exit) = match &header_block.terminator {
443        MirTerminator::Branch {
444            cond,
445            then_target,
446            then_args,
447            else_target,
448            else_args,
449        } if then_args.is_empty() && else_args.is_empty() => (*cond, *then_target, *else_target),
450        _ => return None,
451    };
452
453    active_loops.push(header);
454    let (body, body_end) = structurize_from_block(function, body_start, active_loops);
455    let popped = active_loops.pop();
456    assert_eq!(popped, Some(header));
457    let step_args = match body_end {
458        CfgEnd::Jump { target, args } if target == header => args,
459        CfgEnd::Return => Vec::new(),
460        _ => return None,
461    };
462
463    let cond_body = header_block
464        .instrs
465        .iter()
466        .cloned()
467        .map(MirStmt::Instr)
468        .collect::<Vec<_>>();
469
470    Some((
471        MirStmt::While {
472            loop_params: header_block.params.clone(),
473            init_args: init_args.to_vec(),
474            cond_body,
475            cond,
476            body,
477            step_args,
478        },
479        exit,
480    ))
481}
482
483pub(crate) fn emit_stmt_list_tokens(
484    function: &MirFunction,
485    stmts: &[MirStmt],
486    source_file: &str,
487    source_line: usize,
488    mut_locals: &HashSet<MirLocalId>,
489    read_locals: &HashSet<MirLocalId>,
490    ref_ro_locals: &HashSet<MirLocalId>,
491    ref_mut_locals: &HashSet<MirLocalId>,
492    slice_locals: &HashSet<MirLocalId>,
493    function_names: &FunctionNames,
494    program_names: &ProgramNames,
495) -> rust::Tokens {
496    let mut tokens = rust::Tokens::new();
497    let mut index = 0;
498    while index < stmts.len() {
499        if let Some((fused_tokens, consumed)) = try_emit_pointer_if_binding_tokens(
500            function,
501            stmts,
502            index,
503            source_file,
504            source_line,
505            mut_locals,
506            read_locals,
507            ref_ro_locals,
508            ref_mut_locals,
509            slice_locals,
510            function_names,
511            program_names,
512        ) {
513            quote_in!(tokens => $fused_tokens);
514            index += consumed;
515            continue;
516        }
517
518        let stmt_tokens = emit_stmt_tokens(
519            function,
520            &stmts[index],
521            source_file,
522            source_line,
523            mut_locals,
524            read_locals,
525            ref_ro_locals,
526            ref_mut_locals,
527            slice_locals,
528            function_names,
529            program_names,
530        );
531        quote_in!(tokens => $stmt_tokens);
532        index += 1;
533    }
534    tokens
535}
536
537fn try_emit_pointer_if_binding_tokens(
538    function: &MirFunction,
539    stmts: &[MirStmt],
540    index: usize,
541    source_file: &str,
542    source_line: usize,
543    mut_locals: &HashSet<MirLocalId>,
544    read_locals: &HashSet<MirLocalId>,
545    ref_ro_locals: &HashSet<MirLocalId>,
546    ref_mut_locals: &HashSet<MirLocalId>,
547    slice_locals: &HashSet<MirLocalId>,
548    function_names: &FunctionNames,
549    program_names: &ProgramNames,
550) -> Option<(rust::Tokens, usize)> {
551    let MirStmt::Instr(MirInstr::PointerIsSome { pointer, dst }) = stmts.get(index)? else {
552        return None;
553    };
554    let MirStmt::If {
555        cond,
556        then_body,
557        else_body,
558    } = stmts.get(index + 1)?
559    else {
560        return None;
561    };
562    if cond != dst {
563        return None;
564    }
565    if then_body.len() < 2 {
566        return None;
567    }
568
569    let mut tokens = rust::Tokens::new();
570    let pointer_name = function_names.local_ident(*pointer);
571    let then_rest;
572
573    if let (
574        MirStmt::Instr(MirInstr::PointerBorrowRo {
575            pointer: borrow_pointer,
576            dst: borrowed_ref,
577        }),
578        MirStmt::Instr(MirInstr::DerefCopy {
579            src: deref_src,
580            dst: binding_local,
581        }),
582    ) = (&then_body[0], &then_body[1])
583    {
584        if borrow_pointer != pointer || deref_src != borrowed_ref {
585            return None;
586        }
587        let borrowed_ref_name = function_names.local_ident(*borrowed_ref);
588        let binding_name = function_names.local_ident(*binding_local);
589        let mut_kw = if mut_locals.contains(binding_local) {
590            quote!(mut)
591        } else {
592            quote!()
593        };
594        let binding_init = if read_locals.contains(binding_local) {
595            quote!(let $mut_kw $binding_name = (*$borrowed_ref_name).clone();)
596        } else {
597            quote!(let _ = (*$borrowed_ref_name).clone();)
598        };
599
600        then_rest = &then_body[2..];
601        let then_tokens = emit_stmt_list_tokens(
602            function,
603            then_rest,
604            source_file,
605            source_line,
606            mut_locals,
607            read_locals,
608            ref_ro_locals,
609            ref_mut_locals,
610            slice_locals,
611            function_names,
612            program_names,
613        );
614        let else_tokens = emit_stmt_list_tokens(
615            function,
616            else_body,
617            source_file,
618            source_line,
619            mut_locals,
620            read_locals,
621            ref_ro_locals,
622            ref_mut_locals,
623            slice_locals,
624            function_names,
625            program_names,
626        );
627        quote_in!(tokens =>
628            if let Some($borrowed_ref_name) = $pointer_name.as_ref() {
629                let $borrowed_ref_name = $borrowed_ref_name.borrow();
630                $binding_init
631                $then_tokens
632            } else {
633                $else_tokens
634            }
635        );
636        return Some((tokens, 2));
637    }
638
639    if let (
640        MirStmt::Instr(MirInstr::PointerBorrowRo {
641            pointer: borrow_pointer,
642            dst: borrowed_ref,
643        }),
644        MirStmt::Instr(MirInstr::Move {
645            src: move_src,
646            dst: binding_local,
647        }),
648    ) = (&then_body[0], &then_body[1])
649    {
650        if borrow_pointer != pointer || move_src != borrowed_ref {
651            return None;
652        }
653        let borrowed_ref_name = function_names.local_ident(*borrowed_ref);
654        let binding_name = if read_locals.contains(binding_local) {
655            function_names.local_ident(*binding_local).to_owned()
656        } else {
657            "_".to_owned()
658        };
659        then_rest = &then_body[2..];
660        let then_tokens = emit_stmt_list_tokens(
661            function,
662            then_rest,
663            source_file,
664            source_line,
665            mut_locals,
666            read_locals,
667            ref_ro_locals,
668            ref_mut_locals,
669            slice_locals,
670            function_names,
671            program_names,
672        );
673        let else_tokens = emit_stmt_list_tokens(
674            function,
675            else_body,
676            source_file,
677            source_line,
678            mut_locals,
679            read_locals,
680            ref_ro_locals,
681            ref_mut_locals,
682            slice_locals,
683            function_names,
684            program_names,
685        );
686        quote_in!(tokens =>
687            if let Some($borrowed_ref_name) = $pointer_name.as_ref() {
688                let $binding_name = $borrowed_ref_name.borrow();
689                $then_tokens
690            } else {
691                $else_tokens
692            }
693        );
694        return Some((tokens, 2));
695    }
696
697    if then_body.len() >= 2 {
698        if let (
699            MirStmt::Instr(MirInstr::PointerBorrowMut {
700                pointer: borrow_pointer,
701                dst: borrowed_ref,
702            }),
703            MirStmt::Instr(MirInstr::Move {
704                src: move_src,
705                dst: binding_local,
706            }),
707        ) = (&then_body[0], &then_body[1])
708        {
709            if borrow_pointer != pointer || move_src != borrowed_ref {
710                return None;
711            }
712            let binding_name = if read_locals.contains(binding_local) {
713                function_names.local_ident(*binding_local).to_owned()
714            } else {
715                "_".to_owned()
716            };
717            let binding_ptr_name = synthetic_temp_ident(&format!("{}_ptr", binding_name));
718            let binding_ptr_pattern = binding_ptr_name.clone();
719            then_rest = &then_body[2..];
720            let then_tokens = emit_stmt_list_tokens(
721                function,
722                then_rest,
723                source_file,
724                source_line,
725                mut_locals,
726                read_locals,
727                ref_ro_locals,
728                ref_mut_locals,
729                slice_locals,
730                function_names,
731                program_names,
732            );
733            let else_tokens = emit_stmt_list_tokens(
734                function,
735                else_body,
736                source_file,
737                source_line,
738                mut_locals,
739                read_locals,
740                ref_ro_locals,
741                ref_mut_locals,
742                slice_locals,
743                function_names,
744                program_names,
745            );
746            quote_in!(tokens =>
747                if let Some($binding_ptr_pattern) = $pointer_name.as_mut() {
748                    let $binding_name = $binding_ptr_name.borrow_mut();
749                    $then_tokens
750                } else {
751                    $else_tokens
752                }
753            );
754            return Some((tokens, 2));
755        }
756    }
757
758    None
759}