ruka_codegen_wasm/codegen/wasm/
linker.rs

1use std::collections::BTreeMap;
2
3use ruka_runtime::wasm_api::{LoweringMode, RuntimeFnDescriptor, runtime_wasm_functions};
4use walrus::ir::Value;
5use walrus::{ConstExpr, DataKind, FunctionId, GlobalId, MemoryId, Module, ModuleConfig};
6
7use super::LowerError;
8
9/// Runtime function metadata and function IDs resolved from `ruka_runtime.wasm`.
10#[derive(Debug, Clone)]
11pub(crate) struct RuntimeFunctions {
12    by_symbol: BTreeMap<&'static str, RuntimeFunction>,
13}
14
15/// Resolved runtime function along with its ABI descriptor.
16#[derive(Debug, Clone, Copy)]
17pub(crate) struct RuntimeFunction {
18    pub(crate) function_id: FunctionId,
19    pub(crate) descriptor: RuntimeFnDescriptor,
20}
21
22/// Linked runtime module plus resolved runtime metadata.
23#[derive(Debug)]
24pub(crate) struct LinkedRuntime {
25    pub(crate) module: Module,
26    pub(crate) runtime: RuntimeFunctions,
27    pub(crate) memory_id: MemoryId,
28    pub(crate) string_literal_offsets: BTreeMap<String, u32>,
29}
30
31/// Parse runtime module and install string literals.
32pub(crate) fn link_runtime_with_literals(
33    string_literals: &[String],
34) -> Result<LinkedRuntime, LowerError> {
35    let mut module = load_runtime_module()?;
36    let runtime = resolve_runtime_functions(&module)?;
37    let memory_id = resolve_runtime_memory(&module)?;
38    let heap_base_global = resolve_exported_global(&module, "__heap_base")?;
39    let original_heap_base = read_const_i32_global(&module, heap_base_global, "__heap_base")?;
40    let (string_literal_offsets, heap_base) =
41        install_string_literals(&mut module, memory_id, original_heap_base, string_literals);
42    update_heap_base_global(&mut module, heap_base_global, heap_base)?;
43    Ok(LinkedRuntime {
44        module,
45        runtime,
46        memory_id,
47        string_literal_offsets,
48    })
49}
50
51/// Read one exported constant `i32` global value.
52fn read_const_i32_global(
53    module: &Module,
54    global_id: GlobalId,
55    export_name: &'static str,
56) -> Result<u32, LowerError> {
57    let global = module.globals.get(global_id);
58    let walrus::GlobalKind::Local(ConstExpr::Value(Value::I32(value))) = &global.kind else {
59        return Err(LowerError::MissingRuntimeGlobal(export_name));
60    };
61    u32::try_from(*value).map_err(|_| LowerError::Int32Overflow(export_name))
62}
63
64/// Load and parse the pre-generated runtime wasm module.
65fn load_runtime_module() -> Result<Module, LowerError> {
66    let bytes = include_bytes!(concat!(
67        env!("CARGO_MANIFEST_DIR"),
68        "/../rukalang_wasm/generated/ruka_runtime.wasm"
69    ));
70    ModuleConfig::new()
71        .parse(bytes)
72        .map_err(|error| LowerError::ParseRuntime(error.to_string()))
73}
74
75/// Resolve required runtime exports from the loaded runtime module.
76fn resolve_runtime_functions(module: &Module) -> Result<RuntimeFunctions, LowerError> {
77    let mut by_symbol = BTreeMap::<&'static str, RuntimeFunction>::new();
78    for descriptor in runtime_wasm_functions() {
79        if descriptor.lowering != LoweringMode::Direct {
80            continue;
81        }
82        let function_id = find_exported_function(module, descriptor.export_name).ok_or(
83            LowerError::MissingRuntimeExport {
84                symbol: descriptor.symbol.to_owned(),
85                export_name: descriptor.export_name.to_owned(),
86            },
87        )?;
88        let _ = by_symbol.insert(
89            descriptor.symbol,
90            RuntimeFunction {
91                function_id,
92                descriptor,
93            },
94        );
95    }
96    Ok(RuntimeFunctions { by_symbol })
97}
98
99/// Resolve the runtime module memory used for literal data segments.
100fn resolve_runtime_memory(module: &Module) -> Result<MemoryId, LowerError> {
101    if let Some(memory_id) = module.exports.iter().find_map(|export| {
102        if export.name == "memory" {
103            if let walrus::ExportItem::Memory(memory_id) = export.item {
104                return Some(memory_id);
105            }
106        }
107        None
108    }) {
109        return Ok(memory_id);
110    }
111
112    module
113        .memories
114        .iter()
115        .next()
116        .map(|memory| memory.id())
117        .ok_or(LowerError::MissingRuntimeMemory)
118}
119
120/// Resolve an exported runtime global by name.
121fn resolve_exported_global(
122    module: &Module,
123    export_name: &'static str,
124) -> Result<GlobalId, LowerError> {
125    module
126        .exports
127        .iter()
128        .find_map(|export| {
129            if export.name != export_name {
130                return None;
131            }
132            if let walrus::ExportItem::Global(global_id) = export.item {
133                Some(global_id)
134            } else {
135                None
136            }
137        })
138        .ok_or(LowerError::MissingRuntimeGlobal(export_name))
139}
140
141/// Update the exported heap base after injecting new literal data segments.
142fn update_heap_base_global(
143    module: &mut Module,
144    heap_base_global: GlobalId,
145    heap_base: u32,
146) -> Result<(), LowerError> {
147    let heap_base = i32::try_from(heap_base).map_err(|_| LowerError::Int32Overflow("heap base"))?;
148    let global = module.globals.get_mut(heap_base_global);
149    let walrus::GlobalKind::Local(init) = &mut global.kind else {
150        return Err(LowerError::MissingRuntimeGlobal("__heap_base"));
151    };
152    *init = ConstExpr::Value(Value::I32(heap_base));
153    Ok(())
154}
155
156/// Install UTF-8 string literal data segments and return pointer offsets.
157fn install_string_literals(
158    module: &mut Module,
159    memory: MemoryId,
160    reserved_start: u32,
161    string_literals: &[String],
162) -> (BTreeMap<String, u32>, u32) {
163    let mut cursor = next_data_offset(module).max(reserved_start);
164    let mut offsets = BTreeMap::<String, u32>::new();
165    for literal in string_literals {
166        if offsets.contains_key(literal) {
167            continue;
168        }
169        let bytes = literal.as_bytes();
170        let len = u32::try_from(bytes.len()).expect("string literal length should fit u32");
171        let mut payload = Vec::<u8>::with_capacity(8 + bytes.len());
172        payload.extend_from_slice(&0_u32.to_le_bytes());
173        payload.extend_from_slice(&len.to_le_bytes());
174        payload.extend_from_slice(bytes);
175        let offset = cursor;
176        module.data.add(
177            DataKind::Active {
178                memory,
179                offset: ConstExpr::Value(Value::I32(
180                    i32::try_from(offset).expect("string literal offset should fit i32"),
181                )),
182            },
183            payload,
184        );
185        let _ = offsets.insert(literal.clone(), offset);
186        cursor = align_up(offset.saturating_add(8).saturating_add(len), 8);
187    }
188    (offsets, cursor)
189}
190
191/// Compute the next safe offset for adding active data segments.
192fn next_data_offset(module: &Module) -> u32 {
193    let mut end = 0_u32;
194    for data in module.data.iter() {
195        if let DataKind::Active {
196            memory: _,
197            offset: ConstExpr::Value(Value::I32(start)),
198        } = &data.kind
199        {
200            let start = u32::try_from(*start).unwrap_or(0);
201            let len = u32::try_from(data.value.len()).expect("data segment length should fit u32");
202            end = end.max(start.saturating_add(len));
203        }
204    }
205    align_up(end.max(0x1000), 8)
206}
207
208/// Align an offset to the next multiple of `alignment`.
209fn align_up(offset: u32, alignment: u32) -> u32 {
210    if alignment == 0 {
211        return offset;
212    }
213    let mask = alignment - 1;
214    (offset + mask) & !mask
215}
216
217/// Find a function export by name and return its function ID.
218fn find_exported_function(module: &Module, export_name: &str) -> Option<FunctionId> {
219    module.exports.iter().find_map(|export| {
220        if export.name != export_name {
221            return None;
222        }
223        if let walrus::ExportItem::Function(function_id) = export.item {
224            Some(function_id)
225        } else {
226            None
227        }
228    })
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    /// Keep the runtime allocator heap base aligned with injected literals.
236    #[test]
237    fn link_runtime_advances_exported_heap_base_for_literals() {
238        let literal = "count={}, retries={}";
239        let original = load_runtime_module().expect("runtime should load");
240        let original_heap_base = resolve_exported_global(&original, "__heap_base")
241            .and_then(|global_id| read_const_i32_global(&original, global_id, "__heap_base"))
242            .expect("original heap base");
243        let linked = link_runtime_with_literals(&[literal.to_owned()])
244            .expect("runtime should link with literals");
245        let heap_base_global =
246            resolve_exported_global(&linked.module, "__heap_base").expect("heap base export");
247        let heap_base_global = linked.module.globals.get(heap_base_global);
248        let walrus::GlobalKind::Local(ConstExpr::Value(Value::I32(actual_heap_base))) =
249            &heap_base_global.kind
250        else {
251            panic!("heap base should remain a local const global");
252        };
253        let literal_offset = *linked
254            .string_literal_offsets
255            .get(literal)
256            .expect("literal offset should be recorded");
257        let literal_bytes = u32::try_from(literal.len()).expect("literal length should fit u32");
258        let expected_heap_base = align_up(literal_offset.saturating_add(8 + literal_bytes), 8);
259
260        assert_eq!(
261            u32::try_from(*actual_heap_base).ok(),
262            Some(expected_heap_base)
263        );
264        for offset in linked.string_literal_offsets.values() {
265            assert!(
266                *offset >= original_heap_base,
267                "literal offset should stay beyond the runtime reserved heap start"
268            );
269        }
270        assert!(u32::try_from(*actual_heap_base).ok().unwrap_or(0) >= original_heap_base);
271    }
272}
273
274impl RuntimeFunctions {
275    /// Resolve a runtime function by symbol.
276    pub(crate) fn resolve(&self, symbol: &str) -> Option<RuntimeFunction> {
277        self.by_symbol.get(symbol).copied()
278    }
279}