diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index 6d4e6dda53..b140935d4c 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -21,6 +21,8 @@ env: # Slack hook for chess in production example RUNTIME__SLACK_INCOMING_HOOK: ${{ secrets.RUNTIME__SLACK_INCOMING_HOOK }} + # Path to local qdrant database + DESTINATION__QDRANT__CREDENTIALS__PATH: zendesk.qdb # detect if the workflow is executed in a repo fork IS_FORK: ${{ github.event.pull_request.head.repo.fork }} diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 263d3f588c..f1bf6016bc 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -21,7 +21,7 @@ env: RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 RUNTIME__LOG_LEVEL: ERROR RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} - ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\"]" + ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" DESTINATION__WEAVIATE__VECTORIZER: text2vec-contextionary @@ -63,6 +63,11 @@ jobs: --health-timeout 5s --health-retries 5 + qdrant: + image: qdrant/qdrant:v1.8.4 + ports: + - 6333:6333 + steps: - name: Check out uses: actions/checkout@master @@ -90,7 +95,7 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations - name: Install dependencies - run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate --with sentry-sdk --with pipeline -E deltalake + run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate -E qdrant --with sentry-sdk --with pipeline -E deltalake - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml @@ -100,6 +105,7 @@ jobs: name: Run tests Linux env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data + DESTINATION__QDRANT__CREDENTIALS__location: http://localhost:6333 - name: Stop weaviate if: always() diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index a735aad5cf..259389a5e9 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -63,6 +63,28 @@ class StorageSchemaInfo(NamedTuple): inserted_at: datetime.datetime schema: str + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StorageSchemaInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StorageSchemaInfo: Instance of this class + """ + return cls( + version_hash=normalized_doc[naming_convention.normalize_identifier("version_hash")], + schema_name=normalized_doc[naming_convention.normalize_identifier("schema_name")], + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + inserted_at=normalized_doc[naming_convention.normalize_identifier("inserted_at")], + schema=normalized_doc[naming_convention.normalize_identifier("schema")], + ) + @dataclasses.dataclass class StateInfo: @@ -82,6 +104,29 @@ def as_doc(self) -> TPipelineStateDoc: doc.pop("version_hash") return doc + @classmethod + def from_normalized_mapping( + cls, normalized_doc: Dict[str, Any], naming_convention: NamingConvention + ) -> "StateInfo": + """Instantiate this class from mapping where keys are normalized according to given naming convention + + Args: + normalized_doc: Mapping with normalized keys (e.g. {Version: ..., PipelineName: ...}) + naming_convention: Naming convention that was used to normalize keys + + Returns: + StateInfo: Instance of this class + """ + return cls( + version=normalized_doc[naming_convention.normalize_identifier("version")], + engine_version=normalized_doc[naming_convention.normalize_identifier("engine_version")], + pipeline_name=normalized_doc[naming_convention.normalize_identifier("pipeline_name")], + state=normalized_doc[naming_convention.normalize_identifier("state")], + created_at=normalized_doc[naming_convention.normalize_identifier("created_at")], + version_hash=normalized_doc.get(naming_convention.normalize_identifier("version_hash")), + _dlt_load_id=normalized_doc.get(naming_convention.normalize_identifier("_dlt_load_id")), + ) + @configspec class DestinationClientConfiguration(BaseConfiguration): diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index 4d1ed1234d..baf5e5dc59 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,6 +1,6 @@ import dataclasses -from typing import Optional, Final -from typing_extensions import Annotated +from typing import Optional, Final, Any +from typing_extensions import Annotated, TYPE_CHECKING from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import ( @@ -8,11 +8,17 @@ CredentialsConfiguration, ) from dlt.common.destination.reference import DestinationClientDwhConfiguration +from dlt.destinations.impl.qdrant.exceptions import InvalidInMemoryQdrantCredentials + +if TYPE_CHECKING: + from qdrant_client import QdrantClient @configspec class QdrantCredentials(CredentialsConfiguration): - # If `:memory:` - use in-memory Qdrant instance. + if TYPE_CHECKING: + _external_client: "QdrantClient" + # If `str` - use it as a `url` parameter. # If `None` - use default values for `host` and `port` location: Optional[str] = None @@ -21,6 +27,47 @@ class QdrantCredentials(CredentialsConfiguration): # Persistence path for QdrantLocal. Default: `None` path: Optional[str] = None + def is_local(self) -> bool: + return self.path is not None + + def on_resolved(self) -> None: + if self.location == ":memory:": + raise InvalidInMemoryQdrantCredentials() + + def parse_native_representation(self, native_value: Any) -> None: + try: + from qdrant_client import QdrantClient + + if isinstance(native_value, QdrantClient): + self._external_client = native_value + self.resolve() + except ModuleNotFoundError: + pass + + super().parse_native_representation(native_value) + + def _create_client(self, model: str, **options: Any) -> "QdrantClient": + from qdrant_client import QdrantClient + + creds = dict(self) + if creds["path"]: + del creds["location"] + + client = QdrantClient(**creds, **options) + client.set_model(model) + return client + + def get_client(self, model: str, **options: Any) -> "QdrantClient": + client = getattr(self, "_external_client", None) + return client or self._create_client(model, **options) + + def close_client(self, client: "QdrantClient") -> None: + """Close client if not external""" + if getattr(self, "_external_client", None) is client: + # Do not close client created externally + return + client.close() + def __str__(self) -> str: return self.location or "localhost" @@ -81,6 +128,12 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): # Find the list here. https://qdrant.github.io/fastembed/examples/Supported_Models/. model: str = "BAAI/bge-small-en" + def get_client(self) -> "QdrantClient": + return self.credentials.get_client(self.model, **dict(self.options)) + + def close_client(self, client: "QdrantClient") -> None: + self.credentials.close_client(client) + def fingerprint(self) -> str: """Returns a fingerprint of a connection string""" diff --git a/dlt/destinations/impl/qdrant/exceptions.py b/dlt/destinations/impl/qdrant/exceptions.py new file mode 100644 index 0000000000..19f33f64c1 --- /dev/null +++ b/dlt/destinations/impl/qdrant/exceptions.py @@ -0,0 +1,11 @@ +from dlt.common.destination.exceptions import DestinationTerminalException + + +class InvalidInMemoryQdrantCredentials(DestinationTerminalException): + def __init__(self) -> None: + super().__init__( + "To use in-memory instance of qdrant, " + "please instantiate it first and then pass to destination factory\n" + '\nclient = QdrantClient(":memory:")\n' + 'dlt.pipeline(pipeline_name="...", destination=dlt.destinations.qdrant(client)' + ) diff --git a/dlt/destinations/impl/qdrant/factory.py b/dlt/destinations/impl/qdrant/factory.py index defd29a03a..2bface0938 100644 --- a/dlt/destinations/impl/qdrant/factory.py +++ b/dlt/destinations/impl/qdrant/factory.py @@ -1,6 +1,8 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.reference import TDestinationConfig +from dlt.common.normalizers.naming import NamingConvention from dlt.destinations.impl.qdrant.configuration import QdrantCredentials, QdrantClientConfiguration @@ -26,6 +28,20 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: return caps + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: QdrantClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super(qdrant, cls).adjust_capabilities(caps, config, naming) + if config.credentials.is_local(): + # Local qdrant can not load in parallel + caps.loader_parallelism_strategy = "sequential" + caps.max_parallel_load_jobs = 1 + return caps + @property def client_class(self) -> t.Type["QdrantClient"]: from dlt.destinations.impl.qdrant.qdrant_client import QdrantClient diff --git a/dlt/destinations/impl/qdrant/qdrant_client.py b/dlt/destinations/impl/qdrant/qdrant_client.py index 51915c5536..80c158d51a 100644 --- a/dlt/destinations/impl/qdrant/qdrant_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_client.py @@ -1,5 +1,6 @@ from types import TracebackType from typing import Optional, Sequence, List, Dict, Type, Iterable, Any +import threading from dlt.common import logger from dlt.common.json import json @@ -13,6 +14,7 @@ ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.storages import FileStorage from dlt.common.time import precise_time @@ -46,6 +48,7 @@ def __init__( self.config = client_config with FileStorage.open_zipsafe_ro(local_path) as f: + ids: List[str] docs, payloads, ids = [], [], [] for line in f: @@ -53,7 +56,7 @@ def __init__( point_id = ( self._generate_uuid(data, self.unique_identifiers, self.collection_name) if self.unique_identifiers - else uuid.uuid4() + else str(uuid.uuid4()) ) payloads.append(data) ids.append(point_id) @@ -179,22 +182,6 @@ def dataset_name(self) -> str: def sentinel_collection(self) -> str: return self.dataset_name or "DltSentinelCollection" - @staticmethod - def _create_db_client(config: QdrantClientConfiguration) -> QC: - """Generates a Qdrant client from the 'qdrant_client' package. - - Args: - config (QdrantClientConfiguration): Credentials and options for the Qdrant client. - - Returns: - QdrantClient: A Qdrant client instance. - """ - credentials = dict(config.credentials) - options = dict(config.options) - client = QC(**credentials, **options) - client.set_model(config.model) - return client - def _make_qualified_collection_name(self, table_name: str) -> str: """Generates a qualified collection name. @@ -240,14 +227,11 @@ def _create_point_no_vector(self, obj: Dict[str, Any], collection_name: str) -> obj (Dict[str, Any]): The arbitrary data to be inserted as payload. collection_name (str): The name of the collection to insert the point into. """ - # we want decreased ids because the point scroll functions orders by id ASC - # so we want newest first - id_ = 2**64 - int(precise_time() * 10**6) self.db_client.upsert( collection_name, points=[ models.PointStruct( - id=id_, + id=str(uuid.uuid4()), payload=obj, vector={}, ) @@ -331,7 +315,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: p_load_id = self.schema.naming.normalize_identifier("load_id") p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") - # p_created_at = self.schema.naming.normalize_identifier("created_at") + p_created_at = self.schema.naming.normalize_identifier("created_at") limit = 100 offset = None @@ -350,15 +334,13 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) ] ), - # search by package load id which is guaranteed to increase over time - # order_by=models.OrderBy( - # key=p_created_at, - # # direction=models.Direction.DESC, - # ), + order_by=models.OrderBy( + key=p_created_at, + direction=models.Direction.DESC, + ), limit=limit, offset=offset, ) - # print("state_r", state_records) if len(state_records) == 0: return None for state_record in state_records: @@ -378,21 +360,24 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ] ), ) - if load_records.count > 0: - return StateInfo(**state) - except Exception: - return None + if load_records.count == 0: + return None + return StateInfo.from_normalized_mapping(state, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + raise DestinationUndefinedEntity(str(e)) from e + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + raise DestinationUndefinedEntity(str(e)) from e + raise def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage""" try: scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name) p_schema_name = self.schema.naming.normalize_identifier("schema_name") - # this works only because we create points that have no vectors - # with decreasing ids. so newest (lowest ids) go first - # we do not use order_by because it requires and index to be created - # and this behavior is different for local and cloud qdrant - # p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") + p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") response = self.db_client.scroll( scroll_table_name, with_payload=True, @@ -405,15 +390,23 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ] ), limit=1, - # order_by=models.OrderBy( - # key=p_inserted_at, - # direction=models.Direction.DESC, - # ) + order_by=models.OrderBy( + key=p_inserted_at, + direction=models.Direction.DESC, + ), ) - record = response[0][0].payload - return StorageSchemaInfo(**record) - except Exception: - return None + if not response[0]: + return None + payload = response[0][0].payload + return StorageSchemaInfo.from_normalized_mapping(payload, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + raise DestinationUndefinedEntity(str(e)) from e + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + raise DestinationUndefinedEntity(str(e)) from e + raise def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: try: @@ -431,10 +424,18 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ), limit=1, ) - record = response[0][0].payload - return StorageSchemaInfo(**record) - except Exception: - return None + if not response[0]: + return None + payload = response[0][0].payload + return StorageSchemaInfo.from_normalized_mapping(payload, self.schema.naming) + except UnexpectedResponse as e: + if e.status_code == 404: + return None + raise + except ValueError as e: # Local qdrant error + if "not found" in str(e): + return None + raise def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: return LoadQdrantJob( @@ -456,7 +457,7 @@ def complete_load(self, load_id: str) -> None: self._create_point_no_vector(properties, loads_table_name) def __enter__(self) -> "QdrantClient": - self.db_client = QdrantClient._create_db_client(self.config) + self.db_client = self.config.get_client() return self def __exit__( @@ -466,7 +467,7 @@ def __exit__( exc_tb: TracebackType, ) -> None: if self.db_client: - self.db_client.close() + self.config.close_client(self.db_client) self.db_client = None def _update_schema_in_storage(self, schema: Schema) -> None: @@ -485,13 +486,30 @@ def _update_schema_in_storage(self, schema: Schema) -> None: self._create_point_no_vector(properties, version_table_name) def _execute_schema_update(self, only_tables: Iterable[str]) -> None: + is_local = self.config.credentials.is_local() for table_name in only_tables or self.schema.tables: exists = self._collection_exists(table_name) + qualified_collection_name = self._make_qualified_collection_name(table_name) if not exists: self._create_collection( - full_collection_name=self._make_qualified_collection_name(table_name) + full_collection_name=qualified_collection_name, ) + if not is_local: # Indexes don't work in local Qdrant (trigger log warning) + # Create indexes to enable order_by in state and schema tables + if table_name == self.schema.state_table_name: + self.db_client.create_payload_index( + collection_name=qualified_collection_name, + field_name=self.schema.naming.normalize_identifier("created_at"), + field_schema="datetime", + ) + elif table_name == self.schema.version_table_name: + self.db_client.create_payload_index( + collection_name=qualified_collection_name, + field_name=self.schema.naming.normalize_identifier("inserted_at"), + field_schema="datetime", + ) + self._update_schema_in_storage(self.schema) def _collection_exists(self, table_name: str, qualify_table_name: bool = True) -> bool: diff --git a/dlt/load/load.py b/dlt/load/load.py index 9d1d953f7f..76b4806694 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -196,7 +196,9 @@ def w_spool_job( def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs load_files = filter_new_jobs( - self.load_storage.list_new_jobs(load_id), self.destination.capabilities(), self.config + self.load_storage.list_new_jobs(load_id), + self.destination.capabilities(self.destination.configuration(self.initial_client_config)), + self.config, ) file_count = len(load_files) if file_count == 0: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index e4a7c7c4a8..2bfee3fd29 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -723,6 +723,7 @@ def sync_destination( try: try: restored_schemas: Sequence[Schema] = None + remote_state = self._restore_state_from_destination() # if remote state is newer or same diff --git a/docs/examples/qdrant_zendesk/qdrant_zendesk.py b/docs/examples/qdrant_zendesk/qdrant_zendesk.py index 7fb55fe842..5416f2f2d0 100644 --- a/docs/examples/qdrant_zendesk/qdrant_zendesk.py +++ b/docs/examples/qdrant_zendesk/qdrant_zendesk.py @@ -38,8 +38,6 @@ from dlt.destinations.adapters import qdrant_adapter from qdrant_client import QdrantClient -from dlt.common.configuration.inject import with_config - # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) @@ -181,29 +179,22 @@ def get_pages( # make sure nothing failed load_info.raise_on_failed_jobs() - # running the Qdrant client to connect to your Qdrant database - - @with_config(sections=("destination", "qdrant", "credentials")) - def get_qdrant_client(location=dlt.secrets.value, api_key=dlt.secrets.value): - return QdrantClient( - url=location, - api_key=api_key, - ) - - # running the Qdrant client to connect to your Qdrant database - qdrant_client = get_qdrant_client() + # getting the authenticated Qdrant client to connect to your Qdrant database + with pipeline.destination_client() as destination_client: + from qdrant_client import QdrantClient - # view Qdrant collections you'll find your dataset here: - print(qdrant_client.get_collections()) + qdrant_client: QdrantClient = destination_client.db_client # type: ignore + # view Qdrant collections you'll find your dataset here: + print(qdrant_client.get_collections()) - # query Qdrant with prompt: getting tickets info close to "cancellation" - response = qdrant_client.query( - "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table - query_text=["cancel", "cancel subscription"], # prompt to search - limit=3, # limit the number of results to the nearest 3 embeddings - ) + # query Qdrant with prompt: getting tickets info close to "cancellation" + response = qdrant_client.query( + "zendesk_data_content", # collection/dataset name with the 'content' suffix -> tickets content table + query_text="cancel subscription", # prompt to search + limit=3, # limit the number of results to the nearest 3 embeddings + ) - assert len(response) <= 3 and len(response) > 0 + assert len(response) <= 3 and len(response) > 0 - # make sure nothing failed - load_info.raise_on_failed_jobs() + # make sure nothing failed + load_info.raise_on_failed_jobs() diff --git a/docs/tools/prepare_examples_tests.py b/docs/tools/prepare_examples_tests.py index d39d311a50..dc0a3c82f9 100644 --- a/docs/tools/prepare_examples_tests.py +++ b/docs/tools/prepare_examples_tests.py @@ -3,6 +3,7 @@ """ import os import argparse +from typing import List import dlt.cli.echo as fmt @@ -10,7 +11,7 @@ # settings SKIP_FOLDERS = ["archive", ".", "_", "local_cache"] -SKIP_EXAMPLES = ["qdrant_zendesk"] +SKIP_EXAMPLES: List[str] = [] # the entry point for the script MAIN_CLAUSE = 'if __name__ == "__main__":' diff --git a/tests/conftest.py b/tests/conftest.py index 020487d878..669fd19c35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,3 +111,6 @@ def _create_pipeline_instance_id(self) -> str: # disable databricks logging for log in ["databricks.sql.client"]: logging.getLogger(log).setLevel("WARNING") + + # disable httpx request logging (too verbose when testing qdrant) + logging.getLogger("httpx").setLevel("WARNING") diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 630d84a28c..6d78968996 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -192,7 +192,7 @@ def test_parquet_column_names_are_normalized( def some_data(): yield tbl - pipeline = dlt.pipeline("arrow_" + uniq_id(), destination=destination_config.destination) + pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) pipeline.extract(some_data()) # Find the extracted file diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 37f999ff86..d263f165b7 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -65,9 +65,10 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] with pytest.raises(DestinationUndefinedEntity): load_pipeline_state_from_destination(p.pipeline_name, job_client) - # sync the schema - p.sync_schema() - # check if schema exists + # sync the schema + p.sync_schema() + # check if schema exists + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_schema = job_client.get_stored_schema() assert stored_schema is not None # dataset exists, still no table @@ -93,77 +94,87 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - # so dlt in normalize stage infers _state_version table again but with different column order and the column order in schema is different # then in database. parquet is created in schema order and in Redshift it must exactly match the order. # schema.bump_version() - p.sync_schema() + p.sync_schema() + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_schema = job_client.get_stored_schema() assert stored_schema is not None # table is there but no state assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None - # extract state - with p.managed_state(extract_state=True): - pass - # just run the existing extract - p.normalize(loader_file_format=destination_config.file_format) - p.load() + + # extract state + with p.managed_state(extract_state=True): + pass + # just run the existing extract + p.normalize(loader_file_format=destination_config.file_format) + p.load() + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - local_state = p._get_state() - local_state.pop("_local") - assert stored_state == local_state - # extract state again - with p.managed_state(extract_state=True) as managed_state: - # this will be saved - managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} - p.normalize(loader_file_format=destination_config.file_format) - p.load() + local_state = p._get_state() + local_state.pop("_local") + assert stored_state == local_state + # extract state again + with p.managed_state(extract_state=True) as managed_state: + # this will be saved + managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} + p.normalize(loader_file_format=destination_config.file_format) + p.load() + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} - local_state = p._get_state() - local_state.pop("_local") - assert stored_state == local_state - # use the state context manager again but do not change state - with p.managed_state(extract_state=True): - pass - # version not changed - new_local_state = p._get_state() - new_local_state.pop("_local") - assert local_state == new_local_state - p.normalize(loader_file_format=destination_config.file_format) - info = p.load() - assert len(info.loads_ids) == 0 + assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} + local_state = p._get_state() + local_state.pop("_local") + assert stored_state == local_state + # use the state context manager again but do not change state + with p.managed_state(extract_state=True): + pass + # version not changed + new_local_state = p._get_state() + new_local_state.pop("_local") + assert local_state == new_local_state + p.normalize(loader_file_format=destination_config.file_format) + info = p.load() + assert len(info.loads_ids) == 0 + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] new_stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) - # new state should not be stored - assert new_stored_state == stored_state - - # change the state in context manager but there's no extract - with p.managed_state(extract_state=False) as managed_state: - managed_state["sources"] = {"source": "test2"} # type: ignore[dict-item] - new_local_state = p._get_state() - new_local_state_local = new_local_state.pop("_local") - assert local_state != new_local_state - # version increased - assert local_state["_state_version"] + 1 == new_local_state["_state_version"] - # last extracted hash does not match current version hash - assert new_local_state_local["_last_extracted_hash"] != new_local_state["_version_hash"] - - # use the state context manager again but do not change state - # because _last_extracted_hash is not present (or different), the version will not change but state will be extracted anyway - with p.managed_state(extract_state=True): - pass - new_local_state_2 = p._get_state() - new_local_state_2_local = new_local_state_2.pop("_local") - assert new_local_state == new_local_state_2 - # there's extraction timestamp - assert "_last_extracted_at" in new_local_state_2_local - # and extract hash is == hash - assert new_local_state_2_local["_last_extracted_hash"] == new_local_state_2["_version_hash"] - # but the version didn't change - assert new_local_state["_state_version"] == new_local_state_2["_state_version"] - p.normalize(loader_file_format=destination_config.file_format) - info = p.load() - assert len(info.loads_ids) == 1 + # new state should not be stored + assert new_stored_state == stored_state + + # change the state in context manager but there's no extract + with p.managed_state(extract_state=False) as managed_state: + managed_state["sources"] = {"source": "test2"} # type: ignore[dict-item] + new_local_state = p._get_state() + new_local_state_local = new_local_state.pop("_local") + assert local_state != new_local_state + # version increased + assert local_state["_state_version"] + 1 == new_local_state["_state_version"] + # last extracted hash does not match current version hash + assert new_local_state_local["_last_extracted_hash"] != new_local_state["_version_hash"] + + # use the state context manager again but do not change state + # because _last_extracted_hash is not present (or different), the version will not change but state will be extracted anyway + with p.managed_state(extract_state=True): + pass + new_local_state_2 = p._get_state() + new_local_state_2_local = new_local_state_2.pop("_local") + assert new_local_state == new_local_state_2 + # there's extraction timestamp + assert "_last_extracted_at" in new_local_state_2_local + # and extract hash is == hash + assert new_local_state_2_local["_last_extracted_hash"] == new_local_state_2["_version_hash"] + # but the version didn't change + assert new_local_state["_state_version"] == new_local_state_2["_state_version"] + p.normalize(loader_file_format=destination_config.file_format) + info = p.load() + assert len(info.loads_ids) == 1 + + with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] new_stored_state_2 = load_pipeline_state_from_destination(p.pipeline_name, job_client) - # the stored state changed to next version - assert new_stored_state != new_stored_state_2 - assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] + # the stored state changed to next version + assert new_stored_state != new_stored_state_2 + assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] @pytest.mark.parametrize( @@ -224,9 +235,10 @@ def _make_dn_name(schema_name: str) -> str: default_schema = Schema("state") p._inject_schema(default_schema) + + # just sync schema without name - will use default schema + p.sync_schema() with p.destination_client() as job_client: - # just sync schema without name - will use default schema - p.sync_schema() assert get_normalized_dataset_name( job_client ) == default_schema.naming.normalize_table_identifier(dataset_name) @@ -242,9 +254,9 @@ def _make_dn_name(schema_name: str) -> str: ) == schema_two.naming.normalize_table_identifier(_make_dn_name("two")) schema_three = Schema("three") p._inject_schema(schema_three) + # sync schema with a name + p.sync_schema(schema_three.name) with p._get_destination_clients(schema_three)[0] as job_client: - # sync schema with a name - p.sync_schema(schema_three.name) assert get_normalized_dataset_name( job_client ) == schema_three.naming.normalize_table_identifier(_make_dn_name("three")) diff --git a/tests/load/qdrant/test_pipeline.py b/tests/load/qdrant/test_pipeline.py index e0cb9dab84..a33ecd2a8d 100644 --- a/tests/load/qdrant/test_pipeline.py +++ b/tests/load/qdrant/test_pipeline.py @@ -1,9 +1,12 @@ import pytest from typing import Iterator +from tempfile import TemporaryDirectory +import os import dlt from dlt.common import json from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrStr from dlt.destinations.adapters import qdrant_adapter from dlt.destinations.impl.qdrant.qdrant_adapter import qdrant_adapter, VECTORIZE_HINT @@ -11,6 +14,7 @@ from tests.pipeline.utils import assert_load_info from tests.load.qdrant.utils import drop_active_pipeline_data, assert_collection from tests.load.utils import sequence_generator +from tests.utils import preserve_environ # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -361,3 +365,20 @@ def test_empty_dataset_allowed() -> None: assert client.dataset_name is None assert client.sentinel_collection == "DltSentinelCollection" assert_collection(p, "content", expected_items_count=3) + + +def test_qdrant_local_parallelism_disabled(preserve_environ) -> None: + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "20" + + with TemporaryDirectory() as tmpdir: + p = dlt.pipeline(destination=dlt.destinations.qdrant(path=tmpdir)) + + # Data writer limit ensures that we create multiple load files to the same table + @dlt.resource + def q_data(): + for i in range(222): + yield {"doc_id": i, "content": f"content {i}"} + + info = p.run(q_data) + + assert_load_info(info) diff --git a/tests/load/utils.py b/tests/load/utils.py index 00ed4e3bf3..9ee933a07a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -13,6 +13,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, JobClientBase, @@ -129,6 +130,7 @@ class DestinationTestConfiguration: supports_dbt: bool = True disable_compression: bool = False dev_mode: bool = False + credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None @property def name(self) -> str: @@ -166,6 +168,10 @@ def setup(self) -> None: if self.destination == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + if self.credentials is not None: + for key, value in dict(self.credentials).items(): + os.environ[f"DESTINATION__CREDENTIALS__{key.upper()}"] = str(value) + def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: @@ -279,6 +285,12 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination="weaviate"), DestinationTestConfiguration(destination="lancedb"), + DestinationTestConfiguration( + destination="qdrant", + credentials=dict(path=str(Path(FILE_BUCKET) / "qdrant_data")), + extra_info="local-file", + ), + DestinationTestConfiguration(destination="qdrant", extra_info="server"), ] if default_staging_configs or all_staging_configs: