VapourSynth-llvmexpr
Loading...
Searching...
No Matches
CodeGenerator.cpp
Go to the documentation of this file.
1
19
20#include "CodeGenerator.hpp"
21#include "Builtins.hpp"
22#include "PostfixBuilder.hpp"
23#include "PostfixHelper.hpp"
24#include "types.hpp"
25
26#include <format>
27#include <functional>
28#include <utility>
29
30namespace infix2postfix {
31
33 int num_intermediate_inputs)
34 : mode(mode), num_inputs(num_inputs),
35 num_intermediate_inputs(num_intermediate_inputs) {}
36
37std::string CodeGenerator::generate(const Program* program) {
38 PostfixBuilder main_builder;
39 for (const auto& stmt : program->statements) {
40 generate(stmt.get(), main_builder);
41 }
42
43 if (mode == Mode::Expr || mode == Mode::VkExpr) {
44 main_builder.addVariableLoad("RESULT");
45 }
46
47 return main_builder.getExpression();
48}
49
52 Type t = generate(expr, b);
53 return {.postfix = b, .type = t};
54}
55
56std::string CodeGenerator::generateExprToString(Expr* expr) {
57 if (expr == nullptr) {
58 return "";
59 }
60 PostfixBuilder temp_builder;
61 generate(expr, temp_builder);
62 return temp_builder.getExpression();
63}
64
66 if (expr == nullptr) {
67 return Type::Value;
68 }
69 return std::visit(
70 [this, &builder](auto& e) { return this->handle(e, builder); },
71 expr->value);
72}
73
74void CodeGenerator::generate(Stmt* stmt, PostfixBuilder& builder) {
75 if (stmt == nullptr) {
76 return;
77 }
78 std::visit([this, &builder](auto& s) { this->handle(s, builder); },
79 stmt->value);
80}
81
82Type CodeGenerator::handle(const NumberExpr& expr, PostfixBuilder& builder) {
83 builder.addNumber(expr.value.value);
84 return Type::Literal;
85}
86
87Type CodeGenerator::handle(const VariableExpr& expr, PostfixBuilder& builder) {
88 std::string name = expr.name.value;
89
90 if (param_substitutions.contains(name)) {
91 return generate(param_substitutions.at(name), builder);
92 }
93
94 if (name.starts_with("$")) {
95 std::string base_name = name.substr(1);
96 builder.addConstant(base_name);
97
98 if (auto buf_idx = get_buffer_index(base_name)) {
99 if (*buf_idx >= 0 && *buf_idx < num_intermediate_inputs) {
100 return Type::Clip;
101 }
102 }
103
104 if (is_clip_name(base_name)) {
105 return Type::Clip;
106 }
107 return Type::Value;
108 }
109
110 std::string var_name = name;
111 if (var_rename_map.contains(name)) {
112 var_name = var_rename_map.at(name);
113 } else if (expr.symbol) {
114 var_name = expr.symbol->name;
115 }
116
117 builder.addVariableLoad(var_name);
118
119 if (expr.symbol) {
120 return expr.symbol->type;
121 }
122
123 return Type::Value;
124}
125
126Type CodeGenerator::handle(const UnaryExpr& expr, PostfixBuilder& builder) {
127 if (expr.op.type == TokenType::Minus) {
128 if (auto* num = get_if<NumberExpr>(expr.right.get())) {
129 std::string val = num->value.value;
130 if (!val.starts_with("-")) {
131 val = std::format("-{}", val);
132 }
133 builder.addNumber(val);
134 return Type::Literal;
135 }
136 }
137
138 generate(expr.right.get(), builder);
139
140 if (expr.op.type == TokenType::Not) {
141 builder.addNumber("0");
142 builder.addOp(TokenType::Eq);
143 return Type::Value;
144 }
145
146 builder.addUnaryOp(expr.op.type);
147 return Type::Value;
148}
149
150Type CodeGenerator::handle(const BinaryExpr& expr, PostfixBuilder& builder) {
151 generate(expr.left.get(), builder);
152
153 if (expr.op.type == TokenType::LogicalAnd ||
154 expr.op.type == TokenType::LogicalOr) {
155 builder.addNumber("0");
156 builder.addOp(TokenType::Eq);
157 builder.addUnaryOp(TokenType::Not);
158
159 generate(expr.right.get(), builder);
160 builder.addNumber("0");
161 builder.addOp(TokenType::Eq);
162 builder.addUnaryOp(TokenType::Not);
163
164 builder.addOp(expr.op.type);
165 return Type::Value;
166 }
167
168 generate(expr.right.get(), builder);
169 builder.addOp(expr.op.type);
170 return Type::Value;
171}
172
173Type CodeGenerator::handle(const TernaryExpr& expr, PostfixBuilder& builder) {
174 generate(expr.cond.get(), builder);
175 builder.addNumber("0");
176 builder.addOp(TokenType::Eq);
177 builder.addUnaryOp(TokenType::Not);
178
179 PostfixBuilder true_branch_builder;
180 generate(expr.true_expr.get(), true_branch_builder);
181
182 PostfixBuilder false_branch_builder;
183 generate(expr.false_expr.get(), false_branch_builder);
184
185 builder.append(true_branch_builder);
186 builder.append(false_branch_builder);
187 builder.addTernaryOp();
188 return Type::Value;
189}
190
191Type CodeGenerator::handle(const CallExpr& expr, PostfixBuilder& builder) {
192 // User-defined functions are inlined
193 if (expr.resolved_signature != nullptr && expr.resolved_def != nullptr) {
194 const auto& sig = *expr.resolved_signature;
195 FunctionDef* func_def = expr.resolved_def;
196
197 inlineFunctionCall(sig, func_def, expr.args, expr.range, builder);
198 return Type::Value;
199 }
200
201 // Built-in functions
202 if (expr.resolved_builtin != nullptr) {
203 const BuiltinFunction* builtin = expr.resolved_builtin;
204
205 if (builtin->special_handler) {
206 builder.append(builtin->special_handler(this, expr));
207 return Type::Value;
208 }
209
210 for (size_t i = 0; i < expr.args.size(); ++i) {
211 if (builtin->param_types[i] != Type::LiteralString) {
212 generate(expr.args[i].get(), builder);
213 }
214 }
215 builder.addFunctionCall(expr.callee);
216 return Type::Value;
217 }
218
219 // nth_N functions
220 if (expr.callee.starts_with("nth_")) {
221 std::string n_str = expr.callee.substr(4);
222 int n = std::stoi(n_str);
223 int arg_count = static_cast<int>(expr.args.size());
224
225 for (const auto& arg : expr.args) {
226 generate(arg.get(), builder);
227 }
228 builder.addSortN(arg_count);
229 builder.addDropN(n - 1);
230 builder.addSwapN(arg_count - n);
231 builder.addDropN(arg_count - n);
232
233 return Type::Value;
234 }
235
236 std::unreachable();
237}
238
239Type CodeGenerator::handle(const PropAccessExpr& expr,
240 PostfixBuilder& builder) {
241 std::string clip_name = expr.clip.value;
242 if (param_substitutions.contains(clip_name)) {
243 auto* subst_expr = param_substitutions.at(clip_name);
244 clip_name = generateExprToString(subst_expr);
245 } else {
246 clip_name = clip_name.substr(1);
247 }
248
249 builder.addPropAccess(clip_name, expr.prop.value);
250 return Type::Value;
251}
252
253Type CodeGenerator::handle(const StaticRelPixelAccessExpr& expr,
254 PostfixBuilder& builder) {
255 std::string clip_name = expr.clip.value;
256 if (param_substitutions.contains(clip_name)) {
257 auto* subst_expr = param_substitutions.at(clip_name);
258 clip_name = generateExprToString(subst_expr);
259 } else {
260 clip_name = clip_name.substr(1);
261 }
262
263 builder.addStaticPixelAccess(clip_name, expr.offset_x.value,
264 expr.offset_y.value, expr.boundary_suffix);
265 return Type::Value;
266}
267
268Type CodeGenerator::handle(const FrameDimensionExpr& expr,
269 PostfixBuilder& builder) {
270 std::string plane_idx_str =
271 generateExprToString(expr.plane_index_expr.get());
272 builder.addFrameDimension(expr.dimension_name, plane_idx_str);
273 return Type::Value;
274}
275
276Type CodeGenerator::handle(const ArrayAccessExpr& expr,
277 PostfixBuilder& builder) {
278 // Get the array variable name
279 std::string array_name;
280 auto* var_expr = get_if<VariableExpr>(expr.array.get());
281 if (var_expr != nullptr) {
282 array_name = var_expr->name.value;
283
284 if (var_rename_map.contains(array_name)) {
285 array_name = var_rename_map.at(array_name);
286 } else if (expr.array_symbol) {
287 array_name = expr.array_symbol->name;
288 }
289 } else if (expr.array_symbol) {
290 array_name = expr.array_symbol->name;
291 }
292
293 generate(expr.index.get(), builder);
294 builder.addArrayLoad(array_name);
295
296 return Type::Value;
297}
298
299void CodeGenerator::handle(const ExprStmt& stmt, PostfixBuilder& builder) {
300#ifndef NDEBUG
301 PostfixBuilder temp_builder;
302 generate(stmt.expr.get(), temp_builder);
303 checkStackEffect(temp_builder.getExpression(), 0, stmt.range);
304 builder.append(temp_builder);
305#else
306 generate(stmt.expr.get(), builder);
307#endif
308}
309
310void CodeGenerator::handle(const AssignStmt& stmt, PostfixBuilder& builder) {
311 std::string name = stmt.name.value;
312
313 std::string var_name = name;
314 if (var_rename_map.contains(name)) {
315 var_name = var_rename_map.at(name);
316 } else if (stmt.symbol) {
317 var_name = stmt.symbol->name;
318 }
319
320 // Check if this is a new() or resize() call for array allocation
321 if (auto* call_expr = get_if<CallExpr>(stmt.value.get())) {
322 if (call_expr->callee == "new" || call_expr->callee == "resize") {
323#ifndef NDEBUG
324 PostfixBuilder b;
325#else
326 PostfixBuilder& b = builder;
327#endif
328 if (mode == Mode::Expr || mode == Mode::VkExpr) {
329 auto* arg_expr = call_expr->args[0].get();
330
331 std::string size_value;
332
333 if (auto* num_expr = get_if<NumberExpr>(arg_expr)) {
334 size_value = num_expr->value.value;
335 } else if (auto* var_expr = get_if<VariableExpr>(arg_expr)) {
336 std::string var_name_arg = var_expr->name.value;
337 if (param_substitutions.contains(var_name_arg)) {
338 auto* subst_expr = param_substitutions.at(var_name_arg);
339 if (auto* subst_num = get_if<NumberExpr>(subst_expr)) {
340 size_value = subst_num->value.value;
341 }
342 }
343 }
344
345 b.addArrayAllocStatic(var_name, size_value);
346 } else {
347 generate(call_expr->args[0].get(), b);
348 b.addArrayAllocDynamic(var_name);
349 }
350#ifndef NDEBUG
351 checkStackEffect(b.getExpression(), 0, stmt.range);
352 builder.append(b);
353#endif
354 return;
355 }
356 }
357
358#ifndef NDEBUG
359 PostfixBuilder b;
360 generate(stmt.value.get(), b);
361 b.addVariableStore(var_name);
362 checkStackEffect(b.getExpression(), 0, stmt.range);
363 builder.append(b);
364#else
365 generate(stmt.value.get(), builder);
366 builder.addVariableStore(var_name);
367#endif
368}
369
370void CodeGenerator::handle(const ArrayAssignStmt& stmt,
371 PostfixBuilder& builder) {
372 auto* array_access = get_if<ArrayAccessExpr>(stmt.target.get());
373
374 std::string array_name;
375 auto* var_expr = get_if<VariableExpr>(array_access->array.get());
376 if (var_expr != nullptr) {
377 array_name = var_expr->name.value;
378
379 if (var_rename_map.contains(array_name)) {
380 array_name = var_rename_map.at(array_name);
381 } else if (array_access->array_symbol) {
382 array_name = array_access->array_symbol->name;
383 }
384 } else if (array_access->array_symbol) {
385 array_name = array_access->array_symbol->name;
386 }
387
388#ifndef NDEBUG
389 PostfixBuilder b;
390 generate(stmt.value.get(), b);
391 generate(array_access->index.get(), b);
392 b.addArrayStore(array_name);
393 checkStackEffect(b.getExpression(), 0, stmt.range);
394 builder.append(b);
395#else
396 generate(stmt.value.get(), builder);
397 generate(array_access->index.get(), builder);
398 builder.addArrayStore(array_name);
399#endif
400}
401
402void CodeGenerator::handle(const BlockStmt& stmt, PostfixBuilder& builder) {
403 for (const auto& s : stmt.statements) {
404 generate(s.get(), builder);
405 }
406}
407
408void CodeGenerator::handle(const IfStmt& stmt, PostfixBuilder& builder) {
409 std::string else_label = std::format("__internal_else_{}", label_counter++);
410 std::string endif_label =
411 std::format("__internal_endif_{}", label_counter++);
412
413 generate(stmt.condition.get(), builder);
414 builder.addNumber("0");
415 builder.addOp(TokenType::Eq);
416 builder.addConditionalJump(else_label);
417
418 generate(stmt.then_branch.get(), builder);
419
420 if (stmt.else_branch) {
421 builder.addUnconditionalJump(endif_label);
422 builder.addLabel(else_label);
423 generate(stmt.else_branch.get(), builder);
424 builder.addLabel(endif_label);
425 } else {
426 builder.addLabel(else_label);
427 }
428}
429
430void CodeGenerator::handle(const WhileStmt& stmt, PostfixBuilder& builder) {
431 std::string start_label =
432 std::format("__internal_while_start_{}", label_counter++);
433 std::string end_label =
434 std::format("__internal_while_end_{}", label_counter++);
435
436 builder.addLabel(start_label);
437 generate(stmt.condition.get(), builder);
438 builder.addNumber("0");
439 builder.addOp(TokenType::Eq);
440 builder.addConditionalJump(end_label);
441 generate(stmt.body.get(), builder);
442 builder.addUnconditionalJump(start_label);
443 builder.addLabel(end_label);
444}
445
446void CodeGenerator::handle(const ReturnStmt& stmt, PostfixBuilder& builder) {
447 if (stmt.value) {
448 generate(stmt.value.get(), builder);
449 const int call_id = call_site_id_stack.back();
450 std::string ret_var_name = std::format("__internal_ret_{}_{}",
451 current_function->name, call_id);
452 builder.addVariableStore(ret_var_name);
453 }
454
455 const int call_id = call_site_id_stack.back();
456 std::string ret_label = std::format("__internal_ret_label_{}_{}",
457 current_function->name, call_id);
458 builder.addUnconditionalJump(ret_label);
459}
460
461void CodeGenerator::handle(const LabelStmt& stmt, PostfixBuilder& builder) {
462 std::string label_name = stmt.symbol ? stmt.symbol->name : stmt.name.value;
463 builder.addLabel(label_name);
464}
465
466void CodeGenerator::handle(const GotoStmt& stmt, PostfixBuilder& builder) {
467 std::string label_name = stmt.target_label_symbol
468 ? stmt.target_label_symbol->name
469 : stmt.label.value;
470 if (stmt.condition) {
471 generate(stmt.condition.get(), builder);
472 builder.addNumber("0");
473 builder.addOp(TokenType::Eq);
474 builder.addUnaryOp(TokenType::Not);
475 builder.addConditionalJump(label_name);
476 } else {
477 builder.addUnconditionalJump(label_name);
478 }
479}
480
481void CodeGenerator::handle([[maybe_unused]] const FunctionDef& stmt,
482 [[maybe_unused]] PostfixBuilder& builder) {
483 // Inlined at call sites
484}
485
486void CodeGenerator::handle([[maybe_unused]] const GlobalDecl& stmt,
487 [[maybe_unused]] PostfixBuilder& builder) {
488 // Global declarations don't generate code
489}
490
491void CodeGenerator::checkStackEffect([[maybe_unused]] const std::string& s,
492 [[maybe_unused]] int expected,
493 [[maybe_unused]] const Range& range) {
494#ifndef NDEBUG
495 int effect = computeStackEffect(s, range);
496 if (effect != expected) {
497 throw CodeGenError(std::format("Unbalanced stack. "
498 "Expected {}, got {}",
499 expected, effect),
500 range);
501 }
502#endif
503}
504
505int CodeGenerator::computeStackEffect(const std::string& s,
506 const Range& range) {
508 if (mode == Mode::Expr || mode == Mode::VkExpr) {
509 postfix_mode = PostfixMode::Expr;
510 }
511
512 try {
513 return compute_postfix_stack_effect(s, postfix_mode, range.start.line,
514 num_inputs,
515 num_intermediate_inputs);
516 } catch (const std::exception& e) {
517 throw CodeGenError(e.what(), range);
518 }
519}
520
521Expr* CodeGenerator::resolveParamSubstitution(Expr* expr) const {
522 Expr* resolved = expr;
523 std::set<const Expr*> seen;
524
525 while (resolved != nullptr) {
526 auto* var_expr = get_if<VariableExpr>(resolved);
527 if (var_expr == nullptr) {
528 break;
529 }
530
531 auto it = param_substitutions.find(var_expr->name.value);
532 if (it == param_substitutions.end() || it->second == nullptr) {
533 break;
534 }
535
536 if (!seen.insert(resolved).second || it->second == resolved) {
537 break;
538 }
539
540 resolved = it->second;
541 }
542
543 return resolved;
544}
545
546void CodeGenerator::inlineFunctionCall(
547 const FunctionSignature& sig, FunctionDef* func_def,
548 const std::vector<std::unique_ptr<Expr>>& args,
549 [[maybe_unused]] const Range& call_range, PostfixBuilder& builder) {
550
551 const auto& func_name = sig.name;
552 const int call_id = call_site_counter++;
553 call_site_id_stack.push_back(call_id);
554
555 // Save current state
556 auto saved_param_substitutions = param_substitutions;
557 auto saved_var_rename_map = var_rename_map;
558 const auto* saved_current_function = current_function;
559
560 std::map<std::string, Expr*> next_param_substitutions;
561 std::map<std::string, std::string> param_map;
562
563 PostfixBuilder param_assignments;
564
565 for (size_t i = 0; i < sig.params.size(); ++i) {
566 const auto& param_info = sig.params[i];
567 const std::string& param_name = param_info.name;
568 const Type param_type = param_info.type;
569
570 auto* arg_expr = args[i].get();
571
572 if (param_type == Type::Literal || param_type == Type::Clip) {
573 next_param_substitutions[param_name] =
574 resolveParamSubstitution(arg_expr);
575 } else if (param_type == Type::Array) {
576 auto* var_expr = get_if<VariableExpr>(arg_expr);
577 if (var_expr != nullptr) {
578 std::string arg_var_name = var_expr->name.value;
579 if (saved_var_rename_map.contains(arg_var_name)) {
580 param_map[param_name] =
581 saved_var_rename_map.at(arg_var_name);
582 } else if (var_expr->symbol) {
583 param_map[param_name] = var_expr->symbol->name;
584 } else {
585 param_map[param_name] = arg_var_name;
586 }
587 }
588 } else {
589 std::string renamed_param =
590 std::format("__internal_func_{}_"
591 "{}_{}",
592 func_name, call_id, param_name);
593
594 generate(args[i].get(), param_assignments);
595 param_assignments.addVariableStore(renamed_param);
596 param_map[param_name] = renamed_param;
597 }
598 }
599
600 // Collect local variables from function body
601 std::function<void(Stmt*)> collect_locals = [&](Stmt* stmt) {
602 if (!stmt) {
603 return;
604 }
605 if (auto* assign = get_if<AssignStmt>(stmt)) {
606 std::string var_name = assign->name.value;
607 if (!param_map.contains(var_name)) {
608 std::string renamed_var =
609 std::format("__"
610 "interna"
611 "l_func_"
612 "{}_{}_{"
613 "}",
614 func_name, call_id, var_name);
615 param_map[var_name] = renamed_var;
616 }
617 } else if (auto* block = get_if<BlockStmt>(stmt)) {
618 for (const auto& s : block->statements) {
619 collect_locals(s.get());
620 }
621 } else if (auto* if_stmt = get_if<IfStmt>(stmt)) {
622 collect_locals(if_stmt->then_branch.get());
623 if (if_stmt->else_branch) {
624 collect_locals(if_stmt->else_branch.get());
625 }
626 } else if (auto* while_stmt = get_if<WhileStmt>(stmt)) {
627 collect_locals(while_stmt->body.get());
628 }
629 };
630
631 for (const auto& s : func_def->body->statements) {
632 collect_locals(s.get());
633 }
634
635 // Update context for function body generation
636 param_substitutions = std::move(next_param_substitutions);
637 var_rename_map = param_map;
638 current_function = &sig;
639
640 std::string label_prefix =
641 std::format("__internal_{}_{}_", func_name, call_id);
642
643 PostfixBuilder body_builder;
644 handle(*func_def->body, body_builder);
645 body_builder.prefixLabels(label_prefix);
646
647 std::string ret_label =
648 std::format("__internal_ret_label_{}_{}", func_name, call_id);
649 body_builder.addLabel(ret_label);
650
651 if (sig.returns_value) {
652 std::string ret_var =
653 std::format("__internal_ret_{}_{}", func_name, call_id);
654 body_builder.addVariableLoad(ret_var);
655 }
656
657 // Restore state
658 param_substitutions = saved_param_substitutions;
659 var_rename_map = saved_var_rename_map;
660 current_function = saved_current_function;
661 call_site_id_stack.pop_back();
662
663 PostfixBuilder inlined_builder;
664 inlined_builder.append(param_assignments);
665 inlined_builder.append(body_builder);
666
667#ifndef NDEBUG
668 int expected_effect = sig.returns_value ? 1 : 0;
669 try {
670 int actual_effect =
671 computeStackEffect(inlined_builder.getExpression(), call_range);
672 if (actual_effect != expected_effect) {
673 throw CodeGenError(std::format("Function '{}' has "
674 "unbalanced stack. "
675 "Expected "
676 "effect: {}, "
677 "actual: {}",
678 func_name, expected_effect,
679 actual_effect),
680 sig.range);
681 }
682 } catch (const CodeGenError& e) {
683 throw CodeGenError(
684 std::format("In function '{}': {}", func_name, e.what()),
685 sig.range);
686 }
687#endif
688 builder.append(inlined_builder);
689}
690
691} // namespace infix2postfix
std::string generate(const Program *program)
ExprResult generateExpr(Expr *expr)
CodeGenerator(Mode mode, int num_inputs, int num_intermediate_inputs=0)
std::string getExpression() const
void addVariableLoad(const std::string &var_name)
std::optional< int > get_buffer_index(const std::string &s)
Definition types.hpp:267
auto get_if(Wrapper *wrapper) -> decltype(std::get_if< T >(&wrapper->value))
Definition AST.hpp:442
bool is_clip_name(const std::string &s)
Definition types.hpp:263
int compute_postfix_stack_effect(const std::string &postfix_expr, PostfixMode mode, int line, int num_inputs, int num_intermediate_inputs)
std::vector< std::unique_ptr< Stmt > > statements
Definition AST.hpp:415