From 4bb623decb15a82a053b54c2ee0b8ce0987dc3d8 Mon Sep 17 00:00:00 2001
From: Shoyu Vanilla <modulo641@gmail.com>
Date: Thu, 4 Jul 2024 23:31:55 +0900
Subject: [PATCH] Disallow nested impl traits

---
 crates/hir-def/src/hir/type_ref.rs |  8 +++-
 crates/hir-def/src/lower.rs        | 30 ++++++++++++
 crates/hir-def/src/path/lower.rs   |  2 +
 crates/hir-ty/src/tests/traits.rs  | 73 ++++++++++++++++++++++++++++++
 4 files changed, 112 insertions(+), 1 deletion(-)

diff --git a/crates/hir-def/src/hir/type_ref.rs b/crates/hir-def/src/hir/type_ref.rs
index ec207a7f9651..741ae41c7434 100644
--- a/crates/hir-def/src/hir/type_ref.rs
+++ b/crates/hir-def/src/hir/type_ref.rs
@@ -245,7 +245,13 @@ impl TypeRef {
             // for types are close enough for our purposes to the inner type for now...
             ast::Type::ForType(inner) => TypeRef::from_ast_opt(ctx, inner.ty()),
             ast::Type::ImplTraitType(inner) => {
-                TypeRef::ImplTrait(type_bounds_from_ast(ctx, inner.type_bound_list()))
+                if ctx.outer_impl_trait() {
+                    // Disallow nested impl traits
+                    TypeRef::Error
+                } else {
+                    let _guard = ctx.outer_impl_trait_scope(true);
+                    TypeRef::ImplTrait(type_bounds_from_ast(ctx, inner.type_bound_list()))
+                }
             }
             ast::Type::DynTraitType(inner) => {
                 TypeRef::DynTrait(type_bounds_from_ast(ctx, inner.type_bound_list()))
diff --git a/crates/hir-def/src/lower.rs b/crates/hir-def/src/lower.rs
index ecd8d79f20be..e4786a1dd40e 100644
--- a/crates/hir-def/src/lower.rs
+++ b/crates/hir-def/src/lower.rs
@@ -18,6 +18,26 @@ pub struct LowerCtx<'a> {
     span_map: OnceCell<SpanMap>,
     ast_id_map: OnceCell<Arc<AstIdMap>>,
     impl_trait_bounds: RefCell<Vec<Vec<Interned<TypeBound>>>>,
+    // Prevent nested impl traits like `impl Foo<impl Bar>`.
+    outer_impl_trait: RefCell<bool>,
+}
+
+pub(crate) struct OuterImplTraitGuard<'a> {
+    ctx: &'a LowerCtx<'a>,
+    old: bool,
+}
+
+impl<'a> OuterImplTraitGuard<'a> {
+    fn new(ctx: &'a LowerCtx<'a>, impl_trait: bool) -> Self {
+        let old = ctx.outer_impl_trait.replace(impl_trait);
+        Self { ctx, old }
+    }
+}
+
+impl<'a> Drop for OuterImplTraitGuard<'a> {
+    fn drop(&mut self) {
+        self.ctx.outer_impl_trait.replace(self.old);
+    }
 }
 
 impl<'a> LowerCtx<'a> {
@@ -28,6 +48,7 @@ impl<'a> LowerCtx<'a> {
             span_map: OnceCell::new(),
             ast_id_map: OnceCell::new(),
             impl_trait_bounds: RefCell::new(Vec::new()),
+            outer_impl_trait: RefCell::default(),
         }
     }
 
@@ -42,6 +63,7 @@ impl<'a> LowerCtx<'a> {
             span_map,
             ast_id_map: OnceCell::new(),
             impl_trait_bounds: RefCell::new(Vec::new()),
+            outer_impl_trait: RefCell::default(),
         }
     }
 
@@ -67,4 +89,12 @@ impl<'a> LowerCtx<'a> {
     pub fn take_impl_traits_bounds(&self) -> Vec<Vec<Interned<TypeBound>>> {
         self.impl_trait_bounds.take()
     }
+
+    pub(crate) fn outer_impl_trait(&self) -> bool {
+        *self.outer_impl_trait.borrow()
+    }
+
+    pub(crate) fn outer_impl_trait_scope(&'a self, impl_trait: bool) -> OuterImplTraitGuard<'a> {
+        OuterImplTraitGuard::new(self, impl_trait)
+    }
 }
diff --git a/crates/hir-def/src/path/lower.rs b/crates/hir-def/src/path/lower.rs
index 2b555b3998a7..a710c2dacaab 100644
--- a/crates/hir-def/src/path/lower.rs
+++ b/crates/hir-def/src/path/lower.rs
@@ -202,6 +202,8 @@ pub(super) fn lower_generic_args(
                     continue;
                 }
                 if let Some(name_ref) = assoc_type_arg.name_ref() {
+                    // Nested impl traits like `impl Foo<Assoc = impl Bar>` are allowed
+                    let _guard = lower_ctx.outer_impl_trait_scope(false);
                     let name = name_ref.as_name();
                     let args = assoc_type_arg
                         .generic_arg_list()
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 18fc8afd183d..fb07e718d102 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -4824,3 +4824,76 @@ fn foo() {
 "#,
     )
 }
+
+#[test]
+fn nested_impl_traits() {
+    check_infer(
+        r#"
+//- minicore: fn
+trait Foo {}
+
+trait Bar<T> {}
+
+trait Baz {
+    type Assoc;
+}
+
+struct Qux<T> {
+    qux: T,
+}
+
+struct S;
+
+impl Foo for S {}
+
+fn not_allowed1(f: impl Fn(impl Foo)) {
+    let foo = S;
+    f(foo);
+}
+
+// This caused stack overflow in #17498
+fn not_allowed2(f: impl Fn(&impl Foo)) {
+    let foo = S;
+    f(&foo);
+}
+
+fn not_allowed3(bar: impl Bar<impl Foo>) {}
+
+// This also caused stack overflow
+fn not_allowed4(bar: impl Bar<&impl Foo>) {}
+
+fn allowed1(baz: impl Baz<Assoc = impl Foo>) {}
+
+fn allowed2<'a>(baz: impl Baz<Assoc = &'a (impl Foo + 'a)>) {}
+
+fn allowed3(baz: impl Baz<Assoc = Qux<impl Foo>>) {}
+"#,
+        expect![[r#"
+            139..140 'f': impl Fn({unknown}) + ?Sized
+            161..193 '{     ...oo); }': ()
+            171..174 'foo': S
+            177..178 'S': S
+            184..185 'f': impl Fn({unknown}) + ?Sized
+            184..190 'f(foo)': ()
+            186..189 'foo': S
+            251..252 'f': impl Fn(&'? {unknown}) + ?Sized
+            274..307 '{     ...oo); }': ()
+            284..287 'foo': S
+            290..291 'S': S
+            297..298 'f': impl Fn(&'? {unknown}) + ?Sized
+            297..304 'f(&foo)': ()
+            299..303 '&foo': &'? S
+            300..303 'foo': S
+            325..328 'bar': impl Bar<{unknown}> + ?Sized
+            350..352 '{}': ()
+            405..408 'bar': impl Bar<&'? {unknown}> + ?Sized
+            431..433 '{}': ()
+            447..450 'baz': impl Baz<Assoc = impl Foo + ?Sized> + ?Sized
+            480..482 '{}': ()
+            500..503 'baz': impl Baz<Assoc = &'a impl Foo + 'a + ?Sized> + ?Sized
+            544..546 '{}': ()
+            560..563 'baz': impl Baz<Assoc = Qux<impl Foo + ?Sized>> + ?Sized
+            598..600 '{}': ()
+        "#]],
+    )
+}