1use super::*;
2
3#[derive(Debug, Clone)]
4struct SourceMap {
5 file: String,
6 function_lines: BTreeMap<String, usize>,
7}
8
9impl SourceMap {
10 #[cfg(test)]
11 fn unknown() -> Self {
12 Self {
13 file: "<unknown>".to_owned(),
14 function_lines: BTreeMap::new(),
15 }
16 }
17
18 fn from_function_lines(file: String, function_lines: &BTreeMap<String, usize>) -> Self {
19 Self {
20 file,
21 function_lines: function_lines.clone(),
22 }
23 }
24
25 fn function_line(&self, name: &str) -> usize {
26 self.function_lines.get(name).copied().unwrap_or(0)
27 }
28}
29
30#[derive(Debug, Error)]
32pub enum CodegenError {
33 #[error("failed to format generated rust source: {0}")]
35 Format(#[from] genco::fmt::Error),
36 #[error("failed to parse generated rust source for pretty formatting: {0}")]
38 PrettyFormat(#[from] syn::Error),
39 #[error("failed to write generated rust source: {0}")]
41 Io(#[from] std::io::Error),
42}
43
44pub fn emit_to_path(
46 program: &MirProgram,
47 path: &Path,
48 source_file: &Path,
49 function_lines: &BTreeMap<String, usize>,
50) -> Result<(), CodegenError> {
51 let source = emit_program_with_function_lines(program, source_file, function_lines)?;
52 std::fs::write(path, source)?;
53 Ok(())
54}
55
56#[cfg(test)]
58pub fn emit_program(program: &MirProgram) -> Result<String, CodegenError> {
59 emit_program_with_map(program, &SourceMap::unknown())
60}
61
62pub fn emit_program_with_function_lines(
64 program: &MirProgram,
65 source_file: &Path,
66 function_lines: &BTreeMap<String, usize>,
67) -> Result<String, CodegenError> {
68 let source_map =
69 SourceMap::from_function_lines(source_file.display().to_string(), function_lines);
70 emit_program_with_map(program, &source_map)
71}
72
73fn emit_program_with_map(
74 program: &MirProgram,
75 source_map: &SourceMap,
76) -> Result<String, CodegenError> {
77 for (_, function) in program.functions.iter() {
78 function.assert_valid();
79 }
80
81 let program_names = ProgramNames::from_program(program);
82
83 let mut tokens: rust::Tokens = quote! {
84 #[derive(Debug, Clone)]
85 pub enum RuntimeError {
87 MissingMain,
88 InvalidMainArity { actual: usize },
89 }
90 };
91
92 for decl in &program.structs {
93 let decl_tokens = emit_struct_decl_tokens(decl);
94 quote_in!(tokens => $decl_tokens);
95 }
96
97 for decl in &program.enums {
98 let decl_tokens = emit_enum_decl_tokens(decl);
99 quote_in!(tokens => $decl_tokens);
100 }
101
102 for (func_id, function) in program.functions.iter() {
103 let function_names = FunctionNames::from_function(function);
104 let function_tokens = emit_function_tokens(
105 func_id,
106 function,
107 &program_names,
108 &function_names,
109 &source_map.file,
110 source_map.function_line(&function.name),
111 );
112 quote_in!(tokens => $function_tokens);
113 }
114
115 if let Some(main_id) = program.function_names.get("main") {
116 let main_arity = program.functions[*main_id].arity;
117 let main_ident = program_names.function_ident(*main_id);
118 quote_in!(tokens =>
119 pub fn run_main() -> Result<(), RuntimeError> {
121 if $main_arity != 0 {
122 Err(RuntimeError::InvalidMainArity { actual: $main_arity })
123 } else {
124 let _ = $main_ident();
125 ruka_runtime::ptr::assert_no_leaks();
126 Ok(())
127 }
128 }
129 );
130 } else {
131 quote_in!(tokens =>
132 pub fn run_main() -> Result<(), RuntimeError> {
134 Err(RuntimeError::MissingMain)
135 }
136 );
137 }
138
139 let source = tokens.to_file_string()?;
140 let syntax = parse_file(&source)?;
141 Ok(prettyplease::unparse(&syntax))
142}
143
144fn emit_struct_decl_tokens(decl: &MirStructDecl) -> rust::Tokens {
145 let mut tokens = rust::Tokens::new();
146 let name = mangle_struct_ident(&decl.name);
147 let mut params = rust::Tokens::new();
148 for (index, param) in decl.type_params.iter().enumerate() {
149 if index > 0 {
150 quote_in!(params => ,);
151 }
152 let param = param.clone();
153 quote_in!(params => $param);
154 }
155
156 let mut fields = rust::Tokens::new();
157 for field in &decl.fields {
158 let field_name = field.name.clone();
159 let field_ty = emit_type_expr_tokens(&field.ty);
160 quote_in!(fields => $field_name: $field_ty,);
161 }
162
163 if decl.type_params.is_empty() {
164 quote_in!(tokens =>
165 #[allow(non_camel_case_types)]
166 #[derive(Debug, Clone)]
167 struct $name {
168 $fields
169 }
170 );
171 } else {
172 quote_in!(tokens =>
173 #[allow(non_camel_case_types)]
174 #[derive(Debug, Clone)]
175 struct $name<$params> {
176 $fields
177 }
178 );
179 }
180
181 tokens
182}
183
184fn emit_enum_decl_tokens(decl: &MirEnumDecl) -> rust::Tokens {
185 let mut tokens = rust::Tokens::new();
186 let name = mangle_struct_ident(&decl.name);
187
188 let mut params = rust::Tokens::new();
189 for (index, param) in decl.type_params.iter().enumerate() {
190 if index > 0 {
191 quote_in!(params => ,);
192 }
193 let param = param.clone();
194 quote_in!(params => $param);
195 }
196
197 let mut variants = rust::Tokens::new();
198 for variant in &decl.variants {
199 let variant_name = variant.name.clone();
200 if variant.payload.is_empty() {
201 quote_in!(variants => $variant_name,);
202 } else {
203 let mut payload_tokens = rust::Tokens::new();
204 for (index, payload) in variant.payload.iter().enumerate() {
205 if index > 0 {
206 quote_in!(payload_tokens => ,);
207 }
208 let payload = emit_type_expr_tokens(payload);
209 quote_in!(payload_tokens => $payload);
210 }
211 quote_in!(variants => $variant_name($payload_tokens),);
212 }
213 }
214
215 if decl.type_params.is_empty() {
216 quote_in!(tokens =>
217 #[allow(non_camel_case_types)]
218 #[allow(dead_code)]
219 #[derive(Debug, Clone)]
220 enum $name {
221 $variants
222 }
223 );
224 } else {
225 quote_in!(tokens =>
226 #[allow(non_camel_case_types)]
227 #[allow(dead_code)]
228 #[derive(Debug, Clone)]
229 enum $name<$params> {
230 $variants
231 }
232 );
233 }
234
235 tokens
236}
237
238fn emit_function_tokens(
239 func_id: MirFuncId,
240 function: &MirFunction,
241 program_names: &ProgramNames,
242 function_names: &FunctionNames,
243 source_file: &str,
244 source_line: usize,
245) -> rust::Tokens {
246 let mut tokens = rust::Tokens::new();
247 let source_doc = format!("source: {}:{}", source_file, source_line);
248 let func_ident = program_names.function_ident(func_id);
249 let return_ty = emit_ty_tokens(&function.return_ty);
250
251 let structured_body = structurize_cfg_body(function);
252
253 let mut mut_locals = HashSet::new();
254 collect_mutable_locals(&structured_body, &mut mut_locals);
255
256 let mut read_locals = HashSet::new();
257 collect_read_locals(&structured_body, &mut read_locals);
258
259 let (ref_ro_locals, ref_mut_locals) = collect_ref_locals(function);
260 let slice_locals = collect_slice_locals(function);
261
262 let mut param_inits = rust::Tokens::new();
263 for binding in function.param_bindings() {
264 if binding.requires_materialization() {
265 assert!(
266 binding.expects_view(),
267 "only view params may require Rust materialization"
268 );
269 assert!(
270 !binding.local.is_place(),
271 "materialized Rust params should lower to value locals"
272 );
273 }
274 if !read_locals.contains(&binding.local_id) {
275 continue;
276 }
277 let local_name = function_names.local_ident(binding.local_id);
278 let arg_name = incoming_param_ident(function_names, binding.local_id, binding.index, true);
279 if binding.expects_view() || binding.expects_mut_borrow() {
280 if binding.materializes_view_from_owned() {
281 quote_in!(param_inits => let $local_name = (*$arg_name).clone(););
282 } else {
283 quote_in!(param_inits => let $local_name = $arg_name;);
284 }
285 } else {
286 let mut_kw = if mut_locals.contains(&binding.local_id) {
287 quote!(mut)
288 } else {
289 quote!()
290 };
291 quote_in!(param_inits => let $mut_kw $local_name = $arg_name;);
292 }
293 }
294
295 let body_tokens = emit_stmt_list_tokens(
296 function,
297 &structured_body,
298 source_file,
299 source_line,
300 &mut_locals,
301 &read_locals,
302 &ref_ro_locals,
303 &ref_mut_locals,
304 &slice_locals,
305 function_names,
306 program_names,
307 );
308
309 let params = emit_function_params(function, function_names, &read_locals);
310
311 quote_in!(tokens =>
312 #[doc = $(quoted(source_doc.as_str()))]
313 #[allow(non_snake_case)]
314 #[allow(unused_assignments)]
315 fn $func_ident($params) -> $return_ty {
316 $param_inits
317 $body_tokens
318 }
319 );
320
321 tokens
322}
323
324#[derive(Debug, Clone)]
325enum CfgEnd {
326 Jump {
327 target: MirBlockId,
328 args: Vec<MirLocalId>,
329 },
330 Return,
331}
332
333fn structurize_cfg_body(function: &MirFunction) -> Vec<MirStmt> {
334 let mut active_loops = Vec::new();
335 let (stmts, end) = structurize_from_block(function, function.entry, &mut active_loops);
336 if !matches!(end, CfgEnd::Return) {
337 panic!("cfg structurizer expected function to end in return");
338 }
339 stmts
340}
341
342fn structurize_from_block(
343 function: &MirFunction,
344 block_id: MirBlockId,
345 active_loops: &mut Vec<MirBlockId>,
346) -> (Vec<MirStmt>, CfgEnd) {
347 let block = &function.blocks[block_id];
348 let mut out = Vec::new();
349 for instr in &block.instrs {
350 out.push(MirStmt::Instr(instr.clone()));
351 }
352
353 match &block.terminator {
354 MirTerminator::Return { value } => {
355 out.push(MirStmt::Return { value: *value });
356 (out, CfgEnd::Return)
357 }
358 MirTerminator::Jump { target, args } => {
359 if active_loops.contains(target) {
360 return (
361 out,
362 CfgEnd::Jump {
363 target: *target,
364 args: args.clone(),
365 },
366 );
367 }
368
369 if let Some((while_stmt, exit)) =
370 structurize_while_from_jump(function, *target, args, active_loops)
371 {
372 out.push(while_stmt);
373 let (rest, end) = structurize_from_block(function, exit, active_loops);
374 out.extend(rest);
375 (out, end)
376 } else {
377 (
378 out,
379 CfgEnd::Jump {
380 target: *target,
381 args: args.clone(),
382 },
383 )
384 }
385 }
386 MirTerminator::Branch {
387 cond,
388 then_target,
389 else_target,
390 ..
391 } => {
392 let (then_body, then_end) =
393 structurize_from_block(function, *then_target, active_loops);
394 let (else_body, else_end) =
395 structurize_from_block(function, *else_target, active_loops);
396
397 out.push(MirStmt::If {
398 cond: *cond,
399 then_body,
400 else_body,
401 });
402
403 match (then_end, else_end) {
404 (CfgEnd::Return, CfgEnd::Return) => (out, CfgEnd::Return),
405 (
406 CfgEnd::Jump {
407 target: then_merge,
408 args: then_args,
409 },
410 CfgEnd::Jump {
411 target: else_merge,
412 args: else_args,
413 },
414 ) if then_merge == else_merge && then_args.is_empty() && else_args.is_empty() => {
415 let (rest, end) = structurize_from_block(function, then_merge, active_loops);
416 out.extend(rest);
417 (out, end)
418 }
419 (CfgEnd::Return, CfgEnd::Jump { target, args }) if args.is_empty() => {
420 let (rest, end) = structurize_from_block(function, target, active_loops);
421 out.extend(rest);
422 (out, end)
423 }
424 (CfgEnd::Jump { target, args }, CfgEnd::Return) if args.is_empty() => {
425 let (rest, end) = structurize_from_block(function, target, active_loops);
426 out.extend(rest);
427 (out, end)
428 }
429 _ => panic!("unsupported cfg branch shape for structurization"),
430 }
431 }
432 }
433}
434
435fn structurize_while_from_jump(
436 function: &MirFunction,
437 header: MirBlockId,
438 init_args: &[MirLocalId],
439 active_loops: &mut Vec<MirBlockId>,
440) -> Option<(MirStmt, MirBlockId)> {
441 let header_block = &function.blocks[header];
442 let (cond, body_start, exit) = match &header_block.terminator {
443 MirTerminator::Branch {
444 cond,
445 then_target,
446 then_args,
447 else_target,
448 else_args,
449 } if then_args.is_empty() && else_args.is_empty() => (*cond, *then_target, *else_target),
450 _ => return None,
451 };
452
453 active_loops.push(header);
454 let (body, body_end) = structurize_from_block(function, body_start, active_loops);
455 let popped = active_loops.pop();
456 assert_eq!(popped, Some(header));
457 let step_args = match body_end {
458 CfgEnd::Jump { target, args } if target == header => args,
459 CfgEnd::Return => Vec::new(),
460 _ => return None,
461 };
462
463 let cond_body = header_block
464 .instrs
465 .iter()
466 .cloned()
467 .map(MirStmt::Instr)
468 .collect::<Vec<_>>();
469
470 Some((
471 MirStmt::While {
472 loop_params: header_block.params.clone(),
473 init_args: init_args.to_vec(),
474 cond_body,
475 cond,
476 body,
477 step_args,
478 },
479 exit,
480 ))
481}
482
483pub(crate) fn emit_stmt_list_tokens(
484 function: &MirFunction,
485 stmts: &[MirStmt],
486 source_file: &str,
487 source_line: usize,
488 mut_locals: &HashSet<MirLocalId>,
489 read_locals: &HashSet<MirLocalId>,
490 ref_ro_locals: &HashSet<MirLocalId>,
491 ref_mut_locals: &HashSet<MirLocalId>,
492 slice_locals: &HashSet<MirLocalId>,
493 function_names: &FunctionNames,
494 program_names: &ProgramNames,
495) -> rust::Tokens {
496 let mut tokens = rust::Tokens::new();
497 let mut index = 0;
498 while index < stmts.len() {
499 if let Some((fused_tokens, consumed)) = try_emit_pointer_if_binding_tokens(
500 function,
501 stmts,
502 index,
503 source_file,
504 source_line,
505 mut_locals,
506 read_locals,
507 ref_ro_locals,
508 ref_mut_locals,
509 slice_locals,
510 function_names,
511 program_names,
512 ) {
513 quote_in!(tokens => $fused_tokens);
514 index += consumed;
515 continue;
516 }
517
518 let stmt_tokens = emit_stmt_tokens(
519 function,
520 &stmts[index],
521 source_file,
522 source_line,
523 mut_locals,
524 read_locals,
525 ref_ro_locals,
526 ref_mut_locals,
527 slice_locals,
528 function_names,
529 program_names,
530 );
531 quote_in!(tokens => $stmt_tokens);
532 index += 1;
533 }
534 tokens
535}
536
537fn try_emit_pointer_if_binding_tokens(
538 function: &MirFunction,
539 stmts: &[MirStmt],
540 index: usize,
541 source_file: &str,
542 source_line: usize,
543 mut_locals: &HashSet<MirLocalId>,
544 read_locals: &HashSet<MirLocalId>,
545 ref_ro_locals: &HashSet<MirLocalId>,
546 ref_mut_locals: &HashSet<MirLocalId>,
547 slice_locals: &HashSet<MirLocalId>,
548 function_names: &FunctionNames,
549 program_names: &ProgramNames,
550) -> Option<(rust::Tokens, usize)> {
551 let MirStmt::Instr(MirInstr::PointerIsSome { pointer, dst }) = stmts.get(index)? else {
552 return None;
553 };
554 let MirStmt::If {
555 cond,
556 then_body,
557 else_body,
558 } = stmts.get(index + 1)?
559 else {
560 return None;
561 };
562 if cond != dst {
563 return None;
564 }
565 if then_body.len() < 2 {
566 return None;
567 }
568
569 let mut tokens = rust::Tokens::new();
570 let pointer_name = function_names.local_ident(*pointer);
571 let then_rest;
572
573 if let (
574 MirStmt::Instr(MirInstr::PointerBorrowRo {
575 pointer: borrow_pointer,
576 dst: borrowed_ref,
577 }),
578 MirStmt::Instr(MirInstr::DerefCopy {
579 src: deref_src,
580 dst: binding_local,
581 }),
582 ) = (&then_body[0], &then_body[1])
583 {
584 if borrow_pointer != pointer || deref_src != borrowed_ref {
585 return None;
586 }
587 let borrowed_ref_name = function_names.local_ident(*borrowed_ref);
588 let binding_name = function_names.local_ident(*binding_local);
589 let mut_kw = if mut_locals.contains(binding_local) {
590 quote!(mut)
591 } else {
592 quote!()
593 };
594 let binding_init = if read_locals.contains(binding_local) {
595 quote!(let $mut_kw $binding_name = (*$borrowed_ref_name).clone();)
596 } else {
597 quote!(let _ = (*$borrowed_ref_name).clone();)
598 };
599
600 then_rest = &then_body[2..];
601 let then_tokens = emit_stmt_list_tokens(
602 function,
603 then_rest,
604 source_file,
605 source_line,
606 mut_locals,
607 read_locals,
608 ref_ro_locals,
609 ref_mut_locals,
610 slice_locals,
611 function_names,
612 program_names,
613 );
614 let else_tokens = emit_stmt_list_tokens(
615 function,
616 else_body,
617 source_file,
618 source_line,
619 mut_locals,
620 read_locals,
621 ref_ro_locals,
622 ref_mut_locals,
623 slice_locals,
624 function_names,
625 program_names,
626 );
627 quote_in!(tokens =>
628 if let Some($borrowed_ref_name) = $pointer_name.as_ref() {
629 let $borrowed_ref_name = $borrowed_ref_name.borrow();
630 $binding_init
631 $then_tokens
632 } else {
633 $else_tokens
634 }
635 );
636 return Some((tokens, 2));
637 }
638
639 if let (
640 MirStmt::Instr(MirInstr::PointerBorrowRo {
641 pointer: borrow_pointer,
642 dst: borrowed_ref,
643 }),
644 MirStmt::Instr(MirInstr::Move {
645 src: move_src,
646 dst: binding_local,
647 }),
648 ) = (&then_body[0], &then_body[1])
649 {
650 if borrow_pointer != pointer || move_src != borrowed_ref {
651 return None;
652 }
653 let borrowed_ref_name = function_names.local_ident(*borrowed_ref);
654 let binding_name = if read_locals.contains(binding_local) {
655 function_names.local_ident(*binding_local).to_owned()
656 } else {
657 "_".to_owned()
658 };
659 then_rest = &then_body[2..];
660 let then_tokens = emit_stmt_list_tokens(
661 function,
662 then_rest,
663 source_file,
664 source_line,
665 mut_locals,
666 read_locals,
667 ref_ro_locals,
668 ref_mut_locals,
669 slice_locals,
670 function_names,
671 program_names,
672 );
673 let else_tokens = emit_stmt_list_tokens(
674 function,
675 else_body,
676 source_file,
677 source_line,
678 mut_locals,
679 read_locals,
680 ref_ro_locals,
681 ref_mut_locals,
682 slice_locals,
683 function_names,
684 program_names,
685 );
686 quote_in!(tokens =>
687 if let Some($borrowed_ref_name) = $pointer_name.as_ref() {
688 let $binding_name = $borrowed_ref_name.borrow();
689 $then_tokens
690 } else {
691 $else_tokens
692 }
693 );
694 return Some((tokens, 2));
695 }
696
697 if then_body.len() >= 2 {
698 if let (
699 MirStmt::Instr(MirInstr::PointerBorrowMut {
700 pointer: borrow_pointer,
701 dst: borrowed_ref,
702 }),
703 MirStmt::Instr(MirInstr::Move {
704 src: move_src,
705 dst: binding_local,
706 }),
707 ) = (&then_body[0], &then_body[1])
708 {
709 if borrow_pointer != pointer || move_src != borrowed_ref {
710 return None;
711 }
712 let binding_name = if read_locals.contains(binding_local) {
713 function_names.local_ident(*binding_local).to_owned()
714 } else {
715 "_".to_owned()
716 };
717 let binding_ptr_name = synthetic_temp_ident(&format!("{}_ptr", binding_name));
718 let binding_ptr_pattern = binding_ptr_name.clone();
719 then_rest = &then_body[2..];
720 let then_tokens = emit_stmt_list_tokens(
721 function,
722 then_rest,
723 source_file,
724 source_line,
725 mut_locals,
726 read_locals,
727 ref_ro_locals,
728 ref_mut_locals,
729 slice_locals,
730 function_names,
731 program_names,
732 );
733 let else_tokens = emit_stmt_list_tokens(
734 function,
735 else_body,
736 source_file,
737 source_line,
738 mut_locals,
739 read_locals,
740 ref_ro_locals,
741 ref_mut_locals,
742 slice_locals,
743 function_names,
744 program_names,
745 );
746 quote_in!(tokens =>
747 if let Some($binding_ptr_pattern) = $pointer_name.as_mut() {
748 let $binding_name = $binding_ptr_name.borrow_mut();
749 $then_tokens
750 } else {
751 $else_tokens
752 }
753 );
754 return Some((tokens, 2));
755 }
756 }
757
758 None
759}