diff --git a/crates/hir-def/src/body.rs b/crates/hir-def/src/body.rs index 9535b5aea7c7..684eaf1c3b08 100644 --- a/crates/hir-def/src/body.rs +++ b/crates/hir-def/src/body.rs @@ -10,6 +10,7 @@ use std::ops::{Deref, Index}; use base_db::CrateId; use cfg::{CfgExpr, CfgOptions}; +use either::Either; use hir_expand::{name::Name, ExpandError, InFile}; use la_arena::{Arena, ArenaMap, Idx, RawIdx}; use rustc_hash::FxHashMap; @@ -22,7 +23,8 @@ use crate::{ db::DefDatabase, expander::Expander, hir::{ - dummy_expr_id, Binding, BindingId, Expr, ExprId, Label, LabelId, Pat, PatId, RecordFieldPat, + dummy_expr_id, Array, AsmOperand, Binding, BindingId, Expr, ExprId, ExprOrPatId, Label, + LabelId, Pat, PatId, RecordFieldPat, Statement, }, item_tree::AttrOwner, nameres::DefMap, @@ -67,9 +69,12 @@ pub type LabelSource = InFile; pub type FieldPtr = AstPtr; pub type FieldSource = InFile; -pub type PatFieldPtr = AstPtr; +pub type PatFieldPtr = AstPtr>; pub type PatFieldSource = InFile; +pub type ExprOrPatPtr = AstPtr>; +pub type ExprOrPatSource = InFile; + /// An item body together with the mapping from syntax nodes to HIR expression /// IDs. This is needed to go from e.g. a position in a file to the HIR /// expression containing it; but for type inference etc., we want to operate on @@ -83,11 +88,13 @@ pub type PatFieldSource = InFile; /// this properly for macros. #[derive(Default, Debug, Eq, PartialEq)] pub struct BodySourceMap { - expr_map: FxHashMap, + // AST expressions can create patterns in destructuring assignments. Therefore, `ExprSource` can also map + // to `PatId`, and `PatId` can also map to `ExprSource` (the other way around is unaffected). + expr_map: FxHashMap, expr_map_back: ArenaMap, pat_map: FxHashMap, - pat_map_back: ArenaMap, + pat_map_back: ArenaMap, label_map: FxHashMap, label_map_back: ArenaMap, @@ -286,7 +293,8 @@ impl Body { | Pat::Path(..) | Pat::ConstBlock(..) | Pat::Wild - | Pat::Missing => {} + | Pat::Missing + | Pat::Expr(_) => {} &Pat::Bind { subpat, .. } => { if let Some(subpat) = subpat { f(subpat); @@ -322,6 +330,143 @@ impl Body { None => true, } } + + pub fn walk_child_exprs(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) { + let expr = &self[expr_id]; + match expr { + Expr::Continue { .. } + | Expr::Const(_) + | Expr::Missing + | Expr::Path(_) + | Expr::OffsetOf(_) + | Expr::Literal(_) + | Expr::Underscore => {} + Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op { + AsmOperand::In { expr, .. } + | AsmOperand::Out { expr: Some(expr), .. } + | AsmOperand::InOut { expr, .. } => f(*expr), + AsmOperand::SplitInOut { in_expr, out_expr, .. } => { + f(*in_expr); + if let Some(out_expr) = out_expr { + f(*out_expr); + } + } + AsmOperand::Out { expr: None, .. } + | AsmOperand::Const(_) + | AsmOperand::Label(_) + | AsmOperand::Sym(_) => (), + }), + Expr::If { condition, then_branch, else_branch } => { + f(*condition); + f(*then_branch); + if let &Some(else_branch) = else_branch { + f(else_branch); + } + } + Expr::Let { expr, .. } => { + f(*expr); + } + Expr::Block { statements, tail, .. } + | Expr::Unsafe { statements, tail, .. } + | Expr::Async { statements, tail, .. } => { + for stmt in statements.iter() { + match stmt { + Statement::Let { initializer, else_branch, pat, .. } => { + if let &Some(expr) = initializer { + f(expr); + } + if let &Some(expr) = else_branch { + f(expr); + } + self.walk_exprs_in_pat(*pat, &mut f); + } + Statement::Expr { expr: expression, .. } => f(*expression), + Statement::Item => (), + } + } + if let &Some(expr) = tail { + f(expr); + } + } + Expr::Loop { body, .. } => f(*body), + Expr::Call { callee, args, .. } => { + f(*callee); + args.iter().copied().for_each(f); + } + Expr::MethodCall { receiver, args, .. } => { + f(*receiver); + args.iter().copied().for_each(f); + } + Expr::Match { expr, arms } => { + f(*expr); + arms.iter().map(|arm| arm.expr).for_each(f); + } + Expr::Break { expr, .. } + | Expr::Return { expr } + | Expr::Yield { expr } + | Expr::Yeet { expr } => { + if let &Some(expr) = expr { + f(expr); + } + } + Expr::Become { expr } => f(*expr), + Expr::RecordLit { fields, spread, .. } => { + for field in fields.iter() { + f(field.expr); + } + if let &Some(expr) = spread { + f(expr); + } + } + Expr::Closure { body, .. } => { + f(*body); + } + Expr::BinaryOp { lhs, rhs, .. } => { + f(*lhs); + f(*rhs); + } + Expr::Range { lhs, rhs, .. } => { + if let &Some(lhs) = rhs { + f(lhs); + } + if let &Some(rhs) = lhs { + f(rhs); + } + } + Expr::Index { base, index, .. } => { + f(*base); + f(*index); + } + Expr::Field { expr, .. } + | Expr::Await { expr } + | Expr::Cast { expr, .. } + | Expr::Ref { expr, .. } + | Expr::UnaryOp { expr, .. } + | Expr::Box { expr } => { + f(*expr); + } + Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f), + Expr::Array(a) => match a { + Array::ElementList { elements, .. } => elements.iter().copied().for_each(f), + Array::Repeat { initializer, repeat } => { + f(*initializer); + f(*repeat) + } + }, + &Expr::Assignment { target, value } => { + self.walk_exprs_in_pat(target, &mut f); + f(value); + } + } + } + + pub fn walk_exprs_in_pat(&self, pat_id: PatId, f: &mut impl FnMut(ExprId)) { + self.walk_pats(pat_id, &mut |pat| { + if let Pat::Expr(expr) | Pat::ConstBlock(expr) = self[pat] { + f(expr); + } + }); + } } impl Default for Body { @@ -375,11 +520,18 @@ impl Index for Body { // FIXME: Change `node_` prefix to something more reasonable. // Perhaps `expr_syntax` and `expr_id`? impl BodySourceMap { + pub fn expr_or_pat_syntax(&self, id: ExprOrPatId) -> Result { + match id { + ExprOrPatId::ExprId(id) => self.expr_syntax(id).map(|it| it.map(AstPtr::wrap_left)), + ExprOrPatId::PatId(id) => self.pat_syntax(id), + } + } + pub fn expr_syntax(&self, expr: ExprId) -> Result { self.expr_map_back.get(expr).cloned().ok_or(SyntheticSyntax) } - pub fn node_expr(&self, node: InFile<&ast::Expr>) -> Option { + pub fn node_expr(&self, node: InFile<&ast::Expr>) -> Option { let src = node.map(AstPtr::new); self.expr_map.get(&src).cloned() } @@ -395,7 +547,7 @@ impl BodySourceMap { self.expansions.iter().map(|(&a, &b)| (a, b)) } - pub fn pat_syntax(&self, pat: PatId) -> Result { + pub fn pat_syntax(&self, pat: PatId) -> Result { self.pat_map_back.get(pat).cloned().ok_or(SyntheticSyntax) } @@ -428,7 +580,7 @@ impl BodySourceMap { self.pat_field_map_back[&pat] } - pub fn macro_expansion_expr(&self, node: InFile<&ast::MacroExpr>) -> Option { + pub fn macro_expansion_expr(&self, node: InFile<&ast::MacroExpr>) -> Option { let src = node.map(AstPtr::new).map(AstPtr::upcast::).map(AstPtr::upcast); self.expr_map.get(&src).copied() } @@ -444,7 +596,11 @@ impl BodySourceMap { node: InFile<&ast::FormatArgsExpr>, ) -> Option<&[(syntax::TextRange, Name)]> { let src = node.map(AstPtr::new).map(AstPtr::upcast::); - self.template_map.as_ref()?.0.get(self.expr_map.get(&src)?).map(std::ops::Deref::deref) + self.template_map + .as_ref()? + .0 + .get(&self.expr_map.get(&src)?.as_expr()?) + .map(std::ops::Deref::deref) } pub fn asm_template_args( @@ -452,8 +608,8 @@ impl BodySourceMap { node: InFile<&ast::AsmExpr>, ) -> Option<(ExprId, &[Vec<(syntax::TextRange, usize)>])> { let src = node.map(AstPtr::new).map(AstPtr::upcast::); - let expr = self.expr_map.get(&src)?; - Some(*expr).zip(self.template_map.as_ref()?.1.get(expr).map(std::ops::Deref::deref)) + let expr = self.expr_map.get(&src)?.as_expr()?; + Some(expr).zip(self.template_map.as_ref()?.1.get(&expr).map(std::ops::Deref::deref)) } /// Get a reference to the body source map's diagnostics. diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 9c547574ecb1..4b74028b83a6 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -70,7 +70,6 @@ pub(super) fn lower( body: Body::default(), expander, current_try_block_label: None, - is_lowering_assignee_expr: false, is_lowering_coroutine: false, label_ribs: Vec::new(), current_binding_owner: None, @@ -89,7 +88,6 @@ struct ExprCollector<'a> { body: Body, source_map: BodySourceMap, - is_lowering_assignee_expr: bool, is_lowering_coroutine: bool, current_try_block_label: Option, @@ -359,14 +357,7 @@ impl ExprCollector<'_> { } else { Box::default() }; - self.alloc_expr( - Expr::Call { - callee, - args, - is_assignee_expr: self.is_lowering_assignee_expr, - }, - syntax_ptr, - ) + self.alloc_expr(Expr::Call { callee, args }, syntax_ptr) } } ast::Expr::MethodCallExpr(e) => { @@ -433,7 +424,7 @@ impl ExprCollector<'_> { let inner = self.collect_expr_opt(e.expr()); // make the paren expr point to the inner expression as well for IDE resolution let src = self.expander.in_file(syntax_ptr); - self.source_map.expr_map.insert(src, inner); + self.source_map.expr_map.insert(src, inner.into()); inner } ast::Expr::ReturnExpr(e) => { @@ -457,7 +448,6 @@ impl ExprCollector<'_> { ast::Expr::RecordExpr(e) => { let path = e.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new); - let is_assignee_expr = self.is_lowering_assignee_expr; let record_lit = if let Some(nfl) = e.record_expr_field_list() { let fields = nfl .fields() @@ -476,16 +466,9 @@ impl ExprCollector<'_> { }) .collect(); let spread = nfl.spread().map(|s| self.collect_expr(s)); - let ellipsis = nfl.dotdot_token().is_some(); - Expr::RecordLit { path, fields, spread, ellipsis, is_assignee_expr } + Expr::RecordLit { path, fields, spread } } else { - Expr::RecordLit { - path, - fields: Box::default(), - spread: None, - ellipsis: false, - is_assignee_expr, - } + Expr::RecordLit { path, fields: Box::default(), spread: None } }; self.alloc_expr(record_lit, syntax_ptr) @@ -602,12 +585,14 @@ impl ExprCollector<'_> { ast::Expr::BinExpr(e) => { let op = e.op_kind(); if let Some(ast::BinaryOp::Assignment { op: None }) = op { - self.is_lowering_assignee_expr = true; + let target = self.collect_expr_as_pat_opt(e.lhs()); + let value = self.collect_expr_opt(e.rhs()); + self.alloc_expr(Expr::Assignment { target, value }, syntax_ptr) + } else { + let lhs = self.collect_expr_opt(e.lhs()); + let rhs = self.collect_expr_opt(e.rhs()); + self.alloc_expr(Expr::BinaryOp { lhs, rhs, op }, syntax_ptr) } - let lhs = self.collect_expr_opt(e.lhs()); - self.is_lowering_assignee_expr = false; - let rhs = self.collect_expr_opt(e.rhs()); - self.alloc_expr(Expr::BinaryOp { lhs, rhs, op }, syntax_ptr) } ast::Expr::TupleExpr(e) => { let mut exprs: Vec<_> = e.fields().map(|expr| self.collect_expr(expr)).collect(); @@ -617,13 +602,7 @@ impl ExprCollector<'_> { exprs.insert(0, self.missing_expr()); } - self.alloc_expr( - Expr::Tuple { - exprs: exprs.into_boxed_slice(), - is_assignee_expr: self.is_lowering_assignee_expr, - }, - syntax_ptr, - ) + self.alloc_expr(Expr::Tuple { exprs: exprs.into_boxed_slice() }, syntax_ptr) } ast::Expr::ArrayExpr(e) => { let kind = e.kind(); @@ -631,13 +610,7 @@ impl ExprCollector<'_> { match kind { ArrayExprKind::ElementList(e) => { let elements = e.map(|expr| self.collect_expr(expr)).collect(); - self.alloc_expr( - Expr::Array(Array::ElementList { - elements, - is_assignee_expr: self.is_lowering_assignee_expr, - }), - syntax_ptr, - ) + self.alloc_expr(Expr::Array(Array::ElementList { elements }), syntax_ptr) } ArrayExprKind::Repeat { initializer, repeat } => { let initializer = self.collect_expr_opt(initializer); @@ -664,8 +637,7 @@ impl ExprCollector<'_> { ast::Expr::IndexExpr(e) => { let base = self.collect_expr_opt(e.base()); let index = self.collect_expr_opt(e.index()); - let is_assignee_expr = self.is_lowering_assignee_expr; - self.alloc_expr(Expr::Index { base, index, is_assignee_expr }, syntax_ptr) + self.alloc_expr(Expr::Index { base, index }, syntax_ptr) } ast::Expr::RangeExpr(e) => { let lhs = e.start().map(|lhs| self.collect_expr(lhs)); @@ -688,7 +660,7 @@ impl ExprCollector<'_> { // Make the macro-call point to its expanded expression so we can query // semantics on syntax pointers to the macro let src = self.expander.in_file(syntax_ptr); - self.source_map.expr_map.insert(src, id); + self.source_map.expr_map.insert(src, id.into()); id } None => self.alloc_expr(Expr::Missing, syntax_ptr), @@ -705,6 +677,179 @@ impl ExprCollector<'_> { }) } + fn collect_expr_as_pat_opt(&mut self, expr: Option) -> PatId { + match expr { + Some(expr) => self.collect_expr_as_pat(expr), + _ => self.missing_pat(), + } + } + + fn collect_expr_as_pat(&mut self, expr: ast::Expr) -> PatId { + self.maybe_collect_expr_as_pat(&expr).unwrap_or_else(|| { + let src = self.expander.in_file(AstPtr::new(&expr).wrap_left()); + let expr = self.collect_expr(expr); + // Do not use `alloc_pat_from_expr()` here, it will override the entry in `expr_map`. + let id = self.body.pats.alloc(Pat::Expr(expr)); + self.source_map.pat_map_back.insert(id, src); + id + }) + } + + fn maybe_collect_expr_as_pat(&mut self, expr: &ast::Expr) -> Option { + self.check_cfg(expr)?; + let syntax_ptr = AstPtr::new(expr); + + let result = match expr { + ast::Expr::UnderscoreExpr(_) => self.alloc_pat_from_expr(Pat::Wild, syntax_ptr), + ast::Expr::ParenExpr(e) => { + // We special-case `(..)` for consistency with patterns. + if let Some(ast::Expr::RangeExpr(range)) = e.expr() { + if range.is_range_full() { + return Some(self.alloc_pat_from_expr( + Pat::Tuple { args: Box::default(), ellipsis: Some(0) }, + syntax_ptr, + )); + } + } + return e.expr().and_then(|expr| self.maybe_collect_expr_as_pat(&expr)); + } + ast::Expr::TupleExpr(e) => { + let (ellipsis, args) = collect_tuple(self, e.fields()); + self.alloc_pat_from_expr(Pat::Tuple { args, ellipsis }, syntax_ptr) + } + ast::Expr::ArrayExpr(e) => { + if e.semicolon_token().is_some() { + return None; + } + + let mut elements = e.exprs(); + let prefix = elements + .by_ref() + .map_while(|elem| collect_possibly_rest(self, elem).left()) + .collect(); + let suffix = elements.map(|elem| self.collect_expr_as_pat(elem)).collect(); + self.alloc_pat_from_expr(Pat::Slice { prefix, slice: None, suffix }, syntax_ptr) + } + ast::Expr::CallExpr(e) => { + let path = collect_path(self, e.expr()?)?; + let path = path + .path() + .and_then(|path| self.expander.parse_path(self.db, path)) + .map(Box::new); + let (ellipsis, args) = collect_tuple(self, e.arg_list()?.args()); + self.alloc_pat_from_expr(Pat::TupleStruct { path, args, ellipsis }, syntax_ptr) + } + ast::Expr::PathExpr(e) => { + let path = Box::new(self.expander.parse_path(self.db, e.path()?)?); + self.alloc_pat_from_expr(Pat::Path(path), syntax_ptr) + } + ast::Expr::MacroExpr(e) => { + let e = e.macro_call()?; + let macro_ptr = AstPtr::new(&e); + let src = self.expander.in_file(AstPtr::new(expr)); + let id = self.collect_macro_call(e, macro_ptr, true, |this, expansion| { + this.collect_expr_as_pat_opt(expansion) + }); + self.source_map.expr_map.insert(src, id.into()); + id + } + ast::Expr::RecordExpr(e) => { + let path = + e.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new); + let record_field_list = e.record_expr_field_list()?; + let ellipsis = record_field_list.dotdot_token().is_some(); + // FIXME: Report an error here if `record_field_list.spread().is_some()`. + let args = record_field_list + .fields() + .filter_map(|f| { + self.check_cfg(&f)?; + let field_expr = f.expr()?; + let pat = self.collect_expr_as_pat(field_expr); + let name = f.field_name()?.as_name(); + let src = self.expander.in_file(AstPtr::new(&f).wrap_left()); + self.source_map.pat_field_map_back.insert(pat, src); + Some(RecordFieldPat { name, pat }) + }) + .collect(); + self.alloc_pat_from_expr(Pat::Record { path, args, ellipsis }, syntax_ptr) + } + _ => return None, + }; + return Some(result); + + fn collect_path(this: &mut ExprCollector<'_>, expr: ast::Expr) -> Option { + match expr { + ast::Expr::PathExpr(e) => Some(e), + ast::Expr::MacroExpr(mac) => { + let call = mac.macro_call()?; + { + let macro_ptr = AstPtr::new(&call); + this.collect_macro_call(call, macro_ptr, true, |this, expanded_path| { + collect_path(this, expanded_path?) + }) + } + } + _ => None, + } + } + + fn collect_possibly_rest( + this: &mut ExprCollector<'_>, + expr: ast::Expr, + ) -> Either { + match &expr { + ast::Expr::RangeExpr(e) if e.is_range_full() => Either::Right(()), + ast::Expr::MacroExpr(mac) => match mac.macro_call() { + Some(call) => { + let macro_ptr = AstPtr::new(&call); + let pat = this.collect_macro_call( + call, + macro_ptr, + true, + |this, expanded_expr| match expanded_expr { + Some(expanded_pat) => collect_possibly_rest(this, expanded_pat), + None => Either::Left(this.missing_pat()), + }, + ); + if let Either::Left(pat) = pat { + let src = this.expander.in_file(AstPtr::new(&expr).wrap_left()); + this.source_map.pat_map_back.insert(pat, src); + } + pat + } + None => { + let ptr = AstPtr::new(&expr); + Either::Left(this.alloc_pat_from_expr(Pat::Missing, ptr)) + } + }, + _ => Either::Left(this.collect_expr_as_pat(expr)), + } + } + + fn collect_tuple( + this: &mut ExprCollector<'_>, + fields: ast::AstChildren, + ) -> (Option, Box<[la_arena::Idx]>) { + let mut ellipsis = None; + let args = fields + .enumerate() + .filter_map(|(idx, elem)| { + match collect_possibly_rest(this, elem) { + Either::Left(pat) => Some(pat), + Either::Right(()) => { + if ellipsis.is_none() { + ellipsis = Some(idx as u32); + } + // FIXME: Report an error here otherwise. + None + } + } + }) + .collect(); + (ellipsis, args) + } + } + fn initialize_binding_owner( &mut self, syntax_ptr: AstPtr, @@ -755,17 +900,13 @@ impl ExprCollector<'_> { let callee = self.alloc_expr_desugared_with_ptr(Expr::Path(try_from_output), ptr); let next_tail = match btail { - Some(tail) => self.alloc_expr_desugared_with_ptr( - Expr::Call { callee, args: Box::new([tail]), is_assignee_expr: false }, - ptr, - ), + Some(tail) => self + .alloc_expr_desugared_with_ptr(Expr::Call { callee, args: Box::new([tail]) }, ptr), None => { - let unit = self.alloc_expr_desugared_with_ptr( - Expr::Tuple { exprs: Box::new([]), is_assignee_expr: false }, - ptr, - ); + let unit = + self.alloc_expr_desugared_with_ptr(Expr::Tuple { exprs: Box::new([]) }, ptr); self.alloc_expr_desugared_with_ptr( - Expr::Call { callee, args: Box::new([unit]), is_assignee_expr: false }, + Expr::Call { callee, args: Box::new([unit]) }, ptr, ) } @@ -851,11 +992,7 @@ impl ExprCollector<'_> { let head = self.collect_expr_opt(e.iterable()); let into_iter_fn_expr = self.alloc_expr(Expr::Path(into_iter_fn), syntax_ptr); let iterator = self.alloc_expr( - Expr::Call { - callee: into_iter_fn_expr, - args: Box::new([head]), - is_assignee_expr: false, - }, + Expr::Call { callee: into_iter_fn_expr, args: Box::new([head]) }, syntax_ptr, ); let none_arm = MatchArm { @@ -884,11 +1021,7 @@ impl ExprCollector<'_> { ); let iter_next_fn_expr = self.alloc_expr(Expr::Path(iter_next_fn), syntax_ptr); let iter_next_expr = self.alloc_expr( - Expr::Call { - callee: iter_next_fn_expr, - args: Box::new([iter_expr_mut]), - is_assignee_expr: false, - }, + Expr::Call { callee: iter_next_fn_expr, args: Box::new([iter_expr_mut]) }, syntax_ptr, ); let loop_inner = self.alloc_expr( @@ -942,10 +1075,8 @@ impl ExprCollector<'_> { }; let operand = self.collect_expr_opt(e.expr()); let try_branch = self.alloc_expr(Expr::Path(try_branch), syntax_ptr); - let expr = self.alloc_expr( - Expr::Call { callee: try_branch, args: Box::new([operand]), is_assignee_expr: false }, - syntax_ptr, - ); + let expr = self + .alloc_expr(Expr::Call { callee: try_branch, args: Box::new([operand]) }, syntax_ptr); let continue_name = Name::generate_new_name(self.body.bindings.len()); let continue_binding = self.alloc_binding(continue_name.clone(), BindingAnnotation::Unannotated); @@ -975,10 +1106,8 @@ impl ExprCollector<'_> { expr: { let it = self.alloc_expr(Expr::Path(Path::from(break_name)), syntax_ptr); let callee = self.alloc_expr(Expr::Path(try_from_residual), syntax_ptr); - let result = self.alloc_expr( - Expr::Call { callee, args: Box::new([it]), is_assignee_expr: false }, - syntax_ptr, - ); + let result = + self.alloc_expr(Expr::Call { callee, args: Box::new([it]) }, syntax_ptr); self.alloc_expr( match self.current_try_block_label { Some(label) => Expr::Break { expr: Some(result), label: Some(label) }, @@ -1108,7 +1237,7 @@ impl ExprCollector<'_> { // Make the macro-call point to its expanded expression so we can query // semantics on syntax pointers to the macro let src = self.expander.in_file(syntax_ptr); - self.source_map.expr_map.insert(src, tail); + self.source_map.expr_map.insert(src, tail.into()); }) } @@ -1372,7 +1501,7 @@ impl ExprCollector<'_> { let ast_pat = f.pat()?; let pat = self.collect_pat(ast_pat, binding_list); let name = f.field_name()?.as_name(); - let src = self.expander.in_file(AstPtr::new(&f)); + let src = self.expander.in_file(AstPtr::new(&f).wrap_right()); self.source_map.pat_field_map_back.insert(pat, src); Some(RecordFieldPat { name, pat }) }) @@ -1723,10 +1852,8 @@ impl ExprCollector<'_> { } }) .collect(); - let lit_pieces = self.alloc_expr_desugared(Expr::Array(Array::ElementList { - elements: lit_pieces, - is_assignee_expr: false, - })); + let lit_pieces = + self.alloc_expr_desugared(Expr::Array(Array::ElementList { elements: lit_pieces })); let lit_pieces = self.alloc_expr_desugared(Expr::Ref { expr: lit_pieces, rawness: Rawness::Ref, @@ -1743,10 +1870,7 @@ impl ExprCollector<'_> { Some(self.make_format_spec(placeholder, &mut argmap)) }) .collect(); - let array = self.alloc_expr_desugared(Expr::Array(Array::ElementList { - elements, - is_assignee_expr: false, - })); + let array = self.alloc_expr_desugared(Expr::Array(Array::ElementList { elements })); self.alloc_expr_desugared(Expr::Ref { expr: array, rawness: Rawness::Ref, @@ -1756,10 +1880,8 @@ impl ExprCollector<'_> { let arguments = &*fmt.arguments.arguments; let args = if arguments.is_empty() { - let expr = self.alloc_expr_desugared(Expr::Array(Array::ElementList { - elements: Box::default(), - is_assignee_expr: false, - })); + let expr = self + .alloc_expr_desugared(Expr::Array(Array::ElementList { elements: Box::default() })); self.alloc_expr_desugared(Expr::Ref { expr, rawness: Rawness::Ref, @@ -1786,10 +1908,8 @@ impl ExprCollector<'_> { self.make_argument(arg, ty) }) .collect(); - let array = self.alloc_expr_desugared(Expr::Array(Array::ElementList { - elements: args, - is_assignee_expr: false, - })); + let array = + self.alloc_expr_desugared(Expr::Array(Array::ElementList { elements: args })); self.alloc_expr_desugared(Expr::Ref { expr: array, rawness: Rawness::Ref, @@ -1822,11 +1942,8 @@ impl ExprCollector<'_> { let new_v1_formatted = self.alloc_expr_desugared(Expr::Path(new_v1_formatted)); let unsafe_arg_new = self.alloc_expr_desugared(Expr::Path(unsafe_arg_new)); - let unsafe_arg_new = self.alloc_expr_desugared(Expr::Call { - callee: unsafe_arg_new, - args: Box::default(), - is_assignee_expr: false, - }); + let unsafe_arg_new = + self.alloc_expr_desugared(Expr::Call { callee: unsafe_arg_new, args: Box::default() }); let unsafe_arg_new = self.alloc_expr_desugared(Expr::Unsafe { id: None, // We collect the unused expressions here so that we still infer them instead of @@ -1843,7 +1960,6 @@ impl ExprCollector<'_> { Expr::Call { callee: new_v1_formatted, args: Box::new([lit_pieces, args, format_options, unsafe_arg_new]), - is_assignee_expr: false, }, syntax_ptr, ); @@ -1938,7 +2054,6 @@ impl ExprCollector<'_> { self.alloc_expr_desugared(Expr::Call { callee: format_placeholder_new, args: Box::new([position, fill, align, flags, precision, width]), - is_assignee_expr: false, }) } @@ -1980,11 +2095,7 @@ impl ExprCollector<'_> { Some(count_is) => self.alloc_expr_desugared(Expr::Path(count_is)), None => self.missing_expr(), }; - self.alloc_expr_desugared(Expr::Call { - callee: count_is, - args: Box::new([args]), - is_assignee_expr: false, - }) + self.alloc_expr_desugared(Expr::Call { callee: count_is, args: Box::new([args]) }) } Some(FormatCount::Argument(arg)) => { if let Ok(arg_index) = arg.index { @@ -2005,7 +2116,6 @@ impl ExprCollector<'_> { self.alloc_expr_desugared(Expr::Call { callee: count_param, args: Box::new([args]), - is_assignee_expr: false, }) } else { // FIXME: This drops arg causing it to potentially not be resolved/type checked @@ -2054,11 +2164,7 @@ impl ExprCollector<'_> { Some(new_fn) => self.alloc_expr_desugared(Expr::Path(new_fn)), None => self.missing_expr(), }; - self.alloc_expr_desugared(Expr::Call { - callee: new_fn, - args: Box::new([arg]), - is_assignee_expr: false, - }) + self.alloc_expr_desugared(Expr::Call { callee: new_fn, args: Box::new([arg]) }) } // endregion: format @@ -2082,7 +2188,7 @@ impl ExprCollector<'_> { let src = self.expander.in_file(ptr); let id = self.body.exprs.alloc(expr); self.source_map.expr_map_back.insert(id, src); - self.source_map.expr_map.insert(src, id); + self.source_map.expr_map.insert(src, id.into()); id } // FIXME: desugared exprs don't have ptr, that's wrong and should be fixed. @@ -2110,10 +2216,17 @@ impl ExprCollector<'_> { binding } + fn alloc_pat_from_expr(&mut self, pat: Pat, ptr: ExprPtr) -> PatId { + let src = self.expander.in_file(ptr); + let id = self.body.pats.alloc(pat); + self.source_map.expr_map.insert(src, id.into()); + self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_left)); + id + } fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { let src = self.expander.in_file(ptr); let id = self.body.pats.alloc(pat); - self.source_map.pat_map_back.insert(id, src); + self.source_map.pat_map_back.insert(id, src.map(AstPtr::wrap_right)); self.source_map.pat_map.insert(src, id); id } diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs index 37167fcb8155..2419862b7617 100644 --- a/crates/hir-def/src/body/pretty.rs +++ b/crates/hir-def/src/body/pretty.rs @@ -277,7 +277,7 @@ impl Printer<'_> { w!(self, "loop "); self.print_expr(*body); } - Expr::Call { callee, args, is_assignee_expr: _ } => { + Expr::Call { callee, args } => { self.print_expr(*callee); w!(self, "("); if !args.is_empty() { @@ -372,7 +372,7 @@ impl Printer<'_> { self.print_expr(*expr); } } - Expr::RecordLit { path, fields, spread, ellipsis, is_assignee_expr: _ } => { + Expr::RecordLit { path, fields, spread } => { match path { Some(path) => self.print_path(path), None => w!(self, "�"), @@ -391,9 +391,6 @@ impl Printer<'_> { p.print_expr(*spread); wln!(p); } - if *ellipsis { - wln!(p, ".."); - } }); w!(self, "}}"); } @@ -466,7 +463,7 @@ impl Printer<'_> { w!(self, ") "); } } - Expr::Index { base, index, is_assignee_expr: _ } => { + Expr::Index { base, index } => { self.print_expr(*base); w!(self, "["); self.print_expr(*index); @@ -507,7 +504,7 @@ impl Printer<'_> { self.whitespace(); self.print_expr(*body); } - Expr::Tuple { exprs, is_assignee_expr: _ } => { + Expr::Tuple { exprs } => { w!(self, "("); for expr in exprs.iter() { self.print_expr(*expr); @@ -519,7 +516,7 @@ impl Printer<'_> { w!(self, "["); if !matches!(arr, Array::ElementList { elements, .. } if elements.is_empty()) { self.indented(|p| match arr { - Array::ElementList { elements, is_assignee_expr: _ } => { + Array::ElementList { elements } => { for elem in elements.iter() { p.print_expr(*elem); w!(p, ", "); @@ -551,6 +548,11 @@ impl Printer<'_> { Expr::Const(id) => { w!(self, "const {{ /* {id:?} */ }}"); } + &Expr::Assignment { target, value } => { + self.print_pat(target); + w!(self, " = "); + self.print_expr(value); + } } } @@ -719,6 +721,9 @@ impl Printer<'_> { w!(self, "const "); self.print_expr(*c); } + Pat::Expr(expr) => { + self.print_expr(*expr); + } } } diff --git a/crates/hir-def/src/body/scope.rs b/crates/hir-def/src/body/scope.rs index bf201ca83479..c6967961b33f 100644 --- a/crates/hir-def/src/body/scope.rs +++ b/crates/hir-def/src/body/scope.rs @@ -282,7 +282,7 @@ fn compute_expr_scopes( *scope = scopes.new_scope(*scope); scopes.add_pat_bindings(body, *scope, pat); } - e => e.walk_child_exprs(|e| compute_expr_scopes(scopes, e, scope)), + _ => body.walk_child_exprs(expr, |e| compute_expr_scopes(scopes, e, scope)), }; } @@ -333,6 +333,8 @@ mod tests { let expr_id = source_map .node_expr(InFile { file_id: file_id.into(), value: &marker.into() }) + .unwrap() + .as_expr() .unwrap(); let scope = scopes.scope_for(expr_id); @@ -488,8 +490,11 @@ fn foo() { let expr_scope = { let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap(); - let expr_id = - source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap(); + let expr_id = source_map + .node_expr(InFile { file_id: file_id.into(), value: &expr_ast }) + .unwrap() + .as_expr() + .unwrap(); scopes.scope_for(expr_id).unwrap() }; diff --git a/crates/hir-def/src/body/tests.rs b/crates/hir-def/src/body/tests.rs index dd3e79c874d8..3b29d98d198f 100644 --- a/crates/hir-def/src/body/tests.rs +++ b/crates/hir-def/src/body/tests.rs @@ -370,3 +370,37 @@ fn f(a: i32, b: u32) -> String { }"#]] .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) } + +#[test] +fn destructuring_assignment_tuple_macro() { + // This is a funny one. `let m!()() = Bar()` is an error in rustc, because `m!()()` isn't a valid pattern, + // but in destructuring assignment it is valid, because `m!()()` is a valid expression, and destructuring + // assignments start their lives as expressions. So we have to do the same. + + let (db, body, def) = lower( + r#" +struct Bar(); + +macro_rules! m { + () => { Bar }; +} + +fn foo() { + m!()() = Bar(); +} +"#, + ); + + let (_, source_map) = db.body_with_source_map(def); + assert_eq!(source_map.diagnostics(), &[]); + + for (_, def_map) in body.blocks(&db) { + assert_eq!(def_map.diagnostics(), &[]); + } + + expect![[r#" + fn foo() -> () { + Bar() = Bar(); + }"#]] + .assert_eq(&body.pretty_print(&db, def, Edition::CURRENT)) +} diff --git a/crates/hir-def/src/hir.rs b/crates/hir-def/src/hir.rs index d9358a28822e..a575a2d19913 100644 --- a/crates/hir-def/src/hir.rs +++ b/crates/hir-def/src/hir.rs @@ -48,6 +48,22 @@ pub enum ExprOrPatId { ExprId(ExprId), PatId(PatId), } + +impl ExprOrPatId { + pub fn as_expr(self) -> Option { + match self { + Self::ExprId(v) => Some(v), + _ => None, + } + } + + pub fn as_pat(self) -> Option { + match self { + Self::PatId(v) => Some(v), + _ => None, + } + } +} stdx::impl_from!(ExprId, PatId for ExprOrPatId); #[derive(Debug, Clone, Eq, PartialEq)] @@ -204,7 +220,6 @@ pub enum Expr { Call { callee: ExprId, args: Box<[ExprId]>, - is_assignee_expr: bool, }, MethodCall { receiver: ExprId, @@ -239,8 +254,6 @@ pub enum Expr { path: Option>, fields: Box<[RecordLitField]>, spread: Option, - ellipsis: bool, - is_assignee_expr: bool, }, Field { expr: ExprId, @@ -265,11 +278,17 @@ pub enum Expr { expr: ExprId, op: UnaryOp, }, + /// `op` cannot be bare `=` (but can be `op=`), these are lowered to `Assignment` instead. BinaryOp { lhs: ExprId, rhs: ExprId, op: Option, }, + // Assignments need a special treatment because of destructuring assignment. + Assignment { + target: PatId, + value: ExprId, + }, Range { lhs: Option, rhs: Option, @@ -278,7 +297,6 @@ pub enum Expr { Index { base: ExprId, index: ExprId, - is_assignee_expr: bool, }, Closure { args: Box<[PatId]>, @@ -290,7 +308,6 @@ pub enum Expr { }, Tuple { exprs: Box<[ExprId]>, - is_assignee_expr: bool, }, Array(Array), Literal(Literal), @@ -446,7 +463,7 @@ pub enum Movability { #[derive(Debug, Clone, Eq, PartialEq)] pub enum Array { - ElementList { elements: Box<[ExprId]>, is_assignee_expr: bool }, + ElementList { elements: Box<[ExprId]> }, Repeat { initializer: ExprId, repeat: ExprId }, } @@ -480,130 +497,6 @@ pub enum Statement { Item, } -impl Expr { - pub fn walk_child_exprs(&self, mut f: impl FnMut(ExprId)) { - match self { - Expr::Missing => {} - Expr::Path(_) | Expr::OffsetOf(_) => {} - Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op { - AsmOperand::In { expr, .. } - | AsmOperand::Out { expr: Some(expr), .. } - | AsmOperand::InOut { expr, .. } => f(*expr), - AsmOperand::SplitInOut { in_expr, out_expr, .. } => { - f(*in_expr); - if let Some(out_expr) = out_expr { - f(*out_expr); - } - } - AsmOperand::Out { expr: None, .. } - | AsmOperand::Const(_) - | AsmOperand::Label(_) - | AsmOperand::Sym(_) => (), - }), - Expr::If { condition, then_branch, else_branch } => { - f(*condition); - f(*then_branch); - if let &Some(else_branch) = else_branch { - f(else_branch); - } - } - Expr::Let { expr, .. } => { - f(*expr); - } - Expr::Const(_) => (), - Expr::Block { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Async { statements, tail, .. } => { - for stmt in statements.iter() { - match stmt { - Statement::Let { initializer, else_branch, .. } => { - if let &Some(expr) = initializer { - f(expr); - } - if let &Some(expr) = else_branch { - f(expr); - } - } - Statement::Expr { expr: expression, .. } => f(*expression), - Statement::Item => (), - } - } - if let &Some(expr) = tail { - f(expr); - } - } - Expr::Loop { body, .. } => f(*body), - Expr::Call { callee, args, .. } => { - f(*callee); - args.iter().copied().for_each(f); - } - Expr::MethodCall { receiver, args, .. } => { - f(*receiver); - args.iter().copied().for_each(f); - } - Expr::Match { expr, arms } => { - f(*expr); - arms.iter().map(|arm| arm.expr).for_each(f); - } - Expr::Continue { .. } => {} - Expr::Break { expr, .. } - | Expr::Return { expr } - | Expr::Yield { expr } - | Expr::Yeet { expr } => { - if let &Some(expr) = expr { - f(expr); - } - } - Expr::Become { expr } => f(*expr), - Expr::RecordLit { fields, spread, .. } => { - for field in fields.iter() { - f(field.expr); - } - if let &Some(expr) = spread { - f(expr); - } - } - Expr::Closure { body, .. } => { - f(*body); - } - Expr::BinaryOp { lhs, rhs, .. } => { - f(*lhs); - f(*rhs); - } - Expr::Range { lhs, rhs, .. } => { - if let &Some(lhs) = rhs { - f(lhs); - } - if let &Some(rhs) = lhs { - f(rhs); - } - } - Expr::Index { base, index, .. } => { - f(*base); - f(*index); - } - Expr::Field { expr, .. } - | Expr::Await { expr } - | Expr::Cast { expr, .. } - | Expr::Ref { expr, .. } - | Expr::UnaryOp { expr, .. } - | Expr::Box { expr } => { - f(*expr); - } - Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f), - Expr::Array(a) => match a { - Array::ElementList { elements, .. } => elements.iter().copied().for_each(f), - Array::Repeat { initializer, repeat } => { - f(*initializer); - f(*repeat) - } - }, - Expr::Literal(_) => {} - Expr::Underscore => {} - } - } -} - /// Explicit binding annotations given in the HIR for a binding. Note /// that this is not the final binding *mode* that we infer after type /// inference. @@ -665,18 +558,49 @@ pub struct RecordFieldPat { pub enum Pat { Missing, Wild, - Tuple { args: Box<[PatId]>, ellipsis: Option }, + Tuple { + args: Box<[PatId]>, + ellipsis: Option, + }, Or(Box<[PatId]>), - Record { path: Option>, args: Box<[RecordFieldPat]>, ellipsis: bool }, - Range { start: Option>, end: Option> }, - Slice { prefix: Box<[PatId]>, slice: Option, suffix: Box<[PatId]> }, + Record { + path: Option>, + args: Box<[RecordFieldPat]>, + ellipsis: bool, + }, + Range { + start: Option>, + end: Option>, + }, + Slice { + prefix: Box<[PatId]>, + slice: Option, + suffix: Box<[PatId]>, + }, + /// This might refer to a variable if a single segment path (specifically, on destructuring assignment). Path(Box), Lit(ExprId), - Bind { id: BindingId, subpat: Option }, - TupleStruct { path: Option>, args: Box<[PatId]>, ellipsis: Option }, - Ref { pat: PatId, mutability: Mutability }, - Box { inner: PatId }, + Bind { + id: BindingId, + subpat: Option, + }, + TupleStruct { + path: Option>, + args: Box<[PatId]>, + ellipsis: Option, + }, + Ref { + pat: PatId, + mutability: Mutability, + }, + Box { + inner: PatId, + }, ConstBlock(ExprId), + /// An expression inside a pattern. That can only occur inside assignments. + /// + /// E.g. in `(a, *b) = (1, &mut 2)`, `*b` is an expression. + Expr(ExprId), } impl Pat { @@ -687,7 +611,8 @@ impl Pat { | Pat::Path(..) | Pat::ConstBlock(..) | Pat::Wild - | Pat::Missing => {} + | Pat::Missing + | Pat::Expr(_) => {} Pat::Bind { subpat, .. } => { subpat.iter().copied().for_each(f); } diff --git a/crates/hir-def/src/test_db.rs b/crates/hir-def/src/test_db.rs index 4db21eb46bd5..0c36c88fb093 100644 --- a/crates/hir-def/src/test_db.rs +++ b/crates/hir-def/src/test_db.rs @@ -198,7 +198,10 @@ impl TestDB { .filter_map(|node| { let block = ast::BlockExpr::cast(node)?; let expr = ast::Expr::from(block); - let expr_id = source_map.node_expr(InFile::new(position.file_id.into(), &expr))?; + let expr_id = source_map + .node_expr(InFile::new(position.file_id.into(), &expr))? + .as_expr() + .unwrap(); let scope = scopes.scope_for(expr_id).unwrap(); Some(scope) }); diff --git a/crates/hir-ty/src/consteval.rs b/crates/hir-ty/src/consteval.rs index e41058aac2a9..6a8598884404 100644 --- a/crates/hir-ty/src/consteval.rs +++ b/crates/hir-ty/src/consteval.rs @@ -319,7 +319,7 @@ pub(crate) fn eval_to_const( return true; } let mut r = false; - body[expr].walk_child_exprs(|idx| r |= has_closure(body, idx)); + body.walk_child_exprs(expr, |idx| r |= has_closure(body, idx)); r } if has_closure(ctx.body, expr) { diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index f8b5c7d0ce2c..75dce7783163 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -546,10 +546,7 @@ pub fn record_literal_missing_fields( expr: &Expr, ) -> Option<(VariantId, Vec, /*exhaustive*/ bool)> { let (fields, exhaustive) = match expr { - Expr::RecordLit { fields, spread, ellipsis, is_assignee_expr, .. } => { - let exhaustive = if *is_assignee_expr { !*ellipsis } else { spread.is_none() }; - (fields, exhaustive) - } + Expr::RecordLit { fields, spread, .. } => (fields, spread.is_none()), _ => return None, }; diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs index bcfc37c86711..492262c7a4cf 100644 --- a/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -3,7 +3,7 @@ use hir_def::{ body::Body, - hir::{Expr, ExprId, UnaryOp}, + hir::{Expr, ExprId, ExprOrPatId, Pat, UnaryOp}, resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs}, type_ref::Rawness, DefWithBodyId, @@ -16,7 +16,7 @@ use crate::{ /// Returns `(unsafe_exprs, fn_is_unsafe)`. /// /// If `fn_is_unsafe` is false, `unsafe_exprs` are hard errors. If true, they're `unsafe_op_in_unsafe_fn`. -pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec, bool) { +pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec, bool) { let _p = tracing::info_span!("missing_unsafe").entered(); let mut res = Vec::new(); @@ -32,7 +32,7 @@ pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec, let infer = db.infer(def); unsafe_expressions(db, &infer, def, &body, body.body_expr, &mut |expr| { if !expr.inside_unsafe_block { - res.push(expr.expr); + res.push(expr.node); } }); @@ -40,7 +40,7 @@ pub fn missing_unsafe(db: &dyn HirDatabase, def: DefWithBodyId) -> (Vec, } pub struct UnsafeExpr { - pub expr: ExprId, + pub node: ExprOrPatId, pub inside_unsafe_block: bool, } @@ -75,26 +75,28 @@ fn walk_unsafe( inside_unsafe_block: bool, unsafe_expr_cb: &mut dyn FnMut(UnsafeExpr), ) { + let mut mark_unsafe_path = |path, node| { + let g = resolver.update_to_inner_scope(db.upcast(), def, current); + let value_or_partial = resolver.resolve_path_in_value_ns(db.upcast(), path); + if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial { + let static_data = db.static_data(id); + if static_data.mutable || (static_data.is_extern && !static_data.has_safe_kw) { + unsafe_expr_cb(UnsafeExpr { node, inside_unsafe_block }); + } + } + resolver.reset_to_guard(g); + }; + let expr = &body.exprs[current]; match expr { &Expr::Call { callee, .. } => { if let Some(func) = infer[callee].as_fn_def(db) { if is_fn_unsafe_to_call(db, func) { - unsafe_expr_cb(UnsafeExpr { expr: current, inside_unsafe_block }); - } - } - } - Expr::Path(path) => { - let g = resolver.update_to_inner_scope(db.upcast(), def, current); - let value_or_partial = resolver.resolve_path_in_value_ns(db.upcast(), path); - if let Some(ResolveValueResult::ValueNs(ValueNs::StaticId(id), _)) = value_or_partial { - let static_data = db.static_data(id); - if static_data.mutable || (static_data.is_extern && !static_data.has_safe_kw) { - unsafe_expr_cb(UnsafeExpr { expr: current, inside_unsafe_block }); + unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); } } - resolver.reset_to_guard(g); } + Expr::Path(path) => mark_unsafe_path(path, current.into()), Expr::Ref { expr, rawness: Rawness::RawPtr, mutability: _ } => { if let Expr::Path(_) = body.exprs[*expr] { // Do not report unsafe for `addr_of[_mut]!(EXTERN_OR_MUT_STATIC)`, @@ -108,23 +110,30 @@ fn walk_unsafe( .map(|(func, _)| is_fn_unsafe_to_call(db, func)) .unwrap_or(false) { - unsafe_expr_cb(UnsafeExpr { expr: current, inside_unsafe_block }); + unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); } } Expr::UnaryOp { expr, op: UnaryOp::Deref } => { if let TyKind::Raw(..) = &infer[*expr].kind(Interner) { - unsafe_expr_cb(UnsafeExpr { expr: current, inside_unsafe_block }); + unsafe_expr_cb(UnsafeExpr { node: current.into(), inside_unsafe_block }); } } Expr::Unsafe { .. } => { - return expr.walk_child_exprs(|child| { + return body.walk_child_exprs(current, |child| { walk_unsafe(db, infer, body, resolver, def, child, true, unsafe_expr_cb); }); } + &Expr::Assignment { target, value: _ } => { + body.walk_pats(target, &mut |pat| { + if let Pat::Path(path) = &body[pat] { + mark_unsafe_path(path, pat.into()); + } + }); + } _ => {} } - expr.walk_child_exprs(|child| { + body.walk_child_exprs(current, |child| { walk_unsafe(db, infer, body, resolver, def, child, inside_unsafe_block, unsafe_expr_cb); }); } diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 88334b492d5a..db16899d9f94 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -228,7 +228,7 @@ pub enum InferenceDiagnostic { id: ExprOrPatId, }, UnresolvedIdent { - expr: ExprId, + id: ExprOrPatId, }, // FIXME: This should be emitted in body lowering BreakOutsideOfLoop { @@ -482,12 +482,27 @@ impl InferenceResult { pub fn variant_resolution_for_pat(&self, id: PatId) -> Option { self.variant_resolutions.get(&id.into()).copied() } + pub fn variant_resolution_for_expr_or_pat(&self, id: ExprOrPatId) -> Option { + match id { + ExprOrPatId::ExprId(id) => self.variant_resolution_for_expr(id), + ExprOrPatId::PatId(id) => self.variant_resolution_for_pat(id), + } + } pub fn assoc_resolutions_for_expr(&self, id: ExprId) -> Option<(AssocItemId, Substitution)> { self.assoc_resolutions.get(&id.into()).cloned() } pub fn assoc_resolutions_for_pat(&self, id: PatId) -> Option<(AssocItemId, Substitution)> { self.assoc_resolutions.get(&id.into()).cloned() } + pub fn assoc_resolutions_for_expr_or_pat( + &self, + id: ExprOrPatId, + ) -> Option<(AssocItemId, Substitution)> { + match id { + ExprOrPatId::ExprId(id) => self.assoc_resolutions_for_expr(id), + ExprOrPatId::PatId(id) => self.assoc_resolutions_for_pat(id), + } + } pub fn type_mismatch_for_expr(&self, expr: ExprId) -> Option<&TypeMismatch> { self.type_mismatches.get(&expr.into()) } @@ -506,6 +521,12 @@ impl InferenceResult { pub fn closure_info(&self, closure: &ClosureId) -> &(Vec, FnTrait) { self.closure_info.get(closure).unwrap() } + pub fn type_of_expr_or_pat(&self, id: ExprOrPatId) -> Option<&Ty> { + match id { + ExprOrPatId::ExprId(id) => self.type_of_expr.get(id), + ExprOrPatId::PatId(id) => self.type_of_pat.get(id), + } + } } impl Index for InferenceResult { @@ -524,6 +545,14 @@ impl Index for InferenceResult { } } +impl Index for InferenceResult { + type Output = Ty; + + fn index(&self, id: ExprOrPatId) -> &Ty { + self.type_of_expr_or_pat(id).unwrap_or(&self.standard_types.unknown) + } +} + impl Index for InferenceResult { type Output = Ty; @@ -561,6 +590,9 @@ pub(crate) struct InferenceContext<'a> { diverges: Diverges, breakables: Vec, + /// Whether we are inside the pattern of a destructuring assignment. + inside_assignment: bool, + deferred_cast_checks: Vec, // fields related to closure capture @@ -656,6 +688,7 @@ impl<'a> InferenceContext<'a> { current_closure: None, deferred_closures: FxHashMap::default(), closure_dependencies: FxHashMap::default(), + inside_assignment: false, } } diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index e9825cf09988..91ba2af85e9e 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -11,11 +11,12 @@ use either::Either; use hir_def::{ data::adt::VariantData, hir::{ - Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, Pat, PatId, Statement, - UnaryOp, + Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, ExprOrPatId, Pat, PatId, + Statement, UnaryOp, }, lang_item::LangItem, - resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, + path::Path, + resolver::ValueNs, DefWithBodyId, FieldId, HasModule, TupleFieldId, TupleId, VariantId, }; use hir_expand::name::Name; @@ -508,18 +509,37 @@ impl InferenceContext<'_> { apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments) } + /// Pushes the span into `current_capture_span_stack`, *without clearing it first*. + fn path_place(&mut self, path: &Path, id: ExprOrPatId) -> Option { + if path.type_anchor().is_some() { + return None; + } + let result = self.resolver.resolve_path_in_value_ns_fully(self.db.upcast(), path).and_then( + |result| match result { + ValueNs::LocalBinding(binding) => { + let mir_span = match id { + ExprOrPatId::ExprId(id) => MirSpan::ExprId(id), + ExprOrPatId::PatId(id) => MirSpan::PatId(id), + }; + self.current_capture_span_stack.push(mir_span); + Some(HirPlace { local: binding, projections: Vec::new() }) + } + _ => None, + }, + ); + result + } + /// Changes `current_capture_span_stack` to contain the stack of spans for this expr. fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option { self.current_capture_span_stack.clear(); match &self.body[tgt_expr] { Expr::Path(p) => { - let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr); - if let Some(ResolveValueResult::ValueNs(ValueNs::LocalBinding(b), _)) = - resolver.resolve_path_in_value_ns(self.db.upcast(), p) - { - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - return Some(HirPlace { local: b, projections: vec![] }); - } + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr); + let result = self.path_place(p, tgt_expr.into()); + self.resolver.reset_to_guard(resolver_guard); + return result; } Expr::Field { expr, name: _ } => { let mut place = self.place_of_expr(*expr)?; @@ -590,6 +610,16 @@ impl InferenceContext<'_> { } } + fn mutate_path_pat(&mut self, path: &Path, id: PatId) { + if let Some(place) = self.path_place(path, id.into()) { + self.add_capture( + place, + CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), + ); + self.current_capture_span_stack.pop(); // Remove the pattern span. + } + } + fn mutate_expr(&mut self, expr: ExprId, place: Option) { if let Some(place) = place { self.add_capture( @@ -722,7 +752,7 @@ impl InferenceContext<'_> { self.consume_expr(*tail); } } - Expr::Call { callee, args, is_assignee_expr: _ } => { + Expr::Call { callee, args } => { self.consume_expr(*callee); self.consume_exprs(args.iter().copied()); } @@ -838,7 +868,7 @@ impl InferenceContext<'_> { self.consume_expr(expr); } } - Expr::Index { base, index, is_assignee_expr: _ } => { + Expr::Index { base, index } => { self.select_from_expr(*base); self.consume_expr(*index); } @@ -862,10 +892,30 @@ impl InferenceContext<'_> { })); self.current_captures = cc; } - Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) - | Expr::Tuple { exprs, is_assignee_expr: _ } => { + Expr::Array(Array::ElementList { elements: exprs }) | Expr::Tuple { exprs } => { self.consume_exprs(exprs.iter().copied()) } + &Expr::Assignment { target, value } => { + self.walk_expr(value); + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr); + match self.place_of_expr(value) { + Some(rhs_place) => { + self.inside_assignment = true; + self.consume_with_pat(rhs_place, target); + self.inside_assignment = false; + } + None => self.body.walk_pats(target, &mut |pat| match &self.body[pat] { + Pat::Path(path) => self.mutate_path_pat(path, pat), + &Pat::Expr(expr) => { + let place = self.place_of_expr(expr); + self.mutate_expr(expr, place); + } + _ => {} + }), + } + self.resolver.reset_to_guard(resolver_guard); + } Expr::Missing | Expr::Continue { .. } @@ -903,6 +953,7 @@ impl InferenceContext<'_> { | Pat::Missing | Pat::Wild | Pat::Tuple { .. } + | Pat::Expr(_) | Pat::Or(_) => (), Pat::TupleStruct { .. } | Pat::Record { .. } => { if let Some(variant) = self.result.variant_resolution_for_pat(p) { @@ -1122,11 +1173,15 @@ impl InferenceContext<'_> { } } } - Pat::Range { .. } - | Pat::Slice { .. } - | Pat::ConstBlock(_) - | Pat::Path(_) - | Pat::Lit(_) => self.consume_place(place), + Pat::Range { .. } | Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Lit(_) => { + self.consume_place(place) + } + Pat::Path(path) => { + if self.inside_assignment { + self.mutate_path_pat(path, tgt_pat); + } + self.consume_place(place); + } &Pat::Bind { id, subpat: _ } => { let mode = self.result.binding_modes[tgt_pat]; let capture_kind = match mode { @@ -1180,6 +1235,15 @@ impl InferenceContext<'_> { self.current_capture_span_stack.pop(); } Pat::Box { .. } => (), // not supported + &Pat::Expr(expr) => { + self.consume_place(place); + let pat_capture_span_stack = mem::take(&mut self.current_capture_span_stack); + let old_inside_assignment = mem::replace(&mut self.inside_assignment, false); + let lhs_place = self.place_of_expr(expr); + self.mutate_expr(expr, lhs_place); + self.inside_assignment = old_inside_assignment; + self.current_capture_span_stack = pat_capture_span_stack; + } } } self.current_capture_span_stack diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 657e4d779661..5c822fd22e1c 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -9,8 +9,8 @@ use chalk_ir::{cast::Cast, fold::Shift, DebruijnIndex, Mutability, TyVariableKin use either::Either; use hir_def::{ hir::{ - ArithOp, Array, AsmOperand, AsmOptions, BinaryOp, ClosureKind, Expr, ExprId, LabelId, - Literal, Pat, PatId, Statement, UnaryOp, + ArithOp, Array, AsmOperand, AsmOptions, BinaryOp, ClosureKind, Expr, ExprId, ExprOrPatId, + LabelId, Literal, Pat, PatId, Statement, UnaryOp, }, lang_item::{LangItem, LangItemTarget}, path::{GenericArg, GenericArgs, Path}, @@ -188,6 +188,9 @@ impl InferenceContext<'_> { | Pat::ConstBlock(_) | Pat::Record { .. } | Pat::Missing => true, + Pat::Expr(_) => unreachable!( + "we don't call pat_guaranteed_to_constitute_read_for_never() with assignments" + ), } } @@ -223,6 +226,7 @@ impl InferenceContext<'_> { | Expr::Const(..) | Expr::UnaryOp { .. } | Expr::BinaryOp { .. } + | Expr::Assignment { .. } | Expr::Yield { .. } | Expr::Cast { .. } | Expr::Async { .. } @@ -609,23 +613,7 @@ impl InferenceContext<'_> { coerce.complete(self) } } - Expr::Path(p) => { - let g = self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr); - let ty = match self.infer_path(p, tgt_expr.into()) { - Some(ty) => ty, - None => { - if matches!(p, Path::Normal { mod_path, .. } if mod_path.is_ident() || mod_path.is_self()) - { - self.push_diagnostic(InferenceDiagnostic::UnresolvedIdent { - expr: tgt_expr, - }); - } - self.err_ty() - } - }; - self.resolver.reset_to_guard(g); - ty - } + Expr::Path(p) => self.infer_expr_path(p, tgt_expr.into(), tgt_expr), &Expr::Continue { label } => { if find_continuable(&mut self.breakables, label).is_none() { self.push_diagnostic(InferenceDiagnostic::BreakOutsideOfLoop { @@ -892,36 +880,6 @@ impl InferenceContext<'_> { } } Expr::BinaryOp { lhs, rhs, op } => match op { - Some(BinaryOp::Assignment { op: None }) => { - let lhs = *lhs; - let is_ordinary = match &self.body[lhs] { - Expr::Array(_) - | Expr::RecordLit { .. } - | Expr::Tuple { .. } - | Expr::Underscore => false, - Expr::Call { callee, .. } => !matches!(&self.body[*callee], Expr::Path(_)), - _ => true, - }; - - // In ordinary (non-destructuring) assignments, the type of - // `lhs` must be inferred first so that the ADT fields - // instantiations in RHS can be coerced to it. Note that this - // cannot happen in destructuring assignments because of how - // they are desugared. - if is_ordinary { - // LHS of assignment doesn't constitute reads. - let lhs_ty = self.infer_expr(lhs, &Expectation::none(), ExprIsRead::No); - self.infer_expr_coerce( - *rhs, - &Expectation::has_type(lhs_ty), - ExprIsRead::No, - ); - } else { - let rhs_ty = self.infer_expr(*rhs, &Expectation::none(), ExprIsRead::Yes); - self.infer_assignee_expr(lhs, &rhs_ty); - } - self.result.standard_types.unit.clone() - } Some(BinaryOp::LogicOp(_)) => { let bool_ty = self.result.standard_types.bool_.clone(); self.infer_expr_coerce( @@ -942,6 +900,35 @@ impl InferenceContext<'_> { Some(op) => self.infer_overloadable_binop(*lhs, *op, *rhs, tgt_expr), _ => self.err_ty(), }, + &Expr::Assignment { target, value } => { + // In ordinary (non-destructuring) assignments, the type of + // `lhs` must be inferred first so that the ADT fields + // instantiations in RHS can be coerced to it. Note that this + // cannot happen in destructuring assignments because of how + // they are desugared. + let lhs_ty = match &self.body[target] { + // LHS of assignment doesn't constitute reads. + &Pat::Expr(expr) => { + Some(self.infer_expr(expr, &Expectation::none(), ExprIsRead::No)) + } + Pat::Path(path) => Some(self.infer_expr_path(path, target.into(), tgt_expr)), + _ => None, + }; + + if let Some(lhs_ty) = lhs_ty { + self.write_pat_ty(target, lhs_ty.clone()); + self.infer_expr_coerce(value, &Expectation::has_type(lhs_ty), ExprIsRead::No); + } else { + let rhs_ty = self.infer_expr(value, &Expectation::none(), ExprIsRead::Yes); + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, tgt_expr); + self.inside_assignment = true; + self.infer_top_pat(target, &rhs_ty); + self.inside_assignment = false; + self.resolver.reset_to_guard(resolver_guard); + } + self.result.standard_types.unit.clone() + } Expr::Range { lhs, rhs, range_type } => { let lhs_ty = lhs.map(|e| self.infer_expr_inner(e, &Expectation::none(), ExprIsRead::Yes)); @@ -981,7 +968,7 @@ impl InferenceContext<'_> { (RangeOp::Inclusive, _, None) => self.err_ty(), } } - Expr::Index { base, index, is_assignee_expr } => { + Expr::Index { base, index } => { let base_ty = self.infer_expr_inner(*base, &Expectation::none(), ExprIsRead::Yes); let index_ty = self.infer_expr(*index, &Expectation::none(), ExprIsRead::Yes); @@ -1017,23 +1004,11 @@ impl InferenceContext<'_> { self.write_method_resolution(tgt_expr, func, subst); } let assoc = self.resolve_ops_index_output(); - let res = self.resolve_associated_type_with_params( + self.resolve_associated_type_with_params( self_ty.clone(), assoc, &[index_ty.clone().cast(Interner)], - ); - - if *is_assignee_expr { - if let Some(index_trait) = self.resolve_lang_trait(LangItem::IndexMut) { - let trait_ref = TyBuilder::trait_ref(self.db, index_trait) - .push(self_ty) - .fill(|_| index_ty.clone().cast(Interner)) - .build(); - self.push_obligation(trait_ref.cast(Interner)); - } - } - - res + ) } else { self.err_ty() } @@ -1151,9 +1126,7 @@ impl InferenceContext<'_> { }, }, Expr::Underscore => { - // Underscore expressions may only appear in assignee expressions, - // which are handled by `infer_assignee_expr()`. - // Any other underscore expression is an error, we render a specialized diagnostic + // Underscore expression is an error, we render a specialized diagnostic // to let the user know what type is expected though. let expected = expected.to_option(&mut self.table).unwrap_or_else(|| self.err_ty()); self.push_diagnostic(InferenceDiagnostic::TypedHole { @@ -1232,6 +1205,22 @@ impl InferenceContext<'_> { ty } + fn infer_expr_path(&mut self, path: &Path, id: ExprOrPatId, scope_id: ExprId) -> Ty { + let g = self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, scope_id); + let ty = match self.infer_path(path, id) { + Some(ty) => ty, + None => { + if matches!(path, Path::Normal { mod_path, .. } if mod_path.is_ident() || mod_path.is_self()) + { + self.push_diagnostic(InferenceDiagnostic::UnresolvedIdent { id }); + } + self.err_ty() + } + }; + self.resolver.reset_to_guard(g); + ty + } + fn infer_async_block( &mut self, tgt_expr: ExprId, @@ -1482,107 +1471,6 @@ impl InferenceContext<'_> { } } - pub(super) fn infer_assignee_expr(&mut self, lhs: ExprId, rhs_ty: &Ty) -> Ty { - let is_rest_expr = |expr| { - matches!( - &self.body[expr], - Expr::Range { lhs: None, rhs: None, range_type: RangeOp::Exclusive }, - ) - }; - - let rhs_ty = self.resolve_ty_shallow(rhs_ty); - - let ty = match &self.body[lhs] { - Expr::Tuple { exprs, .. } => { - // We don't consider multiple ellipses. This is analogous to - // `hir_def::body::lower::ExprCollector::collect_tuple_pat()`. - let ellipsis = exprs.iter().position(|e| is_rest_expr(*e)).map(|it| it as u32); - let exprs: Vec<_> = exprs.iter().filter(|e| !is_rest_expr(**e)).copied().collect(); - - self.infer_tuple_pat_like(&rhs_ty, (), ellipsis, &exprs) - } - Expr::Call { callee, args, .. } => { - // Tuple structs - let path = match &self.body[*callee] { - Expr::Path(path) => Some(path), - _ => None, - }; - - // We don't consider multiple ellipses. This is analogous to - // `hir_def::body::lower::ExprCollector::collect_tuple_pat()`. - let ellipsis = args.iter().position(|e| is_rest_expr(*e)).map(|it| it as u32); - let args: Vec<_> = args.iter().filter(|e| !is_rest_expr(**e)).copied().collect(); - - self.infer_tuple_struct_pat_like(path, &rhs_ty, (), lhs, ellipsis, &args) - } - Expr::Array(Array::ElementList { elements, .. }) => { - let elem_ty = match rhs_ty.kind(Interner) { - TyKind::Array(st, _) => st.clone(), - _ => self.err_ty(), - }; - - // There's no need to handle `..` as it cannot be bound. - let sub_exprs = elements.iter().filter(|e| !is_rest_expr(**e)); - - for e in sub_exprs { - self.infer_assignee_expr(*e, &elem_ty); - } - - match rhs_ty.kind(Interner) { - TyKind::Array(_, _) => rhs_ty.clone(), - // Even when `rhs_ty` is not an array type, this assignee - // expression is inferred to be an array (of unknown element - // type and length). This should not be just an error type, - // because we are to compute the unifiability of this type and - // `rhs_ty` in the end of this function to issue type mismatches. - _ => TyKind::Array( - self.err_ty(), - crate::consteval::usize_const(self.db, None, self.resolver.krate()), - ) - .intern(Interner), - } - } - Expr::RecordLit { path, fields, .. } => { - let subs = fields.iter().map(|f| (f.name.clone(), f.expr)); - - self.infer_record_pat_like(path.as_deref(), &rhs_ty, (), lhs, subs) - } - Expr::Underscore => rhs_ty.clone(), - _ => { - // `lhs` is a place expression, a unit struct, or an enum variant. - // LHS of assignment doesn't constitute reads. - let lhs_ty = self.infer_expr_inner(lhs, &Expectation::none(), ExprIsRead::No); - - // This is the only branch where this function may coerce any type. - // We are returning early to avoid the unifiability check below. - let lhs_ty = self.insert_type_vars_shallow(lhs_ty); - let ty = match self.coerce(None, &rhs_ty, &lhs_ty, CoerceNever::Yes) { - Ok(ty) => ty, - Err(_) => { - self.result.type_mismatches.insert( - lhs.into(), - TypeMismatch { expected: rhs_ty.clone(), actual: lhs_ty.clone() }, - ); - // `rhs_ty` is returned so no further type mismatches are - // reported because of this mismatch. - rhs_ty - } - }; - self.write_expr_ty(lhs, ty.clone()); - return ty; - } - }; - - let ty = self.insert_type_vars_shallow(ty); - if !self.unify(&ty, &rhs_ty) { - self.result - .type_mismatches - .insert(lhs.into(), TypeMismatch { expected: rhs_ty.clone(), actual: ty.clone() }); - } - self.write_expr_ty(lhs, ty.clone()); - ty - } - fn infer_overloadable_binop( &mut self, lhs: ExprId, diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs index 6a0daee6ea9f..9fef582d85de 100644 --- a/crates/hir-ty/src/infer/mutability.rs +++ b/crates/hir-ty/src/infer/mutability.rs @@ -4,7 +4,8 @@ use chalk_ir::{cast::Cast, Mutability}; use hir_def::{ hir::{ - Array, AsmOperand, BinaryOp, BindingAnnotation, Expr, ExprId, PatId, Statement, UnaryOp, + Array, AsmOperand, BinaryOp, BindingAnnotation, Expr, ExprId, Pat, PatId, Statement, + UnaryOp, }, lang_item::LangItem, }; @@ -96,7 +97,7 @@ impl InferenceContext<'_> { } } Expr::MethodCall { receiver: it, method_name: _, args, generic_args: _ } - | Expr::Call { callee: it, args, is_assignee_expr: _ } => { + | Expr::Call { callee: it, args } => { self.infer_mut_not_expr_iter(args.iter().copied().chain(Some(*it))); } Expr::Match { expr, arms } => { @@ -120,10 +121,10 @@ impl InferenceContext<'_> { Expr::Become { expr } => { self.infer_mut_expr(*expr, Mutability::Not); } - Expr::RecordLit { path: _, fields, spread, ellipsis: _, is_assignee_expr: _ } => { + Expr::RecordLit { path: _, fields, spread } => { self.infer_mut_not_expr_iter(fields.iter().map(|it| it.expr).chain(*spread)) } - &Expr::Index { base, index, is_assignee_expr } => { + &Expr::Index { base, index } => { if mutability == Mutability::Mut { if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) { if let Some(index_trait) = self @@ -148,11 +149,8 @@ impl InferenceContext<'_> { target, }) = base_adjustments { - // For assignee exprs `IndexMut` obligations are already applied - if !is_assignee_expr { - if let TyKind::Ref(_, _, ty) = target.kind(Interner) { - base_ty = Some(ty.clone()); - } + if let TyKind::Ref(_, _, ty) = target.kind(Interner) { + base_ty = Some(ty.clone()); } *mutability = Mutability::Mut; } @@ -233,6 +231,14 @@ impl InferenceContext<'_> { self.infer_mut_expr(*lhs, Mutability::Mut); self.infer_mut_expr(*rhs, Mutability::Not); } + &Expr::Assignment { target, value } => { + self.body.walk_pats(target, &mut |pat| match self.body[pat] { + Pat::Expr(expr) => self.infer_mut_expr(expr, Mutability::Mut), + Pat::ConstBlock(block) => self.infer_mut_expr(block, Mutability::Not), + _ => {} + }); + self.infer_mut_expr(value, Mutability::Not); + } Expr::Array(Array::Repeat { initializer: lhs, repeat: rhs }) | Expr::BinaryOp { lhs, rhs, op: _ } | Expr::Range { lhs: Some(lhs), rhs: Some(rhs), range_type: _ } => { @@ -242,8 +248,7 @@ impl InferenceContext<'_> { Expr::Closure { body, .. } => { self.infer_mut_expr(*body, Mutability::Not); } - Expr::Tuple { exprs, is_assignee_expr: _ } - | Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) => { + Expr::Tuple { exprs } | Expr::Array(Array::ElementList { elements: exprs }) => { self.infer_mut_not_expr_iter(exprs.iter().copied()); } // These don't need any action, as they don't have sub expressions diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs index fee6755408ea..50e761196ec1 100644 --- a/crates/hir-ty/src/infer/pat.rs +++ b/crates/hir-ty/src/infer/pat.rs @@ -4,7 +4,7 @@ use std::iter::repeat_with; use hir_def::{ body::Body, - hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId}, + hir::{Binding, BindingAnnotation, BindingId, Expr, ExprId, Literal, Pat, PatId}, path::Path, }; use hir_expand::name::Name; @@ -12,63 +12,28 @@ use stdx::TupleExt; use crate::{ consteval::{try_const_usize, usize_const}, - infer::{expr::ExprIsRead, BindingMode, Expectation, InferenceContext, TypeMismatch}, + infer::{ + coerce::CoerceNever, expr::ExprIsRead, BindingMode, Expectation, InferenceContext, + TypeMismatch, + }, lower::lower_to_chalk_mutability, primitive::UintTy, static_lifetime, InferenceDiagnostic, Interner, Mutability, Scalar, Substitution, Ty, TyBuilder, TyExt, TyKind, }; -/// Used to generalize patterns and assignee expressions. -pub(super) trait PatLike: Into + Copy { - type BindingMode: Copy; - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - default_bm: Self::BindingMode, - ) -> Ty; -} - -impl PatLike for ExprId { - type BindingMode = (); - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - (): Self::BindingMode, - ) -> Ty { - this.infer_assignee_expr(id, expected_ty) - } -} - -impl PatLike for PatId { - type BindingMode = BindingMode; - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - default_bm: Self::BindingMode, - ) -> Ty { - this.infer_pat(id, expected_ty, default_bm) - } -} - impl InferenceContext<'_> { /// Infers type for tuple struct pattern or its corresponding assignee expression. /// /// Ellipses found in the original pattern or expression must be filtered out. - pub(super) fn infer_tuple_struct_pat_like( + pub(super) fn infer_tuple_struct_pat_like( &mut self, path: Option<&Path>, expected: &Ty, - default_bm: T::BindingMode, - id: T, + default_bm: BindingMode, + id: PatId, ellipsis: Option, - subs: &[T], + subs: &[PatId], ) -> Ty { let (ty, def) = self.resolve_variant(path, true); let var_data = def.map(|it| it.variant_data(self.db.upcast())); @@ -127,13 +92,13 @@ impl InferenceContext<'_> { } }; - T::infer(self, subpat, &expected_ty, default_bm); + self.infer_pat(subpat, &expected_ty, default_bm); } } None => { let err_ty = self.err_ty(); for &inner in subs { - T::infer(self, inner, &err_ty, default_bm); + self.infer_pat(inner, &err_ty, default_bm); } } } @@ -142,13 +107,13 @@ impl InferenceContext<'_> { } /// Infers type for record pattern or its corresponding assignee expression. - pub(super) fn infer_record_pat_like( + pub(super) fn infer_record_pat_like( &mut self, path: Option<&Path>, expected: &Ty, - default_bm: T::BindingMode, - id: T, - subs: impl ExactSizeIterator, + default_bm: BindingMode, + id: PatId, + subs: impl ExactSizeIterator, ) -> Ty { let (ty, def) = self.resolve_variant(path, false); if let Some(variant) = def { @@ -197,13 +162,13 @@ impl InferenceContext<'_> { } }; - T::infer(self, inner, &expected_ty, default_bm); + self.infer_pat(inner, &expected_ty, default_bm); } } None => { let err_ty = self.err_ty(); for (_, inner) in subs { - T::infer(self, inner, &err_ty, default_bm); + self.infer_pat(inner, &err_ty, default_bm); } } } @@ -214,12 +179,12 @@ impl InferenceContext<'_> { /// Infers type for tuple pattern or its corresponding assignee expression. /// /// Ellipses found in the original pattern or expression must be filtered out. - pub(super) fn infer_tuple_pat_like( + pub(super) fn infer_tuple_pat_like( &mut self, expected: &Ty, - default_bm: T::BindingMode, + default_bm: BindingMode, ellipsis: Option, - subs: &[T], + subs: &[PatId], ) -> Ty { let expected = self.resolve_ty_shallow(expected); let expectations = match expected.as_tuple() { @@ -244,18 +209,20 @@ impl InferenceContext<'_> { // Process pre for (ty, pat) in inner_tys.iter_mut().zip(pre) { - *ty = T::infer(self, *pat, ty, default_bm); + *ty = self.infer_pat(*pat, ty, default_bm); } // Process post for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) { - *ty = T::infer(self, *pat, ty, default_bm); + *ty = self.infer_pat(*pat, ty, default_bm); } TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys)) .intern(Interner) } + /// The resolver needs to be updated to the surrounding expression when inside assignment + /// (because there, `Pat::Path` can refer to a variable). pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) { self.infer_pat(pat, expected, BindingMode::default()); } @@ -263,7 +230,14 @@ impl InferenceContext<'_> { fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty { let mut expected = self.resolve_ty_shallow(expected); - if self.is_non_ref_pat(self.body, pat) { + if matches!(&self.body[pat], Pat::Ref { .. }) || self.inside_assignment { + cov_mark::hit!(match_ergonomics_ref); + // When you encounter a `&pat` pattern, reset to Move. + // This is so that `w` is by value: `let (_, &w) = &(1, &2);` + // Destructuring assignments also reset the binding mode and + // don't do match ergonomics. + default_bm = BindingMode::Move; + } else if self.is_non_ref_pat(self.body, pat) { let mut pat_adjustments = Vec::new(); while let Some((inner, _lifetime, mutability)) = expected.as_reference() { pat_adjustments.push(expected.clone()); @@ -279,11 +253,6 @@ impl InferenceContext<'_> { pat_adjustments.shrink_to_fit(); self.result.pat_adjustments.insert(pat, pat_adjustments); } - } else if let Pat::Ref { .. } = &self.body[pat] { - cov_mark::hit!(match_ergonomics_ref); - // When you encounter a `&pat` pattern, reset to Move. - // This is so that `w` is by value: `let (_, &w) = &(1, &2);` - default_bm = BindingMode::Move; } // Lose mutability. @@ -320,8 +289,34 @@ impl InferenceContext<'_> { self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat, subs) } Pat::Path(path) => { - // FIXME update resolver for the surrounding expression - self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty()) + let ty = self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty()); + let ty_inserted_vars = self.insert_type_vars_shallow(ty.clone()); + match self.table.coerce(&expected, &ty_inserted_vars, CoerceNever::Yes) { + Ok((adjustments, coerced_ty)) => { + if !adjustments.is_empty() { + self.result + .pat_adjustments + .entry(pat) + .or_default() + .extend(adjustments.into_iter().map(|adjust| adjust.target)); + } + self.write_pat_ty(pat, coerced_ty); + return self.pat_ty_after_adjustment(pat); + } + Err(_) => { + self.result.type_mismatches.insert( + pat.into(), + TypeMismatch { + expected: expected.clone(), + actual: ty_inserted_vars.clone(), + }, + ); + self.write_pat_ty(pat, ty); + // We return `expected` to prevent cascading errors. I guess an alternative is to + // not emit type mismatches for error types and emit an error type here. + return expected; + } + } } Pat::Bind { id, subpat } => { return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected); @@ -361,7 +356,40 @@ impl InferenceContext<'_> { None => self.err_ty(), }, Pat::ConstBlock(expr) => { - self.infer_expr(*expr, &Expectation::has_type(expected.clone()), ExprIsRead::Yes) + let old_inside_assign = std::mem::replace(&mut self.inside_assignment, false); + let result = self.infer_expr( + *expr, + &Expectation::has_type(expected.clone()), + ExprIsRead::Yes, + ); + self.inside_assignment = old_inside_assign; + result + } + Pat::Expr(expr) => { + let old_inside_assign = std::mem::replace(&mut self.inside_assignment, false); + // LHS of assignment doesn't constitute reads. + let result = self.infer_expr_coerce( + *expr, + &Expectation::has_type(expected.clone()), + ExprIsRead::No, + ); + // We are returning early to avoid the unifiability check below. + let lhs_ty = self.insert_type_vars_shallow(result); + let ty = match self.coerce(None, &expected, &lhs_ty, CoerceNever::Yes) { + Ok(ty) => ty, + Err(_) => { + self.result.type_mismatches.insert( + pat.into(), + TypeMismatch { expected: expected.clone(), actual: lhs_ty.clone() }, + ); + // `rhs_ty` is returned so no further type mismatches are + // reported because of this mismatch. + expected + } + }; + self.write_pat_ty(pat, ty.clone()); + self.inside_assignment = old_inside_assign; + return ty; } Pat::Missing => self.err_ty(), }; @@ -517,9 +545,12 @@ impl InferenceContext<'_> { body[*expr], Expr::Literal(Literal::String(..) | Literal::CString(..) | Literal::ByteString(..)) ), - Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => { - false - } + Pat::Wild + | Pat::Bind { .. } + | Pat::Ref { .. } + | Pat::Box { .. } + | Pat::Missing + | Pat::Expr(_) => false, } } } diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 16994cdd0c65..a8a927c2c5c7 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -13,7 +13,7 @@ use hir_def::{ }, lang_item::{LangItem, LangItemTarget}, path::Path, - resolver::{resolver_for_expr, HasResolver, ResolveValueResult, ValueNs}, + resolver::{HasResolver, ResolveValueResult, Resolver, ValueNs}, AdtId, DefWithBodyId, EnumVariantId, GeneralConstId, HasModule, ItemContainerId, LocalFieldId, Lookup, TraitId, TupleId, TypeOrConstParamId, }; @@ -76,6 +76,7 @@ struct MirLowerCtx<'a> { db: &'a dyn HirDatabase, body: &'a Body, infer: &'a InferenceResult, + resolver: Resolver, drop_scopes: Vec, } @@ -278,6 +279,7 @@ impl<'ctx> MirLowerCtx<'ctx> { owner, closures: vec![], }; + let resolver = owner.resolver(db.upcast()); MirLowerCtx { result: mir, @@ -285,6 +287,7 @@ impl<'ctx> MirLowerCtx<'ctx> { infer, body, owner, + resolver, current_loop_blocks: None, labeled_loop_blocks: Default::default(), discr_temp: None, @@ -410,43 +413,48 @@ impl<'ctx> MirLowerCtx<'ctx> { Err(MirLowerError::IncompleteExpr) } Expr::Path(p) => { - let pr = - if let Some((assoc, subst)) = self.infer.assoc_resolutions_for_expr(expr_id) { - match assoc { - hir_def::AssocItemId::ConstId(c) => { - self.lower_const( - c.into(), - current, - place, - subst, - expr_id.into(), - self.expr_ty_without_adjust(expr_id), - )?; - return Ok(Some(current)); - } - hir_def::AssocItemId::FunctionId(_) => { - // FnDefs are zero sized, no action is needed. - return Ok(Some(current)); - } - hir_def::AssocItemId::TypeAliasId(_) => { - // FIXME: If it is unreachable, use proper error instead of `not_supported`. - not_supported!("associated functions and types") - } + let pr = if let Some((assoc, subst)) = + self.infer.assoc_resolutions_for_expr(expr_id) + { + match assoc { + hir_def::AssocItemId::ConstId(c) => { + self.lower_const( + c.into(), + current, + place, + subst, + expr_id.into(), + self.expr_ty_without_adjust(expr_id), + )?; + return Ok(Some(current)); } - } else if let Some(variant) = self.infer.variant_resolution_for_expr(expr_id) { - match variant { - VariantId::EnumVariantId(e) => ValueNs::EnumVariantId(e), - VariantId::StructId(s) => ValueNs::StructId(s), - VariantId::UnionId(_) => implementation_error!("Union variant as path"), + hir_def::AssocItemId::FunctionId(_) => { + // FnDefs are zero sized, no action is needed. + return Ok(Some(current)); } - } else { - let unresolved_name = - || MirLowerError::unresolved_path(self.db, p, self.edition()); - let resolver = resolver_for_expr(self.db.upcast(), self.owner, expr_id); - resolver - .resolve_path_in_value_ns_fully(self.db.upcast(), p) - .ok_or_else(unresolved_name)? - }; + hir_def::AssocItemId::TypeAliasId(_) => { + // FIXME: If it is unreachable, use proper error instead of `not_supported`. + not_supported!("associated functions and types") + } + } + } else if let Some(variant) = self.infer.variant_resolution_for_expr(expr_id) { + match variant { + VariantId::EnumVariantId(e) => ValueNs::EnumVariantId(e), + VariantId::StructId(s) => ValueNs::StructId(s), + VariantId::UnionId(_) => implementation_error!("Union variant as path"), + } + } else { + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr_id); + let result = self + .resolver + .resolve_path_in_value_ns_fully(self.db.upcast(), p) + .ok_or_else(|| { + MirLowerError::unresolved_path(self.db, p, self.edition()) + })?; + self.resolver.reset_to_guard(resolver_guard); + result + }; match pr { ValueNs::LocalBinding(_) | ValueNs::StaticId(_) => { let Some((temp, current)) = @@ -553,8 +561,11 @@ impl<'ctx> MirLowerCtx<'ctx> { return Ok(None); }; self.push_fake_read(current, cond_place, expr_id.into()); + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr_id); let (then_target, else_target) = self.pattern_match(current, None, cond_place, *pat)?; + self.resolver.reset_to_guard(resolver_guard); self.write_bytes_to_place( then_target, place, @@ -688,6 +699,8 @@ impl<'ctx> MirLowerCtx<'ctx> { }; self.push_fake_read(current, cond_place, expr_id.into()); let mut end = None; + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr_id); for MatchArm { pat, guard, expr } in arms.iter() { let (then, mut otherwise) = self.pattern_match(current, None, cond_place, *pat)?; @@ -721,6 +734,7 @@ impl<'ctx> MirLowerCtx<'ctx> { } } } + self.resolver.reset_to_guard(resolver_guard); if self.is_unterminated(current) { self.set_terminator(current, TerminatorKind::Unreachable, expr_id.into()); } @@ -795,7 +809,7 @@ impl<'ctx> MirLowerCtx<'ctx> { } Expr::Become { .. } => not_supported!("tail-calls"), Expr::Yield { .. } => not_supported!("yield"), - Expr::RecordLit { fields, path, spread, ellipsis: _, is_assignee_expr: _ } => { + Expr::RecordLit { fields, path, spread } => { let spread_place = match spread { &Some(it) => { let Some((p, c)) = self.lower_expr_as_place(current, it, true)? else { @@ -1010,35 +1024,28 @@ impl<'ctx> MirLowerCtx<'ctx> { ); } } - if let hir_def::hir::BinaryOp::Assignment { op } = op { - if let Some(op) = op { - // last adjustment is `&mut` which we don't want it. - let adjusts = self - .infer - .expr_adjustments - .get(lhs) - .and_then(|it| it.split_last()) - .map(|it| it.1) - .ok_or(MirLowerError::TypeError( - "adjustment of binary op was missing", - ))?; - let Some((lhs_place, current)) = - self.lower_expr_as_place_with_adjust(current, *lhs, false, adjusts)? - else { - return Ok(None); - }; - let Some((rhs_op, current)) = - self.lower_expr_to_some_operand(*rhs, current)? - else { - return Ok(None); - }; - let r_value = - Rvalue::CheckedBinaryOp(op.into(), Operand::Copy(lhs_place), rhs_op); - self.push_assignment(current, lhs_place, r_value, expr_id.into()); - return Ok(Some(current)); - } else { - return self.lower_assignment(current, *lhs, *rhs, expr_id.into()); - } + if let hir_def::hir::BinaryOp::Assignment { op: Some(op) } = op { + // last adjustment is `&mut` which we don't want it. + let adjusts = self + .infer + .expr_adjustments + .get(lhs) + .and_then(|it| it.split_last()) + .map(|it| it.1) + .ok_or(MirLowerError::TypeError("adjustment of binary op was missing"))?; + let Some((lhs_place, current)) = + self.lower_expr_as_place_with_adjust(current, *lhs, false, adjusts)? + else { + return Ok(None); + }; + let Some((rhs_op, current)) = self.lower_expr_to_some_operand(*rhs, current)? + else { + return Ok(None); + }; + let r_value = + Rvalue::CheckedBinaryOp(op.into(), Operand::Copy(lhs_place), rhs_op); + self.push_assignment(current, lhs_place, r_value, expr_id.into()); + return Ok(Some(current)); } let Some((lhs_op, current)) = self.lower_expr_to_some_operand(*lhs, current)? else { @@ -1097,6 +1104,18 @@ impl<'ctx> MirLowerCtx<'ctx> { ); Ok(Some(current)) } + &Expr::Assignment { target, value } => { + let Some((value, mut current)) = self.lower_expr_as_place(current, value, true)? + else { + return Ok(None); + }; + self.push_fake_read(current, value, expr_id.into()); + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr_id); + current = self.pattern_match_assignment(current, value, target)?; + self.resolver.reset_to_guard(resolver_guard); + Ok(Some(current)) + } &Expr::Range { lhs, rhs, range_type: _ } => { let ty = self.expr_ty_without_adjust(expr_id); let Some((adt, subst)) = ty.as_adt() else { @@ -1213,7 +1232,7 @@ impl<'ctx> MirLowerCtx<'ctx> { ); Ok(Some(current)) } - Expr::Tuple { exprs, is_assignee_expr: _ } => { + Expr::Tuple { exprs } => { let Some(values) = exprs .iter() .map(|it| { @@ -1291,73 +1310,6 @@ impl<'ctx> MirLowerCtx<'ctx> { } } - fn lower_destructing_assignment( - &mut self, - mut current: BasicBlockId, - lhs: ExprId, - rhs: Place, - span: MirSpan, - ) -> Result> { - match &self.body.exprs[lhs] { - Expr::Tuple { exprs, is_assignee_expr: _ } => { - for (i, expr) in exprs.iter().enumerate() { - let rhs = rhs.project( - ProjectionElem::Field(Either::Right(TupleFieldId { - tuple: TupleId(!0), // Dummy this as its unused - index: i as u32, - })), - &mut self.result.projection_store, - ); - let Some(c) = self.lower_destructing_assignment(current, *expr, rhs, span)? - else { - return Ok(None); - }; - current = c; - } - Ok(Some(current)) - } - Expr::Underscore => Ok(Some(current)), - _ => { - let Some((lhs_place, current)) = self.lower_expr_as_place(current, lhs, false)? - else { - return Ok(None); - }; - self.push_assignment(current, lhs_place, Operand::Copy(rhs).into(), span); - Ok(Some(current)) - } - } - } - - fn lower_assignment( - &mut self, - current: BasicBlockId, - lhs: ExprId, - rhs: ExprId, - span: MirSpan, - ) -> Result> { - let Some((rhs_op, current)) = self.lower_expr_to_some_operand(rhs, current)? else { - return Ok(None); - }; - if matches!(&self.body.exprs[lhs], Expr::Underscore) { - self.push_fake_read_for_operand(current, rhs_op, span); - return Ok(Some(current)); - } - if matches!( - &self.body.exprs[lhs], - Expr::Tuple { .. } | Expr::RecordLit { .. } | Expr::Call { .. } - ) { - let temp = self.temp(self.expr_ty_after_adjustments(rhs), current, rhs.into())?; - let temp = Place::from(temp); - self.push_assignment(current, temp, rhs_op.into(), span); - return self.lower_destructing_assignment(current, lhs, temp, span); - } - let Some((lhs_place, current)) = self.lower_expr_as_place(current, lhs, false)? else { - return Ok(None); - }; - self.push_assignment(current, lhs_place, rhs_op.into(), span); - Ok(Some(current)) - } - fn placeholder_subst(&mut self) -> Substitution { match self.owner.as_generic_def_id(self.db.upcast()) { Some(it) => TyBuilder::placeholder_subst(self.db, it), @@ -1407,8 +1359,8 @@ impl<'ctx> MirLowerCtx<'ctx> { let edition = self.edition(); let unresolved_name = || MirLowerError::unresolved_path(self.db, c.as_ref(), edition); - let resolver = self.owner.resolver(self.db.upcast()); - let pr = resolver + let pr = self + .resolver .resolve_path_in_value_ns(self.db.upcast(), c.as_ref()) .ok_or_else(unresolved_name)?; match pr { @@ -1632,12 +1584,6 @@ impl<'ctx> MirLowerCtx<'ctx> { self.push_statement(block, StatementKind::FakeRead(p).with_span(span)); } - fn push_fake_read_for_operand(&mut self, block: BasicBlockId, operand: Operand, span: MirSpan) { - if let Operand::Move(p) | Operand::Copy(p) = operand { - self.push_fake_read(block, p, span); - } - } - fn push_assignment( &mut self, block: BasicBlockId, @@ -1791,8 +1737,16 @@ impl<'ctx> MirLowerCtx<'ctx> { }; current = c; self.push_fake_read(current, init_place, span); + // Using the initializer for the resolver scope is good enough for us, as it cannot create new declarations + // and has all declarations of the `let`. + let resolver_guard = self.resolver.update_to_inner_scope( + self.db.upcast(), + self.owner, + *expr_id, + ); (current, else_block) = self.pattern_match(current, None, init_place, *pat)?; + self.resolver.reset_to_guard(resolver_guard); match (else_block, else_branch) { (None, _) => (), (Some(else_block), None) => { @@ -2066,11 +2020,13 @@ pub fn mir_body_for_closure_query( let Some(sig) = ClosureSubst(substs).sig_ty().callable_sig(db) else { implementation_error!("closure has not callable sig"); }; + let resolver_guard = ctx.resolver.update_to_inner_scope(db.upcast(), owner, expr); let current = ctx.lower_params_and_bindings( args.iter().zip(sig.params().iter()).map(|(it, y)| (*it, y.clone())), None, |_| true, )?; + ctx.resolver.reset_to_guard(resolver_guard); if let Some(current) = ctx.lower_expr_to_place(*root, return_slot().into(), current)? { let current = ctx.pop_drop_scope_assert_finished(current, root.into())?; ctx.set_terminator(current, TerminatorKind::Return, (*root).into()); diff --git a/crates/hir-ty/src/mir/lower/as_place.rs b/crates/hir-ty/src/mir/lower/as_place.rs index 424ee1160c82..91086e805798 100644 --- a/crates/hir-ty/src/mir/lower/as_place.rs +++ b/crates/hir-ty/src/mir/lower/as_place.rs @@ -135,8 +135,11 @@ impl MirLowerCtx<'_> { }; match &self.body.exprs[expr_id] { Expr::Path(p) => { - let resolver = resolver_for_expr(self.db.upcast(), self.owner, expr_id); - let Some(pr) = resolver.resolve_path_in_value_ns_fully(self.db.upcast(), p) else { + let resolver_guard = + self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr_id); + let resolved = self.resolver.resolve_path_in_value_ns_fully(self.db.upcast(), p); + self.resolver.reset_to_guard(resolver_guard); + let Some(pr) = resolved else { return try_rvalue(self); }; match pr { @@ -216,7 +219,7 @@ impl MirLowerCtx<'_> { self.push_field_projection(&mut r, expr_id)?; Ok(Some((r, current))) } - Expr::Index { base, index, is_assignee_expr: _ } => { + Expr::Index { base, index } => { let base_ty = self.expr_ty_after_adjustments(*base); let index_ty = self.expr_ty_after_adjustments(*index); if index_ty != TyBuilder::usize() diff --git a/crates/hir-ty/src/mir/lower/pattern_matching.rs b/crates/hir-ty/src/mir/lower/pattern_matching.rs index b1c0d1f2b390..63bb3367771f 100644 --- a/crates/hir-ty/src/mir/lower/pattern_matching.rs +++ b/crates/hir-ty/src/mir/lower/pattern_matching.rs @@ -1,6 +1,6 @@ //! MIR lowering for patterns -use hir_def::{hir::LiteralOrConst, resolver::HasResolver, AssocItemId}; +use hir_def::{hir::LiteralOrConst, AssocItemId}; use crate::{ mir::{ @@ -46,6 +46,8 @@ enum MatchingMode { Check, /// Assume that this pattern matches, fill bindings Bind, + /// Assume that this pattern matches, assign to existing variables. + Assign, } impl MirLowerCtx<'_> { @@ -82,6 +84,17 @@ impl MirLowerCtx<'_> { Ok((current, current_else)) } + pub(super) fn pattern_match_assignment( + &mut self, + current: BasicBlockId, + value: Place, + pattern: PatId, + ) -> Result { + let (current, _) = + self.pattern_match_inner(current, None, value, pattern, MatchingMode::Assign)?; + Ok(current) + } + pub(super) fn match_self_param( &mut self, id: BindingId, @@ -155,14 +168,8 @@ impl MirLowerCtx<'_> { *pat, MatchingMode::Check, )?; - if mode == MatchingMode::Bind { - (next, _) = self.pattern_match_inner( - next, - None, - cond_place, - *pat, - MatchingMode::Bind, - )?; + if mode != MatchingMode::Check { + (next, _) = self.pattern_match_inner(next, None, cond_place, *pat, mode)?; } self.set_goto(next, then_target, pattern.into()); match next_else { @@ -176,11 +183,11 @@ impl MirLowerCtx<'_> { } } if !finished { - if mode == MatchingMode::Bind { - self.set_terminator(current, TerminatorKind::Unreachable, pattern.into()); - } else { + if mode == MatchingMode::Check { let ce = *current_else.get_or_insert_with(|| self.new_basic_block()); self.set_goto(current, ce, pattern.into()); + } else { + self.set_terminator(current, TerminatorKind::Unreachable, pattern.into()); } } (then_target, current_else) @@ -300,7 +307,7 @@ impl MirLowerCtx<'_> { self.pattern_match_inner(current, current_else, next_place, pat, mode)?; } if let &Some(slice) = slice { - if mode == MatchingMode::Bind { + if mode != MatchingMode::Check { if let Pat::Bind { id, subpat: _ } = self.body[slice] { let next_place = cond_place.project( ProjectionElem::Subslice { @@ -342,17 +349,34 @@ impl MirLowerCtx<'_> { mode, )?, None => { - // The path is not a variant, so it is a const - if mode != MatchingMode::Check { - // A const don't bind anything. Only needs check. - return Ok((current, current_else)); - } let unresolved_name = || MirLowerError::unresolved_path(self.db, p, self.edition()); - let resolver = self.owner.resolver(self.db.upcast()); - let pr = resolver + let pr = self + .resolver .resolve_path_in_value_ns(self.db.upcast(), p) .ok_or_else(unresolved_name)?; + + if let ( + MatchingMode::Assign, + ResolveValueResult::ValueNs(ValueNs::LocalBinding(binding), _), + ) = (mode, &pr) + { + let local = self.binding_local(*binding)?; + self.push_match_assignment( + current, + local, + BindingMode::Move, + cond_place, + pattern.into(), + ); + return Ok((current, current_else)); + } + + // The path is not a variant or a local, so it is a const + if mode != MatchingMode::Check { + // A const don't bind anything. Only needs check. + return Ok((current, current_else)); + } let (c, subst) = 'b: { if let Some(x) = self.infer.assoc_resolutions_for_pat(pattern) { if let AssocItemId::ConstId(c) = x.0 { @@ -415,7 +439,7 @@ impl MirLowerCtx<'_> { (current, current_else) = self.pattern_match_inner(current, current_else, cond_place, *subpat, mode)? } - if mode == MatchingMode::Bind { + if mode != MatchingMode::Check { let mode = self.infer.binding_modes[pattern]; self.pattern_match_binding( *id, @@ -448,6 +472,23 @@ impl MirLowerCtx<'_> { cond_place.project(ProjectionElem::Deref, &mut self.result.projection_store); self.pattern_match_inner(current, current_else, cond_place, *pat, mode)? } + &Pat::Expr(expr) => { + stdx::always!( + mode == MatchingMode::Assign, + "Pat::Expr can only come in destructuring assignments" + ); + let Some((lhs_place, current)) = self.lower_expr_as_place(current, expr, false)? + else { + return Ok((current, current_else)); + }; + self.push_assignment( + current, + lhs_place, + Operand::Copy(cond_place).into(), + expr.into(), + ); + (current, current_else) + } Pat::Box { .. } => not_supported!("box pattern"), Pat::ConstBlock(_) => not_supported!("const block pattern"), }) @@ -464,6 +505,18 @@ impl MirLowerCtx<'_> { ) -> Result<(BasicBlockId, Option)> { let target_place = self.binding_local(id)?; self.push_storage_live(id, current)?; + self.push_match_assignment(current, target_place, mode, cond_place, span); + Ok((current, current_else)) + } + + fn push_match_assignment( + &mut self, + current: BasicBlockId, + target_place: LocalId, + mode: BindingMode, + cond_place: Place, + span: MirSpan, + ) { self.push_assignment( current, target_place.into(), @@ -476,7 +529,6 @@ impl MirLowerCtx<'_> { }, span, ); - Ok((current, current_else)) } fn pattern_match_const( diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index a8170b606060..e6ef56b80d93 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -3418,11 +3418,11 @@ struct TS(usize); fn main() { let x; [x,] = &[1,]; - //^^^^expected &'? [i32; 1], got [{unknown}; _] + //^^^^expected &'? [i32; 1], got [{unknown}] let x; [(x,),] = &[(1,),]; - //^^^^^^^expected &'? [(i32,); 1], got [{unknown}; _] + //^^^^^^^expected &'? [(i32,); 1], got [{unknown}] let x; ((x,),) = &((1,),); diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index 0b3cdb2f3790..2fedffe047db 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -246,7 +246,7 @@ pub struct UnresolvedAssocItem { #[derive(Debug)] pub struct UnresolvedIdent { - pub expr: InFile>, + pub expr_or_pat: InFile>>, } #[derive(Debug)] @@ -257,7 +257,7 @@ pub struct PrivateField { #[derive(Debug)] pub struct MissingUnsafe { - pub expr: InFile>, + pub expr: InFile>>, /// If true, the diagnostics is an `unsafe_op_in_unsafe_fn` lint instead of a hard error. pub only_lint: bool, } @@ -398,56 +398,46 @@ impl AnyDiagnostic { .map(|idx| variant_data.fields()[idx].name.clone()) .collect(); - match record { - Either::Left(record_expr) => match source_map.expr_syntax(record_expr) { - Ok(source_ptr) => { - let root = source_ptr.file_syntax(db.upcast()); - if let ast::Expr::RecordExpr(record_expr) = - source_ptr.value.to_node(&root) - { - if record_expr.record_expr_field_list().is_some() { - let field_list_parent_path = - record_expr.path().map(|path| AstPtr::new(&path)); - return Some( - MissingFields { - file: source_ptr.file_id, - field_list_parent: AstPtr::new(&Either::Left( - record_expr, - )), - field_list_parent_path, - missed_fields, - } - .into(), - ); + let record = match record { + Either::Left(record_expr) => { + source_map.expr_syntax(record_expr).ok()?.map(AstPtr::wrap_left) + } + Either::Right(record_pat) => source_map.pat_syntax(record_pat).ok()?, + }; + let file = record.file_id; + let root = record.file_syntax(db.upcast()); + match record.value.to_node(&root) { + Either::Left(ast::Expr::RecordExpr(record_expr)) => { + if record_expr.record_expr_field_list().is_some() { + let field_list_parent_path = + record_expr.path().map(|path| AstPtr::new(&path)); + return Some( + MissingFields { + file, + field_list_parent: AstPtr::new(&Either::Left(record_expr)), + field_list_parent_path, + missed_fields, } - } + .into(), + ); } - Err(SyntheticSyntax) => (), - }, - Either::Right(record_pat) => match source_map.pat_syntax(record_pat) { - Ok(source_ptr) => { - if let Some(ptr) = source_ptr.value.cast::() { - let root = source_ptr.file_syntax(db.upcast()); - let record_pat = ptr.to_node(&root); - if record_pat.record_pat_field_list().is_some() { - let field_list_parent_path = - record_pat.path().map(|path| AstPtr::new(&path)); - return Some( - MissingFields { - file: source_ptr.file_id, - field_list_parent: AstPtr::new(&Either::Right( - record_pat, - )), - field_list_parent_path, - missed_fields, - } - .into(), - ); + } + Either::Right(ast::Pat::RecordPat(record_pat)) => { + if record_pat.record_pat_field_list().is_some() { + let field_list_parent_path = + record_pat.path().map(|path| AstPtr::new(&path)); + return Some( + MissingFields { + file, + field_list_parent: AstPtr::new(&Either::Right(record_pat)), + field_list_parent_path, + missed_fields, } - } + .into(), + ); } - Err(SyntheticSyntax) => (), - }, + } + _ => {} } } BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap { method_call_expr } => { @@ -541,15 +531,17 @@ impl AnyDiagnostic { let pat_syntax = |pat| { source_map.pat_syntax(pat).inspect_err(|_| tracing::error!("synthetic syntax")).ok() }; + let expr_or_pat_syntax = |id| match id { + ExprOrPatId::ExprId(expr) => expr_syntax(expr).map(|it| it.map(AstPtr::wrap_left)), + ExprOrPatId::PatId(pat) => pat_syntax(pat), + }; Some(match d { &InferenceDiagnostic::NoSuchField { field: expr, private, variant } => { let expr_or_pat = match expr { ExprOrPatId::ExprId(expr) => { source_map.field_syntax(expr).map(AstPtr::wrap_left) } - ExprOrPatId::PatId(pat) => { - source_map.pat_field_syntax(pat).map(AstPtr::wrap_right) - } + ExprOrPatId::PatId(pat) => source_map.pat_field_syntax(pat), }; NoSuchField { field: expr_or_pat, private, variant }.into() } @@ -562,10 +554,7 @@ impl AnyDiagnostic { PrivateField { expr, field }.into() } &InferenceDiagnostic::PrivateAssocItem { id, item } => { - let expr_or_pat = match id { - ExprOrPatId::ExprId(expr) => expr_syntax(expr)?.map(AstPtr::wrap_left), - ExprOrPatId::PatId(pat) => pat_syntax(pat)?.map(AstPtr::wrap_right), - }; + let expr_or_pat = expr_or_pat_syntax(id)?; let item = item.into(); PrivateAssocItem { expr_or_pat, item }.into() } @@ -609,15 +598,12 @@ impl AnyDiagnostic { .into() } &InferenceDiagnostic::UnresolvedAssocItem { id } => { - let expr_or_pat = match id { - ExprOrPatId::ExprId(expr) => expr_syntax(expr)?.map(AstPtr::wrap_left), - ExprOrPatId::PatId(pat) => pat_syntax(pat)?.map(AstPtr::wrap_right), - }; + let expr_or_pat = expr_or_pat_syntax(id)?; UnresolvedAssocItem { expr_or_pat }.into() } - &InferenceDiagnostic::UnresolvedIdent { expr } => { - let expr = expr_syntax(expr)?; - UnresolvedIdent { expr }.into() + &InferenceDiagnostic::UnresolvedIdent { id } => { + let expr_or_pat = expr_or_pat_syntax(id)?; + UnresolvedIdent { expr_or_pat }.into() } &InferenceDiagnostic::BreakOutsideOfLoop { expr, is_break, bad_value_break } => { let expr = expr_syntax(expr)?; diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 30e023e1a472..4a18795e7d62 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1885,7 +1885,7 @@ impl DefWithBody { let (unafe_exprs, only_lint) = hir_ty::diagnostics::missing_unsafe(db, self.into()); for expr in unafe_exprs { - match source_map.expr_syntax(expr) { + match source_map.expr_or_pat_syntax(expr) { Ok(expr) => acc.push(MissingUnsafe { expr, only_lint }.into()), Err(SyntheticSyntax) => { // FIXME: Here and elsewhere in this file, the `expr` was @@ -3481,7 +3481,7 @@ impl Local { LocalSource { local: self, source: src.map(|ast| match ast.to_node(&root) { - ast::Pat::IdentPat(it) => Either::Left(it), + Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), _ => unreachable!("local with non ident-pattern"), }), } @@ -3510,7 +3510,7 @@ impl Local { LocalSource { local: self, source: src.map(|ast| match ast.to_node(&root) { - ast::Pat::IdentPat(it) => Either::Left(it), + Either::Right(ast::Pat::IdentPat(it)) => Either::Left(it), _ => unreachable!("local with non ident-pattern"), }), } @@ -4235,10 +4235,7 @@ impl CaptureUsages { } mir::MirSpan::PatId(pat) => { if let Ok(pat) = source_map.pat_syntax(pat) { - result.push(CaptureUsageSource { - is_ref, - source: pat.map(AstPtr::wrap_right), - }); + result.push(CaptureUsageSource { is_ref, source: pat }); } } mir::MirSpan::BindingId(binding) => result.extend( @@ -4246,10 +4243,7 @@ impl CaptureUsages { .patterns_for_binding(binding) .iter() .filter_map(|&pat| source_map.pat_syntax(pat).ok()) - .map(|pat| CaptureUsageSource { - is_ref, - source: pat.map(AstPtr::wrap_right), - }), + .map(|pat| CaptureUsageSource { is_ref, source: pat }), ), mir::MirSpan::SelfParam | mir::MirSpan::Unknown => { unreachable!("invalid capture usage span") diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index b27f1fbb5db1..d210c2671435 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -11,7 +11,7 @@ use std::{ use either::Either; use hir_def::{ - hir::Expr, + hir::{Expr, ExprOrPatId}, lower::LowerCtx, nameres::{MacroSubNs, ModuleOrigin}, path::ModPath, @@ -1755,7 +1755,9 @@ impl<'db> SemanticsImpl<'db> { } if let Some(parent) = ast::Expr::cast(parent.clone()) { - if let Some(expr_id) = source_map.node_expr(InFile { file_id, value: &parent }) { + if let Some(ExprOrPatId::ExprId(expr_id)) = + source_map.node_expr(InFile { file_id, value: &parent }) + { if let Expr::Unsafe { .. } = body[expr_id] { break true; } diff --git a/crates/hir/src/semantics/source_to_def.rs b/crates/hir/src/semantics/source_to_def.rs index fd6d52d6c9df..e53a7da7edb3 100644 --- a/crates/hir/src/semantics/source_to_def.rs +++ b/crates/hir/src/semantics/source_to_def.rs @@ -306,7 +306,7 @@ impl SourceToDefCtx<'_, '_> { .position(|it| it == *src.value)?; let container = self.find_pat_or_label_container(src.syntax_ref())?; let (_, source_map) = self.db.body_with_source_map(container); - let expr = source_map.node_expr(src.with_value(&ast::Expr::AsmExpr(asm)))?; + let expr = source_map.node_expr(src.with_value(&ast::Expr::AsmExpr(asm)))?.as_expr()?; Some(InlineAsmOperand { owner: container, expr, index }) } @@ -350,7 +350,8 @@ impl SourceToDefCtx<'_, '_> { let break_or_continue = ast::Expr::cast(src.value.syntax().parent()?)?; let container = self.find_pat_or_label_container(src.syntax_ref())?; let (body, source_map) = self.db.body_with_source_map(container); - let break_or_continue = source_map.node_expr(src.with_value(&break_or_continue))?; + let break_or_continue = + source_map.node_expr(src.with_value(&break_or_continue))?.as_expr()?; let (Expr::Break { label, .. } | Expr::Continue { label }) = body[break_or_continue] else { return None; }; diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index 3da67ae23f83..f2f27517fd66 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -13,7 +13,7 @@ use hir_def::{ scope::{ExprScopes, ScopeId}, Body, BodySourceMap, }, - hir::{BindingId, ExprId, Pat, PatId}, + hir::{BindingId, ExprId, ExprOrPatId, Pat, PatId}, lang_item::LangItem, lower::LowerCtx, nameres::MacroSubNs, @@ -120,7 +120,7 @@ impl SourceAnalyzer { self.def.as_ref().map(|(_, body, _)| &**body) } - fn expr_id(&self, db: &dyn HirDatabase, expr: &ast::Expr) -> Option { + fn expr_id(&self, db: &dyn HirDatabase, expr: &ast::Expr) -> Option { let src = match expr { ast::Expr::MacroExpr(expr) => { self.expand_expr(db, InFile::new(self.file_id, expr.macro_call()?))?.into() @@ -174,7 +174,9 @@ impl SourceAnalyzer { db: &dyn HirDatabase, expr: &ast::Expr, ) -> Option<&[Adjustment]> { - let expr_id = self.expr_id(db, expr)?; + // It is safe to omit destructuring assignments here because they have no adjustments (neither + // expressions nor patterns). + let expr_id = self.expr_id(db, expr)?.as_expr()?; let infer = self.infer.as_ref()?; infer.expr_adjustments.get(&expr_id).map(|v| &**v) } @@ -186,9 +188,9 @@ impl SourceAnalyzer { ) -> Option<(Type, Option)> { let expr_id = self.expr_id(db, expr)?; let infer = self.infer.as_ref()?; - let coerced = infer - .expr_adjustments - .get(&expr_id) + let coerced = expr_id + .as_expr() + .and_then(|expr_id| infer.expr_adjustments.get(&expr_id)) .and_then(|adjusts| adjusts.last().map(|adjust| adjust.target.clone())); let ty = infer[expr_id].clone(); let mk_ty = |ty| Type::new_with_resolver(db, &self.resolver, ty); @@ -268,7 +270,7 @@ impl SourceAnalyzer { db: &dyn HirDatabase, call: &ast::MethodCallExpr, ) -> Option { - let expr_id = self.expr_id(db, &call.clone().into())?; + let expr_id = self.expr_id(db, &call.clone().into())?.as_expr()?; let (func, substs) = self.infer.as_ref()?.method_resolution(expr_id)?; let ty = db.value_ty(func.into())?.substitute(Interner, &substs); let ty = Type::new_with_resolver(db, &self.resolver, ty); @@ -282,7 +284,7 @@ impl SourceAnalyzer { db: &dyn HirDatabase, call: &ast::MethodCallExpr, ) -> Option { - let expr_id = self.expr_id(db, &call.clone().into())?; + let expr_id = self.expr_id(db, &call.clone().into())?.as_expr()?; let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?; Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs).into()) @@ -293,7 +295,7 @@ impl SourceAnalyzer { db: &dyn HirDatabase, call: &ast::MethodCallExpr, ) -> Option> { - let expr_id = self.expr_id(db, &call.clone().into())?; + let expr_id = self.expr_id(db, &call.clone().into())?.as_expr()?; let inference_result = self.infer.as_ref()?; match inference_result.method_resolution(expr_id) { Some((f_in_trait, substs)) => Some(Either::Left( @@ -322,7 +324,7 @@ impl SourceAnalyzer { field: &ast::FieldExpr, ) -> Option> { let &(def, ..) = self.def.as_ref()?; - let expr_id = self.expr_id(db, &field.clone().into())?; + let expr_id = self.expr_id(db, &field.clone().into())?.as_expr()?; self.infer.as_ref()?.field_resolution(expr_id).map(|it| { it.map_either(Into::into, |f| TupleField { owner: def, tuple: f.tuple, index: f.index }) }) @@ -334,7 +336,7 @@ impl SourceAnalyzer { field: &ast::FieldExpr, ) -> Option, Function>> { let &(def, ..) = self.def.as_ref()?; - let expr_id = self.expr_id(db, &field.clone().into())?; + let expr_id = self.expr_id(db, &field.clone().into())?.as_expr()?; let inference_result = self.infer.as_ref()?; match inference_result.field_resolution(expr_id) { Some(field) => Some(Either::Left(field.map_either(Into::into, |f| TupleField { @@ -403,7 +405,7 @@ impl SourceAnalyzer { self.infer .as_ref() .and_then(|infer| { - let expr = self.expr_id(db, &prefix_expr.clone().into())?; + let expr = self.expr_id(db, &prefix_expr.clone().into())?.as_expr()?; let (func, _) = infer.method_resolution(expr)?; let (deref_mut_trait, deref_mut) = self.lang_trait_fn( db, @@ -449,7 +451,7 @@ impl SourceAnalyzer { .infer .as_ref() .and_then(|infer| { - let expr = self.expr_id(db, &index_expr.clone().into())?; + let expr = self.expr_id(db, &index_expr.clone().into())?.as_expr()?; let (func, _) = infer.method_resolution(expr)?; let (index_mut_trait, index_mut_fn) = self.lang_trait_fn( db, @@ -537,8 +539,8 @@ impl SourceAnalyzer { _ => None, } }; - let (_, subst) = self.infer.as_ref()?.type_of_expr.get(expr_id)?.as_adt()?; - let variant = self.infer.as_ref()?.variant_resolution_for_expr(expr_id)?; + let (_, subst) = self.infer.as_ref()?.type_of_expr_or_pat(expr_id)?.as_adt()?; + let variant = self.infer.as_ref()?.variant_resolution_for_expr_or_pat(expr_id)?; let variant_data = variant.variant_data(db.upcast()); let field = FieldId { parent: variant, local_id: variant_data.field(&local_name)? }; let field_ty = @@ -606,10 +608,10 @@ impl SourceAnalyzer { let infer = self.infer.as_deref()?; if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) { let expr_id = self.expr_id(db, &path_expr.into())?; - if let Some((assoc, subs)) = infer.assoc_resolutions_for_expr(expr_id) { + if let Some((assoc, subs)) = infer.assoc_resolutions_for_expr_or_pat(expr_id) { let assoc = match assoc { AssocItemId::FunctionId(f_in_trait) => { - match infer.type_of_expr.get(expr_id) { + match infer.type_of_expr_or_pat(expr_id) { None => assoc, Some(func_ty) => { if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) { @@ -634,7 +636,7 @@ impl SourceAnalyzer { return Some(PathResolution::Def(AssocItem::from(assoc).into())); } if let Some(VariantId::EnumVariantId(variant)) = - infer.variant_resolution_for_expr(expr_id) + infer.variant_resolution_for_expr_or_pat(expr_id) { return Some(PathResolution::Def(ModuleDef::Variant(variant.into()))); } @@ -658,7 +660,7 @@ impl SourceAnalyzer { } else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) { let expr_id = self.expr_id(db, &rec_lit.into())?; if let Some(VariantId::EnumVariantId(variant)) = - infer.variant_resolution_for_expr(expr_id) + infer.variant_resolution_for_expr_or_pat(expr_id) { return Some(PathResolution::Def(ModuleDef::Variant(variant.into()))); } @@ -790,10 +792,16 @@ impl SourceAnalyzer { let infer = self.infer.as_ref()?; let expr_id = self.expr_id(db, &literal.clone().into())?; - let substs = infer.type_of_expr[expr_id].as_adt()?.1; + let substs = infer[expr_id].as_adt()?.1; - let (variant, missing_fields, _exhaustive) = - record_literal_missing_fields(db, infer, expr_id, &body[expr_id])?; + let (variant, missing_fields, _exhaustive) = match expr_id { + ExprOrPatId::ExprId(expr_id) => { + record_literal_missing_fields(db, infer, expr_id, &body[expr_id])? + } + ExprOrPatId::PatId(pat_id) => { + record_pattern_missing_fields(db, infer, pat_id, &body[pat_id])? + } + }; let res = self.missing_fields(db, substs, variant, missing_fields); Some(res) } @@ -856,7 +864,7 @@ impl SourceAnalyzer { ) -> Option { let infer = self.infer.as_ref()?; let expr_id = self.expr_id(db, &record_lit.into())?; - infer.variant_resolution_for_expr(expr_id) + infer.variant_resolution_for_expr_or_pat(expr_id) } pub(crate) fn is_unsafe_macro_call_expr( @@ -867,14 +875,24 @@ impl SourceAnalyzer { if let (Some((def, body, sm)), Some(infer)) = (&self.def, &self.infer) { if let Some(expanded_expr) = sm.macro_expansion_expr(macro_expr) { let mut is_unsafe = false; - unsafe_expressions( - db, - infer, - *def, - body, - expanded_expr, - &mut |UnsafeExpr { inside_unsafe_block, .. }| is_unsafe |= !inside_unsafe_block, - ); + let mut walk_expr = |expr_id| { + unsafe_expressions( + db, + infer, + *def, + body, + expr_id, + &mut |UnsafeExpr { inside_unsafe_block, .. }| { + is_unsafe |= !inside_unsafe_block + }, + ) + }; + match expanded_expr { + ExprOrPatId::ExprId(expanded_expr) => walk_expr(expanded_expr), + ExprOrPatId::PatId(expanded_pat) => { + body.walk_exprs_in_pat(expanded_pat, &mut walk_expr) + } + } return is_unsafe; } } @@ -991,7 +1009,7 @@ impl SourceAnalyzer { } fn ty_of_expr(&self, db: &dyn HirDatabase, expr: &ast::Expr) -> Option<&Ty> { - self.infer.as_ref()?.type_of_expr.get(self.expr_id(db, expr)?) + self.infer.as_ref()?.type_of_expr_or_pat(self.expr_id(db, expr)?) } } @@ -1004,7 +1022,7 @@ fn scope_for( node.ancestors_with_macros(db.upcast()) .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind())) .filter_map(|it| it.map(ast::Expr::cast).transpose()) - .filter_map(|it| source_map.node_expr(it.as_ref())) + .filter_map(|it| source_map.node_expr(it.as_ref())?.as_expr()) .find_map(|it| scopes.scope_for(it)) } diff --git a/crates/ide-diagnostics/src/handlers/missing_fields.rs b/crates/ide-diagnostics/src/handlers/missing_fields.rs index 86c237f7b5ec..3a622c69681e 100644 --- a/crates/ide-diagnostics/src/handlers/missing_fields.rs +++ b/crates/ide-diagnostics/src/handlers/missing_fields.rs @@ -308,22 +308,27 @@ struct T(S); fn regular(a: S) { let s; S { s, .. } = a; + _ = s; } fn nested(a: S2) { let s; S2 { s: S { s, .. }, .. } = a; + _ = s; } fn in_tuple(a: (S,)) { let s; (S { s, .. },) = a; + _ = s; } fn in_array(a: [S;1]) { let s; [S { s, .. },] = a; + _ = s; } fn in_tuple_struct(a: T) { let s; T(S { s, .. }) = a; + _ = s; } ", ); diff --git a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index cc0f4bfccc9a..98063bf4fef8 100644 --- a/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -32,7 +32,8 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingUnsafe) -> Option, d: &hir::NeedMut) -> Option { - if d.span.file_id.macro_file().is_some() { - // FIXME: Our infra can't handle allow from within macro expansions rn - return None; - } + let root = ctx.sema.db.parse_or_expand(d.span.file_id); + let node = d.span.value.to_node(&root); + let mut span = d.span; + if let Some(parent) = node.parent() { + if ast::BinExpr::can_cast(parent.kind()) { + // In case of an assignment, the diagnostic is provided on the variable name. + // We want to expand it to include the whole assignment, but only when this + // is an ordinary assignment, not a destructuring assignment. So, the direct + // parent is an assignment expression. + span = d.span.with_value(SyntaxNodePtr::new(&parent)); + } + }; + let fixes = (|| { if d.local.is_ref(ctx.sema.db) { // There is no simple way to add `mut` to `ref x` and `ref mut x` return None; } - let file_id = d.span.file_id.file_id()?; + let file_id = span.file_id.file_id()?; let mut edit_builder = TextEdit::builder(); - let use_range = d.span.value.text_range(); + let use_range = span.value.text_range(); for source in d.local.sources(ctx.sema.db) { let Some(ast) = source.name() else { continue }; // FIXME: macros @@ -33,6 +43,7 @@ pub(crate) fn need_mut(ctx: &DiagnosticsContext<'_>, d: &hir::NeedMut) -> Option use_range, )]) })(); + Some( Diagnostic::new_with_syntax_node_ptr( ctx, @@ -42,7 +53,7 @@ pub(crate) fn need_mut(ctx: &DiagnosticsContext<'_>, d: &hir::NeedMut) -> Option "cannot mutate immutable variable `{}`", d.local.name(ctx.sema.db).display(ctx.sema.db, ctx.edition) ), - d.span, + span, ) .with_fixes(fixes), ) @@ -53,10 +64,6 @@ pub(crate) fn need_mut(ctx: &DiagnosticsContext<'_>, d: &hir::NeedMut) -> Option // This diagnostic is triggered when a mutable variable isn't actually mutated. pub(crate) fn unused_mut(ctx: &DiagnosticsContext<'_>, d: &hir::UnusedMut) -> Option { let ast = d.local.primary_source(ctx.sema.db).syntax_ptr(); - if ast.file_id.macro_file().is_some() { - // FIXME: Our infra can't handle allow from within macro expansions rn - return None; - } let fixes = (|| { let file_id = ast.file_id.file_id()?; let mut edit_builder = TextEdit::builder(); @@ -937,7 +944,6 @@ fn fn_once(mut x: impl FnOnce(u8) -> u8) -> u8 { #[test] fn closure() { - // FIXME: Diagnostic spans are inconsistent inside and outside closure check_diagnostics( r#" //- minicore: copy, fn @@ -950,11 +956,11 @@ fn fn_once(mut x: impl FnOnce(u8) -> u8) -> u8 { fn f() { let x = 5; let closure1 = || { x = 2; }; - //^ 💡 error: cannot mutate immutable variable `x` + //^^^^^ 💡 error: cannot mutate immutable variable `x` let _ = closure1(); //^^^^^^^^ 💡 error: cannot mutate immutable variable `closure1` let closure2 = || { x = x; }; - //^ 💡 error: cannot mutate immutable variable `x` + //^^^^^ 💡 error: cannot mutate immutable variable `x` let closure3 = || { let x = 2; x = 5; @@ -996,7 +1002,7 @@ fn f() { || { let x = 2; || { || { x = 5; } } - //^ 💡 error: cannot mutate immutable variable `x` + //^^^^^ 💡 error: cannot mutate immutable variable `x` } } }; @@ -1283,4 +1289,19 @@ fn main() { "#, ); } + + #[test] + fn destructuring_assignment_needs_mut() { + check_diagnostics( + r#" +//- minicore: fn + +fn main() { + let mut var = 1; + let mut func = || (var,) = (2,); + func(); +} + "#, + ); + } } diff --git a/crates/ide-diagnostics/src/handlers/unresolved_ident.rs b/crates/ide-diagnostics/src/handlers/unresolved_ident.rs index 9a81682aaeba..68f14a97f594 100644 --- a/crates/ide-diagnostics/src/handlers/unresolved_ident.rs +++ b/crates/ide-diagnostics/src/handlers/unresolved_ident.rs @@ -11,7 +11,7 @@ pub(crate) fn unresolved_ident( ctx, DiagnosticCode::RustcHardError("E0425"), "no such value in this scope", - d.expr.map(Into::into), + d.expr_or_pat.map(Into::into), ) .experimental() } diff --git a/crates/syntax/src/ast/expr_ext.rs b/crates/syntax/src/ast/expr_ext.rs index 6ed205e2856f..f3053f59836f 100644 --- a/crates/syntax/src/ast/expr_ext.rs +++ b/crates/syntax/src/ast/expr_ext.rs @@ -232,6 +232,10 @@ impl ast::RangeExpr { Some((ix, token, bin_op)) }) } + + pub fn is_range_full(&self) -> bool { + support::children::(&self.syntax).next().is_none() + } } impl RangeItem for ast::RangeExpr {