Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Specialization constants for shader options #88

Merged
merged 5 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions Platform/Common/src/CommonVulkanPlatformEmitter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/

#include <AzslcEmitter.h>
#include "CommonVulkanPlatformEmitter.h"

namespace AZ::ShaderCompiler
{
string CommonVulkanPlatformEmitter::GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbolUid, const Options& options) const
{
std::stringstream stream;
auto* ir = codeEmitter.GetIR();
auto* varInfo = ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());
auto retInfo = varInfo->GetTypeRefInfo();
string typeAsStr = codeEmitter.GetTranslatedName(retInfo, UsageContext::ReferenceSite);
string defaultValue = codeEmitter.GetInitializerClause(varInfo);
Modifiers forbidden = StorageFlag::Static;
assert(varInfo->m_specializationId >= 0);
stream << "[[vk::constant_id(" << varInfo->m_specializationId << ")]]\n";
if (retInfo.m_typeClass == TypeClass::Enum)
{
// Enums are not a valid type for specialization constant, so we use the underlaying scalar type.
// The we add a global static variable with the enum type that cast from the specialization constant.
auto* enumClassInfo = ir->GetSymbolSubAs<ClassInfo>(retInfo.m_typeId.GetName());
auto& enumerators = enumClassInfo->GetOrderedMembers();
auto scalarType = enumClassInfo->Get<EnumerationInfo>()->m_underlyingType.m_arithmeticInfo.UnderlyingScalarToStr();
string scName = JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "_SC_OPTION";
stream << "const " << scalarType << " " << scName;
// TODO: if using a default value, emit it as the underlying scalar type, since enums are not valid default values for
// specialization constants (casting is also not allowed). We set default values for shader options at runtime so it's not
// a problem.
stream << " = (" << scalarType << ")" << "0;\n";
// Set the default value for the global static variable as the value of the specialization constant.
defaultValue = std::move(scName);
forbidden = Modifiers{};
}
stream << codeEmitter.GetTranslatedName(varInfo->m_typeInfoExt, UsageContext::ReferenceSite, options, forbidden) + " ";
stream << codeEmitter.GetTranslatedName(symbolUid.m_name, UsageContext::DeclarationSite) + " = ";
stream << "(" << typeAsStr << ")" << (defaultValue.empty() ? "0" : defaultValue) << "; \n";
return stream.str();
}
}
24 changes: 24 additions & 0 deletions Platform/Common/src/CommonVulkanPlatformEmitter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) Contributors to the Open 3D Engine Project.
* For complete copyright and license terms please see the LICENSE at the root of this distribution.
*
* SPDX-License-Identifier: Apache-2.0 OR MIT
*
*/
#pragma once

#include <AzslcPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct CommonVulkanPlatformEmitter : PlatformEmitter
{
public:
[[nodiscard]]
string GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const override;

protected:
CommonVulkanPlatformEmitter() : PlatformEmitter {} {};
};
}
30 changes: 30 additions & 0 deletions Platform/Windows/src/DirectX12PlatformEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,34 @@ namespace AZ::ShaderCompiler

return Decorate("#define sig ", Join(rootAttrList, ", \" \\\n"), "\"\n\n");
}

string DirectX12PlatformEmitter::GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbolUid, const Options& options) const
{
// Specialization constants will be represented by a volatile variable that will be patched later.
std::stringstream stream;
auto* ir = codeEmitter.GetIR();
auto* varInfo = ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());
auto retInfo = varInfo->GetTypeRefInfo();

// Volatile is not allowed for global variables, so we create a function to wrap it.
std:string varName = "sc_" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name);
std::string retType = codeEmitter.GetTranslatedName(retInfo, UsageContext::ReferenceSite);
std::string functionName = "GetSpecializationConstant_" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "()";

// Emit the function
assert(varInfo->m_specializationId >= 0);
stream << retType << " " << functionName;
stream << "\n{\n";
stream << " volatile int " << varName;
stream << " = " << varInfo->m_specializationId << ";\n";
stream << " return (" << retType << ") " << varName << ";\n";
stream << "}\n\n";

// Emit the global variable that is going to be the shader option. It's default value will be the return value
// of the function we just created.
stream << codeEmitter.GetTranslatedName(varInfo->m_typeInfoExt, UsageContext::ReferenceSite, options) + " ";
stream << codeEmitter.GetTranslatedName(symbolUid.m_name, UsageContext::DeclarationSite);
stream << " = " << functionName << ";\n";
return stream.str();
}
}
2 changes: 2 additions & 0 deletions Platform/Windows/src/DirectX12PlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace AZ::ShaderCompiler

bool RequiresUniqueSpaceForUnboundedArrays() const override {return true;}

[[nodiscard]]
string GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const override;

private:
DirectX12PlatformEmitter() : PlatformEmitter {} {};
Expand Down
1 change: 0 additions & 1 deletion Platform/Windows/src/VulkanPlatformEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,4 @@ namespace AZ::ShaderCompiler
}
return { stream.str(), registerString };
}

}
5 changes: 3 additions & 2 deletions Platform/Windows/src/VulkanPlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#pragma once

#include <AzslcPlatformEmitter.h>
#include <CommonVulkanPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct VulkanPlatformEmitter : PlatformEmitter
struct VulkanPlatformEmitter : CommonVulkanPlatformEmitter
{
public:
//! This method will be called once and only once when the platform emitter registers itself to the system.
Expand All @@ -26,6 +27,6 @@ namespace AZ::ShaderCompiler
std::pair<string, string> GetDataViewHeaderFooter(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, uint32_t bindInfoRegisterIndex, string_view registerTypeLetter, optional<string> stringifiedLogicalSpace) const override final;

private:
VulkanPlatformEmitter() : PlatformEmitter {} {};
VulkanPlatformEmitter() : CommonVulkanPlatformEmitter {} {};
};
}
5 changes: 3 additions & 2 deletions Platform/iOS/src/MetalPlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#pragma once

#include <AzslcPlatformEmitter.h>
#include <CommonVulkanPlatformEmitter.h>

namespace AZ::ShaderCompiler
{
// PlatformEmitter is not a Backend by design. It's a supplement to CodeEmitter, not a replacement
struct MetalPlatformEmitter : PlatformEmitter
struct MetalPlatformEmitter : CommonVulkanPlatformEmitter
{
public:
//! This method will be called once and only once when the platform emitter registers itself to the system.
Expand All @@ -25,6 +26,6 @@ namespace AZ::ShaderCompiler
uint32_t AlignRootConstants(uint32_t size) const override final;

private:
MetalPlatformEmitter() : PlatformEmitter {} {};
MetalPlatformEmitter() : CommonVulkanPlatformEmitter {} {};
};
}
18 changes: 18 additions & 0 deletions src/AzslcBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ namespace AZ::ShaderCompiler
Json::Value varRoot(Json::objectValue);
varRoot["meta"] = "Variant options list exported by AZSLc";

bool useSpecializationConstants = false;

Json::Value shaderOptions(Json::arrayValue);
uint32_t keyOffsetBits = 0;

Expand All @@ -454,6 +456,8 @@ namespace AZ::ShaderCompiler
bool isUdt = IsUserDefined(varInfo->GetTypeClass());
assert(isUdt || IsPredefinedType(varInfo->GetTypeClass()));
shaderOption["kind"] = isUdt ? "user-defined" : "predefined";
shaderOption["specializationId"] = varInfo->m_specializationId;
useSpecializationConstants |= varInfo->m_specializationId >= 0;

AppendOptionRange(shaderOption, uid, varInfo, options);

Expand Down Expand Up @@ -482,11 +486,25 @@ namespace AZ::ShaderCompiler
}
}

varRoot["specializationConstants"] = options.m_useSpecializationConstantsForOptions && useSpecializationConstants;
varRoot["ShaderOptions"] = shaderOptions;

return varRoot;
}

void Backend::SetupOptionsSpecializationId(const Options& options) const
{
uint32_t specializationId = 0;
for (auto& [uid, varInfo, kindInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_3<VarInfo>())
{
if (varInfo->CheckHasStorageFlag(StorageFlag::Option) &&
!m_ir->m_symbols.GetAttribute(uid, "no_specialization"))
{
varInfo->m_specializationId = specializationId++;
}
}
}

// little check utility
static void CheckHasOneFoldedDimensionOrThrow(const ArrayDimensions& dims, string_view callSite)
{
Expand Down
8 changes: 6 additions & 2 deletions src/AzslcBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace AZ::ShaderCompiler
int m_rootConstantsMaxSize = std::numeric_limits<int>::max(); //!< Indicates the number of root constants to be allowed, 0 means root constants not enabled
Packing::Layout m_packConstantBuffers = Packing::Layout::DirectXPacking; //!< Packing standard for constant buffers (uniform)
Packing::Layout m_packDataBuffers = Packing::Layout::CStylePacking; //!< Packing standard for data buffer views
bool m_useSpecializationConstantsForOptions = false; //!< Use specialization constants for shader options
};

struct Binding
Expand Down Expand Up @@ -148,6 +149,9 @@ namespace AZ::ShaderCompiler
//! Get HLSL form of in/out modifiers
static const char* GetInputModifier(const TypeQualifiers& typeQualifier);

//! Get the initialization clause as a string. Returns an empty string if it doesn't have any initialization.
string GetInitializerClause(const AZ::ShaderCompiler::VarInfo* varInfo) const;

//! Fabricate a HLSL snippet that represents the type stored in typeInfo. Relevant options relate to matrix qualifiers.
//! \param banned is the Flag you can setup to list a collection of type qualifiers you don't want to reproduce.
string GetExtendedTypeInfo(const ExtendedTypeInfo& extTypeInfo, const Options& options, Modifiers banned, std::function<string(const TypeRefInfo&)> translator) const;
Expand All @@ -163,8 +167,6 @@ namespace AZ::ShaderCompiler

string GetTranspiledTokens(misc::Interval interval) const;

string GetInitializerClause(const AZ::ShaderCompiler::VarInfo* varInfo) const;

uint32_t GetNumberOf32BitConstants(const Options& options, const IdentifierUID& uid) const;

RootSigDesc BuildSignatureDescription(const Options& options, int num32BitConst) const;
Expand All @@ -177,6 +179,8 @@ namespace AZ::ShaderCompiler

Json::Value GetVariantList(const Options& options, bool includeEmpty = false) const;

void SetupOptionsSpecializationId(const Options& options) const;

IntermediateRepresentation* m_ir;
TokenStream* m_tokens;

Expand Down
40 changes: 28 additions & 12 deletions src/AzslcEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ namespace AZ::ShaderCompiler
const RootSigDesc rootSig = BuildSignatureDescription(options, numOf32bitConst);

SetupScopeMigrations(options);
SetupOptionsSpecializationId(options);

// Emit global attributes
for (const auto& attr : m_ir->m_symbols.GetGlobalAttributeList())
Expand Down Expand Up @@ -379,17 +380,24 @@ namespace AZ::ShaderCompiler
{
assert(m_ir->GetKind(symbolUid) == Kind::Variable);
assert(IsTopLevelThroughTranslation(symbolUid));

auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());

EmitGetShaderKeyFunctionDeclaration(symbolUid, varInfo->GetTypeRefInfo());
m_out << ";\n\n";
if (options.m_useSpecializationConstantsForOptions && varInfo->m_specializationId >= 0)
{
m_out << GetPlatformEmitter().GetSpecializationConstant(*this, symbolUid, options);
}
else
{

m_out << "#if defined(" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "_OPTION_DEF)\n";
EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine));
m_out << "_OPTION_DEF ;\n#else\n";
EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine) | VarDeclHas::Initializer);
m_out << ";\n#endif\n";
EmitGetShaderKeyFunctionDeclaration(symbolUid, varInfo->GetTypeRefInfo());
m_out << ";\n\n";

m_out << "#if defined(" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "_OPTION_DEF)\n";
EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine));
m_out << "_OPTION_DEF ;\n#else\n";
EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine) | VarDeclHas::Initializer);
m_out << ";\n#endif\n";
}
}

void CodeEmitter::EmitShaderVariantOptionGetters(const Options& options) const
Expand All @@ -413,15 +421,18 @@ namespace AZ::ShaderCompiler
m_out << "// Generated code: ShaderVariantOptions fallback value getters:\n";

auto shaderOptions = GetVariantList(options, true);
auto shaderOptionIndex = 0;

for (const auto& [uid, varInfo] : symbols)
for (uint32_t shaderOptionIndex = 0; shaderOptionIndex < symbols.size(); ++shaderOptionIndex)
{
const auto& [uid, varInfo] = symbols[shaderOptionIndex];
if (options.m_useSpecializationConstantsForOptions && varInfo->m_specializationId >= 0)
{
continue;
}

const auto keySizeInBits = shaderOptions["ShaderOptions"][shaderOptionIndex]["keySize"].asUInt();
const auto keyOffsetBits = shaderOptions["ShaderOptions"][shaderOptionIndex]["keyOffset"].asUInt();
const auto defaultValue = shaderOptions["ShaderOptions"][shaderOptionIndex]["defaultValue"].asString();

shaderOptionIndex++;
EmitGetShaderKeyFunction(m_shaderVariantFallbackUid, uid, keySizeInBits, keyOffsetBits, defaultValue, varInfo->GetTypeRefInfo());
}
}
Expand Down Expand Up @@ -693,6 +704,11 @@ namespace AZ::ShaderCompiler
// Reserved for integer type option variables. Do not re-emit
outstream << "// original attribute: [[" << attrInfo << "]]\n ";
}
else if (attrInfo.m_attribute == "no_specialization")
{
// Reserved for avoiding specialization of a shader option. Do not re-emit
outstream << "// original attribute: [[" << attrInfo << "]]\n ";
}

else
{
Expand Down
1 change: 1 addition & 0 deletions src/AzslcKindInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ namespace AZ::ShaderCompiler
optional<SamplerStateDesc> m_samplerState;
ExtendedTypeInfo m_typeInfoExt;
int m_estimatedCostImpact = -1; //!< Cached value calculated by AnalyzeOptionRanks
int m_specializationId= -1; //< id of the specialization. -1 means no specialization.
};

// VarInfo methods definitions
Expand Down
4 changes: 4 additions & 0 deletions src/AzslcMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ int main(int argc, const char* argv[])
int maxSpaces = std::numeric_limits<int>::max();
auto maxSpacesOpt = cli.add_option("--max-spaces", maxSpaces, "Will choose register spaces that do not extend past this limit.");

bool useSpecializationConstants = false;
cli.add_flag("--sc-options", useSpecializationConstants, "Use specialization constants for shader options.");

std::array<bool, Warn::EndEnumeratorSentinel_> warningOpts;
for (const auto e : Warn::Enumerate{})
{
Expand Down Expand Up @@ -575,6 +578,7 @@ int main(int argc, const char* argv[])
emitOptions.m_emitRootSig = rootSig;
emitOptions.m_padRootConstantCB = padRootConst;
emitOptions.m_skipAlignmentValidation = noAlignmentValidation;
emitOptions.m_useSpecializationConstantsForOptions = useSpecializationConstants;

if (*rootConstOpt)
{
Expand Down
5 changes: 5 additions & 0 deletions src/AzslcPlatformEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,9 @@ namespace AZ::ShaderCompiler
{
return size;
}

string PlatformEmitter::GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const
{
return "";
}
}
3 changes: 3 additions & 0 deletions src/AzslcPlatformEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,8 @@ namespace AZ::ShaderCompiler
virtual uint32_t AlignRootConstants(uint32_t size) const;

virtual bool RequiresUniqueSpaceForUnboundedArrays() const {return false;}

[[nodiscard]]
virtual string GetSpecializationConstant(const CodeEmitter& codeEmitter, const IdentifierUID& symbol, const Options& options) const;
};
}
1 change: 1 addition & 0 deletions src/AzslcReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ namespace AZ::ShaderCompiler
void CodeReflection::DumpVariantList(const Options& options) const
{
AnalyzeOptionRanks();
SetupOptionsSpecializationId(options);
m_out << GetVariantList(options);
m_out << "\n";
}
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ target_include_directories(
PRIVATE
${PROJECT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/external
${PROJECT_SOURCE_DIR}/../Platform/Common/src
${ANTLR4CPP_INCLUDE_DIRS}
${MPARK_VARIANT_INCLUDE_DIRS}
${TINY_OPTIONAL_INCLUDE_DIRS}
Expand Down
Loading
Loading