Skip to content

Commit

Permalink
Merge branch 'main' into CLOUD-8059/sql-server-custom-params
Browse files Browse the repository at this point in the history
  • Loading branch information
jzalucki authored Jul 17, 2024
2 parents 110c1a4 + 2293890 commit aaed83f
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 18 deletions.
8 changes: 8 additions & 0 deletions soda/core/soda/common/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import logging
from textwrap import dedent
from typing import List

from ruamel.yaml import YAML
Expand Down Expand Up @@ -101,6 +104,11 @@ def __get_value(self, key: str, required: bool, value_type: type):
return None
return value

def _sanitize_query(self, query: str | None):
if query == None:
return None
return dedent(query).strip()

def _resolve_jinja(self, value: str, variables: dict = None):
from soda.common.jinja import Jinja

Expand Down
6 changes: 6 additions & 0 deletions soda/core/soda/sampler/sample_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ def get_sample_file_name(self):
self.sample_name,
]
return "_".join([part for part in parts if part])

def scan_context_get(self, key: str, default: any = None) -> any:
return self.scan.scan_context_get(key, default)

def scan_context_set(self, key: str | list, value: any, overwrite: bool = True):
return self.scan.scan_context_set(key, value, overwrite)
25 changes: 25 additions & 0 deletions soda/core/soda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self):
self._sample_tables_result_tables: list[SampleTablesResultTable] = []
self._logs.info(f"Soda Core {SODA_CORE_VERSION}")
self.scan_results: dict = {}
self.scan_context: dict = {}

def build_scan_results(self) -> dict:
checks = [check.get_dict() for check in self._checks if check.outcome is not None and check.archetype is None]
Expand Down Expand Up @@ -958,3 +959,27 @@ def get_all_checks_text(self) -> str | None:

def has_soda_cloud_connection(self):
return self._configuration.soda_cloud is not None

def scan_context_get(self, key: str, default: any = None) -> any:
return self.scan_context.get(key, default)

def scan_context_set(self, key: str | list, value: any, overwrite: bool = True):
dic = self.scan_context
dic_key = key

if type(key) == list:
for k in key[:-1]:
if type(dic) != dict:
raise ValueError(
f"Value '{dic}' is not a dictionary but you are trying to use it as one using nested key in sample context."
)
if k not in dic:
dic = dic.setdefault(k, {})
else:
dic = dic[k]
dic_key = key[-1]

if dic_key in self.scan_context and not overwrite:
raise ValueError(f"Key '{key}' already exists in scan context")

dic[dic_key] = value
19 changes: 9 additions & 10 deletions soda/core/soda/sodacl/sodacl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
from datetime import timedelta
from numbers import Number
from textwrap import dedent
from typing import List

from antlr4 import CommonTokenStream, InputStream
Expand Down Expand Up @@ -394,7 +393,7 @@ def parse_group_by_cfg(self, check_configurations, check_str, header_str):

try:
group_limit = self._get_optional("group_limit", int) or 1000
query = self._get_required("query", str)
query = self._sanitize_query(self._get_required("query", str))
fields = self._get_required("fields", list)
check_cfgs = self._get_required("checks", list)
if check_cfgs:
Expand Down Expand Up @@ -444,7 +443,7 @@ def parse_user_defined_failed_rows_check_cfg(self, check_configurations, check_s
fail_condition_sql_expr = self._get_optional(FAIL_CONDITION, str)
samples_limit = self._get_optional(SAMPLES_LIMIT, int)
samples_columns = self._get_optional(SAMPLES_COLUMNS, list)
fail_query = self._get_optional(FAIL_QUERY, str)
fail_query = self._sanitize_query(self._get_optional(FAIL_QUERY, str))

fail_threshold_condition_str = self._get_optional(FAIL, str)
fail_threshold_cfg = self.__parse_configuration_threshold_condition(fail_threshold_condition_str)
Expand Down Expand Up @@ -478,7 +477,7 @@ def parse_user_defined_failed_rows_check_cfg(self, check_configurations, check_s
samples_columns=samples_columns,
)
else:
fail_query = self._get_optional(FAIL_QUERY, str)
fail_query = self._sanitize_query(self._get_optional(FAIL_QUERY, str))
if fail_query:
return UserDefinedFailedRowsCheckCfg(
source_header=header_str,
Expand Down Expand Up @@ -542,7 +541,7 @@ def parse_failed_rows_data_source_query_check(
self._push_path_element(check_str, check_configurations)
try:
name = self._get_optional(NAME, str)
query = self._get_required(FAIL_QUERY, str)
query = self._sanitize_query(self._get_required(FAIL_QUERY, str))
samples_limit = self._get_optional(SAMPLES_LIMIT, int)
samples_columns = self._get_optional(SAMPLES_COLUMNS, list)
fail_threshold_condition_str = self._get_optional(FAIL, str)
Expand Down Expand Up @@ -663,22 +662,22 @@ def __parse_metric_check(
if configuration_key.endswith("sql_file"):
fs = file_system()
sql_file_path = fs.join(fs.dirname(self.path_stack.file_path), configuration_value.strip())
failed_rows_query = dedent(fs.file_read_as_str(sql_file_path)).strip()
failed_rows_query = self._sanitize_query(fs.file_read_as_str(sql_file_path))
else:
failed_rows_query = dedent(configuration_value).strip()
failed_rows_query = self._sanitize_query(configuration_value)
elif configuration_key.endswith("query") or configuration_key.endswith("sql_file"):
if configuration_key.endswith("sql_file"):
fs = file_system()
sql_file_path = fs.join(fs.dirname(self.path_stack.file_path), configuration_value.strip())
metric_query = dedent(fs.file_read_as_str(sql_file_path)).strip()
metric_query = self._sanitize_query(fs.file_read_as_str(sql_file_path))
configuration_metric_name = (
configuration_key[: -len(" sql_file")]
if len(configuration_key) > len(" sql_file")
else None
)

else:
metric_query = dedent(configuration_value).strip()
metric_query = self._sanitize_query(configuration_value)

configuration_metric_name = (
configuration_key[: -len(" query")] if len(configuration_key) > len(" query") else None
Expand Down Expand Up @@ -1041,7 +1040,7 @@ def __parse_group_evolution_check(
f'Invalid group evolution check configuration key "{configuration_key}"', location=self.location
)
name = self._get_optional(NAME, str)
query = self._get_required("query", str)
query = self._sanitize_query(self._get_required("query", str))
group_evolution_check_cfg = GroupEvolutionCheckCfg(
source_header=header_str,
source_line=check_str,
Expand Down
36 changes: 28 additions & 8 deletions soda/core/tests/data_source/test_group_evolution.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import pytest
from helpers.common_test_tables import customers_test_table
from helpers.data_source_fixture import DataSourceFixture
from helpers.fixtures import test_data_source


@pytest.mark.skipif(
test_data_source not in ["postgres", "bigquery", "spark_df"],
reason="Need to make tests work with lower and upper case values for column names",
)
def test_group_evolution(data_source_fixture: DataSourceFixture):
table_name = data_source_fixture.ensure_test_table(customers_test_table)
qualified_table_name = data_source_fixture.data_source.qualified_table_name(table_name)
casify = data_source_fixture.data_source.default_casify_column_name

scan = data_source_fixture.create_test_scan()
scan.add_sodacl_yaml_str(
f"""
checks for {table_name}:
- group evolution:
query: |
SELECT distinct(country)
FROM {table_name}
SELECT distinct({casify('country')})
FROM {qualified_table_name}
fail:
when required group missing: ["BE"]
when forbidden group present: ["US"]
Expand All @@ -27,3 +23,27 @@ def test_group_evolution(data_source_fixture: DataSourceFixture):
scan.execute()

scan.assert_all_checks_pass()


def test_group_evolution_query_multiline(data_source_fixture: DataSourceFixture):
table_name = data_source_fixture.ensure_test_table(customers_test_table)
qualified_table_name = data_source_fixture.data_source.qualified_table_name(table_name)
casify = data_source_fixture.data_source.default_casify_column_name

scan = data_source_fixture.create_test_scan()
scan.add_sodacl_yaml_str(
f"""
checks for {table_name}:
- group evolution:
query: |
SELECT distinct({casify('country')})
FROM {qualified_table_name}
fail:
when required group missing: ["BE"]
when forbidden group present: ["US"]
"""
)
scan.execute()

# No empty line at the end of the string
assert scan._queries[0].sql == f"""SELECT distinct({casify('country')})\nFROM {qualified_table_name}"""
2 changes: 2 additions & 0 deletions soda/denodo/soda/data_sources/denodo_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
self.password = data_source_properties.get("password")
self.username = data_source_properties.get("username")
self.connection_timeout = data_source_properties.get("connection_timeout")
self.sslmode = data_source_properties.get("sslmode", "prefer")

def connect(self):
import psycopg2
Expand All @@ -29,6 +30,7 @@ def connect(self):
port=self.port,
connect_timeout=self.connection_timeout,
database=self.database,
sslmode=self.sslmode,
)
return self.connection

Expand Down
2 changes: 2 additions & 0 deletions soda/postgres/soda/data_sources/postgres_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
self.password = data_source_properties.get("password")
self.username = data_source_properties.get("username")
self.connection_timeout = data_source_properties.get("connection_timeout")
self.sslmode = data_source_properties.get("sslmode", "prefer")

def connect(self):
import psycopg2
Expand All @@ -36,6 +37,7 @@ def connect(self):
connect_timeout=self.connection_timeout,
database=self.database,
options=options,
sslmode=self.sslmode,
)
else:
raise ConnectionError(f"Invalid postgres connection properties: invalid host: {self.host}")
Expand Down

0 comments on commit aaed83f

Please sign in to comment.