#include "pipeline.h"

#include "renderer.h"
#include "texture.h"

pipeline::~pipeline() {
    const renderer* render_vk = application::get()->get_renderer();
    const vk::Device& device = render_vk->device;

    device.destroyDescriptorPool(descriptor_pool_);
    device.destroyDescriptorSetLayout(descriptor_set_layout_);
    device.destroyPipelineLayout(pipeline_layout_);
    device.destroyPipeline(pipeline_);
}

void pipeline::add_uniform_buffer(uint32_t binding, std::shared_ptr<buffer_vk> buf) {
    add_binding(binding, vk::DescriptorType::eUniformBuffer, 1, nullptr);
    buffers_[binding] = buf;
}

void pipeline::add_storage_buffer(uint32_t binding, std::shared_ptr<buffer_vk> buf) {
    add_binding(binding, vk::DescriptorType::eStorageBuffer, 1, nullptr);
    buffers_[binding] = buf;
}

void pipeline::add_sampled_image(uint32_t binding, std::shared_ptr<render_resource> in_texture) {
    add_binding(binding, vk::DescriptorType::eSampledImage, 1, in_texture->sampler);
    textures_[binding] = in_texture;
}

void pipeline::add_storage_image(uint32_t binding, std::shared_ptr<render_resource> in_texture) {
    add_binding(binding, vk::DescriptorType::eStorageImage, 1, in_texture->sampler);
    textures_[binding] = in_texture;
}

void pipeline::add_input_attachment(uint32_t binding) {
    add_binding(binding, vk::DescriptorType::eInputAttachment, 1, nullptr);
}

void pipeline::add_sampler(uint32_t binding, vk::Sampler immutable_samplers) {
    add_binding(binding, vk::DescriptorType::eSampler, 1, immutable_samplers);
}

void pipeline::add_combined_image(uint32_t binding, std::shared_ptr<render_resource> in_texture) {
    add_binding(binding, vk::DescriptorType::eCombinedImageSampler, 1, in_texture->sampler);
    textures_[binding] = in_texture;
}

void pipeline::create_pipeline_layout() {
    const renderer* render_vk = application::get()->get_renderer();
    const vk::Device& device = render_vk->device;

    vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info;
    descriptor_set_layout_create_info.setBindings(descriptor_set_layout_bindings_);
    descriptor_set_layout_ = device.createDescriptorSetLayout(descriptor_set_layout_create_info);

    std::vector<vk::DescriptorPoolSize> pool_sizes;
    {
        std::map<vk::DescriptorType, uint32_t> temp_pool_sizes;
        for (const auto& binding: descriptor_set_layout_bindings_) {
            temp_pool_sizes[binding.descriptorType]++;
        }
        for (const auto& pair: temp_pool_sizes) {
            pool_sizes.emplace_back(pair.first, pair.second);
        }
    }

    vk::DescriptorPoolCreateInfo descriptor_pool_create_info;
    descriptor_pool_create_info.setMaxSets(1);
    descriptor_pool_create_info.setPoolSizes(pool_sizes);
    descriptor_pool_ = device.createDescriptorPool(descriptor_pool_create_info);

    vk::DescriptorSetAllocateInfo descriptor_set_allocate_info;
    descriptor_set_allocate_info.setDescriptorPool(descriptor_pool_);
    descriptor_set_allocate_info.setDescriptorSetCount(1);
    descriptor_set_allocate_info.setSetLayouts(descriptor_set_layout_);
    const std::vector<vk::DescriptorSet>& sets = device.allocateDescriptorSets(descriptor_set_allocate_info);
    descriptor_set_ = sets.front();

    std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
    std::vector<std::vector<vk::DescriptorBufferInfo>> temp_buffer_infos;
    std::vector<std::vector<vk::DescriptorImageInfo>> temp_image_infoses;

    for (const auto& binding_info: descriptor_set_layout_bindings_) {
        vk::WriteDescriptorSet write_descriptor_set;
        write_descriptor_set.setDstSet(descriptor_set_);
        write_descriptor_set.setDstBinding(binding_info.binding);
        write_descriptor_set.setDescriptorType(binding_info.descriptorType);

        switch (binding_info.descriptorType) {
            case vk::DescriptorType::eCombinedImageSampler:
            case vk::DescriptorType::eSampledImage:
            case vk::DescriptorType::eStorageImage: {
                const auto& t = textures_[binding_info.binding];
                vk::DescriptorImageInfo image_info = t->get_descriptor_info();
                temp_image_infoses.push_back({image_info});
                write_descriptor_set.setImageInfo(temp_image_infoses.back());
            }
                break;
            case vk::DescriptorType::eUniformBuffer:
            case vk::DescriptorType::eStorageBuffer: {
                const auto& b = buffers_[binding_info.binding];
                vk::DescriptorBufferInfo buffer_info = b->get_descriptor_info();
                temp_buffer_infos.push_back( {buffer_info} );
                write_descriptor_set.setBufferInfo(temp_buffer_infos.back());
            }
                break;
            default:
                continue;
        }
        write_descriptor_sets.push_back(write_descriptor_set);
    }

    device.updateDescriptorSets(write_descriptor_sets, nullptr);

    vk::PipelineLayoutCreateInfo pipeline_layout_create_info;
    pipeline_layout_create_info.setSetLayouts(descriptor_set_layout_);
    pipeline_layout_ = device.createPipelineLayout(pipeline_layout_create_info);
}