28#include <system_error>
32#include "llvm/IR/Attributes.h"
33#include "llvm/IR/LLVMContext.h"
34#include "llvm/IR/LegacyPassManager.h"
35#include "llvm/IR/Module.h"
36#include "llvm/IR/PassInstrumentation.h"
37#include "llvm/IR/PassTimingInfo.h"
38#include "llvm/IR/Verifier.h"
39#include "llvm/Passes/PassBuilder.h"
40#include "llvm/Support/FileSystem.h"
41#include "llvm/Support/raw_ostream.h"
49bool is_env_var_enabled(
const char* value) {
50 if (value ==
nullptr || *value ==
'\0') {
54 std::string normalized;
55 for (
const char* ptr = value; *ptr !=
'\0'; ++ptr) {
56 const auto ch =
static_cast<unsigned char>(*ptr);
57 if (std::isspace(ch) != 0) {
60 normalized.push_back(
static_cast<char>(std::tolower(ch)));
63 if (normalized.empty()) {
67 return normalized !=
"0" && normalized !=
"false" && normalized !=
"off" &&
71bool is_time_passes_enabled() {
72 return is_env_var_enabled(std::getenv(
"LLVMEXPR_TIME_PASSES"));
78 std::vector<Token> tokens_in,
const VSVideoInfo* out_vi,
79 const std::vector<const VSVideoInfo*>& in_vi,
int width_in,
int height_in,
80 bool mirror, std::string dump_path,
81 const std::map<std::pair<int, std::string>,
int>& p_map,
82 std::string function_name,
int opt_level_in,
int approx_math_in,
84 int tile_x_in,
int tile_y_in,
ExprMode mode,
85 const std::vector<std::string>& output_props)
86 : tokens(std::move(tokens_in)), vo(out_vi), vi(in_vi),
87 num_inputs(static_cast<int>(in_vi.size())), width(width_in),
88 height(height_in), mirror_boundary(mirror),
89 dump_ir_path(std::move(dump_path)), prop_map(p_map),
90 func_name(std::move(function_name)), opt_level(opt_level_in),
91 approx_math(approx_math_in), tile_x(tile_x_in), tile_y(tile_y_in),
92 expr_mode(mode), output_props(output_props),
93 analysis_results(analysis_results_in) {}
96 if (approx_math == 2) {
97 return compileWithApproxMath(1);
99 return compileWithApproxMath(approx_math);
103 bool needs_nans =
false;
105 needs_nans = std::ranges::any_of(tokens, [](
const auto& token) {
110 needs_nans = std::ranges::any_of(tokens, [](
const auto& token) {
118 VectorizationDiagnosticHandler diagnostic_handler;
119 diagnostic_handler.
reset();
122 auto context = std::make_unique<llvm::LLVMContext>();
123 context->setDiagnosticHandlerCallBack(
125 &diagnostic_handler);
127 auto module = std::make_unique<llvm::Module>("ExprJITModule", *context);
128 module->setDataLayout(jit.getDataLayout());
131 llvm::IRBuilder<> builder(*context);
132 llvm::FastMathFlags fmf;
134 fmf.setNoNaNs(!needs_nans);
135 builder.setFastMathFlags(fmf);
138 MathLibraryManager math_manager(module.get(), *context);
141 std::unique_ptr<IRGeneratorBase> ir_gen;
143 ir_gen = std::make_unique<ExprIRGenerator>(
144 tokens, vo, vi, width, height, mirror_boundary, prop_map,
145 analysis_results, *context, *module, builder, math_manager,
146 func_name, actual_approx_math, tile_x, tile_y);
148 ir_gen = std::make_unique<SingleExprIRGenerator>(
149 tokens, vo, vi, mirror_boundary, prop_map, output_props,
150 analysis_results, *context, *module, builder, math_manager,
151 func_name, actual_approx_math);
156 llvm::Function* func =
module->getFunction(func_name);
157 if (func ==
nullptr) {
158 throw std::runtime_error(
"Failed to find generated function");
161 llvm::AttrBuilder func_attrs(func->getContext());
162 if (fmf.allowContract()) {
163 func_attrs.addAttribute(
"fp-contract",
"fast");
165 if (fmf.approxFunc()) {
166 func_attrs.addAttribute(
"approx-func-fp-math",
"true");
169 func_attrs.addAttribute(
"no-infs-fp-math",
"true");
172 func_attrs.addAttribute(
"no-nans-fp-math",
"true");
174 if (fmf.noSignedZeros()) {
175 func_attrs.addAttribute(
"no-signed-zeros-fp-math",
"true");
177 if (fmf.allowReciprocal()) {
178 func_attrs.addAttribute(
"allow-reciprocal-fp-math",
"true");
182 func_attrs.addAttribute(
"no-stack-arg-probe",
"true");
184 func_attrs.addAttribute(llvm::Attribute::NoUnwind);
185 func_attrs.addAttribute(llvm::Attribute::WillReturn);
186 func->addFnAttrs(func_attrs);
189 if (llvm::verifyModule(*module, &llvm::errs())) {
190 module->print(llvm::errs(), nullptr);
191 throw std::runtime_error(
"LLVM module verification failed (pre-opt).");
195 std::string plane_specific_dump_path;
196 if (!dump_ir_path.empty()) {
197 plane_specific_dump_path = dump_ir_path;
198 size_t dot_pos = plane_specific_dump_path.rfind(
'.');
199 if (dot_pos != std::string::npos) {
200 plane_specific_dump_path.insert(dot_pos,
"." + func_name);
202 plane_specific_dump_path +=
"." + func_name;
206 std::string pre_path = plane_specific_dump_path +
".pre.ll";
207 llvm::raw_fd_ostream dest_pre(pre_path, ec, llvm::sys::fs::OF_None);
209 module->print(dest_pre, nullptr);
216 llvm::LoopAnalysisManager lam;
217 llvm::FunctionAnalysisManager fam;
218 llvm::CGSCCAnalysisManager cgam;
219 llvm::ModuleAnalysisManager mam;
221 const bool time_passes_enabled = is_time_passes_enabled();
222 llvm::PassInstrumentationCallbacks pic;
223 llvm::TimePassesHandler time_passes_handler(time_passes_enabled);
225 if (time_passes_enabled) {
226 time_passes_handler.registerCallbacks(pic);
229 llvm::PassBuilder pb(
nullptr, llvm::PipelineTuningOptions(), {}, &pic);
230 pb.registerModuleAnalyses(mam);
231 pb.registerFunctionAnalyses(fam);
232 pb.registerCGSCCAnalyses(cgam);
233 pb.registerLoopAnalyses(lam);
234 pb.crossRegisterProxies(lam, fam, cgam, mam);
236 llvm::ModulePassManager mpm;
237 std::string pipeline;
239 pipeline =
"default<O3>";
240 for (
int i = 1; i < opt_level; ++i) {
241 pipeline +=
",default<O3>";
244 if (
auto err = pb.parsePassPipeline(mpm, pipeline)) {
245 llvm::errs() <<
"Failed to parse '" << pipeline
246 <<
"' pipeline: " << llvm::toString(std::move(err))
248 throw std::runtime_error(
249 "Failed to create default optimization pipeline.");
251 mpm.run(*module, mam);
253 if (time_passes_enabled) {
254 time_passes_handler.print();
259 if (llvm::verifyModule(*module, &llvm::errs())) {
260 module->print(llvm::errs(), nullptr);
261 throw std::runtime_error(
"LLVM module verification failed.");
265 if (!plane_specific_dump_path.empty()) {
267 llvm::raw_fd_ostream dest(plane_specific_dump_path, ec,
268 llvm::sys::fs::OF_None);
270 throw std::runtime_error(
"Could not open file: " + ec.message() +
271 " for writing IR to " +
272 plane_specific_dump_path);
274 module->print(dest, nullptr);
280 actual_approx_math == 1) {
282 std::vector<Token>(tokens), vo, vi, width, height, mirror_boundary,
283 dump_ir_path, prop_map, func_name, opt_level, approx_math,
284 analysis_results, tile_x, tile_y, expr_mode, output_props);
285 return fallback_compiler.compileWithApproxMath(0);
289 jit.
addModule(std::move(module), std::move(context));
292 if (func_addr ==
nullptr) {
293 throw std::runtime_error(
"Failed to get JIT'd function address.");
296 CompiledFunction compiled;
OrcJit global_jit_nan_safe(false)
OrcJit global_jit_fast(true)
void(*)(void *context, uint8_t **rwptrs, const int *strides, float *props) ProcessProc
CompiledFunction compile()
Compiler(std::vector< Token > tokens_in, const VSVideoInfo *out_vi, const std::vector< const VSVideoInfo * > &in_vi, int width_in, int height_in, bool mirror, std::string dump_path, const std::map< std::pair< int, std::string >, int > &p_map, std::string function_name, int opt_level_in, int approx_math_in, const analysis::ExpressionAnalysisResults &analysis_results_in, int tile_x_in=0, int tile_y_in=0, ExprMode mode=ExprMode::Expr, const std::vector< std::string > &output_props={})
void * getFunctionAddress(const std::string &name)
void addModule(std::unique_ptr< llvm::Module > m, std::unique_ptr< llvm::LLVMContext > ctx)
static void diagnosticHandlerCallback(const llvm::DiagnosticInfo *di, void *context)
bool hasVectorizationFailed() const