ruka_codegen_wasm/codegen/wasm/
linker.rs1use std::collections::BTreeMap;
2
3use ruka_runtime::wasm_api::{LoweringMode, RuntimeFnDescriptor, runtime_wasm_functions};
4use walrus::ir::Value;
5use walrus::{ConstExpr, DataKind, FunctionId, GlobalId, MemoryId, Module, ModuleConfig};
6
7use super::LowerError;
8
9#[derive(Debug, Clone)]
11pub(crate) struct RuntimeFunctions {
12 by_symbol: BTreeMap<&'static str, RuntimeFunction>,
13}
14
15#[derive(Debug, Clone, Copy)]
17pub(crate) struct RuntimeFunction {
18 pub(crate) function_id: FunctionId,
19 pub(crate) descriptor: RuntimeFnDescriptor,
20}
21
22#[derive(Debug)]
24pub(crate) struct LinkedRuntime {
25 pub(crate) module: Module,
26 pub(crate) runtime: RuntimeFunctions,
27 pub(crate) memory_id: MemoryId,
28 pub(crate) string_literal_offsets: BTreeMap<String, u32>,
29}
30
31pub(crate) fn link_runtime_with_literals(
33 string_literals: &[String],
34) -> Result<LinkedRuntime, LowerError> {
35 let mut module = load_runtime_module()?;
36 let runtime = resolve_runtime_functions(&module)?;
37 let memory_id = resolve_runtime_memory(&module)?;
38 let heap_base_global = resolve_exported_global(&module, "__heap_base")?;
39 let original_heap_base = read_const_i32_global(&module, heap_base_global, "__heap_base")?;
40 let (string_literal_offsets, heap_base) =
41 install_string_literals(&mut module, memory_id, original_heap_base, string_literals);
42 update_heap_base_global(&mut module, heap_base_global, heap_base)?;
43 Ok(LinkedRuntime {
44 module,
45 runtime,
46 memory_id,
47 string_literal_offsets,
48 })
49}
50
51fn read_const_i32_global(
53 module: &Module,
54 global_id: GlobalId,
55 export_name: &'static str,
56) -> Result<u32, LowerError> {
57 let global = module.globals.get(global_id);
58 let walrus::GlobalKind::Local(ConstExpr::Value(Value::I32(value))) = &global.kind else {
59 return Err(LowerError::MissingRuntimeGlobal(export_name));
60 };
61 u32::try_from(*value).map_err(|_| LowerError::Int32Overflow(export_name))
62}
63
64fn load_runtime_module() -> Result<Module, LowerError> {
66 let bytes = include_bytes!(concat!(
67 env!("CARGO_MANIFEST_DIR"),
68 "/../rukalang_wasm/generated/ruka_runtime.wasm"
69 ));
70 ModuleConfig::new()
71 .parse(bytes)
72 .map_err(|error| LowerError::ParseRuntime(error.to_string()))
73}
74
75fn resolve_runtime_functions(module: &Module) -> Result<RuntimeFunctions, LowerError> {
77 let mut by_symbol = BTreeMap::<&'static str, RuntimeFunction>::new();
78 for descriptor in runtime_wasm_functions() {
79 if descriptor.lowering != LoweringMode::Direct {
80 continue;
81 }
82 let function_id = find_exported_function(module, descriptor.export_name).ok_or(
83 LowerError::MissingRuntimeExport {
84 symbol: descriptor.symbol.to_owned(),
85 export_name: descriptor.export_name.to_owned(),
86 },
87 )?;
88 let _ = by_symbol.insert(
89 descriptor.symbol,
90 RuntimeFunction {
91 function_id,
92 descriptor,
93 },
94 );
95 }
96 Ok(RuntimeFunctions { by_symbol })
97}
98
99fn resolve_runtime_memory(module: &Module) -> Result<MemoryId, LowerError> {
101 if let Some(memory_id) = module.exports.iter().find_map(|export| {
102 if export.name == "memory" {
103 if let walrus::ExportItem::Memory(memory_id) = export.item {
104 return Some(memory_id);
105 }
106 }
107 None
108 }) {
109 return Ok(memory_id);
110 }
111
112 module
113 .memories
114 .iter()
115 .next()
116 .map(|memory| memory.id())
117 .ok_or(LowerError::MissingRuntimeMemory)
118}
119
120fn resolve_exported_global(
122 module: &Module,
123 export_name: &'static str,
124) -> Result<GlobalId, LowerError> {
125 module
126 .exports
127 .iter()
128 .find_map(|export| {
129 if export.name != export_name {
130 return None;
131 }
132 if let walrus::ExportItem::Global(global_id) = export.item {
133 Some(global_id)
134 } else {
135 None
136 }
137 })
138 .ok_or(LowerError::MissingRuntimeGlobal(export_name))
139}
140
141fn update_heap_base_global(
143 module: &mut Module,
144 heap_base_global: GlobalId,
145 heap_base: u32,
146) -> Result<(), LowerError> {
147 let heap_base = i32::try_from(heap_base).map_err(|_| LowerError::Int32Overflow("heap base"))?;
148 let global = module.globals.get_mut(heap_base_global);
149 let walrus::GlobalKind::Local(init) = &mut global.kind else {
150 return Err(LowerError::MissingRuntimeGlobal("__heap_base"));
151 };
152 *init = ConstExpr::Value(Value::I32(heap_base));
153 Ok(())
154}
155
156fn install_string_literals(
158 module: &mut Module,
159 memory: MemoryId,
160 reserved_start: u32,
161 string_literals: &[String],
162) -> (BTreeMap<String, u32>, u32) {
163 let mut cursor = next_data_offset(module).max(reserved_start);
164 let mut offsets = BTreeMap::<String, u32>::new();
165 for literal in string_literals {
166 if offsets.contains_key(literal) {
167 continue;
168 }
169 let bytes = literal.as_bytes();
170 let len = u32::try_from(bytes.len()).expect("string literal length should fit u32");
171 let mut payload = Vec::<u8>::with_capacity(8 + bytes.len());
172 payload.extend_from_slice(&0_u32.to_le_bytes());
173 payload.extend_from_slice(&len.to_le_bytes());
174 payload.extend_from_slice(bytes);
175 let offset = cursor;
176 module.data.add(
177 DataKind::Active {
178 memory,
179 offset: ConstExpr::Value(Value::I32(
180 i32::try_from(offset).expect("string literal offset should fit i32"),
181 )),
182 },
183 payload,
184 );
185 let _ = offsets.insert(literal.clone(), offset);
186 cursor = align_up(offset.saturating_add(8).saturating_add(len), 8);
187 }
188 (offsets, cursor)
189}
190
191fn next_data_offset(module: &Module) -> u32 {
193 let mut end = 0_u32;
194 for data in module.data.iter() {
195 if let DataKind::Active {
196 memory: _,
197 offset: ConstExpr::Value(Value::I32(start)),
198 } = &data.kind
199 {
200 let start = u32::try_from(*start).unwrap_or(0);
201 let len = u32::try_from(data.value.len()).expect("data segment length should fit u32");
202 end = end.max(start.saturating_add(len));
203 }
204 }
205 align_up(end.max(0x1000), 8)
206}
207
208fn align_up(offset: u32, alignment: u32) -> u32 {
210 if alignment == 0 {
211 return offset;
212 }
213 let mask = alignment - 1;
214 (offset + mask) & !mask
215}
216
217fn find_exported_function(module: &Module, export_name: &str) -> Option<FunctionId> {
219 module.exports.iter().find_map(|export| {
220 if export.name != export_name {
221 return None;
222 }
223 if let walrus::ExportItem::Function(function_id) = export.item {
224 Some(function_id)
225 } else {
226 None
227 }
228 })
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
237 fn link_runtime_advances_exported_heap_base_for_literals() {
238 let literal = "count={}, retries={}";
239 let original = load_runtime_module().expect("runtime should load");
240 let original_heap_base = resolve_exported_global(&original, "__heap_base")
241 .and_then(|global_id| read_const_i32_global(&original, global_id, "__heap_base"))
242 .expect("original heap base");
243 let linked = link_runtime_with_literals(&[literal.to_owned()])
244 .expect("runtime should link with literals");
245 let heap_base_global =
246 resolve_exported_global(&linked.module, "__heap_base").expect("heap base export");
247 let heap_base_global = linked.module.globals.get(heap_base_global);
248 let walrus::GlobalKind::Local(ConstExpr::Value(Value::I32(actual_heap_base))) =
249 &heap_base_global.kind
250 else {
251 panic!("heap base should remain a local const global");
252 };
253 let literal_offset = *linked
254 .string_literal_offsets
255 .get(literal)
256 .expect("literal offset should be recorded");
257 let literal_bytes = u32::try_from(literal.len()).expect("literal length should fit u32");
258 let expected_heap_base = align_up(literal_offset.saturating_add(8 + literal_bytes), 8);
259
260 assert_eq!(
261 u32::try_from(*actual_heap_base).ok(),
262 Some(expected_heap_base)
263 );
264 for offset in linked.string_literal_offsets.values() {
265 assert!(
266 *offset >= original_heap_base,
267 "literal offset should stay beyond the runtime reserved heap start"
268 );
269 }
270 assert!(u32::try_from(*actual_heap_base).ok().unwrap_or(0) >= original_heap_base);
271 }
272}
273
274impl RuntimeFunctions {
275 pub(crate) fn resolve(&self, symbol: &str) -> Option<RuntimeFunction> {
277 self.by_symbol.get(symbol).copied()
278 }
279}