Skip to content

Commit

Permalink
Merge pull request #52 from MITLibraries/TIMX-417-read-from-dataset
Browse files Browse the repository at this point in the history
TIMX 417 - read from dataset
  • Loading branch information
ghukill authored Jan 3, 2025
2 parents 41b70e3 + 81cdccf commit 47e74c3
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 114 deletions.
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,23 @@ def fixed_local_dataset(tmp_path) -> TIMDEXDataset:
method.
"""
timdex_dataset = TIMDEXDataset(str(tmp_path / "fixed_local_dataset/"))
timdex_dataset.write(generate_sample_records(num_records=5_000, run_id="abc123"))
for source, run_id in [
("alma", "abc123"),
("dspace", "def456"),
("aspace", "ghi789"),
("libguides", "jkl123"),
("gismit", "mno456"),
]:
timdex_dataset.write(
generate_sample_records(
num_records=1_000,
timdex_record_id_prefix=source,
source=source,
run_date="2024-12-01",
run_id=run_id,
)
)
timdex_dataset.load()
return timdex_dataset


Expand Down
91 changes: 91 additions & 0 deletions tests/test_dataset_read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# ruff: noqa: PLR2004, PD901

import pandas as pd
import pyarrow as pa
import pytest

DATASET_COLUMNS_SET = {
"timdex_record_id",
"source_record",
"transformed_record",
"source",
"run_date",
"run_type",
"run_id",
"action",
"year",
"month",
"day",
}


def test_read_batches_yields_pyarrow_record_batches(fixed_local_dataset):
batches = fixed_local_dataset.read_batches_iter()
batch = next(batches)
assert isinstance(batch, pa.RecordBatch)


def test_read_batches_all_columns_by_default(fixed_local_dataset):
batches = fixed_local_dataset.read_batches_iter()
batch = next(batches)
assert set(batch.column_names) == DATASET_COLUMNS_SET


def test_read_batches_filter_columns(fixed_local_dataset):
columns_subset = ["source", "transformed_record"]
batches = fixed_local_dataset.read_batches_iter(columns=columns_subset)
batch = next(batches)
assert set(batch.column_names) == set(columns_subset)


def test_read_batches_no_filters_gets_full_dataset(fixed_local_dataset):
batches = fixed_local_dataset.read_batches_iter()
table = pa.Table.from_batches(batches)
assert len(table) == fixed_local_dataset.row_count


def test_read_batches_with_filters_gets_subset_of_dataset(fixed_local_dataset):
batches = fixed_local_dataset.read_batches_iter(
source="libguides",
run_date="2024-12-01",
run_type="daily",
action="index",
)

table = pa.Table.from_batches(batches)
assert len(table) == 1_000
assert len(table) < fixed_local_dataset.row_count

# assert loaded dataset is unchanged by filtering for a read method
assert fixed_local_dataset.row_count == 5_000


def test_read_dataframe_batches_yields_dataframes(fixed_local_dataset):
df_iter = fixed_local_dataset.read_dataframes_iter()
df_batch = next(df_iter)
assert isinstance(df_batch, pd.DataFrame)
assert len(df_batch) == 1_000


def test_read_dataframe_reads_all_dataset_rows_after_filtering(fixed_local_dataset):
df = fixed_local_dataset.read_dataframe()
assert isinstance(df, pd.DataFrame)
assert len(df) == fixed_local_dataset.row_count


def test_read_dicts_yields_dictionary_for_each_dataset_record(fixed_local_dataset):
records = fixed_local_dataset.read_dicts_iter()
record = next(records)
assert isinstance(record, dict)
assert set(record.keys()) == DATASET_COLUMNS_SET


def test_read_batches_filter_to_none_returns_empty_list(fixed_local_dataset):
batches = fixed_local_dataset.read_batches_iter(source="not-gonna-find-me")
assert list(batches) == []


def test_read_dicts_filter_to_none_stopiteration_immediately(fixed_local_dataset):
batches = fixed_local_dataset.read_dicts_iter(source="not-gonna-find-me")
with pytest.raises(StopIteration):
next(batches)
2 changes: 1 addition & 1 deletion timdex_dataset_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from timdex_dataset_api.dataset import TIMDEXDataset
from timdex_dataset_api.record import DatasetRecord

__version__ = "0.4.0"
__version__ = "0.5.0"

__all__ = [
"DatasetRecord",
Expand Down
Loading

0 comments on commit 47e74c3

Please sign in to comment.