33 int num_intermediate_inputs)
34 : mode(mode), num_inputs(num_inputs),
35 num_intermediate_inputs(num_intermediate_inputs) {}
53 return {.postfix = b, .type = t};
56std::string CodeGenerator::generateExprToString(
Expr* expr) {
57 if (expr ==
nullptr) {
60 PostfixBuilder temp_builder;
62 return temp_builder.getExpression();
66 if (expr ==
nullptr) {
70 [
this, &builder](
auto& e) {
return this->handle(e, builder); },
75 if (stmt ==
nullptr) {
78 std::visit([
this, &builder](
auto& s) { this->handle(s, builder); },
83 builder.addNumber(expr.value.value);
88 std::string name = expr.name.value;
90 if (param_substitutions.contains(name)) {
91 return generate(param_substitutions.at(name), builder);
94 if (name.starts_with(
"$")) {
95 std::string base_name = name.substr(1);
96 builder.addConstant(base_name);
99 if (*buf_idx >= 0 && *buf_idx < num_intermediate_inputs) {
110 std::string var_name = name;
111 if (var_rename_map.contains(name)) {
112 var_name = var_rename_map.at(name);
113 }
else if (expr.symbol) {
114 var_name = expr.symbol->name;
117 builder.addVariableLoad(var_name);
120 return expr.symbol->type;
127 if (expr.op.type == TokenType::Minus) {
129 std::string val = num->value.value;
130 if (!val.starts_with(
"-")) {
131 val = std::format(
"-{}", val);
133 builder.addNumber(val);
138 generate(expr.right.get(), builder);
141 builder.addNumber(
"0");
146 builder.addUnaryOp(expr.op.type);
153 if (expr.op.type == TokenType::LogicalAnd ||
154 expr.op.type == TokenType::LogicalOr) {
155 builder.addNumber(
"0");
159 generate(expr.right.get(), builder);
160 builder.addNumber(
"0");
164 builder.addOp(expr.op.type);
168 generate(expr.right.get(), builder);
169 builder.addOp(expr.op.type);
175 builder.addNumber(
"0");
179 PostfixBuilder true_branch_builder;
180 generate(expr.true_expr.get(), true_branch_builder);
182 PostfixBuilder false_branch_builder;
183 generate(expr.false_expr.get(), false_branch_builder);
185 builder.append(true_branch_builder);
186 builder.append(false_branch_builder);
187 builder.addTernaryOp();
193 if (expr.resolved_signature !=
nullptr && expr.resolved_def !=
nullptr) {
194 const auto& sig = *expr.resolved_signature;
195 FunctionDef* func_def = expr.resolved_def;
197 inlineFunctionCall(sig, func_def, expr.args, expr.range, builder);
202 if (expr.resolved_builtin !=
nullptr) {
203 const BuiltinFunction* builtin = expr.resolved_builtin;
205 if (builtin->special_handler) {
206 builder.append(builtin->special_handler(
this, expr));
210 for (
size_t i = 0; i < expr.args.size(); ++i) {
212 generate(expr.args[i].get(), builder);
215 builder.addFunctionCall(expr.callee);
220 if (expr.callee.starts_with(
"nth_")) {
221 std::string n_str = expr.callee.substr(4);
222 int n = std::stoi(n_str);
223 int arg_count =
static_cast<int>(expr.args.size());
225 for (
const auto& arg : expr.args) {
228 builder.addSortN(arg_count);
229 builder.addDropN(n - 1);
230 builder.addSwapN(arg_count - n);
231 builder.addDropN(arg_count - n);
241 std::string clip_name = expr.clip.value;
242 if (param_substitutions.contains(clip_name)) {
243 auto* subst_expr = param_substitutions.at(clip_name);
244 clip_name = generateExprToString(subst_expr);
246 clip_name = clip_name.substr(1);
249 builder.addPropAccess(clip_name, expr.prop.value);
255 std::string clip_name = expr.clip.value;
256 if (param_substitutions.contains(clip_name)) {
257 auto* subst_expr = param_substitutions.at(clip_name);
258 clip_name = generateExprToString(subst_expr);
260 clip_name = clip_name.substr(1);
263 builder.addStaticPixelAccess(clip_name, expr.offset_x.value,
264 expr.offset_y.value, expr.boundary_suffix);
270 std::string plane_idx_str =
271 generateExprToString(expr.plane_index_expr.get());
272 builder.addFrameDimension(expr.dimension_name, plane_idx_str);
279 std::string array_name;
281 if (var_expr !=
nullptr) {
282 array_name = var_expr->name.value;
284 if (var_rename_map.contains(array_name)) {
285 array_name = var_rename_map.at(array_name);
286 }
else if (expr.array_symbol) {
287 array_name = expr.array_symbol->name;
289 }
else if (expr.array_symbol) {
290 array_name = expr.array_symbol->name;
293 generate(expr.index.get(), builder);
294 builder.addArrayLoad(array_name);
301 PostfixBuilder temp_builder;
302 generate(stmt.expr.get(), temp_builder);
303 checkStackEffect(temp_builder.getExpression(), 0, stmt.range);
304 builder.append(temp_builder);
311 std::string name = stmt.name.value;
313 std::string var_name = name;
314 if (var_rename_map.contains(name)) {
315 var_name = var_rename_map.at(name);
316 }
else if (stmt.symbol) {
317 var_name = stmt.symbol->name;
322 if (call_expr->callee ==
"new" || call_expr->callee ==
"resize") {
326 PostfixBuilder& b = builder;
329 auto* arg_expr = call_expr->args[0].get();
331 std::string size_value;
334 size_value = num_expr->value.value;
336 std::string var_name_arg = var_expr->name.value;
337 if (param_substitutions.contains(var_name_arg)) {
338 auto* subst_expr = param_substitutions.at(var_name_arg);
340 size_value = subst_num->value.value;
345 b.addArrayAllocStatic(var_name, size_value);
347 generate(call_expr->args[0].get(), b);
348 b.addArrayAllocDynamic(var_name);
351 checkStackEffect(b.getExpression(), 0, stmt.range);
361 b.addVariableStore(var_name);
362 checkStackEffect(b.getExpression(), 0, stmt.range);
365 generate(stmt.value.get(), builder);
366 builder.addVariableStore(var_name);
374 std::string array_name;
376 if (var_expr !=
nullptr) {
377 array_name = var_expr->name.value;
379 if (var_rename_map.contains(array_name)) {
380 array_name = var_rename_map.at(array_name);
381 }
else if (array_access->array_symbol) {
382 array_name = array_access->array_symbol->name;
384 }
else if (array_access->array_symbol) {
385 array_name = array_access->array_symbol->name;
391 generate(array_access->index.get(), b);
392 b.addArrayStore(array_name);
393 checkStackEffect(b.getExpression(), 0, stmt.range);
396 generate(stmt.value.get(), builder);
397 generate(array_access->index.get(), builder);
398 builder.addArrayStore(array_name);
403 for (
const auto& s : stmt.statements) {
409 std::string else_label = std::format(
"__internal_else_{}", label_counter++);
410 std::string endif_label =
411 std::format(
"__internal_endif_{}", label_counter++);
413 generate(stmt.condition.get(), builder);
414 builder.addNumber(
"0");
416 builder.addConditionalJump(else_label);
418 generate(stmt.then_branch.get(), builder);
420 if (stmt.else_branch) {
421 builder.addUnconditionalJump(endif_label);
422 builder.addLabel(else_label);
423 generate(stmt.else_branch.get(), builder);
424 builder.addLabel(endif_label);
426 builder.addLabel(else_label);
431 std::string start_label =
432 std::format(
"__internal_while_start_{}", label_counter++);
433 std::string end_label =
434 std::format(
"__internal_while_end_{}", label_counter++);
436 builder.addLabel(start_label);
437 generate(stmt.condition.get(), builder);
438 builder.addNumber(
"0");
440 builder.addConditionalJump(end_label);
442 builder.addUnconditionalJump(start_label);
443 builder.addLabel(end_label);
448 generate(stmt.value.get(), builder);
449 const int call_id = call_site_id_stack.back();
450 std::string ret_var_name = std::format(
"__internal_ret_{}_{}",
451 current_function->name, call_id);
452 builder.addVariableStore(ret_var_name);
455 const int call_id = call_site_id_stack.back();
456 std::string ret_label = std::format(
"__internal_ret_label_{}_{}",
457 current_function->name, call_id);
458 builder.addUnconditionalJump(ret_label);
462 std::string label_name = stmt.symbol ? stmt.symbol->name : stmt.name.value;
463 builder.addLabel(label_name);
467 std::string label_name = stmt.target_label_symbol
468 ? stmt.target_label_symbol->name
470 if (stmt.condition) {
471 generate(stmt.condition.get(), builder);
472 builder.addNumber(
"0");
475 builder.addConditionalJump(label_name);
477 builder.addUnconditionalJump(label_name);
481void CodeGenerator::handle([[maybe_unused]]
const FunctionDef& stmt,
486void CodeGenerator::handle([[maybe_unused]]
const GlobalDecl& stmt,
491void CodeGenerator::checkStackEffect([[maybe_unused]]
const std::string& s,
492 [[maybe_unused]]
int expected,
493 [[maybe_unused]]
const Range& range) {
495 int effect = computeStackEffect(s, range);
496 if (effect != expected) {
497 throw CodeGenError(std::format(
"Unbalanced stack. "
498 "Expected {}, got {}",
505int CodeGenerator::computeStackEffect(
const std::string& s,
506 const Range& range) {
515 num_intermediate_inputs);
516 }
catch (
const std::exception& e) {
517 throw CodeGenError(e.what(), range);
521Expr* CodeGenerator::resolveParamSubstitution(
Expr* expr)
const {
522 Expr* resolved = expr;
523 std::set<const Expr*> seen;
525 while (resolved !=
nullptr) {
527 if (var_expr ==
nullptr) {
531 auto it = param_substitutions.find(var_expr->name.value);
532 if (it == param_substitutions.end() || it->second ==
nullptr) {
536 if (!seen.insert(resolved).second || it->second == resolved) {
540 resolved = it->second;
546void CodeGenerator::inlineFunctionCall(
548 const std::vector<std::unique_ptr<Expr>>& args,
551 const auto& func_name = sig.name;
552 const int call_id = call_site_counter++;
553 call_site_id_stack.push_back(call_id);
556 auto saved_param_substitutions = param_substitutions;
557 auto saved_var_rename_map = var_rename_map;
558 const auto* saved_current_function = current_function;
560 std::map<std::string, Expr*> next_param_substitutions;
561 std::map<std::string, std::string> param_map;
563 PostfixBuilder param_assignments;
565 for (
size_t i = 0; i < sig.params.size(); ++i) {
566 const auto& param_info = sig.params[i];
567 const std::string& param_name = param_info.name;
568 const Type param_type = param_info.type;
570 auto* arg_expr = args[i].get();
573 next_param_substitutions[param_name] =
574 resolveParamSubstitution(arg_expr);
577 if (var_expr !=
nullptr) {
578 std::string arg_var_name = var_expr->name.value;
579 if (saved_var_rename_map.contains(arg_var_name)) {
580 param_map[param_name] =
581 saved_var_rename_map.at(arg_var_name);
582 }
else if (var_expr->symbol) {
583 param_map[param_name] = var_expr->symbol->name;
585 param_map[param_name] = arg_var_name;
589 std::string renamed_param =
590 std::format(
"__internal_func_{}_"
592 func_name, call_id, param_name);
594 generate(args[i].get(), param_assignments);
595 param_assignments.addVariableStore(renamed_param);
596 param_map[param_name] = renamed_param;
601 std::function<void(Stmt*)> collect_locals = [&](Stmt* stmt) {
606 std::string var_name = assign->name.value;
607 if (!param_map.contains(var_name)) {
608 std::string renamed_var =
614 func_name, call_id, var_name);
615 param_map[var_name] = renamed_var;
618 for (
const auto& s : block->statements) {
619 collect_locals(s.get());
622 collect_locals(if_stmt->then_branch.get());
623 if (if_stmt->else_branch) {
624 collect_locals(if_stmt->else_branch.get());
627 collect_locals(while_stmt->body.get());
631 for (
const auto& s : func_def->body->statements) {
632 collect_locals(s.get());
636 param_substitutions = std::move(next_param_substitutions);
637 var_rename_map = param_map;
638 current_function = &sig;
640 std::string label_prefix =
641 std::format(
"__internal_{}_{}_", func_name, call_id);
643 PostfixBuilder body_builder;
644 handle(*func_def->body, body_builder);
645 body_builder.prefixLabels(label_prefix);
647 std::string ret_label =
648 std::format(
"__internal_ret_label_{}_{}", func_name, call_id);
649 body_builder.addLabel(ret_label);
651 if (sig.returns_value) {
652 std::string ret_var =
653 std::format(
"__internal_ret_{}_{}", func_name, call_id);
654 body_builder.addVariableLoad(ret_var);
658 param_substitutions = saved_param_substitutions;
659 var_rename_map = saved_var_rename_map;
660 current_function = saved_current_function;
661 call_site_id_stack.pop_back();
663 PostfixBuilder inlined_builder;
664 inlined_builder.append(param_assignments);
665 inlined_builder.append(body_builder);
668 int expected_effect = sig.returns_value ? 1 : 0;
671 computeStackEffect(inlined_builder.getExpression(), call_range);
672 if (actual_effect != expected_effect) {
673 throw CodeGenError(std::format(
"Function '{}' has "
678 func_name, expected_effect,
682 }
catch (
const CodeGenError& e) {
684 std::format(
"In function '{}': {}", func_name, e.what()),
688 builder.append(inlined_builder);
std::string generate(const Program *program)
ExprResult generateExpr(Expr *expr)
CodeGenerator(Mode mode, int num_inputs, int num_intermediate_inputs=0)
std::string getExpression() const
void addVariableLoad(const std::string &var_name)
std::optional< int > get_buffer_index(const std::string &s)
auto get_if(Wrapper *wrapper) -> decltype(std::get_if< T >(&wrapper->value))
bool is_clip_name(const std::string &s)
int compute_postfix_stack_effect(const std::string &postfix_expr, PostfixMode mode, int line, int num_inputs, int num_intermediate_inputs)
std::vector< std::unique_ptr< Stmt > > statements