diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 93fe9374a3ee..bfdda5374053 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -1,12 +1,16 @@ use either::Either; -use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; -use ide_db::text_edit::TextEdit; -use ide_db::{famous_defs::FamousDefs, source_change::SourceChange}; +use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; +use ide_db::{ + famous_defs::FamousDefs, + source_change::{SourceChange, SourceChangeBuilder}, + text_edit::TextEdit, +}; use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - BlockExpr, Expr, ExprStmt, + syntax_factory::SyntaxFactory, + BlockExpr, Expr, ExprStmt, HasArgList, }, AstNode, AstPtr, TextSize, }; @@ -63,6 +67,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option, + d: &hir::TypeMismatch, + expr_ptr: &InFile>, + acc: &mut Vec, +) -> Option<()> { + let db = ctx.sema.db; + let root = db.parse_or_expand(expr_ptr.file_id); + let expr = expr_ptr.value.to_node(&root); + let expr = ctx.sema.original_ast_node(expr.clone())?; + + let Expr::CallExpr(call_expr) = expr else { + return None; + }; + + let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?; + let CallableKind::TupleEnumVariant(variant) = callable.kind() else { + return None; + }; + + let actual_enum = d.actual.as_adt()?.as_enum()?; + let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate()); + let core_option = famous_defs.core_option_Option(); + let core_result = famous_defs.core_result_Result(); + if Some(actual_enum) != core_option && Some(actual_enum) != core_result { + return None; + } + + let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments()); + if !d.expected.could_unify_with(db, &inner_type) { + return None; + } + + let inner_arg = call_expr.arg_list()?.args().next()?; + + let file_id = expr_ptr.file_id.original_file(db); + let mut builder = SourceChangeBuilder::new(file_id); + let mut editor; + match inner_arg { + // We're returning `()` + Expr::TupleExpr(tup) if tup.fields().next().is_none() => { + let parent = call_expr + .syntax() + .parent() + .and_then(Either::::cast)?; + + editor = builder.make_editor(parent.syntax()); + let make = SyntaxFactory::new(); + + match parent { + Either::Left(ret_expr) => { + editor.replace(ret_expr.syntax(), make.expr_return(None).syntax()); + } + Either::Right(stmt_list) => { + let new_block = if stmt_list.statements().next().is_none() { + make.expr_empty_block() + } else { + make.block_expr(stmt_list.statements(), None) + }; + + editor.replace(stmt_list.syntax().parent()?, new_block.syntax()); + } + } + + editor.add_mappings(make.finish_with_mappings()); + } + _ => { + editor = builder.make_editor(call_expr.syntax()); + editor.replace(call_expr.syntax(), inner_arg.syntax()); + } + } + + builder.add_file_edits(file_id, editor); + let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str()); + acc.push(fix( + "remove_unnecessary_wrapper", + &name, + builder.finish(), + call_expr.syntax().text_range(), + )); + Some(()) +} + fn remove_semicolon( ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, @@ -243,7 +331,7 @@ fn str_ref_to_owned( #[cfg(test)] mod tests { use crate::tests::{ - check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix, + check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix, }; #[test] @@ -260,7 +348,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_reference_to_int() { + fn add_reference_to_int() { check_fix( r#" fn main() { @@ -278,7 +366,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_mutable_reference_to_int() { + fn add_mutable_reference_to_int() { check_fix( r#" fn main() { @@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {} } #[test] - fn test_add_reference_to_array() { + fn add_reference_to_array() { check_fix( r#" //- minicore: coerce_unsized @@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {} } #[test] - fn test_add_reference_with_autoderef() { + fn add_reference_with_autoderef() { check_fix( r#" //- minicore: coerce_unsized, deref @@ -348,7 +436,7 @@ fn test(_arg: &Bar) {} } #[test] - fn test_add_reference_to_method_call() { + fn add_reference_to_method_call() { check_fix( r#" fn main() { @@ -372,7 +460,7 @@ impl Test { } #[test] - fn test_add_reference_to_let_stmt() { + fn add_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -388,7 +476,7 @@ fn main() { } #[test] - fn test_add_reference_to_macro_call() { + fn add_reference_to_macro_call() { check_fix( r#" macro_rules! thousand { @@ -416,7 +504,7 @@ fn main() { } #[test] - fn test_add_mutable_reference_to_let_stmt() { + fn add_mutable_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -431,29 +519,6 @@ fn main() { ); } - #[test] - fn test_wrap_return_type_option() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - Some(x / y) -} -"#, - ); - } - #[test] fn const_generic_type_mismatch() { check_diagnostics( @@ -487,59 +552,82 @@ fn div(x: i32, y: i32) -> Option { } #[test] - fn test_wrap_return_type_option_tails() { + fn wrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + Ok(x / y) +} +"#, + ); + } + + #[test] + fn wrap_return_type_option() { check_fix( r#" //- minicore: option, result fn div(x: i32, y: i32) -> Option { if y == 0 { - Some(0) - } else if true { - 100$0 - } else { - None + return None; } + x / y$0 } "#, r#" fn div(x: i32, y: i32) -> Option { if y == 0 { - Some(0) - } else if true { - Some(100) - } else { - None + return None; } + Some(x / y) } "#, ); } #[test] - fn test_wrap_return_type() { + fn wrap_return_type_option_tails() { check_fix( r#" //- minicore: option, result -fn div(x: i32, y: i32) -> Result { +fn div(x: i32, y: i32) -> Option { if y == 0 { - return Err(()); + Some(0) + } else if true { + 100$0 + } else { + None } - x / y$0 } "#, r#" -fn div(x: i32, y: i32) -> Result { +fn div(x: i32, y: i32) -> Option { if y == 0 { - return Err(()); + Some(0) + } else if true { + Some(100) + } else { + None } - Ok(x / y) } "#, ); } #[test] - fn test_wrap_return_type_handles_generic_functions() { + fn wrap_return_type_handles_generic_functions() { check_fix( r#" //- minicore: option, result @@ -562,7 +650,7 @@ fn div(x: T) -> Result { } #[test] - fn test_wrap_return_type_handles_type_aliases() { + fn wrap_return_type_handles_type_aliases() { check_fix( r#" //- minicore: option, result @@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult { } #[test] - fn test_wrapped_unit_as_block_tail_expr() { + fn wrapped_unit_as_block_tail_expr() { check_fix( r#" //- minicore: result @@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> { } #[test] - fn test_wrapped_unit_as_return_expr() { + fn wrapped_unit_as_return_expr() { check_fix( r#" //- minicore: result @@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> { } #[test] - fn test_in_const_and_static() { + fn wrap_in_const_and_static() { check_fix( r#" //- minicore: option, result @@ -664,7 +752,7 @@ const _: Option<()> = {Some(())}; } #[test] - fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { + fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { check_no_fix( r#" //- minicore: option, result @@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 } } #[test] - fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { + fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { check_no_fix( r#" //- minicore: option, result @@ -685,6 +773,254 @@ fn foo() -> SomeOtherEnum { 0$0 } ); } + #[test] + fn unwrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Ok(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Some(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option_tails() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + Some(100)$0 + } else { + 0 + } +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + 100 + } else { + 0 + } +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option_tail_unit() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } + + Ok(())$0 +} +"#, + r#" +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_generic_functions() { + check_fix( + r#" +//- minicore: option, result +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + $0Ok(x) +} +"#, + r#" +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + x +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_type_aliases() { + check_fix( + r#" +//- minicore: option, result +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + Ok(x $0/ y) +} +"#, + r#" +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_tail_expr() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + println!("Hello, world!"); + Ok(())$0 +} + "#, + r#" +fn foo() -> () { + println!("Hello, world!"); +} + "#, + ); + } + + #[test] + fn unwrap_to_empty_block() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + Ok(())$0 +} + "#, + r#" +fn foo() -> () {} + "#, + ); + } + + #[test] + fn unwrap_to_return_expr() { + check_has_fix( + r#" +//- minicore: result +fn foo(b: bool) -> () { + if b { + return $0Ok(()); + } + + panic!("oh dear"); +}"#, + r#" +fn foo(b: bool) -> () { + if b { + return; + } + + panic!("oh dear"); +}"#, + ); + } + + #[test] + fn unwrap_in_const_and_static() { + check_fix( + r#" +//- minicore: option, result +static A: () = {Some(($0))}; + "#, + r#" +static A: () = {}; + "#, + ); + check_fix( + r#" +//- minicore: option, result +const _: () = {Some(($0))}; + "#, + r#" +const _: () = {}; + "#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() { + check_no_fix( + r#" +//- minicore: result +fn foo() -> i32 { $0Ok(()) } +"#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() { + check_no_fix( + r#" +//- minicore: option, result +enum SomeOtherEnum { Ok(i32), Err(String) } + +fn foo() -> i32 { SomeOtherEnum::Ok($042) } +"#, + ); + } + #[test] fn remove_semicolon() { check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#); diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index aa894ef633a1..e86c291f76cc 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -66,28 +66,41 @@ impl SyntaxFactory { tail_expr: Option, ) -> ast::BlockExpr { let stmts = stmts.into_iter().collect_vec(); - let input = stmts.iter().map(|it| it.syntax().clone()).collect_vec(); + let mut input = stmts.iter().map(|it| it.syntax().clone()).collect_vec(); let ast = make::block_expr(stmts, tail_expr.clone()).clone_for_update(); - if let Some((mut mapping, stmt_list)) = self.mappings().zip(ast.stmt_list()) { + if let Some(mut mapping) = self.mappings() { + let stmt_list = ast.stmt_list().unwrap(); let mut builder = SyntaxMappingBuilder::new(stmt_list.syntax().clone()); + if let Some(input) = tail_expr { + builder.map_node( + input.syntax().clone(), + stmt_list.tail_expr().unwrap().syntax().clone(), + ); + } else if let Some(ast_tail) = stmt_list.tail_expr() { + // The parser interpreted the last statement (probably a statement with a block) as an Expr + let last_stmt = input.pop().unwrap(); + + builder.map_node(last_stmt, ast_tail.syntax().clone()); + } + builder.map_children( input.into_iter(), stmt_list.statements().map(|it| it.syntax().clone()), ); - if let Some((input, output)) = tail_expr.zip(stmt_list.tail_expr()) { - builder.map_node(input.syntax().clone(), output.syntax().clone()); - } - builder.finish(&mut mapping); } ast } + pub fn expr_empty_block(&self) -> ast::BlockExpr { + ast::BlockExpr { syntax: make::expr_empty_block().syntax().clone_for_update() } + } + pub fn expr_bin(&self, lhs: ast::Expr, op: ast::BinaryOp, rhs: ast::Expr) -> ast::BinExpr { let ast::Expr::BinExpr(ast) = make::expr_bin_op(lhs.clone(), op, rhs.clone()).clone_for_update() @@ -134,6 +147,22 @@ impl SyntaxFactory { ast.into() } + pub fn expr_return(&self, expr: Option) -> ast::ReturnExpr { + let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(input) = expr { + builder.map_node(input.syntax().clone(), ast.expr().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + pub fn let_stmt( &self, pattern: ast::Pat,