From a343fb1d79b9afb93f912fb721b834345d1cb9fc Mon Sep 17 00:00:00 2001 From: Steve Russo <64294847+sjrusso8@users.noreply.github.com> Date: Mon, 22 Apr 2024 09:19:50 -0400 Subject: [PATCH 1/3] feat(catalog): implement additional methods --- src/catalog.rs | 331 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 317 insertions(+), 14 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index 14a5e9e..f7af12c 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -6,6 +6,7 @@ use crate::errors::SparkError; use crate::plan::LogicalPlanBuilder; use crate::session::SparkSession; use crate::spark; +use crate::storage::StorageLevel; #[derive(Debug, Clone)] pub struct Catalog { @@ -17,6 +18,17 @@ impl Catalog { Self { spark_session } } + fn arrow_to_bool(record: RecordBatch) -> Result { + let col = record.column(0); + + let data: &arrow::array::BooleanArray = match col.data_type() { + arrow::datatypes::DataType::Boolean => col.as_any().downcast_ref().unwrap(), + _ => unimplemented!("only Boolean data types are currently handled currently."), + }; + + Ok(data.value(0)) + } + /// Returns the current default catalog in this session #[allow(non_snake_case)] pub async fn currentCatalog(self) -> Result { @@ -31,18 +43,19 @@ impl Catalog { self.spark_session.client().to_first_value(plan).await } - /// Returns the current default database in this session #[allow(non_snake_case)] - pub async fn currentDatabase(self) -> Result { - let cat_type = Some(spark::catalog::CatType::CurrentDatabase( - spark::CurrentDatabase {}, + pub async fn setCurrentCatalog(self, catalogName: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::SetCurrentCatalog( + spark::SetCurrentCatalog { + catalog_name: catalogName.to_string(), + }, )); let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); - self.spark_session.client().to_first_value(plan).await + self.spark_session.client().execute_command(plan).await } /// Returns a list of catalogs in this session @@ -61,6 +74,35 @@ impl Catalog { self.spark_session.client().to_arrow(plan).await } + /// Returns the current default database in this session + #[allow(non_snake_case)] + pub async fn currentDatabase(self) -> Result { + let cat_type = Some(spark::catalog::CatType::CurrentDatabase( + spark::CurrentDatabase {}, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().to_first_value(plan).await + } + + #[allow(non_snake_case)] + pub async fn setCurrentDatabase(self, dbName: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::SetCurrentDatabase( + spark::SetCurrentDatabase { + db_name: dbName.to_string(), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + /// Returns a list of databases in this session #[allow(non_snake_case)] pub async fn listDatabases(self, pattern: Option<&str>) -> Result { @@ -77,6 +119,19 @@ impl Catalog { self.spark_session.client().to_arrow(plan).await } + #[allow(non_snake_case)] + pub async fn getDatabase(self, dbName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::GetDatabase(spark::GetDatabase { + db_name: dbName.to_string(), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().to_arrow(plan).await + } + /// Returns a list of tables/views in the specific database #[allow(non_snake_case)] pub async fn listTables( @@ -96,6 +151,76 @@ impl Catalog { self.spark_session.client().to_arrow(plan).await } + #[allow(non_snake_case)] + pub async fn getTable(self, tableName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::GetTable(spark::GetTable { + table_name: tableName.to_string(), + db_name: None, + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().to_arrow(plan).await + } + + #[allow(non_snake_case)] + pub async fn listFunctions( + self, + dbName: Option<&str>, + pattern: Option<&str>, + ) -> Result { + let cat_type = Some(spark::catalog::CatType::ListFunctions( + spark::ListFunctions { + db_name: dbName.map(|val| val.to_owned()), + pattern: pattern.map(|val| val.to_owned()), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().to_arrow(plan).await + } + + #[allow(non_snake_case)] + pub async fn functionExists( + self, + functionName: &str, + dbName: Option<&str>, + ) -> Result { + let cat_type = Some(spark::catalog::CatType::FunctionExists( + spark::FunctionExists { + function_name: functionName.to_string(), + db_name: dbName.map(|val| val.to_owned()), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + + #[allow(non_snake_case)] + pub async fn getFunction(self, functionName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::GetFunction(spark::GetFunction { + function_name: functionName.to_string(), + db_name: None, + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().to_arrow(plan).await + } + /// Returns a list of columns for the given tables/views in the specific database #[allow(non_snake_case)] pub async fn listColumns( @@ -114,6 +239,158 @@ impl Catalog { self.spark_session.client().to_arrow(plan).await } + + #[allow(non_snake_case)] + pub async fn tableExists( + self, + tableName: &str, + dbName: Option<&str>, + ) -> Result { + let cat_type = Some(spark::catalog::CatType::TableExists(spark::TableExists { + table_name: tableName.to_string(), + db_name: dbName.map(|val| val.to_owned()), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + + #[allow(non_snake_case)] + pub async fn dropTempView(self, viewName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::DropTempView(spark::DropTempView { + view_name: viewName.to_string(), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + + #[allow(non_snake_case)] + pub async fn dropGlobalTempView(self, viewName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::DropGlobalTempView( + spark::DropGlobalTempView { + view_name: viewName.to_string(), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + + #[allow(non_snake_case)] + pub async fn isCached(self, tableName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::IsCached(spark::IsCached { + table_name: tableName.to_string(), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + + #[allow(non_snake_case)] + pub async fn cachedTable( + self, + tableName: &str, + storageLevel: Option, + ) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::CacheTable(spark::CacheTable { + table_name: tableName.to_string(), + storage_level: storageLevel.map(|val| val.to_owned().into()), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + + #[allow(non_snake_case)] + pub async fn uncacheTable(self, tableName: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::UncacheTable(spark::UncacheTable { + table_name: tableName.to_string(), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + + #[allow(non_snake_case)] + pub async fn clearCache(self) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::ClearCache(spark::ClearCache {})); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + + #[allow(non_snake_case)] + pub async fn refreshTable(self, tableName: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::RefreshTable(spark::RefreshTable { + table_name: tableName.to_string(), + })); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + + #[allow(non_snake_case)] + pub async fn recoverPartitions(self, tableName: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::RecoverPartitions( + spark::RecoverPartitions { + table_name: tableName.to_string(), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } + + #[allow(non_snake_case)] + pub async fn refreshByPath(self, path: &str) -> Result<(), SparkError> { + let cat_type = Some(spark::catalog::CatType::RefreshByPath( + spark::RefreshByPath { + path: path.to_string(), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + self.spark_session.client().execute_command(plan).await + } } #[cfg(test)] @@ -144,26 +421,25 @@ mod tests { assert_eq!(value, "spark_catalog".to_string()); Ok(()) } - #[tokio::test] - async fn test_current_database() -> Result<(), SparkError> { + async fn test_list_catalogs() -> Result<(), SparkError> { let spark = setup().await; - let value = spark.catalog().currentDatabase().await?; + let value = spark.catalog().listCatalogs(None).await?; + + assert_eq!(2, value.num_columns()); + assert_eq!(1, value.num_rows()); - assert_eq!(value, "default".to_string()); Ok(()) } #[tokio::test] - async fn test_list_catalogs() -> Result<(), SparkError> { + async fn test_current_database() -> Result<(), SparkError> { let spark = setup().await; - let value = spark.catalog().listCatalogs(None).await?; - - assert_eq!(2, value.num_columns()); - assert_eq!(1, value.num_rows()); + let value = spark.catalog().currentDatabase().await?; + assert_eq!(value, "default".to_string()); Ok(()) } @@ -185,6 +461,33 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_get_database() -> Result<(), SparkError> { + let spark = setup().await; + + spark + .clone() + .sql("CREATE SCHEMA IF NOT EXISTS spark_rust") + .await?; + + let values = spark.catalog().getDatabase("spark_rust").await?; + + println!("{:?}", values); + + assert_eq!(1, 1); + Ok(()) + } + + #[tokio::test] + async fn test_function_exists() -> Result<(), SparkError> { + let spark = setup().await; + + let res = spark.catalog().functionExists("len", None).await?; + + assert!(res); + Ok(()) + } + #[tokio::test] async fn test_list_databases_pattern() -> Result<(), SparkError> { let spark = setup().await; From 08e33667bd3b539e75d2bc75d91a77ddf8430e6f Mon Sep 17 00:00:00 2001 From: Steve Russo <64294847+sjrusso8@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:24:43 -0400 Subject: [PATCH 2/3] unit tests --- src/catalog.rs | 188 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 172 insertions(+), 16 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index f7af12c..a75223e 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -132,18 +132,37 @@ impl Catalog { self.spark_session.client().to_arrow(plan).await } + #[allow(non_snake_case)] + pub async fn databaseExists(self, dbName: &str) -> Result { + let cat_type = Some(spark::catalog::CatType::DatabaseExists( + spark::DatabaseExists { + db_name: dbName.to_string(), + }, + )); + + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); + + let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); + + let record = self.spark_session.client().to_arrow(plan).await?; + + Catalog::arrow_to_bool(record) + } + /// Returns a list of tables/views in the specific database #[allow(non_snake_case)] pub async fn listTables( self, - dbName: Option<&str>, pattern: Option<&str>, + dbName: Option<&str>, ) -> Result { let cat_type = Some(spark::catalog::CatType::ListTables(spark::ListTables { db_name: dbName.map(|db| db.to_owned()), pattern: pattern.map(|val| val.to_owned()), })); + println!("{:?}", cat_type); + let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type)); @@ -308,7 +327,7 @@ impl Catalog { } #[allow(non_snake_case)] - pub async fn cachedTable( + pub async fn cacheTable( self, tableName: &str, storageLevel: Option, @@ -404,7 +423,7 @@ mod tests { async fn setup() -> SparkSession { println!("SparkSession Setup"); - let connection = "sc://127.0.0.1:15002/;user_id=rust_catalog;session_id=f93c9562-cb73-473c-add4-c73a236e50dc"; + let connection = "sc://127.0.0.1:15002/;user_id=rust_catalog"; SparkSessionBuilder::remote(connection) .build() @@ -421,6 +440,31 @@ mod tests { assert_eq!(value, "spark_catalog".to_string()); Ok(()) } + + #[tokio::test] + async fn test_set_current_catalog() -> Result<(), SparkError> { + let spark = setup().await; + + spark.catalog().setCurrentCatalog("spark_catalog").await?; + + assert!(true); + Ok(()) + } + + #[tokio::test] + #[should_panic] + async fn test_set_current_catalog_panic() -> () { + let spark = setup().await; + + spark + .catalog() + .setCurrentCatalog("not_a_real_catalog") + .await + .unwrap(); + + () + } + #[tokio::test] async fn test_list_catalogs() -> Result<(), SparkError> { let spark = setup().await; @@ -443,6 +487,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_set_current_database() -> Result<(), SparkError> { + let spark = setup().await; + + spark + .clone() + .sql("CREATE SCHEMA IF NOT EXISTS spark_rust_db") + .await?; + + spark.catalog().setCurrentDatabase("spark_rust_db").await?; + + assert!(true); + Ok(()) + } + + #[tokio::test] + #[should_panic] + async fn test_set_current_database_panic() -> () { + let spark = setup().await; + + spark + .catalog() + .setCurrentCatalog("not_a_real_db") + .await + .unwrap(); + + () + } #[tokio::test] async fn test_list_databases() -> Result<(), SparkError> { let spark = setup().await; @@ -453,10 +525,15 @@ mod tests { .await .unwrap(); - let value = spark.catalog().listDatabases(None).await?; + let res = spark.clone().catalog().listDatabases(None).await?; + + assert_eq!(4, res.num_columns()); + assert_eq!(2, res.num_rows()); + + let res = spark.catalog().listDatabases(Some("*rust")).await?; - assert_eq!(4, value.num_columns()); - assert_eq!(2, value.num_rows()); + assert_eq!(4, res.num_columns()); + assert_eq!(1, res.num_rows()); Ok(()) } @@ -470,11 +547,23 @@ mod tests { .sql("CREATE SCHEMA IF NOT EXISTS spark_rust") .await?; - let values = spark.catalog().getDatabase("spark_rust").await?; + let res = spark.catalog().getDatabase("spark_rust").await?; - println!("{:?}", values); + assert_eq!(res.num_rows(), 1); + Ok(()) + } - assert_eq!(1, 1); + #[tokio::test] + async fn test_database_exists() -> Result<(), SparkError> { + let spark = setup().await; + + let res = spark.clone().catalog().databaseExists("default").await?; + + assert!(res); + + let res = spark.clone().catalog().databaseExists("not_real").await?; + + assert!(!res); Ok(()) } @@ -489,20 +578,87 @@ mod tests { } #[tokio::test] - async fn test_list_databases_pattern() -> Result<(), SparkError> { + async fn test_list_columns() -> Result<(), SparkError> { let spark = setup().await; + spark.clone().sql("DROP TABLE IF EXISTS tmp_table").await?; + spark .clone() - .sql("CREATE SCHEMA IF NOT EXISTS spark_rust") - .await - .unwrap(); + .sql("CREATE TABLE tmp_table (name STRING, age INT) using parquet") + .await?; - let value = spark.catalog().listDatabases(Some("*rust")).await?; + let res = spark + .clone() + .catalog() + .listColumns("tmp_table", None) + .await?; - assert_eq!(4, value.num_columns()); - assert_eq!(1, value.num_rows()); + assert_eq!(res.num_rows(), 2); + + spark.clone().sql("DROP TABLE IF EXISTS tmp_table").await?; + Ok(()) + } + + #[tokio::test] + async fn test_drop_view() -> Result<(), SparkError> { + let spark = setup().await; + + spark + .clone() + .range(None, 2, 1, Some(1)) + .createOrReplaceGlobalTempView("tmp_view") + .await?; + + let res = spark + .clone() + .catalog() + .dropGlobalTempView("tmp_view") + .await?; + + assert!(res); + + spark + .clone() + .range(None, 2, 1, Some(1)) + .createOrReplaceTempView("tmp_view") + .await?; + + let res = spark.catalog().dropTempView("tmp_view").await?; + + assert!(res); + + Ok(()) + } + + #[tokio::test] + async fn test_cache_table() -> Result<(), SparkError> { + let spark = setup().await; + + spark.clone().sql("DROP TABLE IF EXISTS tmp_table").await?; + + spark + .clone() + .sql("CREATE TABLE tmp_table (name STRING, age INT) using parquet") + .await?; + + spark + .clone() + .catalog() + .cacheTable("tmp_table", None) + .await?; + + let res = spark.clone().catalog().isCached("tmp_table").await?; + + assert!(res); + + spark.clone().catalog().uncacheTable("tmp_table").await?; + + let res = spark.clone().catalog().isCached("tmp_table").await?; + + assert!(!res); + spark.sql("DROP TABLE IF EXISTS tmp_table").await?; Ok(()) } } From e123e04b0ccd51dfa9a9b0e9aadb6d47341940a7 Mon Sep 17 00:00:00 2001 From: Steve Russo <64294847+sjrusso8@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:52:38 -0400 Subject: [PATCH 3/3] remove line --- src/catalog.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/catalog.rs b/src/catalog.rs index a75223e..02973aa 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -161,8 +161,6 @@ impl Catalog { pattern: pattern.map(|val| val.to_owned()), })); - println!("{:?}", cat_type); - let rel_type = spark::relation::RelType::Catalog(spark::Catalog { cat_type }); let plan = LogicalPlanBuilder::plan_root(LogicalPlanBuilder::from(rel_type));