diff --git a/sqlness/Cargo.toml b/sqlness/Cargo.toml index ca0345e..5314b3c 100644 --- a/sqlness/Cargo.toml +++ b/sqlness/Cargo.toml @@ -12,10 +12,12 @@ readme = { workspace = true } [dependencies] async-trait = "0.1" derive_builder = "0.11" +minijinja = "1" mysql = { version = "23.0.1", optional = true } postgres = { version = "0.19.7", optional = true } prettydiff = { version = "0.6.2", default_features = false } regex = "1.7.1" +serde_json = "1" thiserror = "1.0" toml = "0.5" walkdir = "2.3" diff --git a/sqlness/examples/interceptor-replace/simple/replace.result b/sqlness/examples/interceptor-replace/simple/replace.result index 60d17d6..138b705 100644 --- a/sqlness/examples/interceptor-replace/simple/replace.result +++ b/sqlness/examples/interceptor-replace/simple/replace.result @@ -19,3 +19,29 @@ SELECT 1; 03/14/2012, 01/01/2013 and 07/05/2014; +-- SQLNESS TEMPLATE {"name": "test"} +SELECT * FROM table where name = "{{name}}"; + +SELECT * FROM table where name = "test"; + +-- SQLNESS TEMPLATE {"aggr": ["sum", "avg", "count"]} +{% for item in aggr %} +SELECT {{item}}(c) from t {%if not loop.last %} {{sql_delimiter()}} {% endif %} +{% endfor %} +; + +SELECT sum(c) from t ; + + SELECT avg(c) from t ; + + SELECT count(c) from t ; + +-- SQLNESS TEMPLATE +INSERT INTO t (c) VALUES +{% for num in range(1, 5) %} +({{ num }}) {%if not loop.last %} , {% endif %} +{% endfor %} +; + +INSERT INTO t (c) VALUES(1) , (2) , (3) , (4) ; + diff --git a/sqlness/examples/interceptor-replace/simple/replace.sql b/sqlness/examples/interceptor-replace/simple/replace.sql index f47f540..290c98d 100644 --- a/sqlness/examples/interceptor-replace/simple/replace.sql +++ b/sqlness/examples/interceptor-replace/simple/replace.sql @@ -10,3 +10,19 @@ SELECT 0; -- example of capture group replacement -- SQLNESS REPLACE (?P\d{4})-(?P\d{2})-(?P\d{2}) $m/$d/$y 2012-03-14, 2013-01-01 and 2014-07-05; + +-- SQLNESS TEMPLATE {"name": "test"} +SELECT * FROM table where name = "{{name}}"; + +-- SQLNESS TEMPLATE {"aggr": ["sum", "avg", "count"]} +{% for item in aggr %} +SELECT {{item}}(c) from t {%if not loop.last %} {{sql_delimiter()}} {% endif %} +{% endfor %} +; + +-- SQLNESS TEMPLATE +INSERT INTO t (c) VALUES +{% for num in range(1, 5) %} +({{ num }}) {%if not loop.last %} , {% endif %} +{% endfor %} +; diff --git a/sqlness/src/case.rs b/sqlness/src/case.rs index ed581eb..328145f 100644 --- a/sqlness/src/case.rs +++ b/sqlness/src/case.rs @@ -16,6 +16,7 @@ use crate::{ }; const COMMENT_PREFIX: &str = "--"; +const QUERY_DELIMITER: char = ';'; pub(crate) struct TestCase { name: String, @@ -55,7 +56,7 @@ impl TestCase { query.append_query_line(&line); // SQL statement ends with ';' - if line.ends_with(';') { + if line.ends_with(QUERY_DELIMITER) { queries.push(query); query = Query::with_interceptor_factories(cfg.interceptor_factories.clone()); } else { @@ -88,7 +89,7 @@ impl Display for TestCase { } /// A String-to-String map used as query context. -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] pub struct QueryContext { pub context: HashMap, } @@ -137,14 +138,27 @@ impl Query { W: Write, { let context = self.before_execute_intercept(); - - let mut result = db - .query(context, self.concat_query_lines()) - .await - .to_string(); - - self.after_execute_intercept(&mut result); - self.write_result(writer, result)?; + for comment in &self.comment_lines { + writer.write_all(comment.as_bytes())?; + writer.write_all("\n".as_bytes())?; + } + for comment in &self.display_query { + writer.write_all(comment.as_bytes())?; + } + writer.write_all("\n\n".as_bytes())?; + + let sql = self.concat_query_lines(); + // An intercetor may generate multiple SQLs, so we need to split them. + for sql in sql.split(QUERY_DELIMITER) { + if !sql.trim().is_empty() { + let mut result = db + .query(context.clone(), format!("{sql};")) + .await + .to_string(); + self.after_execute_intercept(&mut result); + self.write_result(writer, result)?; + } + } Ok(()) } @@ -183,14 +197,6 @@ impl Query { where W: Write, { - for comment in &self.comment_lines { - writer.write_all(comment.as_bytes())?; - writer.write("\n".as_bytes())?; - } - for line in &self.display_query { - writer.write_all(line.as_bytes())?; - } - writer.write("\n\n".as_bytes())?; writer.write_all(result.as_bytes())?; writer.write("\n\n".as_bytes())?; diff --git a/sqlness/src/interceptor.rs b/sqlness/src/interceptor.rs index 094e998..2f30c25 100644 --- a/sqlness/src/interceptor.rs +++ b/sqlness/src/interceptor.rs @@ -8,7 +8,7 @@ use crate::{ case::QueryContext, interceptor::{ arg::ArgInterceptorFactory, env::EnvInterceptorFactory, replace::ReplaceInterceptorFactory, - sort_result::SortResultInterceptorFactory, + sort_result::SortResultInterceptorFactory, template::TemplateInterceptorFactory, }, }; @@ -16,6 +16,7 @@ pub mod arg; pub mod env; pub mod replace; pub mod sort_result; +pub mod template; pub type InterceptorRef = Box; @@ -40,5 +41,6 @@ pub fn builtin_interceptors() -> Vec { Arc::new(ReplaceInterceptorFactory {}), Arc::new(EnvInterceptorFactory {}), Arc::new(SortResultInterceptorFactory {}), + Arc::new(TemplateInterceptorFactory {}), ] } diff --git a/sqlness/src/interceptor/template.rs b/sqlness/src/interceptor/template.rs new file mode 100644 index 00000000..c66b00e --- /dev/null +++ b/sqlness/src/interceptor/template.rs @@ -0,0 +1,159 @@ +// Copyright 2024 CeresDB Project Authors. Licensed under Apache-2.0. + +use minijinja::Environment; +use serde_json::Value; + +use super::{Interceptor, InterceptorFactory, InterceptorRef}; + +pub struct TemplateInterceptorFactory; + +const PREFIX: &str = "TEMPLATE"; + +/// Templated query, powered by [minijinja](https://github.com/mitsuhiko/minijinja). +/// The template syntax can be found [here](https://docs.rs/minijinja/latest/minijinja/syntax/index.html). +/// +/// Grammar: +/// ``` text +/// -- SQLNESS TEMPLATE +/// ``` +/// +/// `json` define data bindings passed to template, it should be a valid JSON string. +/// +/// # Example +/// `.sql` file: +/// ``` sql +/// -- SQLNESS TEMPLATE {"name": "test"} +/// SELECT * FROM table where name = "{{name}}" +/// ``` +/// +/// `.result` file: +/// ``` sql +/// -- SQLNESS TEMPLATE {"name": "test"} +/// SELECT * FROM table where name = "test"; +/// ``` +/// +/// In order to generate multiple queries, you can use the builtin function +/// `sql_delimiter()` to insert a delimiter. +/// +#[derive(Debug)] +pub struct TemplateInterceptor { + json_ctx: String, +} + +fn sql_delimiter() -> Result { + Ok(";".to_string()) +} + +impl Interceptor for TemplateInterceptor { + fn before_execute(&self, execute_query: &mut Vec, _context: &mut crate::QueryContext) { + let input = execute_query.join("\n"); + let mut env = Environment::new(); + env.add_function("sql_delimiter", sql_delimiter); + env.add_template("sql", &input).unwrap(); + let tmpl = env.get_template("sql").unwrap(); + let bindings: Value = if self.json_ctx.is_empty() { + serde_json::from_str("{}").unwrap() + } else { + serde_json::from_str(&self.json_ctx).unwrap() + }; + let rendered = tmpl.render(bindings).unwrap(); + *execute_query = rendered + .split('\n') + .map(|v| v.to_string()) + .collect::>(); + } + + fn after_execute(&self, _result: &mut String) {} +} + +impl InterceptorFactory for TemplateInterceptorFactory { + fn try_new(&self, interceptor: &str) -> Option { + Self::try_new_from_str(interceptor).map(|i| Box::new(i) as _) + } +} + +impl TemplateInterceptorFactory { + fn try_new_from_str(interceptor: &str) -> Option { + if interceptor.starts_with(PREFIX) { + let json_ctx = interceptor.trim_start_matches(PREFIX).to_string(); + Some(TemplateInterceptor { json_ctx }) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic_template() { + let interceptor = TemplateInterceptorFactory + .try_new(r#"TEMPLATE {"name": "test"}"#) + .unwrap(); + + let mut input = vec!["SELECT * FROM table where name = '{{name}}'".to_string()]; + interceptor.before_execute(&mut input, &mut crate::QueryContext::default()); + + assert_eq!(input, vec!["SELECT * FROM table where name = 'test'"]); + } + + #[test] + fn vector_template() { + let interceptor = TemplateInterceptorFactory + .try_new(r#"TEMPLATE {"aggr": ["sum", "count", "avg"]}"#) + .unwrap(); + + let mut input = [ + "{%- for item in aggr %}", + "SELECT {{item}}(c) from t;", + "{%- endfor %}", + ] + .map(|v| v.to_string()) + .to_vec(); + interceptor.before_execute(&mut input, &mut crate::QueryContext::default()); + + assert_eq!( + input, + [ + "", + "SELECT sum(c) from t;", + "SELECT count(c) from t;", + "SELECT avg(c) from t;" + ] + .map(|v| v.to_string()) + .to_vec() + ); + } + + #[test] + fn range_template() { + let interceptor = TemplateInterceptorFactory.try_new(r#"TEMPLATE"#).unwrap(); + + let mut input = [ + "INSERT INTO t (c) VALUES", + "{%- for num in range(1, 5) %}", + "({{ num }}){%if not loop.last %}, {% endif %}", + "{%- endfor %}", + ";", + ] + .map(|v| v.to_string()) + .to_vec(); + interceptor.before_execute(&mut input, &mut crate::QueryContext::default()); + + assert_eq!( + input, + [ + "INSERT INTO t (c) VALUES", + "(1), ", + "(2), ", + "(3), ", + "(4)", + ";" + ] + .map(|v| v.to_string()) + .to_vec() + ); + } +}