From d2c4473898135744e487b8667426a709f2dbbf1f Mon Sep 17 00:00:00 2001 From: tombaeyens Date: Fri, 28 Jun 2024 12:28:29 +0200 Subject: [PATCH] Fixing release stuff --- soda/atlan/soda/atlan/atlan_plugin.py | 6 +++- .../atlan/test_atlan_contract_push_plugin.py | 9 ++--- soda/contracts/soda/contracts/contract.py | 34 ++++--------------- .../soda/contracts/contract_verification.py | 27 ++++----------- .../other/test_data_source_configurations.py | 5 +-- 5 files changed, 23 insertions(+), 58 deletions(-) diff --git a/soda/atlan/soda/atlan/atlan_plugin.py b/soda/atlan/soda/atlan/atlan_plugin.py index a1e3dc88b..6347c959e 100644 --- a/soda/atlan/soda/atlan/atlan_plugin.py +++ b/soda/atlan/soda/atlan/atlan_plugin.py @@ -41,7 +41,11 @@ def process_contract_results(self, contract_result: ContractResult) -> None: ) return None - contract_dict: dict = contract_result.contract.contract_file.dict + contract_dict: dict = contract_result.contract.contract_file.dict.copy() + contract_dict.setdefault("type", "Table") + contract_dict.setdefault("status", "DRAFT") + contract_dict.setdefault("kind", "DataContract") + contract_json_str: str = dumps(contract_dict) self.logs.info(f"Pushing contract to Atlan: {dataset_atlan_qualified_name}") diff --git a/soda/atlan/tests/atlan/test_atlan_contract_push_plugin.py b/soda/atlan/tests/atlan/test_atlan_contract_push_plugin.py index ddfa06e24..66dce3462 100644 --- a/soda/atlan/tests/atlan/test_atlan_contract_push_plugin.py +++ b/soda/atlan/tests/atlan/test_atlan_contract_push_plugin.py @@ -10,9 +10,9 @@ ) -@pytest.mark.skip( - "Takes too long to be part of the local development test suite & depends on Atlan & Soda Cloud services" -) +#@pytest.mark.skip( +# "Takes too long to be part of the local development test suite & depends on Atlan & Soda Cloud services" +# ) def test_atlan_contract_push_plugin(): this_file_dir_path = os.path.dirname(os.path.realpath(__file__)) load_dotenv(f"{this_file_dir_path}/.env", override=True) @@ -33,9 +33,6 @@ def test_atlan_contract_push_plugin(): contract_yaml_str: str = dedent( """ - type: Table - status: DRAFT - kind: DataContract data_source: postgres_ds database: ${CONTRACTS_POSTGRES_DATABASE} schema: contracts diff --git a/soda/contracts/soda/contracts/contract.py b/soda/contracts/soda/contracts/contract.py index e859811d6..814a994fa 100644 --- a/soda/contracts/soda/contracts/contract.py +++ b/soda/contracts/soda/contracts/contract.py @@ -42,18 +42,15 @@ class Contract: def __init__( self, - data_source_name: str | None, - database_name: str | None, - schema_name: str | None, contract_file: YamlFile, logs: Logs, ): self.contract_file: YamlFile = contract_file self.logs: Logs = logs - self.data_source_name: str | None = data_source_name - self.database_name: str | None = database_name - self.schema_name: str | None = schema_name + self.data_source_name: str | None = None + self.database_name: str | None = None + self.schema_name: str | None = None self.dataset_name: str | None = None # TODO explain filter_expression_sql, default filter and named filters @@ -87,28 +84,9 @@ def __parse(self) -> None: contract_yaml_dict = self.contract_file.dict - # self.database_name comes from the contract verification API - contract_database_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "database") - if contract_database_name is not None: - if self.database_name is None: - self.database_name = contract_database_name - elif contract_database_name != self.database_name: - self.logs.info( - f"Database name in contract YAML '{contract_database_name}' was overridden to " - f"'{self.database_name}' by contract verification parameter." - ) - - # self.schema_name comes from the contract verification API - contract_schema_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "schema") - if contract_schema_name is not None: - if self.schema_name is None: - self.schema_name = contract_schema_name - elif contract_schema_name != self.schema_name: - self.logs.info( - f"Schema name in contract YAML '{contract_schema_name}' was overridden to " - f"'{self.schema_name}' by contract verification parameter." - ) - + self.data_source_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "data_source") + self.database_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "database") + self.schema_name: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "schema") self.dataset_name: str | None = yaml_helper.read_string(contract_yaml_dict, "dataset") self.filter_sql: str | None = yaml_helper.read_string_opt(contract_yaml_dict, "filter_sql") self.filter: str | None = "default" if self.filter_sql else None diff --git a/soda/contracts/soda/contracts/contract_verification.py b/soda/contracts/soda/contracts/contract_verification.py index 61f7f6204..4d8d2bdb7 100644 --- a/soda/contracts/soda/contracts/contract_verification.py +++ b/soda/contracts/soda/contracts/contract_verification.py @@ -23,8 +23,6 @@ def __init__(self): self.logs: Logs = Logs() self.data_source_yaml_file: YamlFile | None = None self.spark_configuration: SparkConfiguration | None = None - self.database_name: str | None = None - self.schema_name: str | None = None self.contract_files: list[YamlFile] = [] self.soda_cloud_file: YamlFile | None = None self.plugin_files: list[YamlFile] = [] @@ -81,16 +79,6 @@ def with_data_source_spark_session( ) return self - def with_database_name(self, database_name: str) -> ContractVerificationBuilder: - assert isinstance(database_name, str) - self.database_name = database_name - return self - - def with_schema_name(self, schema_name: str) -> ContractVerificationBuilder: - assert isinstance(schema_name, str) - self.schema_name = schema_name - return self - def with_soda_cloud_yaml_file(self, soda_cloud_yaml_file_path: str) -> ContractVerificationBuilder: assert isinstance(soda_cloud_yaml_file_path, str) if self.soda_cloud_file is not None: @@ -154,8 +142,6 @@ def __init__(self, contract_verification_builder: ContractVerificationBuilder): self.logs: Logs = contract_verification_builder.logs self.variables: dict[str, str] = contract_verification_builder.variables self.data_source: DataSource | None = None - self.database_name: str | None = contract_verification_builder.database_name - self.schema_name: str | None = contract_verification_builder.schema_name self.contracts: list[Contract] = [] self.soda_cloud: SodaCloud | None = None self.plugins: list[Plugin] = [] @@ -180,9 +166,6 @@ def _initialize_contracts(self, contract_verification_builder: ContractVerificat contract_file.parse(self.variables) if contract_file.is_ok(): contract: Contract = Contract( - data_source_name=self.data_source.data_source_name if self.data_source else None, - database_name=self.database_name, - schema_name=self.schema_name, contract_file=contract_file, logs=contract_file.logs, ) @@ -259,10 +242,12 @@ def _verify(self, contract: Contract) -> ContractResult: scan._data_source_manager.data_sources[self.data_source.data_source_name] = sodacl_data_source if self.soda_cloud: - scan_definition_name = ( - f"dataset://{self.data_source.data_source_name}" - f"/{self.database_name}/{self.schema_name}/{contract.dataset_name}" - ) + parts: list[str] = [ + self.data_source.data_source_name, contract.database_name, + contract.schema_name, contract.dataset_name + ] + parts_str: str = "/".join([part for part in parts if part is not None]) + scan_definition_name = f"dataset://{parts_str}" scan.set_scan_definition_name(scan_definition_name) # noinspection PyProtectedMember scan._configuration.soda_cloud = CustomizedSodaClCloud( diff --git a/soda/contracts/tests/contracts/other/test_data_source_configurations.py b/soda/contracts/tests/contracts/other/test_data_source_configurations.py index f2c15f42c..0e29ea9e5 100644 --- a/soda/contracts/tests/contracts/other/test_data_source_configurations.py +++ b/soda/contracts/tests/contracts/other/test_data_source_configurations.py @@ -1,6 +1,7 @@ import os from textwrap import dedent +from contracts.helpers.test_data_source import TestDataSource from soda.contracts.contract_verification import ContractVerification, SodaException @@ -25,7 +26,7 @@ def test_data_source_file_variable_resolving(environ): assert "sodasql" == resolved_connection_properties["username"] -def test_invalid_database(): +def test_invalid_database(test_data_source: TestDataSource): data_source_yaml_str = dedent( """ name: postgres_ds @@ -44,7 +45,7 @@ def test_invalid_database(): assert 'database "invalid_db" does not exist' in contract_verification_str -def test_invalid_username(): +def test_invalid_username(test_data_source: TestDataSource): data_source_yaml_str = dedent( """ name: postgres_ds