diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index e52d4776ab..930d7790f4 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -579,7 +579,7 @@ impl<'source> Lowerer<'source, '_> { } ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, ast::ConstructorType::Vector { size, ty, ty_span } => { - let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?; + let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; let scalar = match ctx.module.types[ty].inner { crate::TypeInner::Scalar(sc) => sc, _ => return Err(Error::UnknownScalarType(ty_span)), @@ -596,7 +596,7 @@ impl<'source> Lowerer<'source, '_> { ty, ty_span, } => { - let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?; + let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; let scalar = match ctx.module.types[ty].inner { crate::TypeInner::Scalar(sc) => sc, _ => return Err(Error::UnknownScalarType(ty_span)), @@ -613,8 +613,8 @@ impl<'source> Lowerer<'source, '_> { } ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { - let base = self.resolve_ast_type(base, &mut ctx.as_global())?; - let size = self.array_size(size, &mut ctx.as_global())?; + let base = self.resolve_ast_type(base, &mut ctx.as_const())?; + let size = self.array_size(size, &mut ctx.as_const())?; self.layouter.update(ctx.module.to_ctx()).unwrap(); let stride = self.layouter[base].to_stride(); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 24f521b176..238d8f67c2 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -237,6 +237,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { } } + #[allow(dead_code)] fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { GlobalContext { ast_expressions: self.ast_expressions, @@ -453,29 +454,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .map_err(|e| Error::ConstantEvaluatorError(e.into(), span)) } - fn const_access(&self, handle: Handle) -> Option { + fn const_eval_expr_to_u32( + &self, + handle: Handle, + ) -> Result { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { if !ctx.local_expression_kind_tracker.is_const(handle) { - return None; + return Err(crate::proc::U32EvalError::NonConst); } self.module .to_ctx() .eval_expr_to_u32_from(handle, &ctx.function.expressions) - .ok() } ExpressionContextType::Constant(Some(ref ctx)) => { assert!(ctx.local_expression_kind_tracker.is_const(handle)); self.module .to_ctx() .eval_expr_to_u32_from(handle, &ctx.function.expressions) - .ok() } - ExpressionContextType::Constant(None) => { - self.module.to_ctx().eval_expr_to_u32(handle).ok() - } - ExpressionContextType::Override => None, + ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle), + ExpressionContextType::Override => Err(crate::proc::U32EvalError::NonConst), } } @@ -1057,7 +1057,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Var(ref v) => { let explicit_ty = - v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx)) + v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const())) .transpose()?; let (ty, initializer) = @@ -1093,7 +1093,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty; if let Some(explicit_ty) = c.ty { let explicit_ty = - self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + self.resolve_ast_type(explicit_ty, &mut ectx.as_const())?; let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); init = ectx .try_automatic_conversions(init, &explicit_ty_res, c.name.span) @@ -1125,7 +1125,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Override(ref o) => { let explicit_ty = - o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx)) + o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const())) .transpose()?; let mut ectx = ctx.as_override(); @@ -1167,7 +1167,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = self.resolve_named_ast_type( alias.ty, Some(alias.name.name.to_string()), - &mut ctx, + &mut ctx.as_const(), )?; ctx.globals .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); @@ -1254,7 +1254,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .iter() .enumerate() .map(|(i, arg)| -> Result<_, Error<'_>> { - let ty = self.resolve_ast_type(arg.ty, ctx)?; + let ty = self.resolve_ast_type(arg.ty, &mut ctx.as_const())?; let expr = expressions .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr))); @@ -1273,7 +1273,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .result .as_ref() .map(|res| -> Result<_, Error<'_>> { - let ty = self.resolve_ast_type(res.ty, ctx)?; + let ty = self.resolve_ast_type(res.ty, &mut ctx.as_const())?; Ok(crate::FunctionResult { ty, binding: self.binding(&res.binding, ty, ctx)?, @@ -1430,9 +1430,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // optimization. ctx.local_expression_kind_tracker.force_non_const(value); - let explicit_ty = - l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) - .transpose()?; + let explicit_ty = l + .ty + .map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter))) + .transpose()?; if let Some(ty) = explicit_ty { let mut ctx = ctx.as_expression(block, &mut emitter); @@ -1459,12 +1460,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(()); } ast::LocalDecl::Var(ref v) => { - let explicit_ty = - v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global())) - .transpose()?; - let mut emitter = Emitter::default(); emitter.start(&ctx.function.expressions); + + let explicit_ty = + v.ty.map(|ast| { + self.resolve_ast_type(ast, &mut ctx.as_const(block, &mut emitter)) + }) + .transpose()?; + let mut ectx = ctx.as_expression(block, &mut emitter); let ty; @@ -1555,7 +1559,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { if let Some(explicit_ty) = c.ty { let explicit_ty = - self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + self.resolve_ast_type(explicit_ty, &mut ectx.as_const())?; let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); init = ectx .try_automatic_conversions(init, &explicit_ty_res, c.name.span) @@ -1614,7 +1618,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let uint = resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); - block.extend(emitter.finish(&ctx.function.expressions)); let cases = cases .iter() @@ -1623,8 +1626,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value: match case.value { ast::SwitchValue::Expr(expr) => { let span = ctx.ast_expressions.get_span(expr); - let expr = - self.expression(expr, &mut ctx.as_global().as_const())?; + let expr = self + .expression(expr, &mut ctx.as_const(block, &mut emitter))?; match ctx.module.to_ctx().eval_expr_to_literal(expr) { Some(crate::Literal::I32(value)) if !uint => { crate::SwitchValue::I32(value) @@ -1645,6 +1648,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) .collect::>()?; + block.extend(emitter.finish(&ctx.function.expressions)); + crate::Statement::Switch { selector, cases } } ast::StatementKind::Loop { @@ -2010,7 +2015,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } - lowered_base.map(|base| match ctx.const_access(index) { + lowered_base.map(|base| match ctx.const_eval_expr_to_u32(index).ok() { Some(index) => crate::Expression::AccessIndex { base, index }, None => crate::Expression::Access { base, index }, }) @@ -2088,7 +2093,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Expression::Bitcast { expr, to, ty_span } => { let expr = self.expression(expr, ctx)?; - let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; + let to_resolved = self.resolve_ast_type(to, &mut ctx.as_const())?; let element_scalar = match ctx.module.types[to_resolved].inner { crate::TypeInner::Scalar(scalar) => scalar, @@ -2933,7 +2938,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let offset = args .next() - .map(|arg| self.expression(arg, &mut ctx.as_global().as_const())) + .map(|arg| self.expression(arg, &mut ctx.as_const())) .ok() .transpose()?; @@ -3036,7 +3041,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut members = Vec::with_capacity(s.members.len()); for member in s.members.iter() { - let ty = self.resolve_ast_type(member.ty, ctx)?; + let ty = self.resolve_ast_type(member.ty, &mut ctx.as_const())?; self.layouter.update(ctx.module.to_ctx()).unwrap(); @@ -3123,7 +3128,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn array_size( &mut self, size: ast::ArraySize<'source>, - ctx: &mut GlobalContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result> { Ok(match size { ast::ArraySize::Constant(expr) => { @@ -3131,17 +3136,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let const_expr = self.expression(expr, &mut ctx.as_const()); match const_expr { Ok(value) => { - let len = - ctx.module.to_ctx().eval_expr_to_u32(value).map_err( - |err| match err { - crate::proc::U32EvalError::NonConst => { - Error::ExpectedConstExprConcreteIntegerScalar(span) - } - crate::proc::U32EvalError::Negative => { - Error::ExpectedPositiveArrayLength(span) - } - }, - )?; + let len = ctx.const_eval_expr_to_u32(value).map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedPositiveArrayLength(span) + } + })?; let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; crate::ArraySize::Constant(size) @@ -3152,7 +3154,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::proc::ConstantEvaluatorError::OverrideExpr => { crate::ArraySize::Pending(self.array_size_override( expr, - &mut ctx.as_override(), + &mut ctx.as_global().as_override(), span, )?) } @@ -3204,7 +3206,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, handle: Handle>, name: Option, - ctx: &mut GlobalContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let inner = match ctx.types[handle] { ast::Type::Scalar(scalar) => scalar.to_inner_scalar(), @@ -3242,7 +3244,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::TypeInner::Pointer { base, space } } ast::Type::Array { base, size } => { - let base = self.resolve_ast_type(base, ctx)?; + let base = self.resolve_ast_type(base, &mut ctx.as_const())?; let size = self.array_size(size, ctx)?; self.layouter.update(ctx.module.to_ctx()).unwrap(); @@ -3282,14 +3284,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } }; - Ok(ctx.ensure_type_exists(name, inner)) + Ok(ctx.as_global().ensure_type_exists(name, inner)) } /// Return a Naga `Handle` representing the front-end type `handle`. fn resolve_ast_type( &mut self, handle: Handle>, - ctx: &mut GlobalContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { self.resolve_named_ast_type(handle, None, ctx) }