ruka_codegen_wasm/
lower_function.rs

1use super::*;
2
3pub(crate) fn lower_function(
4    module: &mut walrus::Module,
5    wasm_func_id: FunctionId,
6    function: &ruka_mir::MirFunction,
7    module_ctx: &ModuleLowerCtx<'_>,
8) -> Result<(), LowerError> {
9    let local_function = module.funcs.get_mut(wasm_func_id).kind.unwrap_local_mut();
10    let params = local_function.args.clone();
11    let entry_block = local_function.entry_block();
12    local_function.block_mut(entry_block).instrs.clear();
13
14    let mut local_indices = BTreeMap::<u32, Option<LocalId>>::new();
15    let mut local_runtime_types = BTreeMap::<u32, Option<ValType>>::new();
16    let mut local_tys = BTreeMap::<u32, Ty>::new();
17    let mut local_reprs = BTreeMap::<u32, ruka_mir::MirLocalRepr>::new();
18    let mut local_heap_ownership = BTreeMap::<u32, ruka_mir::MirHeapOwnership>::new();
19    let mut passthrough_place_locals = BTreeSet::<u32>::new();
20    let mut param_index = if function_returns_via_out_slot(&function.return_ty) {
21        1
22    } else {
23        0
24    };
25    for binding in function.param_bindings() {
26        if binding.requires_materialization() {
27            assert!(
28                binding.expects_view(),
29                "only view params may require WASM materialization"
30            );
31            assert!(
32                !binding.local.is_place(),
33                "materialized WASM params should lower to value locals"
34            );
35        }
36        let _ = local_tys.insert(binding.local_id.as_u32(), binding.local.ty.clone());
37        let _ = local_reprs.insert(binding.local_id.as_u32(), binding.local.repr);
38        let _ =
39            local_heap_ownership.insert(binding.local_id.as_u32(), binding.local.heap_ownership);
40        if matches!(binding.local.ty, Ty::Unit) {
41            local_indices.insert(binding.local_id.as_u32(), None);
42            local_runtime_types.insert(binding.local_id.as_u32(), None);
43        } else {
44            let runtime_local = params[param_index];
45            let inout_param = borrowed_param_uses_inout(binding);
46            let passthrough_place_param = inout_param && binding.expects_mut_borrow();
47            let runtime_ty = if passthrough_place_param {
48                ty_to_valtype(binding.semantic_ty())
49            } else {
50                ty_to_valtype(&binding.local.ty)
51            };
52            let source_ty = if inout_param {
53                ty_to_valtype(binding.semantic_ty())
54            } else {
55                param_binding_valtype(binding)
56            };
57            assert_eq!(
58                runtime_ty, source_ty,
59                "WASM param local/runtime type mismatch for `{}`",
60                function.name
61            );
62            assert!(
63                !binding.local.is_place() || passthrough_place_param || runtime_ty == ValType::I32,
64                "place-shaped locals must lower to pointer-sized runtime values"
65            );
66            if (binding.expects_view() && binding.local.is_place()) || passthrough_place_param {
67                let _ = passthrough_place_locals.insert(binding.local_id.as_u32());
68            }
69            local_indices.insert(binding.local_id.as_u32(), Some(runtime_local));
70            local_runtime_types.insert(binding.local_id.as_u32(), Some(runtime_ty));
71            param_index += 1;
72        }
73    }
74
75    for (local_id, info) in function.locals.iter() {
76        if function.params.contains(&local_id) {
77            continue;
78        }
79        let _ = local_tys.insert(local_id.as_u32(), info.ty.clone());
80        let _ = local_reprs.insert(local_id.as_u32(), info.repr);
81        let _ = local_heap_ownership.insert(local_id.as_u32(), info.heap_ownership);
82        if matches!(info.ty, Ty::Unit) {
83            local_indices.insert(local_id.as_u32(), None);
84            local_runtime_types.insert(local_id.as_u32(), None);
85            continue;
86        }
87        let runtime_local = module.locals.add(ty_to_valtype(&info.ty));
88        local_indices.insert(local_id.as_u32(), Some(runtime_local));
89        local_runtime_types.insert(local_id.as_u32(), Some(ty_to_valtype(&info.ty)));
90    }
91
92    let mut shadow_stack_offsets = BTreeMap::<u32, u32>::new();
93    let mut shadow_stack_frame_bytes = 0_u32;
94    for (local_id, info) in function.locals.iter() {
95        if function.params.contains(&local_id) || !should_shadow_stack_local(info) {
96            continue;
97        }
98        let aggregate_ty = match &info.ty {
99            Ty::RefRo(inner) | Ty::RefMut(inner) if matches!(inner.as_ref(), Ty::Slice(_)) => {
100                inner.as_ref()
101            }
102            _ => &info.ty,
103        };
104        let payload_bytes =
105            aggregate::aggregate_payload_bytes(aggregate_ty, module_ctx.structs, module_ctx.enums)?;
106        let slot_bytes = align_up_u32(ARRAY_DATA_OFFSET.saturating_add(payload_bytes), 8);
107        shadow_stack_frame_bytes = align_up_u32(shadow_stack_frame_bytes, 8);
108        let _ = shadow_stack_offsets.insert(local_id.as_u32(), shadow_stack_frame_bytes);
109        shadow_stack_frame_bytes = shadow_stack_frame_bytes.saturating_add(slot_bytes);
110    }
111
112    let scratch_i64_local = module.locals.add(ValType::I64);
113    let scratch_i32_local = module.locals.add(ValType::I32);
114    let scratch_i32_local_b = module.locals.add(ValType::I32);
115    let scratch_i32_local_c = module.locals.add(ValType::I32);
116    let scratch_i32_local_d = module.locals.add(ValType::I32);
117    let scratch_i32_local_e = module.locals.add(ValType::I32);
118    let pc_local = module.locals.add(ValType::I32);
119    let max_parallel_i64_params = function
120        .blocks
121        .values()
122        .map(|block| {
123            block
124                .params
125                .iter()
126                .filter(|local| {
127                    local_runtime_types
128                        .get(&local.as_u32())
129                        .and_then(|ty| *ty)
130                        .map(|ty| ty == ValType::I64)
131                        .unwrap_or(false)
132                })
133                .count()
134        })
135        .max()
136        .unwrap_or(0);
137    let max_parallel_i32_params = function
138        .blocks
139        .values()
140        .map(|block| {
141            block
142                .params
143                .iter()
144                .filter(|local| {
145                    local_runtime_types
146                        .get(&local.as_u32())
147                        .and_then(|ty| *ty)
148                        .map(|ty| ty == ValType::I32)
149                        .unwrap_or(false)
150                })
151                .count()
152        })
153        .max()
154        .unwrap_or(0);
155    let mut parallel_tmp_i64_locals = Vec::<LocalId>::with_capacity(max_parallel_i64_params);
156    for _ in 0..max_parallel_i64_params {
157        parallel_tmp_i64_locals.push(module.locals.add(ValType::I64));
158    }
159    let mut parallel_tmp_i32_locals = Vec::<LocalId>::with_capacity(max_parallel_i32_params);
160    for _ in 0..max_parallel_i32_params {
161        parallel_tmp_i32_locals.push(module.locals.add(ValType::I32));
162    }
163
164    let block_id_to_pc = function
165        .blocks
166        .keys()
167        .enumerate()
168        .map(|(index, block_id)| (block_id.as_u32(), index as i32))
169        .collect::<BTreeMap<u32, i32>>();
170
171    let mut lower_error = None::<LowerError>;
172    let returns_via_out_slot = function_returns_via_out_slot(&function.return_ty);
173    let out_slot_param_local = if returns_via_out_slot {
174        Some(params[0])
175    } else {
176        None
177    };
178    let out_slot_payload_bytes = if returns_via_out_slot {
179        aggregate::aggregate_payload_bytes(
180            &function.return_ty,
181            module_ctx.structs,
182            module_ctx.enums,
183        )?
184    } else {
185        0
186    };
187    let lower_ctx = LowerCtx {
188        function,
189        local_indices: &local_indices,
190        local_runtime_types: &local_runtime_types,
191        local_tys: &local_tys,
192        local_reprs: &local_reprs,
193        local_heap_ownership: &local_heap_ownership,
194        passthrough_place_locals: &passthrough_place_locals,
195        shadow_stack_offsets: &shadow_stack_offsets,
196        func_id_by_mir: module_ctx.func_id_by_mir,
197        callee_param_runtime_mask: module_ctx.callee_param_runtime_mask,
198        callee_param_runtime_types: module_ctx.callee_param_runtime_types,
199        callee_param_inout_mask: module_ctx.callee_param_inout_mask,
200        callee_returns_via_out_slot: module_ctx.callee_returns_via_out_slot,
201        callee_returns_runtime: module_ctx.callee_returns_runtime,
202        runtime: module_ctx.runtime,
203        pointer_drop_functions: module_ctx.pointer_drop_functions,
204        memory_id: module_ctx.memory_id,
205        string_literal_offsets: module_ctx.string_literal_offsets,
206        structs: module_ctx.structs,
207        enums: module_ctx.enums,
208        function_line: module_ctx
209            .function_lines
210            .get(&function.name)
211            .copied()
212            .and_then(|line| i32::try_from(line).ok())
213            .unwrap_or(0),
214        scratch_i64_local,
215        scratch_i32_local,
216        scratch_i32_local_b,
217        scratch_i32_local_c,
218        scratch_i32_local_d,
219        scratch_i32_local_e,
220    };
221    let mut body = local_function.builder_mut().instr_seq(entry_block);
222    if shadow_stack_frame_bytes > 0 {
223        let reserve =
224            runtime_function(module_ctx.runtime, wasm_api::RT_SHADOW_STACK_RESERVE_SYMBOL)?;
225        body.i32_const(int32_from_u32(
226            shadow_stack_frame_bytes,
227            "shadow stack frame bytes",
228        )?)
229        .call(reserve.function_id)
230        .local_set(scratch_i32_local);
231        for (local_id, offset) in &shadow_stack_offsets {
232            let local = runtime_local_index(&local_indices, *local_id, "shadow stack local")?;
233            body.local_get(scratch_i32_local)
234                .i32_const(int32_from_u32(*offset, "shadow stack local offset")?)
235                .binop(BinaryOp::I32Add)
236                .local_set(local);
237        }
238    }
239    body.i32_const(block_pc(&block_id_to_pc, function.entry)?)
240        .local_set(pc_local);
241    body.block(None, |outer| {
242        outer.loop_(None, |loop_| {
243            let loop_id = loop_.id();
244            for (block_id, block) in function.blocks.iter() {
245                if lower_error.is_some() {
246                    return;
247                }
248                let pc = match block_pc(&block_id_to_pc, block_id) {
249                    Ok(pc) => pc,
250                    Err(error) => {
251                        lower_error = Some(error);
252                        return;
253                    }
254                };
255                loop_
256                    .local_get(pc_local)
257                    .i32_const(pc)
258                    .binop(BinaryOp::I32Eq)
259                    .if_else(
260                        None,
261                        |then_| {
262                            for instr in &block.instrs {
263                                if let Err(error) = lower_instr(instr, then_, &lower_ctx) {
264                                    lower_error = Some(error);
265                                    return;
266                                }
267                            }
268
269                            let terminator_ctx = TerminatorCtx {
270                                function,
271                                local_indices: &local_indices,
272                                local_runtime_types: &local_runtime_types,
273                                local_heap_ownership: &local_heap_ownership,
274                                runtime: module_ctx.runtime,
275                                pointer_drop_functions: module_ctx.pointer_drop_functions,
276                                memory_id: module_ctx.memory_id,
277                                structs: module_ctx.structs,
278                                enums: module_ctx.enums,
279                                block_id_to_pc: &block_id_to_pc,
280                                pc_local,
281                                break_target: loop_id,
282                                out_slot_param_local,
283                                out_slot_payload_bytes,
284                                scratch_i32_local,
285                                scratch_i32_local_b,
286                                scratch_i32_local_c,
287                                scratch_i32_local_d,
288                                scratch_i64_local,
289                                shadow_stack_frame_bytes,
290                                parallel_tmp_i64_locals: &parallel_tmp_i64_locals,
291                                parallel_tmp_i32_locals: &parallel_tmp_i32_locals,
292                            };
293                            if let Err(error) =
294                                lower_terminator(&block.terminator, then_, &terminator_ctx)
295                            {
296                                lower_error = Some(error);
297                            }
298                        },
299                        |_else_| {},
300                    );
301            }
302            loop_.br(loop_id);
303        });
304    });
305    for binding in function.param_bindings() {
306        if !borrowed_param_uses_inout(binding) {
307            continue;
308        }
309        match ty_to_valtype(binding.semantic_ty()) {
310            ValType::I32 => {
311                body.i32_const(0);
312            }
313            ValType::I64 => {
314                body.i64_const(0);
315            }
316            ValType::F32 => {
317                body.f32_const(0.0);
318            }
319            ValType::F64 => {
320                body.f64_const(0.0);
321            }
322            _ => {}
323        }
324    }
325    if function.return_ty != Ty::Unit && !returns_via_out_slot {
326        match ty_to_valtype(&function.return_ty) {
327            ValType::I32 => {
328                body.i32_const(0);
329            }
330            ValType::I64 => {
331                body.i64_const(0);
332            }
333            ValType::F32 => {
334                body.f32_const(0.0);
335            }
336            ValType::F64 => {
337                body.f64_const(0.0);
338            }
339            _ => {}
340        }
341    }
342
343    if let Some(error) = lower_error {
344        return Err(error);
345    }
346    Ok(())
347}
348
349/// Lower a MIR terminator into WALRUS instructions.
350pub(crate) fn lower_terminator(
351    terminator: &ruka_mir::MirTerminator,
352    body: &mut walrus::InstrSeqBuilder,
353    ctx: &TerminatorCtx<'_>,
354) -> Result<(), LowerError> {
355    match terminator {
356        ruka_mir::MirTerminator::Jump { target, args } => {
357            assign_block_params(
358                ctx.function,
359                body,
360                ctx.local_indices,
361                ctx.local_runtime_types,
362                *target,
363                args,
364                ctx.parallel_tmp_i64_locals,
365                ctx.parallel_tmp_i32_locals,
366            )?;
367            body.i32_const(block_pc(ctx.block_id_to_pc, *target)?)
368                .local_set(ctx.pc_local)
369                .br(ctx.break_target);
370            Ok(())
371        }
372        ruka_mir::MirTerminator::Branch {
373            cond,
374            then_target,
375            then_args,
376            else_target,
377            else_args,
378        } => {
379            let cond_local =
380                runtime_local_index(ctx.local_indices, cond.as_u32(), "branch condition")?;
381            let then_pc = block_pc(ctx.block_id_to_pc, *then_target)?;
382            let else_pc = block_pc(ctx.block_id_to_pc, *else_target)?;
383            body.local_get(cond_local).if_else(
384                None,
385                |then_body| {
386                    let then_result = assign_block_params(
387                        ctx.function,
388                        then_body,
389                        ctx.local_indices,
390                        ctx.local_runtime_types,
391                        *then_target,
392                        then_args,
393                        ctx.parallel_tmp_i64_locals,
394                        ctx.parallel_tmp_i32_locals,
395                    );
396                    if then_result.is_ok() {
397                        then_body
398                            .i32_const(then_pc)
399                            .local_set(ctx.pc_local)
400                            .br(ctx.break_target);
401                    }
402                },
403                |else_body| {
404                    let else_result = assign_block_params(
405                        ctx.function,
406                        else_body,
407                        ctx.local_indices,
408                        ctx.local_runtime_types,
409                        *else_target,
410                        else_args,
411                        ctx.parallel_tmp_i64_locals,
412                        ctx.parallel_tmp_i32_locals,
413                    );
414                    if else_result.is_ok() {
415                        else_body
416                            .i32_const(else_pc)
417                            .local_set(ctx.pc_local)
418                            .br(ctx.break_target);
419                    }
420                },
421            );
422
423            assign_block_params(
424                ctx.function,
425                body,
426                ctx.local_indices,
427                ctx.local_runtime_types,
428                *then_target,
429                then_args,
430                ctx.parallel_tmp_i64_locals,
431                ctx.parallel_tmp_i32_locals,
432            )?;
433            assign_block_params(
434                ctx.function,
435                body,
436                ctx.local_indices,
437                ctx.local_runtime_types,
438                *else_target,
439                else_args,
440                ctx.parallel_tmp_i64_locals,
441                ctx.parallel_tmp_i32_locals,
442            )?;
443            Ok(())
444        }
445        ruka_mir::MirTerminator::Return { value } => {
446            if function_returns_via_out_slot(&ctx.function.return_ty) {
447                let src_ptr =
448                    runtime_local_index(ctx.local_indices, value.as_u32(), "return value ptr")?;
449                let dst_ptr =
450                    ctx.out_slot_param_local
451                        .ok_or(LowerError::UnsupportedInstruction(
452                            "missing out-slot return param",
453                        ))?;
454                emit_copy_bytes(
455                    body,
456                    ctx.memory_id,
457                    src_ptr,
458                    dst_ptr,
459                    ctx.out_slot_payload_bytes,
460                    ctx.scratch_i32_local,
461                    ctx.scratch_i32_local_b,
462                )?;
463            }
464            for (local_id, info) in ctx.function.locals.iter() {
465                if ctx.function.return_ty != Ty::Unit
466                    && !function_returns_via_out_slot(&ctx.function.return_ty)
467                    && local_id == *value
468                {
469                    continue;
470                }
471                let ownership = local_heap_ownership(
472                    ctx.local_heap_ownership,
473                    local_id.as_u32(),
474                    "return local ownership",
475                )?;
476                if !ownership.uses_heap_ops() {
477                    continue;
478                }
479                let Some(runtime_local) = runtime_local(ctx.local_indices, local_id.as_u32())?
480                else {
481                    continue;
482                };
483                match &info.ty {
484                    Ty::Pointer(item) => emit_pointer_release(
485                        body,
486                        ctx.runtime,
487                        ctx.memory_id,
488                        runtime_local,
489                        item.as_ref(),
490                        ctx.structs,
491                        ctx.enums,
492                        ctx.pointer_drop_functions,
493                        ctx.scratch_i32_local,
494                        ctx.scratch_i32_local_b,
495                        ctx.scratch_i32_local_d,
496                        ctx.scratch_i32_local_c,
497                        ctx.scratch_i64_local,
498                    )?,
499                    Ty::String => emit_string_release(
500                        body,
501                        ctx.runtime,
502                        ctx.memory_id,
503                        runtime_local,
504                        ctx.scratch_i32_local,
505                        ctx.scratch_i32_local_b,
506                        ctx.scratch_i32_local_c,
507                    )?,
508                    Ty::Array { item, .. } | Ty::Slice(item) => {
509                        if matches!(ownership, ruka_mir::MirHeapOwnership::OwnedShallow) {
510                            emit_array_release_shallow(
511                                body,
512                                ctx.runtime,
513                                ctx.memory_id,
514                                runtime_local,
515                                ctx.scratch_i32_local,
516                                ctx.scratch_i32_local_b,
517                                ctx.scratch_i32_local_c,
518                                ctx.scratch_i32_local_d,
519                                ctx.scratch_i64_local,
520                            )?;
521                        } else {
522                            emit_array_release(
523                                body,
524                                ctx.runtime,
525                                ctx.memory_id,
526                                runtime_local,
527                                item.as_ref(),
528                                ctx.structs,
529                                ctx.enums,
530                                ctx.pointer_drop_functions,
531                                ctx.scratch_i32_local,
532                                ctx.scratch_i32_local_b,
533                                ctx.scratch_i32_local_c,
534                                ctx.scratch_i32_local_d,
535                                ctx.scratch_i64_local,
536                            )?;
537                        }
538                    }
539                    Ty::Enum { .. } => emit_enum_release(
540                        body,
541                        ctx.runtime,
542                        ctx.memory_id,
543                        runtime_local,
544                        &info.ty,
545                        ctx.structs,
546                        ctx.enums,
547                        ctx.pointer_drop_functions,
548                        ctx.scratch_i32_local,
549                        ctx.scratch_i32_local_b,
550                        ctx.scratch_i32_local_c,
551                        ctx.scratch_i32_local_d,
552                        ctx.scratch_i64_local,
553                    )?,
554                    _ => {}
555                }
556            }
557            for binding in ctx.function.param_bindings() {
558                if !borrowed_param_uses_inout(binding) {
559                    continue;
560                }
561                if matches!(binding.local.ty, Ty::Unit) {
562                    continue;
563                }
564                let local = runtime_local_index(
565                    ctx.local_indices,
566                    binding.local_id.as_u32(),
567                    "inout return local",
568                )?;
569                body.local_get(local);
570            }
571            if ctx.function.return_ty != Ty::Unit
572                && !function_returns_via_out_slot(&ctx.function.return_ty)
573            {
574                let value_index =
575                    runtime_local_index(ctx.local_indices, value.as_u32(), "return value")?;
576                body.local_get(value_index);
577            }
578            if ctx.shadow_stack_frame_bytes > 0 {
579                let release =
580                    runtime_function(ctx.runtime, wasm_api::RT_SHADOW_STACK_RELEASE_SYMBOL)?;
581                body.i32_const(int32_from_u32(
582                    ctx.shadow_stack_frame_bytes,
583                    "shadow stack frame bytes",
584                )?)
585                .call(release.function_id);
586            }
587            body.instr(walrus::ir::Return {});
588            Ok(())
589        }
590    }
591}
592
593/// Assign jump or branch arguments to destination block parameters in parallel.
594pub(crate) fn assign_block_params(
595    function: &ruka_mir::MirFunction,
596    body: &mut walrus::InstrSeqBuilder,
597    local_indices: &BTreeMap<u32, Option<LocalId>>,
598    local_runtime_types: &BTreeMap<u32, Option<ValType>>,
599    target: ruka_mir::MirBlockId,
600    args: &[ruka_mir::MirLocalId],
601    parallel_tmp_i64_locals: &[LocalId],
602    parallel_tmp_i32_locals: &[LocalId],
603) -> Result<(), LowerError> {
604    let target_block = function
605        .blocks
606        .get(target)
607        .ok_or(LowerError::MissingBlock(target.as_u32()))?;
608    if target_block.params.len() != args.len() {
609        return Err(LowerError::BlockParamArityMismatch);
610    }
611
612    let mut runtime_pairs_i64 = Vec::<(LocalId, LocalId)>::new();
613    let mut runtime_pairs_i32 = Vec::<(LocalId, LocalId)>::new();
614    for (param, arg) in target_block.params.iter().zip(args.iter()) {
615        let dst = local_indices.get(&param.as_u32()).copied().ok_or(
616            LowerError::UnsupportedInstruction("missing block param local"),
617        )?;
618        if let Some(dst_local) = dst {
619            let src_local = runtime_local_index(local_indices, arg.as_u32(), "jump arg")?;
620            match local_runtime_types
621                .get(&param.as_u32())
622                .and_then(|ty| *ty)
623                .ok_or(LowerError::UnsupportedInstruction(
624                    "missing block param type",
625                ))? {
626                ValType::I64 => runtime_pairs_i64.push((src_local, dst_local)),
627                ValType::I32 => runtime_pairs_i32.push((src_local, dst_local)),
628                _ => {
629                    return Err(LowerError::UnsupportedInstruction(
630                        "unsupported block param type",
631                    ));
632                }
633            }
634        }
635    }
636
637    if runtime_pairs_i64.len() > parallel_tmp_i64_locals.len() {
638        return Err(LowerError::UnsupportedInstruction(
639            "insufficient i64 parallel copy locals",
640        ));
641    }
642    if runtime_pairs_i32.len() > parallel_tmp_i32_locals.len() {
643        return Err(LowerError::UnsupportedInstruction(
644            "insufficient i32 parallel copy locals",
645        ));
646    }
647
648    for (index, (src, _)) in runtime_pairs_i64.iter().enumerate() {
649        body.local_get(*src)
650            .local_set(parallel_tmp_i64_locals[index]);
651    }
652    for (index, (src, _)) in runtime_pairs_i32.iter().enumerate() {
653        body.local_get(*src)
654            .local_set(parallel_tmp_i32_locals[index]);
655    }
656    for (index, (_, dst)) in runtime_pairs_i64.iter().enumerate() {
657        body.local_get(parallel_tmp_i64_locals[index])
658            .local_set(*dst);
659    }
660    for (index, (_, dst)) in runtime_pairs_i32.iter().enumerate() {
661        body.local_get(parallel_tmp_i32_locals[index])
662            .local_set(*dst);
663    }
664
665    Ok(())
666}