Skip to content

Commit

Permalink
feat: 🎸 revert double limit on the rows size (reverts #162) (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
severo authored Mar 14, 2022
1 parent f406c0d commit 155843f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 38 deletions.
31 changes: 2 additions & 29 deletions src/datasets_preview_backend/models/typed_row.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging
import sys
from typing import List, Optional, Tuple

from datasets import DatasetInfo

from datasets_preview_backend.config import ROWS_MAX_BYTES
from datasets_preview_backend.exceptions import Status400Error
from datasets_preview_backend.models.column import Column, get_columns
from datasets_preview_backend.models.row import Row, get_rows
from datasets_preview_backend.utils import orjson_dumps

logger = logging.getLogger(__name__)

Expand All @@ -22,38 +19,14 @@ def get_typed_row(
}


def get_size_in_bytes(row: Row):
return sys.getsizeof(orjson_dumps(row))
# ^^ every row is transformed here in a string, because it corresponds to
# the size the row will contribute in the JSON response to /rows endpoint.
# The size of the string is measured in bytes.
# An alternative would have been to look at the memory consumption (pympler) but it's
# less related to what matters here (size of the JSON, number of characters in the
# dataset viewer table on the hub)


def get_typed_rows(
dataset_name: str,
config_name: str,
split_name: str,
rows: List[Row],
columns: List[Column],
rows_max_bytes: Optional[int] = None,
) -> List[Row]:
typed_rows = []
bytes = 0
for idx, row in enumerate(rows):
typed_row = get_typed_row(dataset_name, config_name, split_name, row, idx, columns)
if rows_max_bytes is not None:
bytes += get_size_in_bytes(typed_row)
if bytes >= rows_max_bytes:
logger.debug(
f"the rows in the split have been truncated to {idx} row(s) to keep the size ({bytes}) under the"
f" limit ({rows_max_bytes})"
)
break
typed_rows.append(typed_row)
return typed_rows
return [get_typed_row(dataset_name, config_name, split_name, row, idx, columns) for idx, row in enumerate(rows)]


def get_typed_rows_and_columns(
Expand All @@ -76,5 +49,5 @@ def get_typed_rows_and_columns(
raise Status400Error("Cannot get the first rows for the split.", err) from err

columns = get_columns(info, rows)
typed_rows = get_typed_rows(dataset_name, config_name, split_name, rows, columns, ROWS_MAX_BYTES)
typed_rows = get_typed_rows(dataset_name, config_name, split_name, rows, columns)
return typed_rows, columns
9 changes: 0 additions & 9 deletions tests/models/test_typed_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,3 @@ def test_image_url() -> None:
# assert typed_rows[0]["audio"][0]["type"] == "audio/mpeg"
# assert typed_rows[0]["audio"][1]["type"] == "audio/wav"
# assert typed_rows[0]["audio"][0]["src"] == "assets/common_voice/--/tr/train/0/audio/audio.mp3"


def test_bytes_limit() -> None:
dataset = "edbeeching/decision_transformer_gym_replay"
config = "hopper-expert-v2"
split = "train"
info = get_info(dataset, config)
typed_rows, columns = get_typed_rows_and_columns(dataset, config, split, info)
assert len(typed_rows) == 3

0 comments on commit 155843f

Please sign in to comment.