VapourSynth-llvmexpr
Loading...
Searching...
No Matches
VulkanComputePipeline.cpp
Go to the documentation of this file.
1
19
21#include "VulkanContext.hpp"
22#include "VulkanMemory.hpp"
23
24#include <algorithm>
25#include <memory>
26#include <mutex>
27#include <shaderc/shaderc.hpp>
28#include <stdexcept>
29#include <unordered_map>
30
31namespace vkexpr {
32
33namespace {
34constexpr uint32_t WORKGROUP_SIZE = 256;
35
36template <typename T> struct NoDestroyDeleter {
37 void operator()(T* /*unused*/) const noexcept {}
38};
39
40std::unordered_map<std::string, std::vector<uint32_t>>& shader_cache() {
41 using Cache = std::unordered_map<std::string, std::vector<uint32_t>>;
42 static auto cache = []() {
43 auto owned = std::make_unique<Cache>();
44 return std::unique_ptr<Cache, NoDestroyDeleter<Cache>>(owned.release());
45 }();
46 return *cache;
47}
48
49std::mutex& shader_cache_mutex() {
50 static auto mutex = []() {
51 auto owned = std::make_unique<std::mutex>();
52 return std::unique_ptr<std::mutex, NoDestroyDeleter<std::mutex>>(
53 owned.release());
54 }();
55 return *mutex;
56}
57
58} // namespace
59
61 const std::string& glsl_source,
62 uint32_t num_input_buffers,
63 uint32_t num_props_floats)
64 : context(ctx), num_inputs(num_input_buffers),
65 has_props_buffer(num_props_floats > 0) {
66 compileShader(glsl_source);
67 createDescriptorSetLayout(num_input_buffers, has_props_buffer);
68 createPipeline();
69}
70
72
73void VulkanComputePipeline::compileShader(const std::string& glsl_source) {
74 std::lock_guard<std::mutex> lock(shader_cache_mutex());
75
76 auto& cache = shader_cache();
77 if (cache.contains(glsl_source)) {
78 spirv_code = cache[glsl_source];
79 } else {
80 shaderc::Compiler compiler;
81 shaderc::CompileOptions options;
82
83 options.SetOptimizationLevel(shaderc_optimization_level_performance);
84 options.SetTargetEnvironment(shaderc_target_env_vulkan,
85 shaderc_env_version_vulkan_1_2);
86 options.SetTargetSpirv(shaderc_spirv_version_1_5);
87
88 auto result = compiler.CompileGlslToSpv(
89 glsl_source, shaderc_glsl_compute_shader, "compute.glsl", options);
90
91 if (result.GetCompilationStatus() !=
92 shaderc_compilation_status_success) {
93 throw std::runtime_error("Shader compilation failed: " +
94 std::string(result.GetErrorMessage()));
95 }
96
97 spirv_code = {result.cbegin(), result.cend()};
98 cache[glsl_source] = spirv_code;
99 }
100
101 vk::ShaderModuleCreateInfo module_info;
102 module_info.setCode(spirv_code);
103 shader_module = vk::raii::ShaderModule(context.getDevice(), module_info);
104}
105
106void VulkanComputePipeline::createDescriptorSetLayout(
107 uint32_t num_input_buffers, bool with_props_buffer) {
108 std::vector<vk::DescriptorSetLayoutBinding> bindings;
109 uint32_t binding_index = 0;
110
111 // Input buffer(s) - bindings 0 to numInputs-1
112 for (uint32_t i = 0; i < num_input_buffers; ++i) {
113 vk::DescriptorSetLayoutBinding input_binding;
114 input_binding.binding = binding_index++;
115 input_binding.descriptorType = vk::DescriptorType::eStorageBuffer;
116 input_binding.descriptorCount = 1;
117 input_binding.stageFlags = vk::ShaderStageFlagBits::eCompute;
118 bindings.push_back(input_binding);
119 }
120
121 // Output buffer
122 vk::DescriptorSetLayoutBinding output_binding;
123 output_binding.binding = binding_index++;
124 output_binding.descriptorType = vk::DescriptorType::eStorageBuffer;
125 output_binding.descriptorCount = 1;
126 output_binding.stageFlags = vk::ShaderStageFlagBits::eCompute;
127 bindings.push_back(output_binding);
128
129 // Props buffer (if needed)
130 if (with_props_buffer) {
131 vk::DescriptorSetLayoutBinding props_binding;
132 props_binding.binding = binding_index++;
133 props_binding.descriptorType = vk::DescriptorType::eStorageBuffer;
134 props_binding.descriptorCount = 1;
135 props_binding.stageFlags = vk::ShaderStageFlagBits::eCompute;
136 bindings.push_back(props_binding);
137 }
138
139 vk::DescriptorSetLayoutCreateInfo layout_info;
140 layout_info.setBindings(bindings);
141 descriptor_set_layout =
142 vk::raii::DescriptorSetLayout(context.getDevice(), layout_info);
143
144 // Create pipeline layout with push constants
145 vk::PushConstantRange push_const_range;
146 push_const_range.stageFlags = vk::ShaderStageFlagBits::eCompute;
147 push_const_range.offset = 0;
148 push_const_range.size = sizeof(PushConstants);
149
150 vk::PipelineLayoutCreateInfo pipeline_layout_info;
151 pipeline_layout_info.setSetLayouts(*descriptor_set_layout);
152 pipeline_layout_info.setPushConstantRanges(push_const_range);
153 pipeline_layout =
154 vk::raii::PipelineLayout(context.getDevice(), pipeline_layout_info);
155
156 // Create descriptor pool
157 vk::DescriptorPoolSize pool_size;
158 pool_size.type = vk::DescriptorType::eStorageBuffer;
159 // inputs + output + optional props
160 pool_size.descriptorCount =
161 num_input_buffers + 1 + (with_props_buffer ? 1 : 0);
162
163 vk::DescriptorPoolCreateInfo pool_info;
164 pool_info.flags = vk::DescriptorPoolCreateFlagBits::eFreeDescriptorSet;
165 pool_info.maxSets = 1;
166 pool_info.setPoolSizes(pool_size);
167 descriptor_pool = vk::raii::DescriptorPool(context.getDevice(), pool_info);
168
169 // Allocate descriptor set
170 vk::DescriptorSetAllocateInfo alloc_info;
171 alloc_info.descriptorPool = *descriptor_pool;
172 alloc_info.setSetLayouts(*descriptor_set_layout);
173 auto sets = vk::raii::DescriptorSets(context.getDevice(), alloc_info);
174 descriptor_set = std::move(sets[0]);
175}
176
177void VulkanComputePipeline::createPipeline() {
178 vk::PipelineShaderStageCreateInfo stage_info;
179 stage_info.stage = vk::ShaderStageFlagBits::eCompute;
180 stage_info.module = *shader_module;
181 stage_info.pName = "main";
182
183 vk::ComputePipelineCreateInfo pipeline_info;
184 pipeline_info.stage = stage_info;
185 pipeline_info.layout = *pipeline_layout;
186
187 pipeline = vk::raii::Pipeline(context.getDevice(), nullptr, pipeline_info);
188}
189
190void VulkanComputePipeline::updateDescriptorSets(
191 const std::vector<VulkanBuffer*>& input_buffers,
192 VulkanBuffer& output_buffer, VulkanBuffer* props_buffer) {
193
194 // Check if we can skip update
195 bool inputs_changed = false;
196 if (cached_input_buffers.size() != input_buffers.size()) {
197 inputs_changed = true;
198 } else {
199 for (size_t i = 0; i < input_buffers.size(); ++i) {
200 if (cached_input_buffers[i] != input_buffers[i]->buffer) {
201 inputs_changed = true;
202 break;
203 }
204 }
205 }
206
207 VkBuffer new_props_buffer_handle =
208 (props_buffer != nullptr) ? props_buffer->buffer : VK_NULL_HANDLE;
209
210 if (!inputs_changed && cached_output_buffer == output_buffer.buffer &&
211 cached_props_buffer == new_props_buffer_handle) {
212 return;
213 }
214
215 std::vector<vk::WriteDescriptorSet> writes;
216 std::vector<vk::DescriptorBufferInfo> buffer_infos;
217
218 size_t num_buffers =
219 input_buffers.size() + 1 + (props_buffer != nullptr ? 1 : 0);
220 buffer_infos.reserve(num_buffers);
221
222 // Update cache
223 cached_input_buffers.clear();
224 cached_input_buffers.resize(input_buffers.size());
225 std::ranges::transform(input_buffers, cached_input_buffers.begin(),
226 [](const auto* buf) { return buf->buffer; });
227 cached_output_buffer = output_buffer.buffer;
228 cached_props_buffer = new_props_buffer_handle;
229
230 // Input buffers
231 for (auto* input_buffer : input_buffers) {
232 vk::DescriptorBufferInfo buf_info;
233 buf_info.buffer = input_buffer->buffer;
234 buf_info.offset = 0;
235 buf_info.range = VK_WHOLE_SIZE;
236 buffer_infos.push_back(buf_info);
237 }
238
239 // Output buffer
240 vk::DescriptorBufferInfo out_buf_info;
241 out_buf_info.buffer = output_buffer.buffer;
242 out_buf_info.offset = 0;
243 out_buf_info.range = VK_WHOLE_SIZE;
244 buffer_infos.push_back(out_buf_info);
245
246 // Props buffer
247 if (props_buffer != nullptr) {
248 vk::DescriptorBufferInfo props_buf_info;
249 props_buf_info.buffer = props_buffer->buffer;
250 props_buf_info.offset = 0;
251 props_buf_info.range = VK_WHOLE_SIZE;
252 buffer_infos.push_back(props_buf_info);
253 }
254
255 // Create write descriptor sets
256 for (size_t i = 0; i < buffer_infos.size(); ++i) {
257 vk::WriteDescriptorSet write;
258 write.dstSet = *descriptor_set;
259 write.dstBinding = static_cast<uint32_t>(i);
260 write.dstArrayElement = 0;
261 write.descriptorCount = 1;
262 write.descriptorType = vk::DescriptorType::eStorageBuffer;
263 write.setBufferInfo(buffer_infos[i]);
264 writes.push_back(write);
265 }
266
267 context.getDevice().updateDescriptorSets(writes, {});
268}
269
271 vk::raii::CommandBuffer& command_buffer,
272 const std::vector<VulkanBuffer*>& input_buffers,
273 VulkanBuffer& output_buffer, VulkanBuffer* props_buffer, uint32_t width,
274 uint32_t height, int32_t frame_number) {
275 updateDescriptorSets(input_buffers, output_buffer, props_buffer);
276
277 command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, *pipeline);
278 command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
279 *pipeline_layout, 0, *descriptor_set, {});
280
281 PushConstants pc = {.width = width,
282 .height = height,
283 .num_inputs = num_inputs,
284 .frame_number = frame_number};
285 command_buffer.pushConstants<PushConstants>(
286 *pipeline_layout, vk::ShaderStageFlagBits::eCompute, 0, pc);
287
288 uint32_t total_pixels = width * height;
289 uint32_t num_workgroups =
290 (total_pixels + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
291 command_buffer.dispatch(num_workgroups, 1, 1);
292}
293
294} // namespace vkexpr
VulkanComputePipeline(VulkanContext &ctx, const std::string &glsl_source, uint32_t num_input_buffers=1, uint32_t num_props_floats=1)
void recordDispatch(vk::raii::CommandBuffer &command_buffer, const std::vector< VulkanBuffer * > &input_buffers, VulkanBuffer &output_buffer, VulkanBuffer *props_buffer, uint32_t width, uint32_t height, int32_t frame_number)