diff --git a/src/snowflake/snowpark/_internal/data_source_utils.py b/src/snowflake/snowpark/_internal/data_source_utils.py new file mode 100644 index 00000000000..be09f559a30 --- /dev/null +++ b/src/snowflake/snowpark/_internal/data_source_utils.py @@ -0,0 +1,344 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import datetime +import logging +from enum import Enum +from typing import List, Any, Tuple, Protocol, Union +from snowflake.connector.options import pandas as pd + +from snowflake.snowpark.exceptions import SnowparkDataframeReaderException +from snowflake.snowpark.types import ( + StringType, + GeographyType, + VariantType, + BinaryType, + GeometryType, + TimestampType, + DecimalType, + FloatType, + ShortType, + IntegerType, + BooleanType, + LongType, + TimeType, + ByteType, + DateType, + TimestampTimeZone, + StructType, + StructField, +) + +_logger = logging.getLogger(__name__) + +SQL_SERVER_TYPE_TO_SNOW_TYPE = { + "bigint": LongType, + "bit": BooleanType, + "decimal": DecimalType, + "float": FloatType, + "int": IntegerType, + "money": DecimalType, + "real": FloatType, + "smallint": ShortType, + "smallmoney": DecimalType, + "tinyint": ByteType, + "numeric": DecimalType, + "date": DateType, + "datetime2": TimestampType, + "datetime": TimestampType, + "datetimeoffset": TimestampType, + "smalldatetime": TimestampType, + "time": TimeType, + "timestamp": TimestampType, + "char": StringType, + "text": StringType, + "varchar": StringType, + "nchar": StringType, + "ntext": StringType, + "nvarchar": StringType, + "binary": BinaryType, + "varbinary": BinaryType, + "image": BinaryType, + "sql_variant": VariantType, + "geography": GeographyType, + "geometry": GeometryType, + "uniqueidentifier": StringType, + "xml": StringType, + "sysname": StringType, +} +ORACLEDB_TYPE_TO_SNOW_TYPE = { + "bfile": BinaryType, + "varchar2": StringType, + "varchar": StringType, + "nvarchar2": StringType, + # TODO: SNOW-1922043 Investigation on handling number type in oracle db + "number": DecimalType, + "float": FloatType, + "long": LongType, + "date": DateType, + "intervalyeartomonth": StringType, + "intervaldaytosecond": StringType, + "json": VariantType, + "binary_float": FloatType, + "binary_double": FloatType, + "timestamp": TimestampType, + "longraw": BinaryType, + "raw": BinaryType, + "clob": StringType, + "nclob": StringType, + "blob": BinaryType, + "char": StringType, + "nchar": StringType, + "rowid": StringType, + "sys.anydata": VariantType, + "uritype": VariantType, + "urowid": StringType, + "xmltype": VariantType, +} + + +STATEMENT_PARAMS_DATA_SOURCE = "SNOWPARK_PYTHON_DATASOURCE" +DATA_SOURCE_DBAPI_SIGNATURE = "DataFrameReader.dbapi" +DATA_SOURCE_SQL_COMMENT = ( + f"/* Python:snowflake.snowpark.{DATA_SOURCE_DBAPI_SIGNATURE} */" +) + + +class DBMS_TYPE(Enum): + SQL_SERVER_DB = "SQL_SERVER_DB" + ORACLE_DB = "ORACLE_DB" + SQLITE_DB = "SQLITE3_DB" + + +def detect_dbms(dbapi2_conn) -> Union[DBMS_TYPE, str]: + """Detects the DBMS type from a DBAPI2 connection.""" + + # Get the Python driver name + python_driver_name = type(dbapi2_conn).__module__.lower() + + # Dictionary-based lookup for known DBMS + dbms_mapping = { + "pyodbc": detect_dbms_pyodbc, + "cx_oracle": lambda conn: DBMS_TYPE.ORACLE_DB, + "oracledb": lambda conn: DBMS_TYPE.ORACLE_DB, + "sqlite3": lambda conn: DBMS_TYPE.SQLITE_DB, + } + + if python_driver_name in dbms_mapping: + return dbms_mapping[python_driver_name](dbapi2_conn), python_driver_name + + _logger.debug(f"Unsupported database driver: {python_driver_name}") + return None + + +def detect_dbms_pyodbc(dbapi2_conn): + """Detects the DBMS type for a pyodbc connection.""" + import pyodbc + + dbms_name = dbapi2_conn.getinfo(pyodbc.SQL_DBMS_NAME).lower() + + # Set-based lookup for SQL Server + sqlserver_keywords = {"sql server", "mssql", "sqlserver"} + if any(keyword in dbms_name for keyword in sqlserver_keywords): + return DBMS_TYPE.SQL_SERVER_DB + + _logger.debug(f"Unsupported DBMS for pyodbc: {dbms_name}") + return None + + +class Connection(Protocol): + """External datasource connection created from user-input create_connection function.""" + + def cursor(self) -> "Cursor": + pass + + def close(self): + pass + + def commit(self): + pass + + def rollback(self): + pass + + +class Cursor(Protocol): + """Cursor created from external datasource connection""" + + def execute(self, sql: str, *params: Any) -> "Cursor": + pass + + def fetchall(self) -> List[Tuple]: + pass + + def fetchone(self): + pass + + def close(self): + pass + + +def sql_server_to_snowpark_type(schema: List[tuple]) -> StructType: + """ + This is used to convert sql server raw schema to snowpark structtype. + Each tuple in the list represent a column and values are as follows: + column name: str + data type: str + precision: int + scale: int + nullable: bool + """ + fields = [] + for column in schema: + snow_type = SQL_SERVER_TYPE_TO_SNOW_TYPE.get(column[1].lower(), None) + if snow_type is None: + # TODO: SNOW-1912068 support types that we don't have now + raise NotImplementedError(f"sql server type not supported: {column[1]}") + if column[1].lower() in ["datetime2", "datetime", "smalldatetime"]: + data_type = snow_type(TimestampTimeZone.NTZ) + elif column[1].lower() == "datetimeoffset": + data_type = snow_type(TimestampTimeZone.LTZ) + elif snow_type == DecimalType: + data_type = snow_type( + column[2] if column[2] is not None else 38, + column[3] if column[3] is not None else 0, + ) + else: + data_type = snow_type() + fields.append(StructField(column[0], data_type, column[4])) + + return StructType(fields) + + +def oracledb_to_snowpark_type(schema: List[tuple]) -> StructType: + """ + This is used to convert oracledb raw schema to snowpark structtype. + Each tuple in the list represent a column and values are as follows: + column name: str + data type: str + precision: int + scale: int + nullable: str + """ + fields = [] + for column in schema: + remove_space_column_name = column[1].lower().replace(" ", "") + processed_column_name = ( + "timestamp" + if remove_space_column_name.startswith("timestamp") + else remove_space_column_name + ) + snow_type = ORACLEDB_TYPE_TO_SNOW_TYPE.get(processed_column_name, None) + if snow_type is None: + # TODO: SNOW-1912068 support types that we don't have now + raise NotImplementedError(f"oracledb type not supported: {column[1]}") + if "withtimezone" in remove_space_column_name: + data_type = snow_type(TimestampTimeZone.TZ) + elif "withlocaltimezone" in remove_space_column_name: + data_type = snow_type(TimestampTimeZone.LTZ) + elif snow_type == DecimalType: + data_type = snow_type( + column[2] if column[2] is not None else 38, + column[3] if column[3] is not None else 0, + ) + else: + data_type = snow_type() + fields.append(StructField(column[0], data_type, bool(column[4].lower() == "y"))) + + return StructType(fields) + + +def infer_data_source_schema(conn: Connection, table: str) -> StructType: + try: + current_db, driver_info = detect_dbms(conn) + cursor = conn.cursor() + if current_db == DBMS_TYPE.SQL_SERVER_DB: + query = f""" + SELECT COLUMN_NAME, DATA_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE, IS_NULLABLE + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME = '{table}' + """ + cursor.execute(query) + raw_schema = cursor.fetchall() + return sql_server_to_snowpark_type(raw_schema) + elif current_db == DBMS_TYPE.ORACLE_DB: + query = f""" + SELECT COLUMN_NAME, DATA_TYPE, DATA_PRECISION, DATA_SCALE, NULLABLE + FROM USER_TAB_COLUMNS + WHERE TABLE_NAME = '{table}' + """ + cursor.execute(query) + raw_schema = cursor.fetchall() + return oracledb_to_snowpark_type(raw_schema) + else: + raise NotImplementedError( + f"currently supported drivers are pyodbc and oracledb, got: {driver_info}" + ) + except Exception as exc: + raise SnowparkDataframeReaderException( + f"Unable to infer schema from table {table}" + ) from exc + + +def data_source_data_to_pandas_df( + data: List[Any], schema: StructType, current_db: str, driver_info: str +) -> pd.DataFrame: + columns = [col.name for col in schema.fields] + df = pd.DataFrame.from_records(data, columns=columns) + + # convert timestamp and date to string to work around SNOW-1911989 + df = df.map( + lambda x: x.isoformat() + if isinstance(x, (datetime.datetime, datetime.date)) + else x + ) + # convert binary type to object type to work around SNOW-1912094 + df = df.map(lambda x: x.hex() if isinstance(x, (bytearray, bytes)) else x) + if current_db == DBMS_TYPE.SQL_SERVER_DB or current_db == DBMS_TYPE.SQLITE_DB: + return df + elif current_db == DBMS_TYPE.ORACLE_DB: + # apply read to LOB object, we currently have FakeOracleLOB because CLOB and BLOB is represented by an + # oracledb object and we cannot add it as our dependency in test, so we fake it in this way + # TODO: SNOW-1923698 remove FakeOracleLOB after we have test environment + df = df.map(lambda x: x.read() if (type(x).__name__.lower() == "lob") else x) + + else: + raise NotImplementedError( + f"currently supported drivers are pyodbc and oracledb, got: {driver_info}" + ) + return df + + +def generate_select_query(table: str, schema: StructType, conn: Connection) -> str: + current_db, driver_info = detect_dbms(conn) + if current_db == DBMS_TYPE.ORACLE_DB: + cols = [] + for field in schema.fields: + if ( + isinstance(field.datatype, TimestampType) + and field.datatype.tz == TimestampTimeZone.TZ + ): + cols.append( + f"""TO_CHAR({field.name}, 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM')""" + ) + elif ( + isinstance(field.datatype, TimestampType) + and field.datatype.tz == TimestampTimeZone.LTZ + ): + cols.append( + f"""TO_CHAR({field.name} AT TIME ZONE SESSIONTIMEZONE, 'YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM')""" + ) + else: + cols.append(field.name) + return f"""select {" , ".join(cols)} from {table}""" + elif current_db == DBMS_TYPE.SQL_SERVER_DB or current_db == DBMS_TYPE.SQLITE_DB: + return f"select * from {table}" + else: + raise NotImplementedError( + f"currently supported drivers are pyodbc and oracledb, got: {driver_info}" + ) + + +def generate_sql_with_predicates(select_query: str, predicates: List[str]): + return [select_query + f" WHERE {predicate}" for predicate in predicates] diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 6bb59b3f176..67845554610 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -83,6 +83,7 @@ class TelemetryField(Enum): FUNC_CAT_CREATE = "create" # performance categories PERF_CAT_UPLOAD_FILE = "upload_file" + PERF_CAT_DATA_SOURCE = "data_source" # optimizations SESSION_ID = "session_id" SQL_SIMPLIFIER_ENABLED = "sql_simplifier_enabled" @@ -369,23 +370,51 @@ def send_session_created_telemetry(self, created_by_snowpark: bool): self.send(message) @safe_telemetry - def send_upload_file_perf_telemetry( - self, func_name: str, duration: float, sfqid: str + def send_performance_telemetry( + self, category: str, func_name: str, duration: float, sfqid: str = None ): + """ + Sends performance telemetry data. + + Parameters: + category (str): The category of the telemetry (upload file or data source). + func_name (str): The name of the function. + duration (float): The duration of the operation. + sfqid (str, optional): The SFQID for upload file category. Defaults to None. + """ message = { **self._create_basic_telemetry_data( TelemetryField.TYPE_PERFORMANCE_DATA.value ), TelemetryField.KEY_DATA.value: { - PCTelemetryField.KEY_SFQID.value: sfqid, - TelemetryField.KEY_CATEGORY.value: TelemetryField.PERF_CAT_UPLOAD_FILE.value, + TelemetryField.KEY_CATEGORY.value: category, TelemetryField.KEY_FUNC_NAME.value: func_name, TelemetryField.KEY_DURATION.value: duration, TelemetryField.THREAD_IDENTIFIER.value: threading.get_ident(), + **({PCTelemetryField.KEY_SFQID.value: sfqid} if sfqid else {}), }, } self.send(message) + @safe_telemetry + def send_upload_file_perf_telemetry( + self, func_name: str, duration: float, sfqid: str + ): + self.send_performance_telemetry( + category=TelemetryField.PERF_CAT_UPLOAD_FILE.value, + func_name=func_name, + duration=duration, + sfqid=sfqid, + ) + + @safe_telemetry + def send_data_source_perf_telemetry(self, func_name: str, duration: float): + self.send_performance_telemetry( + category=TelemetryField.PERF_CAT_DATA_SOURCE.value, + func_name=func_name, + duration=duration, + ) + @safe_telemetry def send_function_usage_telemetry( self, diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index c42771f5e80..4019e107ed1 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -29,6 +29,7 @@ Union, get_args, get_origin, + Protocol, ) import snowflake.snowpark.context as context @@ -341,6 +342,7 @@ def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str: datetime.time: TimeType, bytes: BinaryType, } + if installed_pandas: import numpy diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index d05df693abe..8a26e5f7caf 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -1,10 +1,26 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import datetime +import decimal +import os +import tempfile +import time +import traceback +from _decimal import ROUND_HALF_EVEN, ROUND_HALF_UP +from concurrent.futures import ( + ProcessPoolExecutor, + wait, + ALL_COMPLETED, + ThreadPoolExecutor, + as_completed, +) +import pytz +from dateutil import parser import sys from logging import getLogger -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable import snowflake.snowpark import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto @@ -13,6 +29,7 @@ drop_file_format_if_exists_statement, infer_schema_statement, quote_name_without_upper_casing, + TEMPORARY_STRING_SET, ) from snowflake.snowpark._internal.analyzer.expression import Attribute from snowflake.snowpark._internal.analyzer.unary_expression import Alias @@ -28,6 +45,19 @@ ColumnOrName, convert_sf_to_sp_type, convert_sp_to_sf_type, + type_string_to_type_object, +) +from snowflake.snowpark._internal.data_source_utils import ( + data_source_data_to_pandas_df, + Connection, + infer_data_source_schema, + generate_select_query, + DATA_SOURCE_DBAPI_SIGNATURE, + detect_dbms, + DBMS_TYPE, + STATEMENT_PARAMS_DATA_SOURCE, + DATA_SOURCE_SQL_COMMENT, + generate_sql_with_predicates, ) from snowflake.snowpark._internal.utils import ( INFER_SCHEMA_FORMAT_TYPES, @@ -37,15 +67,29 @@ get_copy_into_table_options, parse_positional_args_to_list_variadic, publicapi, - random_name_for_temp_object, + get_temp_type_for_object, + normalize_local_file, ) from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column, _to_col_if_str from snowflake.snowpark.dataframe import DataFrame -from snowflake.snowpark.exceptions import SnowparkSessionException +from snowflake.snowpark.exceptions import ( + SnowparkSessionException, + SnowparkDataframeReaderException, +) from snowflake.snowpark.functions import sql_expr from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.table import Table -from snowflake.snowpark.types import StructType, VariantType +from snowflake.snowpark.types import ( + StructType, + VariantType, + DateType, + DataType, + _NumericType, + TimestampType, +) +from snowflake.snowpark._internal.utils import ( + random_name_for_temp_object, +) # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable # Python 3.9 can use both @@ -72,6 +116,8 @@ "TIMESTAMPFORMAT": "TIMESTAMP_FORMAT", } +MAX_RETRY_TIME = 3 + def _validate_stage_path(path: str) -> str: stripped_path = path.strip("\"'") @@ -1023,3 +1069,417 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame: df._reader = self set_api_call_source(df, f"DataFrameReader.{format.lower()}") return df + + @publicapi + def dbapi( + self, + create_connection: Callable[[], "Connection"], + table: str, + *, + column: Optional[str] = None, + lower_bound: Optional[Union[str, int]] = None, + upper_bound: Optional[Union[str, int]] = None, + num_partitions: Optional[int] = None, + max_workers: Optional[int] = None, + query_timeout: Optional[int] = 0, + fetch_size: Optional[int] = 0, + custom_schema: Optional[Union[str, StructType]] = None, + predicates: Optional[List[str]] = None, + session_init_statement: Optional[str] = None, + ) -> DataFrame: + """Reads data from a database table using a DBAPI connection.""" + statements_params_for_telemetry = {STATEMENT_PARAMS_DATA_SOURCE: "1"} + start_time = time.perf_counter() + conn = create_connection() + current_db, driver_info = detect_dbms(conn) + if custom_schema is None: + struct_schema = infer_data_source_schema(conn, table) + else: + if isinstance(custom_schema, str): + struct_schema = type_string_to_type_object(custom_schema) + if not isinstance(struct_schema, StructType): + raise ValueError( + f"Invalid schema string: {custom_schema}. " + f"You should provide a valid schema string representing a struct type." + ) + elif isinstance(custom_schema, StructType): + struct_schema = custom_schema + else: + raise TypeError(f"Invalid schema type: {type(custom_schema)}. ") + + select_query = generate_select_query(table, struct_schema, conn) + if column is None: + if ( + lower_bound is not None + or upper_bound is not None + or num_partitions is not None + ): + raise ValueError( + "when column is not specified, lower_bound, upper_bound, num_partitions are expected to be None" + ) + if predicates is None: + partitioned_queries = [select_query] + else: + partitioned_queries = generate_sql_with_predicates( + select_query, predicates + ) + else: + if lower_bound is None or upper_bound is None or num_partitions is None: + raise ValueError( + "when column is specified, lower_bound, upper_bound, num_partitions must be specified" + ) + + column_type = None + for field in struct_schema.fields: + if field.name.lower() == column.lower(): + column_type = field.datatype + if column_type is None: + raise ValueError("Column does not exist") + + if not isinstance(column_type, _NumericType) and not isinstance( + column_type, DateType + ): + raise ValueError(f"unsupported type {column_type}") + partitioned_queries = self._generate_partition( + select_query, + column_type, + column, + lower_bound, + upper_bound, + num_partitions, + ) + with tempfile.TemporaryDirectory() as tmp_dir: + # create temp table + snowflake_table_type = "temporary" + snowflake_table_name = random_name_for_temp_object(TempObjectType.TABLE) + create_table_sql = ( + "CREATE " + f"{get_temp_type_for_object(self._session._use_scoped_temp_objects, True) if snowflake_table_type.lower() in TEMPORARY_STRING_SET else snowflake_table_type} " + "TABLE " + f"{snowflake_table_name} " + f"""({" , ".join([f'"{field.name}" {convert_sp_to_sf_type(field.datatype)} {"NOT NULL" if not field.nullable else ""}' for field in struct_schema.fields])})""" + f"""{DATA_SOURCE_SQL_COMMENT}""" + ) + self._session.sql(create_table_sql).collect( + statement_params=statements_params_for_telemetry + ) + res_df = self.table(snowflake_table_name) + + # create temp stage + snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE) + sql_create_temp_stage = ( + f"create {get_temp_type_for_object(self._session._use_scoped_temp_objects, True)} stage" + f" if not exists {snowflake_stage_name} {DATA_SOURCE_SQL_COMMENT}" + ) + self._session.sql(sql_create_temp_stage).collect( + statement_params=statements_params_for_telemetry + ) + with ProcessPoolExecutor( + max_workers=max_workers + ) as process_executor, ThreadPoolExecutor( + max_workers=max_workers + ) as thread_executor: + thread_pool_futures = [] + process_pool_futures = [ + process_executor.submit( + _task_fetch_from_data_source_with_retry, + create_connection, + query, + struct_schema, + i, + tmp_dir, + current_db, + driver_info, + query_timeout, + fetch_size, + session_init_statement, + ) + for i, query in enumerate(partitioned_queries) + ] + for future in as_completed(process_pool_futures): + try: + future.result() + except BaseException: + process_executor.shutdown(wait=False) + thread_executor.shutdown(wait=False) + raise + else: + thread_pool_futures.append( + thread_executor.submit( + self._upload_and_copy_into_table_with_retry, + future.result(), + snowflake_stage_name, + snowflake_table_name, + "abort_statement", + statements_params_for_telemetry, + ) + ) + completed_futures = wait(thread_pool_futures, return_when=ALL_COMPLETED) + for f in completed_futures.done: + try: + f.result() + except BaseException: + process_executor.shutdown(wait=False) + thread_executor.shutdown(wait=False) + raise + self._session._conn._telemetry_client.send_data_source_perf_telemetry( + DATA_SOURCE_DBAPI_SIGNATURE, time.perf_counter() - start_time + ) + set_api_call_source(res_df, DATA_SOURCE_DBAPI_SIGNATURE) + return res_df + + def _generate_partition( + self, + select_query: str, + column_type: DataType, + column: Optional[str] = None, + lower_bound: Optional[Union[str, int]] = None, + upper_bound: Optional[Union[str, int]] = None, + num_partitions: Optional[int] = None, + ) -> List[str]: + processed_lower_bound = self._to_internal_value(lower_bound, column_type) + processed_upper_bound = self._to_internal_value(upper_bound, column_type) + if processed_lower_bound > processed_upper_bound: + raise ValueError("lower_bound cannot be greater than upper_bound") + + if processed_lower_bound == processed_upper_bound or num_partitions <= 1: + return [select_query] + + if (processed_upper_bound - processed_lower_bound) >= num_partitions or ( + processed_upper_bound - processed_lower_bound + ) < 0: + actual_num_partitions = num_partitions + else: + actual_num_partitions = processed_upper_bound - processed_lower_bound + logger.warning( + "The number of partitions is reduced because the specified number of partitions is less than the difference between upper bound and lower bound." + ) + + # decide stride length + upper_stride = ( + processed_upper_bound / decimal.Decimal(actual_num_partitions) + ).quantize(decimal.Decimal("1e-18"), rounding=ROUND_HALF_EVEN) + lower_stride = ( + processed_lower_bound / decimal.Decimal(actual_num_partitions) + ).quantize(decimal.Decimal("1e-18"), rounding=ROUND_HALF_EVEN) + preciseStride = upper_stride - lower_stride + stride = int(preciseStride) + + lost_num_of_strides = ( + (preciseStride - decimal.Decimal(stride)) + * decimal.Decimal(actual_num_partitions) + / decimal.Decimal(stride) + ) + lower_bound_with_stride_alignment = processed_lower_bound + int( + (lost_num_of_strides / 2 * decimal.Decimal(stride)).quantize( + decimal.Decimal("1"), rounding=ROUND_HALF_UP + ) + ) + + current_value = lower_bound_with_stride_alignment + + partition_queries = [] + for i in range(actual_num_partitions): + l_bound = ( + f"{column} >= '{self._to_external_value(current_value, column_type)}'" + if i != 0 + else "" + ) + current_value += stride + u_bound = ( + f"{column} < '{self._to_external_value(current_value, column_type)}'" + if i != actual_num_partitions - 1 + else "" + ) + + if u_bound == "": + where_clause = l_bound + elif l_bound == "": + where_clause = f"{u_bound} OR {column} is null" + else: + where_clause = f"{l_bound} AND {u_bound}" + + partition_queries.append(select_query + f" WHERE {where_clause}") + + return partition_queries + + # this function is only used in data source API for SQL server + def _to_internal_value(self, value: Union[int, str, float], column_type: DataType): + if isinstance(column_type, _NumericType): + return int(value) + elif isinstance(column_type, (TimestampType, DateType)): + # TODO: SNOW-1909315: support timezone + dt = parser.parse(value) + return int(dt.replace(tzinfo=pytz.UTC).timestamp()) + else: + raise TypeError(f"unsupported column type for partition: {column_type}") + + # this function is only used in data source API for SQL server + def _to_external_value(self, value: Union[int, str, float], column_type: DataType): + if isinstance(column_type, _NumericType): + return value + elif isinstance(column_type, (TimestampType, DateType)): + # TODO: SNOW-1909315: support timezone + return datetime.datetime.fromtimestamp(value, tz=pytz.UTC) + else: + raise TypeError(f"unsupported column type for partition: {column_type}") + + def _upload_and_copy_into_table( + self, + local_file: str, + snowflake_stage_name: str, + snowflake_table_name: Optional[str] = None, + on_error: Optional[str] = "abort_statement", + statements_params: Optional[Dict[str, str]] = None, + ): + file_name = os.path.basename(local_file) + put_query = ( + f"PUT {normalize_local_file(local_file)} " + f"@{snowflake_stage_name} OVERWRITE=TRUE {DATA_SOURCE_SQL_COMMENT}" + ) + copy_into_table_query = f""" + COPY INTO {snowflake_table_name} FROM @{snowflake_stage_name}/{file_name} + FILE_FORMAT = (TYPE = PARQUET) + MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE + PURGE=TRUE + ON_ERROR={on_error} + {DATA_SOURCE_SQL_COMMENT} + """ + self._session.sql(put_query).collect(statement_params=statements_params) + self._session.sql(copy_into_table_query).collect( + statement_params=statements_params + ) + + def _upload_and_copy_into_table_with_retry( + self, + local_file: str, + snowflake_stage_name: str, + snowflake_table_name: str, + on_error: Optional[str] = "abort_statement", + statements_params: Optional[Dict[str, str]] = None, + ): + retry_count = 0 + last_error = None + error_trace = "" + while retry_count < MAX_RETRY_TIME: + try: + self._upload_and_copy_into_table( + local_file, + snowflake_stage_name, + snowflake_table_name, + on_error, + statements_params, + ) + return + except Exception as e: + last_error = e + error_trace = traceback.format_exc() + retry_count += 1 + logger.debug( + f"Attempt {retry_count}/{MAX_RETRY_TIME} failed with {type(last_error).__name__}: {str(last_error)}. Retrying..." + ) + + final_error = SnowparkDataframeReaderException( + message=( + f"Failed to load data to snowflake after {MAX_RETRY_TIME} attempts.\n" + f"Last error: [{type(last_error).__name__}] {str(last_error)}\n" + f"Traceback:\n{error_trace}" + ) + ) + logger.error( + f"Failed to load data to snowflake after {MAX_RETRY_TIME} attempts.\n" + f"Last encountered error: [{type(last_error).__name__}] {str(last_error)}" + ) + raise final_error + + +def _task_fetch_from_data_source( + create_connection: Callable[[], "Connection"], + query: str, + schema: StructType, + i: int, + tmp_dir: str, + current_db: DBMS_TYPE, + driver_info: str, + query_timeout: int = 0, + fetch_size: int = 0, + session_init_statement: Optional[str] = None, +) -> str: + conn = create_connection() + # this is specified to pyodbc, need other way to manage timeout on other drivers + if current_db == DBMS_TYPE.SQL_SERVER_DB: + conn.timeout = query_timeout + result = [] + cursor = conn.cursor() + if session_init_statement: + cursor.execute(session_init_statement) + if fetch_size == 0: + cursor.execute(query) + result = cursor.fetchall() + elif fetch_size > 0: + cursor = cursor.execute(query) + rows = cursor.fetchmany(fetch_size) + while rows: + result.extend(rows) + rows = cursor.fetchmany(fetch_size) + else: + raise ValueError("fetch size cannot be smaller than 0") + + df = data_source_data_to_pandas_df(result, schema, current_db, driver_info) + path = os.path.join(tmp_dir, f"data_{i}.parquet") + df.to_parquet(path) + return path + + +def _task_fetch_from_data_source_with_retry( + create_connection: Callable[[], "Connection"], + query: str, + schema: StructType, + i: int, + tmp_dir: str, + current_db: DBMS_TYPE, + driver_info: str, + query_timeout: int = 0, + fetch_size: int = 0, + session_init_statement: Optional[str] = None, +) -> str: + retry_count = 0 + last_error = None + error_trace = "" + while retry_count < MAX_RETRY_TIME: + try: + path = _task_fetch_from_data_source( + create_connection, + query, + schema, + i, + tmp_dir, + current_db, + driver_info, + query_timeout, + fetch_size, + session_init_statement, + ) + return path + except Exception as e: + last_error = e + error_trace = traceback.format_exc() + retry_count += 1 + logger.debug( + f"Attempt {retry_count}/{MAX_RETRY_TIME} failed with {type(last_error).__name__}: {str(last_error)}. Retrying..." + ) + + final_error = SnowparkDataframeReaderException( + message=( + f"Failed to fetch from data source after {MAX_RETRY_TIME} attempts.\n" + f"Last error: [{type(last_error).__name__}] {str(last_error)}\n" + f"Traceback:\n{error_trace}" + ) + ) + + logger.error( + f"Failed to fetch from data source after {MAX_RETRY_TIME} attempts.\n" + f"Last encountered error: [{type(last_error).__name__}] {str(last_error)}" + ) + + raise final_error diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index c3da66389ec..02a3d328203 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -25,6 +25,10 @@ DATAFRAME_AST_PARAMETER, build_sp_table_name, ) +from snowflake.snowpark._internal.data_source_utils import ( + DATA_SOURCE_DBAPI_SIGNATURE, + STATEMENT_PARAMS_DATA_SOURCE, +) from snowflake.snowpark._internal.open_telemetry import open_telemetry_context_manager from snowflake.snowpark._internal.telemetry import ( add_api_call, @@ -95,6 +99,24 @@ def __init__( self.__format: Optional[str] = None self._ast_stmt = _ast_stmt + @staticmethod + def _track_data_source_statement_params( + dataframe, statement_params: Optional[Dict] = None + ) -> Optional[Dict]: + """ + Helper method to initialize and update data source tracking statement_params based on dataframe attributes. + """ + statement_params = statement_params or {} + if ( + dataframe._plan + and dataframe._plan.api_calls + and dataframe._plan.api_calls[0].get("name") == DATA_SOURCE_DBAPI_SIGNATURE + ): + # Track data source ingestion + statement_params[STATEMENT_PARAMS_DATA_SOURCE] = "1" + + return statement_params if statement_params else None + @publicapi def mode(self, save_mode: str, _emit_ast: bool = True) -> "DataFrameWriter": """Set the save mode of this :class:`DataFrameWriter`. @@ -444,6 +466,9 @@ def save_as_table( else: table_exists = None + statement_params = self._track_data_source_statement_params( + self._dataframe, statement_params or self._dataframe._statement_params + ) create_table_logic_plan = SnowflakeCreateTable( table_name, column_names, @@ -461,10 +486,11 @@ def save_as_table( iceberg_config, table_exists, ) + snowflake_plan = session._analyzer.resolve(create_table_logic_plan) result = session._conn.execute( snowflake_plan, - _statement_params=statement_params or self._dataframe._statement_params, + _statement_params=statement_params, block=block, data_type=_AsyncResultType.NO_RESULT, **kwargs, @@ -624,6 +650,10 @@ def copy_into_location( cur_format_type_options.update(format_type_aliased_options) + statement_params = self._track_data_source_statement_params( + self._dataframe, statement_params or self._dataframe._statement_params + ) + df = self._dataframe._with_plan( CopyIntoLocationNode( self._dataframe._plan, @@ -638,7 +668,7 @@ def copy_into_location( ) add_api_call(df, "DataFrameWriter.copy_into_location") return df._internal_collect_with_tag( - statement_params=statement_params or self._dataframe._statement_params, + statement_params=statement_params, block=block, **kwargs, ) diff --git a/src/snowflake/snowpark/exceptions.py b/src/snowflake/snowpark/exceptions.py index 2035d2ed6da..6a437435b33 100644 --- a/src/snowflake/snowpark/exceptions.py +++ b/src/snowflake/snowpark/exceptions.py @@ -35,6 +35,9 @@ def __repr__(self): def __str__(self): return self._pretty_msg + def __reduce__(self): + return (self.__class__, (self.message,), {"error_code": self.error_code}) + class _SnowparkInternalException(SnowparkClientException): """Exception for internal errors. For internal use only. diff --git a/tests/integ/test_data_source_api.py b/tests/integ/test_data_source_api.py new file mode 100644 index 00000000000..05defe49e0b --- /dev/null +++ b/tests/integ/test_data_source_api.py @@ -0,0 +1,551 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +import functools +import os +import tempfile +import time +import datetime +from unittest import mock +import pytest + +from snowflake.snowpark._internal.utils import ( + TempObjectType, +) +from snowflake.snowpark.dataframe_reader import ( + _task_fetch_from_data_source_with_retry, + MAX_RETRY_TIME, +) +from snowflake.snowpark._internal.data_source_utils import ( + DATA_SOURCE_DBAPI_SIGNATURE, + DATA_SOURCE_SQL_COMMENT, + STATEMENT_PARAMS_DATA_SOURCE, + DBMS_TYPE, + generate_sql_with_predicates, +) +from snowflake.snowpark.exceptions import SnowparkDataframeReaderException +from snowflake.snowpark.types import ( + StructType, + StructField, + IntegerType, + DateType, + MapType, + FloatType, + StringType, + BinaryType, + NullType, + TimestampType, + TimeType, + ShortType, + LongType, + DoubleType, + DecimalType, + ArrayType, + VariantType, +) +from tests.resources.test_data_source_dir.test_data_source_data import ( + sql_server_all_type_data, + sql_server_all_type_small_data, + sql_server_create_connection, + sql_server_create_connection_small_data, + sqlite3_db, + create_connection_to_sqlite3_db, + oracledb_all_type_data_result, + oracledb_create_connection, + oracledb_all_type_small_data_result, + oracledb_create_connection_small_data, + fake_detect_dbms_pyodbc, +) +from tests.utils import Utils, IS_WINDOWS + +pytestmark = pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="feature not available in local testing", +) + +SQL_SERVER_TABLE_NAME = "AllDataTypesTable" +ORACLEDB_TABLE_NAME = "ALL_TYPES_TABLE" + + +def fake_task_fetch_from_data_source_with_retry( + create_connection, + query, + schema, + i, + tmp_dir, + current_db, + driver_info, + query_timeout, + fetch_size, + session_init_statement, +): + time.sleep(2) + + +def upload_and_copy_into_table_with_retry( + self, + local_file, + snowflake_stage_name, + snowflake_table_name, + on_error, +): + time.sleep(2) + + +def test_dbapi_with_temp_table(session): + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ): + df = session.read.dbapi( + sql_server_create_connection, SQL_SERVER_TABLE_NAME, max_workers=4 + ) + assert df.collect() == sql_server_all_type_data + + +def test_dbapi_oracledb(session): + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ): + df = session.read.dbapi( + oracledb_create_connection, ORACLEDB_TABLE_NAME, max_workers=4 + ) + assert df.collect() == oracledb_all_type_data_result + + +def test_dbapi_batch_fetch_oracledb(session): + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ): + df = session.read.dbapi( + oracledb_create_connection, ORACLEDB_TABLE_NAME, max_workers=4, fetch_size=1 + ) + assert df.collect() == oracledb_all_type_data_result + + df = session.read.dbapi( + oracledb_create_connection, ORACLEDB_TABLE_NAME, max_workers=4, fetch_size=3 + ) + assert df.collect() == oracledb_all_type_data_result + + df = session.read.dbapi( + oracledb_create_connection_small_data, + ORACLEDB_TABLE_NAME, + max_workers=4, + fetch_size=1, + ) + assert df.collect() == oracledb_all_type_small_data_result + + df = session.read.dbapi( + oracledb_create_connection_small_data, + ORACLEDB_TABLE_NAME, + max_workers=4, + fetch_size=3, + ) + assert df.collect() == oracledb_all_type_small_data_result + + +def test_dbapi_batch_fetch(session): + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ): + df = session.read.dbapi( + sql_server_create_connection, + SQL_SERVER_TABLE_NAME, + max_workers=4, + fetch_size=1, + ) + assert df.collect() == sql_server_all_type_data + + df = session.read.dbapi( + sql_server_create_connection, + SQL_SERVER_TABLE_NAME, + max_workers=4, + fetch_size=3, + ) + assert df.collect() == sql_server_all_type_data + + df = session.read.dbapi( + sql_server_create_connection_small_data, + SQL_SERVER_TABLE_NAME, + max_workers=4, + fetch_size=1, + ) + assert df.collect() == sql_server_all_type_small_data + + df = session.read.dbapi( + sql_server_create_connection_small_data, + SQL_SERVER_TABLE_NAME, + max_workers=4, + fetch_size=3, + ) + assert df.collect() == sql_server_all_type_small_data + + +def test_dbapi_retry(session): + + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ), mock.patch( + "snowflake.snowpark.dataframe_reader._task_fetch_from_data_source", + side_effect=RuntimeError("Test error"), + ) as mock_task: + with pytest.raises( + SnowparkDataframeReaderException, match="\\[RuntimeError\\] Test error" + ): + _task_fetch_from_data_source_with_retry( + create_connection=sql_server_create_connection, + query="SELECT * FROM test_table", + schema=StructType([StructField("col1", IntegerType(), False)]), + i=0, + tmp_dir="/tmp", + current_db=DBMS_TYPE.SQL_SERVER_DB, + driver_info="pyodbc", + ) + assert mock_task.call_count == MAX_RETRY_TIME + + with mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ), mock.patch( + "snowflake.snowpark.dataframe_reader.DataFrameReader._upload_and_copy_into_table", + side_effect=RuntimeError("Test error"), + ) as mock_task: + with pytest.raises( + SnowparkDataframeReaderException, match="\\[RuntimeError\\] Test error" + ): + session.read._upload_and_copy_into_table_with_retry( + local_file="fake_file", + snowflake_stage_name="fake_stage", + snowflake_table_name="fake_table", + ) + assert mock_task.call_count == MAX_RETRY_TIME + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="feature not available in local testing", +) +def test_parallel(session): + num_partitions = 3 + # this test meant to test whether ingest is fully parallelized + # we cannot mock this function as process pool does not all mock object + with mock.patch( + "snowflake.snowpark.dataframe_reader._task_fetch_from_data_source_with_retry", + new=fake_task_fetch_from_data_source_with_retry, + ), mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ), mock.patch( + "snowflake.snowpark.dataframe_reader.DataFrameReader._upload_and_copy_into_table_with_retry", + wrap=upload_and_copy_into_table_with_retry, + ) as mock_upload_and_copy: + + start = time.time() + session.read.dbapi( + sql_server_create_connection, + SQL_SERVER_TABLE_NAME, + column="Id", + upper_bound=100, + lower_bound=0, + num_partitions=num_partitions, + max_workers=4, + ) + end = time.time() + # totally time without parallel is 12 seconds + assert end - start < 12 + # verify that mocked function is called for each partition + assert mock_upload_and_copy.call_count == num_partitions + + +def test_partition_logic(session): + expected_queries1 = [ + "SELECT * FROM fake_table WHERE ID < '8' OR ID is null", + "SELECT * FROM fake_table WHERE ID >= '8' AND ID < '10'", + "SELECT * FROM fake_table WHERE ID >= '10' AND ID < '12'", + "SELECT * FROM fake_table WHERE ID >= '12'", + ] + + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=IntegerType(), + column="ID", + lower_bound=5, + upper_bound=15, + num_partitions=4, + ) + for r, expected_r in zip(queries, expected_queries1): + assert r == expected_r + + expected_queries2 = [ + "SELECT * FROM fake_table WHERE ID < '-2' OR ID is null", + "SELECT * FROM fake_table WHERE ID >= '-2' AND ID < '0'", + "SELECT * FROM fake_table WHERE ID >= '0' AND ID < '2'", + "SELECT * FROM fake_table WHERE ID >= '2'", + ] + + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=IntegerType(), + column="ID", + lower_bound=-5, + upper_bound=5, + num_partitions=4, + ) + for r, expected_r in zip(queries, expected_queries2): + assert r == expected_r + + expected_queries3 = [ + "SELECT * FROM fake_table", + ] + + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=IntegerType(), + column="ID", + lower_bound=5, + upper_bound=15, + num_partitions=1, + ) + for r, expected_r in zip(queries, expected_queries3): + assert r == expected_r + + expected_queries4 = [ + "SELECT * FROM fake_table WHERE ID < '6' OR ID is null", + "SELECT * FROM fake_table WHERE ID >= '6' AND ID < '7'", + "SELECT * FROM fake_table WHERE ID >= '7' AND ID < '8'", + "SELECT * FROM fake_table WHERE ID >= '8' AND ID < '9'", + "SELECT * FROM fake_table WHERE ID >= '9' AND ID < '10'", + "SELECT * FROM fake_table WHERE ID >= '10' AND ID < '11'", + "SELECT * FROM fake_table WHERE ID >= '11' AND ID < '12'", + "SELECT * FROM fake_table WHERE ID >= '12' AND ID < '13'", + "SELECT * FROM fake_table WHERE ID >= '13' AND ID < '14'", + "SELECT * FROM fake_table WHERE ID >= '14'", + ] + + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=IntegerType(), + column="ID", + lower_bound=5, + upper_bound=15, + num_partitions=10, + ) + for r, expected_r in zip(queries, expected_queries4): + assert r == expected_r + + expected_queries5 = [ + "SELECT * FROM fake_table WHERE ID < '8' OR ID is null", + "SELECT * FROM fake_table WHERE ID >= '8' AND ID < '11'", + "SELECT * FROM fake_table WHERE ID >= '11'", + ] + + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=IntegerType(), + column="ID", + lower_bound=5, + upper_bound=15, + num_partitions=3, + ) + for r, expected_r in zip(queries, expected_queries5): + assert r == expected_r + + +def test_partition_date_timestamp(session): + expected_queries1 = [ + "SELECT * FROM fake_table WHERE DATE < '2020-07-30 18:00:00+00:00' OR DATE is null", + "SELECT * FROM fake_table WHERE DATE >= '2020-07-30 18:00:00+00:00' AND DATE < '2020-09-14 12:00:00+00:00'", + "SELECT * FROM fake_table WHERE DATE >= '2020-09-14 12:00:00+00:00' AND DATE < '2020-10-30 06:00:00+00:00'", + "SELECT * FROM fake_table WHERE DATE >= '2020-10-30 06:00:00+00:00'", + ] + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=DateType(), + column="DATE", + lower_bound=str(datetime.date(2020, 6, 15)), + upper_bound=str(datetime.date(2020, 12, 15)), + num_partitions=4, + ) + + for r, expected_r in zip(queries, expected_queries1): + assert r == expected_r + + expected_queries2 = [ + "SELECT * FROM fake_table WHERE DATE < '2020-07-31 05:06:13+00:00' OR DATE is null", + "SELECT * FROM fake_table WHERE DATE >= '2020-07-31 05:06:13+00:00' AND DATE < '2020-09-14 21:46:55+00:00'", + "SELECT * FROM fake_table WHERE DATE >= '2020-09-14 21:46:55+00:00' AND DATE < '2020-10-30 14:27:37+00:00'", + "SELECT * FROM fake_table WHERE DATE >= '2020-10-30 14:27:37+00:00'", + ] + queries = session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=DateType(), + column="DATE", + lower_bound=str(datetime.datetime(2020, 6, 15, 12, 25, 30)), + upper_bound=str(datetime.datetime(2020, 12, 15, 7, 8, 20)), + num_partitions=4, + ) + + for r, expected_r in zip(queries, expected_queries2): + assert r == expected_r + + +def test_partition_unsupported_type(session): + with pytest.raises(TypeError, match="unsupported column type for partition:"): + session.read._generate_partition( + select_query="SELECT * FROM fake_table", + column_type=MapType(), + column="DATE", + lower_bound=0, + upper_bound=1, + num_partitions=4, + ) + + +def test_telemetry_tracking(caplog, session): + original_func = session._conn.run_query + called, comment_showed = 0, 0 + + def assert_datasource_statement_params_run_query(*args, **kwargs): + # assert we set statement_parameters to track datasourcee api usage + nonlocal comment_showed + statement_parameters = kwargs.get("_statement_params") + query = args[0] + assert statement_parameters[STATEMENT_PARAMS_DATA_SOURCE] == "1" + if "select" not in query.lower(): + assert DATA_SOURCE_SQL_COMMENT in query + comment_showed += 1 + nonlocal called + called += 1 + return original_func(*args, **kwargs) + + with mock.patch( + "snowflake.snowpark._internal.server_connection.ServerConnection.run_query", + side_effect=assert_datasource_statement_params_run_query, + ), mock.patch( + "snowflake.snowpark._internal.data_source_utils.detect_dbms_pyodbc", + new=fake_detect_dbms_pyodbc, + ), mock.patch( + "snowflake.snowpark._internal.telemetry.TelemetryClient.send_performance_telemetry" + ) as mock_telemetry: + df = session.read.dbapi(sql_server_create_connection, SQL_SERVER_TABLE_NAME) + assert df._plan.api_calls == [{"name": DATA_SOURCE_DBAPI_SIGNATURE}] + assert ( + called == 4 and comment_showed == 4 + ) # 4 queries: create table, create stage, put file, copy into + assert mock_telemetry.called + assert df.collect() == sql_server_all_type_data + + # assert when we save/copy, the statement_params is added + temp_table = Utils.random_name_for_temp_object(TempObjectType.TABLE) + temp_stage = Utils.random_name_for_temp_object(TempObjectType.STAGE) + Utils.create_stage(session, temp_stage, is_temporary=True) + called = 0 + with mock.patch( + "snowflake.snowpark._internal.server_connection.ServerConnection.run_query", + side_effect=assert_datasource_statement_params_run_query, + ): + df.write.save_as_table(temp_table) + df.write.copy_into_location( + f"{temp_stage}/test.parquet", + file_format_type="parquet", + header=True, + overwrite=True, + single=True, + ) + assert called == 2 + + +@pytest.mark.skipif( + IS_WINDOWS, + reason="sqlite3 file can not be shared accorss processes on windows", +) +@pytest.mark.parametrize( + "custom_schema", + [ + "id INTEGER, int_col INTEGER, real_col FLOAT, text_col STRING, blob_col BINARY, null_col STRING, ts_col TIMESTAMP, date_col DATE, time_col TIME, short_col SHORT, long_col LONG, double_col DOUBLE, decimal_col DECIMAL, map_col MAP, array_col ARRAY, var_col VARIANT", + StructType( + [ + StructField("id", IntegerType()), + StructField("int_col", IntegerType()), + StructField("real_col", FloatType()), + StructField("text_col", StringType()), + StructField("blob_col", BinaryType()), + StructField("null_col", NullType()), + StructField("ts_col", TimestampType()), + StructField("date_col", DateType()), + StructField("time_col", TimeType()), + StructField("short_col", ShortType()), + StructField("long_col", LongType()), + StructField("double_col", DoubleType()), + StructField("decimal_col", DecimalType()), + StructField("map_col", MapType()), + StructField("array_col", ArrayType()), + StructField("var_col", VariantType()), + ] + ), + ], +) +def test_custom_schema(session, custom_schema): + with tempfile.TemporaryDirectory() as temp_dir: + dbpath = os.path.join(temp_dir, "testsqlite3.db") + table_name, columns, example_data, assert_data = sqlite3_db(dbpath) + + df = session.read.dbapi( + functools.partial(create_connection_to_sqlite3_db, dbpath), + table_name, + custom_schema=custom_schema, + ) + assert df.columns == [col.upper() for col in columns] + assert df.collect() == assert_data + + with pytest.raises( + SnowparkDataframeReaderException, match="Unable to infer schema" + ): + session.read.dbapi( + functools.partial(create_connection_to_sqlite3_db, dbpath), + table_name, + ) + + +def test_predicates(): + select_query = "select * from fake_table" + predicates = ["id > 1 AND id <= 1000", "id > 1001 AND id <= 2000", "id > 2001"] + expected_result = [ + "select * from fake_table WHERE id > 1 AND id <= 1000", + "select * from fake_table WHERE id > 1001 AND id <= 2000", + "select * from fake_table WHERE id > 2001", + ] + res = generate_sql_with_predicates(select_query, predicates) + assert res == expected_result + + +@pytest.mark.skipif( + IS_WINDOWS, + reason="sqlite3 file can not be shared across processes on windows", +) +def test_session_init_statement(session): + with tempfile.TemporaryDirectory() as temp_dir: + dbpath = os.path.join(temp_dir, "testsqlite3.db") + table_name, _, _, assert_data = sqlite3_db(dbpath) + + df = session.read.dbapi( + functools.partial(create_connection_to_sqlite3_db, dbpath), + table_name, + custom_schema="id INTEGER, int_col INTEGER, real_col FLOAT, text_col STRING, blob_col BINARY, null_col STRING, ts_col TIMESTAMP, date_col DATE, time_col TIME, short_col SHORT, long_col LONG, double_col DOUBLE, decimal_col DECIMAL, map_col MAP, array_col ARRAY, var_col VARIANT", + session_init_statement="SELECT 1;", + ) + assert df.collect() == assert_data + + with pytest.raises( + SnowparkDataframeReaderException, match='near "FROM": syntax error' + ): + session.read.dbapi( + functools.partial(create_connection_to_sqlite3_db, dbpath), + table_name, + custom_schema="id INTEGER", + session_init_statement="SELECT FROM NOTHING;", + ) diff --git a/tests/resources/test_data_source_dir/test_data_source_data.py b/tests/resources/test_data_source_dir/test_data_source_data.py new file mode 100644 index 00000000000..a77aa6c00f1 --- /dev/null +++ b/tests/resources/test_data_source_dir/test_data_source_data.py @@ -0,0 +1,823 @@ +import datetime +import sqlite3 +from _decimal import Decimal +from dateutil import parser + +from snowflake.snowpark._internal.data_source_utils import DBMS_TYPE + + +# we manually mock these objects because mock object cannot be used in multi-process as they are not pickleable +class FakeConnection: + def __init__(self, data, schema, connection_type) -> None: + self.__class__.__module__ = connection_type + self.sql = "" + self.start_index = 0 + self.data = data + self.schema = schema + + def cursor(self): + return self + + def execute(self, sql: str): + self.sql = sql + return self + + def fetchall(self): + if "INFORMATION_SCHEMA" in self.sql or "USER_TAB_COLUMNS" in self.sql: + return self.schema + else: + return self.data + + def fetchmany(self, row_count: int): + end_index = self.start_index + row_count + res = ( + self.data[self.start_index : end_index] + if end_index < len(self.data) + else self.data[self.start_index :] + ) + self.start_index = end_index + return res + + def getinfo(self, sql_dbms_name): + return "sqlserver" + + +class LOB: + def __init__(self, value) -> None: + self.value = value + + def read(self, offset: int = 1, amount: int = None): + return self.value + + +def fake_detect_dbms_pyodbc(conn): + return DBMS_TYPE.SQL_SERVER_DB + + +sql_server_all_type_schema = ( + ("Id", "int", 10, 0, "NO"), + ("SmallIntCol", "smallint", 5, 0, "YES"), + ("TinyIntCol", "tinyint", 3, 0, "YES"), + ("BigIntCol", "bigint", 19, 0, "YES"), + ("DecimalCol", "decimal", 10, 2, "YES"), + ("FloatCol", "float", 53, None, "YES"), + ("RealCol", "real", 24, None, "YES"), + ("MoneyCol", "money", 19, 4, "YES"), + ("SmallMoneyCol", "smallmoney", 10, 4, "YES"), + ("CharCol", "char", None, None, "YES"), + ("VarCharCol", "varchar", None, None, "YES"), + ("TextCol", "text", None, None, "YES"), + ("NCharCol", "nchar", None, None, "YES"), + ("NVarCharCol", "nvarchar", None, None, "YES"), + ("NTextCol", "ntext", None, None, "YES"), + ("DateCol", "date", None, None, "YES"), + ("TimeCol", "time", None, None, "YES"), + ("DateTimeCol", "datetime", None, None, "YES"), + ("DateTime2Col", "datetime2", None, None, "YES"), + ("SmallDateTimeCol", "smalldatetime", None, None, "YES"), + ("BinaryCol", "binary", None, None, "YES"), + ("VarBinaryCol", "varbinary", None, None, "YES"), + ("BitCol", "bit", None, None, "YES"), + ("UniqueIdentifierCol", "uniqueidentifier", None, None, "YES"), +) + +oracledb_all_type_schema = ( + ("ID", "NUMBER", None, None, "N"), + ("NUMBER_COL", "NUMBER", 10, 2, "Y"), + ("BINARY_FLOAT_COL", "BINARY_FLOAT", None, None, "Y"), + ("BINARY_DOUBLE_COL", "BINARY_DOUBLE", None, None, "Y"), + ("VARCHAR2_COL", "VARCHAR2", None, None, "Y"), + ("CHAR_COL", "CHAR", None, None, "Y"), + ("CLOB_COL", "CLOB", None, None, "Y"), + ("NCHAR_COL", "NCHAR", None, None, "Y"), + ("NVARCHAR2_COL", "NVARCHAR2", None, None, "Y"), + ("NCLOB_COL", "NCLOB", None, None, "Y"), + ("DATE_COL", "DATE", None, None, "Y"), + ("TIMESTAMP_COL", "TIMESTAMP(6)", None, 6, "Y"), + ("TIMESTAMP_TZ_COL", "TIMESTAMP(6) WITH TIME ZONE", None, 6, "Y"), + ("TIMESTAMP_LTZ_COL", "TIMESTAMP(6) WITH LOCAL TIME ZONE", None, 6, "Y"), + ("BLOB_COL", "BLOB", None, None, "Y"), + ("RAW_COL", "RAW", None, None, "Y"), +) + +oracledb_all_type_data = [ + ( + 1, + 123.45, + 123.0, + 12345678900.0, + "Sample1", + "Char1 ", + LOB("Large text data 1"), + "Hello ", + "World", + LOB("sample text 1"), + datetime.datetime(2024, 1, 1, 0, 0), + datetime.datetime(2024, 1, 1, 12, 0), + "2024-01-01 12:00:00.000000000 -0800", + "2024-01-01 12:00:00.000000000 -0800", + None, + b"Binary1", + ), + ( + 2, + 234.56, + 234.0, + 234567890000.0, + "Sample2", + "Char2 ", + LOB("Large text data 2"), + "Goodbye ", + "Everyone", + LOB("sample text 2"), + datetime.datetime(2024, 1, 2, 0, 0), + datetime.datetime(2024, 1, 2, 13, 30), + "2024-01-02 13:30:00.000000000 -0800", + "2024-01-02 13:30:00.000000000 -0800", + None, + b"Binary2", + ), + ( + 3, + 345.67, + 345.0, + 3456789000000.0, + "Sample3", + "Char3 ", + LOB("Large text data 3"), + "Morning ", + "Sunrise", + LOB("sample text 3"), + datetime.datetime(2024, 1, 3, 0, 0), + datetime.datetime(2024, 1, 3, 8, 15), + "2024-01-03 08:15:00.000000000 -0800", + "2024-01-03 08:15:00.000000000 -0800", + None, + b"Binary3", + ), + ( + 4, + 456.78, + 456.0, + 45678900000000.0, + "Sample4", + "Char4 ", + LOB("Large text data 4"), + "Afternoon ", + "Clouds", + LOB("sample text 4"), + datetime.datetime(2024, 1, 4, 0, 0), + datetime.datetime(2024, 1, 4, 14, 45), + "2024-01-04 14:45:00.000000000 -0800", + "2024-01-04 14:45:00.000000000 -0800", + None, + b"Binary4", + ), + ( + 5, + 567.89, + 567.0, + 567890000000000.0, + "Sample5", + "Char5 ", + LOB("Large text data 5"), + "Evening ", + "Stars", + LOB("sample text 5"), + datetime.datetime(2024, 1, 5, 0, 0), + datetime.datetime(2024, 1, 5, 19, 0), + "2024-01-05 19:00:00.000000000 -0800", + "2024-01-05 19:00:00.000000000 -0800", + None, + b"Binary5", + ), + ( + 6, + 678.9, + 678.0, + 6789000000000000.0, + "Sample6", + "Char6 ", + LOB("Large text data 6"), + "Night ", + "Moon", + LOB("sample text 6"), + datetime.datetime(2024, 1, 6, 0, 0), + datetime.datetime(2024, 1, 6, 23, 59), + "2024-01-06 23:59:00.000000000 -0800", + "2024-01-06 23:59:00.000000000 -0800", + None, + b"Binary6", + ), + ( + 7, + 789.01, + 789.0, + 7.89e16, + "Sample7", + "Char7 ", + LOB("Large text data 7"), + "Dawn ", + "Mist", + LOB("sample text 7"), + datetime.datetime(2024, 1, 7, 0, 0), + datetime.datetime(2024, 1, 7, 4, 30), + "2024-01-07 04:30:00.000000000 -0800", + "2024-01-07 04:30:00.000000000 -0800", + None, + b"Binary7", + ), + ( + 8, + 890.12, + 890.0, + 8.9e17, + "Sample8", + "Char8 ", + LOB("Large text data 8"), + "Midday ", + "Heat", + LOB("sample text 8"), + datetime.datetime(2024, 1, 8, 0, 0), + datetime.datetime(2024, 1, 8, 12, 0), + "2024-01-08 12:00:00.000000000 -0800", + "2024-01-08 12:00:00.000000000 -0800", + None, + b"Binary8", + ), + ( + 9, + 901.23, + 901.0, + 9.01e18, + "Sample9", + "Char9 ", + LOB("Large text data 9"), + "Sunset ", + "Horizon", + LOB("sample text 9"), + datetime.datetime(2024, 1, 9, 0, 0), + datetime.datetime(2024, 1, 9, 18, 45), + "2024-01-09 18:45:00.000000000 -0800", + "2024-01-09 18:45:00.000000000 -0800", + None, + b"Binary9", + ), + ( + 10, + 1012.34, + 1010.0, + 1.01e19, + "Sample10", + "Char10 ", + LOB("Large text data 10"), + "Twilight ", + "Calm", + LOB("sample text 10"), + datetime.datetime(2024, 1, 10, 0, 0), + datetime.datetime(2024, 1, 10, 21, 15), + "2024-01-10 21:15:00.000000000 -0800", + "2024-01-10 21:15:00.000000000 -0800", + None, + b"Binary10", + ), +] + +oracledb_all_type_data_result = [] +for row in oracledb_all_type_data: + new_row = [] + for i, item in enumerate(row): + if i == 6 or i == 9: + new_row.append(item.read()) + elif i == 1: + new_row.append(Decimal(str(item))) + elif i == 12 or i == 13: + new_row.append(parser.parse(item)) + elif i == 10: + new_row.append(item.date()) + else: + new_row.append(item) + oracledb_all_type_data_result.append(tuple(new_row)) +sql_server_all_type_data = [ + ( + 1, + 100, + 10, + 100000, + Decimal("12345.67"), + 1.23, + 0.4560000002384186, + Decimal("1234.5600"), + Decimal("12.3400"), + "FixedStr1 ", + "VarStr1", + "Text1", + "UniFix1 ", + "UniVar1", + "UniText1", + datetime.date(2023, 1, 1), + datetime.time(12, 0), + datetime.datetime(2023, 1, 1, 12, 0), + datetime.datetime(2023, 1, 1, 12, 0, 0, 123000), + datetime.datetime(2023, 1, 1, 12, 0), + b"\x01\x02\x03\x04\x05", + b"\x01\x02\x03\x04", + True, + "06D48351-6EA7-4E64-81A2-9921F0EC42A5", + ), + ( + 2, + 200, + 20, + 200000, + Decimal("23456.78"), + 2.34, + 1.5670000314712524, + Decimal("2345.6700"), + Decimal("23.4500"), + "FixedStr2 ", + "VarStr2", + "Text2", + "UniFix2 ", + "UniVar2", + "UniText2", + datetime.date(2023, 2, 1), + datetime.time(13, 0), + datetime.datetime(2023, 2, 1, 13, 0), + datetime.datetime(2023, 2, 1, 13, 0, 0, 234000), + datetime.datetime(2023, 2, 1, 13, 0), + b"\x02\x03\x04\x05\x06", + b"\x02\x03\x04\x05", + False, + "41B116E8-7D42-420B-A28A-98D53C782C79", + ), + ( + 3, + 300, + 30, + 300000, + Decimal("34567.89"), + 3.45, + 2.677999973297119, + Decimal("3456.7800"), + Decimal("34.5600"), + "FixedStr3 ", + "VarStr3", + "Text3", + "UniFix3 ", + "UniVar3", + "UniText3", + datetime.date(2023, 3, 1), + datetime.time(14, 0), + datetime.datetime(2023, 3, 1, 14, 0), + datetime.datetime(2023, 3, 1, 14, 0, 0, 345000), + datetime.datetime(2023, 3, 1, 14, 0), + b"\x03\x04\x05\x06\x07", + b"\x03\x04\x05\x06", + True, + "F418999E-15F9-4FB0-9161-3383E0BC1B3E", + ), + ( + 4, + 400, + 40, + 400000, + Decimal("45678.90"), + 4.56, + 3.7890000343322754, + Decimal("4567.8900"), + Decimal("45.6700"), + "FixedStr4 ", + "VarStr4", + "Text4", + "UniFix4 ", + "UniVar4", + "UniText4", + datetime.date(2023, 4, 1), + datetime.time(15, 0), + datetime.datetime(2023, 4, 1, 15, 0), + datetime.datetime(2023, 4, 1, 15, 0, 0, 456000), + datetime.datetime(2023, 4, 1, 15, 0), + b"\x04\x05\x06\x07\x08", + b"\x04\x05\x06\x07", + False, + "13DF4C45-682A-4C17-81BA-7B00C77E3F9C", + ), + ( + 5, + 500, + 50, + 500000, + Decimal("56789.01"), + 5.67, + 4.889999866485596, + Decimal("5678.9000"), + Decimal("56.7800"), + "FixedStr5 ", + "VarStr5", + "Text5", + "UniFix5 ", + "UniVar5", + "UniText5", + datetime.date(2023, 5, 1), + datetime.time(16, 0), + datetime.datetime(2023, 5, 1, 16, 0), + datetime.datetime(2023, 5, 1, 16, 0, 0, 567000), + datetime.datetime(2023, 5, 1, 16, 0), + b"\x05\x06\x07\x08\t", + b"\x05\x06\x07\x08", + True, + "16592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 5, + 500, + 50, + 500000, + Decimal("56789.01"), + 5.67, + 4.889999866485596, + Decimal("5678.9000"), + Decimal("56.7800"), + "FixedStr5 ", + "VarStr5", + "Text5", + "UniFix5 ", + "UniVar5", + "UniText5", + datetime.date(2023, 5, 1), + datetime.time(16, 0), + datetime.datetime(2023, 5, 1, 16, 0), + datetime.datetime(2023, 5, 1, 16, 0, 0, 567000), + datetime.datetime(2023, 5, 1, 16, 0), + b"\x05\x06\x07\x08\t", + b"\x05\x06\x07\x08", + True, + "16592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 6, + 600, + 60, + 600000, + Decimal("67890.12"), + 6.78, + 5.999999866485596, + Decimal("6789.0100"), + Decimal("67.8900"), + "FixedStr6 ", + "VarStr6", + "Text6", + "UniFix6 ", + "UniVar6", + "UniText6", + datetime.date(2023, 6, 1), + datetime.time(17, 0), + datetime.datetime(2023, 6, 1, 17, 0), + datetime.datetime(2023, 6, 1, 17, 0, 0, 678000), + datetime.datetime(2023, 6, 1, 17, 0), + b"\x06\x07\x08\t\n", + b"\x06\x07\x08\t", + False, + "26592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 7, + 700, + 70, + 700000, + Decimal("78901.23"), + 7.89, + 7.099999866485596, + Decimal("7890.1200"), + Decimal("78.9000"), + "FixedStr7 ", + "VarStr7", + "Text7", + "UniFix7 ", + "UniVar7", + "UniText7", + datetime.date(2023, 7, 1), + datetime.time(18, 0), + datetime.datetime(2023, 7, 1, 18, 0), + datetime.datetime(2023, 7, 1, 18, 0, 0, 789000), + datetime.datetime(2023, 7, 1, 18, 0), + b"\x07\x08\t\n\x0b", + b"\x07\x08\t\n", + True, + "36592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 8, + 800, + 80, + 800000, + Decimal("89012.34"), + 8.90, + 8.199999866485596, + Decimal("8901.2300"), + Decimal("89.0100"), + "FixedStr8 ", + "VarStr8", + "Text8", + "UniFix8 ", + "UniVar8", + "UniText8", + datetime.date(2023, 8, 1), + datetime.time(19, 0), + datetime.datetime(2023, 8, 1, 19, 0), + datetime.datetime(2023, 8, 1, 19, 0, 0, 890000), + datetime.datetime(2023, 8, 1, 19, 0), + b"\x08\t\n\x0b\x0c", + b"\x08\t\n\x0b", + False, + "46592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 9, + 900, + 90, + 900000, + Decimal("90123.45"), + 9.01, + 9.299999866485596, + Decimal("9012.3400"), + Decimal("90.1200"), + "FixedStr9 ", + "VarStr9", + "Text9", + "UniFix9 ", + "UniVar9", + "UniText9", + datetime.date(2023, 9, 1), + datetime.time(20, 0), + datetime.datetime(2023, 9, 1, 20, 0), + datetime.datetime(2023, 9, 1, 20, 0, 0, 901000), + datetime.datetime(2023, 9, 1, 20, 0), + b"\t\n\x0b\x0c\r", + b"\t\n\x0b\x0c", + True, + "56592D8F-D876-4629-B8E5-C9C882A23C9D", + ), + ( + 10, + 1000, + 100, + 1000000, + Decimal("12345.67"), + 10.12, + 10.399999866485596, + Decimal("1234.5600"), + Decimal("12.3400"), + "FixedStr10", + "VarStr10", + "Text10", + "UniFix10 ", + "UniVar10", + "UniText10", + datetime.date(2023, 10, 1), + datetime.time(21, 0), + datetime.datetime(2023, 10, 1, 21, 0), + datetime.datetime(2023, 10, 1, 21, 0, 0, 123000), + datetime.datetime(2023, 10, 1, 21, 0), + b"\n\x0b\x0c\r\x0e", + b"\n\x0b\x0c\r", + False, + "66592D8F-D876-4629-B8E5-C9C882A23C9D", + ), +] + +sql_server_all_type_small_data = sql_server_all_type_data[5:] +oracledb_all_type_small_data = oracledb_all_type_data[5:] +oracledb_all_type_small_data_result = oracledb_all_type_data_result[5:] + + +def sql_server_create_connection(): + return FakeConnection( + sql_server_all_type_data, sql_server_all_type_schema, "pyodbc" + ) + + +def sql_server_create_connection_small_data(): + return FakeConnection( + sql_server_all_type_small_data, sql_server_all_type_schema, "pyodbc" + ) + + +def sqlite3_db(db_path): + conn = create_connection_to_sqlite3_db(db_path) + cursor = conn.cursor() + table_name = "PrimitiveTypes" + columns = [ + "id", + "int_col", + "real_col", + "text_col", + "blob_col", + "null_col", + "ts_col", + "date_col", + "time_col", + "short_col", + "long_col", + "double_col", + "decimal_col", + "map_col", + "array_col", + "var_col", + ] + # Create a table with different primitive types + # sqlite3 only supports 5 types: NULL, INTEGER, REAL, TEXT, BLOB + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + id INTEGER PRIMARY KEY, -- Auto-incrementing primary key + int_col INTEGER, -- Integer column + real_col REAL, -- Floating point column + text_col TEXT, -- String column + blob_col BLOB, -- Binary data column + null_col NULL, -- Explicit NULL type (for testing purposes) + ts_col TEXT, -- Timestamp column in TEXT format + date_col TEXT, -- Date column in TEXT format + time_col TEXT, -- Time column in TEXT format + short_col INTEGER, -- Short integer column + long_col INTEGER, -- Long integer column + double_col REAL, -- Double column + decimal_col REAL, -- Decimal column + map_col TEXT, -- Map column in TEXT format + array_col TEXT, -- Array column in TEXT format + var_col TEXT -- Variant column in TEXT format + ) + """ + ) + test_datetime = datetime.datetime(2021, 1, 2, 12, 34, 56) + test_date = test_datetime.date() + test_time = test_datetime.time() + example_data = [ + ( + 1, + 42, + 3.14, + "Hello, world!", + b"\x00\x01\x02\x03", + None, + test_datetime.isoformat(), + test_date.isoformat(), + test_time.isoformat(), + 1, + 2, + 3.0, + 4.0, + '{"a": 1, "b": 2}', + "[1, 2, 3]", + "1", + ), + ( + 2, + -10, + 2.718, + "SQLite", + b"\x04\x05\x06\x07", + None, + test_datetime.isoformat(), + test_date.isoformat(), + test_time.isoformat(), + 1, + 2, + 3.0, + 4.0, + '{"a": 1, "b": 2}', + "[1, 2, 3]", + "2", + ), + ( + 3, + 9999, + -0.99, + "Python", + b"\x08\x09\x0A\x0B", + None, + test_datetime.isoformat(), + test_date.isoformat(), + test_time.isoformat(), + 1, + 2, + 3.0, + 4.0, + '{"a": 1, "b": 2}', + "[1, 2, 3]", + "3", + ), + ( + 4, + 0, + 123.456, + "Data", + b"\x0C\x0D\x0E\x0F", + None, + test_datetime.isoformat(), + test_date.isoformat(), + test_time.isoformat(), + 1, + 2, + 3.0, + 4.0, + '{"a": 1, "b": 2}', + "[1, 2, 3]", + "4", + ), + ] + assert_data = [ + ( + 1, + 42, + 3.14, + "Hello, world!", + b"\x00\x01\x02\x03", + None, + test_datetime, + test_date, + test_time, + 1, + 2, + 3.0, + 4.0, + '{\n "a": 1,\n "b": 2\n}', + '[\n "[1, 2, 3]"\n]', + '"1"', + ), + ( + 2, + -10, + 2.718, + "SQLite", + b"\x04\x05\x06\x07", + None, + test_datetime, + test_date, + test_time, + 1, + 2, + 3.0, + 4.0, + '{\n "a": 1,\n "b": 2\n}', + '[\n "[1, 2, 3]"\n]', + '"2"', + ), + ( + 3, + 9999, + -0.99, + "Python", + b"\x08\x09\x0A\x0B", + None, + test_datetime, + test_date, + test_time, + 1, + 2, + 3.0, + 4.0, + '{\n "a": 1,\n "b": 2\n}', + '[\n "[1, 2, 3]"\n]', + '"3"', + ), + ( + 4, + 0, + 123.456, + "Data", + b"\x0C\x0D\x0E\x0F", + None, + test_datetime, + test_date, + test_time, + 1, + 2, + 3.0, + 4.0, + '{\n "a": 1,\n "b": 2\n}', + '[\n "[1, 2, 3]"\n]', + '"4"', + ), + ] + cursor.executemany( + f"INSERT INTO {table_name} VALUES ({','.join('?' * 16)})", example_data + ) + conn.commit() + conn.close() + return table_name, columns, example_data, assert_data + + +def create_connection_to_sqlite3_db(db_path): + return sqlite3.connect(db_path) + + +def oracledb_create_connection(): + return FakeConnection(oracledb_all_type_data, oracledb_all_type_schema, "oracledb") + + +def oracledb_create_connection_small_data(): + return FakeConnection( + oracledb_all_type_small_data, oracledb_all_type_schema, "oracledb" + ) diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index b057f4eea03..7f613d2f866 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -306,6 +306,8 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test_excel.xlsx", "resources/test_sas.sas7bdat", "resources/test_sas.xpt", + "resources/test_data_source_dir/", + "resources/test_data_source_dir/test_data_source_data.py", "resources/test_sp_dir/", "resources/test_sp_dir/test_sp_file.py", "resources/test_sp_dir/test_sp_mod3_file.py",