VapourSynth-llvmexpr
Loading...
Searching...
No Matches
Jit.cpp
Go to the documentation of this file.
1
19
20#include "Jit.hpp"
21
22#ifdef _WIN32
23#include <cmath>
24#endif
25#include <stdexcept>
26#include <string>
27#include <vector>
28
29#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
31#include "llvm/Support/Error.h"
32#include "llvm/Support/TargetSelect.h"
33#include "llvm/Support/raw_ostream.h"
34#include "llvm/Target/TargetOptions.h"
35#include "llvm/TargetParser/Host.h"
36
37// Forward declare the host API functions
38extern "C" {
39float* llvmexpr_ensure_buffer(const char*, int64_t);
40int64_t llvmexpr_get_buffer_size(const char*);
41}
42
43#ifdef _WIN32
44namespace {
45extern "C" void llvmexpr_sincosf(float x, float* s, float* c) {
46 if (s != nullptr) {
47 *s = std::sin(x);
48 }
49 if (c != nullptr) {
50 *c = std::cos(x);
51 }
52}
53
54extern "C" void llvmexpr_sincos(double x, double* s, double* c) {
55 if (s != nullptr) {
56 *s = std::sin(x);
57 }
58 if (c != nullptr) {
59 *c = std::cos(x);
60 }
61}
62
63} // namespace
64#endif
65
66OrcJit::OrcJit(bool no_nans_fp_math) {
67 static struct LLVMInitializer {
68 LLVMInitializer() { // NOLINT(modernize-use-equals-default)
69 llvm::InitializeNativeTarget();
70 llvm::InitializeNativeTargetAsmPrinter();
71 llvm::InitializeNativeTargetAsmParser();
72 }
73 LLVMInitializer(const LLVMInitializer&) = delete;
74 LLVMInitializer& operator=(const LLVMInitializer&) = delete;
75 LLVMInitializer(LLVMInitializer&&) = delete;
76 LLVMInitializer& operator=(LLVMInitializer&&) = delete;
77 ~LLVMInitializer() = default;
78 } initializer;
79
80 auto jtmb =
81 llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost());
82 jtmb.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
83
84 llvm::StringMap<bool> host_features = llvm::sys::getHostCPUFeatures();
85 if (host_features.size() > 0) {
86 std::vector<std::string> features;
87 for (auto& f : host_features) {
88 if (f.getValue()) {
89 features.push_back("+" + f.getKey().str());
90 }
91 }
92 jtmb.addFeatures(features);
93 }
94
95 llvm::TargetOptions opts;
96 opts.AllowFPOpFusion = llvm::FPOpFusion::Fast;
97#if LLVM_VERSION_MAJOR < 22
98 opts.UnsafeFPMath = true;
99#endif
100 opts.NoInfsFPMath = true;
101 opts.NoNaNsFPMath = no_nans_fp_math;
102 jtmb.setOptions(opts);
103
104 auto jit_builder = llvm::orc::LLJITBuilder();
105 jit_builder.setJITTargetMachineBuilder(std::move(jtmb));
106 auto temp_jit = jit_builder.create();
107 if (!temp_jit) {
108 llvm::errs() << "Failed to create LLJIT instance: "
109 << llvm::toString(temp_jit.takeError()) << "\n";
110 throw std::runtime_error("LLJIT creation failed");
111 }
112 lljit = std::move(*temp_jit);
113
114 // Register Host API symbols for dynamic array management
115 auto& main_jd = lljit->getMainJITDylib();
116 llvm::orc::SymbolMap symbols;
117
118 symbols[lljit->mangleAndIntern("llvmexpr_ensure_buffer")] =
119 llvm::orc::ExecutorSymbolDef(
120 llvm::orc::ExecutorAddr(
121 llvm::pointerToJITTargetAddress(&llvmexpr_ensure_buffer)),
122 llvm::JITSymbolFlags::Callable | llvm::JITSymbolFlags::Exported);
123
124 symbols[lljit->mangleAndIntern("llvmexpr_get_buffer_size")] =
125 llvm::orc::ExecutorSymbolDef(
126 llvm::orc::ExecutorAddr(
127 llvm::pointerToJITTargetAddress(&llvmexpr_get_buffer_size)),
128 llvm::JITSymbolFlags::Callable | llvm::JITSymbolFlags::Exported);
129
130#ifdef _WIN32
131 symbols[lljit->mangleAndIntern("sincosf")] = llvm::orc::ExecutorSymbolDef(
132 llvm::orc::ExecutorAddr(
133 llvm::pointerToJITTargetAddress(&llvmexpr_sincosf)),
134 llvm::JITSymbolFlags::Callable | llvm::JITSymbolFlags::Exported);
135
136 symbols[lljit->mangleAndIntern("sincos")] = llvm::orc::ExecutorSymbolDef(
137 llvm::orc::ExecutorAddr(
138 llvm::pointerToJITTargetAddress(&llvmexpr_sincos)),
139 llvm::JITSymbolFlags::Callable | llvm::JITSymbolFlags::Exported);
140#endif
141
142 if (auto err = main_jd.define(llvm::orc::absoluteSymbols(symbols))) {
143 llvm::errs() << "Failed to define host call symbols: "
144 << llvm::toString(std::move(err)) << "\n";
145 throw std::runtime_error("Failed to define host call symbols in JIT");
146 }
147}
148
149const llvm::DataLayout& OrcJit::getDataLayout() const {
150 return lljit->getDataLayout();
151}
152
153const llvm::Triple& OrcJit::getTargetTriple() const {
154 return lljit->getTargetTriple();
155}
156
157void OrcJit::addModule(std::unique_ptr<llvm::Module> m,
158 std::unique_ptr<llvm::LLVMContext> ctx) {
159 std::vector<llvm::Function*> functions_to_remove;
160 for (auto& f : *m) {
161 if (!f.isDeclaration()) {
162 std::string func_name = f.getName().str();
163
164 if (!func_name.starts_with("process_plane_")) {
165 if (func_name.starts_with("fast_") ||
166 func_name.find("_v4") != std::string::npos ||
167 func_name.find("_v8") != std::string::npos ||
168 func_name.find("_v16") != std::string::npos) {
169
170 // Try a test lookup to see if symbol exists
171 auto test_sym = lljit->lookup(func_name);
172 if (test_sym) {
173 // Symbol already exists, mark for removal
174 functions_to_remove.push_back(&f);
175 continue;
176 }
177 // If test lookup failed, the symbol doesn't exist, so we can add it
178 llvm::consumeError(test_sym.takeError());
179 }
180 }
181 }
182 }
183
184 // Remove duplicate functions from the module
185 for (auto* f : functions_to_remove) {
186 f->eraseFromParent();
187 }
188
189 auto tsm = llvm::orc::ThreadSafeModule(std::move(m), std::move(ctx));
190 auto err = lljit->addIRModule(std::move(tsm));
191 if (err) {
192 llvm::errs() << "Failed to add IR module: "
193 << llvm::toString(std::move(err)) << "\n";
194 throw std::runtime_error("Failed to add IR module to JIT");
195 }
196}
197
198void* OrcJit::getFunctionAddress(const std::string& name) {
199 // Try to lookup the symbol
200 auto sym = lljit->lookup(name);
201 if (!sym) {
202 llvm::errs() << "Failed to find symbol '" << name
203 << "': " << llvm::toString(sym.takeError()) << "\n";
204 return nullptr;
205 }
206 return sym->toPtr<void*>();
207}
208
209// NOLINTBEGIN(cppcoreguidelines-avoid-non-const-global-variables)
210// Global JIT instances
213
214// JIT cache
215std::unordered_map<std::string, CompiledFunction> jit_cache;
216std::mutex cache_mutex;
217// NOLINTEND(cppcoreguidelines-avoid-non-const-global-variables)
OrcJit global_jit_nan_safe(false)
int64_t llvmexpr_get_buffer_size(const char *)
float * llvmexpr_ensure_buffer(const char *, int64_t)
OrcJit global_jit_fast(true)
std::unordered_map< std::string, CompiledFunction > jit_cache
Definition Jit.cpp:215
std::mutex cache_mutex
Definition Jit.cpp:216
Definition Jit.hpp:42
void * getFunctionAddress(const std::string &name)
Definition Jit.cpp:198
const llvm::Triple & getTargetTriple() const
Definition Jit.cpp:153
OrcJit(bool no_nans_fp_math)
Definition Jit.cpp:66
const llvm::DataLayout & getDataLayout() const
Definition Jit.cpp:149
void addModule(std::unique_ptr< llvm::Module > m, std::unique_ptr< llvm::LLVMContext > ctx)
Definition Jit.cpp:157
float * llvmexpr_ensure_buffer(const char *name, int64_t requested_size)
int64_t llvmexpr_get_buffer_size(const char *name)