Skip to content

Commit

Permalink
added projection kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto committed Jan 12, 2025
1 parent ed0f5a4 commit 7d52e1b
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 15 deletions.
11 changes: 10 additions & 1 deletion sources/mongodb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Source that loads collections form any a mongo database, supports incremental loads."""

from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Union, Mapping

import dlt
from dlt.common.data_writers import TDataItemFormat
Expand Down Expand Up @@ -73,6 +73,7 @@ def mongodb(
parallel=parallel,
limit=limit,
filter_=filter_ or {},
projection=None,
)


Expand All @@ -90,6 +91,7 @@ def mongodb_collection(
chunk_size: Optional[int] = 10000,
data_item_format: Optional[TDataItemFormat] = "object",
filter_: Optional[Dict[str, Any]] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> Any:
"""
A DLT source which loads a collection from a mongo database using PyMongo.
Expand All @@ -109,6 +111,12 @@ def mongodb_collection(
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
filter_ (Optional[Dict[str, Any]]): The filter to apply to the collection.
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
when loading the collection. Supported inputs:
include (list) - ["year", "title"]
include (dict) - {"year": 1, "title": 1}
exclude (dict) - {"released": 0, "runtime": 0}
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand Down Expand Up @@ -136,4 +144,5 @@ def mongodb_collection(
chunk_size=chunk_size,
data_item_format=data_item_format,
filter_=filter_ or {},
projection=projection,
)
104 changes: 90 additions & 14 deletions sources/mongodb/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Mongo database source helpers"""

from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union, Iterable, Mapping

import dlt
from bson.decimal128 import Decimal128
Expand All @@ -18,6 +18,7 @@
from pymongo import ASCENDING, DESCENDING, MongoClient
from pymongo.collection import Collection
from pymongo.cursor import Cursor
from pymongo.helpers import _fields_list_to_dict


if TYPE_CHECKING:
Expand Down Expand Up @@ -106,6 +107,41 @@ def _filter_op(self) -> Dict[str, Any]:
filt[self.cursor_field]["$gt"] = self.incremental.end_value

return filt

def _projection_op(self, projection) -> Optional[Dict[str, Any]]:
"""Build a projection operator.
A tuple of fields to include or a dict specifying fields to include or exclude.
The incremental `primary_key` needs to be handle differently for inclusion
and exclusion projections.
Returns:
Tuple[str, ...] | Dict[str, Any]: A tuple or dictionary with the projection operator.
"""
if projection is None:
return None

projection_dict = dict(_fields_list_to_dict(projection, "projection"))

# NOTE we can still filter on primary_key if it's excluded from projection
if self.incremental:
# this is an inclusion projection
if any(v == 1 for v in projection.values()):
# ensure primary_key is included
projection_dict.update({self.incremental.primary_key: 1})
# this is an exclusion projection
else:
try:
# ensure primary_key isn't excluded
projection_dict.pop(self.incremental.primary_key)
except KeyError:
pass # primary_key was properly not included in exclusion projection
else:
dlt.common.logger.warn(
f"Primary key `{self.incremental.primary_key} was removed from exclusion projection"
)

return projection_dict

def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> TCursor: # type: ignore
"""Apply a limit to the cursor, if needed.
Expand All @@ -128,7 +164,10 @@ def _limit(self, cursor: Cursor, limit: Optional[int] = None) -> TCursor: # typ
return cursor

def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> Iterator[TDataItem]:
"""Construct the query and load the documents from the collection.
Expand All @@ -143,7 +182,9 @@ def load_documents(
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
projection_op = self._projection_op(projection)

cursor = self.collection.find(filter=filter_op, projection=projection_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand Down Expand Up @@ -171,7 +212,11 @@ def _create_batches(self, limit: Optional[int] = None) -> List[Dict[str, int]]:

return batches

def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
def _get_cursor(
self,
filter_: Dict[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> TCursor:
"""Get a reading cursor for the collection.
Args:
Expand All @@ -184,7 +229,9 @@ def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find(filter=filter_op)
projection_op = self._projection_op(projection)

cursor = self.collection.find(filter=filter_op, projection=projection_op)
if self._sort_op:
cursor = cursor.sort(self._sort_op)

Expand All @@ -201,7 +248,10 @@ def _run_batch(self, cursor: TCursor, batch: Dict[str, int]) -> TDataItem:
return data

def _get_all_batches(
self, filter_: Dict[str, Any], limit: Optional[int] = None
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> Iterator[TDataItem]:
"""Load all documents from the collection in parallel batches.
Expand All @@ -213,13 +263,16 @@ def _get_all_batches(
Iterator[TDataItem]: An iterator of the loaded documents.
"""
batches = self._create_batches(limit=limit)
cursor = self._get_cursor(filter_=filter_)
cursor = self._get_cursor(filter_=filter_, projection=projection)

for batch in batches:
yield self._run_batch(cursor=cursor, batch=batch)

def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> Iterator[TDataItem]:
"""Load documents from the collection in parallel.
Expand All @@ -230,7 +283,9 @@ def load_documents(
Yields:
Iterator[TDataItem]: An iterator of the loaded documents.
"""
for document in self._get_all_batches(limit=limit, filter_=filter_):
for document in self._get_all_batches(
limit=limit, filter_=filter_, projection=projection
):
yield document


Expand All @@ -241,7 +296,10 @@ class CollectionArrowLoader(CollectionLoader):
"""

def load_documents(
self, filter_: Dict[str, Any], limit: Optional[int] = None
self,
filter_: Dict[str, Any],
limit: Optional[int] = None,
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> Iterator[Any]:
"""
Load documents from the collection in Apache Arrow format.
Expand All @@ -264,7 +322,12 @@ def load_documents(
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

cursor = self.collection.find_raw_batches(filter_, batch_size=self.chunk_size)
projection_op = self._projection_op(projection)

# NOTE the `filter_op` isn't passed
cursor = self.collection.find_raw_batches(
filter_, batch_size=self.chunk_size, projection=projection_op
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore

Expand All @@ -283,7 +346,11 @@ class CollectionArrowLoaderParallel(CollectionLoaderParallel):
Apache Arrow for data processing.
"""

def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
def _get_cursor(
self,
filter_: Dict[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None,
) -> TCursor:
"""Get a reading cursor for the collection.
Args:
Expand All @@ -296,8 +363,10 @@ def _get_cursor(self, filter_: Dict[str, Any]) -> TCursor:
_raise_if_intersection(filter_op, filter_)
filter_op.update(filter_)

projection_op = self._projection_op(projection)

cursor = self.collection.find_raw_batches(
filter=filter_op, batch_size=self.chunk_size
filter=filter_op, batch_size=self.chunk_size, projection=projection_op
)
if self._sort_op:
cursor = cursor.sort(self._sort_op) # type: ignore
Expand Down Expand Up @@ -326,6 +395,7 @@ def collection_documents(
client: TMongoClient,
collection: TCollection,
filter_: Dict[str, Any],
projection: Union[Dict[str, Any], List[str]], # TODO kwargs reserved for dlt?
incremental: Optional[dlt.sources.incremental[Any]] = None,
parallel: bool = False,
limit: Optional[int] = None,
Expand All @@ -348,6 +418,12 @@ def collection_documents(
Supported formats:
object - Python objects (dicts, lists).
arrow - Apache Arrow tables.
projection: (Optional[Union[Mapping[str, Any], Iterable[str]]]): The projection to select columns
when loading the collection. Supported inputs:
include (list) - ["year", "title"]
include (dict) - {"year": 1, "title": 1}
exclude (dict) - {"released": 0, "runtime": 0}
Note: Can't mix include and exclude statements '{"title": 1, "released": 0}`
Returns:
Iterable[DltResource]: A list of DLT resources for each collection to be loaded.
Expand All @@ -372,7 +448,7 @@ def collection_documents(
loader = LoaderClass(
client, collection, incremental=incremental, chunk_size=chunk_size
)
for data in loader.load_documents(limit=limit, filter_=filter_):
for data in loader.load_documents(limit=limit, filter_=filter_, projection=projection):
yield data


Expand Down
100 changes: 100 additions & 0 deletions tests/mongodb/test_mongodb_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,106 @@ def test_filter_intersect(destination_name):
pipeline.run(movies)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_projection_list_inclusion(destination_name):
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
collection_name = "movies"
projection = ["title", "poster"]
expected_columns = projection + ["_id", "_dlt_id", "_dlt_load_id"]

movies = mongodb_collection(
collection=collection_name,
projection=projection,
limit=2
)
pipeline.run(movies)
loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys()

assert set(loaded_columns) == set(expected_columns)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_projection_dict_inclusion(destination_name):
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
collection_name = "movies"
projection = {"title": 1, "poster": 1}
expected_columns = list(projection.keys()) + ["_id", "_dlt_id", "_dlt_load_id"]

movies = mongodb_collection(
collection=collection_name,
projection=projection,
limit=2
)
pipeline.run(movies)
loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys()

assert set(loaded_columns) == set(expected_columns)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_projection_dict_exclusion(destination_name):
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
collection_name = "movies"
columns_to_exclude = [
"runtime", "released", "year", "plot", "fullplot", "lastupdated", "type",
"directors", "imdb", "cast", "countries", "genres", "tomatoes", "num_mflix_comments",
"rated", "awards"
]
projection = {col: 0 for col in columns_to_exclude}
expected_columns = ["title", "poster", "_id", "_dlt_id", "_dlt_load_id"]

movies = mongodb_collection(
collection=collection_name,
projection=projection,
limit=2
)
pipeline.run(movies)
loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys()

assert set(loaded_columns) == set(expected_columns)


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
def test_projection_nested_field(destination_name):
pipeline = dlt.pipeline(
pipeline_name="mongodb_test",
destination=destination_name,
dataset_name="mongodb_test_data",
full_refresh=True,
)
collection_name = "movies"
projection = ["imdb.votes", "poster"]
expected_columns = ["imdb__votes", "poster", "_id", "_dlt_id", "_dlt_load_id"]
# other documents nested under `imdb` shouldn't be loaded
not_expected_columns = ["imdb__rating", "imdb__id"]

movies = mongodb_collection(
collection=collection_name,
projection=projection,
limit=2
)
pipeline.run(movies)
loaded_columns = pipeline.default_schema.get_table_columns(collection_name).keys()

assert set(loaded_columns) == set(expected_columns)
assert len(set(loaded_columns).intersection(not_expected_columns)) == 0


@pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)
@pytest.mark.parametrize("data_item_format", ["object", "arrow"])
def test_mongodb_without_pymongoarrow(
Expand Down

0 comments on commit 7d52e1b

Please sign in to comment.