295 std::vector<llvm::Value*>& rpn_stack,
296 llvm::Type* float_ty,
298 bool use_approx_math) {
299 auto apply_stack_op = [&]<
size_t ARITY>(
auto&& op) {
300 std::array<llvm::Value*, ARITY> args{};
301 for (
size_t i = ARITY; i > 0; --i) {
302 args.at(i - 1) = rpn_stack.back();
303 rpn_stack.pop_back();
305 rpn_stack.push_back(std::apply(op, args));
308 auto apply_intrinsic = [&]<
size_t ARITY>(llvm::Intrinsic::ID id) {
309 apply_stack_op.operator()<ARITY>(
313 auto apply_binary_op = [&](
auto op_callable) {
314 apply_stack_op.operator()<2>(
315 [&](
auto a,
auto b) {
return op_callable(a, b); });
318 auto apply_binary_cmp = [&](llvm::CmpInst::Predicate pred) {
319 apply_stack_op.operator()<2>([&](
auto a,
auto b) {
320 auto cmp =
builder.CreateFCmp(pred, a, b);
321 return builder.CreateSelect(cmp,
322 llvm::ConstantFP::get(float_ty, 1.0),
323 llvm::ConstantFP::get(float_ty, 0.0));
327 auto apply_logical_op = [&](
auto op) {
328 apply_stack_op.operator()<2>([&](
auto a_val,
auto b_val) {
329 auto a_bool =
builder.CreateFCmpOGT(
330 a_val, llvm::ConstantFP::get(float_ty, 0.0));
331 auto b_bool =
builder.CreateFCmpOGT(
332 b_val, llvm::ConstantFP::get(float_ty, 0.0));
333 auto logic_res = op(a_bool, b_bool);
334 return builder.CreateSelect(logic_res,
335 llvm::ConstantFP::get(float_ty, 1.0),
336 llvm::ConstantFP::get(float_ty, 0.0));
340 auto apply_bitwise_op = [&](
auto op) {
341 apply_stack_op.operator()<2>([&](
auto a,
auto b) {
344 auto ai =
builder.CreateFPToSI(a_rounded, i32_ty);
345 auto bi =
builder.CreateFPToSI(b_rounded, i32_ty);
346 auto resi = op(ai, bi);
347 return builder.CreateSIToFP(resi, float_ty);
351 auto apply_approx_math_op =
352 [&]<
size_t ARITY>(
MathOp math_op, llvm::Intrinsic::ID intrinsic_id) {
353 static_assert(ARITY == 1 || ARITY == 2,
354 "Only unary or binary operations supported");
356 std::array<llvm::Value*, ARITY> args{};
357 for (
size_t i = 0; i < ARITY; ++i) {
358 args.at(ARITY - 1 - i) = rpn_stack.back();
359 rpn_stack.pop_back();
362 if (use_approx_math) {
364 llvm::SmallVector<llvm::Value*, 2> call_args(args.begin(),
366 auto* call =
builder.CreateCall(callee, call_args);
367 call->setFastMathFlags(
builder.getFastMathFlags());
368 rpn_stack.push_back(call);
370 rpn_stack.push_back(std::apply(
378 switch (token.
type) {
380 const auto& payload = std::get<TokenPayloadNumber>(token.
payload);
381 rpn_stack.push_back(llvm::ConstantFP::get(float_ty, payload.value));
393 rpn_stack.push_back(
builder.CreateLoad(
398 rpn_stack.push_back(llvm::ConstantFP::get(float_ty, std::numbers::pi));
403 apply_binary_op([&](llvm::Value* a, llvm::Value* b) {
404 return builder.CreateFAdd(a, b);
408 apply_binary_op([&](llvm::Value* a, llvm::Value* b) {
409 return builder.CreateFSub(a, b);
413 apply_binary_op([&](llvm::Value* a, llvm::Value* b) {
414 return builder.CreateFMul(a, b);
418 apply_binary_op([&](llvm::Value* a, llvm::Value* b) {
419 return builder.CreateFDiv(a, b);
423 apply_binary_op([&](llvm::Value* a, llvm::Value* b) {
424 return builder.CreateFRem(a, b);
428 apply_intrinsic.operator()<2>(llvm::Intrinsic::pow);
432 llvm::Intrinsic::atan2);
435 apply_intrinsic.operator()<2>(llvm::Intrinsic::copysign);
438 apply_intrinsic.operator()<2>(llvm::Intrinsic::minnum);
441 apply_intrinsic.operator()<2>(llvm::Intrinsic::maxnum);
446 apply_binary_cmp(llvm::CmpInst::FCMP_OGT);
449 apply_binary_cmp(llvm::CmpInst::FCMP_OLT);
452 apply_binary_cmp(llvm::CmpInst::FCMP_OGE);
455 apply_binary_cmp(llvm::CmpInst::FCMP_OLE);
458 apply_binary_cmp(llvm::CmpInst::FCMP_OEQ);
464 [&](
auto a,
auto b) {
return builder.CreateAnd(a, b); });
468 [&](
auto a,
auto b) {
return builder.CreateOr(a, b); });
472 [&](
auto a,
auto b) {
return builder.CreateXor(a, b); });
478 [&](
auto a,
auto b) {
return builder.CreateAnd(a, b); });
482 [&](
auto a,
auto b) {
return builder.CreateOr(a, b); });
486 [&](
auto a,
auto b) {
return builder.CreateXor(a, b); });
491 auto* a = rpn_stack.back();
492 rpn_stack.pop_back();
493 auto* zero = llvm::ConstantFP::get(float_ty, 0.0);
500 apply_approx_math_op.operator()<1>(
MathOp::Exp, llvm::Intrinsic::exp);
503 apply_approx_math_op.operator()<1>(
MathOp::Log, llvm::Intrinsic::log);
506 apply_intrinsic.operator()<1>(llvm::Intrinsic::fabs);
509 apply_intrinsic.operator()<1>(llvm::Intrinsic::floor);
512 apply_intrinsic.operator()<1>(llvm::Intrinsic::ceil);
515 apply_intrinsic.operator()<1>(llvm::Intrinsic::trunc);
518 apply_intrinsic.operator()<1>(llvm::Intrinsic::round);
521 apply_approx_math_op.operator()<1>(
MathOp::Sin, llvm::Intrinsic::sin);
524 apply_approx_math_op.operator()<1>(
MathOp::Cos, llvm::Intrinsic::cos);
527 apply_approx_math_op.operator()<1>(
MathOp::Tan, llvm::Intrinsic::tan);
530 apply_approx_math_op.operator()<1>(
MathOp::Asin, llvm::Intrinsic::asin);
533 apply_approx_math_op.operator()<1>(
MathOp::Acos, llvm::Intrinsic::acos);
536 apply_approx_math_op.operator()<1>(
MathOp::Atan, llvm::Intrinsic::atan);
539 apply_intrinsic.operator()<1>(llvm::Intrinsic::exp2);
542 apply_intrinsic.operator()<1>(llvm::Intrinsic::log10);
545 apply_intrinsic.operator()<1>(llvm::Intrinsic::log2);
548 apply_intrinsic.operator()<1>(llvm::Intrinsic::sinh);
551 apply_intrinsic.operator()<1>(llvm::Intrinsic::cosh);
554 apply_intrinsic.operator()<1>(llvm::Intrinsic::tanh);
557 auto* x = rpn_stack.back();
558 rpn_stack.pop_back();
559 auto* zero = llvm::ConstantFP::get(float_ty, 0.0);
560 auto* one = llvm::ConstantFP::get(float_ty, 1.0);
561 auto* nonzero =
builder.CreateFCmpONE(x, zero);
562 auto* sign =
builder.CreateCall(
563 llvm::Intrinsic::getOrInsertDeclaration(
564 &
module, llvm::Intrinsic::copysign, {float_ty}),
566 rpn_stack.push_back(
builder.CreateSelect(nonzero, sign, zero));
570 auto* a = rpn_stack.back();
571 rpn_stack.pop_back();
572 rpn_stack.push_back(
builder.CreateFNeg(a));
576 auto* a = rpn_stack.back();
577 rpn_stack.pop_back();
578 rpn_stack.push_back(
builder.CreateSelect(
579 builder.CreateFCmpOLE(a, llvm::ConstantFP::get(float_ty, 0.0)),
580 llvm::ConstantFP::get(float_ty, 1.0),
581 llvm::ConstantFP::get(float_ty, 0.0)));
585 auto* a = rpn_stack.back();
586 rpn_stack.pop_back();
588 rpn_stack.push_back(
builder.CreateSIToFP(
596 auto* c = rpn_stack.back();
597 rpn_stack.pop_back();
598 auto* b = rpn_stack.back();
599 rpn_stack.pop_back();
600 auto* a = rpn_stack.back();
601 rpn_stack.pop_back();
602 rpn_stack.push_back(
builder.CreateSelect(
603 builder.CreateFCmpOGT(a, llvm::ConstantFP::get(float_ty, 0.0)), b,
609 auto* max_val = rpn_stack.back();
610 rpn_stack.pop_back();
611 auto* min_val = rpn_stack.back();
612 rpn_stack.pop_back();
613 auto* val = rpn_stack.back();
614 rpn_stack.pop_back();
618 rpn_stack.push_back(clamped);
622 auto* c = rpn_stack.back();
623 rpn_stack.pop_back();
624 auto* b = rpn_stack.back();
625 rpn_stack.pop_back();
626 auto* a = rpn_stack.back();
627 rpn_stack.pop_back();
628 rpn_stack.push_back(
builder.CreateCall(
629 llvm::Intrinsic::getOrInsertDeclaration(
630 &
module, llvm::Intrinsic::fma, {builder.getFloatTy()}),
637 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
638 rpn_stack.push_back(rpn_stack[rpn_stack.size() - 1 - payload.n]);
642 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
644 rpn_stack.resize(rpn_stack.size() - payload.n);
649 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
650 std::swap(rpn_stack.back(),
651 rpn_stack[rpn_stack.size() - 1 - payload.n]);
655 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
661 std::vector<llvm::Value*> values;
663 for (
int k = 0; k < n; ++k) {
664 values.push_back(rpn_stack.back());
665 rpn_stack.pop_back();
668 auto compare_swap = [&](
int i_idx,
int j_idx) {
669 llvm::Value* val_i = values[i_idx];
670 llvm::Value* val_j = values[j_idx];
671 llvm::Value* cond =
builder.CreateFCmpOGT(val_i, val_j);
672 values[i_idx] =
builder.CreateSelect(cond, val_j, val_i);
673 values[j_idx] =
builder.CreateSelect(cond, val_i, val_j);
677 for (
const auto& pair : network) {
678 compare_swap(pair.first, pair.second);
681 for (
int k = n - 1; k >= 0; --k) {
682 rpn_stack.push_back(values[k]);
688 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
692 llvm::ConstantFP::get(
builder.getFloatTy(), 0.0));
696 std::vector<llvm::Value*> values(n);
697 for (
int i = 0; i < n; ++i) {
698 values[i] = rpn_stack.back();
699 rpn_stack.pop_back();
706 std::vector<Node> current_level;
707 current_level.reserve(n);
708 for (
int i = 0; i < n; ++i) {
709 current_level.push_back(
711 llvm::ConstantFP::get(
builder.getFloatTy(),
712 static_cast<double>(n - 1 - i))});
717 while (current_level.size() > 1) {
718 std::vector<Node> next_level;
719 for (
size_t i = 0; i < current_level.size(); i += 2) {
720 if (i + 1 < current_level.size()) {
721 const auto& left = current_level[i];
722 const auto& right = current_level[i + 1];
724 llvm::Value* cmp_val =
725 is_max ?
builder.CreateFCmpOGT(left.val, right.val)
726 :
builder.CreateFCmpOLT(left.val, right.val);
728 llvm::Value* eq_val =
729 builder.CreateFCmpOEQ(left.val, right.val);
730 llvm::Value* cmp_idx =
731 builder.CreateFCmpOLT(left.idx, right.idx);
732 llvm::Value* tie_break =
builder.CreateAnd(eq_val, cmp_idx);
733 llvm::Value* cond =
builder.CreateOr(cmp_val, tie_break);
735 next_level.push_back(
736 {
builder.CreateSelect(cond, left.val, right.val),
737 builder.CreateSelect(cond, left.idx, right.idx)});
739 next_level.push_back(current_level[i]);
742 current_level = std::move(next_level);
744 rpn_stack.push_back(current_level[0].idx);
748 const auto& payload = std::get<TokenPayloadStackOp>(token.
payload);
754 rpn_stack.pop_back();
756 llvm::ConstantFP::get(
builder.getFloatTy(), 0.0));
760 std::vector<llvm::Value*> values(n);
761 std::vector<llvm::Value*> indices(n);
762 for (
int i = 0; i < n; ++i) {
763 values[i] = rpn_stack.back();
764 rpn_stack.pop_back();
765 indices[i] = llvm::ConstantFP::get(
builder.getFloatTy(),
766 static_cast<double>(n - 1 - i));
770 for (
const auto& pair : network) {
772 int i2 = pair.second;
774 llvm::Value* v1 = values[i1];
775 llvm::Value* v2 = values[i2];
776 llvm::Value* idx1 = indices[i1];
777 llvm::Value* idx2 = indices[i2];
779 llvm::Value* cmp_val =
builder.CreateFCmpOGT(v1, v2);
780 llvm::Value* eq_val =
builder.CreateFCmpOEQ(v1, v2);
781 llvm::Value* cmp_idx =
builder.CreateFCmpOGT(idx1, idx2);
782 llvm::Value* tie_break =
builder.CreateAnd(eq_val, cmp_idx);
783 llvm::Value* cond =
builder.CreateOr(cmp_val, tie_break);
785 values[i1] =
builder.CreateSelect(cond, v2, v1);
786 values[i2] =
builder.CreateSelect(cond, v1, v2);
787 indices[i1] =
builder.CreateSelect(cond, idx2, idx1);
788 indices[i2] =
builder.CreateSelect(cond, idx1, idx2);
791 for (
int i = n - 1; i >= 0; --i) {
792 rpn_stack.push_back(indices[i]);
809 llvm::Value* x_fp, llvm::Value* y_fp,
810 bool no_x_bounds_check) {
811 llvm::Type* float_ty =
builder.getFloatTy();
812 llvm::Type* i32_ty =
builder.getInt32Ty();
813 llvm::Function* parent_func =
builder.GetInsertBlock()->getParent();
815 bool use_approx_math =
false;
817 use_approx_math =
true;
820 use_approx_math =
true;
828 std::unordered_map<std::string, llvm::Value*> named_vars;
831 for (
const std::string& var_name : all_vars) {
835 std::map<int, llvm::BasicBlock*> llvm_blocks;
840 for (
int i = 0; i < static_cast<int>(cfg_blocks.size()); ++i) {
841 std::string name = std::format(
"b{}", i);
842 for (
const auto& [label_name, block_idx] : label_to_block_idx) {
843 if (block_idx == i) {
848 llvm_blocks[i] = llvm::BasicBlock::Create(
context, name, parent_func);
850 llvm::BasicBlock* exit_bb =
851 llvm::BasicBlock::Create(
context,
"exit", parent_func);
854 builder.CreateBr(llvm_blocks[0]);
857 std::map<int, std::vector<llvm::Value*>> block_initial_stacks;
858 for (
int i = 0; i < static_cast<int>(cfg_blocks.size()); ++i) {
859 if (cfg_blocks[i].predecessors.size() > 1) {
860 builder.SetInsertPoint(llvm_blocks[i]);
861 std::vector<llvm::Value*> initial_stack;
862 int depth = stack_depth_in[i];
863 initial_stack.reserve(depth);
864 for (
int j = 0; j < depth; ++j) {
865 initial_stack.push_back(
builder.CreatePHI(
866 float_ty, cfg_blocks[i].predecessors.size()));
868 block_initial_stacks[i] = initial_stack;
873 std::map<int, std::vector<llvm::Value*>> block_final_stacks;
875 for (
int i = 0; i < static_cast<int>(cfg_blocks.size()); ++i) {
876 const auto& block_info = cfg_blocks[i];
877 builder.SetInsertPoint(llvm_blocks[i]);
879 std::vector<llvm::Value*> rpn_stack;
880 if (block_info.predecessors.empty()) {
882 }
else if (block_info.predecessors.size() == 1) {
883 int pred_idx = block_info.predecessors[0];
884 if (block_final_stacks.contains(pred_idx)) {
885 rpn_stack = block_final_stacks.at(pred_idx);
888 rpn_stack = block_initial_stacks.at(i);
891 for (
int j = block_info.start_token_idx; j < block_info.end_token_idx;
893 const auto& token =
tokens[j];
903 const auto& payload = std::get<TokenPayloadVar>(token.payload);
904 llvm::Value* val_to_store = rpn_stack.back();
905 rpn_stack.pop_back();
906 llvm::Value* var_ptr = named_vars[payload.name];
907 builder.CreateStore(val_to_store, var_ptr);
911 const auto& payload = std::get<TokenPayloadVar>(token.payload);
912 llvm::Value* var_ptr = named_vars[payload.name];
913 rpn_stack.push_back(
builder.CreateLoad(float_ty, var_ptr));
919 no_x_bounds_check)) {
920 throw std::runtime_error(std::format(
921 "Unhandled token type: {}",
static_cast<int>(token.type)));
926 if (block_info.successors.empty()) {
928 }
else if (block_info.successors.size() == 1) {
929 builder.CreateBr(llvm_blocks[block_info.successors[0]]);
931 llvm::Value* cond_val = rpn_stack.back();
932 llvm::Value* cond =
builder.CreateFCmpOGT(
933 cond_val, llvm::ConstantFP::get(float_ty, 0.0));
934 builder.CreateCondBr(cond, llvm_blocks[block_info.successors[0]],
935 llvm_blocks[block_info.successors[1]]);
936 rpn_stack.pop_back();
939 block_final_stacks[i] = rpn_stack;
943 for (
int i = 0; i < static_cast<int>(cfg_blocks.size()); ++i) {
944 if (cfg_blocks[i].predecessors.size() > 1) {
945 auto& phis = block_initial_stacks.at(i);
946 for (
int pred_idx : cfg_blocks[i].predecessors) {
947 auto& incoming_stack = block_final_stacks.at(pred_idx);
948 auto* incoming_block = llvm_blocks.at(pred_idx);
949 for (
size_t j = 0; j < phis.size(); ++j) {
950 if (j < incoming_stack.size()) {
951 llvm::cast<llvm::PHINode>(phis[j])->addIncoming(
952 incoming_stack[j], incoming_block);
960 builder.SetInsertPoint(exit_bb);
961 std::vector<std::pair<llvm::Value*, llvm::BasicBlock*>> final_values;
962 for (
int i = 0; i < static_cast<int>(cfg_blocks.size()); ++i) {
963 if (cfg_blocks[i].successors.empty()) {
964 auto& stack = block_final_stacks.at(i);
965 if (!stack.empty()) {
966 final_values.emplace_back(stack.back(), llvm_blocks.at(i));
971 llvm::Value* result_val =
nullptr;
972 if (final_values.empty()) {
973 result_val = llvm::UndefValue::get(float_ty);
974 }
else if (final_values.size() == 1) {
975 result_val = final_values[0].first;
978 builder.CreatePHI(float_ty, final_values.size(),
"result_phi");
979 for (
const auto& pair : final_values) {
980 phi->addIncoming(pair.first, pair.second);