diff --git a/src/catalog.rs b/src/catalog.rs index 02973aa..4a550de 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -489,14 +489,18 @@ mod tests { async fn test_set_current_database() -> Result<(), SparkError> { let spark = setup().await; + spark.clone().sql("CREATE SCHEMA current_db").await?; + spark .clone() - .sql("CREATE SCHEMA IF NOT EXISTS spark_rust_db") + .catalog() + .setCurrentDatabase("current_db") .await?; - spark.catalog().setCurrentDatabase("spark_rust_db").await?; - assert!(true); + + spark.clone().sql("DROP SCHEMA current_db").await?; + Ok(()) } @@ -513,41 +517,19 @@ mod tests { () } - #[tokio::test] - async fn test_list_databases() -> Result<(), SparkError> { - let spark = setup().await; - - spark - .clone() - .sql("CREATE SCHEMA IF NOT EXISTS spark_rust") - .await - .unwrap(); - - let res = spark.clone().catalog().listDatabases(None).await?; - - assert_eq!(4, res.num_columns()); - assert_eq!(2, res.num_rows()); - - let res = spark.catalog().listDatabases(Some("*rust")).await?; - - assert_eq!(4, res.num_columns()); - assert_eq!(1, res.num_rows()); - - Ok(()) - } #[tokio::test] async fn test_get_database() -> Result<(), SparkError> { let spark = setup().await; - spark - .clone() - .sql("CREATE SCHEMA IF NOT EXISTS spark_rust") - .await?; + spark.clone().sql("CREATE SCHEMA get_db").await?; - let res = spark.catalog().getDatabase("spark_rust").await?; + let res = spark.clone().catalog().getDatabase("get_db").await?; assert_eq!(res.num_rows(), 1); + + spark.clone().sql("DROP SCHEMA get_db").await?; + Ok(()) } @@ -633,30 +615,28 @@ mod tests { async fn test_cache_table() -> Result<(), SparkError> { let spark = setup().await; - spark.clone().sql("DROP TABLE IF EXISTS tmp_table").await?; - spark .clone() - .sql("CREATE TABLE tmp_table (name STRING, age INT) using parquet") + .sql("CREATE TABLE cache_table (name STRING, age INT) using parquet") .await?; spark .clone() .catalog() - .cacheTable("tmp_table", None) + .cacheTable("cache_table", None) .await?; - let res = spark.clone().catalog().isCached("tmp_table").await?; + let res = spark.clone().catalog().isCached("cache_table").await?; assert!(res); - spark.clone().catalog().uncacheTable("tmp_table").await?; + spark.clone().catalog().uncacheTable("cache_table").await?; - let res = spark.clone().catalog().isCached("tmp_table").await?; + let res = spark.clone().catalog().isCached("cache_table").await?; assert!(!res); - spark.sql("DROP TABLE IF EXISTS tmp_table").await?; + spark.sql("DROP TABLE cache_table").await?; Ok(()) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 8a8b3f6..583b228 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; +use std::env; use std::str::FromStr; use std::sync::Arc; use crate::errors::SparkError; use crate::spark; -use crate::SparkSession; use spark::execute_plan_response::ResponseType; use spark::spark_connect_service_client::SparkConnectServiceClient; @@ -14,26 +14,29 @@ use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow_ipc::reader::StreamReader; -use parking_lot::Mutex; +use parking_lot::RwLock; use tonic::codegen::{Body, Bytes, StdError}; use tonic::metadata::{ Ascii, AsciiMetadataValue, KeyAndValueRef, MetadataKey, MetadataMap, MetadataValue, }; use tonic::service::Interceptor; -use tonic::transport::{Endpoint, Error}; use tonic::Status; use url::Url; use uuid::Uuid; +type Host = String; +type Port = u16; +type UrlParse = (Host, Port, Option>); + /// ChannelBuilder validates a connection string /// based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) #[derive(Clone, Debug)] pub struct ChannelBuilder { - host: String, - port: u16, + host: Host, + port: Port, session_id: Uuid, token: Option, user_id: Option, @@ -44,38 +47,89 @@ pub struct ChannelBuilder { impl Default for ChannelBuilder { fn default() -> Self { - ChannelBuilder::create("sc://127.0.0.1:15002").unwrap() + let connection = match env::var("SPARK_REMOTE") { + Ok(conn) => conn.to_string(), + Err(_) => "sc://localhost:15002".to_string(), + }; + + ChannelBuilder::create(&connection).unwrap() } } impl ChannelBuilder { - /// create and Validate a connnection string - #[allow(unreachable_code)] - pub fn create(connection: &str) -> Result { - let url = Url::parse(connection).map_err(|_| "Failed to parse the url.".to_string())?; + pub fn new() -> Self { + ChannelBuilder::default() + } + + pub fn endpoint(&self) -> String { + format!("https://{}:{}", self.host, self.port) + } + + pub fn token(&self) -> Option { + self.token.to_owned() + } + + pub fn headers(&self) -> Option { + self.headers.to_owned() + } + + fn create_user_agent(user_agent: Option<&str>) -> Option { + let user_agent = user_agent.unwrap_or("_SPARK_CONNECT_RUST"); + let pkg_version = env!("CARGO_PKG_VERSION"); + let os = env::consts::OS.to_lowercase(); + + Some(format!( + "{} os/{} spark_connect_rs/{}", + user_agent, os, pkg_version + )) + } + + fn create_user_id(user_id: Option<&str>) -> Option { + match user_id { + Some(user_id) => Some(user_id.to_string()), + None => match env::var("USER") { + Ok(user) => Some(user), + Err(_) => None, + }, + } + } + + pub fn parse_connection_string(connection: &str) -> Result { + let url = Url::parse(connection).map_err(|_| { + SparkError::InvalidConnectionUrl("Failed to parse the connection URL".to_string()) + })?; if url.scheme() != "sc" { - return Err("Scheme is not set to 'sc'".to_string()); + return Err(SparkError::InvalidConnectionUrl( + "The URL must start with 'sc://'. Please update the URL to follow the correct format, e.g., 'sc://hostname:port'".to_string(), + )); }; let host = url .host_str() - .ok_or("Missing host in the URL.".to_string())? + .ok_or_else(|| { + SparkError::InvalidConnectionUrl( + "The hostname must not be empty. Please update + the URL to follow the correct format, e.g., 'sc://hostname:port'." + .to_string(), + ) + })? .to_string(); - let port = url.port().ok_or("Missing port in the URL.".to_string())?; + let port = url.port().ok_or_else(|| { + SparkError::InvalidConnectionUrl( + "The port must not be empty. Please update + the URL to follow the correct format, e.g., 'sc://hostname:port'." + .to_string(), + ) + })?; - let mut channel_builder = ChannelBuilder { - host, - port, - session_id: Uuid::new_v4(), - token: None, - user_id: None, - user_agent: Some("_SPARK_CONNECT_RUST".to_string()), - use_ssl: false, - headers: None, - }; + let headers = ChannelBuilder::parse_headers(url); + + Ok((host, port, headers)) + } + pub fn parse_headers(url: Url) -> Option> { let path: Vec<&str> = url .path() .split(';') @@ -83,10 +137,10 @@ impl ChannelBuilder { .collect(); if path.is_empty() || (path.len() == 1 && (path[0].is_empty() || path[0] == "/")) { - return Ok(channel_builder); + return None; } - let mut headers: HashMap = path + let headers: HashMap = path .iter() .copied() .map(|pair| { @@ -99,23 +153,51 @@ impl ChannelBuilder { .collect(); if headers.is_empty() { - return Ok(channel_builder); + return None; } + Some(headers) + } + + /// Create and validate a connnection string + #[allow(unreachable_code)] + pub fn create(connection: &str) -> Result { + let (host, port, headers) = ChannelBuilder::parse_connection_string(connection)?; + + let mut channel_builder = ChannelBuilder { + host, + port, + session_id: Uuid::new_v4(), + token: None, + user_id: ChannelBuilder::create_user_id(None), + user_agent: ChannelBuilder::create_user_agent(None), + use_ssl: false, + headers: None, + }; + + let mut headers = match headers { + Some(headers) => headers, + None => return Ok(channel_builder), + }; + + channel_builder.user_id = headers + .remove("user_id") + .map(|user_id| ChannelBuilder::create_user_id(Some(&user_id))) + .unwrap_or_else(|| ChannelBuilder::create_user_id(None)); + + channel_builder.user_agent = headers + .remove("user_agent") + .map(|user_agent| ChannelBuilder::create_user_agent(Some(&user_agent))) + .unwrap_or_else(|| ChannelBuilder::create_user_agent(None)); + if let Some(token) = headers.remove("token") { channel_builder.token = Some(format!("Bearer {token}")); } - // !TODO try to grab the user id from the system if not provided - // when connecting to Databricks User ID is required to be populated - if let Some(user_id) = headers.remove("user_id") { - channel_builder.user_id = Some(user_id) - } - if let Some(user_agent) = headers.remove("user_agent") { - channel_builder.user_agent = Some(user_agent) - } + if let Some(session_id) = headers.remove("session_id") { channel_builder.session_id = Uuid::from_str(&session_id).unwrap() } + if let Some(use_ssl) = headers.remove("use_ssl") { if use_ssl.to_lowercase() == "true" { #[cfg(not(feature = "tls"))] @@ -132,31 +214,6 @@ impl ChannelBuilder { Ok(channel_builder) } - - async fn create_client(&self) -> Result { - let endpoint = format!("https://{}:{}", self.host, self.port); - - let channel = Endpoint::from_shared(endpoint)?.connect().await?; - - let service_client = SparkConnectServiceClient::with_interceptor( - channel, - MetadataInterceptor { - token: self.token.clone(), - metadata: self.headers.clone(), - }, - ); - - let client = Arc::new(Mutex::new(service_client)); - - let spark_connnect_client = SparkConnectClient { - stub: client.clone(), - builder: self.clone(), - handler: ResponseHandler::new(), - analyzer: AnalyzeHandler::new(), - }; - - Ok(SparkSession::new(spark_connnect_client)) - } } #[derive(Clone, Debug)] @@ -181,6 +238,12 @@ impl Interceptor for MetadataInterceptor { } } +impl MetadataInterceptor { + pub fn new(token: Option, metadata: Option) -> Self { + MetadataInterceptor { token, metadata } + } +} + fn metadata_builder(headers: &HashMap) -> MetadataMap { let mut metadata_map = MetadataMap::new(); for (key, val) in headers.iter() { @@ -213,45 +276,6 @@ where } } -/// SparkSessionBuilder creates a remote Spark Session a connection string. -/// -/// The connection string is define based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) -#[derive(Clone, Debug)] -pub struct SparkSessionBuilder { - pub channel_builder: ChannelBuilder, -} - -/// Default connects a Spark cluster running at `sc://127.0.0.1:15002/` -impl Default for SparkSessionBuilder { - fn default() -> Self { - let channel_builder = ChannelBuilder::default(); - - Self { channel_builder } - } -} - -impl SparkSessionBuilder { - fn new(connection: &str) -> Self { - let channel_builder = ChannelBuilder::create(connection).unwrap(); - - Self { channel_builder } - } - - /// Validate a connect string for a remote Spark Session - /// - /// String must conform to the [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) - pub fn remote(connection: &str) -> Self { - Self::new(connection) - } - - /// Attempt to connect to a remote Spark Session - /// - /// and return a [SparkSession] - pub async fn build(self) -> Result { - self.channel_builder.create_client().await - } -} - #[allow(dead_code)] #[derive(Default, Debug, Clone)] pub struct ResponseHandler { @@ -321,7 +345,7 @@ impl AnalyzeHandler { #[derive(Clone, Debug)] pub struct SparkConnectClient { - stub: Arc>>, + stub: Arc>>, builder: ChannelBuilder, pub handler: ResponseHandler, pub analyzer: AnalyzeHandler, @@ -334,6 +358,15 @@ where T::ResponseBody: Body + Send + 'static, ::Error: Into + Send, { + pub fn new(stub: Arc>>, builder: ChannelBuilder) -> Self { + SparkConnectClient { + stub, + builder, + handler: ResponseHandler::new(), + analyzer: AnalyzeHandler::new(), + } + } + pub fn session_id(&self) -> String { self.builder.session_id.to_string() } @@ -372,7 +405,7 @@ where &mut self, req: spark::ExecutePlanRequest, ) -> Result<(), SparkError> { - let mut client = self.stub.lock(); + let mut client = self.stub.write(); let mut resp = client.execute_plan(req).await?.into_inner(); @@ -402,7 +435,7 @@ where req.analyze = Some(analyze); - let mut client = self.stub.lock(); + let mut client = self.stub.write(); // clear out any prior responses self.analyzer = AnalyzeHandler::new(); @@ -630,37 +663,32 @@ mod tests { #[test] fn test_channel_builder_default() { - let expected_url = "127.0.0.1:15002"; + let expected_url = "https://localhost:15002".to_string(); let cb = ChannelBuilder::default(); - let output_url = format!("{}:{}", cb.host, cb.port); - - assert_eq!(expected_url, output_url) + assert_eq!(expected_url, cb.endpoint()) } #[test] - #[should_panic(expected = "Scheme is not set to 'sc")] fn test_panic_incorrect_url_scheme() { let connection = "http://127.0.0.1:15002"; - ChannelBuilder::create(&connection).unwrap(); + assert!(ChannelBuilder::create(connection).is_err()) } #[test] - #[should_panic(expected = "Failed to parse the url.")] fn test_panic_missing_url_host() { let connection = "sc://:15002"; - ChannelBuilder::create(&connection).unwrap(); + assert!(ChannelBuilder::create(connection).is_err()) } #[test] - #[should_panic(expected = "Missing port in the URL")] fn test_panic_missing_url_port() { let connection = "sc://127.0.0.1"; - ChannelBuilder::create(&connection).unwrap(); + assert!(ChannelBuilder::create(connection).is_err()) } #[test] @@ -672,36 +700,4 @@ mod tests { ChannelBuilder::create(&connection).unwrap(); } - - #[test] - fn test_spark_session_builder() { - let connection = "sc://myhost.com:443/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; - - let ssbuilder = SparkSessionBuilder::remote(connection); - - assert_eq!("myhost.com".to_string(), ssbuilder.channel_builder.host); - assert_eq!(443, ssbuilder.channel_builder.port); - assert_eq!( - "Bearer ABCDEFG".to_string(), - ssbuilder.channel_builder.token.unwrap() - ); - assert_eq!( - "user123".to_string(), - ssbuilder.channel_builder.user_id.unwrap() - ); - assert_eq!( - Some("some_agent".to_string()), - ssbuilder.channel_builder.user_agent - ); - } - - #[tokio::test] - async fn test_spark_session_create() { - let connection = - "sc://localhost:15002/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; - - let spark = SparkSessionBuilder::remote(connection).build().await; - - assert!(spark.is_ok()); - } } diff --git a/src/dataframe.rs b/src/dataframe.rs index 69c867e..2fadf76 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -376,12 +376,8 @@ impl DataFrame { explain_mode: explain_mode.into(), }); - let explain = self - .spark_session - .client() - .analyze(analyze) - .await? - .explain()?; + let mut client = self.spark_session.client(); + let explain = client.analyze(analyze).await?.explain()?; println!("{}", explain); @@ -446,11 +442,9 @@ impl DataFrame { }, ); - self.spark_session - .client() - .analyze(input_files) - .await? - .input_files() + let mut client = self.spark_session.client(); + + client.analyze(input_files).await?.input_files() } /// Return a new DataFrame containing rows only in both this DataFrame and another DataFrame. @@ -488,11 +482,9 @@ impl DataFrame { }, ); - self.spark_session - .client() - .analyze(is_streaming) - .await? - .is_streaming() + let mut client = self.spark_session.client(); + + client.analyze(is_streaming).await?.is_streaming() } /// Joins with another DataFrame, using the given join expression. @@ -563,12 +555,9 @@ impl DataFrame { storage_level: Some(storage_level.into()), }); - self.spark_session - .clone() - .client() - .analyze(analyze) - .await - .unwrap(); + let mut client = self.spark_session.clone().client(); + + client.analyze(analyze).await.unwrap(); DataFrame::new(self.spark_session, self.logical_plan) } @@ -582,14 +571,10 @@ impl DataFrame { level, }, ); - let tree = self - .spark_session - .client() - .analyze(tree_string) - .await? - .tree_string()?; - Ok(tree) + let mut client = self.spark_session.client(); + + client.analyze(tree_string).await?.tree_string() } /// Returns a new [DataFrame] partitioned by the given partition number and shuffle option @@ -624,14 +609,10 @@ impl DataFrame { other_plan, }, ); - let same_semantics = self - .spark_session - .client() - .analyze(same_semantics) - .await? - .same_semantics()?; - Ok(same_semantics) + let mut client = self.spark_session.client(); + + client.analyze(same_semantics).await?.same_semantics() } /// Returns a sampled subset of this [DataFrame] @@ -659,14 +640,9 @@ impl DataFrame { plan: Some(plan), }); - let data_type = self - .spark_session - .client() - .analyze(schema) - .await? - .schema()?; + let mut client = self.spark_session.client(); - Ok(data_type) + client.analyze(schema).await?.schema() } /// Projects a set of expressions and returns a new [DataFrame] @@ -711,14 +687,9 @@ impl DataFrame { spark::analyze_plan_request::SemanticHash { plan: Some(plan) }, ); - let semantic_hash = self - .spark_session - .client() - .analyze(semantic_hash) - .await? - .semantic_hash()?; + let mut client = self.spark_session.client(); - Ok(semantic_hash) + client.analyze(semantic_hash).await?.semantic_hash() } /// Prints the first `n` rows to the console @@ -769,14 +740,10 @@ impl DataFrame { }, ); - let storage = self - .spark_session - .client() - .analyze(storage_level) - .await? - .get_storage_level()?; + let mut client = self.spark_session.client(); + let storage = client.analyze(storage_level).await?.get_storage_level(); - Ok(storage.into()) + Ok(storage?.into()) } pub fn subtract(self, other: DataFrame) -> DataFrame { @@ -851,12 +818,9 @@ impl DataFrame { }, ); - self.spark_session - .clone() - .client() - .analyze(unpersist) - .await - .unwrap(); + let mut client = self.spark_session.clone().client(); + + client.analyze(unpersist).await.unwrap(); DataFrame::new(self.spark_session, self.logical_plan) } diff --git a/src/errors.rs b/src/errors.rs index 9a25c9e..4880c5d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -16,6 +16,7 @@ pub enum SparkError { AnalysisException(String), IoError(String, std::io::Error), ArrowError(ArrowError), + InvalidConnectionUrl(String), } impl SparkError { @@ -75,6 +76,7 @@ impl Display for SparkError { SparkError::IoError(desc, _) => write!(f, "Io error: {desc}"), SparkError::ArrowError(desc) => write!(f, "Apache Arrow error: {desc}"), SparkError::NotYetImplemented(source) => write!(f, "Not yet implemented: {source}"), + SparkError::InvalidConnectionUrl(val) => write!(f, "Invalid URL error: {val}"), } } } diff --git a/src/group.rs b/src/group.rs index 65f463f..65b9b81 100644 --- a/src/group.rs +++ b/src/group.rs @@ -1,8 +1,7 @@ //! A DataFrame created with an aggregate statement -use crate::column::Column; use crate::dataframe::DataFrame; -use crate::expressions::{ToLiteral, ToVecExpr}; +use crate::expressions::{ToExpr, ToLiteral, ToVecExpr}; use crate::plan::LogicalPlanBuilder; use crate::functions::lit; @@ -16,7 +15,7 @@ pub struct GroupedData { df: DataFrame, group_type: GroupType, grouping_cols: Vec, - pivot_col: Option, + pivot_col: Option, pivot_vals: Option>, } @@ -25,7 +24,7 @@ impl GroupedData { df: DataFrame, group_type: GroupType, grouping_cols: Vec, - pivot_col: Option, + pivot_col: Option, pivot_vals: Option>, ) -> GroupedData { Self { @@ -84,7 +83,7 @@ impl GroupedData { self.df, GroupType::Pivot, self.grouping_cols, - Some(Column::from(col)), + Some(col.to_expr()), pivot_vals, ) } diff --git a/src/lib.rs b/src/lib.rs index f5b0614..3f9a045 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,7 +133,6 @@ pub mod types; mod utils; pub mod window; -pub use arrow; pub use dataframe::{DataFrame, DataFrameReader, DataFrameWriter}; pub use session::{SparkSession, SparkSessionBuilder}; diff --git a/src/plan.rs b/src/plan.rs index eccd189..835ee0c 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use std::sync::Mutex; -use crate::column::Column; use crate::errors::SparkError; use crate::expressions::{ToExpr, ToFilterExpr, ToVecExpr}; use crate::spark; @@ -118,12 +117,12 @@ impl LogicalPlanBuilder { group_type: GroupType, grouping_cols: Vec, agg_expression: T, - pivot_col: Option, + pivot_col: Option, pivot_vals: Option>, ) -> LogicalPlanBuilder { let pivot = match group_type { GroupType::Pivot => Some(spark::aggregate::Pivot { - col: pivot_col.map(|col| col.expression), + col: pivot_col, values: pivot_vals.unwrap_or_default(), }), _ => None, @@ -560,9 +559,10 @@ impl LogicalPlanBuilder { LogicalPlanBuilder::from(rel_type) } - pub fn sort(self, cols: I) -> LogicalPlanBuilder + pub fn sort(self, cols: I) -> LogicalPlanBuilder where - I: IntoIterator, + T: ToExpr, + I: IntoIterator, { let order = sort_order(cols); let sort_type = RelType::Sort(Box::new(spark::Sort { @@ -575,7 +575,7 @@ impl LogicalPlanBuilder { } #[allow(non_snake_case)] - pub fn withColumn(self, colName: &str, col: Column) -> LogicalPlanBuilder { + pub fn withColumn(self, colName: &str, col: T) -> LogicalPlanBuilder { let aliases: Vec = vec![spark::expression::Alias { expr: Some(Box::new(col.to_expr())), name: vec![colName.to_string()], @@ -591,10 +591,11 @@ impl LogicalPlanBuilder { } #[allow(non_snake_case)] - pub fn withColumns(self, colMap: I) -> LogicalPlanBuilder + pub fn withColumns(self, colMap: I) -> LogicalPlanBuilder where - I: IntoIterator, + T: ToExpr, K: ToString, + I: IntoIterator, { let aliases: Vec = colMap .into_iter() diff --git a/src/readwriter.rs b/src/readwriter.rs index 615f34d..9e228f8 100644 --- a/src/readwriter.rs +++ b/src/readwriter.rs @@ -353,6 +353,7 @@ mod tests { let path = "/opt/spark/examples/src/main/rust/employees/"; df.write() + .mode(SaveMode::Overwrite) .format("csv") .option("header", "true") .save(path) diff --git a/src/session.rs b/src/session.rs index 05cf905..52ea5aa 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,20 +1,86 @@ //! Spark Session containing the remote gRPC client use std::collections::HashMap; +use std::sync::Arc; use crate::catalog::Catalog; -pub use crate::client::SparkSessionBuilder; +use crate::client::ChannelBuilder; use crate::client::{MetadataInterceptor, SparkConnectClient}; use crate::dataframe::{DataFrame, DataFrameReader}; use crate::errors::SparkError; use crate::plan::LogicalPlanBuilder; use crate::spark; use crate::streaming::DataStreamReader; +use spark::spark_connect_service_client::SparkConnectServiceClient; use arrow::record_batch::RecordBatch; +use parking_lot::RwLock; + use tonic::service::interceptor::InterceptedService; -use tonic::transport::Channel; +use tonic::transport::{Channel, Endpoint}; + +/// SparkSessionBuilder creates a remote Spark Session a connection string. +/// +/// The connection string is define based on the requirements from [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) +#[derive(Clone, Debug)] +pub struct SparkSessionBuilder { + pub channel_builder: ChannelBuilder, +} + +/// Default connects a Spark cluster running at `sc://127.0.0.1:15002/` +impl Default for SparkSessionBuilder { + fn default() -> Self { + let channel_builder = ChannelBuilder::default(); + + Self { channel_builder } + } +} + +impl SparkSessionBuilder { + fn new(connection: &str) -> Self { + let channel_builder = ChannelBuilder::create(connection).unwrap(); + + Self { channel_builder } + } + + /// Validate a connect string for a remote Spark Session + /// + /// String must conform to the [Spark Documentation](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md) + pub fn remote(connection: &str) -> Self { + Self::new(connection) + } + + async fn create_client(&self) -> Result { + let channel = Endpoint::from_shared(self.channel_builder.endpoint()) + .expect("Failed to create endpoint") + .connect() + .await + .expect("Failed to create channel"); + + let service_client = SparkConnectServiceClient::with_interceptor( + channel, + MetadataInterceptor::new( + self.channel_builder.token().to_owned(), + self.channel_builder.headers().to_owned(), + ), + ); + + let client = Arc::new(RwLock::new(service_client)); + + let spark_connnect_client = + SparkConnectClient::new(client.clone(), self.channel_builder.clone()); + + Ok(SparkSession::new(spark_connnect_client)) + } + + /// Attempt to connect to a remote Spark Session + /// + /// and return a [SparkSession] + pub async fn build(self) -> Result { + self.create_client().await + } +} /// The entry point to connecting to a Spark Cluster /// using the Spark Connection gRPC protocol. @@ -115,3 +181,34 @@ impl SparkSession { self.client } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_spark_session_builder() { + let connection = "sc://myhost.com:443/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; + + let ssbuilder = SparkSessionBuilder::remote(connection); + + assert_eq!( + "https://myhost.com:443".to_string(), + ssbuilder.channel_builder.endpoint() + ); + assert_eq!( + "Bearer ABCDEFG".to_string(), + ssbuilder.channel_builder.token().unwrap() + ); + } + + #[tokio::test] + async fn test_spark_session_create() { + let connection = + "sc://localhost:15002/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; + + let spark = SparkSessionBuilder::remote(connection).build().await; + + assert!(spark.is_ok()); + } +} diff --git a/src/utils.rs b/src/utils.rs index 878c818..10bc925 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,7 @@ use crate::spark; use crate::column::Column; -use crate::expressions::ToVecExpr; +use crate::expressions::{ToExpr, ToVecExpr}; pub fn invoke_func(name: &str, args: T) -> Column { Column::from(spark::Expression { @@ -16,15 +16,16 @@ pub fn invoke_func(name: &str, args: T) -> Column { }) } -pub fn sort_order(cols: I) -> Vec +pub fn sort_order(cols: I) -> Vec where - I: IntoIterator, + T: ToExpr, + I: IntoIterator, { cols.into_iter() - .map(|col| match col.clone().expression.expr_type.unwrap() { + .map(|col| match col.to_expr().expr_type.unwrap() { spark::expression::ExprType::SortOrder(ord) => *ord, _ => spark::expression::SortOrder { - child: Some(Box::new(col.expression)), + child: Some(Box::new(col.to_expr())), direction: 1, null_ordering: 1, }, diff --git a/src/window.rs b/src/window.rs index 85bd3b0..b7f1ca8 100644 --- a/src/window.rs +++ b/src/window.rs @@ -1,8 +1,6 @@ //! Utility structs for defining a window over a DataFrame -use crate::column::Column; -use crate::expressions::ToVecExpr; -use crate::functions::lit; +use crate::expressions::{ToExpr, ToLiteralExpr, ToVecExpr}; use crate::utils::sort_order; use crate::spark; @@ -37,9 +35,10 @@ impl WindowSpec { } #[allow(non_snake_case)] - pub fn orderBy(self, cols: I) -> WindowSpec + pub fn orderBy(self, cols: I) -> WindowSpec where - I: IntoIterator, + T: ToExpr, + I: IntoIterator, { let order_spec = sort_order(cols); @@ -80,10 +79,10 @@ impl WindowSpec { // !TODO - I don't like casting this to i32 // however, the window boundary is expecting an INT and not a BIGINT // i64 is a BIGINT (i.e. Long) - let value = lit(value as i32).expression; + let expr = (value as i32).to_literal_expr(); let boundary = Some(window::window_frame::frame_boundary::Boundary::Value( - Box::new(value), + Box::new(expr), )); Some(Box::new(window::window_frame::FrameBoundary { boundary })) @@ -150,9 +149,10 @@ impl Window { /// Creates a [WindowSpec] with the ordering defined #[allow(non_snake_case)] - pub fn orderBy(mut self, cols: I) -> WindowSpec + pub fn orderBy(mut self, cols: I) -> WindowSpec where - I: IntoIterator, + T: ToExpr, + I: IntoIterator, { self.spec = self.spec.orderBy(cols);