AronaCore/core/rhi/compute_pipeline.cpp
2024-02-20 10:06:45 +08:00

99 lines
4.2 KiB
C++

#include "compute_pipeline.h"
#include "renderer.h"
#include <vulkan/vulkan.hpp>
void compute_pipeline::add_binding(uint32_t binding, vk::DescriptorType descriptor_type, uint32_t descriptor_count,
const vk::Sampler* immutable_samplers) {
const vk::ShaderStageFlags flag = vk::ShaderStageFlagBits::eCompute;
vk::DescriptorSetLayoutBinding descriptor_set_layout_binding;
descriptor_set_layout_binding.setBinding(binding);
descriptor_set_layout_binding.setDescriptorType(descriptor_type);
descriptor_set_layout_binding.setDescriptorCount(descriptor_count);
descriptor_set_layout_binding.setStageFlags(flag);
descriptor_set_layout_binding.setPImmutableSamplers(immutable_samplers);
descriptor_set_layout_bindings_.push_back(descriptor_set_layout_binding);
}
void compute_pipeline::create() {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
vk::PipelineCacheCreateInfo pipeline_cache_create_info;
pipeline_cache_create_info.setInitialDataSize(0);
pipeline_cache_ = device.createPipelineCache(pipeline_cache_create_info);
create_pipeline_layout();
vk::PipelineShaderStageCreateInfo pipeline_shader_stage_create_info;
pipeline_shader_stage_create_info.setStage(vk::ShaderStageFlagBits::eCompute);
pipeline_shader_stage_create_info.setModule(shader_module_);
pipeline_shader_stage_create_info.setPName("main");
vk::ComputePipelineCreateInfo pipeline_create_info;
pipeline_create_info.setLayout(pipeline_layout_);
pipeline_create_info.setStage(pipeline_shader_stage_create_info);
const auto pipeline_result = device.createComputePipeline(pipeline_cache_, pipeline_create_info);
check_vk_result(pipeline_result.result);
pipeline_ = pipeline_result.value;
}
void compute_pipeline::destroy() {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
device.destroyPipeline(pipeline_);
device.destroyPipelineCache(pipeline_cache_);
device.destroyPipelineLayout(pipeline_layout_);
}
void compute_pipeline::dispatch(uint32_t group_count_x, uint32_t group_count_y, uint32_t group_count_z) const {
#ifdef _DEBUG
if (pipeline_ == vk::Pipeline()) {
throw std::runtime_error("Pipeline not created");
}
#endif
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
const vk::CommandPool& command_pool = render_vk->get_command_pool();
vk::CommandBufferAllocateInfo command_buffer_allocate_info;
command_buffer_allocate_info.setCommandPool(command_pool);
command_buffer_allocate_info.setLevel(vk::CommandBufferLevel::ePrimary);
command_buffer_allocate_info.setCommandBufferCount(1);
const vk::CommandBuffer command_buffer = device.allocateCommandBuffers(command_buffer_allocate_info)[0];
vk::CommandBufferBeginInfo command_buffer_begin_info;
command_buffer_begin_info.setFlags(vk::CommandBufferUsageFlagBits::eOneTimeSubmit);
command_buffer.begin(command_buffer_begin_info);
command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline_);
command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline_layout_, 0, descriptor_set_, nullptr);
command_buffer.dispatch(group_count_x, group_count_y, group_count_z);
command_buffer.end();
vk::SubmitInfo submit_info;
submit_info.setCommandBufferCount(1);
submit_info.setPCommandBuffers(&command_buffer);
render_vk->queue.submit(submit_info, nullptr);
render_vk->queue.waitIdle();
device.freeCommandBuffers(command_pool, command_buffer);
device.destroyCommandPool(command_pool);
}
void compute_pipeline::set_shader(const uint8_t* shader_code, size_t shader_code_size) {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
vk::ShaderModuleCreateInfo shader_module_create_info;
shader_module_create_info.setCodeSize(shader_code_size);
shader_module_create_info.setPCode(reinterpret_cast<const uint32_t*>(shader_code));
shader_module_ = device.createShaderModule(shader_module_create_info);
}