From 10a862b4044874e576d5cee9166b7a1e217b765b Mon Sep 17 00:00:00 2001 From: Ameya Ketkar <94497232+ketkarameya@users.noreply.github.com> Date: Wed, 5 Jul 2023 07:43:45 -0700 Subject: [PATCH] Introduce CompiledCGPatterns capturing the TS-Query and regex (placeholder) ghstack-source-id: cf2baf162fcbfc57e097e1fbc6e90c5ae2869b74 Pull Request resolved: https://github.com/uber/piranha/pull/527 --- src/models/capture_group_patterns.rs | 61 +++++++++++++++++++++++++- src/models/filter.rs | 21 +++------ src/models/matches.rs | 10 ++--- src/models/rule_store.rs | 20 ++++++--- src/models/scopes.rs | 9 +--- src/models/source_code_unit.rs | 12 ++--- src/utilities/tree_sitter_utilities.rs | 22 ---------- 7 files changed, 86 insertions(+), 69 deletions(-) diff --git a/src/models/capture_group_patterns.rs b/src/models/capture_group_patterns.rs index 9ac618a4ab..4495bd6f9b 100644 --- a/src/models/capture_group_patterns.rs +++ b/src/models/capture_group_patterns.rs @@ -14,13 +14,17 @@ Copyright (c) 2023 Uber Technologies, Inc. use crate::{ models::Validator, utilities::{ - tree_sitter_utilities::{get_ts_query_parser, number_of_errors}, + tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors}, Instantiate, }, }; use pyo3::prelude::pyclass; +use regex::Regex; use serde_derive::Deserialize; use std::collections::HashMap; +use tree_sitter::{Node, Query}; + +use super::matches::Match; #[pyclass] #[derive(Deserialize, Debug, Clone, Default, PartialEq, Hash, Eq)] @@ -38,12 +42,18 @@ impl CGPattern { impl Validator for CGPattern { fn validate(&self) -> Result<(), String> { + if self.pattern().starts_with("rgx ") { + panic!("Regex not supported") + } let mut parser = get_ts_query_parser(); parser .parse(self.pattern(), None) .filter(|x| number_of_errors(&x.root_node()) == 0) .map(|_| Ok(())) - .unwrap_or(Err(format!("Cannot parse - {}", self.pattern()))) + .unwrap_or(Err(format!( + "Cannot parse the tree-sitter query - {}", + self.pattern() + ))) } } @@ -56,3 +66,50 @@ impl Instantiate for CGPattern { CGPattern::new(self.pattern().instantiate(&substitutions)) } } + +#[derive(Debug)] +pub(crate) enum CompiledCGPattern { + Q(Query), + R(Regex), // Regex is not yet supported +} + +impl CompiledCGPattern { + /// Applies the query upon the given node, and gets all the matches + /// # Arguments + /// * `node` - the root node to apply the query upon + /// * `source_code` - the corresponding source code string for the node. + /// * `query` - the query to be applied + /// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`. + /// + /// # Returns + /// A vector of `tuples` containing the range of the matches in the source code and the corresponding mapping for the tags (to code snippets). + /// By default it returns the range of the outermost node for each query match. + /// If `replace_node` is provided in the rule, it returns the range of the node corresponding to that tag. + pub(crate) fn get_match(&self, node: &Node, source_code: &str, recursive: bool) -> Option { + if let Some(m) = self + .get_matches(node, source_code.to_string(), recursive, None, None) + .first() + { + return Some(m.clone()); + } + None + } + + /// Applies the pattern upon the given `node`, and gets all the matches + pub(crate) fn get_matches( + &self, node: &Node, source_code: String, recursive: bool, replace_node: Option, + replace_node_idx: Option, + ) -> Vec { + match self { + CompiledCGPattern::Q(query) => get_all_matches_for_query( + node, + source_code, + query, + recursive, + replace_node, + replace_node_idx, + ), + CompiledCGPattern::R(_) => panic!("Regex is not yet supported!!!"), + } + } +} diff --git a/src/models/filter.rs b/src/models/filter.rs index 962df6ebd1..f8ebbcf415 100644 --- a/src/models/filter.rs +++ b/src/models/filter.rs @@ -22,10 +22,7 @@ use pyo3::prelude::{pyclass, pymethods}; use serde_derive::Deserialize; use tree_sitter::Node; -use crate::utilities::{ - gen_py_str_methods, - tree_sitter_utilities::{get_all_matches_for_query, get_match_for_query, get_node_for_range}, -}; +use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range}; use super::{ capture_group_patterns::CGPattern, default_configs::default_child_count, @@ -415,9 +412,8 @@ impl SourceCodeUnit { } while let Some(parent) = current_node.parent() { - if let Some(p_match) = - get_match_for_query(&parent, self.code(), rule_store.query(ts_query), false) - { + let pattern = rule_store.query(ts_query); + if let Some(p_match) = pattern.get_match(&parent, self.code(), false) { let matched_ancestor = get_node_for_range( self.root_node(), p_match.range().start_byte, @@ -442,14 +438,7 @@ impl SourceCodeUnit { // Retrieve all matches within the ancestor node let contains_query = &rule_store.query(filter.contains()); - let matches = get_all_matches_for_query( - ancestor, - self.code().to_string(), - contains_query, - true, - None, - None, - ); + let matches = contains_query.get_matches(ancestor, self.code().to_string(), true, None, None); let at_least = filter.at_least as usize; let at_most = filter.at_most as usize; // Validate if the count of matches falls within the expected range @@ -464,7 +453,7 @@ impl SourceCodeUnit { // Check if there's a match within the scope node // If one of the filters is not satisfied, return false let query = &rule_store.query(ts_query); - if get_match_for_query(ancestor, self.code(), query, true).is_some() { + if query.get_match(ancestor, self.code(), true).is_some() { return false; } } diff --git a/src/models/matches.rs b/src/models/matches.rs index bf77cb95fe..202ce8ac52 100644 --- a/src/models/matches.rs +++ b/src/models/matches.rs @@ -20,10 +20,7 @@ use pyo3::prelude::{pyclass, pymethods}; use serde_derive::{Deserialize, Serialize}; use tree_sitter::Node; -use crate::utilities::{ - gen_py_str_methods, - tree_sitter_utilities::{get_all_matches_for_query, get_node_for_range}, -}; +use crate::utilities::{gen_py_str_methods, tree_sitter_utilities::get_node_for_range}; use super::{ piranha_arguments::PiranhaArguments, rule::InstantiatedRule, rule_store::RuleStore, @@ -291,10 +288,11 @@ impl SourceCodeUnit { } else { (rule.replace_node(), rule.replace_idx()) }; - let mut all_query_matches = get_all_matches_for_query( + + let pattern = rule_store.query(&rule.query()); + let mut all_query_matches = pattern.get_matches( &node, self.code().to_string(), - rule_store.query(&rule.query()), recursive, replace_node_tag, replace_node_idx, diff --git a/src/models/rule_store.rs b/src/models/rule_store.rs index fe239da8c8..d91780fc92 100644 --- a/src/models/rule_store.rs +++ b/src/models/rule_store.rs @@ -22,21 +22,22 @@ use itertools::Itertools; use jwalk::WalkDir; use log::{debug, trace}; use regex::Regex; -use tree_sitter::Query; use crate::{ models::capture_group_patterns::CGPattern, models::piranha_arguments::PiranhaArguments, models::scopes::ScopeQueryGenerator, utilities::read_file, }; -use super::{language::PiranhaLanguage, rule::InstantiatedRule}; +use super::{ + capture_group_patterns::CompiledCGPattern, language::PiranhaLanguage, rule::InstantiatedRule, +}; use glob::Pattern; /// This maintains the state for Piranha. #[derive(Debug, Getters, Default)] pub(crate) struct RuleStore { // Caches the compiled tree-sitter queries. - rule_query_cache: HashMap, + rule_query_cache: HashMap, // Current global rules to be applied. #[get = "pub"] global_rules: Vec, @@ -75,11 +76,16 @@ impl RuleStore { /// Get the compiled query for the `query_str` from the cache /// else compile it, add it to the cache and return it. - pub(crate) fn query(&mut self, query_str: &CGPattern) -> &Query { - self + pub(crate) fn query(&mut self, cg_pattern: &CGPattern) -> &CompiledCGPattern { + let pattern = cg_pattern.pattern(); + if pattern.starts_with("rgx ") { + panic!("Regex not supported.") + } + + &*self .rule_query_cache - .entry(query_str.pattern()) - .or_insert_with(|| self.language.create_query(query_str.pattern())) + .entry(pattern.to_string()) + .or_insert_with(|| CompiledCGPattern::Q(self.language.create_query(pattern))) } // For the given scope level, get the ScopeQueryGenerator from the `scope_config.toml` file diff --git a/src/models/scopes.rs b/src/models/scopes.rs index c494f74352..7393d1776b 100644 --- a/src/models/scopes.rs +++ b/src/models/scopes.rs @@ -13,7 +13,6 @@ Copyright (c) 2023 Uber Technologies, Inc. use super::capture_group_patterns::CGPattern; use super::{rule_store::RuleStore, source_code_unit::SourceCodeUnit}; -use crate::utilities::tree_sitter_utilities::get_match_for_query; use crate::utilities::tree_sitter_utilities::get_node_for_range; use crate::utilities::Instantiate; use derive_builder::Builder; @@ -65,12 +64,8 @@ impl SourceCodeUnit { changed_node.kind() ); for m in &scope_enclosing_nodes { - if let Some(p_match) = get_match_for_query( - &changed_node, - self.code(), - rules_store.query(m.enclosing_node()), - false, - ) { + let pattern = rules_store.query(m.enclosing_node()); + if let Some(p_match) = pattern.get_match(&changed_node, self.code(), false) { // Generate the scope query for the specific context by substituting the // the tags with code snippets appropriately in the `generator` query. return m.scope().instantiate(p_match.matches()); diff --git a/src/models/source_code_unit.rs b/src/models/source_code_unit.rs index 93230262a0..71c173e074 100644 --- a/src/models/source_code_unit.rs +++ b/src/models/source_code_unit.rs @@ -25,8 +25,7 @@ use crate::{ models::capture_group_patterns::CGPattern, models::rule_graph::{GLOBAL, PARENT}, utilities::tree_sitter_utilities::{ - get_match_for_query, get_node_for_range, get_replace_range, get_tree_sitter_edit, - number_of_errors, + get_node_for_range, get_replace_range, get_tree_sitter_edit, number_of_errors, }, }; @@ -299,13 +298,8 @@ impl SourceCodeUnit { // let mut scope_node = self.root_node(); if let Some(query_str) = scope_query { // Apply the scope query in the source code and get the appropriate node - let tree_sitter_scope_query = rules_store.query(query_str); - if let Some(p_match) = get_match_for_query( - &self.root_node(), - self.code(), - tree_sitter_scope_query, - true, - ) { + let scope_pattern = rules_store.query(query_str); + if let Some(p_match) = scope_pattern.get_match(&self.root_node(), self.code(), true) { return get_node_for_range( self.root_node(), p_match.range().start_byte, diff --git a/src/utilities/tree_sitter_utilities.rs b/src/utilities/tree_sitter_utilities.rs index 97ef334a3b..6cbd5832b6 100644 --- a/src/utilities/tree_sitter_utilities.rs +++ b/src/utilities/tree_sitter_utilities.rs @@ -25,28 +25,6 @@ use std::collections::HashMap; use tree_sitter::{InputEdit, Node, Parser, Point, Query, QueryCapture, QueryCursor, Range}; use tree_sitter_traversal::{traverse, Order}; -/// Applies the query upon the given node, and gets all the matches -/// # Arguments -/// * `node` - the root node to apply the query upon -/// * `source_code` - the corresponding source code string for the node. -/// * `query` - the query to be applied -/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`. -/// -/// # Returns -/// A vector of `tuples` containing the range of the matches in the source code and the corresponding mapping for the tags (to code snippets). -/// By default it returns the range of the outermost node for each query match. -/// If `replace_node` is provided in the rule, it returns the range of the node corresponding to that tag. -pub(crate) fn get_match_for_query( - node: &Node, source_code: &str, query: &Query, recursive: bool, -) -> Option { - if let Some(m) = - get_all_matches_for_query(node, source_code.to_string(), query, recursive, None, None).first() - { - return Some(m.clone()); - } - None -} - /// Applies the query upon the given `node`, and gets the first match /// # Arguments /// * `node` - the root node to apply the query upon