diff --git a/src/models/capture_group_patterns.rs b/src/models/capture_group_patterns.rs index 1929d6bee5..946dcac5a3 100644 --- a/src/models/capture_group_patterns.rs +++ b/src/models/capture_group_patterns.rs @@ -14,6 +14,7 @@ Copyright (c) 2023 Uber Technologies, Inc. use crate::{ models::Validator, utilities::{ + regex_utilities::get_all_matches_for_regex, tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors}, Instantiate, }, @@ -38,12 +39,20 @@ impl CGPattern { pub(crate) fn pattern(&self) -> String { self.0.to_string() } + + pub(crate) fn extract_regex(&self) -> String { + let mut _val = &self.pattern()[4..]; + _val.to_string() + } } impl Validator for CGPattern { fn validate(&self) -> Result<(), String> { if self.pattern().starts_with("rgx ") { - panic!("Regex not supported") + let mut _val = &self.pattern()[4..]; + return Regex::new(_val) + .map(|_| Ok(())) + .unwrap_or(Err(format!("Cannot parse the regex - {_val}"))); } let mut parser = get_ts_query_parser(); parser @@ -103,7 +112,9 @@ impl CompiledCGPattern { replace_node, replace_node_idx, ), - CompiledCGPattern::R(_) => panic!("Regex is not yet supported!!!"), + CompiledCGPattern::R(regex) => { + get_all_matches_for_regex(node, source_code, regex, recursive, replace_node) + } } } } diff --git a/src/models/matches.rs b/src/models/matches.rs index 202ce8ac52..193169ddab 100644 --- a/src/models/matches.rs +++ b/src/models/matches.rs @@ -55,6 +55,18 @@ pub(crate) struct Match { gen_py_str_methods!(Match); impl Match { + pub(crate) fn from_regex( + mtch: ®ex::Match, matches: HashMap, source_code: &str, + ) -> Self { + Match { + matched_string: mtch.as_str().to_string(), + range: Range::from_regex_match(mtch, source_code), + matches, + associated_comma: None, + associated_comments: Vec::new(), + } + } + pub(crate) fn new( matched_string: String, range: tree_sitter::Range, matches: HashMap, ) -> Self { @@ -231,7 +243,7 @@ impl Match { serde_derive::Serialize, Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, )] #[pyclass] -struct Range { +pub(crate) struct Range { #[pyo3(get)] start_byte: usize, #[pyo3(get)] @@ -260,6 +272,32 @@ impl From for Range { } gen_py_str_methods!(Range); +impl Range { + pub(crate) fn from_regex_match(mtch: ®ex::Match, source_code: &str) -> Self { + Self { + start_byte: mtch.start(), + end_byte: mtch.end(), + start_point: position_for_offset(source_code.as_bytes(), mtch.start()), + end_point: position_for_offset(source_code.as_bytes(), mtch.end()), + } + } +} + +// Finds the position (col and row number) for a given offset. +// Copied from tree-sitter tests [https://github.com/tree-sitter/tree-sitter/blob/d0029a15273e526925a764033e9b7f18f96a7ce5/cli/src/parse.rs#L364] +fn position_for_offset(input: &[u8], offset: usize) -> Point { + let mut result = Point { row: 0, column: 0 }; + for c in &input[0..offset] { + if *c as char == '\n' { + result.row += 1; + result.column = 0; + } else { + result.column += 1; + } + } + result +} + /// A range of positions in a multi-line text document, both in terms of bytes and of /// rows and columns. #[derive( diff --git a/src/models/rule_store.rs b/src/models/rule_store.rs index d91780fc92..9de4c2b5f9 100644 --- a/src/models/rule_store.rs +++ b/src/models/rule_store.rs @@ -79,7 +79,10 @@ impl RuleStore { pub(crate) fn query(&mut self, cg_pattern: &CGPattern) -> &CompiledCGPattern { let pattern = cg_pattern.pattern(); if pattern.starts_with("rgx ") { - panic!("Regex not supported.") + return &*self + .rule_query_cache + .entry(pattern) + .or_insert_with(|| CompiledCGPattern::R(Regex::new(&cg_pattern.extract_regex()).unwrap())); } &*self diff --git a/src/models/unit_tests/rule_graph_validation_test.rs b/src/models/unit_tests/rule_graph_validation_test.rs index f12d9880a9..9102ff837f 100644 --- a/src/models/unit_tests/rule_graph_validation_test.rs +++ b/src/models/unit_tests/rule_graph_validation_test.rs @@ -118,13 +118,3 @@ fn test_filter_bad_arg_contains_n_sibling() { .sibling_count(2) .build(); } - -#[test] -#[should_panic(expected = "Regex not supported")] -fn test_unsupported_regex() { - RuleGraphBuilder::default() - .rules(vec![ - piranha_rule! {name = "Test rule", query = "rgx (\\w+) (\\w)+"}, - ]) - .build(); -} diff --git a/src/tests/test_piranha_java.rs b/src/tests/test_piranha_java.rs index 6340b8d9d4..56c3a023e8 100644 --- a/src/tests/test_piranha_java.rs +++ b/src/tests/test_piranha_java.rs @@ -66,6 +66,7 @@ create_rewrite_tests! { test_new_line_character_used_in_string_literal: "new_line_character_used_in_string_literal", 1; test_java_delete_method_invocation_argument: "delete_method_invocation_argument", 1; test_java_delete_method_invocation_argument_no_op: "delete_method_invocation_argument_no_op", 0; + test_regex_based_matcher: "regex_based_matcher", 1; } create_match_tests! { diff --git a/src/utilities/mod.rs b/src/utilities/mod.rs index 7035b8567e..7ca6df0296 100644 --- a/src/utilities/mod.rs +++ b/src/utilities/mod.rs @@ -11,6 +11,7 @@ Copyright (c) 2023 Uber Technologies, Inc. limitations under the License. */ +pub(crate) mod regex_utilities; pub(crate) mod tree_sitter_utilities; use std::collections::HashMap; use std::error::Error; diff --git a/src/utilities/regex_utilities.rs b/src/utilities/regex_utilities.rs new file mode 100644 index 0000000000..c4ea25e4b4 --- /dev/null +++ b/src/utilities/regex_utilities.rs @@ -0,0 +1,70 @@ +/* + Copyright (c) 2023 Uber Technologies, Inc. + +

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file + except in compliance with the License. You may obtain a copy of the License at +

http://www.apache.org/licenses/LICENSE-2.0 + +

Unless required by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + express or implied. See the License for the specific language governing permissions and + limitations under the License. +*/ + +use crate::models::matches::Match; +use itertools::Itertools; +use regex::Regex; +use std::collections::HashMap; +use tree_sitter::Node; + +/// Applies the query upon the given `node`, and gets all the matches +/// # Arguments +/// * `node` - the root node to apply the query upon +/// * `source_code` - the corresponding source code string for the node. +/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`. +/// * `replace_node` - node to replace +/// +/// # Returns +/// The range of the match in the source code and the corresponding mapping from tags to code snippets. +pub(crate) fn get_all_matches_for_regex( + node: &Node, source_code: String, regex: &Regex, recursive: bool, replace_node: Option, +) -> Vec { + // let code_snippet = node.utf8_text(source_code.as_bytes()).unwrap(); + let all_captures = regex.captures_iter(&source_code).collect_vec(); + let names = regex.capture_names().collect_vec(); + let mut all_matches = vec![]; + for captures in all_captures { + // Check if the range of the self (node), and the range of outermost node captured by the query are equal. + let range_matches_node = node.start_byte() == captures.get(0).unwrap().start() + && node.end_byte() == captures.get(0).unwrap().end(); + let range_matches_inside_node = node.start_byte() <= captures.get(0).unwrap().start() + && node.end_byte() >= captures.get(0).unwrap().end(); + if (recursive && range_matches_inside_node) || range_matches_node { + let group_by_tag = if let Some(ref rn) = replace_node { + captures + .name(rn) + .unwrap_or_else(|| panic!("the tag {rn} provided in the replace node is not present")) + } else { + captures.get(0).unwrap() + }; + let matches = extract_captures(&captures, &names); + all_matches.push(Match::from_regex(&group_by_tag, matches, &source_code)); + } + } + all_matches +} + +// Creates an hashmap from the capture group(name) to the corresponding code snippet. +fn extract_captures( + captures: ®ex::Captures<'_>, names: &Vec>, +) -> HashMap { + names + .iter() + .flatten() + .flat_map(|x| { + captures + .name(x) + .map(|v| (x.to_string(), v.as_str().to_string())) + }) + .collect() +} diff --git a/src/utilities/tree_sitter_utilities.rs b/src/utilities/tree_sitter_utilities.rs index 6cbd5832b6..176d0872af 100644 --- a/src/utilities/tree_sitter_utilities.rs +++ b/src/utilities/tree_sitter_utilities.rs @@ -57,7 +57,7 @@ pub(crate) fn get_all_matches_for_query( // If `recursive` it allows matches to the subtree of self (Node) // Else it ensure that the query perfectly matches the node (`self`). if recursive || range_matches_self { - let mut replace_node_range = captured_node_range; + let mut replace_node_range: Range = captured_node_range; if let Some(replace_node_name) = &replace_node { if let Some(r) = get_range_for_replace_node(query, &query_matches, replace_node_name, replace_node_idx) diff --git a/test-resources/java/regex_based_matcher/configurations/edges.toml b/test-resources/java/regex_based_matcher/configurations/edges.toml new file mode 100644 index 0000000000..183e25ebe2 --- /dev/null +++ b/test-resources/java/regex_based_matcher/configurations/edges.toml @@ -0,0 +1,20 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 +# +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +[[edges]] +scope = "File" +from = "update_import" +to = ["update_list_int"] + +[[edges]] +scope = "Method" +from = "update_list_int" +to = ["update_add"] diff --git a/test-resources/java/regex_based_matcher/configurations/rules.toml b/test-resources/java/regex_based_matcher/configurations/rules.toml new file mode 100644 index 0000000000..0ea6f676c9 --- /dev/null +++ b/test-resources/java/regex_based_matcher/configurations/rules.toml @@ -0,0 +1,59 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 +# +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +# Replace foo().bar().baz() with `true` inside methods not nnotated as @DoNotCleanup +[[rules]] +name = "replace_call" +query = """rgx (?Pfoo\\(\\)\\.bar\\(\\)\\.baz\\(\\))""" +replace_node = "n1" +replace = "true" +groups = ["replace_expression_with_boolean_literal"] +[[rules.filters]] +enclosing_node = """(method_declaration) @md""" +not_contains = ["""rgx @DoNotCleanup"""] + +# Before: +# abc().def().ghi() +# abc().fed().ghi() +[[rules]] +name = "replace_call_def_fed" +query = """rgx (?Pabc\\(\\)\\.(?Pdef)\\(\\)\\.ghi\\(\\))""" +replace_node = "m_def" +replace = "fed" + + +# The below three rules do a dummy type migration from List to NewList + +# Updates the import statement from `java.util.List` to `com.uber.NEwList` +[[rules]] +name = "update_import" +query = """rgx (?Pjava\\.util\\.List)""" +replace_node = "n" +replace = "com.uber.NewList" + +# Updates the type of local variables from `List` to `com.uber.NewList` +[[rules]] +name = "update_list_int" +query = """rgx (?P(?PList)\\s*(?P\\w+)\\s*=.*;)""" +replace_node = "type" +replace = "NewList" +is_seed_rule = false +[[rules.filter]] +enclosing_node = "(method_declaration) @cmd" + +# Updates the relevant callsite from `add` to `addToNewList` +[[rules]] +name = "update_add" +query = """rgx (?P@name\\.(?Padd)\\(\\w+\\))""" +replace_node = "m_name" +replace = "addToNewList" +holes = ["name"] +is_seed_rule = false diff --git a/test-resources/java/regex_based_matcher/expected/Sample.java b/test-resources/java/regex_based_matcher/expected/Sample.java new file mode 100644 index 0000000000..a20ab4428f --- /dev/null +++ b/test-resources/java/regex_based_matcher/expected/Sample.java @@ -0,0 +1,34 @@ +package com.uber.piranha; + +import com.uber.NewList; + +class A { + + void foobar() { + System.out.println("Hello World!"); + System.out.println(true); + } + + @DoNotCleanup + void barfn() { + boolean b = foo().bar().baz(); + System.out.println(b); + } + + void foofn() { + int total = abc().fed().ghi(); + } + + void someTypeChange() { + // Will get updated + NewList a = getList(); + Integer item = getItem(); + a.addToNewList(item); + + // Will not get updated + List b = getListStr(); + Integer item = getItemStr(); + b.add(item); + } + +} diff --git a/test-resources/java/regex_based_matcher/input/Sample.java b/test-resources/java/regex_based_matcher/input/Sample.java new file mode 100644 index 0000000000..01d8b22ca7 --- /dev/null +++ b/test-resources/java/regex_based_matcher/input/Sample.java @@ -0,0 +1,37 @@ +package com.uber.piranha; + +import java.util.List; + +class A { + + void foobar() { + boolean b = foo().bar().baz(); + if (b) { + System.out.println("Hello World!"); + } + System.out.println(b); + } + + @DoNotCleanup + void barfn() { + boolean b = foo().bar().baz(); + System.out.println(b); + } + + void foofn() { + int total = abc().def().ghi(); + } + + void someTypeChange() { + // Will get updated + List a = getList(); + Integer item = getItem(); + a.add(item); + + // Will not get updated + List b = getListStr(); + Integer item = getItemStr(); + b.add(item); + } + +}