Skip to content

Commit

Permalink
Fixing release stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
tombaeyens committed Jun 28, 2024
1 parent 8078665 commit d2c4473
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 58 deletions.
6 changes: 5 additions & 1 deletion soda/atlan/soda/atlan/atlan_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def process_contract_results(self, contract_result: ContractResult) -> None:
)
return None

contract_dict: dict = contract_result.contract.contract_file.dict
contract_dict: dict = contract_result.contract.contract_file.dict.copy()
contract_dict.setdefault("type", "Table")
contract_dict.setdefault("status", "DRAFT")
contract_dict.setdefault("kind", "DataContract")

contract_json_str: str = dumps(contract_dict)

self.logs.info(f"Pushing contract to Atlan: {dataset_atlan_qualified_name}")
Expand Down
9 changes: 3 additions & 6 deletions soda/atlan/tests/atlan/test_atlan_contract_push_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
)


@pytest.mark.skip(
"Takes too long to be part of the local development test suite & depends on Atlan & Soda Cloud services"
)
#@pytest.mark.skip(
# "Takes too long to be part of the local development test suite & depends on Atlan & Soda Cloud services"
# )
def test_atlan_contract_push_plugin():
this_file_dir_path = os.path.dirname(os.path.realpath(__file__))
load_dotenv(f"{this_file_dir_path}/.env", override=True)
Expand All @@ -33,9 +33,6 @@ def test_atlan_contract_push_plugin():

contract_yaml_str: str = dedent(
"""
type: Table
status: DRAFT
kind: DataContract
data_source: postgres_ds
database: ${CONTRACTS_POSTGRES_DATABASE}
schema: contracts
Expand Down
34 changes: 6 additions & 28 deletions soda/contracts/soda/contracts/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@ class Contract:

def __init__(
self,
data_source_name: str | None,
database_name: str | None,
schema_name: str | None,
contract_file: YamlFile,
logs: Logs,
):
self.contract_file: YamlFile = contract_file
self.logs: Logs = logs

self.data_source_name: str | None = data_source_name
self.database_name: str | None = database_name
self.schema_name: str | None = schema_name
self.data_source_name: str | None = None
self.database_name: str | None = None
self.schema_name: str | None = None
self.dataset_name: str | None = None

# TODO explain filter_expression_sql, default filter and named filters
Expand Down Expand Up @@ -87,28 +84,9 @@ def __parse(self) -> None:

contract_yaml_dict = self.contract_file.dict

# self.database_name comes from the contract verification API
contract_database_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "database")
if contract_database_name is not None:
if self.database_name is None:
self.database_name = contract_database_name
elif contract_database_name != self.database_name:
self.logs.info(
f"Database name in contract YAML '{contract_database_name}' was overridden to "
f"'{self.database_name}' by contract verification parameter."
)

# self.schema_name comes from the contract verification API
contract_schema_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "schema")
if contract_schema_name is not None:
if self.schema_name is None:
self.schema_name = contract_schema_name
elif contract_schema_name != self.schema_name:
self.logs.info(
f"Schema name in contract YAML '{contract_schema_name}' was overridden to "
f"'{self.schema_name}' by contract verification parameter."
)

self.data_source_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "data_source")
self.database_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "database")
self.schema_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "schema")
self.dataset_name: str | None = yaml_helper.read_string(contract_yaml_dict, "dataset")
self.filter_sql: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "filter_sql")
self.filter: str | None = "default" if self.filter_sql else None
Expand Down
27 changes: 6 additions & 21 deletions soda/contracts/soda/contracts/contract_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def __init__(self):
self.logs: Logs = Logs()
self.data_source_yaml_file: YamlFile | None = None
self.spark_configuration: SparkConfiguration | None = None
self.database_name: str | None = None
self.schema_name: str | None = None
self.contract_files: list[YamlFile] = []
self.soda_cloud_file: YamlFile | None = None
self.plugin_files: list[YamlFile] = []
Expand Down Expand Up @@ -81,16 +79,6 @@ def with_data_source_spark_session(
)
return self

def with_database_name(self, database_name: str) -> ContractVerificationBuilder:
assert isinstance(database_name, str)
self.database_name = database_name
return self

def with_schema_name(self, schema_name: str) -> ContractVerificationBuilder:
assert isinstance(schema_name, str)
self.schema_name = schema_name
return self

def with_soda_cloud_yaml_file(self, soda_cloud_yaml_file_path: str) -> ContractVerificationBuilder:
assert isinstance(soda_cloud_yaml_file_path, str)
if self.soda_cloud_file is not None:
Expand Down Expand Up @@ -154,8 +142,6 @@ def __init__(self, contract_verification_builder: ContractVerificationBuilder):
self.logs: Logs = contract_verification_builder.logs
self.variables: dict[str, str] = contract_verification_builder.variables
self.data_source: DataSource | None = None
self.database_name: str | None = contract_verification_builder.database_name
self.schema_name: str | None = contract_verification_builder.schema_name
self.contracts: list[Contract] = []
self.soda_cloud: SodaCloud | None = None
self.plugins: list[Plugin] = []
Expand All @@ -180,9 +166,6 @@ def _initialize_contracts(self, contract_verification_builder: ContractVerificat
contract_file.parse(self.variables)
if contract_file.is_ok():
contract: Contract = Contract(
data_source_name=self.data_source.data_source_name if self.data_source else None,
database_name=self.database_name,
schema_name=self.schema_name,
contract_file=contract_file,
logs=contract_file.logs,
)
Expand Down Expand Up @@ -259,10 +242,12 @@ def _verify(self, contract: Contract) -> ContractResult:
scan._data_source_manager.data_sources[self.data_source.data_source_name] = sodacl_data_source

if self.soda_cloud:
scan_definition_name = (
f"dataset://{self.data_source.data_source_name}"
f"/{self.database_name}/{self.schema_name}/{contract.dataset_name}"
)
parts: list[str] = [
self.data_source.data_source_name, contract.database_name,
contract.schema_name, contract.dataset_name
]
parts_str: str = "/".join([part for part in parts if part is not None])
scan_definition_name = f"dataset://{parts_str}"
scan.set_scan_definition_name(scan_definition_name)
# noinspection PyProtectedMember
scan._configuration.soda_cloud = CustomizedSodaClCloud(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from textwrap import dedent

from contracts.helpers.test_data_source import TestDataSource
from soda.contracts.contract_verification import ContractVerification, SodaException


Expand All @@ -25,7 +26,7 @@ def test_data_source_file_variable_resolving(environ):
assert "sodasql" == resolved_connection_properties["username"]


def test_invalid_database():
def test_invalid_database(test_data_source: TestDataSource):
data_source_yaml_str = dedent(
"""
name: postgres_ds
Expand All @@ -44,7 +45,7 @@ def test_invalid_database():
assert 'database "invalid_db" does not exist' in contract_verification_str


def test_invalid_username():
def test_invalid_username(test_data_source: TestDataSource):
data_source_yaml_str = dedent(
"""
name: postgres_ds
Expand Down

0 comments on commit d2c4473

Please sign in to comment.