diff --git a/Cargo.lock b/Cargo.lock index 4922b93f..a4e6153d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -3220,6 +3220,7 @@ name = "synth" version = "0.6.9" dependencies = [ "anyhow", + "async-lock 2.8.0", "async-std", "async-trait", "backtrace", diff --git a/core/src/db_utils.rs b/core/src/db_utils.rs index f38d0a3b..ba3bb303 100644 --- a/core/src/db_utils.rs +++ b/core/src/db_utils.rs @@ -3,4 +3,5 @@ use uriparse::URI; pub struct DataSourceParams<'a> { pub uri: URI<'a>, pub schema: Option, // PostgreSQL + pub concurrency: usize, } diff --git a/core/src/schema/content/datasource.rs b/core/src/schema/content/datasource.rs index f92caf61..a58b6b0b 100644 --- a/core/src/schema/content/datasource.rs +++ b/core/src/schema/content/datasource.rs @@ -16,6 +16,7 @@ impl Compile for DatasourceContent { let params = DataSourceParams { uri: URI::try_from(self.path.as_str())?, schema: None, + concurrency: 1, }; let iter = get_iter(params).map(|i| -> Box> { if !self.cycle { diff --git a/synth/Cargo.toml b/synth/Cargo.toml index ec7e53c6..c81140c1 100644 --- a/synth/Cargo.toml +++ b/synth/Cargo.toml @@ -51,6 +51,7 @@ strsim = "0.10.0" async-std = { version = "1.12", features = [ "attributes", "unstable" ] } async-trait = "0.1.50" +async-lock = "2.6" futures = "0.3.15" fs2 = "0.4.3" diff --git a/synth/benches/bench.rs b/synth/benches/bench.rs index d5edf83c..286fbf1b 100644 --- a/synth/benches/bench.rs +++ b/synth/benches/bench.rs @@ -29,6 +29,7 @@ fn bench_generate_n_to_stdout(size: usize) { seed: Some(0), random: false, schema: None, + concurrency: 3, }); let output = io::stdout(); Cli::new().unwrap().run(args, output).await.unwrap() diff --git a/synth/src/cli/export.rs b/synth/src/cli/export.rs index b013c643..8471209f 100644 --- a/synth/src/cli/export.rs +++ b/synth/src/cli/export.rs @@ -63,12 +63,15 @@ where "postgres" | "postgresql" => Box::new(PostgresExportStrategy { uri_string: params.uri.to_string(), schema: params.schema, + concurrency: params.concurrency, }), "mongodb" => Box::new(MongoExportStrategy { uri_string: params.uri.to_string(), + concurrency: params.concurrency, }), "mysql" | "mariadb" => Box::new(MySqlExportStrategy { uri_string: params.uri.to_string(), + concurrency: params.concurrency, }), "json" => { if params.uri.path() == "" { diff --git a/synth/src/cli/mod.rs b/synth/src/cli/mod.rs index f12ad7fd..c989ae6c 100644 --- a/synth/src/cli/mod.rs +++ b/synth/src/cli/mod.rs @@ -123,6 +123,7 @@ impl<'w> Cli { uri: URI::try_from(cmd.from.as_str()) .with_context(|| format!("Parsing import URI '{}'", cmd.from))?, schema: cmd.schema, + concurrency: 1, } .try_into()?; @@ -176,6 +177,7 @@ impl<'w> Cli { uri: URI::try_from(cmd.to.as_str()) .with_context(|| format!("Parsing generation URI '{}'", cmd.to))?, schema: cmd.schema, + concurrency: cmd.concurrency, } .try_into()?; @@ -289,6 +291,13 @@ pub struct GenerateCommand { )] #[serde(skip)] pub schema: Option, + #[structopt( + long, + help = "The maximum number of concurrent tasks writing to the database.", + default_value = "3" + )] + #[serde(skip)] + pub concurrency: usize, } #[derive(StructOpt, Serialize)] diff --git a/synth/src/cli/mongo.rs b/synth/src/cli/mongo.rs index af313480..f750815a 100644 --- a/synth/src/cli/mongo.rs +++ b/synth/src/cli/mongo.rs @@ -20,6 +20,7 @@ use synth_core::{Content, Namespace, Value}; #[derive(Clone, Debug)] pub struct MongoExportStrategy { pub uri_string: String, + pub concurrency: usize, } #[derive(Clone, Debug)] @@ -152,7 +153,12 @@ fn bson_to_content(bson: &Bson) -> Content { impl ExportStrategy for MongoExportStrategy { fn export(&self, _namespace: Namespace, sample: SamplerOutput) -> Result<()> { - let mut client = Client::with_uri_str(&self.uri_string)?; + let mut client_options = ClientOptions::parse(&self.uri_string)?; + client_options.max_pool_size = Some(self.concurrency.try_into().unwrap()); + + info!("Connecting to database at {} ...", &self.uri_string); + + let mut client = Client::with_options(client_options)?; match sample { SamplerOutput::Collection(name, value) => { diff --git a/synth/src/cli/mysql.rs b/synth/src/cli/mysql.rs index 91ed7905..07702622 100644 --- a/synth/src/cli/mysql.rs +++ b/synth/src/cli/mysql.rs @@ -1,7 +1,7 @@ use crate::cli::export::{create_and_insert_values, ExportStrategy}; use crate::cli::import::ImportStrategy; use crate::cli::import_utils::build_namespace_import; -use crate::datasource::mysql_datasource::MySqlDataSource; +use crate::datasource::mysql_datasource::{MySqlConnectParams, MySqlDataSource}; use crate::datasource::DataSource; use crate::sampler::SamplerOutput; use anyhow::Result; @@ -10,11 +10,17 @@ use synth_core::schema::Namespace; #[derive(Clone, Debug)] pub struct MySqlExportStrategy { pub uri_string: String, + pub concurrency: usize, } impl ExportStrategy for MySqlExportStrategy { fn export(&self, _namespace: Namespace, sample: SamplerOutput) -> Result<()> { - let datasource = MySqlDataSource::new(&self.uri_string)?; + let connect_params = MySqlConnectParams { + uri: self.uri_string.clone(), + concurrency: self.concurrency, + }; + + let datasource = MySqlDataSource::new(&connect_params)?; create_and_insert_values(sample, &datasource) } @@ -27,7 +33,11 @@ pub struct MySqlImportStrategy { impl ImportStrategy for MySqlImportStrategy { fn import(&self) -> Result { - let datasource = MySqlDataSource::new(&self.uri_string)?; + let connect_params = MySqlConnectParams { + uri: self.uri_string.clone(), + concurrency: 1, + }; + let datasource = MySqlDataSource::new(&connect_params)?; build_namespace_import(&datasource) } diff --git a/synth/src/cli/postgres.rs b/synth/src/cli/postgres.rs index c69d5da6..806d2ca4 100644 --- a/synth/src/cli/postgres.rs +++ b/synth/src/cli/postgres.rs @@ -11,6 +11,7 @@ use synth_core::schema::Namespace; pub struct PostgresExportStrategy { pub uri_string: String, pub schema: Option, + pub concurrency: usize, } impl ExportStrategy for PostgresExportStrategy { @@ -18,6 +19,7 @@ impl ExportStrategy for PostgresExportStrategy { let connect_params = PostgresConnectParams { uri: self.uri_string.clone(), schema: self.schema.clone(), + concurrency: self.concurrency, }; let datasource = PostgresDataSource::new(&connect_params)?; @@ -37,6 +39,7 @@ impl ImportStrategy for PostgresImportStrategy { let connect_params = PostgresConnectParams { uri: self.uri_string.clone(), schema: self.schema.clone(), + concurrency: 1, }; let datasource = PostgresDataSource::new(&connect_params)?; diff --git a/synth/src/datasource/mysql_datasource.rs b/synth/src/datasource/mysql_datasource.rs index 516c858e..31d3dd1b 100644 --- a/synth/src/datasource/mysql_datasource.rs +++ b/synth/src/datasource/mysql_datasource.rs @@ -22,22 +22,31 @@ use synth_gen::prelude::*; /// - MySql aliases bool and boolean data types as tinyint. We currently define all tinyint as i8. /// Ideally, the user can define a way to force certain fields as bool rather than i8. +pub struct MySqlConnectParams { + pub(crate) uri: String, + pub(crate) concurrency: usize, +} + pub struct MySqlDataSource { pool: Pool, + concurrency: usize, } #[async_trait] impl DataSource for MySqlDataSource { - type ConnectParams = String; + type ConnectParams = MySqlConnectParams; fn new(connect_params: &Self::ConnectParams) -> Result { task::block_on(async { let pool = MySqlPoolOptions::new() - .max_connections(3) //TODO expose this as a user config? - .connect(connect_params.as_str()) + .max_connections(connect_params.concurrency.try_into().unwrap()) + .connect(connect_params.uri.as_str()) .await?; - Ok::(MySqlDataSource { pool }) + Ok::(MySqlDataSource { + pool, + concurrency: connect_params.concurrency, + }) }) } @@ -53,6 +62,17 @@ impl SqlxDataSource for MySqlDataSource { const IDENTIFIER_QUOTE: char = '`'; + fn clone(&self) -> Self { + Self { + pool: Pool::clone(&self.pool), + concurrency: self.concurrency, + } + } + + fn get_concurrency(&self) -> usize { + self.concurrency + } + fn get_pool(&self) -> Pool { Pool::clone(&self.pool) } diff --git a/synth/src/datasource/postgres_datasource.rs b/synth/src/datasource/postgres_datasource.rs index 12578289..25db63f6 100644 --- a/synth/src/datasource/postgres_datasource.rs +++ b/synth/src/datasource/postgres_datasource.rs @@ -21,12 +21,14 @@ use synth_core::{Content, Value}; pub struct PostgresConnectParams { pub(crate) uri: String, pub(crate) schema: Option, + pub(crate) concurrency: usize, } pub struct PostgresDataSource { pool: Pool, single_thread_pool: Pool, schema: String, // consider adding a type schema + concurrency: usize, } #[async_trait] @@ -42,7 +44,7 @@ impl DataSource for PostgresDataSource { let mut arc = Arc::new(schema.clone()); let pool = PgPoolOptions::new() - .max_connections(3) //TODO expose this as a user config? + .max_connections(connect_params.concurrency.try_into().unwrap()) .after_connect(move |conn, _meta| { let schema = arc.clone(); Box::pin(async move { @@ -76,6 +78,7 @@ impl DataSource for PostgresDataSource { pool, single_thread_pool, schema, + concurrency: connect_params.concurrency, }) }) } @@ -119,6 +122,19 @@ impl SqlxDataSource for PostgresDataSource { const IDENTIFIER_QUOTE: char = '\"'; + fn clone(&self) -> Self { + Self { + pool: Pool::clone(&self.pool), + single_thread_pool: Pool::clone(&self.single_thread_pool), + schema: self.schema.clone(), + concurrency: self.concurrency, + } + } + + fn get_concurrency(&self) -> usize { + self.concurrency + } + fn get_pool(&self) -> Pool { Pool::clone(&self.single_thread_pool) } diff --git a/synth/src/datasource/relational_datasource.rs b/synth/src/datasource/relational_datasource.rs index 99670b69..020a0071 100644 --- a/synth/src/datasource/relational_datasource.rs +++ b/synth/src/datasource/relational_datasource.rs @@ -1,11 +1,14 @@ use crate::datasource::DataSource; use anyhow::Result; +use async_lock::Semaphore; +use async_std::task; use async_trait::async_trait; use beau_collector::BeauCollector; use futures::future::join_all; use sqlx::{ query::Query, Arguments, Connection, Database, Encode, Executor, IntoArguments, Pool, Type, }; +use std::sync::Arc; use synth_core::{Content, Value}; use synth_gen::value::Number; @@ -52,12 +55,17 @@ pub trait SqlxDataSource: DataSource { const IDENTIFIER_QUOTE: char; + fn clone(&self) -> Self; + /// Gets a pool to execute queries with fn get_pool(&self) -> Pool; /// Gets a multithread pool to execute queries with fn get_multithread_pool(&self) -> Pool; + /// Get the maximum concurrency for the data source + fn get_concurrency(&self) -> usize; + /// Prepare a single query with data source specifics fn query<'q>(&self, query: &'q str) -> Query<'q, Self::DB, Self::Arguments> { sqlx::query(query) @@ -117,8 +125,8 @@ pub trait SqlxDataSource: DataSource { ) -> Result<::QueryResult> where for<'c> &'c mut Self::Connection: Executor<'c, Database = Self::DB>, - Value: Type, - for<'d> Value: Encode<'d, Self::DB>, + Value: Type + Send + Sync, + for<'d> Value: Encode<'d, Self::DB> + Send + Sync, { let mut query = sqlx::query::(query.as_str()); @@ -132,7 +140,7 @@ pub trait SqlxDataSource: DataSource { } } -pub async fn insert_relational_data( +pub async fn insert_relational_data( datasource: &T, collection_name: &str, collection: &[Value], @@ -146,6 +154,7 @@ where for<'d> Value: Encode<'d, T::DB>, { let batch_size = DEFAULT_INSERT_BATCH_SIZE; + let max_concurrency = datasource.get_concurrency(); if collection.is_empty() { println!("Collection {collection_name} generated 0 values. Skipping insertion...",); @@ -208,9 +217,19 @@ where .collect::>() .join(","); - let mut futures = Vec::with_capacity(collection.len()); + let collection_chunks = collection.chunks(batch_size); + let mut futures = Vec::with_capacity(collection_chunks.len()); + + info!( + "Inserting {} rows for {}...", + collection.len(), + collection_name + ); + + let semaphore = Arc::new(Semaphore::new(max_concurrency)); + for rows in collection_chunks { + let permit = semaphore.clone().acquire_arc().await; - for rows in collection.chunks(batch_size) { let table_name = datasource.get_table_name_for_insert(collection_name); let mut query = format!("INSERT INTO {table_name} ({column_names}) VALUES \n"); @@ -233,7 +252,13 @@ where query.push_str(",\n"); } } - let future = datasource.execute_query(query, query_params); + let datasource = datasource.clone(); + let future = task::spawn(async move { + let result = datasource.execute_query(query, query_params).await; + drop(permit); + result + }); + futures.push(future); } diff --git a/synth/tests/helpers/mod.rs b/synth/tests/helpers/mod.rs index 2a89d656..f8020250 100644 --- a/synth/tests/helpers/mod.rs +++ b/synth/tests/helpers/mod.rs @@ -16,6 +16,7 @@ pub async fn generate_scenario(namespace: &str, scenario: Option) -> Res seed: Some(5), size: 10, to: "json:".to_string(), + concurrency: 3, })) .await }