74 lines
3.1 KiB
C++
74 lines
3.1 KiB
C++
#include "compute_pipeline.h"
|
|
|
|
#include "renderer.h"
|
|
|
|
#include <vulkan/vulkan.hpp>
|
|
|
|
compute_pipeline::~compute_pipeline() {
|
|
const renderer* render_vk = application::get()->get_renderer();
|
|
const vk::Device& device = render_vk->device;
|
|
|
|
device.destroyShaderModule(shader_module_);
|
|
}
|
|
|
|
void compute_pipeline::add_binding(uint32_t binding, vk::DescriptorType descriptor_type, uint32_t descriptor_count,
|
|
vk::Sampler immutable_samplers) {
|
|
const vk::ShaderStageFlags flag = vk::ShaderStageFlagBits::eCompute;
|
|
|
|
vk::DescriptorSetLayoutBinding descriptor_set_layout_binding;
|
|
descriptor_set_layout_binding.setBinding(binding);
|
|
descriptor_set_layout_binding.setDescriptorType(descriptor_type);
|
|
descriptor_set_layout_binding.setDescriptorCount(descriptor_count);
|
|
descriptor_set_layout_binding.setStageFlags(flag);
|
|
if (immutable_samplers)
|
|
descriptor_set_layout_binding.setImmutableSamplers(immutable_samplers);
|
|
descriptor_set_layout_bindings_.push_back(descriptor_set_layout_binding);
|
|
}
|
|
|
|
void compute_pipeline::create() {
|
|
const renderer* render_vk = application::get()->get_renderer();
|
|
const vk::Device& device = render_vk->device;
|
|
|
|
create_pipeline_layout();
|
|
|
|
vk::PipelineShaderStageCreateInfo pipeline_shader_stage_create_info;
|
|
pipeline_shader_stage_create_info.setStage(vk::ShaderStageFlagBits::eCompute);
|
|
pipeline_shader_stage_create_info.setModule(shader_module_);
|
|
pipeline_shader_stage_create_info.setPName("main");
|
|
|
|
vk::ComputePipelineCreateInfo pipeline_create_info;
|
|
pipeline_create_info.setLayout(pipeline_layout_);
|
|
pipeline_create_info.setStage(pipeline_shader_stage_create_info);
|
|
const auto pipeline_result = device.createComputePipeline(VK_NULL_HANDLE, pipeline_create_info);
|
|
|
|
check_vk_result(pipeline_result.result);
|
|
pipeline_ = pipeline_result.value;
|
|
}
|
|
|
|
void compute_pipeline::dispatch(uint32_t group_count_x, uint32_t group_count_y, uint32_t group_count_z) const {
|
|
#ifdef _DEBUG
|
|
if (pipeline_ == vk::Pipeline()) {
|
|
throw std::runtime_error("Pipeline not created");
|
|
}
|
|
#endif
|
|
renderer* render_vk = application::get()->get_renderer();
|
|
|
|
const vk::CommandBuffer command_buffer = render_vk->create_command_buffer(vk::CommandBufferLevel::ePrimary, true);
|
|
|
|
command_buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline_);
|
|
command_buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline_layout_, 0, descriptor_set_, nullptr);
|
|
command_buffer.dispatch(group_count_x, group_count_y, group_count_z);
|
|
|
|
render_vk->end_command_buffer(command_buffer, true);
|
|
}
|
|
|
|
void compute_pipeline::set_shader(const uint8_t* shader_code, size_t shader_code_size) {
|
|
const renderer* render_vk = application::get()->get_renderer();
|
|
const vk::Device& device = render_vk->device;
|
|
|
|
vk::ShaderModuleCreateInfo shader_module_create_info;
|
|
shader_module_create_info.setCodeSize(shader_code_size);
|
|
shader_module_create_info.setPCode(reinterpret_cast<const uint32_t*>(shader_code));
|
|
shader_module_ = device.createShaderModule(shader_module_create_info);
|
|
}
|