Skip to content

Commit

Permalink
[spv-in] add support for specialization constants
Browse files Browse the repository at this point in the history
teoxoy authored and jimblandy committed Apr 11, 2024
1 parent b7519bb commit 9df6819
Showing 10 changed files with 1,610 additions and 93 deletions.
6 changes: 4 additions & 2 deletions naga/src/front/spv/error.rs
Original file line number Diff line number Diff line change
@@ -118,8 +118,8 @@ pub enum Error {
ControlFlowGraphCycle(crate::front::spv::BlockId),
#[error("recursive function call %{0}")]
FunctionCallCycle(spirv::Word),
#[error("invalid array size {0:?}")]
InvalidArraySize(Handle<crate::Constant>),
#[error("invalid array size %{0}")]
InvalidArraySize(spirv::Word),
#[error("invalid barrier scope %{0}")]
InvalidBarrierScope(spirv::Word),
#[error("invalid barrier memory semantics %{0}")]
@@ -130,6 +130,8 @@ pub enum Error {
come from a binding)"
)]
NonBindingArrayOfImageOrSamplers,
#[error("naga only supports specialization constant IDs up to 65535 but was given {0}")]
SpecIdTooHigh(u32),
}

impl Error {
7 changes: 5 additions & 2 deletions naga/src/front/spv/function.rs
Original file line number Diff line number Diff line change
@@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
})
},
local_variables: Arena::new(),
expressions: self
.make_expression_storage(&module.global_variables, &module.constants),
expressions: self.make_expression_storage(
&module.global_variables,
&module.constants,
&module.overrides,
),
named_expressions: crate::NamedExpressions::default(),
body: crate::Block::new(),
}
13 changes: 8 additions & 5 deletions naga/src/front/spv/image.rs
Original file line number Diff line number Diff line change
@@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
}
spirv::ImageOperands::CONST_OFFSET => {
let offset_constant = self.next()?;
let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
let offset_handle = ctx.global_expressions.append(
crate::Expression::Constant(offset_handle),
Default::default(),
);
let offset_expr = self
.lookup_constant
.lookup(offset_constant)?
.inner
.to_expr();
let offset_handle = ctx
.global_expressions
.append(offset_expr, Default::default());
offset = Some(offset_handle);
words_left -= 1;
}
172 changes: 88 additions & 84 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
@@ -196,6 +196,7 @@ struct Decoration {
location: Option<spirv::Word>,
desc_set: Option<spirv::Word>,
desc_index: Option<spirv::Word>,
specialization_constant_id: Option<spirv::Word>,
storage_buffer: bool,
offset: Option<spirv::Word>,
array_stride: Option<NonZeroU32>,
@@ -277,9 +278,24 @@ struct LookupType {
base_id: Option<spirv::Word>,
}

#[derive(Debug)]
enum Constant {
Constant(Handle<crate::Constant>),
Override(Handle<crate::Override>),
}

impl Constant {
const fn to_expr(&self) -> crate::Expression {
match *self {
Self::Constant(c) => crate::Expression::Constant(c),
Self::Override(o) => crate::Expression::Override(o),
}
}
}

#[derive(Debug)]
struct LookupConstant {
handle: Handle<crate::Constant>,
inner: Constant,
type_id: spirv::Word,
}

@@ -751,6 +767,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
spirv::Decoration::RowMajor => {
dec.matrix_major = Some(Majority::Row);
}
spirv::Decoration::SpecId => {
dec.specialization_constant_id = Some(self.next()?);
}
other => {
log::warn!("Unknown decoration {:?}", other);
for _ in base_words + 1..inst.wc {
@@ -1385,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
inst.expect(5)?;
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
Some(
ctx.expressions
.append(crate::Expression::Constant(lconst.handle), span),
)
Some(ctx.expressions.append(lconst.inner.to_expr(), span))
} else {
None
};
@@ -3642,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;

let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;

if exec_scope == spirv::Scope::Workgroup as u32 {
@@ -3705,6 +3721,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
&mut self,
globals: &Arena<crate::GlobalVariable>,
constants: &Arena<crate::Constant>,
overrides: &Arena<crate::Override>,
) -> Arena<crate::Expression> {
let mut expressions = Arena::new();
#[allow(clippy::panic)]
@@ -3729,8 +3746,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
// register constants
for (&id, con) in self.lookup_constant.iter() {
let span = constants.get_span(con.handle);
let handle = expressions.append(crate::Expression::Constant(con.handle), span);
let (expr, span) = match con.inner {
Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)),
Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)),
};
let handle = expressions.append(expr, span);
self.lookup_expression.insert(
id,
LookupExpression {
@@ -3935,11 +3955,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeImage => self.parse_type_image(inst, &mut module),
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant => self.parse_constant(inst, &mut module),
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
Op::ConstantComposite | Op::SpecConstantComposite => {
self.parse_composite_constant(inst, &mut module)
}
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
Op::ConstantTrue | Op::SpecConstantTrue => {
self.parse_bool_constant(inst, true, &mut module)
}
Op::ConstantFalse | Op::SpecConstantFalse => {
self.parse_bool_constant(inst, false, &mut module)
}
Op::Variable => self.parse_global_variable(inst, &mut module),
Op::Function => {
self.switch(ModuleState::Function, inst.op)?;
@@ -4496,9 +4522,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;

let size = resolve_constant(module.to_ctx(), length_const.handle)
let size = resolve_constant(module.to_ctx(), &length_const.inner)
.and_then(NonZeroU32::new)
.ok_or(Error::InvalidArraySize(length_const.handle))?;
.ok_or(Error::InvalidArraySize(length_id))?;

let decor = self.future_decor.remove(&id).unwrap_or_default();
let base = self.lookup_type.lookup(type_id)?.handle;
@@ -4911,28 +4937,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
_ => return Err(Error::UnsupportedType(type_lookup.handle)),
};

let decor = self.future_decor.remove(&id).unwrap_or_default();

let span = self.span_from_with_op(start);

let init = module
.global_expressions
.append(crate::Expression::Literal(literal), span);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_composite_constant(
@@ -4957,32 +4968,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let constant = self.lookup_constant.lookup(component_id)?;
let expr = module
.global_expressions
.append(crate::Expression::Constant(constant.handle), span);
.append(constant.inner.to_expr(), span);
components.push(expr);
}

let decor = self.future_decor.remove(&id).unwrap_or_default();

let span = self.span_from_with_op(start);

let init = module
.global_expressions
.append(crate::Expression::Compose { ty, components }, span);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_null_constant(
@@ -5000,22 +4996,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;

let decor = self.future_decor.remove(&id).unwrap_or_default();

let init = module
.global_expressions
.append(crate::Expression::ZeroValue(ty), span);
let handle = module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
);
self.lookup_constant
.insert(id, LookupConstant { handle, type_id });
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_bool_constant(
@@ -5034,26 +5019,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;

let decor = self.future_decor.remove(&id).unwrap_or_default();

let init = module.global_expressions.append(
crate::Expression::Literal(crate::Literal::Bool(value)),
span,
);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn insert_parsed_constant(
&mut self,
module: &mut crate::Module,
id: u32,
type_id: u32,
ty: Handle<crate::Type>,
init: Handle<crate::Expression>,
span: crate::Span,
) -> Result<(), Error> {
let decor = self.future_decor.remove(&id).unwrap_or_default();

let inner = if let Some(id) = decor.specialization_constant_id {
let o = crate::Override {
name: decor.name,
id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?),
ty,
init: Some(init),
};
Constant::Override(module.overrides.append(o, span))
} else {
let c = crate::Constant {
name: decor.name,
ty,
init,
};
Constant::Constant(module.constants.append(c, span))
};

self.lookup_constant
.insert(id, LookupConstant { inner, type_id });
Ok(())
}

@@ -5076,7 +5079,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let lconst = self.lookup_constant.lookup(init_id)?;
let expr = module
.global_expressions
.append(crate::Expression::Constant(lconst.handle), span);
.append(lconst.inner.to_expr(), span);
Some(expr)
} else {
None
@@ -5291,10 +5294,11 @@ fn make_index_literal(
Ok(expr)
}

fn resolve_constant(
gctx: crate::proc::GlobalCtx,
constant: Handle<crate::Constant>,
) -> Option<u32> {
fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> {
let constant = match *constant {
Constant::Constant(constant) => constant,
Constant::Override(_) => return None,
};
match gctx.global_expressions[gctx.constants[constant].init] {
crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
Binary file added naga/tests/in/spv/spec-constants.spv
Binary file not shown.
143 changes: 143 additions & 0 deletions naga/tests/in/spv/spec-constants.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
; SPIR-V
; Version: 1.0
; Generator: Google Shaderc over Glslang; 11
; Bound: 74
; Schema: 0
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %main "main" %v_Uv %Vertex_Uv %Vertex_Position %__0 %Vertex_Normal
OpSource GLSL 450
OpSourceExtension "GL_GOOGLE_cpp_style_line_directive"
OpSourceExtension "GL_GOOGLE_include_directive"
OpName %main "main"
OpName %test_constant "test_constant"
OpName %TEST_CONSTANT "TEST_CONSTANT"
OpName %TEST_CONSTANT_TRUE "TEST_CONSTANT_TRUE"
OpName %TEST_CONSTANT_FALSE "TEST_CONSTANT_FALSE"
OpName %v_Uv "v_Uv"
OpName %Vertex_Uv "Vertex_Uv"
OpName %position "position"
OpName %Vertex_Position "Vertex_Position"
OpName %Sprite_size "Sprite_size"
OpMemberName %Sprite_size 0 "size"
OpName %_ ""
OpName %gl_PerVertex "gl_PerVertex"
OpMemberName %gl_PerVertex 0 "gl_Position"
OpMemberName %gl_PerVertex 1 "gl_PointSize"
OpMemberName %gl_PerVertex 2 "gl_ClipDistance"
OpMemberName %gl_PerVertex 3 "gl_CullDistance"
OpName %__0 ""
OpName %Camera "Camera"
OpMemberName %Camera 0 "ViewProj"
OpName %__1 ""
OpName %Transform "Transform"
OpMemberName %Transform 0 "Model"
OpName %__2 ""
OpName %Vertex_Normal "Vertex_Normal"
OpDecorate %TEST_CONSTANT SpecId 0
OpDecorate %TEST_CONSTANT_TRUE SpecId 1
OpDecorate %TEST_CONSTANT_FALSE SpecId 2
OpDecorate %v_Uv Location 0
OpDecorate %Vertex_Uv Location 2
OpDecorate %Vertex_Position Location 0
OpMemberDecorate %Sprite_size 0 Offset 0
OpDecorate %Sprite_size Block
OpDecorate %_ DescriptorSet 2
OpDecorate %_ Binding 1
OpMemberDecorate %gl_PerVertex 0 BuiltIn Position
OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize
OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance
OpMemberDecorate %gl_PerVertex 3 BuiltIn CullDistance
OpDecorate %gl_PerVertex Block
OpMemberDecorate %Camera 0 ColMajor
OpMemberDecorate %Camera 0 Offset 0
OpMemberDecorate %Camera 0 MatrixStride 16
OpDecorate %Camera Block
OpDecorate %__1 DescriptorSet 0
OpDecorate %__1 Binding 0
OpMemberDecorate %Transform 0 ColMajor
OpMemberDecorate %Transform 0 Offset 0
OpMemberDecorate %Transform 0 MatrixStride 16
OpDecorate %Transform Block
OpDecorate %__2 DescriptorSet 2
OpDecorate %__2 Binding 0
OpDecorate %Vertex_Normal Location 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%TEST_CONSTANT = OpSpecConstant %float 64
%bool = OpTypeBool
%TEST_CONSTANT_TRUE = OpSpecConstantTrue %bool
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%TEST_CONSTANT_FALSE = OpSpecConstantFalse %bool
%v2float = OpTypeVector %float 2
%_ptr_Output_v2float = OpTypePointer Output %v2float
%v_Uv = OpVariable %_ptr_Output_v2float Output
%_ptr_Input_v2float = OpTypePointer Input %v2float
%Vertex_Uv = OpVariable %_ptr_Input_v2float Input
%v3float = OpTypeVector %float 3
%_ptr_Function_v3float = OpTypePointer Function %v3float
%_ptr_Input_v3float = OpTypePointer Input %v3float
%Vertex_Position = OpVariable %_ptr_Input_v3float Input
%Sprite_size = OpTypeStruct %v2float
%_ptr_Uniform_Sprite_size = OpTypePointer Uniform %Sprite_size
%_ = OpVariable %_ptr_Uniform_Sprite_size Uniform
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%_ptr_Uniform_v2float = OpTypePointer Uniform %v2float
%v4float = OpTypeVector %float 4
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%_arr_float_uint_1 = OpTypeArray %float %uint_1
%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1 %_arr_float_uint_1
%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex
%__0 = OpVariable %_ptr_Output_gl_PerVertex Output
%mat4v4float = OpTypeMatrix %v4float 4
%Camera = OpTypeStruct %mat4v4float
%_ptr_Uniform_Camera = OpTypePointer Uniform %Camera
%__1 = OpVariable %_ptr_Uniform_Camera Uniform
%_ptr_Uniform_mat4v4float = OpTypePointer Uniform %mat4v4float
%Transform = OpTypeStruct %mat4v4float
%_ptr_Uniform_Transform = OpTypePointer Uniform %Transform
%__2 = OpVariable %_ptr_Uniform_Transform Uniform
%_ptr_Output_v4float = OpTypePointer Output %v4float
%Vertex_Normal = OpVariable %_ptr_Input_v3float Input
%main = OpFunction %void None %3
%5 = OpLabel
%test_constant = OpVariable %_ptr_Function_float Function
%position = OpVariable %_ptr_Function_v3float Function
%14 = OpSelect %float %TEST_CONSTANT_TRUE %float_1 %float_0
%15 = OpFMul %float %TEST_CONSTANT %14
%17 = OpSelect %float %TEST_CONSTANT_FALSE %float_1 %float_0
%18 = OpFMul %float %15 %17
OpStore %test_constant %18
%24 = OpLoad %v2float %Vertex_Uv
OpStore %v_Uv %24
%30 = OpLoad %v3float %Vertex_Position
%37 = OpAccessChain %_ptr_Uniform_v2float %_ %int_0
%38 = OpLoad %v2float %37
%39 = OpCompositeExtract %float %38 0
%40 = OpCompositeExtract %float %38 1
%41 = OpCompositeConstruct %v3float %39 %40 %float_1
%42 = OpFMul %v3float %30 %41
OpStore %position %42
%55 = OpAccessChain %_ptr_Uniform_mat4v4float %__1 %int_0
%56 = OpLoad %mat4v4float %55
%60 = OpAccessChain %_ptr_Uniform_mat4v4float %__2 %int_0
%61 = OpLoad %mat4v4float %60
%62 = OpMatrixTimesMatrix %mat4v4float %56 %61
%63 = OpLoad %v3float %position
%64 = OpCompositeExtract %float %63 0
%65 = OpCompositeExtract %float %63 1
%66 = OpCompositeExtract %float %63 2
%67 = OpCompositeConstruct %v4float %64 %65 %66 %float_1
%68 = OpMatrixTimesVector %v4float %62 %67
%69 = OpLoad %float %test_constant
%70 = OpVectorTimesScalar %v4float %68 %69
%72 = OpAccessChain %_ptr_Output_v4float %__0 %int_0
OpStore %72 %70
OpReturn
OpFunctionEnd
31 changes: 31 additions & 0 deletions naga/tests/in/spv/spec-constants.vert
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#version 450

layout (constant_id = 0) const float TEST_CONSTANT = 64.0;
layout (constant_id = 1) const bool TEST_CONSTANT_TRUE = true;
layout (constant_id = 2) const bool TEST_CONSTANT_FALSE = false;
// layout (constant_id = 3) const vec2 TEST_CONSTANT_COMPOSITE = vec2(TEST_CONSTANT, 3.0);
// glslc error: 'constant_id' : can only be applied to a scalar

layout(location = 0) in vec3 Vertex_Position;
layout(location = 1) in vec3 Vertex_Normal;
layout(location = 2) in vec2 Vertex_Uv;

layout(location = 0) out vec2 v_Uv;

layout(set = 0, binding = 0) uniform Camera {
mat4 ViewProj;
};
layout(set = 2, binding = 0) uniform Transform {
mat4 Model;
};
layout(set = 2, binding = 1) uniform Sprite_size {
vec2 size;
};

void main() {
float test_constant = TEST_CONSTANT * float(TEST_CONSTANT_TRUE) * float(TEST_CONSTANT_FALSE)
;//* TEST_CONSTANT_COMPOSITE.x * TEST_CONSTANT_COMPOSITE.y;
v_Uv = Vertex_Uv;
vec3 position = Vertex_Position * vec3(size, 1.0);
gl_Position = ViewProj * Model * vec4(position, 1.0) * test_constant;
}
612 changes: 612 additions & 0 deletions naga/tests/out/ir/spec-constants.compact.ron

Large diffs are not rendered by default.

718 changes: 718 additions & 0 deletions naga/tests/out/ir/spec-constants.ron

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
@@ -956,6 +956,7 @@ fn convert_spv_all() {
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
);
convert_spv("builtin-accessed-outside-entrypoint", true, Targets::WGSL);
convert_spv("spec-constants", true, Targets::IR);
}

#[cfg(feature = "glsl-in")]

0 comments on commit 9df6819

Please sign in to comment.