VapourSynth-llvmexpr
Loading...
Searching...
No Matches
Math.hpp
Go to the documentation of this file.
1
47
48#ifndef LLVMEXPR_CODEGEN_LLVM_MATH_HPP
49#define LLVMEXPR_CODEGEN_LLVM_MATH_HPP
50
51#include <format>
52#include <functional>
53#include <map>
54#include <numbers>
55#include <string>
56#include <tuple>
57#include <utility>
58
59#include "llvm/IR/BasicBlock.h"
60#include "llvm/IR/Function.h"
61#include "llvm/IR/IRBuilder.h"
62#include "llvm/IR/Intrinsics.h"
63#include "llvm/IR/LLVMContext.h"
64#include "llvm/IR/Module.h"
65#include "llvm/IR/Type.h"
66#include "llvm/IR/Value.h"
67#include "llvm/IR/Verifier.h"
68
69enum class MathOp : std::uint8_t {
79};
80
81struct MathopInfo {
82 int arity;
83 const char* name;
84};
85
87 switch (op) {
88 case MathOp::Exp:
89 return {.arity = 1, .name = "fast_exp"};
90 case MathOp::Log:
91 return {.arity = 1, .name = "fast_log"};
92 case MathOp::Sin:
93 return {.arity = 1, .name = "fast_sin"};
94 case MathOp::Cos:
95 return {.arity = 1, .name = "fast_cos"};
96 case MathOp::Tan:
97 return {.arity = 1, .name = "fast_tan"};
98 case MathOp::Atan:
99 return {.arity = 1, .name = "fast_atan"};
100 case MathOp::Atan2:
101 return {.arity = 2, .name = "fast_atan2"};
102 case MathOp::Acos:
103 return {.arity = 1, .name = "fast_acos"};
104 case MathOp::Asin:
105 return {.arity = 1, .name = "fast_asin"};
106 }
107}
108
109#ifdef __x86_64__
111 std::integer_sequence<int, 4,
112 8, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
113 16>; // NOLINT(cppcoreguidelines-avoid-magic-numbers)
114#elif defined(__ARM_NEON__)
116 std::integer_sequence<int,
117 4>; // NOLINT(cppcoreguidelines-avoid-magic-numbers)
118#else
119using SupportedVectorWidths = std::integer_sequence<int>;
120#endif
121
122template <int VectorWidth> class MathFunctionGenerator {
123 public:
124 MathFunctionGenerator(llvm::Module* module, llvm::LLVMContext& context)
125 : module(module), context(context), builder(context) {}
126
127 template <MathOp op> llvm::Function* getOrCreate();
128
129 private:
130 template <int, MathOp> friend struct MathFunctionImpl;
131
132 llvm::Module* module;
133 llvm::LLVMContext& context;
134 llvm::IRBuilder<> builder;
135
136 llvm::Type* getFloatType() {
137 auto* ty = llvm::Type::getFloatTy(context);
138 if (VectorWidth == 1) {
139 return ty;
140 }
141 return llvm::VectorType::get(ty, VectorWidth, false);
142 }
143
144 llvm::Type* getInt32Type() {
145 auto* ty = llvm::Type::getInt32Ty(context);
146 if (VectorWidth == 1) {
147 return ty;
148 }
149 return llvm::VectorType::get(ty, VectorWidth, false);
150 }
151
152 llvm::Value* getConstant(double val) {
153 auto* scalar_const =
154 llvm::ConstantFP::get(llvm::Type::getFloatTy(context), val);
155 return (VectorWidth == 1)
156 ? (llvm::Value*)scalar_const
157 : builder.CreateVectorSplat(VectorWidth, scalar_const);
158 }
159
160 llvm::Value* getInt32Constant(int32_t val) {
161 auto* scalar_const =
162 llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), val);
163 return (VectorWidth == 1)
164 ? (llvm::Value*)scalar_const
165 : builder.CreateVectorSplat(VectorWidth, scalar_const);
166 }
167
168 std::string getFunctionName(const std::string& base_name) {
169 if (VectorWidth == 1) {
170 return base_name;
171 }
172 return base_name + "_v" + std::to_string(VectorWidth);
173 }
174
175 llvm::Value* createIntrinsicCall(llvm::Intrinsic::ID intrinsic_id,
176 llvm::ArrayRef<llvm::Value*> args) {
177 auto* intrinsic = llvm::Intrinsic::getOrInsertDeclaration(
178 module, intrinsic_id, getFloatType());
179 return builder.CreateCall(intrinsic, args);
180 }
181
182 llvm::Function* createFunction(
183 const std::string& base_name, int arity,
184 const std::function<llvm::Value*(llvm::ArrayRef<llvm::Value*>)>&
185 body_generator) {
186 std::string func_name = getFunctionName(base_name);
187 if (auto* existing_func = module->getFunction(func_name)) {
188 return existing_func;
189 }
190
191 auto last_ip = builder.saveIP();
192
193 auto* float_ty = getFloatType();
194 std::vector<llvm::Type*> arg_types(arity, float_ty);
195 auto* func_ty = llvm::FunctionType::get(float_ty, arg_types, false);
196 auto* func = llvm::Function::Create(
197 func_ty, llvm::Function::ExternalLinkage, func_name, module);
198
199 auto* entry_bb = llvm::BasicBlock::Create(context, "entry", func);
200 builder.SetInsertPoint(entry_bb);
201
202 std::vector<llvm::Value*> args;
203 std::ranges::transform(func->args(), std::back_inserter(args),
204 [](auto& arg) { return &arg; });
205 if (arity > 0) {
206 args[0]->setName("x");
207 if (arity > 1) {
208 args[1]->setName("y");
209 }
210 if (arity > 2) {
211 args[2]->setName("z");
212 }
213 }
214
215 llvm::Value* result = body_generator(args);
216
217 builder.CreateRet(result);
218 builder.restoreIP(last_ip);
219
220 if (llvm::verifyFunction(*func, &llvm::errs())) {
221 func->eraseFromParent();
222 return nullptr;
223 }
224
225 return func;
226 }
227};
228
229template <int VectorWidth, MathOp op> struct MathFunctionImpl;
230
231template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Exp> {
232 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
233 constexpr auto OP_INFO = get_math_op_info(MathOp::Exp);
234 // https://github.com/vapoursynth/vapoursynth/blob/2a3d3657320ca505c784b98f10e7cd9649d6169a/src/core/expr/jitcompiler_x86.cpp#L635
235 return gen->createFunction(
236 OP_INFO.name, OP_INFO.arity,
237 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
238 auto* x = args[0];
239 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
240 auto* exp_hi = gen->getConstant(88.3762626647949F);
241 auto* exp_lo = gen->getConstant(-88.3762626647949F);
242 auto* log2e = gen->getConstant(std::numbers::log2e_v<float>);
243 auto* exp_p0 = gen->getConstant(1.9875691500E-4F);
244 auto* exp_p1 = gen->getConstant(1.3981999507E-3F);
245 auto* exp_p2 = gen->getConstant(8.3334519073E-3F);
246 auto* exp_p3 = gen->getConstant(4.1665795894E-2F);
247 auto* exp_p4 = gen->getConstant(1.6666665459E-1F);
248 auto* exp_p5 = gen->getConstant(5.0000001201E-1F);
249 auto* half = gen->getConstant(0.5F);
250 auto* one = gen->getConstant(1.0F);
251 auto* neg_exp_c1 = gen->getConstant(-0.693359375F);
252 auto* neg_exp_c2 = gen->getConstant(2.12194440e-4F);
253 auto* const_0x7f = gen->getInt32Constant(0x7F);
254 auto* const_23 = gen->getInt32Constant(23);
255 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
256 x = gen->createIntrinsicCall(llvm::Intrinsic::minnum,
257 {x, exp_hi});
258 x = gen->createIntrinsicCall(llvm::Intrinsic::maxnum,
259 {x, exp_lo});
260 auto* fx = gen->createIntrinsicCall(llvm::Intrinsic::fma,
261 {log2e, x, half});
262 auto* etmp =
263 gen->createIntrinsicCall(llvm::Intrinsic::nearbyint, {fx});
264 auto* cmp_gt = gen->builder.CreateFCmpOGT(etmp, fx);
265 auto* ext_cmp =
266 gen->builder.CreateSExt(cmp_gt, gen->getInt32Type());
267 auto* one_int =
268 gen->builder.CreateBitCast(one, gen->getInt32Type());
269 auto* mask_int = gen->builder.CreateAnd(ext_cmp, one_int);
270 auto* mask =
271 gen->builder.CreateBitCast(mask_int, gen->getFloatType());
272 fx = gen->builder.CreateFSub(etmp, mask);
273 x = gen->createIntrinsicCall(llvm::Intrinsic::fma,
274 {fx, neg_exp_c1, x});
275 x = gen->createIntrinsicCall(llvm::Intrinsic::fma,
276 {fx, neg_exp_c2, x});
277 auto* z = gen->builder.CreateFMul(x, x);
278 llvm::Value* y = exp_p0;
279 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
280 {y, x, exp_p1});
281 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
282 {y, x, exp_p2});
283 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
284 {y, x, exp_p3});
285 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
286 {y, x, exp_p4});
287 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
288 {y, x, exp_p5});
289 y = gen->createIntrinsicCall(llvm::Intrinsic::fma, {y, z, x});
290 y = gen->builder.CreateFAdd(y, one);
291 auto* emm0_float =
292 gen->createIntrinsicCall(llvm::Intrinsic::nearbyint, {fx});
293 auto* emm0 =
294 gen->builder.CreateFPToSI(emm0_float, gen->getInt32Type());
295 emm0 = gen->builder.CreateAdd(emm0, const_0x7f);
296 emm0 = gen->builder.CreateShl(emm0, const_23);
297 auto* emm0_as_float =
298 gen->builder.CreateBitCast(emm0, gen->getFloatType());
299 x = gen->builder.CreateFMul(y, emm0_as_float);
300 return x;
301 });
302 }
303};
304
305template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Log> {
306 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
307 constexpr auto OP_INFO = get_math_op_info(MathOp::Log);
308 // https://github.com/vapoursynth/vapoursynth/blob/2a3d3657320ca505c784b98f10e7cd9649d6169a/src/core/expr/jitcompiler_x86.cpp#L671
309 return gen->createFunction(
310 OP_INFO.name, OP_INFO.arity,
311 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
312 auto* x = args[0];
313 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
314 auto* min_norm_pos = gen->getInt32Constant(0x00800000);
315 auto* inv_mant_mask = gen->getInt32Constant(~0x7F800000);
316 auto* sqrt_1_2 = gen->getConstant(0.707106781186547524F);
317 auto* log_p0 = gen->getConstant(7.0376836292E-2F);
318 auto* log_p1 = gen->getConstant(-1.1514610310E-1F);
319 auto* log_p2 = gen->getConstant(1.1676998740E-1F);
320 auto* log_p3 = gen->getConstant(-1.2420140846E-1F);
321 auto* log_p4 = gen->getConstant(1.4249322787E-1F);
322 auto* log_p5 = gen->getConstant(-1.6668057665E-1F);
323 auto* log_p6 = gen->getConstant(2.0000714765E-1F);
324 auto* log_p7 = gen->getConstant(-2.4999993993E-1F);
325 auto* log_p8 = gen->getConstant(3.3333331174E-1F);
326 auto* log_q2 = gen->getConstant(0.693359375F);
327 auto* log_q1 = gen->getConstant(-2.12194440e-4F);
328 auto* one = gen->getConstant(1.0F);
329 auto* neg_half = gen->getConstant(-0.5F);
330 auto* const_0x7f = gen->getInt32Constant(0x7F);
331 auto* const_23 = gen->getInt32Constant(23);
332 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
333 auto* is_one = gen->builder.CreateFCmpOEQ(x, one);
334 auto* min_norm_pos_float = gen->builder.CreateBitCast(
335 min_norm_pos, gen->getFloatType());
336 x = gen->createIntrinsicCall(llvm::Intrinsic::maxnum,
337 {x, min_norm_pos_float});
338 auto* x_as_int =
339 gen->builder.CreateBitCast(x, gen->getInt32Type());
340 auto* emm0i = gen->builder.CreateLShr(x_as_int, const_23);
341 auto* x_masked =
342 gen->builder.CreateAnd(x_as_int, inv_mant_mask);
343 auto* half_as_int = gen->builder.CreateBitCast(
344 gen->getConstant(
345 0.5F), // NOLINT(cppcoreguidelines-avoid-magic-numbers)
346 gen->getInt32Type()); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
347 x_masked = gen->builder.CreateOr(x_masked, half_as_int);
348 x = gen->builder.CreateBitCast(x_masked, gen->getFloatType());
349 emm0i = gen->builder.CreateSub(emm0i, const_0x7f);
350 auto* emm0 =
351 gen->builder.CreateSIToFP(emm0i, gen->getFloatType());
352 emm0 = gen->builder.CreateFAdd(emm0, one);
353 auto* mask = gen->builder.CreateFCmpOLT(x, sqrt_1_2);
354 auto* ext_mask =
355 gen->builder.CreateSExt(mask, gen->getInt32Type());
356 x_as_int = gen->builder.CreateBitCast(x, gen->getInt32Type());
357 auto* etmp_as_int = gen->builder.CreateAnd(ext_mask, x_as_int);
358 auto* etmp = gen->builder.CreateBitCast(etmp_as_int,
359 gen->getFloatType());
360 x = gen->builder.CreateFSub(x, one);
361 auto* one_as_int =
362 gen->builder.CreateBitCast(one, gen->getInt32Type());
363 auto* maskf_as_int =
364 gen->builder.CreateAnd(ext_mask, one_as_int);
365 auto* maskf = gen->builder.CreateBitCast(maskf_as_int,
366 gen->getFloatType());
367 emm0 = gen->builder.CreateFSub(emm0, maskf);
368 x = gen->builder.CreateFAdd(x, etmp);
369 auto* z = gen->builder.CreateFMul(x, x);
370 llvm::Value* y = log_p0;
371 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
372 {y, x, log_p1});
373 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
374 {y, x, log_p2});
375 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
376 {y, x, log_p3});
377 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
378 {y, x, log_p4});
379 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
380 {y, x, log_p5});
381 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
382 {y, x, log_p6});
383 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
384 {y, x, log_p7});
385 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
386 {y, x, log_p8});
387 y = gen->builder.CreateFMul(y, x);
388 y = gen->builder.CreateFMul(y, z);
389 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
390 {emm0, log_q1, y});
391 y = gen->createIntrinsicCall(llvm::Intrinsic::fma,
392 {z, neg_half, y});
393 x = gen->builder.CreateFAdd(x, y);
394 x = gen->createIntrinsicCall(llvm::Intrinsic::fma,
395 {emm0, log_q2, x});
396 x_as_int = gen->builder.CreateBitCast(x, gen->getInt32Type());
397 auto* ext_is_one =
398 gen->builder.CreateSExt(is_one, gen->getInt32Type());
399 auto* not_ext_is_one = gen->builder.CreateNot(ext_is_one);
400 auto* result_as_int =
401 gen->builder.CreateAnd(not_ext_is_one, x_as_int);
402 x = gen->builder.CreateBitCast(result_as_int,
403 gen->getFloatType());
404 return x;
405 });
406 }
407};
408
409template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Sin> {
410 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
411 constexpr auto OP_INFO = get_math_op_info(MathOp::Sin);
412 // https://github.com/vapoursynth/vapoursynth/blob/2a3d3657320ca505c784b98f10e7cd9649d6169a/src/core/expr/jitcompiler_x86.cpp#L813
413 return gen->createFunction(
414 OP_INFO.name, OP_INFO.arity,
415 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
416 auto* x = args[0];
417 auto* float_ty = gen->getFloatType();
418 auto* int32_ty = gen->getInt32Type();
419 auto* float_invpi =
420 gen->getConstant(std::numbers::inv_pi_v<float>);
421 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
422 auto* float_pi1 = gen->getConstant(3.140625F);
423 auto* float_pi2 = gen->getConstant(0.0009670257568359375F);
424 auto* float_pi3 = gen->getConstant(1.984187252998352e-07F);
425 auto* float_pi4 = gen->getConstant(1.273533813134432e-11F);
426 auto* float_sin_c3 = gen->getConstant(-0.1666666567325592F);
427 auto* float_sin_c5 = gen->getConstant(0.00833307858556509F);
428 auto* float_sin_c7 = gen->getConstant(-0.00019807418575510383F);
429 auto* float_sin_c9 = gen->getConstant(2.6019030363451748e-06F);
430 auto* signmask = gen->getInt32Constant(0x80000000);
431 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
432 llvm::Value* sign = gen->builder.CreateBitCast(x, int32_ty);
433 sign = gen->builder.CreateAnd(sign, signmask);
434 llvm::Value* t1 =
435 gen->createIntrinsicCall(llvm::Intrinsic::fabs, {x});
436 llvm::Value* t2 = gen->builder.CreateFMul(t1, float_invpi);
437 llvm::Value* t2_rounded =
438 gen->createIntrinsicCall(llvm::Intrinsic::nearbyint, {t2});
439 llvm::Value* t2i =
440 gen->builder.CreateFPToSI(t2_rounded, int32_ty);
441 llvm::Value* t4 = gen->builder.CreateShl(
442 t2i, 31); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
443 sign = gen->builder.CreateXor(sign, t4);
444 t2 = gen->builder.CreateSIToFP(t2i, float_ty);
445 t1 = gen->createIntrinsicCall(
446 llvm::Intrinsic::fma,
447 {t2, gen->builder.CreateFNeg(float_pi1), t1});
448 t1 = gen->createIntrinsicCall(
449 llvm::Intrinsic::fma,
450 {t2, gen->builder.CreateFNeg(float_pi2), t1});
451 t1 = gen->createIntrinsicCall(
452 llvm::Intrinsic::fma,
453 {t2, gen->builder.CreateFNeg(float_pi3), t1});
454 t1 = gen->createIntrinsicCall(
455 llvm::Intrinsic::fma,
456 {t2, gen->builder.CreateFNeg(float_pi4), t1});
457 t2 = gen->builder.CreateFMul(t1, t1);
458 llvm::Value* t3 = gen->createIntrinsicCall(
459 llvm::Intrinsic::fma, {t2, float_sin_c9, float_sin_c7});
460 t3 = gen->createIntrinsicCall(llvm::Intrinsic::fma,
461 {t3, t2, float_sin_c5});
462 t3 = gen->createIntrinsicCall(llvm::Intrinsic::fma,
463 {t3, t2, float_sin_c3});
464 t3 = gen->builder.CreateFMul(t3, t2);
465 t3 = gen->builder.CreateFMul(t3, t1);
466 t1 = gen->builder.CreateFAdd(t1, t3);
467 llvm::Value* t1_as_int =
468 gen->builder.CreateBitCast(t1, int32_ty);
469 llvm::Value* result_as_int =
470 gen->builder.CreateXor(sign, t1_as_int);
471 return gen->builder.CreateBitCast(result_as_int, float_ty);
472 });
473 }
474};
475
476template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Cos> {
477 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
478 constexpr auto OP_INFO = get_math_op_info(MathOp::Cos);
479 // https://github.com/vapoursynth/vapoursynth/blob/2a3d3657320ca505c784b98f10e7cd9649d6169a/src/core/expr/jitcompiler_x86.cpp#L813
480 return gen->createFunction(
481 OP_INFO.name, OP_INFO.arity,
482 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
483 auto* x = args[0];
484 auto* float_ty = gen->getFloatType();
485 auto* int32_ty = gen->getInt32Type();
486 auto* float_invpi =
487 gen->getConstant(std::numbers::inv_pi_v<float>);
488 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
489 auto* float_pi1 = gen->getConstant(3.140625F);
490 auto* float_pi2 = gen->getConstant(0.0009670257568359375F);
491 auto* float_pi3 = gen->getConstant(1.984187252998352e-07F);
492 auto* float_pi4 = gen->getConstant(1.273533813134432e-11F);
493 auto* float_cos_c2 = gen->getConstant(-0.4999999701976776F);
494 auto* float_cos_c4 = gen->getConstant(0.04166652262210846F);
495 auto* float_cos_c6 = gen->getConstant(-0.001388676579343155F);
496 auto* float_cos_c8 = gen->getConstant(2.4390448881604243e-05F);
497 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
498 auto* one_float = gen->getConstant(1.0F);
499 llvm::Value* sign = gen->getInt32Constant(0);
500 llvm::Value* t1 =
501 gen->createIntrinsicCall(llvm::Intrinsic::fabs, {x});
502 llvm::Value* t2 = gen->builder.CreateFMul(t1, float_invpi);
503 llvm::Value* t2_rounded =
504 gen->createIntrinsicCall(llvm::Intrinsic::nearbyint, {t2});
505 llvm::Value* t2i =
506 gen->builder.CreateFPToSI(t2_rounded, int32_ty);
507 llvm::Value* t4 = gen->builder.CreateShl(
508 t2i, 31); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
509 sign = gen->builder.CreateXor(sign, t4);
510 t2 = gen->builder.CreateSIToFP(t2i, float_ty);
511 t1 = gen->createIntrinsicCall(
512 llvm::Intrinsic::fma,
513 {t2, gen->builder.CreateFNeg(float_pi1), t1});
514 t1 = gen->createIntrinsicCall(
515 llvm::Intrinsic::fma,
516 {t2, gen->builder.CreateFNeg(float_pi2), t1});
517 t1 = gen->createIntrinsicCall(
518 llvm::Intrinsic::fma,
519 {t2, gen->builder.CreateFNeg(float_pi3), t1});
520 t1 = gen->createIntrinsicCall(
521 llvm::Intrinsic::fma,
522 {t2, gen->builder.CreateFNeg(float_pi4), t1});
523 t2 = gen->builder.CreateFMul(t1, t1);
524 llvm::Value* t3 = gen->createIntrinsicCall(
525 llvm::Intrinsic::fma, {t2, float_cos_c8, float_cos_c6});
526 t3 = gen->createIntrinsicCall(llvm::Intrinsic::fma,
527 {t3, t2, float_cos_c4});
528 t3 = gen->createIntrinsicCall(llvm::Intrinsic::fma,
529 {t3, t2, float_cos_c2});
530 t1 = gen->createIntrinsicCall(llvm::Intrinsic::fma,
531 {t3, t2, one_float});
532 llvm::Value* t1_as_int =
533 gen->builder.CreateBitCast(t1, int32_ty);
534 llvm::Value* result_as_int =
535 gen->builder.CreateXor(sign, t1_as_int);
536 return gen->builder.CreateBitCast(result_as_int, float_ty);
537 });
538 }
539};
540
541template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Tan> {
542 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
543 constexpr auto OP_INFO = get_math_op_info(MathOp::Tan);
544 return gen->createFunction(
545 OP_INFO.name, OP_INFO.arity,
546 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
547 auto* x = args[0];
548 llvm::Function* sin_func =
549 MathFunctionImpl<VectorWidth, MathOp::Sin>::generate(gen);
550 llvm::Function* cos_func =
551 MathFunctionImpl<VectorWidth, MathOp::Cos>::generate(gen);
552 llvm::Value* sin_x = gen->builder.CreateCall(sin_func, {x});
553 llvm::Value* cos_x = gen->builder.CreateCall(cos_func, {x});
554 return gen->builder.CreateFDiv(sin_x, cos_x);
555 });
556 }
557};
558
559template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Atan> {
560 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
561 constexpr auto OP_INFO = get_math_op_info(MathOp::Atan);
562 // https://stackoverflow.com/a/23097989
563 return gen->createFunction(
564 OP_INFO.name, OP_INFO.arity,
565 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
566 auto* var = args[0];
567 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
568 auto* one = gen->getConstant(1.0F);
569 auto* pi_div_2 = gen->getConstant(1.5707963267948966F);
570 auto* z =
571 gen->createIntrinsicCall(llvm::Intrinsic::fabs, {var});
572 auto* z_gt_1 = gen->builder.CreateFCmpOGT(z, one);
573 auto* one_div_zz = gen->builder.CreateFDiv(one, z);
574 auto* a = gen->builder.CreateSelect(z_gt_1, one_div_zz, z);
575 auto* s = gen->builder.CreateFMul(a, a);
576 auto* q = gen->builder.CreateFMul(s, s);
577 llvm::Value* p = gen->getConstant(-2.0258553044340116e-5F);
578 llvm::Value* t = gen->getConstant(2.2302240345710764e-4F);
579 p = gen->createIntrinsicCall(
580 llvm::Intrinsic::fma,
581 {p, q, gen->getConstant(-1.1640717779912220e-3F)});
582 t = gen->createIntrinsicCall(
583 llvm::Intrinsic::fma,
584 {t, q, gen->getConstant(3.8559749383656407e-3F)});
585 p = gen->createIntrinsicCall(
586 llvm::Intrinsic::fma,
587 {p, q, gen->getConstant(-9.1845592187222193e-3F)});
588 t = gen->createIntrinsicCall(
589 llvm::Intrinsic::fma,
590 {t, q, gen->getConstant(1.6978035834594660e-2F)});
591 p = gen->createIntrinsicCall(
592 llvm::Intrinsic::fma,
593 {p, q, gen->getConstant(-2.5826796814492296e-2F)});
594 t = gen->createIntrinsicCall(
595 llvm::Intrinsic::fma,
596 {t, q, gen->getConstant(3.4067811082715810e-2F)});
597 p = gen->createIntrinsicCall(
598 llvm::Intrinsic::fma,
599 {p, q, gen->getConstant(-4.0926382420509999e-2F)});
600 t = gen->createIntrinsicCall(
601 llvm::Intrinsic::fma,
602 {t, q, gen->getConstant(4.6739496199158334e-2F)});
603 p = gen->createIntrinsicCall(
604 llvm::Intrinsic::fma,
605 {p, q, gen->getConstant(-5.2392330054601366e-2F)});
606 t = gen->createIntrinsicCall(
607 llvm::Intrinsic::fma,
608 {t, q, gen->getConstant(5.8773077721790683e-2F)});
609 p = gen->createIntrinsicCall(
610 llvm::Intrinsic::fma,
611 {p, q, gen->getConstant(-6.6658603633512892e-2F)});
612 t = gen->createIntrinsicCall(
613 llvm::Intrinsic::fma,
614 {t, q, gen->getConstant(7.6922129305867892e-2F)});
615 p = gen->createIntrinsicCall(llvm::Intrinsic::fma, {p, s, t});
616 p = gen->createIntrinsicCall(
617 llvm::Intrinsic::fma,
618 {p, s, gen->getConstant(-9.0909012354005267e-2F)});
619 p = gen->createIntrinsicCall(
620 llvm::Intrinsic::fma,
621 {p, s, gen->getConstant(1.1111110678749421e-1F)});
622 p = gen->createIntrinsicCall(
623 llvm::Intrinsic::fma,
624 {p, s, gen->getConstant(-1.4285714271334810e-1F)});
625 p = gen->createIntrinsicCall(
626 llvm::Intrinsic::fma,
627 {p, s, gen->getConstant(1.9999999999755005e-1F)});
628 p = gen->createIntrinsicCall(
629 llvm::Intrinsic::fma,
630 {p, s, gen->getConstant(-3.3333333333331838e-1F)});
631 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
632 auto* pp_mul_ss = gen->builder.CreateFMul(p, s);
633 p = gen->createIntrinsicCall(llvm::Intrinsic::fma,
634 {pp_mul_ss, a, a});
635 auto* rr_if_gt_1 = gen->builder.CreateFSub(pi_div_2, p);
636 auto* rr = gen->builder.CreateSelect(z_gt_1, rr_if_gt_1, p);
637 return gen->createIntrinsicCall(llvm::Intrinsic::copysign,
638 {rr, var});
639 });
640 }
641};
642
643template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Atan2> {
644 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
645 constexpr auto OP_INFO = get_math_op_info(MathOp::Atan2);
646 return gen->createFunction(
647 OP_INFO.name, OP_INFO.arity,
648 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
649 auto* var_y = args[0];
650 auto* var_x = args[1];
651 auto* atan_func =
652 MathFunctionImpl<VectorWidth, MathOp::Atan>::generate(gen);
653 auto* zero = gen->getConstant(0.0F);
654 auto* pi = gen->getConstant(std::numbers::pi_v<float>);
655 auto* pi_div_2 = gen->getConstant(
656 1.5707963267948966F); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
657 auto* y_div_x = gen->builder.CreateFDiv(var_y, var_x);
658 auto* atan_y_div_x =
659 gen->builder.CreateCall(atan_func, {y_div_x});
660 auto* res_x_gt_0 = atan_y_div_x;
661 auto* signed_pi = gen->createIntrinsicCall(
662 llvm::Intrinsic::copysign, {pi, var_y});
663 auto* res_x_lt_0 =
664 gen->builder.CreateFAdd(atan_y_div_x, signed_pi);
665 auto* res_x_eq_0 = gen->createIntrinsicCall(
666 llvm::Intrinsic::copysign, {pi_div_2, var_y});
667 auto* x_gt_0 = gen->builder.CreateFCmpOGT(var_x, zero);
668 auto* x_lt_0 = gen->builder.CreateFCmpOLT(var_x, zero);
669 auto* result =
670 gen->builder.CreateSelect(x_gt_0, res_x_gt_0, res_x_lt_0);
671 result = gen->builder.CreateSelect(
672 x_lt_0, res_x_lt_0,
673 gen->builder.CreateSelect(x_gt_0, res_x_gt_0, res_x_eq_0));
674 auto* x_is_zero = gen->builder.CreateFCmpOEQ(var_x, zero);
675 auto* y_is_zero = gen->builder.CreateFCmpOEQ(var_y, zero);
676 auto* both_zero = gen->builder.CreateAnd(x_is_zero, y_is_zero);
677 result = gen->builder.CreateSelect(both_zero, zero, result);
678 return result;
679 });
680 }
681};
682
683template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Acos> {
684 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
685 constexpr auto OP_INFO = get_math_op_info(MathOp::Acos);
686 // https://forwardscattering.org/post/66
687 // TODO: Switch to another implementation that doesn't has licensing issues.
688 return gen->createFunction(
689 OP_INFO.name, OP_INFO.arity,
690 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
691 auto* x = args[0];
692 auto* pi = gen->getConstant(std::numbers::pi_v<float>);
693 auto* ax = gen->createIntrinsicCall(llvm::Intrinsic::fabs, {x});
694 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers)
695 auto* term1_mul = gen->getConstant(-0.124605335F);
696 auto* term1_add = gen->getConstant(0.1570634F);
697 auto* term1 = gen->createIntrinsicCall(
698 llvm::Intrinsic::fma, {ax, term1_mul, term1_add});
699
700 auto* term2_sub = gen->getConstant(0.99418175F);
701 auto* term2 = gen->builder.CreateFSub(term2_sub, ax);
702
703 auto* poly_part = gen->builder.CreateFMul(term1, term2);
704
705 auto* two = gen->getConstant(2.0F);
706 auto* neg_two = gen->getConstant(-2.0F);
707 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers)
708 auto* sqrt_arg = gen->createIntrinsicCall(llvm::Intrinsic::fma,
709 {ax, neg_two, two});
710 auto* sqrt_part =
711 gen->createIntrinsicCall(llvm::Intrinsic::sqrt, {sqrt_arg});
712
713 auto* res_pos = gen->builder.CreateFAdd(poly_part, sqrt_part);
714
715 auto* zero = gen->getConstant(0.0F);
716 auto* is_neg = gen->builder.CreateFCmpOLT(x, zero);
717
718 auto* res_neg = gen->builder.CreateFSub(pi, res_pos);
719
720 return gen->builder.CreateSelect(is_neg, res_neg, res_pos);
721 });
722 }
723};
724
725template <int VectorWidth> struct MathFunctionImpl<VectorWidth, MathOp::Asin> {
726 static llvm::Function* generate(MathFunctionGenerator<VectorWidth>* gen) {
727 constexpr auto OP_INFO = get_math_op_info(MathOp::Asin);
728 // asin(x) = pi/2 - acos(x)
729 return gen->createFunction(
730 OP_INFO.name, OP_INFO.arity,
731 [gen](llvm::ArrayRef<llvm::Value*> args) -> llvm::Value* {
732 auto* x = args[0];
733 auto* pi_div_2 = gen->getConstant(
734 1.5707963267948966F); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
735 auto* acos_func =
736 MathFunctionImpl<VectorWidth, MathOp::Acos>::generate(gen);
737 auto* acos_x = gen->builder.CreateCall(acos_func, {x});
738 return gen->builder.CreateFSub(pi_div_2, acos_x);
739 });
740 }
741};
742
743template <int VectorWidth>
744template <MathOp op>
748
750 std::tuple<std::integral_constant<MathOp, MathOp::Exp>,
751 std::integral_constant<MathOp, MathOp::Log>,
752 std::integral_constant<MathOp, MathOp::Sin>,
753 std::integral_constant<MathOp, MathOp::Cos>,
754 std::integral_constant<MathOp, MathOp::Tan>,
755 std::integral_constant<MathOp, MathOp::Atan>,
756 std::integral_constant<MathOp, MathOp::Atan2>,
757 std::integral_constant<MathOp, MathOp::Acos>,
758 std::integral_constant<MathOp, MathOp::Asin>>;
759
761 public:
762 MathLibraryManager(llvm::Module* module, llvm::LLVMContext& context)
763 : module(module), context(context) {}
764
765 llvm::Function* getFunction(MathOp op) {
766 if (auto it = func_cache.find(op); it != func_cache.end()) {
767 return it->second;
768 }
769 return generateAndCache(op);
770 }
771
772 private:
773 llvm::Module* module;
774 llvm::LLVMContext& context;
775 std::map<MathOp, llvm::Function*> func_cache;
776
777 template <MathOp op, int VectorWidth> llvm::Function* dispatch() {
778 MathFunctionGenerator<VectorWidth> generator(module, context);
779 return generator.template getOrCreate<op>();
780 }
781
782 template <MathOp op> llvm::Function* generateAndCacheImpl() {
783 llvm::Function* scalar_func = dispatch<op, 1>();
784
785 if (!scalar_func) {
786 return nullptr;
787 }
788
789 // https://llvm.org/docs/LangRef.html#id1998
790 auto link_vectors = [&]<int... vlen>(
791 std::integer_sequence<int, vlen...>) {
792 (
793 [&] {
794 const llvm::Function* vec_func = dispatch<op, vlen>();
795
796 if (vec_func) {
797#if defined(__x86_64__) || defined(__ARM_NEON__)
798 std::string isa;
799#ifdef __x86_64__
800 if constexpr (vlen == 4) {
801 isa = "b"; // SSE
802 } else if constexpr (
803 vlen ==
804 8) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
805 isa = "d"; // AVX2
806 } else if constexpr (
807 vlen ==
808 16) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
809 isa = "e"; // AVX512
810 }
811#elif defined(__ARM_NEON__)
812 if constexpr (vlen == 4) {
813 isa = "n"; // Armv8 Advanced SIMD
814 } else {
815 isa = "s"; // SVE
816 }
817#endif
818 constexpr auto OP_INFO = get_math_op_info(op);
819 std::string parameters(OP_INFO.arity, 'v');
820 std::string mask = "N";
821 std::string abi_string = std::format(
822 "_ZGV{}{}{}{}_{}({})", isa, mask, vlen, parameters,
823 scalar_func->getName().str(),
824 vec_func->getName().str());
825
826 scalar_func->addFnAttr(llvm::Attribute::get(
827 context, "vector-function-abi-variant",
828 abi_string));
829#endif
830 }
831 }(),
832 ...);
833 };
834
835 link_vectors(SupportedVectorWidths{});
836
837 func_cache[op] = scalar_func;
838 return scalar_func;
839 }
840
841 llvm::Function* generateAndCache(MathOp op) {
842 llvm::Function* result = nullptr;
843 std::apply(
844 [&, this](auto... op_constant) {
845 auto dispatcher = [&, this](auto op_c) {
846 if (op_c.value == op) {
847 result = this->generateAndCacheImpl<op_c.value>();
848 }
849 };
850 (dispatcher(op_constant), ...);
851 },
853 return result;
854 }
855};
856
857#endif // LLVMEXPR_CODEGEN_LLVM_MATH_HPP
constexpr MathopInfo get_math_op_info(MathOp op)
Definition Math.hpp:86
std::tuple< std::integral_constant< MathOp, MathOp::Exp >, std::integral_constant< MathOp, MathOp::Log >, std::integral_constant< MathOp, MathOp::Sin >, std::integral_constant< MathOp, MathOp::Cos >, std::integral_constant< MathOp, MathOp::Tan >, std::integral_constant< MathOp, MathOp::Atan >, std::integral_constant< MathOp, MathOp::Atan2 >, std::integral_constant< MathOp, MathOp::Acos >, std::integral_constant< MathOp, MathOp::Asin > > SupportedMathOpsTuple
Definition Math.hpp:749
std::integer_sequence< int > SupportedVectorWidths
Definition Math.hpp:119
MathOp
Definition Math.hpp:69
@ Sin
Definition Math.hpp:72
@ Tan
Definition Math.hpp:74
@ Atan2
Definition Math.hpp:76
@ Asin
Definition Math.hpp:78
@ Atan
Definition Math.hpp:75
@ Exp
Definition Math.hpp:70
@ Log
Definition Math.hpp:71
@ Acos
Definition Math.hpp:77
@ Cos
Definition Math.hpp:73
MathFunctionGenerator(llvm::Module *module, llvm::LLVMContext &context)
Definition Math.hpp:124
friend struct MathFunctionImpl
Definition Math.hpp:130
llvm::Function * getOrCreate()
Definition Math.hpp:745
llvm::Function * getFunction(MathOp op)
Definition Math.hpp:765
MathLibraryManager(llvm::Module *module, llvm::LLVMContext &context)
Definition Math.hpp:762
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:684
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:726
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:644
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:560
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:477
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:232
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:306
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:410
static llvm::Function * generate(MathFunctionGenerator< VectorWidth > *gen)
Definition Math.hpp:542
int arity
Definition Math.hpp:82
const char * name
Definition Math.hpp:83