AronaCore/core/rhi/compute_pipeline.cpp
2024-02-20 14:09:44 +08:00

74 lines
3.1 KiB
C++

#include "compute_pipeline.h"
#include "renderer.h"
#include <vulkan/vulkan.hpp>
compute_pipeline::~compute_pipeline() {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
device.destroyShaderModule(shader_module_);
}
void compute_pipeline::add_binding(uint32_t binding, vk::DescriptorType descriptor_type, uint32_t descriptor_count,
vk::Sampler immutable_samplers) {
const vk::ShaderStageFlags flag = vk::ShaderStageFlagBits::eCompute;
vk::DescriptorSetLayoutBinding descriptor_set_layout_binding;
descriptor_set_layout_binding.setBinding(binding);
descriptor_set_layout_binding.setDescriptorType(descriptor_type);
descriptor_set_layout_binding.setDescriptorCount(descriptor_count);
descriptor_set_layout_binding.setStageFlags(flag);
if (immutable_samplers)
descriptor_set_layout_binding.setImmutableSamplers(immutable_samplers);
descriptor_set_layout_bindings_.push_back(descriptor_set_layout_binding);
}
void compute_pipeline::create() {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
create_pipeline_layout();
vk::PipelineShaderStageCreateInfo pipeline_shader_stage_create_info;
pipeline_shader_stage_create_info.setStage(vk::ShaderStageFlagBits::eCompute);
pipeline_shader_stage_create_info.setModule(shader_module_);
pipeline_shader_stage_create_info.setPName("main");
vk::ComputePipelineCreateInfo pipeline_create_info;
pipeline_create_info.setLayout(pipeline_layout_);
pipeline_create_info.setStage(pipeline_shader_stage_create_info);
const auto pipeline_result = device.createComputePipeline(VK_NULL_HANDLE, pipeline_create_info);
check_vk_result(pipeline_result.result);
pipeline_ = pipeline_result.value;
}
void compute_pipeline::dispatch(uint32_t group_count_x, uint32_t group_count_y, uint32_t group_count_z) const {
#ifdef _DEBUG
if (pipeline_ == vk::Pipeline()) {
throw std::runtime_error("Pipeline not created");
}
#endif
renderer* render_vk = application::get()->get_renderer();
const vk::CommandBuffer command_buffer = render_vk->create_command_buffer(vk::CommandBufferLevel::ePrimary, true);
command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline_);
command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline_layout_, 0, descriptor_set_, nullptr);
command_buffer.dispatch(group_count_x, group_count_y, group_count_z);
render_vk->end_command_buffer(command_buffer, true);
}
void compute_pipeline::set_shader(const uint8_t* shader_code, size_t shader_code_size) {
const renderer* render_vk = application::get()->get_renderer();
const vk::Device& device = render_vk->device;
vk::ShaderModuleCreateInfo shader_module_create_info;
shader_module_create_info.setCodeSize(shader_code_size);
shader_module_create_info.setPCode(reinterpret_cast<const uint32_t*>(shader_code));
shader_module_ = device.createShaderModule(shader_module_create_info);
}