From ec4220ae440454f85488a4beb8098fc8f0d32b93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Fri, 30 Aug 2024 20:21:18 -0700 Subject: [PATCH] Updated spirv-cross. --- 3rdparty/spirv-cross/spirv_cross.cpp | 5 + 3rdparty/spirv-cross/spirv_cross_c.cpp | 48 ++++ 3rdparty/spirv-cross/spirv_cross_c.h | 23 +- .../spirv-cross/spirv_cross_parsed_ir.cpp | 2 + 3rdparty/spirv-cross/spirv_glsl.cpp | 67 ++++- 3rdparty/spirv-cross/spirv_glsl.hpp | 8 +- 3rdparty/spirv-cross/spirv_hlsl.cpp | 57 +++- 3rdparty/spirv-cross/spirv_msl.cpp | 252 ++++++++++++------ 3rdparty/spirv-cross/spirv_msl.hpp | 2 + 9 files changed, 360 insertions(+), 104 deletions(-) diff --git a/3rdparty/spirv-cross/spirv_cross.cpp b/3rdparty/spirv-cross/spirv_cross.cpp index 8c3e7d3812..5471b35150 100644 --- a/3rdparty/spirv-cross/spirv_cross.cpp +++ b/3rdparty/spirv-cross/spirv_cross.cpp @@ -1850,6 +1850,11 @@ const SmallVector &Compiler::get_case_list(const SPIRBlock &blo const auto &type = get(constant->constant_type); width = type.width; } + else if (const auto *op = maybe_get(block.condition)) + { + const auto &type = get(op->basetype); + width = type.width; + } else if (const auto *var = maybe_get(block.condition)) { const auto &type = get(var->basetype); diff --git a/3rdparty/spirv-cross/spirv_cross_c.cpp b/3rdparty/spirv-cross/spirv_cross_c.cpp index e0cc68ca98..b506ceeeb5 100644 --- a/3rdparty/spirv-cross/spirv_cross_c.cpp +++ b/3rdparty/spirv-cross/spirv_cross_c.cpp @@ -516,6 +516,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c case SPVC_COMPILER_OPTION_HLSL_FLATTEN_MATRIX_VERTEX_INPUT_SEMANTICS: options->hlsl.flatten_matrix_vertex_input_semantics = value != 0; break; + + case SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME: + options->hlsl.use_entry_point_name = value != 0; + break; #endif #if SPIRV_CROSS_C_API_MSL @@ -1355,6 +1359,34 @@ spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler, #endif } +spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler, + const spvc_msl_resource_binding_2 *binding) +{ +#if SPIRV_CROSS_C_API_MSL + if (compiler->backend != SPVC_BACKEND_MSL) + { + compiler->context->report_error("MSL function used on a non-MSL backend."); + return SPVC_ERROR_INVALID_ARGUMENT; + } + + auto &msl = *static_cast(compiler->compiler.get()); + MSLResourceBinding bind; + bind.binding = binding->binding; + bind.desc_set = binding->desc_set; + bind.stage = static_cast(binding->stage); + bind.msl_buffer = binding->msl_buffer; + bind.msl_texture = binding->msl_texture; + bind.msl_sampler = binding->msl_sampler; + bind.count = binding->count; + msl.add_msl_resource_binding(bind); + return SPVC_SUCCESS; +#else + (void)binding; + compiler->context->report_error("MSL function used on a non-MSL backend."); + return SPVC_ERROR_INVALID_ARGUMENT; +#endif +} + spvc_result spvc_compiler_msl_add_dynamic_buffer(spvc_compiler compiler, unsigned desc_set, unsigned binding, unsigned index) { #if SPIRV_CROSS_C_API_MSL @@ -2811,6 +2843,22 @@ void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding) #endif } +void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding) +{ +#if SPIRV_CROSS_C_API_MSL + MSLResourceBinding binding_default; + binding->desc_set = binding_default.desc_set; + binding->binding = binding_default.binding; + binding->msl_buffer = binding_default.msl_buffer; + binding->msl_texture = binding_default.msl_texture; + binding->msl_sampler = binding_default.msl_sampler; + binding->stage = static_cast(binding_default.stage); + binding->count = 0; +#else + memset(binding, 0, sizeof(*binding)); +#endif +} + void spvc_hlsl_resource_binding_init(spvc_hlsl_resource_binding *binding) { #if SPIRV_CROSS_C_API_HLSL diff --git a/3rdparty/spirv-cross/spirv_cross_c.h b/3rdparty/spirv-cross/spirv_cross_c.h index acae93558d..a25b0b5b9e 100644 --- a/3rdparty/spirv-cross/spirv_cross_c.h +++ b/3rdparty/spirv-cross/spirv_cross_c.h @@ -40,7 +40,7 @@ extern "C" { /* Bumped if ABI or API breaks backwards compatibility. */ #define SPVC_C_API_VERSION_MAJOR 0 /* Bumped if APIs or enumerations are added in a backwards compatible way. */ -#define SPVC_C_API_VERSION_MINOR 60 +#define SPVC_C_API_VERSION_MINOR 62 /* Bumped if internal implementation details change. */ #define SPVC_C_API_VERSION_PATCH 0 @@ -380,7 +380,8 @@ typedef struct spvc_msl_shader_interface_var_2 */ SPVC_PUBLIC_API void spvc_msl_shader_interface_var_init_2(spvc_msl_shader_interface_var_2 *var); -/* Maps to C++ API. */ +/* Maps to C++ API. + * Deprecated. Use spvc_msl_resource_binding_2. */ typedef struct spvc_msl_resource_binding { SpvExecutionModel stage; @@ -391,11 +392,24 @@ typedef struct spvc_msl_resource_binding unsigned msl_sampler; } spvc_msl_resource_binding; +typedef struct spvc_msl_resource_binding_2 +{ + SpvExecutionModel stage; + unsigned desc_set; + unsigned binding; + unsigned count; + unsigned msl_buffer; + unsigned msl_texture; + unsigned msl_sampler; +} spvc_msl_resource_binding_2; + /* * Initializes the resource binding struct. * The defaults are non-zero. + * Deprecated: Use spvc_msl_resource_binding_init_2. */ SPVC_PUBLIC_API void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding); +SPVC_PUBLIC_API void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding); #define SPVC_MSL_PUSH_CONSTANT_DESC_SET (~(0u)) #define SPVC_MSL_PUSH_CONSTANT_BINDING (0) @@ -730,6 +744,8 @@ typedef enum spvc_compiler_option SPVC_COMPILER_OPTION_MSL_AGX_MANUAL_CUBE_GRAD_FIXUP = 88 | SPVC_COMPILER_OPTION_MSL_BIT, SPVC_COMPILER_OPTION_MSL_FORCE_FRAGMENT_WITH_SIDE_EFFECTS_EXECUTION = 89 | SPVC_COMPILER_OPTION_MSL_BIT, + SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME = 90 | SPVC_COMPILER_OPTION_HLSL_BIT, + SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff } spvc_compiler_option; @@ -836,8 +852,11 @@ SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_patch_output_buffer(spvc_compi SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_input_threadgroup_mem(spvc_compiler compiler); SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_vertex_attribute(spvc_compiler compiler, const spvc_msl_vertex_attribute *attrs); +/* Deprecated; use spvc_compiler_msl_add_resource_binding_2(). */ SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler, const spvc_msl_resource_binding *binding); +SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler, + const spvc_msl_resource_binding_2 *binding); /* Deprecated; use spvc_compiler_msl_add_shader_input_2(). */ SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_shader_input(spvc_compiler compiler, const spvc_msl_shader_interface_var *input); diff --git a/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp b/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp index c6ddb6a45e..3072cd8abb 100644 --- a/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp +++ b/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp @@ -783,6 +783,8 @@ uint32_t ParsedIR::get_member_decoration(TypeID id, uint32_t index, Decoration d return dec.stream; case DecorationSpecId: return dec.spec_id; + case DecorationMatrixStride: + return dec.matrix_stride; case DecorationIndex: return dec.index; default: diff --git a/3rdparty/spirv-cross/spirv_glsl.cpp b/3rdparty/spirv-cross/spirv_glsl.cpp index 3f13febcce..fad1132e82 100644 --- a/3rdparty/spirv-cross/spirv_glsl.cpp +++ b/3rdparty/spirv-cross/spirv_glsl.cpp @@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && should_dereference(id)) + + if (is_pointer(type) && should_dereference(id)) return dereference_expression(type, to_enclosed_expression(id, register_expression_read)); else return to_expression(id, register_expression_read); @@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && expression_is_lvalue(id) && !should_dereference(id)) + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) return address_of_expression(to_enclosed_expression(id, register_expression_read)); else return to_unpacked_expression(id, register_expression_read); @@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read) { auto &type = expression_type(id); - if (type.pointer && expression_is_lvalue(id) && !should_dereference(id)) + if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id)) return address_of_expression(to_enclosed_expression(id, register_expression_read)); else return to_enclosed_unpacked_expression(id, register_expression_read); @@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice } else { - append_index(index, is_literal, true); + if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT) + { + SPIRType tmp_type(OpTypeInt); + tmp_type.basetype = SPIRType::UInt64; + tmp_type.width = 64; + tmp_type.vecsize = 1; + tmp_type.columns = 1; + + TypeID ptr_type_id = expression_type_id(base); + const SPIRType &ptr_type = get(ptr_type_id); + const SPIRType &pointee_type = get_pointee_type(ptr_type); + + // This only runs in native pointer backends. + // Can replace reinterpret_cast with a backend string if ever needed. + // We expect this to count as a de-reference. + // This leaks some MSL details, but feels slightly overkill to + // add yet another virtual interface just for this. + auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")"); + intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ", + get_decoration(ptr_type_id, DecorationArrayStride)); + + if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT) + { + is_packed = true; + expr = join("*reinterpret_cast(", intptr_expr, ")"); + } + else + { + expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")"); + } + } + else + append_index(index, is_literal, true); } if (type->basetype == SPIRType::ControlPointArray) @@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP return ret; } +uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const +{ + SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support."); +} + string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type, AccessChainMeta *meta, bool ptr_chain) { @@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32 { AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT; if (ptr_chain) + { flags |= ACCESS_CHAIN_PTR_CHAIN_BIT; + // PtrAccessChain could get complicated. + TypeID type_id = expression_type_id(base); + if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride)) + { + // If there is a mismatch we have to go via 64-bit pointer arithmetic :'( + // Using packed hacks only gets us so far, and is not designed to deal with pointer to + // random values. It works for structs though. + auto &pointee_type = get_pointee_type(get(type_id)); + uint32_t physical_stride = get_physical_type_stride(pointee_type); + uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride); + if (physical_stride != requested_stride) + { + flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT; + if (is_vector(pointee_type)) + flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT; + } + } + } + return access_chain_internal(base, indices, count, flags, meta); } } diff --git a/3rdparty/spirv-cross/spirv_glsl.hpp b/3rdparty/spirv-cross/spirv_glsl.hpp index f3e545e9f5..8a00263234 100644 --- a/3rdparty/spirv-cross/spirv_glsl.hpp +++ b/3rdparty/spirv-cross/spirv_glsl.hpp @@ -66,7 +66,9 @@ enum AccessChainFlagBits ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3, ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4, ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5, - ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6 + ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6, + ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7, + ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8 }; typedef uint32_t AccessChainFlags; @@ -753,6 +755,10 @@ class CompilerGLSL : public Compiler std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags, AccessChainMeta *meta); + // Only meaningful on backends with physical pointer support ala MSL. + // Relevant for PtrAccessChain / BDA. + virtual uint32_t get_physical_type_stride(const SPIRType &type) const; + spv::StorageClass get_expression_effective_storage_class(uint32_t ptr); virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base); diff --git a/3rdparty/spirv-cross/spirv_hlsl.cpp b/3rdparty/spirv-cross/spirv_hlsl.cpp index ac1d262af4..46fc176886 100644 --- a/3rdparty/spirv-cross/spirv_hlsl.cpp +++ b/3rdparty/spirv-cross/spirv_hlsl.cpp @@ -849,9 +849,23 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() case BuiltInSubgroupLeMask: case BuiltInSubgroupGtMask: case BuiltInSubgroupGeMask: + // Handled specially. + break; + case BuiltInBaseVertex: + if (hlsl_options.shader_model >= 68) + { + type = "uint"; + semantic = "SV_StartVertexLocation"; + } + break; + case BuiltInBaseInstance: - // Handled specially. + if (hlsl_options.shader_model >= 68) + { + type = "uint"; + semantic = "SV_StartInstanceLocation"; + } break; case BuiltInHelperInvocation: @@ -1231,7 +1245,7 @@ void CompilerHLSL::emit_builtin_variables() case BuiltInVertexIndex: case BuiltInInstanceIndex: type = "int"; - if (hlsl_options.support_nonzero_base_vertex_base_instance) + if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68) base_vertex_info.used = true; break; @@ -1353,7 +1367,7 @@ void CompilerHLSL::emit_builtin_variables() } }); - if (base_vertex_info.used) + if (base_vertex_info.used && hlsl_options.shader_model < 68) { string binding_info; if (base_vertex_info.explicit_binding) @@ -3136,23 +3150,39 @@ void CompilerHLSL::emit_hlsl_entry_point() case BuiltInVertexIndex: case BuiltInInstanceIndex: // D3D semantics are uint, but shader wants int. - if (hlsl_options.support_nonzero_base_vertex_base_instance) + if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68) { - if (static_cast(i) == BuiltInInstanceIndex) - statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;"); + if (hlsl_options.shader_model >= 68) + { + if (static_cast(i) == BuiltInInstanceIndex) + statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseInstanceARB);"); + else + statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseVertexARB);"); + } else - statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;"); + { + if (static_cast(i) == BuiltInInstanceIndex) + statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;"); + else + statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;"); + } } else statement(builtin, " = int(stage_input.", builtin, ");"); break; case BuiltInBaseVertex: - statement(builtin, " = SPIRV_Cross_BaseVertex;"); + if (hlsl_options.shader_model >= 68) + statement(builtin, " = stage_input.gl_BaseVertexARB;"); + else + statement(builtin, " = SPIRV_Cross_BaseVertex;"); break; case BuiltInBaseInstance: - statement(builtin, " = SPIRV_Cross_BaseInstance;"); + if (hlsl_options.shader_model >= 68) + statement(builtin, " = stage_input.gl_BaseInstanceARB;"); + else + statement(builtin, " = SPIRV_Cross_BaseInstance;"); break; case BuiltInInstanceId: @@ -6714,6 +6744,15 @@ string CompilerHLSL::compile() if (need_subpass_input) active_input_builtins.set(BuiltInFragCoord); + // Need to offset by BaseVertex/BaseInstance in SM 6.8+. + if (hlsl_options.shader_model >= 68) + { + if (active_input_builtins.get(BuiltInVertexIndex)) + active_input_builtins.set(BuiltInBaseVertex); + if (active_input_builtins.get(BuiltInInstanceIndex)) + active_input_builtins.set(BuiltInBaseInstance); + } + uint32_t pass_count = 0; do { diff --git a/3rdparty/spirv-cross/spirv_msl.cpp b/3rdparty/spirv-cross/spirv_msl.cpp index 562bb62a52..acc66eef9a 100644 --- a/3rdparty/spirv-cross/spirv_msl.cpp +++ b/3rdparty/spirv-cross/spirv_msl.cpp @@ -1361,14 +1361,14 @@ void CompilerMSL::emit_entry_point_declarations() if (is_array(type)) { - if (!type.array[type.array.size() - 1]) - SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet."); - is_using_builtin_array = true; statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, true), name, type_to_array_glsl(type, var_id), " ="); - uint32_t array_size = to_array_size_literal(type); + uint32_t array_size = get_resource_array_size(type, var_id); + if (array_size == 0) + SPIRV_CROSS_THROW("Size of runtime array with dynamic offset could not be determined from resource bindings."); + begin_scope(); for (uint32_t i = 0; i < array_size; i++) @@ -1576,8 +1576,7 @@ string CompilerMSL::compile() preprocess_op_codes(); build_implicit_builtins(); - if (needs_manual_helper_invocation_updates() && - (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + if (needs_manual_helper_invocation_updates() && needs_helper_invocation) { string builtin_helper_invocation = builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput); string discard_expr = join(builtin_helper_invocation, " = true, discard_fragment()"); @@ -1721,7 +1720,7 @@ void CompilerMSL::preprocess_op_codes() (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) || (need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses)))) needs_sample_id = true; - if (preproc.needs_helper_invocation) + if (preproc.needs_helper_invocation || active_input_builtins.get(BuiltInHelperInvocation)) needs_helper_invocation = true; // OpKill is removed by the parser, so we need to identify those by inspecting @@ -2058,8 +2057,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: } case OpDemoteToHelperInvocation: - if (needs_manual_helper_invocation_updates() && - (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + if (needs_manual_helper_invocation_updates() && needs_helper_invocation) added_arg_ids.insert(builtin_helper_invocation_id); break; @@ -2112,7 +2110,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: } if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill && - (active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation)) + needs_helper_invocation) added_arg_ids.insert(builtin_helper_invocation_id); // TODO: Add all other operations which can affect memory. @@ -4803,7 +4801,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32 return false; } - if (!mbr_type.array.empty()) + if (is_array(mbr_type)) { // If we have an array type, array stride must match exactly with SPIR-V. @@ -5615,6 +5613,10 @@ void CompilerMSL::emit_custom_templates() // otherwise they will cause problems when linked together in a single Metallib. void CompilerMSL::emit_custom_functions() { + // Use when outputting overloaded functions to cover different address spaces. + static const char *texture_addr_spaces[] = { "device", "constant", "thread" }; + static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*); + if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim)) spv_function_implementations.insert(SPVFuncImplArrayCopy); @@ -6264,54 +6266,62 @@ void CompilerMSL::emit_custom_functions() break; case SPVFuncImplGatherConstOffsets: - statement("// Wrapper function that processes a texture gather with a constant offset array."); - statement("template class Tex, " - "typename Toff, typename... Tp>"); - statement("inline vec spvGatherConstOffsets(const thread Tex& t, sampler s, " - "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)"); - begin_scope(); - statement("vec rslts[4];"); - statement("for (uint i = 0; i < 4; i++)"); - begin_scope(); - statement("switch (c)"); - begin_scope(); - // Work around texture::gather() requiring its component parameter to be a constant expression - statement("case component::x:"); - statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::x);"); - statement(" break;"); - statement("case component::y:"); - statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::y);"); - statement(" break;"); - statement("case component::z:"); - statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::z);"); - statement(" break;"); - statement("case component::w:"); - statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::w);"); - statement(" break;"); - end_scope(); - end_scope(); - // Pull all values from the i0j0 component of each gather footprint - statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); - end_scope(); - statement(""); + // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space. + for (uint32_t i = 0; i < texture_addr_space_count; i++) + { + statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array."); + statement("template class Tex, " + "typename Toff, typename... Tp>"); + statement("inline vec spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex& t, sampler s, " + "Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)"); + begin_scope(); + statement("vec rslts[4];"); + statement("for (uint i = 0; i < 4; i++)"); + begin_scope(); + statement("switch (c)"); + begin_scope(); + // Work around texture::gather() requiring its component parameter to be a constant expression + statement("case component::x:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::x);"); + statement(" break;"); + statement("case component::y:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::y);"); + statement(" break;"); + statement("case component::z:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::z);"); + statement(" break;"); + statement("case component::w:"); + statement(" rslts[i] = t.gather(s, spvForward(params)..., coffsets[i], component::w);"); + statement(" break;"); + end_scope(); + end_scope(); + // Pull all values from the i0j0 component of each gather footprint + statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); + end_scope(); + statement(""); + } break; case SPVFuncImplGatherCompareConstOffsets: - statement("// Wrapper function that processes a texture gather with a constant offset array."); - statement("template class Tex, " - "typename Toff, typename... Tp>"); - statement("inline vec spvGatherCompareConstOffsets(const thread Tex& t, sampler s, " - "Toff coffsets, Tp... params)"); - begin_scope(); - statement("vec rslts[4];"); - statement("for (uint i = 0; i < 4; i++)"); - begin_scope(); - statement(" rslts[i] = t.gather_compare(s, spvForward(params)..., coffsets[i]);"); - end_scope(); - // Pull all values from the i0j0 component of each gather footprint - statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); - end_scope(); - statement(""); + // Because we are passing a texture reference, we have to output an overloaded version of this function for each address space. + for (uint32_t i = 0; i < texture_addr_space_count; i++) + { + statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array."); + statement("template class Tex, " + "typename Toff, typename... Tp>"); + statement("inline vec spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex& t, sampler s, " + "Toff coffsets, Tp... params)"); + begin_scope(); + statement("vec rslts[4];"); + statement("for (uint i = 0; i < 4; i++)"); + begin_scope(); + statement(" rslts[i] = t.gather_compare(s, spvForward(params)..., coffsets[i]);"); + end_scope(); + // Pull all values from the i0j0 component of each gather footprint + statement("return vec(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);"); + end_scope(); + statement(""); + } break; case SPVFuncImplSubgroupBroadcast: @@ -9246,18 +9256,40 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) uint32_t coord_id = ops[3]; emit_uninitialized_temporary_expression(result_type, id); + std::string coord_expr = to_expression(coord_id); auto sampler_expr = to_sampler_expression(image_id); auto *combined = maybe_get(image_id); auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id); + const SPIRType &image_type = expression_type(image_id); + const SPIRType &coord_type = expression_type(coord_id); + + switch (image_type.image.dim) + { + case Dim1D: + if (!msl_options.texture_1D_as_2D) + SPIRV_CROSS_THROW("ImageQueryLod is not supported on 1D textures."); + [[fallthrough]]; + case Dim2D: + if (coord_type.vecsize > 2) + coord_expr = enclose_expression(coord_expr) + ".xy"; + break; + case DimCube: + case Dim3D: + if (coord_type.vecsize > 3) + coord_expr = enclose_expression(coord_expr) + ".xyz"; + break; + default: + SPIRV_CROSS_THROW("Bad image type given to OpImageQueryLod"); + } // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds // the reported LOD based on the sampler. NEAREST miplevel should // round the LOD, but LINEAR miplevel should not round. // Let's hope this does not become an issue ... statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ", - to_expression(coord_id), ");"); + coord_expr, ");"); statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ", - to_expression(coord_id), ");"); + coord_expr, ");"); register_control_dependent_expression(id); break; } @@ -12167,21 +12199,26 @@ string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_ string CompilerMSL::to_sampler_expression(uint32_t id) { auto *combined = maybe_get(id); - auto expr = to_expression(combined ? combined->image : VariableID(id)); - auto index = expr.find_first_of('['); + if (combined && combined->sampler) + return to_expression(combined->sampler); - uint32_t samp_id = 0; - if (combined) - samp_id = combined->sampler; + uint32_t expr_id = combined ? uint32_t(combined->image) : id; - if (index == string::npos) - return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix; - else + // Constexpr samplers are declared as local variables, + // so exclude any qualifier names on the image expression. + if (auto *var = maybe_get_backing_variable(expr_id)) { - auto image_expr = expr.substr(0, index); - auto array_expr = expr.substr(index); - return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr); + uint32_t img_id = var->basevariable ? var->basevariable : VariableID(var->self); + if (find_constexpr_sampler(img_id)) + return Compiler::to_name(img_id) + sampler_name_suffix; } + + auto img_expr = to_expression(expr_id); + auto index = img_expr.find_first_of('['); + if (index == string::npos) + return img_expr + sampler_name_suffix; + else + return img_expr.substr(0, index) + sampler_name_suffix + img_expr.substr(index); } string CompilerMSL::to_swizzle_expression(uint32_t id) @@ -13176,7 +13213,10 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : ""; } - return join(decoration_flags_signal_volatile(flags) ? "volatile " : "", addr_space); + if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread")) + return join("volatile ", addr_space); + else + return addr_space; } const char *CompilerMSL::to_restrict(uint32_t id, bool space) @@ -13602,7 +13642,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) claimed_bindings.set(buffer_binding); - ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); + ep_args += get_argument_address_space(var) + " "; + + if (recursive_inputs.count(type.self)) + ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp"; + else + ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id); + ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]"; next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1); @@ -14040,7 +14086,7 @@ void CompilerMSL::fix_up_shader_inputs_outputs() statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id), is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]), ".spvBufferSizeConstants", "[", - convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];"); + convert_to_string(get_metal_resource_index(var, SPIRType::UInt)), "];"); } else { @@ -14053,7 +14099,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs() } } - if (msl_options.replace_recursive_inputs && type_contains_recursion(type) && + if (!msl_options.argument_buffers && + msl_options.replace_recursive_inputs && type_contains_recursion(type) && (var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant || var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)) { @@ -17026,13 +17073,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, return msl_size; } +uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const +{ + // This should only be relevant for plain types such as scalars and vectors? + // If we're pointing to a struct, it will recursively pick up packed/row-major state. + return get_declared_type_size_msl(type, false, false); +} + // Returns the byte size of a struct member. uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const { // Pointers take 8 bytes each + // Match both pointer and array-of-pointer here. if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) { - uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize); + uint32_t type_size = 8; // Work our way through potentially layered arrays, // stopping when we hit a pointer that is not also an array. @@ -17107,9 +17162,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t // Returns the byte alignment of a type. uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const { - // Pointers aligns on multiples of 8 bytes + // Pointers align on multiples of 8 bytes. + // Deliberately ignore array-ness here. It's not relevant for alignment. if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer) - return 8 * (type.vecsize == 3 ? 4 : type.vecsize); + return 8; switch (type.basetype) { @@ -18134,6 +18190,13 @@ void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &al } else { + // This alias may have already been used to emit an entry point declaration. If there is a mismatch, we need a recompile. + // Moving this code to be run earlier will also conflict, + // because we need the qualified alias for the base resource, + // so forcing recompile until things sync up is the least invasive method for now. + if (ir.meta[aliased_var.self].decoration.qualified_alias != name) + force_recompile(); + // This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted. set_qualified_name(aliased_var.self, name); } @@ -18158,6 +18221,7 @@ void CompilerMSL::analyze_argument_buffers() string name; SPIRType::BaseType basetype; uint32_t index; + uint32_t plane_count; uint32_t plane; uint32_t overlapping_var_id; }; @@ -18208,14 +18272,14 @@ void CompilerMSL::analyze_argument_buffers() { uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i); resources_in_set[desc_set].push_back( - { &var, to_name(var_id), SPIRType::Image, image_resource_index, i, 0 }); + { &var, to_name(var_id), SPIRType::Image, image_resource_index, plane_count, i, 0 }); } if (type.image.dim != DimBuffer && !constexpr_sampler) { uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler); resources_in_set[desc_set].push_back( - { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0, 0 }); + { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 1, 0, 0 }); } } else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding })) @@ -18231,14 +18295,14 @@ void CompilerMSL::analyze_argument_buffers() uint32_t resource_index = get_metal_resource_index(var, type.basetype); resources_in_set[desc_set].push_back( - { &var, to_name(var_id), type.basetype, resource_index, 0, 0 }); + { &var, to_name(var_id), type.basetype, resource_index, 1, 0, 0 }); // Emulate texture2D atomic operations if (atomic_image_vars_emulated.count(var.self)) { uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0); resources_in_set[desc_set].push_back( - { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0, 0 }); + { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 1, 0, 0 }); } } @@ -18286,7 +18350,7 @@ void CompilerMSL::analyze_argument_buffers() set_decoration(var_id, DecorationDescriptorSet, desc_set); set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding); resources_in_set[desc_set].push_back( - { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 }); + { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 }); } if (set_needs_buffer_sizes[desc_set]) @@ -18297,7 +18361,7 @@ void CompilerMSL::analyze_argument_buffers() set_decoration(var_id, DecorationDescriptorSet, desc_set); set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding); resources_in_set[desc_set].push_back( - { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 }); + { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 }); } } } @@ -18309,7 +18373,7 @@ void CompilerMSL::analyze_argument_buffers() uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet); add_resource_name(var_id); resources_in_set[desc_set].push_back( - { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0, 0 }); + { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 1, 0, 0 }); } for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++) @@ -18340,7 +18404,8 @@ void CompilerMSL::analyze_argument_buffers() else buffer_type.storage = StorageClassUniform; - set_name(type_id, join("spvDescriptorSetBuffer", desc_set)); + auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set); + set_name(type_id, buffer_type_name); auto &ptr_type = set(ptr_type_id, OpTypePointer); ptr_type = buffer_type; @@ -18350,8 +18415,9 @@ void CompilerMSL::analyze_argument_buffers() ptr_type.parent_type = type_id; uint32_t buffer_variable_id = next_id; - set(buffer_variable_id, ptr_type_id, StorageClassUniform); - set_name(buffer_variable_id, join("spvDescriptorSet", desc_set)); + auto &buffer_var = set(buffer_variable_id, ptr_type_id, StorageClassUniform); + auto buffer_name = join("spvDescriptorSet", desc_set); + set_name(buffer_variable_id, buffer_name); // Ids must be emitted in ID order. stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool { @@ -18386,7 +18452,7 @@ void CompilerMSL::analyze_argument_buffers() // If needed, synthesize and add padding members. // member_index and next_arg_buff_index are incremented when padding members are added. - if (msl_options.pad_argument_buffer_resources && resource.overlapping_var_id == 0) + if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0) { auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); while (resource.index > next_arg_buff_index) @@ -18432,7 +18498,7 @@ void CompilerMSL::analyze_argument_buffers() // Adjust the number of slots consumed by current member itself. // Use the count value from the app, instead of the shader, in case the // shader is only accessing part, or even one element, of the array. - next_arg_buff_index += rez_bind.count; + next_arg_buff_index += resource.plane_count * rez_bind.count; } string mbr_name = ensure_valid_name(resource.name, "m"); @@ -18559,6 +18625,16 @@ void CompilerMSL::analyze_argument_buffers() set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding); member_index++; } + + if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type)) + { + recursive_inputs.insert(type_id); + auto &entry_func = this->get(ir.default_entry_point); + auto addr_space = get_argument_address_space(buffer_var); + entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() { + statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;"); + }); + } } } diff --git a/3rdparty/spirv-cross/spirv_msl.hpp b/3rdparty/spirv-cross/spirv_msl.hpp index 9a1715808d..2d970c0da5 100644 --- a/3rdparty/spirv-cross/spirv_msl.hpp +++ b/3rdparty/spirv-cross/spirv_msl.hpp @@ -1028,6 +1028,8 @@ class CompilerMSL : public CompilerGLSL uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const; + uint32_t get_physical_type_stride(const SPIRType &type) const override; + // MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output. // These values can change depending on various extended decorations which control packing rules. // We need to make these rules match up with SPIR-V declared rules.