Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Use const ctx instead of global ctx for const resolution #6935

Draft
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions naga/src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ impl<'source> Lowerer<'source, '_> {
}
ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
ast::ConstructorType::Vector { size, ty, ty_span } => {
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
let scalar = match ctx.module.types[ty].inner {
crate::TypeInner::Scalar(sc) => sc,
_ => return Err(Error::UnknownScalarType(ty_span)),
Expand All @@ -596,7 +596,7 @@ impl<'source> Lowerer<'source, '_> {
ty,
ty_span,
} => {
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
let scalar = match ctx.module.types[ty].inner {
crate::TypeInner::Scalar(sc) => sc,
_ => return Err(Error::UnknownScalarType(ty_span)),
Expand All @@ -613,8 +613,8 @@ impl<'source> Lowerer<'source, '_> {
}
ast::ConstructorType::PartialArray => Constructor::PartialArray,
ast::ConstructorType::Array { base, size } => {
let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
let size = self.array_size(size, &mut ctx.as_global())?;
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
let size = self.array_size(size, &mut ctx.as_const())?;

self.layouter.update(ctx.module.to_ctx()).unwrap();
let stride = self.layouter[base].to_stride();
Expand Down
94 changes: 48 additions & 46 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
}
}

#[allow(dead_code)]
fn as_global(&mut self) -> GlobalContext<'a, '_, '_> {
GlobalContext {
ast_expressions: self.ast_expressions,
Expand Down Expand Up @@ -453,29 +454,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.map_err(|e| Error::ConstantEvaluatorError(e.into(), span))
}

fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
fn const_eval_expr_to_u32(
&self,
handle: Handle<crate::Expression>,
) -> Result<u32, crate::proc::U32EvalError> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
return Err(crate::proc::U32EvalError::NonConst);
}

self.module
.to_ctx()
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant(Some(ref ctx)) => {
assert!(ctx.local_expression_kind_tracker.is_const(handle));
self.module
.to_ctx()
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant(None) => {
self.module.to_ctx().eval_expr_to_u32(handle).ok()
}
ExpressionContextType::Override => None,
ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle),
ExpressionContextType::Override => Err(crate::proc::U32EvalError::NonConst),
}
}

Expand Down Expand Up @@ -1057,7 +1057,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::GlobalDeclKind::Var(ref v) => {
let explicit_ty =
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
.transpose()?;

let (ty, initializer) =
Expand Down Expand Up @@ -1093,7 +1093,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty;
if let Some(explicit_ty) = c.ty {
let explicit_ty =
self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?;
self.resolve_ast_type(explicit_ty, &mut ectx.as_const())?;
let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
init = ectx
.try_automatic_conversions(init, &explicit_ty_res, c.name.span)
Expand Down Expand Up @@ -1125,7 +1125,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::GlobalDeclKind::Override(ref o) => {
let explicit_ty =
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
.transpose()?;

let mut ectx = ctx.as_override();
Expand Down Expand Up @@ -1167,7 +1167,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = self.resolve_named_ast_type(
alias.ty,
Some(alias.name.name.to_string()),
&mut ctx,
&mut ctx.as_const(),
)?;
ctx.globals
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
Expand Down Expand Up @@ -1254,7 +1254,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.iter()
.enumerate()
.map(|(i, arg)| -> Result<_, Error<'_>> {
let ty = self.resolve_ast_type(arg.ty, ctx)?;
let ty = self.resolve_ast_type(arg.ty, &mut ctx.as_const())?;
let expr = expressions
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr)));
Expand All @@ -1273,7 +1273,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.result
.as_ref()
.map(|res| -> Result<_, Error<'_>> {
let ty = self.resolve_ast_type(res.ty, ctx)?;
let ty = self.resolve_ast_type(res.ty, &mut ctx.as_const())?;
Ok(crate::FunctionResult {
ty,
binding: self.binding(&res.binding, ty, ctx)?,
Expand Down Expand Up @@ -1430,9 +1430,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// optimization.
ctx.local_expression_kind_tracker.force_non_const(value);

let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
.transpose()?;
let explicit_ty = l
.ty
.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter)))
.transpose()?;

if let Some(ty) = explicit_ty {
let mut ctx = ctx.as_expression(block, &mut emitter);
Expand All @@ -1459,12 +1460,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(());
}
ast::LocalDecl::Var(ref v) => {
let explicit_ty =
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
.transpose()?;

let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let explicit_ty =
v.ty.map(|ast| {
self.resolve_ast_type(ast, &mut ctx.as_const(block, &mut emitter))
})
.transpose()?;

let mut ectx = ctx.as_expression(block, &mut emitter);

let ty;
Expand Down Expand Up @@ -1555,7 +1559,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

if let Some(explicit_ty) = c.ty {
let explicit_ty =
self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?;
self.resolve_ast_type(explicit_ty, &mut ectx.as_const())?;
let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
init = ectx
.try_automatic_conversions(init, &explicit_ty_res, c.name.span)
Expand Down Expand Up @@ -1614,7 +1618,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let uint =
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
block.extend(emitter.finish(&ctx.function.expressions));

let cases = cases
.iter()
Expand All @@ -1623,8 +1626,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
value: match case.value {
ast::SwitchValue::Expr(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let expr =
self.expression(expr, &mut ctx.as_global().as_const())?;
let expr = self
.expression(expr, &mut ctx.as_const(block, &mut emitter))?;
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
Some(crate::Literal::I32(value)) if !uint => {
crate::SwitchValue::I32(value)
Expand All @@ -1645,6 +1648,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
.collect::<Result<_, _>>()?;

block.extend(emitter.finish(&ctx.function.expressions));

crate::Statement::Switch { selector, cases }
}
ast::StatementKind::Loop {
Expand Down Expand Up @@ -2010,7 +2015,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}

lowered_base.map(|base| match ctx.const_access(index) {
lowered_base.map(|base| match ctx.const_eval_expr_to_u32(index).ok() {
Some(index) => crate::Expression::AccessIndex { base, index },
None => crate::Expression::Access { base, index },
})
Expand Down Expand Up @@ -2088,7 +2093,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::Expression::Bitcast { expr, to, ty_span } => {
let expr = self.expression(expr, ctx)?;
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?;
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_const())?;

let element_scalar = match ctx.module.types[to_resolved].inner {
crate::TypeInner::Scalar(scalar) => scalar,
Expand Down Expand Up @@ -2933,7 +2938,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let offset = args
.next()
.map(|arg| self.expression(arg, &mut ctx.as_global().as_const()))
.map(|arg| self.expression(arg, &mut ctx.as_const()))
.ok()
.transpose()?;

Expand Down Expand Up @@ -3036,7 +3041,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut members = Vec::with_capacity(s.members.len());

for member in s.members.iter() {
let ty = self.resolve_ast_type(member.ty, ctx)?;
let ty = self.resolve_ast_type(member.ty, &mut ctx.as_const())?;

self.layouter.update(ctx.module.to_ctx()).unwrap();

Expand Down Expand Up @@ -3123,25 +3128,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
fn array_size(
&mut self,
size: ast::ArraySize<'source>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<crate::ArraySize, Error<'source>> {
Ok(match size {
ast::ArraySize::Constant(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let const_expr = self.expression(expr, &mut ctx.as_const());
match const_expr {
Ok(value) => {
let len =
ctx.module.to_ctx().eval_expr_to_u32(value).map_err(
|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
},
)?;
let len = ctx.const_eval_expr_to_u32(value).map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
})?;
let size =
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
Expand All @@ -3152,7 +3154,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::proc::ConstantEvaluatorError::OverrideExpr => {
crate::ArraySize::Pending(self.array_size_override(
expr,
&mut ctx.as_override(),
&mut ctx.as_global().as_override(),
span,
)?)
}
Expand Down Expand Up @@ -3204,7 +3206,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&mut self,
handle: Handle<ast::Type<'source>>,
name: Option<String>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Type>, Error<'source>> {
let inner = match ctx.types[handle] {
ast::Type::Scalar(scalar) => scalar.to_inner_scalar(),
Expand Down Expand Up @@ -3242,7 +3244,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::TypeInner::Pointer { base, space }
}
ast::Type::Array { base, size } => {
let base = self.resolve_ast_type(base, ctx)?;
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
let size = self.array_size(size, ctx)?;

self.layouter.update(ctx.module.to_ctx()).unwrap();
Expand Down Expand Up @@ -3282,14 +3284,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};

Ok(ctx.ensure_type_exists(name, inner))
Ok(ctx.as_global().ensure_type_exists(name, inner))
}

/// Return a Naga `Handle<Type>` representing the front-end type `handle`.
fn resolve_ast_type(
&mut self,
handle: Handle<ast::Type<'source>>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Type>, Error<'source>> {
self.resolve_named_ast_type(handle, None, ctx)
}
Expand Down
Loading