From 3434a173a2851e77900ee4e7956b62aed3a8158d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D1=80=D0=B0=D0=BD=D0=B8=D0=BC=D0=B8=D1=80=20=D0=9A?= =?UTF-8?q?=D0=B0=D1=80=D0=B0=D1=9F=D0=B8=D1=9B?= Date: Sat, 28 Dec 2024 22:41:09 -0800 Subject: [PATCH] Updated spirv-cross. --- 3rdparty/spirv-cross/spirv_common.hpp | 10 +- 3rdparty/spirv-cross/spirv_cross.cpp | 11 +- .../spirv-cross/spirv_cross_parsed_ir.cpp | 4 + 3rdparty/spirv-cross/spirv_glsl.cpp | 56 +- 3rdparty/spirv-cross/spirv_hlsl.cpp | 88 ++- 3rdparty/spirv-cross/spirv_hlsl.hpp | 2 + 3rdparty/spirv-cross/spirv_msl.cpp | 650 +++++++++++++++++- 3rdparty/spirv-cross/spirv_msl.hpp | 13 + 8 files changed, 780 insertions(+), 54 deletions(-) diff --git a/3rdparty/spirv-cross/spirv_common.hpp b/3rdparty/spirv-cross/spirv_common.hpp index 93b26697709..b70536d9ecc 100644 --- a/3rdparty/spirv-cross/spirv_common.hpp +++ b/3rdparty/spirv-cross/spirv_common.hpp @@ -578,7 +578,9 @@ struct SPIRType : IVariant // Keep internal types at the end. ControlPointArray, Interpolant, - Char + Char, + // MSL specific type, that is used by 'object'(analog of 'task' from glsl) shader. + MeshGridProperties }; // Scalar/vector/matrix support. @@ -746,6 +748,10 @@ struct SPIRExpression : IVariant // A list of expressions which this expression depends on. SmallVector expression_dependencies; + // Similar as expression dependencies, but does not stop the tracking for force-temporary variables. + // We need to know the full chain from store back to any SSA variable. + SmallVector invariance_dependencies; + // By reading this expression, we implicitly read these expressions as well. // Used by access chain Store and Load since we read multiple expressions in this case. SmallVector implied_read_expressions; @@ -1598,6 +1604,8 @@ struct AccessChainMeta bool flattened_struct = false; bool relaxed_precision = false; bool access_meshlet_position_y = false; + bool chain_is_builtin = false; + spv::BuiltIn builtin = {}; }; enum ExtendedDecorations diff --git a/3rdparty/spirv-cross/spirv_cross.cpp b/3rdparty/spirv-cross/spirv_cross.cpp index 5471b351501..3492f0b3ed9 100644 --- a/3rdparty/spirv-cross/spirv_cross.cpp +++ b/3rdparty/spirv-cross/spirv_cross.cpp @@ -2569,6 +2569,15 @@ void Compiler::add_active_interface_variable(uint32_t var_id) void Compiler::inherit_expression_dependencies(uint32_t dst, uint32_t source_expression) { + auto *ptr_e = maybe_get(dst); + + if (is_position_invariant() && ptr_e && maybe_get(source_expression)) + { + auto &deps = ptr_e->invariance_dependencies; + if (std::find(deps.begin(), deps.end(), source_expression) == deps.end()) + deps.push_back(source_expression); + } + // Don't inherit any expression dependencies if the expression in dst // is not a forwarded temporary. if (forwarded_temporaries.find(dst) == end(forwarded_temporaries) || @@ -2577,7 +2586,7 @@ void Compiler::inherit_expression_dependencies(uint32_t dst, uint32_t source_exp return; } - auto &e = get(dst); + auto &e = *ptr_e; auto *phi = maybe_get(source_expression); if (phi && phi->phi_variable) { diff --git a/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp b/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp index 188c0ae65d5..b05afeb3f57 100644 --- a/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp +++ b/3rdparty/spirv-cross/spirv_cross_parsed_ir.cpp @@ -928,6 +928,8 @@ void ParsedIR::reset_all_of_type(Types type) void ParsedIR::add_typed_id(Types type, ID id) { + assert(id < ids.size()); + if (loop_iteration_depth_hard != 0) SPIRV_CROSS_THROW("Cannot add typed ID while looping over it."); @@ -1030,6 +1032,8 @@ ParsedIR::LoopLock &ParsedIR::LoopLock::operator=(LoopLock &&other) SPIRV_CROSS_ void ParsedIR::make_constant_null(uint32_t id, uint32_t type, bool add_to_typed_id_set) { + assert(id < ids.size()); + auto &constant_type = get(type); if (constant_type.pointer) diff --git a/3rdparty/spirv-cross/spirv_glsl.cpp b/3rdparty/spirv-cross/spirv_glsl.cpp index d8d509f47c1..6c1d5208b98 100644 --- a/3rdparty/spirv-cross/spirv_glsl.cpp +++ b/3rdparty/spirv-cross/spirv_glsl.cpp @@ -6438,7 +6438,7 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t if (splat) { res += convert_to_string(c.scalar(vector, 0)); - if (is_legacy()) + if (is_legacy() && !has_extension("GL_EXT_gpu_shader4")) { // Fake unsigned constant literals with signed ones if possible. // Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed. @@ -6457,7 +6457,7 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t else { res += convert_to_string(c.scalar(vector, i)); - if (is_legacy()) + if (is_legacy() && !has_extension("GL_EXT_gpu_shader4")) { // Fake unsigned constant literals with signed ones if possible. // Things like array sizes, etc, tend to be unsigned even though they could just as easily be signed. @@ -10210,6 +10210,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice bool pending_array_enclose = false; bool dimension_flatten = false; bool access_meshlet_position_y = false; + bool chain_is_builtin = false; + spv::BuiltIn chained_builtin = {}; if (auto *base_expr = maybe_get(base)) { @@ -10367,6 +10369,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice auto builtin = ir.meta[base].decoration.builtin_type; bool mesh_shader = get_execution_model() == ExecutionModelMeshEXT; + chain_is_builtin = true; + chained_builtin = builtin; + switch (builtin) { case BuiltInCullDistance: @@ -10502,6 +10507,9 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice { access_meshlet_position_y = true; } + + chain_is_builtin = true; + chained_builtin = builtin; } else { @@ -10721,6 +10729,8 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice meta->storage_physical_type = physical_type; meta->relaxed_precision = relaxed_precision; meta->access_meshlet_position_y = access_meshlet_position_y; + meta->chain_is_builtin = chain_is_builtin; + meta->builtin = chained_builtin; } return expr; @@ -11766,13 +11776,13 @@ void CompilerGLSL::disallow_forwarding_in_expression_chain(const SPIRExpression // Allow trivially forwarded expressions like OpLoad or trivial shuffles, // these will be marked as having suppressed usage tracking. // Our only concern is to make sure arithmetic operations are done in similar ways. - if (expression_is_forwarded(expr.self) && !expression_suppresses_usage_tracking(expr.self) && - forced_invariant_temporaries.count(expr.self) == 0) + if (forced_invariant_temporaries.count(expr.self) == 0) { - force_temporary_and_recompile(expr.self); + if (!expression_suppresses_usage_tracking(expr.self)) + force_temporary_and_recompile(expr.self); forced_invariant_temporaries.insert(expr.self); - for (auto &dependent : expr.expression_dependencies) + for (auto &dependent : expr.invariance_dependencies) disallow_forwarding_in_expression_chain(get(dependent)); } } @@ -12336,6 +12346,8 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) flattened_structs[ops[1]] = true; if (meta.relaxed_precision && backend.requires_relaxed_precision_analysis) set_decoration(ops[1], DecorationRelaxedPrecision); + if (meta.chain_is_builtin) + set_decoration(ops[1], DecorationBuiltIn, meta.builtin); // If we have some expression dependencies in our access chain, this access chain is technically a forwarded // temporary which could be subject to invalidation. @@ -13229,13 +13241,24 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) uint32_t op0 = ops[2]; uint32_t op1 = ops[3]; - // Needs special handling. + auto &out_type = get(result_type); + bool forward = should_forward(op0) && should_forward(op1); - auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(", - to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")"); + string cast_op0, cast_op1; + auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, int_type, op0, op1, false); + + // Needs special handling. + auto expr = join(cast_op0, " - ", cast_op1, " * ", "(", cast_op0, " / ", cast_op1, ")"); if (implicit_integer_promotion) + { expr = join(type_to_glsl(get(result_type)), '(', expr, ')'); + } + else if (out_type.basetype != int_type) + { + expected_type.basetype = int_type; + expr = join(bitcast_glsl_op(out_type, expected_type), '(', expr, ')'); + } emit_op(result_type, result_id, expr, forward); inherit_expression_dependencies(result_id, op0); @@ -15668,7 +15691,16 @@ string CompilerGLSL::argument_decl(const SPIRFunction::Parameter &arg) if (type.pointer) { - if (arg.write_count && arg.read_count) + // If we're passing around block types to function, we really mean reference in a pointer sense, + // but DXC does not like inout for mesh blocks, so workaround that. out is technically not correct, + // but it works in practice due to legalization. It's ... not great, but you gotta do what you gotta do. + // GLSL will never hit this case since it's not valid. + if (type.storage == StorageClassOutput && get_execution_model() == ExecutionModelMeshEXT && + has_decoration(type.self, DecorationBlock) && is_builtin_type(type) && arg.write_count) + { + direction = "out "; + } + else if (arg.write_count && arg.read_count) direction = "inout "; else if (arg.write_count) direction = "out "; @@ -15945,7 +15977,7 @@ string CompilerGLSL::image_type_glsl(const SPIRType &type, uint32_t id, bool /*m case DimBuffer: if (options.es && options.version < 320) require_extension_internal("GL_EXT_texture_buffer"); - else if (!options.es && options.version < 300) + else if (!options.es && options.version < 140) require_extension_internal("GL_EXT_texture_buffer_object"); res += "Buffer"; break; @@ -16488,6 +16520,8 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags) { auto &var = get(v); var.deferred_declaration = false; + if (var.storage == StorageClassTaskPayloadWorkgroupEXT) + continue; if (variable_decl_is_remapped_storage(var, StorageClassWorkgroup)) { diff --git a/3rdparty/spirv-cross/spirv_hlsl.cpp b/3rdparty/spirv-cross/spirv_hlsl.cpp index 560f177b070..de370873763 100644 --- a/3rdparty/spirv-cross/spirv_hlsl.cpp +++ b/3rdparty/spirv-cross/spirv_hlsl.cpp @@ -4775,13 +4775,13 @@ void CompilerHLSL::emit_load(const Instruction &instruction) { auto ops = stream(instruction); - auto *chain = maybe_get(ops[2]); + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + + auto *chain = maybe_get(ptr); if (chain) { - uint32_t result_type = ops[0]; - uint32_t id = ops[1]; - uint32_t ptr = ops[2]; - auto &type = get(result_type); bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct; @@ -4819,7 +4819,36 @@ void CompilerHLSL::emit_load(const Instruction &instruction) } } else - CompilerGLSL::emit_instruction(instruction); + { + // Very special case where we cannot rely on IO lowering. + // Mesh shader clip/cull arrays ... Cursed. + auto &res_type = get(result_type); + if (get_execution_model() == ExecutionModelMeshEXT && + has_decoration(ptr, DecorationBuiltIn) && + (get_decoration(ptr, DecorationBuiltIn) == BuiltInClipDistance || + get_decoration(ptr, DecorationBuiltIn) == BuiltInCullDistance) && + is_array(res_type) && !is_array(get(res_type.parent_type)) && + to_array_size_literal(res_type) > 1) + { + track_expression_read(ptr); + string load_expr = "{ "; + uint32_t num_elements = to_array_size_literal(res_type); + for (uint32_t i = 0; i < num_elements; i++) + { + load_expr += join(to_expression(ptr), ".", index_to_swizzle(i)); + if (i + 1 < num_elements) + load_expr += ", "; + } + load_expr += " }"; + emit_op(result_type, id, load_expr, false); + register_read(id, ptr, false); + inherit_expression_dependencies(id, ptr); + } + else + { + CompilerGLSL::emit_instruction(instruction); + } + } } void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value, @@ -6903,3 +6932,50 @@ bool CompilerHLSL::is_user_type_structured(uint32_t id) const } return false; } + +void CompilerHLSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) +{ + // Loading a full array of ClipDistance needs special consideration in mesh shaders + // since we cannot lower them by wrapping the variables in global statics. + // Fortunately, clip/cull is a proper vector in HLSL so we can lower with simple rvalue casts. + if (get_execution_model() != ExecutionModelMeshEXT || + !has_decoration(target_id, DecorationBuiltIn) || + !is_array(expr_type)) + { + CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type); + return; + } + + auto builtin = BuiltIn(get_decoration(target_id, DecorationBuiltIn)); + if (builtin != BuiltInClipDistance && builtin != BuiltInCullDistance) + { + CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type); + return; + } + + // Array of array means one thread is storing clip distance for all vertices. Nonsensical? + if (is_array(get(expr_type.parent_type))) + SPIRV_CROSS_THROW("Attempting to store all mesh vertices in one go. This is not supported."); + + uint32_t num_clip = to_array_size_literal(expr_type); + if (num_clip > 4) + SPIRV_CROSS_THROW("Number of clip or cull distances exceeds 4, this will not work with mesh shaders."); + + if (num_clip == 1) + { + // We already emit array here. + CompilerGLSL::cast_to_variable_store(target_id, expr, expr_type); + return; + } + + auto unrolled_expr = join("float", num_clip, "("); + for (uint32_t i = 0; i < num_clip; i++) + { + unrolled_expr += join(expr, "[", i, "]"); + if (i + 1 < num_clip) + unrolled_expr += ", "; + } + + unrolled_expr += ")"; + expr = std::move(unrolled_expr); +} diff --git a/3rdparty/spirv-cross/spirv_hlsl.hpp b/3rdparty/spirv-cross/spirv_hlsl.hpp index bec458c6122..3dc89cc683f 100644 --- a/3rdparty/spirv-cross/spirv_hlsl.hpp +++ b/3rdparty/spirv-cross/spirv_hlsl.hpp @@ -408,6 +408,8 @@ class CompilerHLSL : public CompilerGLSL std::vector composite_selection_workaround_types; std::string get_inner_entry_point_name() const; + + void cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type) override; }; } // namespace SPIRV_CROSS_NAMESPACE diff --git a/3rdparty/spirv-cross/spirv_msl.cpp b/3rdparty/spirv-cross/spirv_msl.cpp index aba26be8f3c..50b215e267e 100644 --- a/3rdparty/spirv-cross/spirv_msl.cpp +++ b/3rdparty/spirv-cross/spirv_msl.cpp @@ -202,6 +202,9 @@ uint32_t CompilerMSL::get_resource_array_size(const SPIRType &type, uint32_t id) { uint32_t array_size = to_array_size_literal(type); + if (id == 0) + return array_size; + // If we have argument buffers, we need to honor the ABI by using the correct array size // from the layout. Only use shader declared size if we're not using argument buffers. uint32_t desc_set = get_decoration(id, DecorationDescriptorSet); @@ -269,7 +272,7 @@ void CompilerMSL::build_implicit_builtins() (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) || active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance)); - bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId); + bool need_local_invocation_index = (msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId)) || is_mesh_shader(); bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups); bool force_frag_depth_passthrough = get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input && @@ -278,7 +281,7 @@ void CompilerMSL::build_implicit_builtins() if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params || need_tese_params || need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id || needs_subgroup_invocation_id || needs_subgroup_size || needs_helper_invocation || - has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough) + has_additional_fixed_sample_mask() || need_local_invocation_index || need_workgroup_size || force_frag_depth_passthrough || is_mesh_shader()) { bool has_frag_coord = false; bool has_sample_id = false; @@ -325,6 +328,13 @@ void CompilerMSL::build_implicit_builtins() } } + if (builtin == BuiltInPrimitivePointIndicesEXT || + builtin == BuiltInPrimitiveLineIndicesEXT || + builtin == BuiltInPrimitiveTriangleIndicesEXT) + { + builtin_mesh_primitive_indices_id = var.self; + } + if (var.storage != StorageClassInput) return; @@ -1057,6 +1067,53 @@ void CompilerMSL::build_implicit_builtins() set_decoration(var_id, DecorationBuiltIn, BuiltInPosition); mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id); } + + if (is_mesh_shader()) + { + uint32_t offset = ir.increase_bound_by(2); + uint32_t type_ptr_id = offset; + uint32_t var_id = offset + 1; + + // Create variable to store meshlet size. + uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 2); + SPIRType uint_type_ptr = get(type_id); + uint_type_ptr.op = OpTypePointer; + uint_type_ptr.pointer = true; + uint_type_ptr.pointer_depth++; + uint_type_ptr.parent_type = type_id; + uint_type_ptr.storage = StorageClassWorkgroup; + + auto &ptr_type = set(type_ptr_id, uint_type_ptr); + ptr_type.self = type_id; + set(var_id, type_ptr_id, StorageClassWorkgroup); + set_name(var_id, "spvMeshSizes"); + builtin_mesh_sizes_id = var_id; + } + + if (get_execution_model() == spv::ExecutionModelTaskEXT) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + SPIRType mesh_grid_type { OpTypeStruct }; + mesh_grid_type.basetype = SPIRType::MeshGridProperties; + set(type_id, mesh_grid_type); + + SPIRType mesh_grid_type_ptr = mesh_grid_type; + mesh_grid_type_ptr.op = spv::OpTypePointer; + mesh_grid_type_ptr.pointer = true; + mesh_grid_type_ptr.pointer_depth++; + mesh_grid_type_ptr.parent_type = type_id; + mesh_grid_type_ptr.storage = StorageClassOutput; + + auto &ptr_in_type = set(type_ptr_id, mesh_grid_type_ptr); + ptr_in_type.self = type_id; + set(var_id, type_ptr_id, StorageClassOutput); + set_name(var_id, "spvMgp"); + builtin_task_grid_id = var_id; + } } // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active. @@ -1509,6 +1566,10 @@ void CompilerMSL::emit_entry_point_declarations() statement(CompilerGLSL::variable_decl(var), ";"); var.deferred_declaration = false; } + + // Holds SetMeshOutputsEXT information. Threadgroup since first thread wins. + if (processing_entry_point && is_mesh_shader()) + statement("threadgroup uint2 spvMeshSizes;"); } string CompilerMSL::compile() @@ -1544,6 +1605,8 @@ string CompilerMSL::compile() backend.native_pointers = true; backend.nonuniform_qualifier = ""; backend.support_small_type_sampling_result = true; + backend.force_merged_mesh_block = false; + backend.force_gl_in_out_block = get_execution_model() == ExecutionModelMeshEXT; backend.supports_empty_struct = true; backend.support_64bit_switch = true; backend.boolean_in_struct_remapped_type = SPIRType::Short; @@ -1559,6 +1622,9 @@ string CompilerMSL::compile() capture_output_to_buffer = msl_options.capture_output_to_buffer; is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer; + if (is_mesh_shader() && !get_entry_point().flags.get(ExecutionModeOutputPoints)) + msl_options.enable_point_size_builtin = false; + // Initialize array here rather than constructor, MSVC 2013 workaround. for (auto &id : next_metal_resource_ids) id = 0; @@ -1566,6 +1632,11 @@ string CompilerMSL::compile() fixup_anonymous_struct_names(); fixup_type_alias(); replace_illegal_names(); + if (get_execution_model() == ExecutionModelMeshEXT) + { + // Emit proxy entry-point for the sake of copy-pass + emit_mesh_entry_point(); + } sync_entry_point_aliases_and_names(); build_function_control_flow_graphs_and_analyze(); @@ -1617,9 +1688,17 @@ string CompilerMSL::compile() // Create structs to hold input, output and uniform variables. // Do output first to ensure out. is declared at top of entry function. qual_pos_var_name = ""; - stage_out_var_id = add_interface_block(StorageClassOutput); - patch_stage_out_var_id = add_interface_block(StorageClassOutput, true); - stage_in_var_id = add_interface_block(StorageClassInput); + if (is_mesh_shader()) + { + fixup_implicit_builtin_block_names(get_execution_model()); + } + else + { + stage_out_var_id = add_interface_block(StorageClassOutput); + patch_stage_out_var_id = add_interface_block(StorageClassOutput, true); + stage_in_var_id = add_interface_block(StorageClassInput); + } + if (is_tese_shader()) patch_stage_in_var_id = add_interface_block(StorageClassInput, true); @@ -1628,6 +1707,12 @@ string CompilerMSL::compile() if (is_tessellation_shader()) stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput); + if (is_mesh_shader()) + { + mesh_out_per_vertex = add_meshlet_block(false); + mesh_out_per_primitive = add_meshlet_block(true); + } + // Metal vertex functions that define no output must disable rasterization and return void. if (!stage_out_var_id) is_rasterization_disabled = true; @@ -1762,12 +1847,18 @@ void CompilerMSL::localize_global_variables() { uint32_t v_id = *iter; auto &var = get(v_id); - if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup) + if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup || + var.storage == StorageClassTaskPayloadWorkgroupEXT) { if (!variable_is_lut(var)) entry_func.add_local_variable(v_id); iter = global_variables.erase(iter); } + else if (var.storage == StorageClassOutput && is_mesh_shader()) + { + entry_func.add_local_variable(v_id); + iter = global_variables.erase(iter); + } else iter++; } @@ -2105,6 +2196,15 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: break; } + case OpSetMeshOutputsEXT: + { + if (builtin_local_invocation_index_id != 0) + added_arg_ids.insert(builtin_local_invocation_index_id); + if (builtin_mesh_sizes_id != 0) + added_arg_ids.insert(builtin_mesh_sizes_id); + break; + } + default: break; } @@ -2117,6 +2217,9 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: // We should consider a more unified system here to reduce boiler-plate. // This kind of analysis is done in several places ... } + + if (b.terminator == SPIRBlock::EmitMeshTasks && builtin_task_grid_id != 0) + added_arg_ids.insert(builtin_task_grid_id); } function_global_vars[func_id] = added_arg_ids; @@ -2206,6 +2309,17 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: if (is_tese_shader() && msl_options.raw_buffer_tese_input && var.storage == StorageClassInput) set_decoration(next_id, DecorationNonWritable); } + else if (is_builtin && is_mesh_shader()) + { + uint32_t next_id = ir.increase_bound_by(1); + func.add_parameter(type_id, next_id, true); + auto &v = set(next_id, type_id, StorageClassFunction, 0, arg_id); + v.storage = StorageClassWorkgroup; + + // Ensure the existing variable has a valid name and the new variable has all the same meta info + set_name(arg_id, ensure_valid_name(to_name(arg_id), "v")); + ir.meta[next_id] = ir.meta[arg_id]; + } else if (is_builtin && has_decoration(p_type->self, DecorationBlock)) { // Get the pointee type @@ -4490,6 +4604,42 @@ uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageCla return ib_ptr_var_id; } +uint32_t CompilerMSL::add_meshlet_block(bool per_primitive) +{ + // Accumulate the variables that should appear in the interface struct. + SmallVector vars; + + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + if (var.storage != StorageClassOutput || var.self == builtin_mesh_primitive_indices_id) + return; + if (is_per_primitive_variable(var) != per_primitive) + return; + vars.push_back(&var); + }); + + if (vars.empty()) + return 0; + + uint32_t next_id = ir.increase_bound_by(1); + auto &type = set(next_id, SPIRType(OpTypeStruct)); + type.basetype = SPIRType::Struct; + + InterfaceBlockMeta meta; + for (auto *p_var : vars) + { + meta.strip_array = true; + meta.allow_local_declaration = false; + add_variable_to_interface_block(StorageClassOutput, "", type, *p_var, meta); + } + + if (per_primitive) + set_name(type.self, "spvPerPrimitive"); + else + set_name(type.self, "spvPerVertex"); + + return next_id; +} + // Ensure that the type is compatible with the builtin. // If it is, simply return the given type ID. // Otherwise, create a new type, and return it's ID. @@ -5482,6 +5632,19 @@ void CompilerMSL::emit_custom_templates() begin_scope(); statement("return elements[pos];"); end_scope(); + if (get_execution_model() == spv::ExecutionModelMeshEXT || + get_execution_model() == spv::ExecutionModelTaskEXT) + { + statement(""); + statement("object_data T& operator [] (size_t pos) object_data"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + statement("constexpr const object_data T& operator [] (size_t pos) const object_data"); + begin_scope(); + statement("return elements[pos];"); + end_scope(); + } end_scope_decl(); statement(""); break; @@ -7609,6 +7772,18 @@ void CompilerMSL::emit_custom_functions() statement(""); break; + case SPVFuncImplSetMeshOutputsEXT: + statement("void spvSetMeshOutputsEXT(uint gl_LocalInvocationIndex, threadgroup uint2& spvMeshSizes, uint vertexCount, uint primitiveCount)"); + begin_scope(); + statement("if (gl_LocalInvocationIndex == 0)"); + begin_scope(); + statement("spvMeshSizes.x = vertexCount;"); + statement("spvMeshSizes.y = primitiveCount;"); + end_scope(); + end_scope(); + statement(""); + break; + default: break; } @@ -7710,6 +7885,23 @@ void CompilerMSL::emit_resources() emit_interface_block(patch_stage_out_var_id); emit_interface_block(stage_in_var_id); emit_interface_block(patch_stage_in_var_id); + + if (get_execution_model() == ExecutionModelMeshEXT) + { + auto &execution = get_entry_point(); + const char *topology = ""; + if (execution.flags.get(ExecutionModeOutputTrianglesEXT)) + topology = "topology::triangle"; + else if (execution.flags.get(ExecutionModeOutputLinesEXT)) + topology = "topology::line"; + else if (execution.flags.get(ExecutionModeOutputPoints)) + topology = "topology::point"; + + const char *per_primitive = mesh_out_per_primitive ? "spvPerPrimitive" : "void"; + statement("using spvMesh_t = mesh<", "spvPerVertex, ", per_primitive, ", ", execution.output_vertices, ", ", + execution.output_primitives, ", ", topology, ">;"); + statement(""); + } } // Emit declarations for the specialization Metal function constants @@ -7733,7 +7925,7 @@ void CompilerMSL::emit_specialization_constants_and_structs() mark_scalar_layout_structs(type); }); - bool builtin_block_type_is_required = false; + bool builtin_block_type_is_required = is_mesh_shader(); // Very special case. If gl_PerVertex is initialized as an array (tessellation) // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT. ir.for_each_typed_id([&](uint32_t, SPIRConstant &c) { @@ -9925,6 +10117,14 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) break; } + case OpSetMeshOutputsEXT: + { + flush_variable_declaration(builtin_mesh_primitive_indices_id); + add_spv_func_and_recompile(SPVFuncImplSetMeshOutputsEXT); + statement("spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, ", to_unpacked_expression(ops[0]), ", ", to_unpacked_expression(ops[1]), ");"); + break; + } + default: CompilerGLSL::emit_instruction(instruction); break; @@ -9966,8 +10166,13 @@ void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse) void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem) { - if (get_execution_model() != ExecutionModelGLCompute && !is_tesc_shader()) + auto model = get_execution_model(); + + if (model != ExecutionModelGLCompute && model != ExecutionModelTaskEXT && + model != ExecutionModelMeshEXT && !is_tesc_shader()) + { return; + } uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation); uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation); @@ -11008,6 +11213,21 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &) if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression) set(ed_var.initializer, "{}", ed_var.basetype, true); } + + // add `taskPayloadSharedEXT` variable to entry-point arguments + for (auto &v : func.local_variables) + { + auto &var = get(v); + if (var.storage != StorageClassTaskPayloadWorkgroupEXT) + continue; + + add_local_variable_name(v); + SPIRFunction::Parameter arg = {}; + arg.id = v; + arg.type = var.basetype; + arg.alias_global_variable = true; + decl += join(", ", argument_decl(arg), " [[payload]]"); + } } for (auto &arg : func.arguments) @@ -11325,7 +11545,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool if (args.has_array_offsets) { forward = forward && should_forward(args.offset); - farg_str += ", " + to_expression(args.offset); + farg_str += ", " + to_unpacked_expression(args.offset); } // Const offsets gather or swizzled gather puts the component before the other args. @@ -11338,7 +11558,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool // Texture coordinates forward = forward && should_forward(args.coord); - auto coord_expr = to_enclosed_expression(args.coord); + auto coord_expr = to_enclosed_unpacked_expression(args.coord); auto &coord_type = expression_type(args.coord); bool coord_is_fp = type_is_floating_point(coord_type); bool is_cube_fetch = false; @@ -11462,14 +11682,14 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool if (type.basetype != SPIRType::UInt) tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, args.offset), ", 0)"); else - tex_coords += join(" + uint2(", to_enclosed_expression(args.offset), ", 0)"); + tex_coords += join(" + uint2(", to_enclosed_unpacked_expression(args.offset), ", 0)"); } else { if (type.basetype != SPIRType::UInt) tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.offset); else - tex_coords += " + " + to_enclosed_expression(args.offset); + tex_coords += " + " + to_enclosed_unpacked_expression(args.offset); } } @@ -11556,10 +11776,10 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool string dref_expr; if (args.base.is_proj) - dref_expr = join(to_enclosed_expression(args.dref), " / ", + dref_expr = join(to_enclosed_unpacked_expression(args.dref), " / ", to_extract_component_expression(args.coord, alt_coord_component)); else - dref_expr = to_expression(args.dref); + dref_expr = to_unpacked_expression(args.dref); if (sampling_type_needs_f32_conversion(dref_type)) dref_expr = convert_to_f32(dref_expr, 1); @@ -11610,7 +11830,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D)) { forward = forward && should_forward(bias); - farg_str += ", bias(" + to_expression(bias) + ")"; + farg_str += ", bias(" + to_unpacked_expression(bias) + ")"; } // Metal does not support LOD for 1D textures. @@ -11619,7 +11839,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool forward = forward && should_forward(lod); if (args.base.is_fetch) { - farg_str += ", " + to_expression(lod); + farg_str += ", " + to_unpacked_expression(lod); } else if (msl_options.sample_dref_lod_array_as_grad && args.dref && imgtype.image.arrayed) { @@ -11676,12 +11896,12 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool extent = "float3(1.0)"; break; } - farg_str += join(", ", grad_opt, "(", grad_coord, "exp2(", to_expression(lod), " - 0.5) / ", extent, - ", exp2(", to_expression(lod), " - 0.5) / ", extent, ")"); + farg_str += join(", ", grad_opt, "(", grad_coord, "exp2(", to_unpacked_expression(lod), " - 0.5) / ", extent, + ", exp2(", to_unpacked_expression(lod), " - 0.5) / ", extent, ")"); } else { - farg_str += ", level(" + to_expression(lod) + ")"; + farg_str += ", level(" + to_unpacked_expression(lod) + ")"; } } else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) && @@ -11727,7 +11947,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool grad_opt = "unsupported_gradient_dimension"; break; } - farg_str += join(", ", grad_opt, "(", grad_coord, to_expression(grad_x), ", ", to_expression(grad_y), ")"); + farg_str += join(", ", grad_opt, "(", grad_coord, to_unpacked_expression(grad_x), ", ", to_unpacked_expression(grad_y), ")"); } if (args.min_lod) @@ -11736,7 +11956,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up."); forward = forward && should_forward(args.min_lod); - farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")"; + farg_str += ", min_lod_clamp(" + to_unpacked_expression(args.min_lod) + ")"; } // Add offsets @@ -11745,7 +11965,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool if (args.offset && !args.base.is_fetch && !args.has_array_offsets) { forward = forward && should_forward(args.offset); - offset_expr = to_expression(args.offset); + offset_expr = to_unpacked_expression(args.offset); offset_type = &expression_type(args.offset); } @@ -11811,7 +12031,7 @@ string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool { forward = forward && should_forward(args.sample); farg_str += ", "; - farg_str += to_expression(args.sample); + farg_str += to_unpacked_expression(args.sample); } *p_forward = forward; @@ -12463,12 +12683,50 @@ string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_ ((stage_out_var_id && get_stage_out_struct_type().self == type.self && variable_storage_requires_stage_io(StorageClassOutput)) || (stage_in_var_id && get_stage_in_struct_type().self == type.self && - variable_storage_requires_stage_io(StorageClassInput))); + variable_storage_requires_stage_io(StorageClassInput))) || + is_mesh_shader(); if (is_ib_in_out && is_member_builtin(type, index, &builtin)) is_using_builtin_array = true; array_type = type_to_array_glsl(physical_type, orig_id); } + if (is_mesh_shader()) + { + BuiltIn builtin = BuiltInMax; + if (is_member_builtin(type, index, &builtin)) + { + if (builtin == BuiltInPrimitiveShadingRateKHR) + { + // not supported in metal 3.0 + is_using_builtin_array = false; + return ""; + } + + SPIRType metallic_type = *declared_type; + if (builtin == BuiltInCullPrimitiveEXT) + metallic_type.basetype = SPIRType::Boolean; + else if (builtin == BuiltInPrimitiveId || builtin == BuiltInLayer || builtin == BuiltInViewportIndex) + metallic_type.basetype = SPIRType::UInt; + + is_using_builtin_array = true; + std::string result; + if (has_member_decoration(type.self, orig_id, DecorationBuiltIn)) + { + // avoid '_RESERVED_IDENTIFIER_FIXUP_' in variable name + result = join(type_to_glsl(metallic_type, orig_id, false), " ", qualifier, + builtin_to_glsl(builtin, StorageClassOutput), member_attribute_qualifier(type, index), + array_type, ";"); + } + else + { + result = join(type_to_glsl(metallic_type, orig_id, false), " ", qualifier, + to_member_name(type, index), member_attribute_qualifier(type, index), array_type, ";"); + } + is_using_builtin_array = false; + return result; + } + } + if (orig_id) { auto *data_type = declared_type; @@ -12522,6 +12780,16 @@ void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_ statement("char _m", index, "_pad", "[", pad_len, "];"); } + BuiltIn builtin = BuiltInMax; + if (is_mesh_shader() && is_member_builtin(type, index, &builtin)) + { + if (!has_active_builtin(builtin, StorageClassOutput) && !has_active_builtin(builtin, StorageClassInput)) + { + // Do not emit unused builtins in mesh-output blocks + return; + } + } + // Handle HLSL-style 0-based vertex/instance index. builtin_declaration = true; statement(to_struct_member(type, member_type_id, index, qualifier)); @@ -12595,9 +12863,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in return string(" [[attribute(") + convert_to_string(locn) + ")]]"; } - // Vertex and tessellation evaluation function outputs - if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) || is_tese_shader()) && - type.storage == StorageClassOutput) + bool use_semantic_stage_output = is_mesh_shader() || is_tese_shader() || + (execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation); + + // Vertex, mesh and tessellation evaluation function outputs + if ((type.storage == StorageClassOutput || is_mesh_shader()) && use_semantic_stage_output) { if (is_builtin) { @@ -12616,6 +12886,9 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in /* fallthrough */ case BuiltInPosition: case BuiltInLayer: + case BuiltInCullPrimitiveEXT: + case BuiltInPrimitiveShadingRateKHR: + case BuiltInPrimitiveId: return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " "); case BuiltInClipDistance: @@ -12790,7 +13063,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in { if (!quals.empty()) quals += ", "; - if (has_member_decoration(type.self, index, DecorationNoPerspective) || builtin == BuiltInBaryCoordNoPerspKHR) + + if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR) + SPIRV_CROSS_THROW("Centroid interpolation not supported for barycentrics in MSL."); + + if (has_member_decoration(type.self, index, DecorationNoPerspective)) quals += "centroid_no_perspective"; else quals += "centroid_perspective"; @@ -12799,7 +13076,11 @@ string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t in { if (!quals.empty()) quals += ", "; - if (has_member_decoration(type.self, index, DecorationNoPerspective) || builtin == BuiltInBaryCoordNoPerspKHR) + + if (builtin == BuiltInBaryCoordNoPerspKHR || builtin == BuiltInBaryCoordKHR) + SPIRV_CROSS_THROW("Sample interpolation not supported for barycentrics in MSL."); + + if (has_member_decoration(type.self, index, DecorationNoPerspective)) quals += "sample_no_perspective"; else quals += "sample_perspective"; @@ -13078,6 +13359,12 @@ string CompilerMSL::func_type_decl(SPIRType &type) case ExecutionModelKernel: entry_type = "kernel"; break; + case ExecutionModelMeshEXT: + entry_type = "[[mesh]]"; + break; + case ExecutionModelTaskEXT: + entry_type = "[[object]]"; + break; default: entry_type = "unknown"; break; @@ -13096,6 +13383,11 @@ bool CompilerMSL::is_tese_shader() const return get_execution_model() == ExecutionModelTessellationEvaluation; } +bool CompilerMSL::is_mesh_shader() const +{ + return get_execution_model() == spv::ExecutionModelMeshEXT; +} + bool CompilerMSL::uses_explicit_early_fragment_test() { auto &ep_flags = get_entry_point().flags; @@ -13211,6 +13503,16 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo if (!addr_space) addr_space = "device"; } + + if (is_mesh_shader()) + addr_space = "threadgroup"; + break; + + case StorageClassTaskPayloadWorkgroupEXT: + if (is_mesh_shader()) + addr_space = "const object_data"; + else + addr_space = "object_data"; break; default: @@ -13612,6 +13914,20 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args) " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]"); } } + + if (is_mesh_shader()) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("spvMesh_t spvMesh"); + } + + if (get_execution_model() == ExecutionModelTaskEXT) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("mesh_grid_properties spvMgp"); + } } string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) @@ -14050,6 +14366,14 @@ void CompilerMSL::fix_up_shader_inputs_outputs() }); } + if (is_mesh_shader()) + { + // If shader doesn't call SetMeshOutputsEXT, nothing should be rendered. + // No need to barrier after this, because only thread 0 writes to this later. + entry_func.fixup_hooks_in.push_back([this]() { statement("if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u;"); }); + entry_func.fixup_hooks_out.push_back([this]() { emit_mesh_outputs(); }); + } + // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths. ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { auto &type = get_variable_data_type(var); @@ -14850,7 +15174,8 @@ string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg) if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id)) decl = join(cv_qualifier, type_to_glsl(type, arg.id)); - else if (builtin) + else if (builtin && builtin_type != spv::BuiltInPrimitiveTriangleIndicesEXT && + builtin_type != spv::BuiltInPrimitiveLineIndicesEXT && builtin_type != spv::BuiltInPrimitivePointIndicesEXT) { // Only use templated array for Clip/Cull distance when feasible. // In other scenarios, we need need to override array length for tess levels (if used as outputs), @@ -15487,6 +15812,9 @@ string CompilerMSL::to_qualifiers_glsl(uint32_t id) auto *var = maybe_get(id); auto &type = expression_type(id); + if (type.storage == StorageClassTaskPayloadWorkgroupEXT) + quals += "object_data "; + if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))) quals += "threadgroup "; @@ -15671,6 +15999,8 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) break; case SPIRType::RayQuery: return "raytracing::intersection_query"; + case SPIRType::MeshGridProperties: + return "mesh_grid_properties"; default: return "unknown_type"; @@ -15785,6 +16115,9 @@ bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable return true; } + if (is_mesh_shader()) + return variable.storage == StorageClassOutput; + return variable.storage == StorageClassOutput && is_tesc_shader() && is_stage_output_variable_masked(variable); } else if (storage == StorageClassStorageBuffer) @@ -16554,6 +16887,8 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) case BuiltInLayer: if (is_tesc_shader()) break; + if (is_mesh_shader()) + break; if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) && !is_stage_output_builtin_masked(builtin)) return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage); @@ -16611,6 +16946,9 @@ string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) // In SPIR-V 1.6 with Volatile HelperInvocation, we cannot emit a fixup early. return "simd_is_helper_thread()"; + case BuiltInPrimitiveId: + return "gl_PrimitiveID"; + default: break; } @@ -16644,6 +16982,8 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) // Vertex function out case BuiltInClipDistance: return "clip_distance"; + case BuiltInCullDistance: + return "cull_distance"; case BuiltInPointSize: return "point_size"; case BuiltInPosition: @@ -16691,6 +17031,8 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2)) SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2."); return "primitive_id"; + case ExecutionModelMeshEXT: + return "primitive_id"; default: SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model."); } @@ -16720,7 +17062,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) // Shouldn't be reached. SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL."); case BuiltInViewIndex: - if (execution.model != ExecutionModelFragment) + if (execution.model != ExecutionModelFragment && execution.model != ExecutionModelMeshEXT) SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders."); // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index, // so we can get it from there. @@ -16825,6 +17167,9 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS."); return "barycentric_coord"; + case BuiltInCullPrimitiveEXT: + return "primitive_culled"; + default: return "unsupported-built-in"; } @@ -16941,6 +17286,13 @@ string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id) case BuiltInDeviceIndex: return "int"; + case BuiltInPrimitivePointIndicesEXT: + return "uint"; + case BuiltInPrimitiveLineIndicesEXT: + return "uint2"; + case BuiltInPrimitiveTriangleIndicesEXT: + return "uint3"; + default: return "unsupported-built-in-type"; } @@ -18452,6 +18804,7 @@ void CompilerMSL::analyze_argument_buffers() uint32_t member_index = 0; uint32_t next_arg_buff_index = 0; + uint32_t prev_was_scalar_on_array_offset = 0; for (auto &resource : resources) { auto &var = *resource.var; @@ -18464,7 +18817,9 @@ void CompilerMSL::analyze_argument_buffers() // member_index and next_arg_buff_index are incremented when padding members are added. if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0) { - auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); + auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index - prev_was_scalar_on_array_offset); + rez_bind.count -= prev_was_scalar_on_array_offset; + while (resource.index > next_arg_buff_index) { switch (rez_bind.basetype) @@ -18503,12 +18858,19 @@ void CompilerMSL::analyze_argument_buffers() // After padding, retrieve the resource again. It will either be more padding, or the actual resource. rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index); + prev_was_scalar_on_array_offset = 0; } + uint32_t count = rez_bind.count; + + // If the current resource is an array in the descriptor, but is a scalar + // in the shader, only the first element will be consumed. The next pass + // will add a padding member to consume the remaining array elements. + if (count > 1 && type.array.empty()) + count = prev_was_scalar_on_array_offset = 1; + // Adjust the number of slots consumed by current member itself. - // Use the count value from the app, instead of the shader, in case the - // shader is only accessing part, or even one element, of the array. - next_arg_buff_index += resource.plane_count * rez_bind.count; + next_arg_buff_index += resource.plane_count * count; } string mbr_name = ensure_valid_name(resource.name, "m"); @@ -18799,6 +19161,224 @@ void CompilerMSL::emit_block_hints(const SPIRBlock &) { } +void CompilerMSL::emit_mesh_entry_point() +{ + auto &ep = get_entry_point(); + auto &f = get(ir.default_entry_point); + + const uint32_t func_id = ir.increase_bound_by(3); + const uint32_t block_id = func_id + 1; + const uint32_t ret_id = func_id + 2; + auto &wrapped_main = set(func_id, f.return_type, f.function_type); + + wrapped_main.blocks.push_back(block_id); + wrapped_main.entry_block = block_id; + + auto &wrapped_entry = set(block_id); + wrapped_entry.terminator = SPIRBlock::Return; + + // Push call to original 'main' + Instruction ix = {}; + ix.op = OpFunctionCall; + ix.offset = uint32_t(ir.spirv.size()); + ix.length = 3; + + ir.spirv.push_back(f.return_type); + ir.spirv.push_back(ret_id); + ir.spirv.push_back(ep.self); + + wrapped_entry.ops.push_back(ix); + + // relace entry-point for new one + SPIREntryPoint proxy_ep = ep; + proxy_ep.self = func_id; + ir.entry_points.insert(std::make_pair(func_id, proxy_ep)); + ir.meta[func_id] = ir.meta[ir.default_entry_point]; + ir.meta[ir.default_entry_point].decoration.alias.clear(); + + ir.default_entry_point = func_id; +} + +void CompilerMSL::emit_mesh_outputs() +{ + auto &mode = get_entry_point(); + + // predefined thread count or zero, if specialization constant is in use + uint32_t num_invocations = 0; + if (mode.workgroup_size.id_x == 0 && mode.workgroup_size.id_y == 0 && mode.workgroup_size.id_z == 0) + num_invocations = mode.workgroup_size.x * mode.workgroup_size.y * mode.workgroup_size.z; + + statement("threadgroup_barrier(mem_flags::mem_threadgroup);"); + statement("if (spvMeshSizes.y == 0)"); + begin_scope(); + statement("return;"); + end_scope(); + statement("spvMesh.set_primitive_count(spvMeshSizes.y);"); + + statement("const uint spvThreadCount [[maybe_unused]] = (gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);"); + + if (mesh_out_per_vertex != 0) + { + auto &type_vert = get(mesh_out_per_vertex); + + if (num_invocations < mode.output_vertices) + { + statement("for (uint spvVI = gl_LocalInvocationIndex; spvVI < spvMeshSizes.x; spvVI += spvThreadCount)"); + } + else + { + statement("const uint spvVI = gl_LocalInvocationIndex;"); + statement("if (gl_LocalInvocationIndex < spvMeshSizes.x)"); + } + + begin_scope(); + + statement("spvPerVertex spvV = {};"); + for (uint32_t index = 0; index < uint32_t(type_vert.member_types.size()); ++index) + { + uint32_t orig_var = get_extended_member_decoration(type_vert.self, index, SPIRVCrossDecorationInterfaceOrigID); + uint32_t orig_id = get_extended_member_decoration(type_vert.self, index, SPIRVCrossDecorationInterfaceMemberIndex); + + // Clip/cull distances are special-case + if (orig_var == 0 && orig_id == (~0u)) + continue; + + auto &orig = get(orig_var); + auto &orig_type = get(orig.basetype); + + // FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc. + + BuiltIn builtin = BuiltInMax; + std::string access; + if (orig_type.basetype == SPIRType::Struct) + { + if (has_member_decoration(orig_type.self, orig_id, DecorationBuiltIn)) + builtin = BuiltIn(get_member_decoration(orig_type.self, orig_id, DecorationBuiltIn)); + + switch (builtin) + { + case BuiltInPosition: + case BuiltInPointSize: + case BuiltInClipDistance: + case BuiltInCullDistance: + access = "." + builtin_to_glsl(builtin, StorageClassOutput); + break; + default: + access = "." + to_member_name(orig_type, orig_id); + break; + } + + if (has_member_decoration(type_vert.self, index, DecorationIndex)) + { + // Declare the Clip/CullDistance as [[user(clip/cullN)]]. + const uint32_t orig_index = get_member_decoration(type_vert.self, index, DecorationIndex); + access += "[" + to_string(orig_index) + "]"; + statement("spvV.", builtin_to_glsl(builtin, StorageClassOutput), "[", orig_index, "] = ", to_name(orig_var), "[spvVI]", access, ";"); + } + } + + statement("spvV.", to_member_name(type_vert, index), " = ", to_name(orig_var), "[spvVI]", access, ";"); + if (options.vertex.flip_vert_y && builtin == BuiltInPosition) + { + statement("spvV.", to_member_name(type_vert, index), ".y = -(", "spvV.", + to_member_name(type_vert, index), ".y);", " // Invert Y-axis for Metal"); + } + } + statement("spvMesh.set_vertex(spvVI, spvV);"); + end_scope(); + } + + if (mesh_out_per_primitive != 0 || builtin_mesh_primitive_indices_id != 0) + { + if (num_invocations < mode.output_primitives) + { + statement("for (uint spvPI = gl_LocalInvocationIndex; spvPI < spvMeshSizes.y; spvPI += spvThreadCount)"); + } + else + { + statement("const uint spvPI = gl_LocalInvocationIndex;"); + statement("if (gl_LocalInvocationIndex < spvMeshSizes.y)"); + } + + // FIXME: Need to deal with complex composite IO types. These may need extra unroll, etc. + + begin_scope(); + + if (builtin_mesh_primitive_indices_id != 0) + { + if (mode.flags.get(ExecutionModeOutputTrianglesEXT)) + { + statement("spvMesh.set_index(spvPI * 3u + 0u, gl_PrimitiveTriangleIndicesEXT[spvPI].x);"); + statement("spvMesh.set_index(spvPI * 3u + 1u, gl_PrimitiveTriangleIndicesEXT[spvPI].y);"); + statement("spvMesh.set_index(spvPI * 3u + 2u, gl_PrimitiveTriangleIndicesEXT[spvPI].z);"); + } + else if (mode.flags.get(ExecutionModeOutputLinesEXT)) + { + statement("spvMesh.set_index(spvPI * 2u + 0u, gl_PrimitiveLineIndicesEXT[spvPI].x);"); + statement("spvMesh.set_index(spvPI * 2u + 1u, gl_PrimitiveLineIndicesEXT[spvPI].y);"); + } + else + { + statement("spvMesh.set_index(spvPI, gl_PrimitivePointIndicesEXT[spvPI]);"); + } + } + + if (mesh_out_per_primitive != 0) + { + auto &type_prim = get(mesh_out_per_primitive); + statement("spvPerPrimitive spvP = {};"); + for (uint32_t index = 0; index < uint32_t(type_prim.member_types.size()); ++index) + { + uint32_t orig_var = + get_extended_member_decoration(type_prim.self, index, SPIRVCrossDecorationInterfaceOrigID); + uint32_t orig_id = + get_extended_member_decoration(type_prim.self, index, SPIRVCrossDecorationInterfaceMemberIndex); + auto &orig = get(orig_var); + auto &orig_type = get(orig.basetype); + + BuiltIn builtin = BuiltInMax; + std::string access; + if (orig_type.basetype == SPIRType::Struct) + { + if (has_member_decoration(orig_type.self, orig_id, DecorationBuiltIn)) + builtin = BuiltIn(get_member_decoration(orig_type.self, orig_id, DecorationBuiltIn)); + + switch (builtin) + { + case BuiltInPrimitiveId: + case BuiltInLayer: + case BuiltInViewportIndex: + case BuiltInCullPrimitiveEXT: + case BuiltInPrimitiveShadingRateKHR: + access = "." + builtin_to_glsl(builtin, StorageClassOutput); + break; + default: + access = "." + to_member_name(orig_type, orig_id); + } + } + statement("spvP.", to_member_name(type_prim, index), " = ", to_name(orig_var), "[spvPI]", access, ";"); + } + statement("spvMesh.set_primitive(spvPI, spvP);"); + } + + end_scope(); + } +} + +void CompilerMSL::emit_mesh_tasks(SPIRBlock &block) +{ + // GLSL: Once this instruction is called, the workgroup must be terminated immediately, and the mesh shaders are launched. + // TODO: find relieble and clean of terminating shader. + flush_variable_declaration(builtin_task_grid_id); + statement("spvMgp.set_threadgroups_per_grid(uint3(", to_unpacked_expression(block.mesh.groups[0]), ", ", + to_unpacked_expression(block.mesh.groups[1]), ", ", to_unpacked_expression(block.mesh.groups[2]), "));"); + // This is correct if EmitMeshTasks is called in the entry function for shader. + // Only viable solutions would be: + // - Caller ensures the SPIR-V is inlined, then this always holds true. + // - Pass down a "should terminate" bool to leaf functions and chain return (horrible and disgusting, let's not). + statement("return;"); +} + string CompilerMSL::additional_fixed_sample_mask_str() const { char print_buffer[32]; diff --git a/3rdparty/spirv-cross/spirv_msl.hpp b/3rdparty/spirv-cross/spirv_msl.hpp index 14cd84b0f71..4aaad01a892 100644 --- a/3rdparty/spirv-cross/spirv_msl.hpp +++ b/3rdparty/spirv-cross/spirv_msl.hpp @@ -840,6 +840,7 @@ class CompilerMSL : public CompilerGLSL SPVFuncImplImageFence, SPVFuncImplTextureCast, SPVFuncImplMulExtended, + SPVFuncImplSetMeshOutputsEXT, }; // If the underlying resource has been used for comparison then duplicate loads of that resource must be too @@ -868,6 +869,9 @@ class CompilerMSL : public CompilerGLSL std::string type_to_glsl(const SPIRType &type, uint32_t id, bool member); std::string type_to_glsl(const SPIRType &type, uint32_t id = 0) override; void emit_block_hints(const SPIRBlock &block) override; + void emit_mesh_entry_point(); + void emit_mesh_outputs(); + void emit_mesh_tasks(SPIRBlock &block) override; // Allow Metal to use the array template to make arrays a value type std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id) override; @@ -919,6 +923,7 @@ class CompilerMSL : public CompilerGLSL bool is_tesc_shader() const; bool is_tese_shader() const; + bool is_mesh_shader() const; void preprocess_op_codes(); void localize_global_variables(); @@ -933,6 +938,7 @@ class CompilerMSL : public CompilerGLSL std::unordered_set &processed_func_ids); uint32_t add_interface_block(spv::StorageClass storage, bool patch = false); uint32_t add_interface_block_pointer(uint32_t ib_var_id, spv::StorageClass storage); + uint32_t add_meshlet_block(bool per_primitive); struct InterfaceBlockMeta { @@ -1104,12 +1110,17 @@ class CompilerMSL : public CompilerGLSL uint32_t builtin_stage_input_size_id = 0; uint32_t builtin_local_invocation_index_id = 0; uint32_t builtin_workgroup_size_id = 0; + uint32_t builtin_mesh_primitive_indices_id = 0; + uint32_t builtin_mesh_sizes_id = 0; + uint32_t builtin_task_grid_id = 0; uint32_t builtin_frag_depth_id = 0; uint32_t swizzle_buffer_id = 0; uint32_t buffer_size_buffer_id = 0; uint32_t view_mask_buffer_id = 0; uint32_t dynamic_offsets_buffer_id = 0; uint32_t uint_type_id = 0; + uint32_t shared_uint_type_id = 0; + uint32_t meshlet_type_id = 0; uint32_t argument_buffer_padding_buffer_type_id = 0; uint32_t argument_buffer_padding_image_type_id = 0; uint32_t argument_buffer_padding_sampler_type_id = 0; @@ -1174,6 +1185,8 @@ class CompilerMSL : public CompilerGLSL VariableID stage_out_ptr_var_id = 0; VariableID tess_level_inner_var_id = 0; VariableID tess_level_outer_var_id = 0; + VariableID mesh_out_per_vertex = 0; + VariableID mesh_out_per_primitive = 0; VariableID stage_out_masked_builtin_type_id = 0; // Handle HLSL-style 0-based vertex/instance index.