From ed57008510b35338a9d1eff0fbadf11be3fc32c2 Mon Sep 17 00:00:00 2001
From: Lukas Wirth <lukastw97@gmail.com>
Date: Tue, 13 Feb 2024 12:33:51 +0100
Subject: [PATCH] fix: Validate literals in proc-macro-srv
 FreeFunctions::literal_from_str

---
 Cargo.lock                                    |  2 +-
 crates/proc-macro-srv/Cargo.toml              |  2 +-
 crates/proc-macro-srv/src/lib.rs              |  5 ++
 crates/proc-macro-srv/src/server.rs           | 27 --------
 .../src/server/rust_analyzer_span.rs          | 66 +++++++++++++------
 crates/proc-macro-srv/src/server/token_id.rs  | 66 +++++++++++++------
 crates/proc-macro-srv/src/tests/mod.rs        | 18 +++--
 crates/syntax/src/ast/token_ext.rs            | 10 ---
 crates/syntax/src/lib.rs                      | 22 -------
 9 files changed, 110 insertions(+), 108 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 0fdb366c1f9d..7b29d7bb798d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1329,9 +1329,9 @@ dependencies = [
  "paths",
  "proc-macro-api",
  "proc-macro-test",
+ "ra-ap-rustc_lexer",
  "span",
  "stdx",
- "syntax",
  "tt",
 ]
 
diff --git a/crates/proc-macro-srv/Cargo.toml b/crates/proc-macro-srv/Cargo.toml
index d0cdc51c3c22..bd7a31654584 100644
--- a/crates/proc-macro-srv/Cargo.toml
+++ b/crates/proc-macro-srv/Cargo.toml
@@ -29,7 +29,7 @@ paths.workspace = true
 base-db.workspace = true
 span.workspace = true
 proc-macro-api.workspace = true
-syntax.workspace = true
+ra-ap-rustc_lexer.workspace = true
 
 [dev-dependencies]
 expect-test = "1.4.0"
diff --git a/crates/proc-macro-srv/src/lib.rs b/crates/proc-macro-srv/src/lib.rs
index 460a96c07f36..831632c64c0a 100644
--- a/crates/proc-macro-srv/src/lib.rs
+++ b/crates/proc-macro-srv/src/lib.rs
@@ -20,6 +20,11 @@ extern crate proc_macro;
 #[cfg(feature = "in-rust-tree")]
 extern crate rustc_driver as _;
 
+#[cfg(not(feature = "in-rust-tree"))]
+extern crate ra_ap_rustc_lexer as rustc_lexer;
+#[cfg(feature = "in-rust-tree")]
+extern crate rustc_lexer;
+
 mod dylib;
 mod proc_macros;
 mod server;
diff --git a/crates/proc-macro-srv/src/server.rs b/crates/proc-macro-srv/src/server.rs
index bb49dc14f96a..ff8fd295d884 100644
--- a/crates/proc-macro-srv/src/server.rs
+++ b/crates/proc-macro-srv/src/server.rs
@@ -17,7 +17,6 @@ pub mod rust_analyzer_span;
 mod symbol;
 pub mod token_id;
 pub use symbol::*;
-use syntax::ast::{self, IsString};
 use tt::Spacing;
 
 fn delim_to_internal<S>(d: proc_macro::Delimiter, span: bridge::DelimSpan<S>) -> tt::Delimiter<S> {
@@ -55,32 +54,6 @@ fn spacing_to_external(spacing: Spacing) -> proc_macro::Spacing {
     }
 }
 
-fn literal_to_external(literal_kind: ast::LiteralKind) -> Option<proc_macro::bridge::LitKind> {
-    match literal_kind {
-        ast::LiteralKind::String(data) => Some(if data.is_raw() {
-            bridge::LitKind::StrRaw(data.raw_delimiter_count()?)
-        } else {
-            bridge::LitKind::Str
-        }),
-
-        ast::LiteralKind::ByteString(data) => Some(if data.is_raw() {
-            bridge::LitKind::ByteStrRaw(data.raw_delimiter_count()?)
-        } else {
-            bridge::LitKind::ByteStr
-        }),
-        ast::LiteralKind::CString(data) => Some(if data.is_raw() {
-            bridge::LitKind::CStrRaw(data.raw_delimiter_count()?)
-        } else {
-            bridge::LitKind::CStr
-        }),
-        ast::LiteralKind::IntNumber(_) => Some(bridge::LitKind::Integer),
-        ast::LiteralKind::FloatNumber(_) => Some(bridge::LitKind::Float),
-        ast::LiteralKind::Char(_) => Some(bridge::LitKind::Char),
-        ast::LiteralKind::Byte(_) => Some(bridge::LitKind::Byte),
-        ast::LiteralKind::Bool(_) => None,
-    }
-}
-
 struct LiteralFormatter<S>(bridge::Literal<S, Symbol>);
 
 impl<S> LiteralFormatter<S> {
diff --git a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
index cf6e816d599a..17159ffbb887 100644
--- a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
+++ b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs
@@ -13,11 +13,10 @@ use std::{
 use ::tt::{TextRange, TextSize};
 use proc_macro::bridge::{self, server};
 use span::{Span, FIXUP_ERASED_FILE_AST_ID_MARKER};
-use syntax::ast::{self, IsString};
 
 use crate::server::{
-    delim_to_external, delim_to_internal, literal_to_external, token_stream::TokenStreamBuilder,
-    LiteralFormatter, Symbol, SymbolInternerRef, SYMBOL_INTERNER,
+    delim_to_external, delim_to_internal, token_stream::TokenStreamBuilder, LiteralFormatter,
+    Symbol, SymbolInternerRef, SYMBOL_INTERNER,
 };
 mod tt {
     pub use ::tt::*;
@@ -71,32 +70,57 @@ impl server::FreeFunctions for RaSpanServer {
         &mut self,
         s: &str,
     ) -> Result<bridge::Literal<Self::Span, Self::Symbol>, ()> {
-        let literal = ast::Literal::parse(s).ok_or(())?;
-        let literal = literal.tree();
+        use proc_macro::bridge::LitKind;
+        use rustc_lexer::{LiteralKind, Token, TokenKind};
+
+        let mut tokens = rustc_lexer::tokenize(s);
+        let minus_or_lit = tokens.next().unwrap_or(Token { kind: TokenKind::Eof, len: 0 });
+
+        let lit = if minus_or_lit.kind == TokenKind::Minus {
+            let lit = tokens.next().ok_or(())?;
+            if !matches!(
+                lit.kind,
+                TokenKind::Literal {
+                    kind: LiteralKind::Int { .. } | LiteralKind::Float { .. },
+                    ..
+                }
+            ) {
+                return Err(());
+            }
+            lit
+        } else {
+            minus_or_lit
+        };
 
-        let kind = literal_to_external(literal.kind()).ok_or(())?;
+        if tokens.next().is_some() {
+            return Err(());
+        }
 
-        // FIXME: handle more than just int and float suffixes
-        let suffix = match literal.kind() {
-            ast::LiteralKind::FloatNumber(num) => num.suffix().map(ToString::to_string),
-            ast::LiteralKind::IntNumber(num) => num.suffix().map(ToString::to_string),
-            _ => None,
+        let TokenKind::Literal { kind, suffix_start } = lit.kind else { return Err(()) };
+        let kind = match kind {
+            LiteralKind::Int { .. } => LitKind::Integer,
+            LiteralKind::Float { .. } => LitKind::Float,
+            LiteralKind::Char { .. } => LitKind::Char,
+            LiteralKind::Byte { .. } => LitKind::Byte,
+            LiteralKind::Str { .. } => LitKind::Str,
+            LiteralKind::ByteStr { .. } => LitKind::ByteStr,
+            LiteralKind::CStr { .. } => LitKind::CStr,
+            LiteralKind::RawStr { n_hashes } => LitKind::StrRaw(n_hashes.unwrap_or_default()),
+            LiteralKind::RawByteStr { n_hashes } => {
+                LitKind::ByteStrRaw(n_hashes.unwrap_or_default())
+            }
+            LiteralKind::RawCStr { n_hashes } => LitKind::CStrRaw(n_hashes.unwrap_or_default()),
         };
 
-        let text = match literal.kind() {
-            ast::LiteralKind::String(data) => data.text_without_quotes().to_string(),
-            ast::LiteralKind::ByteString(data) => data.text_without_quotes().to_string(),
-            ast::LiteralKind::CString(data) => data.text_without_quotes().to_string(),
-            _ => s.to_string(),
+        let (lit, suffix) = s.split_at(suffix_start as usize);
+        let suffix = match suffix {
+            "" | "_" => None,
+            suffix => Some(Symbol::intern(self.interner, suffix)),
         };
-        let text = if let Some(ref suffix) = suffix { text.strip_suffix(suffix) } else { None }
-            .unwrap_or(&text);
-
-        let suffix = suffix.map(|suffix| Symbol::intern(self.interner, &suffix));
 
         Ok(bridge::Literal {
             kind,
-            symbol: Symbol::intern(self.interner, text),
+            symbol: Symbol::intern(self.interner, lit),
             suffix,
             span: self.call_site,
         })
diff --git a/crates/proc-macro-srv/src/server/token_id.rs b/crates/proc-macro-srv/src/server/token_id.rs
index 70e577f576fe..eddd6b1e6b9d 100644
--- a/crates/proc-macro-srv/src/server/token_id.rs
+++ b/crates/proc-macro-srv/src/server/token_id.rs
@@ -6,11 +6,10 @@ use std::{
 };
 
 use proc_macro::bridge::{self, server};
-use syntax::ast::{self, IsString};
 
 use crate::server::{
-    delim_to_external, delim_to_internal, literal_to_external, token_stream::TokenStreamBuilder,
-    LiteralFormatter, Symbol, SymbolInternerRef, SYMBOL_INTERNER,
+    delim_to_external, delim_to_internal, token_stream::TokenStreamBuilder, LiteralFormatter,
+    Symbol, SymbolInternerRef, SYMBOL_INTERNER,
 };
 mod tt {
     pub use proc_macro_api::msg::TokenId;
@@ -63,32 +62,57 @@ impl server::FreeFunctions for TokenIdServer {
         &mut self,
         s: &str,
     ) -> Result<bridge::Literal<Self::Span, Self::Symbol>, ()> {
-        let literal = ast::Literal::parse(s).ok_or(())?;
-        let literal = literal.tree();
+        use proc_macro::bridge::LitKind;
+        use rustc_lexer::{LiteralKind, Token, TokenKind};
+
+        let mut tokens = rustc_lexer::tokenize(s);
+        let minus_or_lit = tokens.next().unwrap_or(Token { kind: TokenKind::Eof, len: 0 });
+
+        let lit = if minus_or_lit.kind == TokenKind::Minus {
+            let lit = tokens.next().ok_or(())?;
+            if !matches!(
+                lit.kind,
+                TokenKind::Literal {
+                    kind: LiteralKind::Int { .. } | LiteralKind::Float { .. },
+                    ..
+                }
+            ) {
+                return Err(());
+            }
+            lit
+        } else {
+            minus_or_lit
+        };
 
-        let kind = literal_to_external(literal.kind()).ok_or(())?;
+        if tokens.next().is_some() {
+            return Err(());
+        }
 
-        // FIXME: handle more than just int and float suffixes
-        let suffix = match literal.kind() {
-            ast::LiteralKind::FloatNumber(num) => num.suffix().map(ToString::to_string),
-            ast::LiteralKind::IntNumber(num) => num.suffix().map(ToString::to_string),
-            _ => None,
+        let TokenKind::Literal { kind, suffix_start } = lit.kind else { return Err(()) };
+        let kind = match kind {
+            LiteralKind::Int { .. } => LitKind::Integer,
+            LiteralKind::Float { .. } => LitKind::Float,
+            LiteralKind::Char { .. } => LitKind::Char,
+            LiteralKind::Byte { .. } => LitKind::Byte,
+            LiteralKind::Str { .. } => LitKind::Str,
+            LiteralKind::ByteStr { .. } => LitKind::ByteStr,
+            LiteralKind::CStr { .. } => LitKind::CStr,
+            LiteralKind::RawStr { n_hashes } => LitKind::StrRaw(n_hashes.unwrap_or_default()),
+            LiteralKind::RawByteStr { n_hashes } => {
+                LitKind::ByteStrRaw(n_hashes.unwrap_or_default())
+            }
+            LiteralKind::RawCStr { n_hashes } => LitKind::CStrRaw(n_hashes.unwrap_or_default()),
         };
 
-        let text = match literal.kind() {
-            ast::LiteralKind::String(data) => data.text_without_quotes().to_string(),
-            ast::LiteralKind::ByteString(data) => data.text_without_quotes().to_string(),
-            ast::LiteralKind::CString(data) => data.text_without_quotes().to_string(),
-            _ => s.to_string(),
+        let (lit, suffix) = s.split_at(suffix_start as usize);
+        let suffix = match suffix {
+            "" | "_" => None,
+            suffix => Some(Symbol::intern(self.interner, suffix)),
         };
-        let text = if let Some(ref suffix) = suffix { text.strip_suffix(suffix) } else { None }
-            .unwrap_or(&text);
-
-        let suffix = suffix.map(|suffix| Symbol::intern(self.interner, &suffix));
 
         Ok(bridge::Literal {
             kind,
-            symbol: Symbol::intern(self.interner, text),
+            symbol: Symbol::intern(self.interner, lit),
             suffix,
             span: self.call_site,
         })
diff --git a/crates/proc-macro-srv/src/tests/mod.rs b/crates/proc-macro-srv/src/tests/mod.rs
index 87d832cc76fa..e5bfe5ee92cd 100644
--- a/crates/proc-macro-srv/src/tests/mod.rs
+++ b/crates/proc-macro-srv/src/tests/mod.rs
@@ -169,8 +169,8 @@ fn test_fn_like_mk_idents() {
 fn test_fn_like_macro_clone_literals() {
     assert_expand(
         "fn_like_clone_tokens",
-        r#"1u16, 2_u32, -4i64, 3.14f32, "hello bridge""#,
-        expect![[r#"
+        r###"1u16, 2_u32, -4i64, 3.14f32, "hello bridge", "suffixed"suffix, r##"raw"##"###,
+        expect![[r###"
             SUBTREE $$ 1 1
               LITERAL 1u16 1
               PUNCH   , [alone] 1
@@ -181,8 +181,12 @@ fn test_fn_like_macro_clone_literals() {
               PUNCH   , [alone] 1
               LITERAL 3.14f32 1
               PUNCH   , [alone] 1
-              LITERAL "hello bridge" 1"#]],
-        expect![[r#"
+              LITERAL ""hello bridge"" 1
+              PUNCH   , [alone] 1
+              LITERAL ""suffixed""suffix 1
+              PUNCH   , [alone] 1
+              LITERAL r##"r##"raw"##"## 1"###]],
+        expect![[r###"
             SUBTREE $$ SpanData { range: 0..100, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) } SpanData { range: 0..100, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
               LITERAL 1u16 SpanData { range: 0..4, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
               PUNCH   , [alone] SpanData { range: 4..5, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
@@ -193,7 +197,11 @@ fn test_fn_like_macro_clone_literals() {
               PUNCH   , [alone] SpanData { range: 18..19, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
               LITERAL 3.14f32 SpanData { range: 20..27, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
               PUNCH   , [alone] SpanData { range: 27..28, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
-              LITERAL "hello bridge" SpanData { range: 29..43, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }"#]],
+              LITERAL ""hello bridge"" SpanData { range: 29..43, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
+              PUNCH   , [alone] SpanData { range: 43..44, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
+              LITERAL ""suffixed""suffix SpanData { range: 45..61, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
+              PUNCH   , [alone] SpanData { range: 61..62, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }
+              LITERAL r##"r##"raw"##"## SpanData { range: 63..73, anchor: SpanAnchor(FileId(42), 2), ctx: SyntaxContextId(0) }"###]],
     );
 }
 
diff --git a/crates/syntax/src/ast/token_ext.rs b/crates/syntax/src/ast/token_ext.rs
index c93391a9792c..7cd1f1550b98 100644
--- a/crates/syntax/src/ast/token_ext.rs
+++ b/crates/syntax/src/ast/token_ext.rs
@@ -204,16 +204,6 @@ pub trait IsString: AstToken {
         assert!(TextRange::up_to(contents_range.len()).contains_range(range));
         Some(range + contents_range.start())
     }
-    fn raw_delimiter_count(&self) -> Option<u8> {
-        let text = self.text();
-        let quote_range = self.text_range_between_quotes()?;
-        let range_start = self.syntax().text_range().start();
-        text[TextRange::up_to((quote_range - range_start).start())]
-            .matches('#')
-            .count()
-            .try_into()
-            .ok()
-    }
 }
 
 impl IsString for ast::String {
diff --git a/crates/syntax/src/lib.rs b/crates/syntax/src/lib.rs
index f562da150372..b755de86d32c 100644
--- a/crates/syntax/src/lib.rs
+++ b/crates/syntax/src/lib.rs
@@ -182,28 +182,6 @@ impl SourceFile {
     }
 }
 
-impl ast::Literal {
-    pub fn parse(text: &str) -> Option<Parse<ast::Literal>> {
-        let lexed = parser::LexedStr::new(text);
-        let parser_input = lexed.to_input();
-        let parser_output = parser::TopEntryPoint::Expr.parse(&parser_input);
-        let (green, mut errors, _) = parsing::build_tree(lexed, parser_output);
-        let root = SyntaxNode::new_root(green.clone());
-
-        errors.extend(validation::validate(&root));
-
-        if root.kind() == SyntaxKind::LITERAL {
-            Some(Parse {
-                green,
-                errors: if errors.is_empty() { None } else { Some(errors.into()) },
-                _ty: PhantomData,
-            })
-        } else {
-            None
-        }
-    }
-}
-
 impl ast::TokenTree {
     pub fn reparse_as_comma_separated_expr(self) -> Parse<ast::MacroEagerInput> {
         let tokens = self.syntax().descendants_with_tokens().filter_map(NodeOrToken::into_token);