Skip to content

Commit

Permalink
Add flag to allow or reject datasets containing unsafe types (i.e., P…
Browse files Browse the repository at this point in the history
…ickle) (#519)

* Add allow_unsafe_types.

* is_mds_encodings_safe.

* usefixtures.

* Fix lint.
  • Loading branch information
knighton authored Dec 7, 2023
1 parent 4134ec7 commit 8b0f1df
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 7 deletions.
9 changes: 7 additions & 2 deletions simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class SimulationDataset(StreamingDataset):
Defaults to ``1``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
"""

def __init__(self,
Expand All @@ -125,7 +128,8 @@ def __init__(self,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random') -> None:
batching_method: str = 'random',
allow_unsafe_types: bool = False) -> None:

# Time how long it takes for StreamingDataset instantiation
t0 = time.time()
Expand All @@ -146,6 +150,7 @@ def __init__(self,
self.sampling_granularity = sampling_granularity
self.batching_method = batching_method
self.num_canonical_nodes = num_canonical_nodes
self.allow_unsafe_types = allow_unsafe_types

self.initial_physical_nodes = nodes

Expand Down Expand Up @@ -265,7 +270,7 @@ def __init__(self,
local_foldernames = []
for stream_id, stream in enumerate(self.streams):
logger.info(f' Processing index file for stream {stream_id + 1}')
stream_shards = stream.get_shards(self.world)
stream_shards = stream.get_shards(self.world, self.allow_unsafe_types)
num_stream_samples = sum(map(len, stream_shards))
index_filename = os.path.join(stream.local, stream.split, get_index_basename())
index_filenames.append(index_filename)
Expand Down
9 changes: 7 additions & 2 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ class StreamingDataset(Array, IterableDataset):
``None``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an error
if ``False``. Defaults to ``False``.
"""

def __init__(self,
Expand All @@ -327,7 +330,8 @@ def __init__(self,
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random') -> None:
batching_method: str = 'random',
allow_unsafe_types: bool = False) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
Expand All @@ -341,6 +345,7 @@ def __init__(self,
self.shuffle_seed = shuffle_seed
self.shuffle_block_size = shuffle_block_size
self.batching_method = batching_method
self.allow_unsafe_types = allow_unsafe_types

# Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the
# number of physical nodes of the initial run in the _resume function.
Expand Down Expand Up @@ -452,7 +457,7 @@ def __init__(self,
self.sample_offset_per_stream = np.zeros(self.num_streams, np.int64)
self.samples_per_stream = np.zeros(self.num_streams, np.int64)
for stream_id, stream in enumerate(self.streams):
stream_shards = stream.get_shards(world)
stream_shards = stream.get_shards(world, self.allow_unsafe_types)
num_stream_samples = sum(map(len, stream_shards))
if not num_stream_samples:
index_filename = os.path.join(stream.local, stream.split, get_index_basename())
Expand Down
10 changes: 10 additions & 0 deletions streaming/base/format/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def __init__(

self.file_pairs = []

def validate(self, allow_unsafe_types: bool) -> None:
"""Check whether this shard is acceptable to be part of some Stream.
Args:
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an
error if ``False``.
"""
pass

@property
def size(self):
"""Get the number of samples in this shard.
Expand Down
17 changes: 16 additions & 1 deletion streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from typing_extensions import Self

__all__ = [
'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode'
'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode',
'is_mds_encoding_safe'
]


Expand Down Expand Up @@ -543,6 +544,8 @@ def _is_valid(self, original: Any, converted: Any) -> None:
'json': JSON,
}

_unsafe_encodings = {'pkl'}


def get_mds_encodings() -> Set[str]:
"""List supported encodings.
Expand Down Expand Up @@ -586,6 +589,18 @@ def is_mds_encoding(encoding: str) -> bool:
return coder is not None


def is_mds_encoding_safe(encoding: str) -> bool:
"""Get whether the given encoding is safe (does not allow arbitrary code execution).
Args:
encoding (str): Encoding.
Returns:
bool: Whether the encoding is safe.
"""
return encoding not in _unsafe_encodings


def mds_encode(encoding: str, obj: Any) -> bytes:
"""Encode the given data from the original object to bytes.
Expand Down
17 changes: 16 additions & 1 deletion streaming/base/format/mds/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from streaming.base.format.base.reader import FileInfo, JointReader
from streaming.base.format.mds.encodings import mds_decode
from streaming.base.format.mds.encodings import is_mds_encoding_safe, mds_decode

__all__ = ['MDSReader']

Expand Down Expand Up @@ -84,6 +84,21 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S
args[key] = FileInfo(**arg) if arg else None
return cls(**args)

def validate(self, allow_unsafe_types: bool) -> None:
"""Check whether this shard is acceptable to be part of some Stream.
Args:
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an
error if ``False``.
"""
if not allow_unsafe_types:
for column_id, encoding in enumerate(self.column_encodings):
if not is_mds_encoding_safe(encoding):
name = self.column_names[column_id]
raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' +
f'proceed anyway, set ``allow_unsafe_types=True``.')

def decode_sample(self, data: bytes) -> Dict[str, Any]:
"""Decode a sample dict from bytes.
Expand Down
6 changes: 5 additions & 1 deletion streaming/base/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,11 +421,14 @@ def prepare_shard(self, shard: Reader) -> int:
delta += self._prepare_shard_part(raw_info, zip_info, shard.compression)
return delta

def get_shards(self, world: World) -> List[Reader]:
def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]:
"""Load this Stream's index, retrieving its shard readers.
Args:
world (World): Distributed context.
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an
error.
Returns:
`List[Reader]: Shard readers.
Expand Down Expand Up @@ -469,6 +472,7 @@ def get_shards(self, world: World) -> List[Reader]:
shards = []
for info in obj['shards']:
shard = reader_from_json(self.local, self.split, info)
shard.validate(allow_unsafe_types)
shards.append(shard)

return shards
Expand Down
53 changes: 53 additions & 0 deletions tests/test_unsafe_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

from typing import Tuple

import pytest

from streaming import MDSWriter, StreamingDataset


@pytest.mark.usefixtures('local_remote_dir')
def test_do_allow_unsafe_types_safe(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
columns = {'num': 'int'}
with MDSWriter(out=local, columns=columns) as out:
for num in range(100):
sample = {'num': num}
out.write(sample)
StreamingDataset(local=local, allow_unsafe_types=True)


@pytest.mark.usefixtures('local_remote_dir')
def test_do_allow_unsafe_types_unsafe(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
columns = {'num': 'pkl'}
with MDSWriter(out=local, columns=columns) as out:
for num in range(100):
sample = {'num': num}
out.write(sample)
StreamingDataset(local=local, allow_unsafe_types=True)


@pytest.mark.usefixtures('local_remote_dir')
def test_dont_allow_unsafe_types_safe(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
columns = {'num': 'int'}
with MDSWriter(out=local, columns=columns) as out:
for num in range(100):
sample = {'num': num}
out.write(sample)
StreamingDataset(local=local)


@pytest.mark.usefixtures('local_remote_dir')
def test_dont_allow_unsafe_types_unsafe(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
columns = {'num': 'pkl'}
with MDSWriter(out=local, columns=columns) as out:
for num in range(100):
sample = {'num': num}
out.write(sample)
with pytest.raises(ValueError, match='.*contains an unsafe type.*'):
StreamingDataset(local=local)

0 comments on commit 8b0f1df

Please sign in to comment.