From 623606de608d72f77ad18d098a9bfc39fb3c7341 Mon Sep 17 00:00:00 2001 From: Sylvain Lesage Date: Tue, 12 Apr 2022 10:15:54 +0200 Subject: [PATCH] Simplify cache by dropping two collections (#202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docs: โœ๏ธ add backup/restore to migration instructions * feat: ๐ŸŽธ pass the max number of rows to the worker * feat: ๐ŸŽธ delete the 'rows' and 'columns' collections instead of keeping a large collection of rows and columns, then compute the response on every endpoint call, possibly truncating the response, we now pre-compute the response and store it in the cache. We lose the ability to get the original data, but we don't need it. It fixes #197. See https://github.com/huggingface/datasets-preview-backend/issues/197#issuecomment-1092700076. BREAKING CHANGE: ๐Ÿงจ the cache database structure has been modified. Run 20220408_cache_remove_dbrow_dbcolumn.py to migrate the database. * style: ๐Ÿ’„ fix types and style * docs: โœ๏ธ add parameter to avoid error in mongodump * docs: โœ๏ธ mark ROWS_MAX_BYTES and ROWS_MIN_NUMBER as worker vars --- .env.example | 18 +- README.md | 5 +- src/datasets_preview_backend/config.py | 8 +- src/datasets_preview_backend/constants.py | 6 +- src/datasets_preview_backend/io/cache.py | 301 +++--------------- .../20220406_cache_dbrow_status_and_since.py | 21 +- .../20220408_cache_remove_dbrow_dbcolumn.py | 168 ++++++++++ .../io/migrations/README.md | 21 +- .../io/migrations/_utils.py | 12 +- .../io/migrations/validate.py | 5 + src/datasets_preview_backend/models/row.py | 19 +- src/datasets_preview_backend/models/split.py | 176 +++++++++- .../models/typed_row.py | 5 +- src/datasets_preview_backend/routes/rows.py | 10 +- src/datasets_preview_backend/worker.py | 10 + tests/io/test_cache.py | 24 +- tests/models/test_row.py | 25 +- tests/models/test_split.py | 9 +- tests/models/test_typed_row.py | 27 +- tests/test_app.py | 53 ++- 20 files changed, 562 insertions(+), 361 deletions(-) create mode 100644 src/datasets_preview_backend/io/migrations/20220408_cache_remove_dbrow_dbcolumn.py create mode 100644 src/datasets_preview_backend/io/migrations/validate.py diff --git a/.env.example b/.env.example index 6e289c410b..cde82ffcf5 100644 --- a/.env.example +++ b/.env.example @@ -35,15 +35,6 @@ # URL to connect to mongo db # MONGO_URL="mongodb://localhost:27018" -# Max size of the /rows endpoint response in bytes -# ROWS_MAX_BYTES=1_000_000 - -# Max number of rows in the /rows endpoint response -# ROWS_MAX_NUMBER=100 - -# Min number of rows in the /rows endpoint response -# ROWS_MIN_NUMBER=10 - # Number of uvicorn workers # WEB_CONCURRENCY = 2 @@ -66,6 +57,15 @@ # Max size (in bytes) of the dataset to fallback in normal mode if streaming fails # MAX_SIZE_FALLBACK = 100_000_000 +# Max size of the /rows endpoint response in bytes +# ROWS_MAX_BYTES=1_000_000 + +# Max number of rows in the /rows endpoint response +# ROWS_MAX_NUMBER=100 + +# Min number of rows in the /rows endpoint response +# ROWS_MIN_NUMBER=10 + # Number of seconds a worker will sleep before trying to process a new job # WORKER_SLEEP_SECONDS = 5 diff --git a/README.md b/README.md index aa60bf19b2..2359398814 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,6 @@ Set environment variables to configure the following aspects: - `MONGO_CACHE_DATABASE`: the name of the database used for storing the cache. Defaults to `"datasets_preview_cache"`. - `MONGO_QUEUE_DATABASE`: the name of the database used for storing the queue. Defaults to `"datasets_preview_queue"`. - `MONGO_URL`: the URL used to connect to the mongo db server. Defaults to `"mongodb://localhost:27018"`. -- `ROWS_MAX_BYTES`: max size of the /rows endpoint response in bytes. Defaults to `1_000_000` (1 MB). -- `ROWS_MAX_NUMBER`: max number of rows in the /rows endpoint response. Defaults to `100`. -- `ROWS_MIN_NUMBER`: min number of rows in the /rows endpoint response. Defaults to `10`. - `WEB_CONCURRENCY`: the number of workers. For now, it's ignored and hardcoded to 1 because the cache is not shared yet. Defaults to `2`. For example: @@ -71,6 +68,8 @@ Also specify `HF_TOKEN` with an App Access Token (ask moonlanding administrators Also specify `MAX_SIZE_FALLBACK` with the maximum size in bytes of the dataset to fallback in normal mode if streaming fails. Note that it requires to have the size in the info metadata. Set to `0` to disable the fallback. Defaults to `100_000_000`. +`ROWS_MIN_NUMBER` is the min number (defaults to `10`) and `ROWS_MAX_NUMBER` the max number (defaults to `100`) of rows fetched by the worker for the split, and provided in the /rows endpoint response. `ROWS_MAX_BYTES` is the max size of the /rows endpoint response in bytes. Defaults to `1_000_000` (1 MB). + The `WORKER_QUEUE` variable specifies which jobs queue the worker will pull jobs from. It can be equal to `datasets` (default) or `splits`. The `datasets` jobs should be a lot faster than the `splits` ones, so that we should need a lot more workers for `splits` than for `datasets`. To warm the cache, ie. add all the missing Hugging Face datasets to the queue: diff --git a/src/datasets_preview_backend/config.py b/src/datasets_preview_backend/config.py index bba7fa3c68..704c6b4455 100644 --- a/src/datasets_preview_backend/config.py +++ b/src/datasets_preview_backend/config.py @@ -44,10 +44,12 @@ MONGO_CACHE_DATABASE = get_str_value(d=os.environ, key="MONGO_CACHE_DATABASE", default=DEFAULT_MONGO_CACHE_DATABASE) MONGO_QUEUE_DATABASE = get_str_value(d=os.environ, key="MONGO_QUEUE_DATABASE", default=DEFAULT_MONGO_QUEUE_DATABASE) MONGO_URL = get_str_value(d=os.environ, key="MONGO_URL", default=DEFAULT_MONGO_URL) -ROWS_MAX_BYTES = get_int_value(d=os.environ, key="ROWS_MAX_BYTES", default=DEFAULT_ROWS_MAX_BYTES) -ROWS_MAX_NUMBER = get_int_value(d=os.environ, key="ROWS_MAX_NUMBER", default=DEFAULT_ROWS_MAX_NUMBER) -ROWS_MIN_NUMBER = get_int_value(d=os.environ, key="ROWS_MIN_NUMBER", default=DEFAULT_ROWS_MIN_NUMBER) WEB_CONCURRENCY = get_int_value(d=os.environ, key="WEB_CONCURRENCY", default=DEFAULT_WEB_CONCURRENCY) # Ensure datasets library uses the expected revision for canonical datasets os.environ["HF_SCRIPTS_VERSION"] = DATASETS_REVISION + +# for tests - to be removed +ROWS_MAX_BYTES = get_int_value(d=os.environ, key="ROWS_MAX_BYTES", default=DEFAULT_ROWS_MAX_BYTES) +ROWS_MAX_NUMBER = get_int_value(d=os.environ, key="ROWS_MAX_NUMBER", default=DEFAULT_ROWS_MAX_NUMBER) +ROWS_MIN_NUMBER = get_int_value(d=os.environ, key="ROWS_MIN_NUMBER", default=DEFAULT_ROWS_MIN_NUMBER) diff --git a/src/datasets_preview_backend/constants.py b/src/datasets_preview_backend/constants.py index f1a8c4b047..ac90a77271 100644 --- a/src/datasets_preview_backend/constants.py +++ b/src/datasets_preview_backend/constants.py @@ -11,9 +11,6 @@ DEFAULT_MONGO_CACHE_DATABASE: str = "datasets_preview_cache" DEFAULT_MONGO_QUEUE_DATABASE: str = "datasets_preview_queue" DEFAULT_MONGO_URL: str = "mongodb://localhost:27018" -DEFAULT_ROWS_MAX_BYTES: int = 1_000_000 -DEFAULT_ROWS_MAX_NUMBER: int = 100 -DEFAULT_ROWS_MIN_NUMBER: int = 10 DEFAULT_WEB_CONCURRENCY: int = 2 DEFAULT_HF_TOKEN: Optional[str] = None @@ -21,6 +18,9 @@ DEFAULT_MAX_LOAD_PCT: int = 50 DEFAULT_MAX_MEMORY_PCT: int = 60 DEFAULT_MAX_SIZE_FALLBACK: int = 100_000_000 +DEFAULT_ROWS_MAX_BYTES: int = 1_000_000 +DEFAULT_ROWS_MAX_NUMBER: int = 100 +DEFAULT_ROWS_MIN_NUMBER: int = 10 DEFAULT_WORKER_SLEEP_SECONDS: int = 5 DEFAULT_WORKER_QUEUE: str = "datasets" diff --git a/src/datasets_preview_backend/io/cache.py b/src/datasets_preview_backend/io/cache.py index 6c6441d2df..3db895b188 100644 --- a/src/datasets_preview_backend/io/cache.py +++ b/src/datasets_preview_backend/io/cache.py @@ -1,6 +1,5 @@ import enum import logging -import sys import types from datetime import datetime from typing import ( @@ -26,25 +25,19 @@ StringField, ) from mongoengine.queryset.queryset import QuerySet +from pymongo.errors import DocumentTooLarge from datasets_preview_backend.config import MONGO_CACHE_DATABASE, MONGO_URL -from datasets_preview_backend.constants import DEFAULT_MIN_CELL_BYTES from datasets_preview_backend.exceptions import ( Status400Error, Status500Error, StatusError, ) -from datasets_preview_backend.models.column import ( - ClassLabelColumn, - ColumnDict, - ColumnType, -) from datasets_preview_backend.models.dataset import ( SplitFullName, get_dataset_split_full_names, ) from datasets_preview_backend.models.split import Split, get_split -from datasets_preview_backend.utils import orjson_dumps # START monkey patching ### hack ### # see https://github.com/sbdchd/mongo-types#install @@ -102,6 +95,10 @@ class SplitsResponse(TypedDict): splits: List[SplitItem] +def get_empty_rows_response() -> Dict[str, Any]: + return {"columns": [], "rows": []} + + class DbSplit(Document): dataset_name = StringField(required=True, unique_with=["config_name", "split_name"]) config_name = StringField(required=True) @@ -109,10 +106,11 @@ class DbSplit(Document): split_idx = IntField(required=True, min_value=0) # used to maintain the order num_bytes = IntField(min_value=0) num_examples = IntField(min_value=0) + rows_response = DictField(required=True) status = EnumField(Status, default=Status.EMPTY) since = DateTimeField(default=datetime.utcnow) - def to_item(self) -> SplitItem: + def to_split_item(self) -> SplitItem: return { "dataset": self.dataset_name, "config": self.config_name, @@ -128,86 +126,6 @@ def to_split_full_name(self) -> SplitFullName: objects = QuerySetManager["DbSplit"]() -class RowItem(TypedDict): - dataset: str - config: str - split: str - row_idx: int - row: Dict[str, Any] - truncated_cells: List[str] - - -class DbRow(Document): - dataset_name = StringField(required=True, unique_with=["config_name", "split_name", "row_idx"]) - config_name = StringField(required=True) - split_name = StringField(required=True) - row_idx = IntField(required=True, min_value=0) - row = DictField(required=True) - status = EnumField(Status, default=Status.EMPTY) - since = DateTimeField(default=datetime.utcnow) - - def to_item(self) -> RowItem: - if self.status == Status.VALID: - return { - "dataset": self.dataset_name, - "config": self.config_name, - "split": self.split_name, - "row_idx": self.row_idx, - "row": self.row, - "truncated_cells": [], - } - else: - return { - "dataset": self.dataset_name, - "config": self.config_name, - "split": self.split_name, - "row_idx": self.row_idx, - "row": self.row, - "truncated_cells": list(self.row.keys()), - } - - meta = {"collection": "rows", "db_alias": "cache"} - objects = QuerySetManager["DbRow"]() - - -class ColumnItem(TypedDict): - dataset: str - config: str - split: str - column_idx: int - column: ColumnDict - - -class RowsResponse(TypedDict): - columns: List[ColumnItem] - rows: List[RowItem] - - -class DbColumn(Document): - dataset_name = StringField(required=True, unique_with=["config_name", "split_name", "name"]) - config_name = StringField(required=True) - split_name = StringField(required=True) - column_idx = IntField(required=True, min_value=0) - name = StringField(required=True) - type = EnumField(ColumnType, required=True) - labels = ListField(StringField()) - - def to_item(self) -> ColumnItem: - column: ColumnDict = {"name": self.name, "type": self.type.name} - if self.labels: - column["labels"] = self.labels - return { - "dataset": self.dataset_name, - "config": self.config_name, - "split": self.split_name, - "column_idx": self.column_idx, - "column": column, - } - - meta = {"collection": "columns", "db_alias": "cache"} - objects = QuerySetManager["DbColumn"]() - - class _BaseErrorItem(TypedDict): status_code: int exception: str @@ -269,8 +187,6 @@ def to_item(self) -> ErrorItem: def upsert_dataset_error(dataset_name: str, error: StatusError) -> None: DbSplit.objects(dataset_name=dataset_name).delete() - DbRow.objects(dataset_name=dataset_name).delete() - DbColumn.objects(dataset_name=dataset_name).delete() DbDataset.objects(dataset_name=dataset_name).upsert_one(status=Status.ERROR) DbDatasetError.objects(dataset_name=dataset_name).upsert_one( status_code=error.status_code, @@ -300,8 +216,6 @@ def upsert_dataset(dataset_name: str, new_split_full_names: List[SplitFullName]) def upsert_split_error(dataset_name: str, config_name: str, split_name: str, error: StatusError) -> None: - DbRow.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - DbColumn.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() DbSplit.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).upsert_one( status=Status.ERROR ) @@ -315,58 +229,29 @@ def upsert_split_error(dataset_name: str, config_name: str, split_name: str, err ) -def upsert_split(dataset_name: str, config_name: str, split_name: str, split: Split) -> None: - rows = split["rows"] - columns = split["columns"] - - DbSplit.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).upsert_one( - status=Status.VALID, num_bytes=split["num_bytes"], num_examples=split["num_examples"] - ) - DbSplitError.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - - DbRow.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - for row_idx, row in enumerate(rows): - try: - DbRow( - dataset_name=dataset_name, - config_name=config_name, - split_name=split_name, - row_idx=row_idx, - row=row, - status=Status.VALID, - ).save() - except Exception: - DbRow( - dataset_name=dataset_name, - config_name=config_name, - split_name=split_name, - row_idx=row_idx, - row={column_name: "" for column_name in row.keys()}, - # ^ truncated to empty string - status=Status.ERROR, - ).save() - - DbColumn.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - for column_idx, column in enumerate(columns): - db_column = DbColumn( - dataset_name=dataset_name, - config_name=config_name, - split_name=split_name, - column_idx=column_idx, - name=column.name, - type=column.type, +def upsert_split( + dataset_name: str, + config_name: str, + split_name: str, + split: Split, +) -> None: + try: + DbSplit.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).upsert_one( + status=Status.VALID, + num_bytes=split["num_bytes"], + num_examples=split["num_examples"], + rows_response=split["rows_response"], # TODO: a class method + ) + DbSplitError.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() + except DocumentTooLarge as err: + upsert_split_error( + dataset_name, config_name, split_name, Status500Error("could not store the rows/ cache entry.", err) ) - # TODO: seems like suboptimal code, introducing unnecessary coupling - if isinstance(column, ClassLabelColumn): - db_column.labels = column.labels - db_column.save() def delete_dataset_cache(dataset_name: str) -> None: DbDataset.objects(dataset_name=dataset_name).delete() DbSplit.objects(dataset_name=dataset_name).delete() - DbRow.objects(dataset_name=dataset_name).delete() - DbColumn.objects(dataset_name=dataset_name).delete() DbDatasetError.objects(dataset_name=dataset_name).delete() DbSplitError.objects(dataset_name=dataset_name).delete() @@ -374,8 +259,6 @@ def delete_dataset_cache(dataset_name: str) -> None: def clean_database() -> None: DbDataset.drop_collection() # type: ignore DbSplit.drop_collection() # type: ignore - DbRow.drop_collection() # type: ignore - DbColumn.drop_collection() # type: ignore DbDatasetError.drop_collection() # type: ignore DbSplitError.drop_collection() # type: ignore @@ -400,10 +283,8 @@ def delete_split(split_full_name: SplitFullName): dataset_name = split_full_name["dataset_name"] config_name = split_full_name["config_name"] split_name = split_full_name["split_name"] - DbRow.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - DbColumn.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() DbSplit.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() - # DbSplitError.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() + DbSplitError.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).delete() logger.debug(f"dataset '{dataset_name}': deleted split {split_name} from config {config_name}") @@ -417,6 +298,7 @@ def create_empty_split(split_full_name: SplitFullName, split_idx: int): split_name=split_name, status=Status.EMPTY, split_idx=split_idx, + rows_response=get_empty_rows_response(), ).save() logger.debug(f"dataset '{dataset_name}': created empty split {split_name} in config {config_name}") @@ -467,10 +349,20 @@ def refresh_split( split_name: str, hf_token: Optional[str] = None, max_size_fallback: Optional[int] = None, + rows_max_bytes: Optional[int] = None, + rows_max_number: Optional[int] = None, + rows_min_number: Optional[int] = None, ): try: split = get_split( - dataset_name, config_name, split_name, hf_token=hf_token, max_size_fallback=max_size_fallback + dataset_name, + config_name, + split_name, + hf_token=hf_token, + max_size_fallback=max_size_fallback, + rows_max_bytes=rows_max_bytes, + rows_max_number=rows_max_number, + rows_min_number=rows_min_number, ) upsert_split(dataset_name, config_name, split_name, split) logger.debug( @@ -525,113 +417,18 @@ def get_splits_response(dataset_name: str) -> Tuple[Union[SplitsResponse, None], return None, dataset_error.to_item(), dataset_error.status_code splits_response: SplitsResponse = { - "splits": [split.to_item() for split in DbSplit.objects(dataset_name=dataset_name).order_by("+split_idx")] + "splits": [ + split.to_split_item() for split in DbSplit.objects(dataset_name=dataset_name).order_by("+split_idx") + ] } return splits_response, None, 200 -def get_size_in_bytes(obj: Any): - return sys.getsizeof(orjson_dumps(obj)) - # ^^ every row is transformed here in a string, because it corresponds to - # the size the row will contribute in the JSON response to /rows endpoint. - # The size of the string is measured in bytes. - # An alternative would have been to look at the memory consumption (pympler) but it's - # less related to what matters here (size of the JSON, number of characters in the - # dataset viewer table on the hub) - - -def truncate_cell(cell: Any, min_cell_bytes: int) -> str: - return orjson_dumps(cell)[:min_cell_bytes].decode("utf8", "ignore") - - -# Mutates row_item, and returns it anyway -def truncate_row_item(row_item: RowItem) -> RowItem: - min_cell_bytes = DEFAULT_MIN_CELL_BYTES - row = {} - for column_name, cell in row_item["row"].items(): - # for now: all the cells, but the smallest ones, are truncated - cell_bytes = get_size_in_bytes(cell) - if cell_bytes > min_cell_bytes: - row_item["truncated_cells"].append(column_name) - row[column_name] = truncate_cell(cell, min_cell_bytes) - else: - row[column_name] = cell - row_item["row"] = row - return row_item - - -# Mutates row_items, and returns them anyway -def truncate_row_items(row_items: List[RowItem], rows_max_bytes: int) -> List[RowItem]: - # compute the current size - rows_bytes = sum(get_size_in_bytes(row_item) for row_item in row_items) - - # Loop backwards, so that the last rows are truncated first - for row_item in reversed(row_items): - previous_size = get_size_in_bytes(row_item) - row_item = truncate_row_item(row_item) - new_size = get_size_in_bytes(row_item) - rows_bytes += new_size - previous_size - row_idx = row_item["row_idx"] - logger.debug(f"the size of the rows is now ({rows_bytes}) after truncating row idx={row_idx}") - if rows_bytes < rows_max_bytes: - break - return row_items - - -def to_row_items( - rows: QuerySet[DbRow], rows_max_bytes: Optional[int], rows_min_number: Optional[int] -) -> List[RowItem]: - row_items = [] - rows_bytes = 0 - if rows_min_number is None: - rows_min_number = 0 - else: - logger.debug(f"min number of rows in the response: '{rows_min_number}'") - if rows_max_bytes is not None: - logger.debug(f"max number of bytes in the response: '{rows_max_bytes}'") - - # two restrictions must be enforced: - # - at least rows_min_number rows - # - at most rows_max_bytes bytes - # To enforce this: - # 1. first get the first rows_min_number rows - for row in rows[:rows_min_number]: - row_item = row.to_item() - if rows_max_bytes is not None: - rows_bytes += get_size_in_bytes(row_item) - row_items.append(row_item) - - # 2. if the total is over the bytes limit, truncate the values, iterating backwards starting - # from the last rows, until getting under the threshold - if rows_max_bytes is not None and rows_bytes >= rows_max_bytes: - logger.debug( - f"the size of the first {rows_min_number} rows ({rows_bytes}) is above the max number of bytes" - f" ({rows_max_bytes}), they will be truncated" - ) - return truncate_row_items(row_items, rows_max_bytes) - - # 3. else: add the remaining rows until the end, or until the bytes threshold - for idx, row in enumerate(rows[rows_min_number:]): - row_item = row.to_item() - if rows_max_bytes is not None: - rows_bytes += get_size_in_bytes(row_item) - if rows_bytes >= rows_max_bytes: - logger.debug( - f"the rows in the split have been truncated to {rows_min_number + idx} row(s) to keep the size" - f" ({rows_bytes}) under the limit ({rows_max_bytes})" - ) - break - row_items.append(row_item) - return row_items - - def get_rows_response( dataset_name: str, config_name: str, split_name: str, - rows_max_bytes: Optional[int] = None, - rows_min_number: Optional[int] = None, -) -> Tuple[Union[RowsResponse, None], Union[ErrorItem, None], int]: +) -> Tuple[Union[Dict[str, Any], None], Union[ErrorItem, None], int]: try: split = DbSplit.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).get() except DoesNotExist as e: @@ -649,21 +446,7 @@ def get_rows_response( # ^ can raise DoesNotExist or MultipleObjectsReturned, which should not occur -> we let the exception raise return None, split_error.to_item(), split_error.status_code - # TODO: if status is Status.STALLED, mention it in the response? - columns = DbColumn.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).order_by( - "+column_idx" - ) - # TODO: on some datasets, such as "edbeeching/decision_transformer_gym_replay", it takes a long time, and we - # truncate it anyway in to_row_items(). We might optimize here - rows = DbRow.objects(dataset_name=dataset_name, config_name=config_name, split_name=split_name).order_by( - "+row_idx" - ) - row_items = to_row_items(rows, rows_max_bytes, rows_min_number) - rows_response: RowsResponse = { - "columns": [column.to_item() for column in columns], - "rows": row_items, - } - return rows_response, None, 200 + return split.rows_response, None, 200 # special reports diff --git a/src/datasets_preview_backend/io/migrations/20220406_cache_dbrow_status_and_since.py b/src/datasets_preview_backend/io/migrations/20220406_cache_dbrow_status_and_since.py index 74241e5f40..fe6f827f92 100644 --- a/src/datasets_preview_backend/io/migrations/20220406_cache_dbrow_status_and_since.py +++ b/src/datasets_preview_backend/io/migrations/20220406_cache_dbrow_status_and_since.py @@ -1,19 +1,14 @@ from datetime import datetime -from datasets_preview_backend.io.cache import DbRow, Status, connect_to_cache -from datasets_preview_backend.io.migrations._utils import check_documents +from pymongo import MongoClient -# connect -connect_to_cache() - -# migrate -DbRow.objects().update(status=Status.VALID, since=datetime.utcnow) +from datasets_preview_backend.config import MONGO_CACHE_DATABASE, MONGO_URL +from datasets_preview_backend.io.cache import Status +client = MongoClient(MONGO_URL) +db = client[MONGO_CACHE_DATABASE] -# validate -def custom_validation(row: DbRow) -> None: - if row.status != Status.VALID: - raise ValueError(f"row status should be '{Status.VALID}', got '{row.status}'") - -check_documents(DbRow, 100, custom_validation) +# migrate +rows_coll = db.rows +rows_coll.update_many({}, {"$set": {"status": Status.VALID.value, "since": datetime.utcnow}}) diff --git a/src/datasets_preview_backend/io/migrations/20220408_cache_remove_dbrow_dbcolumn.py b/src/datasets_preview_backend/io/migrations/20220408_cache_remove_dbrow_dbcolumn.py new file mode 100644 index 0000000000..62625f7e4c --- /dev/null +++ b/src/datasets_preview_backend/io/migrations/20220408_cache_remove_dbrow_dbcolumn.py @@ -0,0 +1,168 @@ +import base64 +import sys +from enum import Enum, auto +from typing import Any, Dict, List, TypedDict + +import orjson +from pymongo import MongoClient + +from datasets_preview_backend.config import MONGO_CACHE_DATABASE, MONGO_URL +from datasets_preview_backend.io.cache import Status + +client = MongoClient(MONGO_URL) +db = client[MONGO_CACHE_DATABASE] + + +# copy code required for the migration (it might disappear in next iterations) +class RowItem(TypedDict): + dataset: str + config: str + split: str + row_idx: int + row: Dict[str, Any] + truncated_cells: List[str] + + +class ColumnType(Enum): + JSON = auto() # default + BOOL = auto() + INT = auto() + FLOAT = auto() + STRING = auto() + IMAGE_URL = auto() + RELATIVE_IMAGE_URL = auto() + AUDIO_RELATIVE_SOURCES = auto() + CLASS_LABEL = auto() + + +def get_empty_rows_response() -> Dict[str, Any]: + return {"columns": [], "rows": []} + + +def to_column_item(column: Dict[str, Any]) -> Dict[str, Any]: + column_field = { + "name": column["name"], + "type": ColumnType(column["type"]).name, + } + if "labels" in column and len(column["labels"]) > 0: + column_field["labels"] = column["labels"] + + return { + "dataset": column["dataset_name"], + "config": column["config_name"], + "split": column["split_name"], + "column_idx": column["column_idx"], + "column": column_field, + } + + +# orjson is used to get rid of errors with datetime (see allenai/c4) +def orjson_default(obj: Any) -> Any: + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + raise TypeError + + +def orjson_dumps(content: Any) -> bytes: + return orjson.dumps(content, option=orjson.OPT_UTC_Z, default=orjson_default) + + +def get_size_in_bytes(obj: Any): + return sys.getsizeof(orjson_dumps(obj)) + # ^^ every row is transformed here in a string, because it corresponds to + # the size the row will contribute in the JSON response to /rows endpoint. + # The size of the string is measured in bytes. + # An alternative would have been to look at the memory consumption (pympler) but it's + # less related to what matters here (size of the JSON, number of characters in the + # dataset viewer table on the hub) + + +def truncate_cell(cell: Any, min_cell_bytes: int) -> str: + return orjson_dumps(cell)[:min_cell_bytes].decode("utf8", "ignore") + + +DEFAULT_MIN_CELL_BYTES = 100 + + +# Mutates row_item, and returns it anyway +def truncate_row_item(row_item: RowItem) -> RowItem: + min_cell_bytes = DEFAULT_MIN_CELL_BYTES + row = {} + for column_name, cell in row_item["row"].items(): + # for now: all the cells, but the smallest ones, are truncated + cell_bytes = get_size_in_bytes(cell) + if cell_bytes > min_cell_bytes: + row_item["truncated_cells"].append(column_name) + row[column_name] = truncate_cell(cell, min_cell_bytes) + else: + row[column_name] = cell + row_item["row"] = row + return row_item + + +# Mutates row_items, and returns them anyway +def truncate_row_items(row_items: List[RowItem], rows_max_bytes: int) -> List[RowItem]: + # compute the current size + rows_bytes = sum(get_size_in_bytes(row_item) for row_item in row_items) + + # Loop backwards, so that the last rows are truncated first + for row_item in reversed(row_items): + if rows_bytes < rows_max_bytes: + break + previous_size = get_size_in_bytes(row_item) + row_item = truncate_row_item(row_item) + new_size = get_size_in_bytes(row_item) + rows_bytes += new_size - previous_size + return row_items + + +def to_row_item(row: Dict[str, Any]) -> RowItem: + return { + "dataset": row["dataset_name"], + "config": row["config_name"], + "split": row["split_name"], + "row_idx": row["row_idx"], + "row": row["row"], + "truncated_cells": [], + } + + +# migrate +rows_max_bytes = 1_000_000 +splits_coll = db.splits +rows_coll = db.rows +columns_coll = db.columns +splits_coll.update_many({}, {"$set": {"rows_response": get_empty_rows_response()}}) +# ^ add the new field to all the splits +for split in splits_coll.find({"status": {"$in": [Status.VALID.value, Status.STALLED.value]}}): + print(f"update split {split}") + columns = list( + columns_coll.find( + { + "dataset_name": split["dataset_name"], + "config_name": split["config_name"], + "split_name": split["split_name"], + } + ) + ) + print(f"found {len(columns)} columns") + rows = list( + rows_coll.find( + { + "dataset_name": split["dataset_name"], + "config_name": split["config_name"], + "split_name": split["split_name"], + } + ) + ) + print(f"found {len(rows)} rows") + column_items = [to_column_item(column) for column in sorted(columns, key=lambda d: d["column_idx"])] + row_items = truncate_row_items( + [to_row_item(row) for row in sorted(rows, key=lambda d: d["row_idx"])], rows_max_bytes + ) + rows_response = {"columns": column_items, "rows": row_items} + splits_coll.update_one({"_id": split["_id"]}, {"$set": {"rows_response": rows_response}}) + +# ^ fill the rows_response field, only for VALID and STALLED +db["rows"].drop() +db["columns"].drop() diff --git a/src/datasets_preview_backend/io/migrations/README.md b/src/datasets_preview_backend/io/migrations/README.md index fc32eb3a30..7c810ccb82 100644 --- a/src/datasets_preview_backend/io/migrations/README.md +++ b/src/datasets_preview_backend/io/migrations/README.md @@ -8,10 +8,29 @@ When the structure of a database is changed, the data stored in the database mus The commit, and the release, MUST always give the list of migration scripts that must be applied to migrate. +Before apply the migration script, be sure to **backup** the database, in case of failure. + +```shell +mongodump --forceTableScan --uri=mongodb://localhost:27018 --archive=dump.bson +``` + To run a script, for example [20220406_cache_dbrow_status_and_since.py](./20220406_cache_dbrow_status_and_since.py): ```shell -poetry run python src/datasets_preview_backend/io/migrations/20220406_cache_dbrow_status_and_since.py +poetry run python src/datasets_preview_backend/io/migrations/.py +``` + +Then, validate with + +```shell +poetry run python src/datasets_preview_backend/io/migrations/validate.py +``` + +In case of **error**, restore the database, else remove the dump file + +```shell +# only in case of error! +mongorestore --drop --uri=mongodb://localhost:27018 --archive=dump.bson ``` ## Write a migration script diff --git a/src/datasets_preview_backend/io/migrations/_utils.py b/src/datasets_preview_backend/io/migrations/_utils.py index 0e6786a8fe..898cc4e26c 100644 --- a/src/datasets_preview_backend/io/migrations/_utils.py +++ b/src/datasets_preview_backend/io/migrations/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, ClassVar, Iterator, List, Optional, Type, TypeVar +from typing import Callable, Iterator, List, Optional, Type, TypeVar from mongoengine import Document from pymongo.collection import Collection @@ -9,11 +9,7 @@ class DocumentWithId(Document): id: str -class ExtendedDocument(DocumentWithId): - objects: ClassVar[Callable[[Any], DocumentWithId]] - - -U = TypeVar("U", bound=ExtendedDocument) +U = TypeVar("U", bound=DocumentWithId) DocumentClass = Type[U] CustomValidation = Callable[[U], None] # --- end @@ -27,10 +23,10 @@ def get_random_oids(collection: Collection, sample_size: int) -> List[int]: def get_random_documents(DocCls: DocumentClass, sample_size: int) -> Iterator[DocumentWithId]: doc_collection = DocCls._get_collection() random_oids = get_random_oids(doc_collection, sample_size) - return DocCls.objects(id__in=random_oids) + return DocCls.objects(id__in=random_oids) # type: ignore -def check_documents(DocCls: DocumentClass, sample_size: int, custom_validation: Optional[CustomValidation]): +def check_documents(DocCls: DocumentClass, sample_size: int, custom_validation: Optional[CustomValidation] = None): for doc in get_random_documents(DocCls, sample_size): # general validation (types and values) doc.validate() diff --git a/src/datasets_preview_backend/io/migrations/validate.py b/src/datasets_preview_backend/io/migrations/validate.py new file mode 100644 index 0000000000..726f1ca90c --- /dev/null +++ b/src/datasets_preview_backend/io/migrations/validate.py @@ -0,0 +1,5 @@ +from datasets_preview_backend.io.cache import DbSplit, connect_to_cache +from datasets_preview_backend.io.migrations._utils import check_documents + +connect_to_cache() +check_documents(DbSplit, 100) diff --git a/src/datasets_preview_backend/models/row.py b/src/datasets_preview_backend/models/row.py index 5e9e5480a1..f0c4c723be 100644 --- a/src/datasets_preview_backend/models/row.py +++ b/src/datasets_preview_backend/models/row.py @@ -4,7 +4,7 @@ from datasets import Dataset, DownloadMode, IterableDataset, load_dataset -from datasets_preview_backend.config import ROWS_MAX_NUMBER +from datasets_preview_backend.constants import DEFAULT_ROWS_MAX_NUMBER from datasets_preview_backend.utils import retry logger = logging.getLogger(__name__) @@ -15,8 +15,15 @@ @retry(logger=logger) def get_rows( - dataset_name: str, config_name: str, split_name: str, hf_token: Optional[str] = None, streaming: bool = True + dataset_name: str, + config_name: str, + split_name: str, + hf_token: Optional[str] = None, + streaming: bool = True, + rows_max_number: Optional[int] = None, ) -> List[Row]: + if rows_max_number is None: + rows_max_number = DEFAULT_ROWS_MAX_NUMBER dataset = load_dataset( dataset_name, name=config_name, @@ -30,10 +37,10 @@ def get_rows( raise TypeError("load_dataset should return an IterableDataset") elif not isinstance(dataset, Dataset): raise TypeError("load_dataset should return a Dataset") - rows_plus_one = list(itertools.islice(dataset, ROWS_MAX_NUMBER + 1)) + rows_plus_one = list(itertools.islice(dataset, rows_max_number + 1)) # ^^ to be able to detect if a split has exactly ROWS_MAX_NUMBER rows - if len(rows_plus_one) <= ROWS_MAX_NUMBER: + if len(rows_plus_one) <= rows_max_number: logger.debug(f"all the rows in the split have been fetched ({len(rows_plus_one)})") else: - logger.debug(f"the rows in the split have been truncated ({ROWS_MAX_NUMBER} rows)") - return rows_plus_one[:ROWS_MAX_NUMBER] + logger.debug(f"the rows in the split have been truncated ({rows_max_number} rows)") + return rows_plus_one[:rows_max_number] diff --git a/src/datasets_preview_backend/models/split.py b/src/datasets_preview_backend/models/split.py index 840ffb3474..237c7a9660 100644 --- a/src/datasets_preview_backend/models/split.py +++ b/src/datasets_preview_backend/models/split.py @@ -1,29 +1,180 @@ import logging -from typing import List, Optional, TypedDict +import sys +from typing import Any, Dict, List, Optional, TypedDict +from datasets_preview_backend.constants import DEFAULT_MIN_CELL_BYTES from datasets_preview_backend.models._guard import guard_blocked_datasets -from datasets_preview_backend.models.column import Column +from datasets_preview_backend.models.column import Column, ColumnDict from datasets_preview_backend.models.info import get_info from datasets_preview_backend.models.row import Row from datasets_preview_backend.models.typed_row import get_typed_rows_and_columns +from datasets_preview_backend.utils import orjson_dumps logger = logging.getLogger(__name__) +class RowItem(TypedDict): + dataset: str + config: str + split: str + row_idx: int + row: Dict[str, Any] + truncated_cells: List[str] + + +class ColumnItem(TypedDict): + dataset: str + config: str + split: str + column_idx: int + column: ColumnDict + + +class RowsResponse(TypedDict): + columns: List[ColumnItem] + rows: List[RowItem] + + class Split(TypedDict): split_name: str - rows: List[Row] - columns: List[Column] + rows_response: RowsResponse num_bytes: Optional[int] num_examples: Optional[int] +def get_size_in_bytes(obj: Any): + return sys.getsizeof(orjson_dumps(obj)) + # ^^ every row is transformed here in a string, because it corresponds to + # the size the row will contribute in the JSON response to /rows endpoint. + # The size of the string is measured in bytes. + # An alternative would have been to look at the memory consumption (pympler) but it's + # less related to what matters here (size of the JSON, number of characters in the + # dataset viewer table on the hub) + + +def truncate_cell(cell: Any, min_cell_bytes: int) -> str: + return orjson_dumps(cell)[:min_cell_bytes].decode("utf8", "ignore") + + +# Mutates row_item, and returns it anyway +def truncate_row_item(row_item: RowItem) -> RowItem: + min_cell_bytes = DEFAULT_MIN_CELL_BYTES + row = {} + for column_name, cell in row_item["row"].items(): + # for now: all the cells, but the smallest ones, are truncated + cell_bytes = get_size_in_bytes(cell) + if cell_bytes > min_cell_bytes: + row_item["truncated_cells"].append(column_name) + row[column_name] = truncate_cell(cell, min_cell_bytes) + else: + row[column_name] = cell + row_item["row"] = row + return row_item + + +# Mutates row_items, and returns them anyway +def truncate_row_items(row_items: List[RowItem], rows_max_bytes: int) -> List[RowItem]: + # compute the current size + rows_bytes = sum(get_size_in_bytes(row_item) for row_item in row_items) + + # Loop backwards, so that the last rows are truncated first + for row_item in reversed(row_items): + if rows_bytes < rows_max_bytes: + break + previous_size = get_size_in_bytes(row_item) + row_item = truncate_row_item(row_item) + new_size = get_size_in_bytes(row_item) + rows_bytes += new_size - previous_size + row_idx = row_item["row_idx"] + logger.debug(f"the size of the rows is now ({rows_bytes}) after truncating row idx={row_idx}") + return row_items + + +def to_row_item(dataset_name: str, config_name: str, split_name: str, row_idx: int, row: Row) -> RowItem: + return { + "dataset": dataset_name, + "config": config_name, + "split": split_name, + "row_idx": row_idx, + "row": row, + "truncated_cells": [], + } + + +def to_column_item( + dataset_name: str, config_name: str, split_name: str, column_idx: int, column: Column +) -> ColumnItem: + return { + "dataset": dataset_name, + "config": config_name, + "split": split_name, + "column_idx": column_idx, + "column": column.as_dict(), + } + + +def create_truncated_row_items( + dataset_name: str, + config_name: str, + split_name: str, + rows: List[Row], + rows_max_bytes: Optional[int] = None, + rows_min_number: Optional[int] = None, +) -> List[RowItem]: + row_items = [] + rows_bytes = 0 + if rows_min_number is None: + rows_min_number = 0 + else: + logger.debug(f"min number of rows in the response: '{rows_min_number}'") + if rows_max_bytes is not None: + logger.debug(f"max number of bytes in the response: '{rows_max_bytes}'") + + # two restrictions must be enforced: + # - at least rows_min_number rows + # - at most rows_max_bytes bytes + # To enforce this: + # 1. first get the first rows_min_number rows + for row_idx, row in enumerate(rows[:rows_min_number]): + row_item = to_row_item(dataset_name, config_name, split_name, row_idx, row) + if rows_max_bytes is not None: + rows_bytes += get_size_in_bytes(row_item) + row_items.append(row_item) + + # 2. if the total is over the bytes limit, truncate the values, iterating backwards starting + # from the last rows, until getting under the threshold + if rows_max_bytes is not None and rows_bytes >= rows_max_bytes: + logger.debug( + f"the size of the first {rows_min_number} rows ({rows_bytes}) is above the max number of bytes" + f" ({rows_max_bytes}), they will be truncated" + ) + return truncate_row_items(row_items, rows_max_bytes) + + # 3. else: add the remaining rows until the end, or until the bytes threshold + for idx, row in enumerate(rows[rows_min_number:]): + row_idx = rows_min_number + idx + row_item = to_row_item(dataset_name, config_name, split_name, row_idx, row) + if rows_max_bytes is not None: + rows_bytes += get_size_in_bytes(row_item) + if rows_bytes >= rows_max_bytes: + logger.debug( + f"the rows in the split have been truncated to {row_idx} row(s) to keep the size" + f" ({rows_bytes}) under the limit ({rows_max_bytes})" + ) + break + row_items.append(row_item) + return row_items + + def get_split( dataset_name: str, config_name: str, split_name: str, hf_token: Optional[str] = None, max_size_fallback: Optional[int] = None, + rows_max_bytes: Optional[int] = None, + rows_max_number: Optional[int] = None, + rows_min_number: Optional[int] = None, ) -> Split: logger.info(f"get split '{split_name}' for config '{config_name}' of dataset '{dataset_name}'") guard_blocked_datasets(dataset_name) @@ -31,7 +182,19 @@ def get_split( fallback = ( max_size_fallback is not None and info.size_in_bytes is not None and info.size_in_bytes < max_size_fallback ) - typed_rows, columns = get_typed_rows_and_columns(dataset_name, config_name, split_name, info, hf_token, fallback) + typed_rows, columns = get_typed_rows_and_columns( + dataset_name, config_name, split_name, info, hf_token, fallback, rows_max_number + ) + row_items = create_truncated_row_items( + dataset_name, config_name, split_name, typed_rows, rows_max_bytes, rows_min_number + ) + rows_response: RowsResponse = { + "columns": [ + to_column_item(dataset_name, config_name, split_name, column_idx, column) + for column_idx, column in enumerate(columns) + ], + "rows": row_items, + } try: if info.splits is None: raise Exception("no splits in info") @@ -42,8 +205,7 @@ def get_split( num_examples = None return { "split_name": split_name, - "rows": typed_rows, - "columns": columns, + "rows_response": rows_response, "num_bytes": num_bytes, "num_examples": num_examples, } diff --git a/src/datasets_preview_backend/models/typed_row.py b/src/datasets_preview_backend/models/typed_row.py index c0f0fae1aa..4dfe5efba0 100644 --- a/src/datasets_preview_backend/models/typed_row.py +++ b/src/datasets_preview_backend/models/typed_row.py @@ -36,13 +36,14 @@ def get_typed_rows_and_columns( info: DatasetInfo, hf_token: Optional[str] = None, fallback: Optional[bool] = False, + rows_max_number: Optional[int] = None, ) -> Tuple[List[Row], List[Column]]: try: try: - rows = get_rows(dataset_name, config_name, split_name, hf_token, streaming=True) + rows = get_rows(dataset_name, config_name, split_name, hf_token, True, rows_max_number) except Exception: if fallback: - rows = get_rows(dataset_name, config_name, split_name, hf_token, streaming=False) + rows = get_rows(dataset_name, config_name, split_name, hf_token, False, rows_max_number) else: raise except Exception as err: diff --git a/src/datasets_preview_backend/routes/rows.py b/src/datasets_preview_backend/routes/rows.py index 33180cac53..a7dfaf79bb 100644 --- a/src/datasets_preview_backend/routes/rows.py +++ b/src/datasets_preview_backend/routes/rows.py @@ -3,11 +3,7 @@ from starlette.requests import Request from starlette.responses import Response -from datasets_preview_backend.config import ( - MAX_AGE_LONG_SECONDS, - ROWS_MAX_BYTES, - ROWS_MIN_NUMBER, -) +from datasets_preview_backend.config import MAX_AGE_LONG_SECONDS from datasets_preview_backend.exceptions import ( Status400Error, Status500Error, @@ -34,9 +30,7 @@ async def rows_endpoint(request: Request) -> Response: or not isinstance(split_name, str) ): raise StatusError("Parameters 'dataset', 'config' and 'split' are required", 400) - rows_response, rows_error, status_code = get_rows_response( - dataset_name, config_name, split_name, ROWS_MAX_BYTES, ROWS_MIN_NUMBER - ) + rows_response, rows_error, status_code = get_rows_response(dataset_name, config_name, split_name) return get_response(rows_response or rows_error, status_code, MAX_AGE_LONG_SECONDS) except StatusError as err: if err.message == "The split does not exist." and is_dataset_in_queue(dataset_name): diff --git a/src/datasets_preview_backend/worker.py b/src/datasets_preview_backend/worker.py index d7b401fda3..b8a59a5548 100644 --- a/src/datasets_preview_backend/worker.py +++ b/src/datasets_preview_backend/worker.py @@ -13,6 +13,9 @@ DEFAULT_MAX_LOAD_PCT, DEFAULT_MAX_MEMORY_PCT, DEFAULT_MAX_SIZE_FALLBACK, + DEFAULT_ROWS_MAX_BYTES, + DEFAULT_ROWS_MAX_NUMBER, + DEFAULT_ROWS_MIN_NUMBER, DEFAULT_WORKER_QUEUE, DEFAULT_WORKER_SLEEP_SECONDS, ) @@ -46,7 +49,11 @@ worker_sleep_seconds = get_int_value(os.environ, "WORKER_SLEEP_SECONDS", DEFAULT_WORKER_SLEEP_SECONDS) hf_token = get_str_or_none_value(d=os.environ, key="HF_TOKEN", default=DEFAULT_HF_TOKEN) max_size_fallback = get_int_value(os.environ, "MAX_SIZE_FALLBACK", DEFAULT_MAX_SIZE_FALLBACK) +rows_max_bytes = get_int_value(os.environ, "ROWS_MAX_BYTES", DEFAULT_ROWS_MAX_BYTES) +rows_max_number = get_int_value(os.environ, "ROWS_MAX_NUMBER", DEFAULT_ROWS_MAX_NUMBER) +rows_min_number = get_int_value(os.environ, "ROWS_MIN_NUMBER", DEFAULT_ROWS_MIN_NUMBER) worker_queue = get_str_value(os.environ, "WORKER_QUEUE", DEFAULT_WORKER_QUEUE) + # Ensure datasets library uses the expected revision for canonical datasets os.environ["HF_SCRIPTS_VERSION"] = get_str_value( d=os.environ, key="DATASETS_REVISION", default=DEFAULT_DATASETS_REVISION @@ -105,6 +112,9 @@ def process_next_split_job() -> bool: split_name=split_name, hf_token=hf_token, max_size_fallback=max_size_fallback, + rows_max_bytes=rows_max_bytes, + rows_max_number=rows_max_number, + rows_min_number=rows_min_number, ) success = True except Status400Error: diff --git a/tests/io/test_cache.py b/tests/io/test_cache.py index a47fee58c8..a84ac3108d 100644 --- a/tests/io/test_cache.py +++ b/tests/io/test_cache.py @@ -16,7 +16,7 @@ upsert_split, ) from datasets_preview_backend.models.dataset import get_dataset_split_full_names -from datasets_preview_backend.models.split import Split +from datasets_preview_backend.models.split import RowItem, Split @pytest.fixture(autouse=True, scope="module") @@ -113,18 +113,24 @@ def test_big_row() -> None: dataset_name = "test_dataset" config_name = "test_config" split_name = "test_split" - big_row = {"col": "a" * 100_000_000} + big_row: RowItem = { + "dataset": dataset_name, + "config": config_name, + "split": split_name, + "row_idx": 0, + "row": {"col": "a" * 100_000_000}, + "truncated_cells": [], + } split: Split = { "split_name": split_name, - "rows": [big_row], - "columns": [], + "rows_response": {"rows": [big_row], "columns": []}, "num_bytes": None, "num_examples": None, } upsert_split(dataset_name, config_name, split_name, split) rows_response, error, status_code = get_rows_response(dataset_name, config_name, split_name) - assert status_code == 200 - assert error is None - assert rows_response is not None - assert rows_response["rows"][0]["row"]["col"] == "" - assert rows_response["rows"][0]["truncated_cells"] == ["col"] + assert status_code == 500 + assert error is not None + assert rows_response is None + assert error["message"] == "could not store the rows/ cache entry." + assert error["cause_exception"] == "DocumentTooLarge" diff --git a/tests/models/test_row.py b/tests/models/test_row.py index 3b6e5736f9..0f885a319f 100644 --- a/tests/models/test_row.py +++ b/tests/models/test_row.py @@ -6,72 +6,71 @@ # get_rows def test_get_rows() -> None: - rows = get_rows("acronym_identification", "default", "train") + rows = get_rows("acronym_identification", "default", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert rows[0]["tokens"][0] == "What" def test_class_label() -> None: - rows = get_rows("glue", "cola", "train") + rows = get_rows("glue", "cola", "train", rows_max_number=ROWS_MAX_NUMBER) assert rows[0]["label"] == 1 def test_mnist() -> None: - rows = get_rows("mnist", "mnist", "train") + rows = get_rows("mnist", "mnist", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert isinstance(rows[0]["image"], Image.Image) def test_cifar() -> None: - rows = get_rows("cifar10", "plain_text", "train") + rows = get_rows("cifar10", "plain_text", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert isinstance(rows[0]["img"], Image.Image) def test_iter_archive() -> None: - rows = get_rows("food101", "default", "train") + rows = get_rows("food101", "default", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert isinstance(rows[0]["image"], Image.Image) def test_dl_1_suffix() -> None: # see https://github.com/huggingface/datasets/pull/2843 - rows = get_rows("discovery", "discovery", "train") + rows = get_rows("discovery", "discovery", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER def test_txt_zip() -> None: # see https://github.com/huggingface/datasets/pull/2856 - rows = get_rows("bianet", "en_to_ku", "train") + rows = get_rows("bianet", "en_to_ku", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER def test_pathlib() -> None: # see https://github.com/huggingface/datasets/issues/2866 - rows = get_rows("counter", "counter", "train") + rows = get_rows("counter", "counter", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER def test_community_with_no_config() -> None: - rows = get_rows("Check/region_1", "Check--region_1", "train") + rows = get_rows("Check/region_1", "Check--region_1", "train", rows_max_number=ROWS_MAX_NUMBER) # it's not correct: here this is the number of splits, not the number of rows assert len(rows) == 2 # see https://github.com/huggingface/datasets-preview-backend/issues/78 - get_rows("Check/region_1", "Check--region_1", "train") def test_audio_dataset() -> None: - rows = get_rows("abidlabs/test-audio-1", "test", "train") + rows = get_rows("abidlabs/test-audio-1", "test", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == 1 assert rows[0]["Output"]["sampling_rate"] == 48000 def test_libsndfile() -> None: # see https://github.com/huggingface/datasets-preview-backend/issues/194 - rows = get_rows("polinaeterna/ml_spoken_words", "ar_opus", "train") + rows = get_rows("polinaeterna/ml_spoken_words", "ar_opus", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert rows[0]["audio"]["sampling_rate"] == 48000 - rows = get_rows("polinaeterna/ml_spoken_words", "ar_wav", "train") + rows = get_rows("polinaeterna/ml_spoken_words", "ar_wav", "train", rows_max_number=ROWS_MAX_NUMBER) assert len(rows) == ROWS_MAX_NUMBER assert rows[0]["audio"]["sampling_rate"] == 16000 diff --git a/tests/models/test_split.py b/tests/models/test_split.py index 7f0c5f723c..0a0cee9c87 100644 --- a/tests/models/test_split.py +++ b/tests/models/test_split.py @@ -18,7 +18,10 @@ def test_gated() -> None: dataset_name = "severo/dummy_gated" config_name = "severo--embellishments" split_name = "train" - split = get_split(dataset_name, config_name, split_name, HF_TOKEN) + split = get_split(dataset_name, config_name, split_name, HF_TOKEN, rows_max_number=ROWS_MAX_NUMBER) - assert len(split["rows"]) == ROWS_MAX_NUMBER - assert split["rows"][0]["year"] == "1855" + assert len(split["rows_response"]["rows"]) == ROWS_MAX_NUMBER + assert split["rows_response"]["rows"][0]["row"]["year"] == "1855" + + +# TODO: test the truncation diff --git a/tests/models/test_typed_row.py b/tests/models/test_typed_row.py index d1ce733f3a..f23e25d122 100644 --- a/tests/models/test_typed_row.py +++ b/tests/models/test_typed_row.py @@ -4,16 +4,19 @@ from datasets_preview_backend.models.typed_row import get_typed_rows_and_columns +# TODO: this is slow: change the tested dataset? def test_detect_types_from_typed_rows() -> None: info = get_info("allenai/c4", "allenai--c4") - typed_rows, columns = get_typed_rows_and_columns("allenai/c4", "allenai--c4", "train", info) + typed_rows, columns = get_typed_rows_and_columns( + "allenai/c4", "allenai--c4", "train", info, rows_max_number=ROWS_MAX_NUMBER + ) assert len(typed_rows) == ROWS_MAX_NUMBER assert columns[0].type == ColumnType.STRING def test_class_label() -> None: info = get_info("glue", "cola") - typed_rows, columns = get_typed_rows_and_columns("glue", "cola", "train", info) + typed_rows, columns = get_typed_rows_and_columns("glue", "cola", "train", info, rows_max_number=ROWS_MAX_NUMBER) column = columns[1] assert isinstance(column, ClassLabelColumn) assert column.type == ColumnType.CLASS_LABEL @@ -23,7 +26,7 @@ def test_class_label() -> None: def test_mnist() -> None: info = get_info("mnist", "mnist") - typed_rows, columns = get_typed_rows_and_columns("mnist", "mnist", "train", info) + typed_rows, columns = get_typed_rows_and_columns("mnist", "mnist", "train", info, rows_max_number=ROWS_MAX_NUMBER) assert len(typed_rows) == ROWS_MAX_NUMBER assert typed_rows[0]["image"] == "assets/mnist/--/mnist/train/0/image/image.jpg" assert columns[0].type == ColumnType.RELATIVE_IMAGE_URL @@ -31,7 +34,9 @@ def test_mnist() -> None: def test_cifar() -> None: info = get_info("cifar10", "plain_text") - typed_rows, columns = get_typed_rows_and_columns("cifar10", "plain_text", "train", info) + typed_rows, columns = get_typed_rows_and_columns( + "cifar10", "plain_text", "train", info, rows_max_number=ROWS_MAX_NUMBER + ) assert len(typed_rows) == ROWS_MAX_NUMBER assert typed_rows[0]["img"] == "assets/cifar10/--/plain_text/train/0/img/image.jpg" assert columns[0].type == ColumnType.RELATIVE_IMAGE_URL @@ -39,7 +44,7 @@ def test_cifar() -> None: def test_head_qa() -> None: info = get_info("head_qa", "es") - typed_rows, columns = get_typed_rows_and_columns("head_qa", "es", "train", info) + typed_rows, columns = get_typed_rows_and_columns("head_qa", "es", "train", info, rows_max_number=ROWS_MAX_NUMBER) assert len(typed_rows) == ROWS_MAX_NUMBER assert typed_rows[0]["image"] is None assert columns[6].name == "image" @@ -48,21 +53,27 @@ def test_head_qa() -> None: def test_iter_archive() -> None: info = get_info("food101", "default") - typed_rows, columns = get_typed_rows_and_columns("food101", "default", "train", info) + typed_rows, columns = get_typed_rows_and_columns( + "food101", "default", "train", info, rows_max_number=ROWS_MAX_NUMBER + ) assert len(typed_rows) == ROWS_MAX_NUMBER assert columns[0].type == ColumnType.RELATIVE_IMAGE_URL def test_image_url() -> None: info = get_info("severo/wit", "default") - typed_rows, columns = get_typed_rows_and_columns("severo/wit", "default", "train", info) + typed_rows, columns = get_typed_rows_and_columns( + "severo/wit", "default", "train", info, rows_max_number=ROWS_MAX_NUMBER + ) assert len(typed_rows) == ROWS_MAX_NUMBER assert columns[2].type == ColumnType.IMAGE_URL def test_audio_dataset() -> None: info = get_info("abidlabs/test-audio-1", "test") - typed_rows, columns = get_typed_rows_and_columns("abidlabs/test-audio-1", "test", "train", info) + typed_rows, columns = get_typed_rows_and_columns( + "abidlabs/test-audio-1", "test", "train", info, rows_max_number=ROWS_MAX_NUMBER + ) assert len(typed_rows) == 1 assert columns[1].type == ColumnType.AUDIO_RELATIVE_SOURCES assert len(typed_rows[0]["Output"]) == 2 diff --git a/tests/test_app.py b/tests/test_app.py index 1e7e0a7a98..720ac3a8ed 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -2,7 +2,13 @@ from starlette.testclient import TestClient from datasets_preview_backend.app import create_app -from datasets_preview_backend.config import MONGO_CACHE_DATABASE, MONGO_QUEUE_DATABASE +from datasets_preview_backend.config import ( + MONGO_CACHE_DATABASE, + MONGO_QUEUE_DATABASE, + ROWS_MAX_BYTES, + ROWS_MAX_NUMBER, + ROWS_MIN_NUMBER, +) from datasets_preview_backend.exceptions import Status400Error from datasets_preview_backend.io.cache import clean_database as clean_cache_database from datasets_preview_backend.io.cache import ( @@ -85,7 +91,14 @@ def test_get_is_valid(client: TestClient) -> None: dataset = "acronym_identification" split_full_names = refresh_dataset_split_full_names(dataset) for split_full_name in split_full_names: - refresh_split(split_full_name["dataset_name"], split_full_name["config_name"], split_full_name["split_name"]) + refresh_split( + split_full_name["dataset_name"], + split_full_name["config_name"], + split_full_name["split_name"], + rows_max_bytes=ROWS_MAX_BYTES, + rows_max_number=ROWS_MAX_NUMBER, + rows_min_number=ROWS_MIN_NUMBER, + ) response = client.get("/is-valid", params={"dataset": "acronym_identification"}) assert response.status_code == 200 json = response.json() @@ -154,7 +167,14 @@ def test_get_rows(client: TestClient) -> None: dataset = "acronym_identification" config = "default" split = "train" - refresh_split(dataset, config, split) + refresh_split( + dataset, + config, + split, + rows_max_bytes=ROWS_MAX_BYTES, + rows_max_number=ROWS_MAX_NUMBER, + rows_min_number=ROWS_MIN_NUMBER, + ) response = client.get("/rows", params={"dataset": dataset, "config": config, "split": split}) assert response.status_code == 200 json = response.json() @@ -195,7 +215,14 @@ def test_datetime_content(client: TestClient) -> None: response = client.get("/rows", params={"dataset": dataset, "config": config, "split": split}) assert response.status_code == 400 - refresh_split(dataset, config, split) + refresh_split( + dataset, + config, + split, + rows_max_bytes=ROWS_MAX_BYTES, + rows_max_number=ROWS_MAX_NUMBER, + rows_min_number=ROWS_MIN_NUMBER, + ) response = client.get("/rows", params={"dataset": dataset, "config": config, "split": split}) assert response.status_code == 200 @@ -205,7 +232,14 @@ def test_bytes_limit(client: TestClient) -> None: dataset = "edbeeching/decision_transformer_gym_replay" config = "hopper-expert-v2" split = "train" - refresh_split(dataset, config, split) + refresh_split( + dataset, + config, + split, + rows_max_bytes=ROWS_MAX_BYTES, + rows_max_number=ROWS_MAX_NUMBER, + rows_min_number=ROWS_MIN_NUMBER, + ) response = client.get("/rows", params={"dataset": dataset, "config": config, "split": split}) assert response.status_code == 200 json = response.json() @@ -296,7 +330,14 @@ def test_error_messages(client: TestClient) -> None: # curl http://localhost:8000/rows\?dataset\=acronym_identification\&config\=default\&split\=train assert response.json()["message"] == "The split is being processed. Retry later." - refresh_split(dataset_name=dataset, config_name=config, split_name=split) + refresh_split( + dataset_name=dataset, + config_name=config, + split_name=split, + rows_max_bytes=ROWS_MAX_BYTES, + rows_max_number=ROWS_MAX_NUMBER, + rows_min_number=ROWS_MIN_NUMBER, + ) finish_split_job(job_id, success=True) # ^ equivalent to # WORKER_QUEUE=splits make worker