Skip to content

Commit

Permalink
Updated spirv-cross.
Browse files Browse the repository at this point in the history
  • Loading branch information
bkaradzic committed Aug 31, 2024
1 parent 7cda7c9 commit ec4220a
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 104 deletions.
5 changes: 5 additions & 0 deletions 3rdparty/spirv-cross/spirv_cross.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,11 @@ const SmallVector<SPIRBlock::Case> &Compiler::get_case_list(const SPIRBlock &blo
const auto &type = get<SPIRType>(constant->constant_type);
width = type.width;
}
else if (const auto *op = maybe_get<SPIRConstantOp>(block.condition))
{
const auto &type = get<SPIRType>(op->basetype);
width = type.width;
}
else if (const auto *var = maybe_get<SPIRVariable>(block.condition))
{
const auto &type = get<SPIRType>(var->basetype);
Expand Down
48 changes: 48 additions & 0 deletions 3rdparty/spirv-cross/spirv_cross_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
case SPVC_COMPILER_OPTION_HLSL_FLATTEN_MATRIX_VERTEX_INPUT_SEMANTICS:
options->hlsl.flatten_matrix_vertex_input_semantics = value != 0;
break;

case SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME:
options->hlsl.use_entry_point_name = value != 0;
break;
#endif

#if SPIRV_CROSS_C_API_MSL
Expand Down Expand Up @@ -1355,6 +1359,34 @@ spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
#endif
}

spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
const spvc_msl_resource_binding_2 *binding)
{
#if SPIRV_CROSS_C_API_MSL
if (compiler->backend != SPVC_BACKEND_MSL)
{
compiler->context->report_error("MSL function used on a non-MSL backend.");
return SPVC_ERROR_INVALID_ARGUMENT;
}

auto &msl = *static_cast<CompilerMSL *>(compiler->compiler.get());
MSLResourceBinding bind;
bind.binding = binding->binding;
bind.desc_set = binding->desc_set;
bind.stage = static_cast<spv::ExecutionModel>(binding->stage);
bind.msl_buffer = binding->msl_buffer;
bind.msl_texture = binding->msl_texture;
bind.msl_sampler = binding->msl_sampler;
bind.count = binding->count;
msl.add_msl_resource_binding(bind);
return SPVC_SUCCESS;
#else
(void)binding;
compiler->context->report_error("MSL function used on a non-MSL backend.");
return SPVC_ERROR_INVALID_ARGUMENT;
#endif
}

spvc_result spvc_compiler_msl_add_dynamic_buffer(spvc_compiler compiler, unsigned desc_set, unsigned binding, unsigned index)
{
#if SPIRV_CROSS_C_API_MSL
Expand Down Expand Up @@ -2811,6 +2843,22 @@ void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding)
#endif
}

void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding)
{
#if SPIRV_CROSS_C_API_MSL
MSLResourceBinding binding_default;
binding->desc_set = binding_default.desc_set;
binding->binding = binding_default.binding;
binding->msl_buffer = binding_default.msl_buffer;
binding->msl_texture = binding_default.msl_texture;
binding->msl_sampler = binding_default.msl_sampler;
binding->stage = static_cast<SpvExecutionModel>(binding_default.stage);
binding->count = 0;
#else
memset(binding, 0, sizeof(*binding));
#endif
}

void spvc_hlsl_resource_binding_init(spvc_hlsl_resource_binding *binding)
{
#if SPIRV_CROSS_C_API_HLSL
Expand Down
23 changes: 21 additions & 2 deletions 3rdparty/spirv-cross/spirv_cross_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ extern "C" {
/* Bumped if ABI or API breaks backwards compatibility. */
#define SPVC_C_API_VERSION_MAJOR 0
/* Bumped if APIs or enumerations are added in a backwards compatible way. */
#define SPVC_C_API_VERSION_MINOR 60
#define SPVC_C_API_VERSION_MINOR 62
/* Bumped if internal implementation details change. */
#define SPVC_C_API_VERSION_PATCH 0

Expand Down Expand Up @@ -380,7 +380,8 @@ typedef struct spvc_msl_shader_interface_var_2
*/
SPVC_PUBLIC_API void spvc_msl_shader_interface_var_init_2(spvc_msl_shader_interface_var_2 *var);

/* Maps to C++ API. */
/* Maps to C++ API.
* Deprecated. Use spvc_msl_resource_binding_2. */
typedef struct spvc_msl_resource_binding
{
SpvExecutionModel stage;
Expand All @@ -391,11 +392,24 @@ typedef struct spvc_msl_resource_binding
unsigned msl_sampler;
} spvc_msl_resource_binding;

typedef struct spvc_msl_resource_binding_2
{
SpvExecutionModel stage;
unsigned desc_set;
unsigned binding;
unsigned count;
unsigned msl_buffer;
unsigned msl_texture;
unsigned msl_sampler;
} spvc_msl_resource_binding_2;

/*
* Initializes the resource binding struct.
* The defaults are non-zero.
* Deprecated: Use spvc_msl_resource_binding_init_2.
*/
SPVC_PUBLIC_API void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding);
SPVC_PUBLIC_API void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding);

#define SPVC_MSL_PUSH_CONSTANT_DESC_SET (~(0u))
#define SPVC_MSL_PUSH_CONSTANT_BINDING (0)
Expand Down Expand Up @@ -730,6 +744,8 @@ typedef enum spvc_compiler_option
SPVC_COMPILER_OPTION_MSL_AGX_MANUAL_CUBE_GRAD_FIXUP = 88 | SPVC_COMPILER_OPTION_MSL_BIT,
SPVC_COMPILER_OPTION_MSL_FORCE_FRAGMENT_WITH_SIDE_EFFECTS_EXECUTION = 89 | SPVC_COMPILER_OPTION_MSL_BIT,

SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME = 90 | SPVC_COMPILER_OPTION_HLSL_BIT,

SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
} spvc_compiler_option;

Expand Down Expand Up @@ -836,8 +852,11 @@ SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_patch_output_buffer(spvc_compi
SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_input_threadgroup_mem(spvc_compiler compiler);
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_vertex_attribute(spvc_compiler compiler,
const spvc_msl_vertex_attribute *attrs);
/* Deprecated; use spvc_compiler_msl_add_resource_binding_2(). */
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
const spvc_msl_resource_binding *binding);
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
const spvc_msl_resource_binding_2 *binding);
/* Deprecated; use spvc_compiler_msl_add_shader_input_2(). */
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_shader_input(spvc_compiler compiler,
const spvc_msl_shader_interface_var *input);
Expand Down
2 changes: 2 additions & 0 deletions 3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ uint32_t ParsedIR::get_member_decoration(TypeID id, uint32_t index, Decoration d
return dec.stream;
case DecorationSpecId:
return dec.spec_id;
case DecorationMatrixStride:
return dec.matrix_stride;
case DecorationIndex:
return dec.index;
default:
Expand Down
67 changes: 63 additions & 4 deletions 3rdparty/spirv-cross/spirv_glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_
string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && should_dereference(id))

if (is_pointer(type) && should_dereference(id))
return dereference_expression(type, to_enclosed_expression(id, register_expression_read));
else
return to_expression(id, register_expression_read);
Expand All @@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre
string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_unpacked_expression(id, register_expression_read);
Expand All @@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression
string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_enclosed_unpacked_expression(id, register_expression_read);
Expand Down Expand Up @@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
}
else
{
append_index(index, is_literal, true);
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
{
SPIRType tmp_type(OpTypeInt);
tmp_type.basetype = SPIRType::UInt64;
tmp_type.width = 64;
tmp_type.vecsize = 1;
tmp_type.columns = 1;

TypeID ptr_type_id = expression_type_id(base);
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
const SPIRType &pointee_type = get_pointee_type(ptr_type);

// This only runs in native pointer backends.
// Can replace reinterpret_cast with a backend string if ever needed.
// We expect this to count as a de-reference.
// This leaks some MSL details, but feels slightly overkill to
// add yet another virtual interface just for this.
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
get_decoration(ptr_type_id, DecorationArrayStride));

if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
{
is_packed = true;
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
" *>(", intptr_expr, ")");
}
else
{
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
}
}
else
append_index(index, is_literal, true);
}

if (type->basetype == SPIRType::ControlPointArray)
Expand Down Expand Up @@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
return ret;
}

uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
{
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
}

string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
AccessChainMeta *meta, bool ptr_chain)
{
Expand Down Expand Up @@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
{
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
if (ptr_chain)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
// PtrAccessChain could get complicated.
TypeID type_id = expression_type_id(base);
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
{
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
// random values. It works for structs though.
auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
uint32_t physical_stride = get_physical_type_stride(pointee_type);
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
if (physical_stride != requested_stride)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
if (is_vector(pointee_type))
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
}
}
}

return access_chain_internal(base, indices, count, flags, meta);
}
}
Expand Down
8 changes: 7 additions & 1 deletion 3rdparty/spirv-cross/spirv_glsl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ enum AccessChainFlagBits
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
};
typedef uint32_t AccessChainFlags;

Expand Down Expand Up @@ -753,6 +755,10 @@ class CompilerGLSL : public Compiler
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
AccessChainMeta *meta);

// Only meaningful on backends with physical pointer support ala MSL.
// Relevant for PtrAccessChain / BDA.
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;

spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);

Expand Down
57 changes: 48 additions & 9 deletions 3rdparty/spirv-cross/spirv_hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,23 @@ void CompilerHLSL::emit_builtin_inputs_in_struct()
case BuiltInSubgroupLeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupGeMask:
// Handled specially.
break;

case BuiltInBaseVertex:
if (hlsl_options.shader_model >= 68)
{
type = "uint";
semantic = "SV_StartVertexLocation";
}
break;

case BuiltInBaseInstance:
// Handled specially.
if (hlsl_options.shader_model >= 68)
{
type = "uint";
semantic = "SV_StartInstanceLocation";
}
break;

case BuiltInHelperInvocation:
Expand Down Expand Up @@ -1231,7 +1245,7 @@ void CompilerHLSL::emit_builtin_variables()
case BuiltInVertexIndex:
case BuiltInInstanceIndex:
type = "int";
if (hlsl_options.support_nonzero_base_vertex_base_instance)
if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
base_vertex_info.used = true;
break;

Expand Down Expand Up @@ -1353,7 +1367,7 @@ void CompilerHLSL::emit_builtin_variables()
}
});

if (base_vertex_info.used)
if (base_vertex_info.used && hlsl_options.shader_model < 68)
{
string binding_info;
if (base_vertex_info.explicit_binding)
Expand Down Expand Up @@ -3136,23 +3150,39 @@ void CompilerHLSL::emit_hlsl_entry_point()
case BuiltInVertexIndex:
case BuiltInInstanceIndex:
// D3D semantics are uint, but shader wants int.
if (hlsl_options.support_nonzero_base_vertex_base_instance)
if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
if (hlsl_options.shader_model >= 68)
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseInstanceARB);");
else
statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseVertexARB);");
}
else
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
else
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
}
}
else
statement(builtin, " = int(stage_input.", builtin, ");");
break;

case BuiltInBaseVertex:
statement(builtin, " = SPIRV_Cross_BaseVertex;");
if (hlsl_options.shader_model >= 68)
statement(builtin, " = stage_input.gl_BaseVertexARB;");
else
statement(builtin, " = SPIRV_Cross_BaseVertex;");
break;

case BuiltInBaseInstance:
statement(builtin, " = SPIRV_Cross_BaseInstance;");
if (hlsl_options.shader_model >= 68)
statement(builtin, " = stage_input.gl_BaseInstanceARB;");
else
statement(builtin, " = SPIRV_Cross_BaseInstance;");
break;

case BuiltInInstanceId:
Expand Down Expand Up @@ -6714,6 +6744,15 @@ string CompilerHLSL::compile()
if (need_subpass_input)
active_input_builtins.set(BuiltInFragCoord);

// Need to offset by BaseVertex/BaseInstance in SM 6.8+.
if (hlsl_options.shader_model >= 68)
{
if (active_input_builtins.get(BuiltInVertexIndex))
active_input_builtins.set(BuiltInBaseVertex);
if (active_input_builtins.get(BuiltInInstanceIndex))
active_input_builtins.set(BuiltInBaseInstance);
}

uint32_t pass_count = 0;
do
{
Expand Down
Loading

0 comments on commit ec4220a

Please sign in to comment.