Skip to content

Commit

Permalink
Add support for Specialization constants for shader options (#88)
Browse files Browse the repository at this point in the history
* Add support for specialization constants
* Add support for specialization constants on Metal
* Add tests
* Update lint files

Signed-off-by: Akio Gaule <[email protected]>
  • Loading branch information
akioCL authored Jun 28, 2024
1 parent 219aa96 commit 58c32ff
Show file tree
Hide file tree
Showing 24 changed files with 360 additions and 21 deletions.
2 changes: 1 addition & 1 deletion EditorsLint/AZSL_notepad++_bright.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<Keywords name="Keywords1">break continue if else switch case return for while do typedef namespace true false compile discard inline const struct class interface static extern register volatile inline nointerpolation shared uniform row_major column_major snorm unorm cbuffer groupshared SamplerState in out inout point triangle line lineadj triangleadj rootconstant ShaderResourceGroup ShaderResourceGroupSemantic ShaderVariantFallback FrequencyId associatedtype typealias typeof __azslc_print_message __azslc_print_symbol __azslc_prtsym_fully_qualified __azslc_prtsym_least_qualified __azslc_prtsym_constint_value enum option</Keywords>
<Keywords name="Keywords2">void bool bool1 bool2 bool3 bool4 int int1 int2 int3 int4 uint uint1 uint2 uint3 uint4 half half1 half2 half3 half4 float float1 float2 float3 float4 double double1 double2 double3 double4 matrix bool1x1 bool1x2 bool1x3 bool1x4 bool2x1 bool2x2 bool2x3 bool2x4 bool3x1 bool3x2 bool3x3 bool3x4 bool4x1 bool4x2 bool4x3 bool4x4 int1x1 int1x2 int1x3 int1x4 int2x1 int2x2 int2x3 int2x4 int3x1 int3x2 int3x3 int3x4 int4x1 int4x2 int4x3 int4x4 uint1x1 uint1x2 uint1x3 uint1x4 uint2x1 uint2x2 uint2x3 uint2x4 uint3x1 uint3x2 uint3x3 uint3x4 uint4x1 uint4x2 uint4x3 uint4x4 dword dword1 dword2 dword3 dword4 dword1x1 dword1x2 dword1x3 dword1x4 dword2x1 dword2x2 dword2x3 dword2x4 dword3x1 dword3x2 dword3x3 dword3x4 dword4x1 dword4x2 dword4x3 dword4x4 half1x1 half1x2 half1x3 half1x4 half2x1 half2x2 half2x3 half2x4 half3x1 half3x2 half3x3 half3x4 half4x1 half4x2 half4x3 half4x4 float1x1 float1x2 float1x3 float1x4 float2x1 float2x2 float2x3 float2x4 float3x1 float3x2 float3x3 float3x4 float4x1 float4x2 float4x3 float4x4 double1x1 double1x2 double1x3 double1x4 double2x1 double2x2 double2x3 double2x4 double3x1 double3x2 double3x3 double3x4 double4x1 double4x2 double4x3 double4x4 vector matrix</Keywords>
<Keywords name="Keywords3">Texture Texture1D Texture1DArray Texture2D Texture2DArray Texture2DMS Texture2DMSArray Texture3D TextureCube RWTexture1D RWTexture1DArray RWTexture2D RWTexture2DArray RWTexture3D Buffer StructuredBuffer AppendStructuredBuffer ConsumeStructuredBuffer RWBuffer RWStructuredBuffer ByteAddressBuffer RWByteAddressBuffer PointStream TriangleStream LineStream InputPatch OutputPatch SubpassInput SubpassInputMS ConstantBuffer Sampler sampler RaytracingAccelerationStructure BuiltInTriangleIntersectionAttributes RayDesc RAY_FLAG</Keywords>
<Keywords name="Keywords4">unroll loop flatten branch earlydepthstencil domain instance maxtessfactor outputcontrolpoints outputtopology partitioning patchconstantfunc numthreads maxvertexcount precise input_attachment_index</Keywords>
<Keywords name="Keywords4">unroll loop flatten branch earlydepthstencil domain instance maxtessfactor outputcontrolpoints outputtopology partitioning patchconstantfunc numthreads maxvertexcount precise input_attachment_index no_specialization</Keywords>
<Keywords name="Keywords5">SV_DispatchThreadID SV_DomainLocation SV_GroupID SV_GroupIndex SV_GroupThreadID SV_GSInstanceID SV_InsideTessFactor SV_OutputControlPointID SV_Coverage SV_Depth SV_Position SV_IsFrontFace SV_RenderTargetArrayIndex SV_SampleIndex SV_ViewportArrayIndex SV_InstanceID SV_PrimitiveID SV_VertexID</Keywords>
<Keywords name="Keywords6">SV_ClipDistance SV_CullDistance SV_Target BINORMAL BLENDINDICES BLENDWEIGHT COLOR NORMAL POSITION POSITIONT PSIZE TANGENT TEXCOORD FOG TESSFACTOR TEXCOORD VFACE VPOS DEPTH</Keywords>
<Keywords name="Keywords7">abs acos all AllMemoryBarrier AllMemoryBarrierWithGroupSync any asdouble asfloat asin asint asuint atan atan2 ceil clamp clip cos cosh countbits cross ddx ddx_coarse ddx_fine ddy ddy_coards ddy_fine degrees determinant DeviceMemoryBarrier DeviceMemoryBarrierWithGroupSync distance dot dst EvaluateAttributeAtCentroid EvaluateAttributeAtSample EvaluateAttributeSnapped exp exp2 f16tof32 f32tof16 faceforward firstbithigh firstbitlow floor fmod frac frexp fwidth GetRenderTargetSampleCount GetRenderTargetSamplePosition GroupMemoryBarrier GroupMemoryBarrierWithGroupSync InterlockedAdd InterlockedAnd InterlockedCompareExchange InterlockedExchange InterlockedMax InterlockedMin IntterlockedOr InterlockedXor isfinite isinf isnan ldexp length lerp lit log log10 log2 mad max min modf mul normalize pow Process2DQuadTessFactorsAvg Process2DQuadTessFactorsMax Process2DQuadTessFactorsMin ProcessIsolineTessFactors ProcessQuadTessFactorsAvg ProcessQuadTessFactorsMax ProcessQuadTessFactorsMin ProcessTriTessFactorsAvg ProcessTriTessFactorsMax ProcessTriTessFactorsMin radians rcp reflect refract reversebits round rsqrt saturate sign sin sincos sinh smoothstep sqrt step tan tanh transpose trunc Append RestartStrip CalculateLevelOfDetail CalculateLevelOfDetailUnclamped GetDimensions GetSamplePosition Load Sample SampleBias SampleCmp SampleCmpLevelZero SampleGrad SampleLevel Load2 Load3 Load4 Consume Store Store2 Store3 Store4 DecrementCounter IncrementCounter mips Gather GatherRed GatherGreen GatherBlue GatherAlpha GatherCmp GatherCmpRed GatherCmpGreen GatherCmpBlue GatherCmpAlpha AddressU AddressV AddressW BorderColor MinFilter MagFilter MipFilter MaxAnisotropy MaxLOD MinLOD MipLODBias ComparisonFunc ReductionType Point Linear Filter Comparison Minimum Maximum Wrap Mirror Clamp Border MirrorOnce Never Less Equal LessEqual Greater NotEqual GreaterEqual Always OpaqueBlack TransparentBlack OpaqueWhite SubpassLoad TraceRay ReportHit IgnoreHit CallShader AcceptHitAndEndSearch DispatchRaysIndex DispatchRaysDimensions PrimitiveIndex</Keywords>
Expand Down
2 changes: 1 addition & 1 deletion EditorsLint/AZSL_notepad++_dark.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<Keywords name="Keywords1">break continue if else switch case return for while do typedef namespace true false compile discard inline const struct class interface static extern register volatile inline nointerpolation shared uniform row_major column_major snorm unorm cbuffer groupshared SamplerState in out inout point triangle line lineadj triangleadj rootconstant ShaderResourceGroup ShaderResourceGroupSemantic ShaderVariantFallback FrequencyId associatedtype typealias typeof __azslc_print_message __azslc_print_symbol __azslc_prtsym_fully_qualified __azslc_prtsym_least_qualified __azslc_prtsym_constint_value enum option</Keywords>
<Keywords name="Keywords2">void bool bool1 bool2 bool3 bool4 int int1 int2 int3 int4 uint uint1 uint2 uint3 uint4 half half1 half2 half3 half4 float float1 float2 float3 float4 double double1 double2 double3 double4 matrix bool1x1 bool1x2 bool1x3 bool1x4 bool2x1 bool2x2 bool2x3 bool2x4 bool3x1 bool3x2 bool3x3 bool3x4 bool4x1 bool4x2 bool4x3 bool4x4 int1x1 int1x2 int1x3 int1x4 int2x1 int2x2 int2x3 int2x4 int3x1 int3x2 int3x3 int3x4 int4x1 int4x2 int4x3 int4x4 uint1x1 uint1x2 uint1x3 uint1x4 uint2x1 uint2x2 uint2x3 uint2x4 uint3x1 uint3x2 uint3x3 uint3x4 uint4x1 uint4x2 uint4x3 uint4x4 dword dword1 dword2 dword3 dword4 dword1x1 dword1x2 dword1x3 dword1x4 dword2x1 dword2x2 dword2x3 dword2x4 dword3x1 dword3x2 dword3x3 dword3x4 dword4x1 dword4x2 dword4x3 dword4x4 half1x1 half1x2 half1x3 half1x4 half2x1 half2x2 half2x3 half2x4 half3x1 half3x2 half3x3 half3x4 half4x1 half4x2 half4x3 half4x4 float1x1 float1x2 float1x3 float1x4 float2x1 float2x2 float2x3 float2x4 float3x1 float3x2 float3x3 float3x4 float4x1 float4x2 float4x3 float4x4 double1x1 double1x2 double1x3 double1x4 double2x1 double2x2 double2x3 double2x4 double3x1 double3x2 double3x3 double3x4 double4x1 double4x2 double4x3 double4x4 vector matrix</Keywords>
<Keywords name="Keywords3">Texture Texture1D Texture1DArray Texture2D Texture2DArray Texture2DMS Texture2DMSArray Texture3D TextureCube RWTexture1D RWTexture1DArray RWTexture2D RWTexture2DArray RWTexture3D Buffer StructuredBuffer AppendStructuredBuffer ConsumeStructuredBuffer RWBuffer RWStructuredBuffer ByteAddressBuffer RWByteAddressBuffer PointStream TriangleStream LineStream InputPatch OutputPatch SubpassInput SubpassInputMS ConstantBuffer Sampler sampler RaytracingAccelerationStructure BuiltInTriangleIntersectionAttributes RayDesc RAY_FLAG </Keywords>
<Keywords name="Keywords4">unroll loop flatten branch earlydepthstencil domain instance maxtessfactor outputcontrolpoints outputtopology partitioning patchconstantfunc numthreads maxvertexcount precise input_attachment_index</Keywords>
<Keywords name="Keywords4">unroll loop flatten branch earlydepthstencil domain instance maxtessfactor outputcontrolpoints outputtopology partitioning patchconstantfunc numthreads maxvertexcount precise input_attachment_index no_specialization</Keywords>
<Keywords name="Keywords5">SV_DispatchThreadID SV_DomainLocation SV_GroupID SV_GroupIndex SV_GroupThreadID SV_GSInstanceID SV_InsideTessFactor SV_OutputControlPointID SV_Coverage SV_Depth SV_Position SV_IsFrontFace SV_RenderTargetArrayIndex SV_SampleIndex SV_ViewportArrayIndex SV_InstanceID SV_PrimitiveID SV_VertexID</Keywords>
<Keywords name="Keywords6">SV_ClipDistance SV_CullDistance SV_Target BINORMAL BLENDINDICES BLENDWEIGHT COLOR NORMAL POSITION POSITIONT PSIZE TANGENT TEXCOORD FOG TESSFACTOR TEXCOORD VFACE VPOS DEPTH</Keywords>
<Keywords name="Keywords7">abs acos all AllMemoryBarrier AllMemoryBarrierWithGroupSync any asdouble asfloat asin asint asuint atan atan2 ceil clamp clip cos cosh countbits cross ddx ddx_coarse ddx_fine ddy ddy_coards ddy_fine degrees determinant DeviceMemoryBarrier DeviceMemoryBarrierWithGroupSync distance dot dst EvaluateAttributeAtCentroid EvaluateAttributeAtSample EvaluateAttributeSnapped exp exp2 f16tof32 f32tof16 faceforward firstbithigh firstbitlow floor fmod frac frexp fwidth GetRenderTargetSampleCount GetRenderTargetSamplePosition GroupMemoryBarrier GroupMemoryBarrierWithGroupSync InterlockedAdd InterlockedAnd InterlockedCompareExchange InterlockedExchange InterlockedMax InterlockedMin IntterlockedOr InterlockedXor isfinite isinf isnan ldexp length lerp lit log log10 log2 mad max min modf mul normalize pow Process2DQuadTessFactorsAvg Process2DQuadTessFactorsMax Process2DQuadTessFactorsMin ProcessIsolineTessFactors ProcessQuadTessFactorsAvg ProcessQuadTessFactorsMax ProcessQuadTessFactorsMin ProcessTriTessFactorsAvg ProcessTriTessFactorsMax ProcessTriTessFactorsMin radians rcp reflect refract reversebits round rsqrt saturate sign sin sincos sinh smoothstep sqrt step tan tanh transpose trunc Append RestartStrip CalculateLevelOfDetail CalculateLevelOfDetailUnclamped GetDimensions GetSamplePosition Load Sample SampleBias SampleCmp SampleCmpLevelZero SampleGrad SampleLevel Load2 Load3 Load4 Consume Store Store2 Store3 Store4 DecrementCounter IncrementCounter mips Gather GatherRed GatherGreen GatherBlue GatherAlpha GatherCmp GatherCmpRed GatherCmpGreen GatherCmpBlue GatherCmpAlpha AddressU AddressV AddressW BorderColor MinFilter MagFilter MipFilter MaxAnisotropy MaxLOD MinLOD MipLODBias ComparisonFunc ReductionType Point Linear Filter Comparison Minimum Maximum Wrap Mirror Clamp Border MirrorOnce Never Less Equal LessEqual Greater NotEqual GreaterEqual Always OpaqueBlack TransparentBlack OpaqueWhite SubpassLoad TraceRay ReportHit IgnoreHit CallShader AcceptHitAndEndSearch DispatchRaysIndex DispatchRaysDimensions PrimitiveIndex</Keywords>
Expand Down
47 changes: 47 additions & 0 deletions Platform/Common/src/CommonVulkanPlatformEmitter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/

#include <AzslcEmitter.h>
#include "CommonVulkanPlatformEmitter.h"

namespace AZ::ShaderCompiler
{
string CommonVulkanPlatformEmitter::GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbolUid, const Options& options) const
{
std::stringstream stream;
auto* ir = codeEmitter.GetIR();
auto* varInfo = ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());
auto retInfo = varInfo->GetTypeRefInfo();
string typeAsStr = codeEmitter.GetTranslatedName(retInfo, UsageContext::ReferenceSite);
string defaultValue = codeEmitter.GetInitializerClause(varInfo);
Modifiers forbidden = StorageFlag::Static;
assert(varInfo->m_specializationId >= 0);
stream << "[[vk::constant_id(" << varInfo->m_specializationId << ")]]\n";
if (retInfo.m_typeClass == TypeClass::Enum)
{
// Enums are not a valid type for specialization constant, so we use the underlaying scalar type.
// The we add a global static variable with the enum type that cast from the specialization constant.
auto* enumClassInfo = ir->GetSymbolSubAs<ClassInfo>(retInfo.m_typeId.GetName());
auto& enumerators = enumClassInfo->GetOrderedMembers();
auto scalarType = enumClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.UnderlyingScalarToStr();
string scName = JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "_SC_OPTION";
stream << "const " << scalarType << " " << scName;
// TODO: if using a default value, emit it as the underlying scalar type, since enums are not valid default values for
// specialization constants (casting is also not allowed). We set default values for shader options at runtime so it's not
// a problem.
stream << " = (" << scalarType << ")" << "0;\n";
// Set the default value for the global static variable as the value of the specialization constant.
defaultValue = std::move(scName);
forbidden = Modifiers{};
}
stream << codeEmitter.GetTranslatedName(varInfo->m_typeInfoExt, UsageContext::ReferenceSite, options, forbidden) + " ";
stream << codeEmitter.GetTranslatedName(symbolUid.m_name, UsageContext::DeclarationSite) + " = ";
stream << "(" << typeAsStr << ")" << (defaultValue.empty() ? "0" : defaultValue) << "; \n";
return stream.str();
}
}
24 changes: 24 additions & 0 deletions Platform/Common/src/CommonVulkanPlatformEmitter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#pragma once

#include <AzslcPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct CommonVulkanPlatformEmitter : PlatformEmitter
{
public:
[[nodiscard]]
string GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const override;

protected:
CommonVulkanPlatformEmitter() : PlatformEmitter {} {};
};
}
30 changes: 30 additions & 0 deletions Platform/Windows/src/DirectX12PlatformEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,34 @@ namespace AZ::ShaderCompiler

return Decorate("#define sig ", Join(rootAttrList, ", \" \\\n"), "\"\n\n");
}

string DirectX12PlatformEmitter::GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbolUid, const Options& options) const
{
// Specialization constants will be represented by a volatile variable that will be patched later.
std::stringstream stream;
auto* ir = codeEmitter.GetIR();
auto* varInfo = ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());
auto retInfo = varInfo->GetTypeRefInfo();

// Volatile is not allowed for global variables, so we create a function to wrap it.
std:string varName = "sc_" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name);
std::string retType = codeEmitter.GetTranslatedName(retInfo, UsageContext::ReferenceSite);
std::string functionName = "GetSpecializationConstant_" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "()";

// Emit the function
assert(varInfo->m_specializationId >= 0);
stream << retType << " " << functionName;
stream << "\n{\n";
stream << " volatile int " << varName;
stream << " = " << varInfo->m_specializationId << ";\n";
stream << " return (" << retType << ") " << varName << ";\n";
stream << "}\n\n";

// Emit the global variable that is going to be the shader option. It's default value will be the return value
// of the function we just created.
stream << codeEmitter.GetTranslatedName(varInfo->m_typeInfoExt, UsageContext::ReferenceSite, options) + " ";
stream << codeEmitter.GetTranslatedName(symbolUid.m_name, UsageContext::DeclarationSite);
stream << " = " << functionName << ";\n";
return stream.str();
}
}
2 changes: 2 additions & 0 deletions Platform/Windows/src/DirectX12PlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace AZ::ShaderCompiler

bool RequiresUniqueSpaceForUnboundedArrays() const override {return true;}

[[nodiscard]]
string GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const override;

private:
DirectX12PlatformEmitter() : PlatformEmitter {} {};
Expand Down
1 change: 0 additions & 1 deletion Platform/Windows/src/VulkanPlatformEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ namespace AZ::ShaderCompiler
}
return { stream.str(), registerString };
}

}
5 changes: 3 additions & 2 deletions Platform/Windows/src/VulkanPlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#pragma once

#include <AzslcPlatformEmitter.h>
#include <CommonVulkanPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct VulkanPlatformEmitter : PlatformEmitter
struct VulkanPlatformEmitter : CommonVulkanPlatformEmitter
{
public:
//! This method will be called once and only once when the platform emitter registers itself to the system.
Expand All @@ -26,6 +27,6 @@ namespace AZ::ShaderCompiler
std::pair<string, string> GetDataViewHeaderFooter(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, uint32_t bindInfoRegisterIndex, string_view registerTypeLetter, optional<string> stringifiedLogicalSpace) const override final;

private:
VulkanPlatformEmitter() : PlatformEmitter {} {};
VulkanPlatformEmitter() : CommonVulkanPlatformEmitter {} {};
};
}
5 changes: 3 additions & 2 deletions Platform/iOS/src/MetalPlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#pragma once

#include <AzslcPlatformEmitter.h>
#include <CommonVulkanPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct MetalPlatformEmitter : PlatformEmitter
struct MetalPlatformEmitter : CommonVulkanPlatformEmitter
{
public:
//! This method will be called once and only once when the platform emitter registers itself to the system.
Expand All @@ -25,6 +26,6 @@ namespace AZ::ShaderCompiler
uint32_t AlignRootConstants(uint32_t size) const override final;

private:
MetalPlatformEmitter() : PlatformEmitter {} {};
MetalPlatformEmitter() : CommonVulkanPlatformEmitter {} {};
};
}
Loading

0 comments on commit 58c32ff

Please sign in to comment.