VapourSynth-llvmexpr
Loading...
Searching...
No Matches
Compiler.cpp
Go to the documentation of this file.
1
19
20#include "Compiler.hpp"
21
22#include <algorithm>
23#include <cctype>
24#include <cstdlib>
25#include <memory>
26#include <stdexcept>
27#include <string>
28#include <system_error>
29#include <utility>
30#include <vector>
31
32#include "llvm/IR/Attributes.h"
33#include "llvm/IR/LLVMContext.h"
34#include "llvm/IR/LegacyPassManager.h"
35#include "llvm/IR/Module.h"
36#include "llvm/IR/PassInstrumentation.h"
37#include "llvm/IR/PassTimingInfo.h"
38#include "llvm/IR/Verifier.h"
39#include "llvm/Passes/PassBuilder.h"
40#include "llvm/Support/FileSystem.h"
41#include "llvm/Support/raw_ostream.h"
42
46
47namespace {
48
49bool is_env_var_enabled(const char* value) {
50 if (value == nullptr || *value == '\0') {
51 return false;
52 }
53
54 std::string normalized;
55 for (const char* ptr = value; *ptr != '\0'; ++ptr) {
56 const auto ch = static_cast<unsigned char>(*ptr);
57 if (std::isspace(ch) != 0) {
58 continue;
59 }
60 normalized.push_back(static_cast<char>(std::tolower(ch)));
61 }
62
63 if (normalized.empty()) {
64 return false;
65 }
66
67 return normalized != "0" && normalized != "false" && normalized != "off" &&
68 normalized != "no";
69}
70
71bool is_time_passes_enabled() {
72 return is_env_var_enabled(std::getenv("LLVMEXPR_TIME_PASSES"));
73}
74
75} // namespace
76
78 std::vector<Token> tokens_in, const VSVideoInfo* out_vi,
79 const std::vector<const VSVideoInfo*>& in_vi, int width_in, int height_in,
80 bool mirror, std::string dump_path,
81 const std::map<std::pair<int, std::string>, int>& p_map,
82 std::string function_name, int opt_level_in, int approx_math_in,
83 const analysis::ExpressionAnalysisResults& analysis_results_in,
84 int tile_x_in, int tile_y_in, ExprMode mode,
85 const std::vector<std::string>& output_props)
86 : tokens(std::move(tokens_in)), vo(out_vi), vi(in_vi),
87 num_inputs(static_cast<int>(in_vi.size())), width(width_in),
88 height(height_in), mirror_boundary(mirror),
89 dump_ir_path(std::move(dump_path)), prop_map(p_map),
90 func_name(std::move(function_name)), opt_level(opt_level_in),
91 approx_math(approx_math_in), tile_x(tile_x_in), tile_y(tile_y_in),
92 expr_mode(mode), output_props(output_props),
93 analysis_results(analysis_results_in) {}
94
96 if (approx_math == 2) {
97 return compileWithApproxMath(1);
98 }
99 return compileWithApproxMath(approx_math);
100}
101
102CompiledFunction Compiler::compileWithApproxMath(int actual_approx_math) {
103 bool needs_nans = false;
104 if (expr_mode == ExprMode::Expr) {
105 needs_nans = std::ranges::any_of(tokens, [](const auto& token) {
106 return token.type == TokenType::ExitNoWrite ||
107 token.type == TokenType::PropExists;
108 });
109 } else if (expr_mode == ExprMode::SingleExpr) {
110 needs_nans = std::ranges::any_of(tokens, [](const auto& token) {
111 return token.type == TokenType::PropStore ||
112 token.type == TokenType::PropExists;
113 });
114 }
115
116 OrcJit& jit = needs_nans ? global_jit_nan_safe : global_jit_fast;
117
118 VectorizationDiagnosticHandler diagnostic_handler;
119 diagnostic_handler.reset();
120
121 // Create LLVM context and module
122 auto context = std::make_unique<llvm::LLVMContext>();
123 context->setDiagnosticHandlerCallBack(
125 &diagnostic_handler);
126
127 auto module = std::make_unique<llvm::Module>("ExprJITModule", *context);
128 module->setDataLayout(jit.getDataLayout());
129
130 // Set up fast math flags
131 llvm::IRBuilder<> builder(*context);
132 llvm::FastMathFlags fmf;
133 fmf.setFast();
134 fmf.setNoNaNs(!needs_nans);
135 builder.setFastMathFlags(fmf);
136
137 // Create math library manager
138 MathLibraryManager math_manager(module.get(), *context);
139
140 // Create IR generator and generate code
141 std::unique_ptr<IRGeneratorBase> ir_gen;
142 if (expr_mode == ExprMode::Expr) {
143 ir_gen = std::make_unique<ExprIRGenerator>(
144 tokens, vo, vi, width, height, mirror_boundary, prop_map,
145 analysis_results, *context, *module, builder, math_manager,
146 func_name, actual_approx_math, tile_x, tile_y);
147 } else {
148 ir_gen = std::make_unique<SingleExprIRGenerator>(
149 tokens, vo, vi, mirror_boundary, prop_map, output_props,
150 analysis_results, *context, *module, builder, math_manager,
151 func_name, actual_approx_math);
152 }
153 ir_gen->generate();
154
155 // Get the generated function and set attributes
156 llvm::Function* func = module->getFunction(func_name);
157 if (func == nullptr) {
158 throw std::runtime_error("Failed to find generated function");
159 }
160
161 llvm::AttrBuilder func_attrs(func->getContext());
162 if (fmf.allowContract()) {
163 func_attrs.addAttribute("fp-contract", "fast");
164 }
165 if (fmf.approxFunc()) {
166 func_attrs.addAttribute("approx-func-fp-math", "true");
167 }
168 if (fmf.noInfs()) {
169 func_attrs.addAttribute("no-infs-fp-math", "true");
170 }
171 if (fmf.noNaNs()) {
172 func_attrs.addAttribute("no-nans-fp-math", "true");
173 }
174 if (fmf.noSignedZeros()) {
175 func_attrs.addAttribute("no-signed-zeros-fp-math", "true");
176 }
177 if (fmf.allowReciprocal()) {
178 func_attrs.addAttribute("allow-reciprocal-fp-math", "true");
179 }
180#ifdef _WIN32
181 // Fix for missing ___chkstk_ms symbol
182 func_attrs.addAttribute("no-stack-arg-probe", "true");
183#endif
184 func_attrs.addAttribute(llvm::Attribute::NoUnwind);
185 func_attrs.addAttribute(llvm::Attribute::WillReturn);
186 func->addFnAttrs(func_attrs);
187
188 // Verify module before optimization
189 if (llvm::verifyModule(*module, &llvm::errs())) {
190 module->print(llvm::errs(), nullptr);
191 throw std::runtime_error("LLVM module verification failed (pre-opt).");
192 }
193
194 // Dump pre-optimization IR if requested
195 std::string plane_specific_dump_path;
196 if (!dump_ir_path.empty()) {
197 plane_specific_dump_path = dump_ir_path;
198 size_t dot_pos = plane_specific_dump_path.rfind('.');
199 if (dot_pos != std::string::npos) {
200 plane_specific_dump_path.insert(dot_pos, "." + func_name);
201 } else {
202 plane_specific_dump_path += "." + func_name;
203 }
204
205 std::error_code ec;
206 std::string pre_path = plane_specific_dump_path + ".pre.ll";
207 llvm::raw_fd_ostream dest_pre(pre_path, ec, llvm::sys::fs::OF_None);
208 if (!ec) {
209 module->print(dest_pre, nullptr);
210 dest_pre.flush();
211 }
212 }
213
214 // Run optimization passes
215 {
216 llvm::LoopAnalysisManager lam;
217 llvm::FunctionAnalysisManager fam;
218 llvm::CGSCCAnalysisManager cgam;
219 llvm::ModuleAnalysisManager mam;
220
221 const bool time_passes_enabled = is_time_passes_enabled();
222 llvm::PassInstrumentationCallbacks pic;
223 llvm::TimePassesHandler time_passes_handler(time_passes_enabled);
224
225 if (time_passes_enabled) {
226 time_passes_handler.registerCallbacks(pic);
227 }
228
229 llvm::PassBuilder pb(nullptr, llvm::PipelineTuningOptions(), {}, &pic);
230 pb.registerModuleAnalyses(mam);
231 pb.registerFunctionAnalyses(fam);
232 pb.registerCGSCCAnalyses(cgam);
233 pb.registerLoopAnalyses(lam);
234 pb.crossRegisterProxies(lam, fam, cgam, mam);
235
236 llvm::ModulePassManager mpm;
237 std::string pipeline;
238 if (opt_level > 0) {
239 pipeline = "default<O3>";
240 for (int i = 1; i < opt_level; ++i) {
241 pipeline += ",default<O3>";
242 }
243 }
244 if (auto err = pb.parsePassPipeline(mpm, pipeline)) {
245 llvm::errs() << "Failed to parse '" << pipeline
246 << "' pipeline: " << llvm::toString(std::move(err))
247 << "\n";
248 throw std::runtime_error(
249 "Failed to create default optimization pipeline.");
250 }
251 mpm.run(*module, mam);
252
253 if (time_passes_enabled) {
254 time_passes_handler.print();
255 }
256 }
257
258 // Verify module after optimization
259 if (llvm::verifyModule(*module, &llvm::errs())) {
260 module->print(llvm::errs(), nullptr);
261 throw std::runtime_error("LLVM module verification failed.");
262 }
263
264 // Dump post-optimization IR if requested
265 if (!plane_specific_dump_path.empty()) {
266 std::error_code ec;
267 llvm::raw_fd_ostream dest(plane_specific_dump_path, ec,
268 llvm::sys::fs::OF_None);
269 if (ec) {
270 throw std::runtime_error("Could not open file: " + ec.message() +
271 " for writing IR to " +
272 plane_specific_dump_path);
273 }
274 module->print(dest, nullptr);
275 dest.flush();
276 }
277
278 // Handle vectorization fallback
279 if (diagnostic_handler.hasVectorizationFailed() && approx_math == 2 &&
280 actual_approx_math == 1) {
281 Compiler fallback_compiler(
282 std::vector<Token>(tokens), vo, vi, width, height, mirror_boundary,
283 dump_ir_path, prop_map, func_name, opt_level, approx_math,
284 analysis_results, tile_x, tile_y, expr_mode, output_props);
285 return fallback_compiler.compileWithApproxMath(0);
286 }
287
288 // Add module to JIT and get function address
289 jit.addModule(std::move(module), std::move(context));
290 void* func_addr = jit.getFunctionAddress(func_name);
291
292 if (func_addr == nullptr) {
293 throw std::runtime_error("Failed to get JIT'd function address.");
294 }
295
296 CompiledFunction compiled;
297 compiled.func_ptr =
298 reinterpret_cast< // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
299 ProcessProc>(func_addr);
300 return compiled;
301}
OrcJit global_jit_nan_safe(false)
OrcJit global_jit_fast(true)
void(*)(void *context, uint8_t **rwptrs, const int *strides, float *props) ProcessProc
Definition Jit.hpp:35
ExprMode
CompiledFunction compile()
Definition Compiler.cpp:95
Compiler(std::vector< Token > tokens_in, const VSVideoInfo *out_vi, const std::vector< const VSVideoInfo * > &in_vi, int width_in, int height_in, bool mirror, std::string dump_path, const std::map< std::pair< int, std::string >, int > &p_map, std::string function_name, int opt_level_in, int approx_math_in, const analysis::ExpressionAnalysisResults &analysis_results_in, int tile_x_in=0, int tile_y_in=0, ExprMode mode=ExprMode::Expr, const std::vector< std::string > &output_props={})
Definition Compiler.cpp:77
void * getFunctionAddress(const std::string &name)
Definition Jit.cpp:198
void addModule(std::unique_ptr< llvm::Module > m, std::unique_ptr< llvm::LLVMContext > ctx)
Definition Jit.cpp:157
static void diagnosticHandlerCallback(const llvm::DiagnosticInfo *di, void *context)
ProcessProc func_ptr
Definition Jit.hpp:39