Skip to content

Commit

Permalink
address ayush's comments
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Nov 6, 2024
1 parent 645655d commit e31a3f1
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 27 deletions.
40 changes: 38 additions & 2 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ def read_json(
columns: Optional[List[str]] = None,
**kwargs,
):
"""
Read JSONL or JSONL file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
input_meta: A dictionary or a string formatted as a dictionary, which outlines
the field names and their respective data types within the JSONL input file.
columns: If not None, only these columns will be read from the file.
"""
return cls(
_read_json_or_parquet(
input_files=input_files,
Expand All @@ -73,6 +86,18 @@ def read_parquet(
columns: Optional[List[str]] = None,
**kwargs,
):
"""
Read Parquet file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
columns: If not None, only these columns will be read from the file.
There is a significant performance gain when specifying columns for Parquet files.
"""
return cls(
_read_json_or_parquet(
input_files=input_files,
Expand All @@ -95,6 +120,17 @@ def read_pickle(
columns: Optional[List[str]] = None,
**kwargs,
):
"""
Read Pickle file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
columns: If not None, only these columns will be read from the file.
"""
return cls(
read_data(
input_files=input_files,
Expand All @@ -114,7 +150,7 @@ def to_json(
keep_filename_column: bool = False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.
See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters.
"""
write_to_disk(
Expand All @@ -132,7 +168,7 @@ def to_parquet(
keep_filename_column: bool = False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.
See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters.
"""
write_to_disk(
Expand Down
58 changes: 40 additions & 18 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,23 @@ def _enable_spilling():


def start_dask_gpu_local_cluster(
nvlink_only=False,
protocol="tcp",
rmm_pool_size="1024M",
enable_spilling=True,
set_torch_to_use_rmm=True,
rmm_async=True,
rmm_maximum_pool_size=None,
rmm_managed_memory=False,
rmm_release_threshold=None,
nvlink_only: bool = False,
protocol: str = "tcp",
rmm_pool_size: Optional[Union[int, str]] = "1024M",
enable_spilling: bool = True,
set_torch_to_use_rmm: bool = True,
rmm_async: bool = True,
rmm_maximum_pool_size: Optional[Union[int, str]] = None,
rmm_managed_memory: bool = False,
rmm_release_threshold: Optional[Union[int, str]] = None,
**cluster_kwargs,
) -> Client:
"""
This function sets up a Dask cluster across all the
GPUs present on the machine.
See get_client function for parameters.
"""
extra_kwargs = (
{
Expand Down Expand Up @@ -111,12 +113,16 @@ def start_dask_gpu_local_cluster(


def start_dask_cpu_local_cluster(
n_workers=os.cpu_count(), threads_per_worker=1, **cluster_kwargs
n_workers: Optional[int] = os.cpu_count(),
threads_per_worker: int = 1,
**cluster_kwargs,
) -> Client:
"""
This function sets up a Dask cluster across all the
CPUs present on the machine.
See get_client function for parameters.
"""
cluster = LocalCluster(
n_workers=n_workers,
Expand All @@ -130,7 +136,7 @@ def start_dask_cpu_local_cluster(


def get_client(
cluster_type="cpu",
cluster_type: str = "cpu",
scheduler_address=None,
scheduler_file=None,
n_workers=os.cpu_count(),
Expand Down Expand Up @@ -262,10 +268,10 @@ def _set_torch_to_use_rmm():


def read_single_partition(
files,
backend="cudf",
filetype="jsonl",
add_filename=False,
files: List[str],
backend: str = "cudf",
filetype: str = "jsonl",
add_filename: bool = False,
input_meta: Union[str, dict] = None,
columns: Optional[List[str]] = None,
**kwargs,
Expand Down Expand Up @@ -364,6 +370,7 @@ def read_pandas_pickle(
Args:
file: The path to the pickle file to read.
add_filename: Whether to add a "filename" column to the DataFrame.
columns: If not None, only these columns will be read from the file.
Returns:
A Pandas DataFrame.
Expand Down Expand Up @@ -688,7 +695,7 @@ def load_object_on_worker(attr, load_object_function, load_object_kwargs):
return obj


def offload_object_on_worker(attr):
def offload_object_on_worker(attr: str):
"""
This function deletes an existing attribute from a Dask worker.
Expand Down Expand Up @@ -725,7 +732,19 @@ def get_current_client():
return None


def performance_report_if(path=None, report_name="dask-profile.html"):
def performance_report_if(
path: Optional[str] = None, report_name: str = "dask-profile.html"
):
"""
Generates a performance report if a valid path is provided, or returns a
no-op context manager if not.
Args:
path: The directory path where the performance report should be saved.
If None, no report is generated.
report_name: The name of the report file.
"""
if path is not None:
return performance_report(os.path.join(path, report_name))
else:
Expand All @@ -735,7 +754,10 @@ def performance_report_if(path=None, report_name="dask-profile.html"):
def performance_report_if_with_ts_suffix(
path: Optional[str] = None, report_name: str = "dask-profile"
):
"""Suffixes the report_name with the timestamp"""
"""
Same as performance_report_if, except it suffixes the report_name with the timestamp.
"""
return performance_report_if(
path=path,
report_name=f"{report_name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}.html",
Expand Down
41 changes: 34 additions & 7 deletions nemo_curator/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,18 @@ def expand_outdir_and_mkdir(outdir):


def filter_files_by_extension(
files_list: str,
files_list: List[str],
filter_by: Union[str, List[str]],
):
"""
Given a list of files, filter it to only include files matching given extension(s).
Args:
files_list: List of files.
filter_by: A string (e.g., "json") or a list of strings (e.g., ["json", "parquet"])
representing which file types to keep from files_list.
"""
filtered_files = []

if isinstance(filter_by, str):
Expand All @@ -61,7 +70,12 @@ def filter_files_by_extension(
if file.endswith(tuple(file_extensions)):
filtered_files.append(file)
else:
warnings.warn(f"Skipping read for file: {file}")
warning_flag = True

if warning_flag:
warnings.warn(
f"Skipped at least one file due to unmatched file extension(s)."
)

return filtered_files

Expand Down Expand Up @@ -106,7 +120,10 @@ def get_all_files_paths_under(
# writing a file we can use the offset counter approach
# in jaccard shuffle as a more robust way to restart jobs
def get_remaining_files(
input_file_path, output_file_path, input_file_type, num_files=-1
input_file_path: str,
output_file_path: str,
input_file_type: str,
num_files: int = -1,
):
"""
This function returns a list of the files that still remain to be read.
Expand Down Expand Up @@ -148,7 +165,10 @@ def get_remaining_files(


def get_batched_files(
input_file_path, output_file_path, input_file_type, batch_size=64
input_file_path: str,
output_file_path: str,
input_file_type: str,
batch_size: int = 64,
):
"""
This function returns a batch of files that still remain to be processed.
Expand Down Expand Up @@ -331,7 +351,7 @@ def separate_by_metadata(
return delayed(reduce)(merge_counts, delayed_counts)


def parse_str_of_num_bytes(s, return_str=False):
def parse_str_of_num_bytes(s: str, return_str: bool = False):
try:
power = "kmg".find(s[-1].lower()) + 1
size = float(s[:-1]) * 1024**power
Expand All @@ -344,7 +364,10 @@ def parse_str_of_num_bytes(s, return_str=False):


def _save_jsonl(documents, output_path, start_index=0, max_index=10000, prefix=None):
"""Worker function to write out the data to jsonl files"""
"""
Worker function to write out the data to jsonl files
"""

def _encode_text(document):
return document.strip().encode("utf-8")
Expand Down Expand Up @@ -377,7 +400,11 @@ def _name(start_index, npad, prefix, i):


def reshard_jsonl(
input_dir, output_dir, output_file_size="100M", start_index=0, file_prefix=""
input_dir: str,
output_dir: str,
output_file_size: str = "100M",
start_index: int = 0,
file_prefix: str = "",
):
"""
Reshards a directory of jsonl files to have a new (approximate) file size for each shard
Expand Down

0 comments on commit e31a3f1

Please sign in to comment.