From 688c06bd0184f14439874a1b2a0cda5ce49c2123 Mon Sep 17 00:00:00 2001 From: Jamie Nicol Date: Thu, 9 Jan 2025 12:46:16 +0000 Subject: [PATCH] [naga wgsl-in] Return error if wgsl parser recurses too deeply. It's currently trivial to write a shader that causes the wgsl parser to recurse too deeply and overflow the stack. This patch makes the parser return an error when recursing too deeply, before the stack overflows. It makes use of a new function Parser::track_recursion(). This increments a counter returning an error if the value is too high, before calling the user-provided function and returning its return value after decrementing the counter again. Any recursively-called functions can simply be modified to call track_recursion(), providing their previous contents in a closure as the argument. All instances of recursion during parsing call through either Parser::statement(), Parser::unary_expression(), or Parser::type_decl(), so only these functions have been updated as described in order to keep the patch as unobtrusive as possible. A value of 256 has been chosen as the recursion limit, but can be later tweaked if required. This avoids the stack overflow in the testcase attached to issue #5757. --- naga/src/front/wgsl/parse/mod.rs | 821 ++++++++++++++++--------------- 1 file changed, 424 insertions(+), 397 deletions(-) diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 4ed88efb34..498abd526a 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -261,15 +261,20 @@ impl<'a> BindingParser<'a> { pub struct Parser { rules: Vec<(Rule, usize)>, + recursion_depth: u32, } impl Parser { pub const fn new() -> Self { - Parser { rules: Vec::new() } + Parser { + rules: Vec::new(), + recursion_depth: 0, + } } fn reset(&mut self) { self.rules.clear(); + self.recursion_depth = 0; } fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) { @@ -296,6 +301,19 @@ impl Parser { ) } + fn track_recursion<'a, F, R>(&mut self, f: F) -> Result> + where + F: FnOnce(&mut Self) -> Result>, + { + self.recursion_depth += 1; + if self.recursion_depth >= 256 { + return Err(Error::Internal("Parser recursion limit exceeded")); + } + let ret = f(self); + self.recursion_depth -= 1; + ret + } + fn switch_value<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -860,58 +878,60 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { - self.push_rule_span(Rule::UnaryExpr, lexer); - //TODO: refactor this to avoid backing up - let expr = match lexer.peek().0 { - Token::Operation('-') => { - let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx)?; - let expr = ast::Expression::Unary { - op: crate::UnaryOperator::Negate, - expr, - }; - let span = self.peek_rule_span(lexer); - ctx.expressions.append(expr, span) - } - Token::Operation('!') => { - let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx)?; - let expr = ast::Expression::Unary { - op: crate::UnaryOperator::LogicalNot, - expr, - }; - let span = self.peek_rule_span(lexer); - ctx.expressions.append(expr, span) - } - Token::Operation('~') => { - let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx)?; - let expr = ast::Expression::Unary { - op: crate::UnaryOperator::BitwiseNot, - expr, - }; - let span = self.peek_rule_span(lexer); - ctx.expressions.append(expr, span) - } - Token::Operation('*') => { - let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx)?; - let expr = ast::Expression::Deref(expr); - let span = self.peek_rule_span(lexer); - ctx.expressions.append(expr, span) - } - Token::Operation('&') => { - let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx)?; - let expr = ast::Expression::AddrOf(expr); - let span = self.peek_rule_span(lexer); - ctx.expressions.append(expr, span) - } - _ => self.singular_expression(lexer, ctx)?, - }; + self.track_recursion(|this| { + this.push_rule_span(Rule::UnaryExpr, lexer); + //TODO: refactor this to avoid backing up + let expr = match lexer.peek().0 { + Token::Operation('-') => { + let _ = lexer.next(); + let expr = this.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr, + }; + let span = this.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('!') => { + let _ = lexer.next(); + let expr = this.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::LogicalNot, + expr, + }; + let span = this.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('~') => { + let _ = lexer.next(); + let expr = this.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::BitwiseNot, + expr, + }; + let span = this.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('*') => { + let _ = lexer.next(); + let expr = this.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Deref(expr); + let span = this.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('&') => { + let _ = lexer.next(); + let expr = this.unary_expression(lexer, ctx)?; + let expr = ast::Expression::AddrOf(expr); + let span = this.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + _ => this.singular_expression(lexer, ctx)?, + }; - self.pop_rule_span(lexer); - Ok(expr) + this.pop_rule_span(lexer); + Ok(expr) + }) } /// Parse a `singular_expression`. @@ -1644,25 +1664,27 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { - self.push_rule_span(Rule::TypeDecl, lexer); - - let (name, span) = lexer.next_ident_with_span()?; - - let ty = match self.type_decl_impl(lexer, name, ctx)? { - Some(ty) => ty, - None => { - ctx.unresolved.insert(ast::Dependency { - ident: name, - usage: span, - }); - ast::Type::User(ast::Ident { name, span }) - } - }; + self.track_recursion(|this| { + this.push_rule_span(Rule::TypeDecl, lexer); + + let (name, span) = lexer.next_ident_with_span()?; + + let ty = match this.type_decl_impl(lexer, name, ctx)? { + Some(ty) => ty, + None => { + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: span, + }); + ast::Type::User(ast::Ident { name, span }) + } + }; - self.pop_rule_span(lexer); + this.pop_rule_span(lexer); - let handle = ctx.types.append(ty, Span::UNDEFINED); - Ok(handle) + let handle = ctx.types.append(ty, Span::UNDEFINED); + Ok(handle) + }) } fn assignment_op_and_rhs<'a>( @@ -1806,291 +1828,235 @@ impl Parser { block: &mut ast::Block<'a>, brace_nesting_level: u8, ) -> Result<(), Error<'a>> { - self.push_rule_span(Rule::Statement, lexer); - match lexer.peek() { - (Token::Separator(';'), _) => { - let _ = lexer.next(); - self.pop_rule_span(lexer); - } - (Token::Paren('{') | Token::Attribute, _) => { - let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?; - block.stmts.push(ast::Statement { - kind: ast::StatementKind::Block(inner), - span, - }); - self.pop_rule_span(lexer); - } - (Token::Word(word), _) => { - let kind = match word { - "_" => { - let _ = lexer.next(); - lexer.expect(Token::Operation('='))?; - let expr = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Separator(';'))?; + self.track_recursion(|this| { + this.push_rule_span(Rule::Statement, lexer); + match lexer.peek() { + (Token::Separator(';'), _) => { + let _ = lexer.next(); + this.pop_rule_span(lexer); + } + (Token::Paren('{') | Token::Attribute, _) => { + let (inner, span) = this.block(lexer, ctx, brace_nesting_level)?; + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(inner), + span, + }); + this.pop_rule_span(lexer); + } + (Token::Word(word), _) => { + let kind = match word { + "_" => { + let _ = lexer.next(); + lexer.expect(Token::Operation('='))?; + let expr = this.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + ast::StatementKind::Phony(expr) + } + "let" => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = this.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = this.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let { + name, + ty: given_ty, + init: expr_id, + handle, + })) + } + "const" => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = this.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = this.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Const(ast::LocalConst { + name, + ty: given_ty, + init: expr_id, + handle, + })) + } + "var" => { + let _ = lexer.next(); + + let name = lexer.next_ident()?; + let ty = if lexer.skip(Token::Separator(':')) { + let ty = this.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; - ast::StatementKind::Phony(expr) - } - "let" => { - let _ = lexer.next(); - let name = lexer.next_ident()?; - - let given_ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx)?; - Some(ty) - } else { - None - }; - lexer.expect(Token::Operation('='))?; - let expr_id = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Separator(';'))?; + let init = if lexer.skip(Token::Operation('=')) { + let init = this.general_expression(lexer, ctx)?; + Some(init) + } else { + None + }; - let handle = ctx.declare_local(name)?; - ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let { - name, - ty: given_ty, - init: expr_id, - handle, - })) - } - "const" => { - let _ = lexer.next(); - let name = lexer.next_ident()?; - - let given_ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx)?; - Some(ty) - } else { - None - }; - lexer.expect(Token::Operation('='))?; - let expr_id = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Separator(';'))?; + lexer.expect(Token::Separator(';'))?; - let handle = ctx.declare_local(name)?; - ast::StatementKind::LocalDecl(ast::LocalDecl::Const(ast::LocalConst { - name, - ty: given_ty, - init: expr_id, - handle, - })) - } - "var" => { - let _ = lexer.next(); - - let name = lexer.next_ident()?; - let ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx)?; - Some(ty) - } else { - None - }; - - let init = if lexer.skip(Token::Operation('=')) { - let init = self.general_expression(lexer, ctx)?; - Some(init) - } else { - None - }; + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable { + name, + ty, + init, + handle, + })) + } + "return" => { + let _ = lexer.next(); + let value = if lexer.peek().0 != Token::Separator(';') { + let handle = this.general_expression(lexer, ctx)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Return { value } + } + "if" => { + let _ = lexer.next(); + let condition = this.general_expression(lexer, ctx)?; - lexer.expect(Token::Separator(';'))?; + let accept = this.block(lexer, ctx, brace_nesting_level)?.0; - let handle = ctx.declare_local(name)?; - ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable { - name, - ty, - init, - handle, - })) - } - "return" => { - let _ = lexer.next(); - let value = if lexer.peek().0 != Token::Separator(';') { - let handle = self.general_expression(lexer, ctx)?; - Some(handle) - } else { - None - }; - lexer.expect(Token::Separator(';'))?; - ast::StatementKind::Return { value } - } - "if" => { - let _ = lexer.next(); - let condition = self.general_expression(lexer, ctx)?; + let mut elsif_stack = Vec::new(); + let mut elseif_span_start = lexer.start_byte_offset(); + let mut reject = loop { + if !lexer.skip(Token::Word("else")) { + break ast::Block::default(); + } - let accept = self.block(lexer, ctx, brace_nesting_level)?.0; + if !lexer.skip(Token::Word("if")) { + // ... else { ... } + break this.block(lexer, ctx, brace_nesting_level)?.0; + } - let mut elsif_stack = Vec::new(); - let mut elseif_span_start = lexer.start_byte_offset(); - let mut reject = loop { - if !lexer.skip(Token::Word("else")) { - break ast::Block::default(); - } + // ... else if (...) { ... } + let other_condition = this.general_expression(lexer, ctx)?; + let other_block = this.block(lexer, ctx, brace_nesting_level)?; + elsif_stack.push((elseif_span_start, other_condition, other_block)); + elseif_span_start = lexer.start_byte_offset(); + }; - if !lexer.skip(Token::Word("if")) { - // ... else { ... } - break self.block(lexer, ctx, brace_nesting_level)?.0; + // reverse-fold the else-if blocks + //Note: we may consider uplifting this to the IR + for (other_span_start, other_cond, other_block) in + elsif_stack.into_iter().rev() + { + let sub_stmt = ast::StatementKind::If { + condition: other_cond, + accept: other_block.0, + reject, + }; + reject = ast::Block::default(); + let span = lexer.span_from(other_span_start); + reject.stmts.push(ast::Statement { + kind: sub_stmt, + span, + }) } - // ... else if (...) { ... } - let other_condition = self.general_expression(lexer, ctx)?; - let other_block = self.block(lexer, ctx, brace_nesting_level)?; - elsif_stack.push((elseif_span_start, other_condition, other_block)); - elseif_span_start = lexer.start_byte_offset(); - }; - - // reverse-fold the else-if blocks - //Note: we may consider uplifting this to the IR - for (other_span_start, other_cond, other_block) in - elsif_stack.into_iter().rev() - { - let sub_stmt = ast::StatementKind::If { - condition: other_cond, - accept: other_block.0, + ast::StatementKind::If { + condition, + accept, reject, - }; - reject = ast::Block::default(); - let span = lexer.span_from(other_span_start); - reject.stmts.push(ast::Statement { - kind: sub_stmt, - span, - }) - } - - ast::StatementKind::If { - condition, - accept, - reject, + } } - } - "switch" => { - let _ = lexer.next(); - let selector = self.general_expression(lexer, ctx)?; - let brace_span = lexer.expect_span(Token::Paren('{'))?; - let brace_nesting_level = - Self::increase_brace_nesting(brace_nesting_level, brace_span)?; - let mut cases = Vec::new(); - - loop { - // cases + default - match lexer.next() { - (Token::Word("case"), _) => { - // parse a list of values - let value = loop { - let value = self.switch_value(lexer, ctx)?; - if lexer.skip(Token::Separator(',')) { - if lexer.skip(Token::Separator(':')) { + "switch" => { + let _ = lexer.next(); + let selector = this.general_expression(lexer, ctx)?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; + let mut cases = Vec::new(); + + loop { + // cases + default + match lexer.next() { + (Token::Word("case"), _) => { + // parse a list of values + let value = loop { + let value = this.switch_value(lexer, ctx)?; + if lexer.skip(Token::Separator(',')) { + if lexer.skip(Token::Separator(':')) { + break value; + } + } else { + lexer.skip(Token::Separator(':')); break value; } - } else { - lexer.skip(Token::Separator(':')); - break value; - } + cases.push(ast::SwitchCase { + value, + body: ast::Block::default(), + fall_through: true, + }); + }; + + let body = this.block(lexer, ctx, brace_nesting_level)?.0; + cases.push(ast::SwitchCase { value, - body: ast::Block::default(), - fall_through: true, + body, + fall_through: false, }); - }; - - let body = self.block(lexer, ctx, brace_nesting_level)?.0; - - cases.push(ast::SwitchCase { - value, - body, - fall_through: false, - }); - } - (Token::Word("default"), _) => { - lexer.skip(Token::Separator(':')); - let body = self.block(lexer, ctx, brace_nesting_level)?.0; - cases.push(ast::SwitchCase { - value: ast::SwitchValue::Default, - body, - fall_through: false, - }); - } - (Token::Paren('}'), _) => break, - (_, span) => { - return Err(Error::Unexpected(span, ExpectedToken::SwitchItem)) + } + (Token::Word("default"), _) => { + lexer.skip(Token::Separator(':')); + let body = this.block(lexer, ctx, brace_nesting_level)?.0; + cases.push(ast::SwitchCase { + value: ast::SwitchValue::Default, + body, + fall_through: false, + }); + } + (Token::Paren('}'), _) => break, + (_, span) => { + return Err(Error::Unexpected( + span, + ExpectedToken::SwitchItem, + )) + } } } - } - - ast::StatementKind::Switch { selector, cases } - } - "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?, - "while" => { - let _ = lexer.next(); - let mut body = ast::Block::default(); - - let (condition, span) = - lexer.capture_span(|lexer| self.general_expression(lexer, ctx))?; - let mut reject = ast::Block::default(); - reject.stmts.push(ast::Statement { - kind: ast::StatementKind::Break, - span, - }); - - body.stmts.push(ast::Statement { - kind: ast::StatementKind::If { - condition, - accept: ast::Block::default(), - reject, - }, - span, - }); - let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; - body.stmts.push(ast::Statement { - kind: ast::StatementKind::Block(block), - span, - }); - - ast::StatementKind::Loop { - body, - continuing: ast::Block::default(), - break_if: None, + ast::StatementKind::Switch { selector, cases } } - } - "for" => { - let _ = lexer.next(); - lexer.expect(Token::Paren('('))?; - - ctx.local_table.push_scope(); - - if !lexer.skip(Token::Separator(';')) { - let num_statements = block.stmts.len(); - let (_, span) = { - let ctx = &mut *ctx; - let block = &mut *block; - lexer.capture_span(|lexer| { - self.statement(lexer, ctx, block, brace_nesting_level) - })? - }; + "loop" => this.r#loop(lexer, ctx, brace_nesting_level)?, + "while" => { + let _ = lexer.next(); + let mut body = ast::Block::default(); - if block.stmts.len() != num_statements { - match block.stmts.last().unwrap().kind { - ast::StatementKind::Call { .. } - | ast::StatementKind::Assign { .. } - | ast::StatementKind::LocalDecl(_) => {} - _ => return Err(Error::InvalidForInitializer(span)), - } - } - }; - - let mut body = ast::Block::default(); - if !lexer.skip(Token::Separator(';')) { let (condition, span) = - lexer.capture_span(|lexer| -> Result<_, Error<'_>> { - let condition = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Separator(';'))?; - Ok(condition) - })?; + lexer.capture_span(|lexer| this.general_expression(lexer, ctx))?; let mut reject = ast::Block::default(); reject.stmts.push(ast::Statement { kind: ast::StatementKind::Break, span, }); + body.stmts.push(ast::Statement { kind: ast::StatementKind::If { condition, @@ -2099,88 +2065,149 @@ impl Parser { }, span, }); - }; - - let mut continuing = ast::Block::default(); - if !lexer.skip(Token::Paren(')')) { - self.function_call_or_assignment_statement( - lexer, - ctx, - &mut continuing, - )?; - lexer.expect(Token::Paren(')'))?; + + let (block, span) = this.block(lexer, ctx, brace_nesting_level)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ast::StatementKind::Loop { + body, + continuing: ast::Block::default(), + break_if: None, + } } + "for" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + + ctx.local_table.push_scope(); + + if !lexer.skip(Token::Separator(';')) { + let num_statements = block.stmts.len(); + let (_, span) = { + let ctx = &mut *ctx; + let block = &mut *block; + lexer.capture_span(|lexer| { + this.statement(lexer, ctx, block, brace_nesting_level) + })? + }; + + if block.stmts.len() != num_statements { + match block.stmts.last().unwrap().kind { + ast::StatementKind::Call { .. } + | ast::StatementKind::Assign { .. } + | ast::StatementKind::LocalDecl(_) => {} + _ => return Err(Error::InvalidForInitializer(span)), + } + } + }; - let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; - body.stmts.push(ast::Statement { - kind: ast::StatementKind::Block(block), - span, - }); + let mut body = ast::Block::default(); + if !lexer.skip(Token::Separator(';')) { + let (condition, span) = + lexer.capture_span(|lexer| -> Result<_, Error<'_>> { + let condition = this.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + Ok(condition) + })?; + let mut reject = ast::Block::default(); + reject.stmts.push(ast::Statement { + kind: ast::StatementKind::Break, + span, + }); + body.stmts.push(ast::Statement { + kind: ast::StatementKind::If { + condition, + accept: ast::Block::default(), + reject, + }, + span, + }); + }; - ctx.local_table.pop_scope(); + let mut continuing = ast::Block::default(); + if !lexer.skip(Token::Paren(')')) { + this.function_call_or_assignment_statement( + lexer, + ctx, + &mut continuing, + )?; + lexer.expect(Token::Paren(')'))?; + } - ast::StatementKind::Loop { - body, - continuing, - break_if: None, + let (block, span) = this.block(lexer, ctx, brace_nesting_level)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ctx.local_table.pop_scope(); + + ast::StatementKind::Loop { + body, + continuing, + break_if: None, + } } - } - "break" => { - let (_, span) = lexer.next(); - // Check if the next token is an `if`, this indicates - // that the user tried to type out a `break if` which - // is illegal in this position. - let (peeked_token, peeked_span) = lexer.peek(); - if let Token::Word("if") = peeked_token { - let span = span.until(&peeked_span); - return Err(Error::InvalidBreakIf(span)); + "break" => { + let (_, span) = lexer.next(); + // Check if the next token is an `if`, this indicates + // that the user tried to type out a `break if` which + // is illegal in this position. + let (peeked_token, peeked_span) = lexer.peek(); + if let Token::Word("if") = peeked_token { + let span = span.until(&peeked_span); + return Err(Error::InvalidBreakIf(span)); + } + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Break } - lexer.expect(Token::Separator(';'))?; - ast::StatementKind::Break - } - "continue" => { - let _ = lexer.next(); - lexer.expect(Token::Separator(';'))?; - ast::StatementKind::Continue - } - "discard" => { - let _ = lexer.next(); - lexer.expect(Token::Separator(';'))?; - ast::StatementKind::Kill - } - // https://www.w3.org/TR/WGSL/#const-assert-statement - "const_assert" => { - let _ = lexer.next(); - // parentheses are optional - let paren = lexer.skip(Token::Paren('(')); + "continue" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Continue + } + "discard" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Kill + } + // https://www.w3.org/TR/WGSL/#const-assert-statement + "const_assert" => { + let _ = lexer.next(); + // parentheses are optional + let paren = lexer.skip(Token::Paren('(')); - let condition = self.general_expression(lexer, ctx)?; + let condition = this.general_expression(lexer, ctx)?; - if paren { - lexer.expect(Token::Paren(')'))?; + if paren { + lexer.expect(Token::Paren(')'))?; + } + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::ConstAssert(condition) } - lexer.expect(Token::Separator(';'))?; - ast::StatementKind::ConstAssert(condition) - } - // assignment or a function call - _ => { - self.function_call_or_assignment_statement(lexer, ctx, block)?; - lexer.expect(Token::Separator(';'))?; - self.pop_rule_span(lexer); - return Ok(()); - } - }; + // assignment or a function call + _ => { + this.function_call_or_assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + this.pop_rule_span(lexer); + return Ok(()); + } + }; - let span = self.pop_rule_span(lexer); - block.stmts.push(ast::Statement { kind, span }); - } - _ => { - self.assignment_statement(lexer, ctx, block)?; - lexer.expect(Token::Separator(';'))?; - self.pop_rule_span(lexer); + let span = this.pop_rule_span(lexer); + block.stmts.push(ast::Statement { kind, span }); + } + _ => { + this.assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + this.pop_rule_span(lexer); + } } - } - Ok(()) + Ok(()) + }) } fn r#loop<'a>(