#include "compute_pipeline.h" #include "renderer.h" #include compute_pipeline::~compute_pipeline() { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; device.destroyShaderModule(shader_module_); } void compute_pipeline::add_binding(uint32_t binding, vk::DescriptorType descriptor_type, uint32_t descriptor_count, vk::Sampler immutable_samplers) { const vk::ShaderStageFlags flag = vk::ShaderStageFlagBits::eCompute; vk::DescriptorSetLayoutBinding descriptor_set_layout_binding; descriptor_set_layout_binding.setBinding(binding); descriptor_set_layout_binding.setDescriptorType(descriptor_type); descriptor_set_layout_binding.setDescriptorCount(descriptor_count); descriptor_set_layout_binding.setStageFlags(flag); if (immutable_samplers) descriptor_set_layout_binding.setImmutableSamplers(immutable_samplers); descriptor_set_layout_bindings_.push_back(descriptor_set_layout_binding); } void compute_pipeline::create() { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; create_pipeline_layout(); vk::PipelineShaderStageCreateInfo pipeline_shader_stage_create_info; pipeline_shader_stage_create_info.setStage(vk::ShaderStageFlagBits::eCompute); pipeline_shader_stage_create_info.setModule(shader_module_); pipeline_shader_stage_create_info.setPName("main"); vk::ComputePipelineCreateInfo pipeline_create_info; pipeline_create_info.setLayout(pipeline_layout_); pipeline_create_info.setStage(pipeline_shader_stage_create_info); const auto pipeline_result = device.createComputePipeline(VK_NULL_HANDLE, pipeline_create_info); check_vk_result(pipeline_result.result); pipeline_ = pipeline_result.value; } void compute_pipeline::dispatch(uint32_t group_count_x, uint32_t group_count_y, uint32_t group_count_z) const { #ifdef _DEBUG if (pipeline_ == vk::Pipeline()) { throw std::runtime_error("Pipeline not created"); } #endif renderer* render_vk = application::get()->get_renderer(); const vk::CommandBuffer command_buffer = render_vk->create_command_buffer(vk::CommandBufferLevel::ePrimary, true); command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline_); command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline_layout_, 0, descriptor_set_, nullptr); command_buffer.dispatch(group_count_x, group_count_y, group_count_z); render_vk->end_command_buffer(command_buffer, true); } void compute_pipeline::set_shader(const uint8_t* shader_code, size_t shader_code_size) { const renderer* render_vk = application::get()->get_renderer(); const vk::Device& device = render_vk->device; vk::ShaderModuleCreateInfo shader_module_create_info; shader_module_create_info.setCodeSize(shader_code_size); shader_module_create_info.setPCode(reinterpret_cast(shader_code)); shader_module_ = device.createShaderModule(shader_module_create_info); }