#include "compute_pipeline.h" #include "renderer.h" #include void compute_pipeline::add_binding(uint32_t binding, vk::DescriptorType descriptor_type, uint32_t descriptor_count, const vk::Sampler* immutable_samplers) { const vk::ShaderStageFlags flag = vk::ShaderStageFlagBits::eCompute; vk::DescriptorSetLayoutBinding descriptor_set_layout_binding; descriptor_set_layout_binding.setBinding(binding); descriptor_set_layout_binding.setDescriptorType(descriptor_type); descriptor_set_layout_binding.setDescriptorCount(descriptor_count); descriptor_set_layout_binding.setStageFlags(flag); descriptor_set_layout_binding.setPImmutableSamplers(immutable_samplers); descriptor_set_layout_bindings_.push_back(descriptor_set_layout_binding); } void compute_pipeline::create() { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; vk::PipelineCacheCreateInfo pipeline_cache_create_info; pipeline_cache_create_info.setInitialDataSize(0); pipeline_cache_ = device.createPipelineCache(pipeline_cache_create_info); create_pipeline_layout(); vk::PipelineShaderStageCreateInfo pipeline_shader_stage_create_info; pipeline_shader_stage_create_info.setStage(vk::ShaderStageFlagBits::eCompute); pipeline_shader_stage_create_info.setModule(shader_module_); pipeline_shader_stage_create_info.setPName("main"); vk::ComputePipelineCreateInfo pipeline_create_info; pipeline_create_info.setLayout(pipeline_layout_); pipeline_create_info.setStage(pipeline_shader_stage_create_info); const auto pipeline_result = device.createComputePipeline(pipeline_cache_, pipeline_create_info); check_vk_result(pipeline_result.result); pipeline_ = pipeline_result.value; } void compute_pipeline::destroy() { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; device.destroyPipeline(pipeline_); device.destroyPipelineCache(pipeline_cache_); device.destroyPipelineLayout(pipeline_layout_); } void compute_pipeline::dispatch(uint32_t group_count_x, uint32_t group_count_y, uint32_t group_count_z) const { #ifdef _DEBUG if (pipeline_ == vk::Pipeline()) { throw std::runtime_error("Pipeline not created"); } #endif const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; const vk::CommandPool& command_pool = render_vk->get_command_pool(); vk::CommandBufferAllocateInfo command_buffer_allocate_info; command_buffer_allocate_info.setCommandPool(command_pool); command_buffer_allocate_info.setLevel(vk::CommandBufferLevel::ePrimary); command_buffer_allocate_info.setCommandBufferCount(1); const vk::CommandBuffer command_buffer = device.allocateCommandBuffers(command_buffer_allocate_info)[0]; vk::CommandBufferBeginInfo command_buffer_begin_info; command_buffer_begin_info.setFlags(vk::CommandBufferUsageFlagBits::eOneTimeSubmit); command_buffer.begin(command_buffer_begin_info); command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline_); command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline_layout_, 0, descriptor_set_, nullptr); command_buffer.dispatch(group_count_x, group_count_y, group_count_z); command_buffer.end(); vk::SubmitInfo submit_info; submit_info.setCommandBufferCount(1); submit_info.setPCommandBuffers(&command_buffer); render_vk->queue.submit(submit_info, nullptr); render_vk->queue.waitIdle(); device.freeCommandBuffers(command_pool, command_buffer); device.destroyCommandPool(command_pool); } void compute_pipeline::set_shader(const uint8_t* shader_code, size_t shader_code_size) { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; vk::ShaderModuleCreateInfo shader_module_create_info; shader_module_create_info.setCodeSize(shader_code_size); shader_module_create_info.setPCode(reinterpret_cast(shader_code)); shader_module_ = device.createShaderModule(shader_module_create_info); }