新增dx11根据反射信息自动绑定到目标binding point

This commit is contained in:
Nanako 2024-02-07 22:59:26 +08:00
parent 93980885a6
commit 3a3d3caa1a
15 changed files with 248 additions and 173 deletions

View File

@ -22,7 +22,7 @@ public:
// param setters
virtual void set_cbuffer(const char* name, void* buffer, int size) = 0;
virtual void set_uav_buffer(const char* name, void* buffer, int size, int element_size) {}
virtual void set_uav_buffer(const char* name, void* buffer, int count, int element_size) {}
virtual void set_texture(const char* name, std::shared_ptr<texture> in_texture) = 0;
virtual void set_render_target(const char* name, std::shared_ptr<render_target> in_render_target) = 0;
@ -40,12 +40,12 @@ public:
template<typename T>
void set_uav_buffer(const char* name, const T& buffer)
{
set_uav_buffer(name, (void*)&buffer, sizeof(T));
set_uav_buffer(name, (void*)&buffer, sizeof(buffer), sizeof(T));
}
template<typename T>
void set_uav_buffer(const char* name, const std::vector<T>& buffer)
{
set_uav_buffer(name, (void*)buffer.data(), sizeof(T) * buffer.size());
set_uav_buffer(name, (void*)buffer.data(), buffer.size(), sizeof(T));
}
protected:
std::shared_ptr<slang_handle> handle_;

View File

@ -1,6 +1,17 @@
#pragma once
#include <d3dcompiler.h>
typedef HRESULT(WINAPI* pD3DReflect)
(__in_bcount(SrcDataSize) LPCVOID pSrcData,
__in SIZE_T SrcDataSize,
__in REFIID pInterface,
__out void** ppReflector);
#define DEFINE_GUID_FOR_CURRENT_COMPILER(name, l, w1, w2, b1, b2, b3, b4, b5, b6, b7, b8) \
static const GUID name = { l, w1, w2, { b1, b2, b3, b4, b5, b6, b7, b8 } }
DEFINE_GUID_FOR_CURRENT_COMPILER(IID_ID3D11ShaderReflectionForCurrentCompiler, 0x8d536ca1, 0x0cca, 0x4956, 0xa8, 0x37, 0x78, 0x69, 0x63, 0x75, 0x55, 0x84);
inline HMODULE get_compiler_module()
{
static HMODULE compiler_dll = nullptr;
@ -25,3 +36,15 @@ inline pD3DCompile get_d3d_compile_func()
return nullptr;
}
// @return pointer to the D3DCompile function
inline pD3DReflect get_d3d_reflect_func()
{
static HMODULE CompilerDLL = get_compiler_module();
if (CompilerDLL)
{
return (pD3DReflect)(void*)GetProcAddress(CompilerDLL, "D3DReflect");
}
return nullptr;
}

View File

@ -23,6 +23,8 @@ public:
void unlock() override;
void release();
int binding_point = -1;
protected:
void on_resize(int width, int height) override;
private:

View File

@ -47,6 +47,8 @@ public:
memcpy(data, buffer_data_, element_size_);
unlock();
}
int binding_point = -1;
private:
[[nodiscard]] void* lock() const
{

View File

@ -20,48 +20,35 @@ void shader_cs_dx11::bind()
g_d3d11_device_context->CSGetShader(&prev_shader_, prev_class_instances_, &prev_class_instances_num_);
g_d3d11_device_context->CSSetShader(compute_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
for (const auto& buffer : constant_buffers_ | std::views::values)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->CSSetConstantBuffers(0, constant_num, buffers.data());
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->CSSetConstantBuffers(buffer->binding_point, 1, &b);
}
if (const size_t uav_buffer_num = uav_buffers_.size(); uav_buffer_num > 0)
for (const auto& buffer : uav_buffers_ | std::views::values)
{
std::vector<ID3D11UnorderedAccessView*> uavs;
uavs.reserve(uav_buffer_num);
for (const auto& buffer : uav_buffers_ | std::views::values)
{
buffer->update_buffer();
uavs.push_back(buffer->get_uav());
}
g_d3d11_device_context->CSSetUnorderedAccessViews(0, uav_buffer_num, uavs.data(), nullptr);
buffer->update_buffer();
auto u = buffer->get_uav();
auto p = u.get_reference();
g_d3d11_device_context->CSSetUnorderedAccessViews(buffer->binding_point, 1, &p, nullptr);
}
for (const auto& texture : textures_ | std::views::values)
{
const unsigned int texture_num = textures_.size();
const unsigned int render_target_num = render_targets_.size();
if (const unsigned int num = texture_num + render_target_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->CSSetShaderResources(0, texture_num, srvs.data());
}
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->CSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->CSSetShaderResources(texture->binding_point, 1, &p);
}
}
@ -84,10 +71,18 @@ HRESULT shader_cs_dx11::create_shader(ID3DBlob* blob, ID3D11Device* device)
void shader_cs_dx11::set_uav_buffer(const char* name, void* buffer, int size, int element_size)
{
const auto& binding = bindings_.find(name);
if (binding == bindings_.end())
{
spdlog::warn("uav buffer {} not found in shader", name);
return;
}
const auto find = uav_buffers_.find(name);
if (find == uav_buffers_.end())
{
auto uav_buffer = std::make_shared<uav_buffer_dx11>();
uav_buffer->binding_point = binding->second.binding;
uav_buffer->create(buffer, size, element_size);
uav_buffers_.insert({name, uav_buffer});
return;
@ -98,15 +93,26 @@ void shader_cs_dx11::set_uav_buffer(const char* name, void* buffer, int size, in
void shader_cs_dx11::set_render_target(const char* name, std::shared_ptr<render_target> in_render_target)
{
const auto& binding = bindings_.find(name);
if (binding == bindings_.end())
{
spdlog::warn("render target {} not found in shader", name);
return;
}
std::shared_ptr<render_target_dx11> rt = std::static_pointer_cast<render_target_dx11>(in_render_target);
const auto find = uav_buffers_.find(name);
if (find == uav_buffers_.end())
{
auto uav_buffer = std::make_shared<uav_buffer_dx11>();
uav_buffer->create_from_render_target(rt);
if (in_render_target)
uav_buffer->create_from_render_target(rt);
uav_buffer->binding_point = binding->second.binding;
uav_buffers_.insert({name, uav_buffer});
return;
}
if (in_render_target)
find->second->create_from_render_target(rt);
}
void shader_cs_dx11::compute(int x, int y, int z)

View File

@ -18,35 +18,26 @@ void shader_ds_dx11::bind()
g_d3d11_device_context->DSGetShader(&prev_shader_, prev_class_instances_, &prev_class_instances_num_);
g_d3d11_device_context->DSSetShader(domain_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
for (const auto& buffer : constant_buffers_ | std::views::values)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->DSSetConstantBuffers(0, constant_num, buffers.data());
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->DSSetConstantBuffers(buffer->binding_point, 1, &b);
}
for (const auto& texture : textures_ | std::views::values)
{
const size_t render_target_num = render_targets_.size();
const size_t texture_num = textures_.size();
if (const unsigned int num = render_target_num + texture_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->DSSetShaderResources(0, texture_num, srvs.data());
}
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->DSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->DSSetShaderResources(texture->binding_point, 1, &p);
}
}

View File

@ -28,6 +28,12 @@ bool shader_dx11::init()
spdlog::critical("slang: get D3DCompile function failed");
return false;
}
const auto reflect_func = get_d3d_reflect_func();
if (!reflect_func)
{
spdlog::critical("slang: get D3DReflect function failed");
return false;
}
unsigned int shader_flags = D3DCOMPILE_ENABLE_STRICTNESS;
#if _DEBUG
@ -62,16 +68,34 @@ bool shader_dx11::init()
return false;
}
Slang::ComPtr<ID3D11ShaderReflection> reflector;
hr = reflect_func(kernel_blob->GetBufferPointer(), kernel_blob->GetBufferSize(), IID_ID3D11ShaderReflectionForCurrentCompiler, (void**)reflector.writeRef());
if (FAILED(hr))
{
spdlog::error("reflect shader failed: {:x}", hr);
return false;
}
init_shader_bindings(reflector);
return true;
}
void shader_dx11::set_cbuffer(const char* name, void* buffer, int size)
{
const auto& binding = bindings_.find(name);
if (binding == bindings_.end())
{
spdlog::warn("constant buffer {} not found in shader", name);
return;
}
auto find = constant_buffers_.find(name);
if (find == constant_buffers_.end())
{
auto constant_buffer = std::make_shared<constant_buffer_dx11>();
constant_buffer->create(size);
constant_buffer->binding_point = binding->second.binding;
constant_buffers_.insert({name, constant_buffer});
find = constant_buffers_.find(name);
}
@ -81,10 +105,18 @@ void shader_dx11::set_cbuffer(const char* name, void* buffer, int size)
void shader_dx11::set_texture(const char* name, std::shared_ptr<texture> in_texture)
{
const auto& binding = bindings_.find(name);
if (binding == bindings_.end())
{
spdlog::warn("texture {} not found in shader", name);
return;
}
std::shared_ptr<texture_dx11> dx11_t = std::static_pointer_cast<texture_dx11>(in_texture);
const auto find = textures_.find(name);
if (find == textures_.end())
{
dx11_t->binding_point = binding->second.binding;
textures_.insert({name, dx11_t});
return;
}
@ -92,12 +124,46 @@ void shader_dx11::set_texture(const char* name, std::shared_ptr<texture> in_text
void shader_dx11::set_render_target(const char* name, std::shared_ptr<render_target> in_render_target)
{
const auto& binding = bindings_.find(name);
if (binding == bindings_.end())
{
spdlog::warn("render target {} not found in shader", name);
return;
}
std::shared_ptr<render_target_dx11> rt = std::static_pointer_cast<render_target_dx11>(in_render_target);
const auto find = render_targets_.find(name);
if (find == render_targets_.end())
{
rt->binding_point = binding->second.binding;
render_targets_.insert({name, rt});
return;
}
}
void shader_dx11::init_shader_bindings(ID3D11ShaderReflection* reflector)
{
D3D11_SHADER_DESC shader_desc;
auto hr = reflector->GetDesc(&shader_desc);
for (unsigned int i = 0; i < shader_desc.BoundResources; ++i)
{
D3D11_SHADER_INPUT_BIND_DESC desc;
hr = reflector->GetResourceBindingDesc(i, &desc);
dx11_binding_data binding_data;
binding_data.binding = desc.BindPoint;
binding_data.type = desc.Type;
binding_data.return_type = desc.ReturnType;
binding_data.dimension = desc.Dimension;
binding_data.num_samples = desc.NumSamples;
std::string clean_name = desc.Name;
if (const size_t clean_index = clean_name.find('_'); clean_index != std::string::npos)
{
clean_name = clean_name.substr(0, clean_index);
}
bindings_.insert(std::pair(clean_name, binding_data));
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include "rhi/shader.h"
#include <d3d11.h>
#include <d3d11shader.h>
#include <map>
class render_target_dx11;
@ -8,6 +9,15 @@ class texture_dx11;
class uav_buffer_dx11;
class constant_buffer_dx11;
struct dx11_binding_data
{
D3D_SHADER_INPUT_TYPE type; // Type of resource (e.g. texture, cbuffer, etc.)
unsigned int binding; // Starting bind point
D3D_RESOURCE_RETURN_TYPE return_type; // Return type (if texture)
D3D_SRV_DIMENSION dimension; // Dimension (if texture)
unsigned int num_samples ;// Number of samples (0 if not MS texture)
};
class shader_dx11 : public shader
{
public:
@ -26,8 +36,13 @@ public:
void set_cbuffer(const char* name, void* buffer, int size) override;
void set_texture(const char* name, std::shared_ptr<texture> in_texture) override;
void set_render_target(const char* name, std::shared_ptr<render_target> in_render_target) override;
const auto& get_bindings() const { return bindings_; }
protected:
std::map<std::string, std::shared_ptr<constant_buffer_dx11>> constant_buffers_;
std::map<std::string, std::shared_ptr<texture_dx11>> textures_;
std::map<std::string, std::shared_ptr<render_target_dx11>> render_targets_;
std::map<std::string, dx11_binding_data> bindings_;
private:
void init_shader_bindings(ID3D11ShaderReflection* reflector);
};

View File

@ -18,35 +18,26 @@ void shader_gs_dx11::bind()
g_d3d11_device_context->GSGetShader(&prev_shader_, prev_class_instances_, &prev_class_instances_num_);
g_d3d11_device_context->GSSetShader(geometry_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
for (const auto& buffer : constant_buffers_ | std::views::values)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->GSSetConstantBuffers(0, constant_num, buffers.data());
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->GSSetConstantBuffers(buffer->binding_point, 1, &b);
}
for (const auto& texture : textures_ | std::views::values)
{
const size_t render_target_num = render_targets_.size();
const size_t texture_num = textures_.size();
if (const unsigned int num = render_target_num + texture_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->GSSetShaderResources(0, texture_num, srvs.data());
}
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->GSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->GSSetShaderResources(texture->binding_point, 1, &p);
}
}

View File

@ -18,35 +18,26 @@ void shader_hs_dx11::bind()
g_d3d11_device_context->HSGetShader(&prev_shader_, prev_class_instances_, &prev_class_instances_num_);
g_d3d11_device_context->HSSetShader(hull_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
for (const auto& buffer : constant_buffers_ | std::views::values)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->HSSetConstantBuffers(0, constant_num, buffers.data());
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->HSSetConstantBuffers(buffer->binding_point, 1, &b);
}
for (const auto& texture : textures_ | std::views::values)
{
const size_t render_target_num = render_targets_.size();
const size_t texture_num = textures_.size();
if (const unsigned int num = render_target_num + texture_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->HSSetShaderResources(0, texture_num, srvs.data());
}
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->HSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->HSSetShaderResources(texture->binding_point, 1, &p);
}
}

View File

@ -22,35 +22,26 @@ void shader_ps_dx11::bind()
g_d3d11_device_context->PSGetShader(&prev_shader_, prev_class_instances_, &prev_class_instances_num_);
g_d3d11_device_context->PSSetShader(pixel_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->PSSetConstantBuffers(0, constant_num, buffers.data());
}
for (const auto& buffer : constant_buffers_ | std::views::values)
{
const size_t render_target_num = render_targets_.size();
const size_t texture_num = textures_.size();
if (const unsigned int num = render_target_num + texture_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->PSSetShaderResources(0, texture_num, srvs.data());
}
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->PSSetConstantBuffers(buffer->binding_point, 1, &b);
}
for (const auto& texture : textures_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->PSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->PSSetShaderResources(texture->binding_point, 1, &p);
}
}

View File

@ -20,35 +20,26 @@ void shader_vs_dx11::bind()
g_d3d11_device_context->IASetInputLayout(input_layout_);
g_d3d11_device_context->VSSetShader(vertex_shader_, nullptr, 0);
if (const unsigned int constant_num = constant_buffers_.size(); constant_num > 0)
for (const auto& buffer : constant_buffers_ | std::views::values)
{
std::vector<ID3D11Buffer*> buffers;
buffers.reserve(constant_num);
for (const auto& buffer : constant_buffers_ | std::views::values)
{
buffers.push_back(buffer->get_resource());
}
g_d3d11_device_context->VSSetConstantBuffers(0, constant_num, buffers.data());
ref_count_ptr<ID3D11Buffer> cb = buffer->get_resource();
ID3D11Buffer* b = cb.get_reference();
g_d3d11_device_context->VSSetConstantBuffers(buffer->binding_point, 1, &b);
}
for (const auto& texture : textures_ | std::views::values)
{
const size_t render_target_num = render_targets_.size();
const size_t texture_num = textures_.size();
if (const unsigned int num = render_target_num + texture_num; num > 0)
{
std::vector<ID3D11ShaderResourceView*> srvs;
srvs.reserve(num);
for (const auto& texture : textures_ | std::views::values)
{
srvs.push_back(texture->get_srv());
}
for (const auto& render_target : render_targets_ | std::views::values)
{
srvs.push_back(render_target->get_srv());
}
g_d3d11_device_context->VSSetShaderResources(0, texture_num, srvs.data());
}
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->VSSetShaderResources(texture->binding_point, 1, &p);
}
for (const auto& texture : render_targets_ | std::views::values)
{
auto srv = texture->get_srv();
auto p = srv.get_reference();
g_d3d11_device_context->VSSetShaderResources(texture->binding_point, 1, &p);
}
}

View File

@ -15,8 +15,10 @@ void uav_buffer_dx11::create(const void* in_init_data, int in_count, int element
HRESULT hr;
if (in_init_data)
{
D3D11_SUBRESOURCE_DATA init_data;
D3D11_SUBRESOURCE_DATA init_data = {};
init_data.pSysMem = in_init_data;
init_data.SysMemPitch = 0;
init_data.SysMemSlicePitch = 0;
hr = g_d3d11_device->CreateBuffer(&buffer_desc, &init_data, buffer_.get_init_reference());
}
else

View File

@ -21,6 +21,8 @@ public:
ref_count_ptr<ID3D11Buffer> get_resource() { return buffer_; }
ref_count_ptr<ID3D11UnorderedAccessView> get_uav() { return uav_; }
ref_count_ptr<ID3D11ShaderResourceView> get_srv() { return srv_; }
int binding_point = -1;
private:
ref_count_ptr<ID3D11Buffer> buffer_; // nullptr when created from render target

View File

@ -11,6 +11,8 @@ public:
[[nodiscard]] ref_count_ptr<ID3D11ShaderResourceView> get_srv() { return shader_resource_view_; }
bool init_data(const unsigned char* data, int width, int height) override;
[[nodiscard]] bool is_valid() const override { return shader_resource_view_.is_valid(); }
int binding_point = -1;
private:
ref_count_ptr<ID3D11ShaderResourceView> shader_resource_view_;
};