rukalang/elab/
bind_template_calls.rs

1use super::helpers::{is_type_param, ty_to_type_expr};
2use super::prelude::*;
3
4/// Bind runtime template call arguments into specialization inputs.
5///
6/// Input invariant: `args` belongs to one runtime template call.
7/// Output invariant: returned binding contains a deterministic specialization key
8/// for the same argument sequence.
9///
10/// This pass resolves type arguments, binds meta parameters, materializes
11/// variadic packs, and builds a stable specialization key used by runtime
12/// template instantiation cache lookup.
13///
14/// Pass name: `elab.bind_template_call_args`.
15struct BindTemplateCallArgsPass<'a> {
16    elaborator: &'a mut Elaborator,
17    callee_name: &'a str,
18    template: &'a RuntimeFunctionTemplate,
19    locals: &'a mut Vec<(String, Ty)>,
20}
21
22impl<'a> Pass for BindTemplateCallArgsPass<'a> {
23    type In = Vec<CallArg>;
24    type Out = TemplateCallBinding;
25    type Error = ElabError;
26
27    const NAME: &'static str = "elab.bind_template_call_args";
28
29    fn run(&mut self, mut args: Self::In, _cx: &mut PassContext) -> Result<Self::Out, Self::Error> {
30        let mut type_bindings = BTreeMap::new();
31        let mut meta_bindings = BTreeMap::new();
32        let mut runtime_args = Vec::new();
33        let mut specialization_values = Vec::new();
34        let mut arg_index = 0usize;
35        let type_params = self
36            .template
37            .type_params
38            .iter()
39            .cloned()
40            .collect::<BTreeSet<_>>();
41
42        for param in &self.template.function.params {
43            if is_type_param(param) {
44                let Some(arg) = args.get_mut(arg_index) else {
45                    return Err(ElabError::ArityMismatch {
46                        function: self.callee_name.to_owned(),
47                        expected: self.template.function.params.len(),
48                        actual: args.len(),
49                    });
50                };
51                let CallArg::Type(ty) = arg else {
52                    return Err(ElabError::UnsupportedTypeExpr);
53                };
54                let resolved = self.elaborator.resolve_type_expr(
55                    ty,
56                    &BTreeMap::new(),
57                    &BTreeSet::new(),
58                    None,
59                )?;
60                self.elaborator.ensure_ty_instantiated(&resolved)?;
61                *ty = ty_to_type_expr(&resolved);
62                type_bindings.insert(param.name.clone(), resolved.clone());
63                specialization_values.push(SpecializationValue::Type(resolved));
64                arg_index += 1;
65            } else if param.is_meta {
66                let Some(arg) = args.get_mut(arg_index) else {
67                    return Err(ElabError::ArityMismatch {
68                        function: self.callee_name.to_owned(),
69                        expected: self.template.function.params.len(),
70                        actual: args.len(),
71                    });
72                };
73                let CallArg::Expr(expr) = arg else {
74                    return Err(ElabError::UnsupportedTypeExpr);
75                };
76                let binding = self
77                    .elaborator
78                    .bind_meta_arg(self.callee_name, param, expr)?;
79                specialization_values.push(match &binding {
80                    TemplateMetaBinding::Int(value) => SpecializationValue::Int(*value),
81                    TemplateMetaBinding::String(value) => {
82                        SpecializationValue::String(value.clone())
83                    }
84                    TemplateMetaBinding::Type(_) => {
85                        return Err(ElabError::UnsupportedMetaParamType {
86                            name: self.callee_name.to_owned(),
87                            param: param.name.clone(),
88                        });
89                    }
90                });
91                meta_bindings.insert(param.name.clone(), binding);
92                arg_index += 1;
93            } else if param.is_variadic {
94                let mut pack_exprs = Vec::new();
95                let mut pack_tys = Vec::new();
96                while arg_index < args.len() {
97                    let CallArg::Expr(expr) = &mut args[arg_index] else {
98                        return Err(ElabError::UnsupportedTypeExpr);
99                    };
100                    let ty = self.elaborator.elaborate_expr(expr, None, self.locals)?;
101                    pack_tys.push(ty);
102                    pack_exprs.push(expr.clone());
103                    arg_index += 1;
104                }
105                let tuple_ty = if pack_tys.is_empty() {
106                    Ty::Unit
107                } else {
108                    Ty::Tuple(pack_tys.clone())
109                };
110                runtime_args.extend(pack_exprs);
111                meta_bindings.insert(
112                    format!("{}__types", param.name),
113                    TemplateMetaBinding::Type(ty_to_type_expr(&tuple_ty)),
114                );
115                specialization_values.push(SpecializationValue::Pack(pack_tys));
116                return Ok(TemplateCallBinding {
117                    type_bindings,
118                    meta_bindings,
119                    runtime_args,
120                    variadic_pack_tys: match tuple_ty {
121                        Ty::Unit => Vec::new(),
122                        Ty::Tuple(items) => items,
123                        _ => Vec::new(),
124                    },
125                    specialization_key: SpecializationKey {
126                        function_name: self.callee_name.to_owned(),
127                        values: specialization_values,
128                    },
129                });
130            } else {
131                let Some(arg) = args.get_mut(arg_index) else {
132                    return Err(ElabError::ArityMismatch {
133                        function: self.callee_name.to_owned(),
134                        expected: self.template.function.params.len(),
135                        actual: args.len(),
136                    });
137                };
138                let CallArg::Expr(expr) = arg else {
139                    return Err(ElabError::UnsupportedTypeExpr);
140                };
141                let expected_ty = self.elaborator.resolve_type_expr(
142                    &param.ty.ty,
143                    &type_bindings,
144                    &type_params,
145                    None,
146                )?;
147                let _ = self
148                    .elaborator
149                    .elaborate_expr(expr, Some(&expected_ty), self.locals)?;
150                runtime_args.push(expr.clone());
151                arg_index += 1;
152            }
153        }
154
155        if arg_index != args.len() {
156            return Err(ElabError::ArityMismatch {
157                function: self.callee_name.to_owned(),
158                expected: arg_index,
159                actual: args.len(),
160            });
161        }
162
163        Ok(TemplateCallBinding {
164            type_bindings,
165            meta_bindings,
166            runtime_args,
167            variadic_pack_tys: Vec::new(),
168            specialization_key: SpecializationKey {
169                function_name: self.callee_name.to_owned(),
170                values: specialization_values,
171            },
172        })
173    }
174}
175
176impl Elaborator {
177    /// Bind one runtime template call argument list to specialization inputs.
178    pub(super) fn bind_template_call_args(
179        &mut self,
180        callee_name: &str,
181        template: &RuntimeFunctionTemplate,
182        args: &mut Vec<CallArg>,
183        locals: &mut Vec<(String, Ty)>,
184    ) -> Result<TemplateCallBinding, ElabError> {
185        let mut pass = BindTemplateCallArgsPass {
186            elaborator: self,
187            callee_name,
188            template,
189            locals,
190        };
191        let (output, timings) = Elaborator::run_subpass(&mut pass, std::mem::take(args))?;
192        self.record_subpass_timings(&timings);
193        Ok(output)
194    }
195}