Skip to content

Commit

Permalink
write embeddings to jsonl
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Oct 28, 2024
1 parent 08349d5 commit 080563d
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 16 deletions.
3 changes: 2 additions & 1 deletion config/sem_dedup_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
cache_dir: "semdedup_cache"
num_files: 16
id_col_name: "id"
id_col_type: "int"
id_col_type: "int" # or "str"
input_column: "text"

# Embeddings configuration
embeddings_save_loc: "embeddings"
embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: 128
embedding_max_mem_gb: 25
input_file_type: "parquet" # or "jsonl"

# Clustering configuration
clustering_save_loc: "clustering_results"
Expand Down
4 changes: 3 additions & 1 deletion docs/user-guide/semdedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Semantic deduplication in NeMo Curator can be configured using a YAML file. Here
embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: 128
embedding_max_mem_gb: 25
input_file_type: "jsonl" # or "parquet"
# Clustering configuration
clustering_save_loc: "clustering_results"
Expand Down Expand Up @@ -177,7 +178,8 @@ Use Individual Components
embedding_batch_size=128,
embedding_output_dir="path/to/output/embeddings",
input_column="text",
logger="path/to/log/dir"
input_file_type="jsonl",
logger="path/to/log/dir",
)
embeddings_dataset = embedding_creator(dataset)
Expand Down
2 changes: 2 additions & 0 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class SemDedupConfig(BaseConfig):
embedding_model_name_or_path (str): Model name or path for embeddings.
embedding_batch_size (int): Inital Batch size for processing embeddings.
embedding_max_mem_gb (int): Maximum memory in GB for embeddings.
input_file_type (str): File type of input data, can be "parquet" or "jsonl".
clustering_save_loc (str): Location to save clustering results.
n_clusters (int): Number of clusters.
seed (int): Seed for clustering.
Expand All @@ -151,6 +152,7 @@ class SemDedupConfig(BaseConfig):
embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: int = 128
embedding_max_mem_gb: int = 25
input_file_type: str = "parquet"

# Clustering config
clustering_save_loc: str = "clustering_results"
Expand Down
33 changes: 27 additions & 6 deletions nemo_curator/modules/semantic_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
performance_report_if_with_ts_suffix,
write_to_disk,
)
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.file_utils import (
expand_outdir_and_mkdir,
get_all_files_paths_under,
)
from nemo_curator.utils.semdedup_utils import (
assign_and_sort_clusters,
extract_dedup_data,
Expand Down Expand Up @@ -130,6 +133,7 @@ def __init__(
embedding_column: str = "embeddings",
write_embeddings_to_disk: bool = True,
write_to_filename: bool = False,
input_file_type: str = "parquet",
logger: Union[logging.Logger, str] = "./",
profile_dir: Optional[str] = None,
):
Expand All @@ -146,6 +150,7 @@ def __init__(
We recommend setting this to False when you have a delayed pipeline.
Setting it to False can lead to more memory overhead.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
input_file_type (str): Whether a Parquet or JSON file type is being read.
logger (Union[logging.Logger, str]): Logger object or path to store logs, defaults to "./".
profile_dir (str): If specified directory to write dask profile. Default is None.
Expand All @@ -157,6 +162,7 @@ def __init__(
input_column (str): Input column for data processing.
model (EmbeddingCrossFitModel): Model instance for embedding generation.
write_to_filename (bool): If True, saves the embeddings to the same filename as input files, defaults to False.
output_file_type (str): Whether to write to a Parquet or JSON file type.
"""

self.embeddings_config = EmbeddingConfig(
Expand All @@ -171,6 +177,9 @@ def __init__(
self.model = EmbeddingCrossFitModel(self.embeddings_config)
self.write_embeddings_to_disk = write_embeddings_to_disk
self.write_to_filename = write_to_filename
if input_file_type == "json":
input_file_type = "jsonl"
self.output_file_type = input_file_type.lower()
self.profile_dir = profile_dir

def _setup_logger(self, logger):
Expand Down Expand Up @@ -216,13 +225,24 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
embedding_ddf,
self.embedding_output_dir,
write_to_filename=self.write_to_filename,
output_type="parquet",
output_type=self.output_file_type,
)

if self.output_file_type == "jsonl":
embedding_files = get_all_files_paths_under(self.embedding_output_dir)
ddf = DocumentDataset(
dask_cudf.read_json(
embedding_files, blocksize="2GB"
)
)
ddf = DocumentDataset(
dask_cudf.read_parquet(
self.embedding_output_dir, blocksize="2GB", aggregate_files=True
elif self.output_file_type == "parquet":
ddf = DocumentDataset(
dask_cudf.read_parquet(
self.embedding_output_dir, blocksize="2GB", aggregate_files=True
)
)
)
else:
raise ValueError(f"Unknown output type: {self.output_file_type}")
else:
ddf = DocumentDataset(embedding_ddf)

Expand Down Expand Up @@ -595,6 +615,7 @@ def __init__(
embedding_batch_size=config.embedding_batch_size,
input_column=config.input_column,
embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc),
input_file_type=config.input_file_type,
logger=logger,
profile_dir=self.config.profile_dir,
)
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/scripts/semdedup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Please edit `config/sem_dedup_config.yaml` to configure the pipeline and run it

3) Clustering
```sh
semdedup_clustering --config-file "$CONFIG_FILE"
semdedup_clustering --input-file-type "jsonl" --config-file "$CONFIG_FILE"
```
**Input:** Output from step (2) and YAML file from step (1)

Expand Down
34 changes: 29 additions & 5 deletions nemo_curator/scripts/semdedup/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from nemo_curator.modules.config import SemDedupConfig
from nemo_curator.modules.semantic_dedup import ClusteringModel
from nemo_curator.utils.distributed_utils import get_client
from nemo_curator.utils.file_utils import expand_outdir_and_mkdir
from nemo_curator.utils.file_utils import (
expand_outdir_and_mkdir,
get_all_files_paths_under,
)
from nemo_curator.utils.script_utils import ArgumentHelper


Expand Down Expand Up @@ -55,10 +58,29 @@ def main(args):
clustering_output_dir = os.path.join(
semdedup_config.cache_dir, semdedup_config.clustering_save_loc
)
# Switch to https://github.com/NVIDIA/NeMo-Curator/issues/50
# When we fix that
embedding_df = dask_cudf.read_parquet(embedding_fp, blocksize="2GB")
embedding_dataset = DocumentDataset(embedding_df)

if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
elif args.input_file_type is not None:
input_file_extension = args.input_file_type
else:
# Set default
input_file_extension = "parquet"

if input_file_extension in ["json", "jsonl"]:
embedding_files = get_all_files_paths_under(embedding_fp)
embedding_dataset = DocumentDataset(
dask_cudf.read_json(
embedding_files, blocksize="2GB"
)
)
elif input_file_extension == "parquet":
# Switch to https://github.com/NVIDIA/NeMo-Curator/issues/50
# When we fix that
embedding_df = dask_cudf.read_parquet(embedding_fp, blocksize="2GB")
embedding_dataset = DocumentDataset(embedding_df)
else:
raise RuntimeError("Could not read embeddings, please check file type")

clustering_model = ClusteringModel(
id_col=semdedup_config.id_col_name,
Expand All @@ -67,6 +89,7 @@ def main(args):
clustering_output_dir=clustering_output_dir,
logger=logger,
)

clustered_embeddings = clustering_model(embedding_dataset)
clustered_embeddings.df.head(10)
dt2 = datetime.now()
Expand Down Expand Up @@ -95,6 +118,7 @@ def attach_args():
" kmeans_with_cos_dist for using KMeans with cosine distance,"
),
add_input_args=False,
add_file_type_args=True,
)
return parser

Expand Down
10 changes: 8 additions & 2 deletions nemo_curator/scripts/semdedup/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,34 @@ def main(args):
output_data_dir = os.path.join(
semdedup_config.cache_dir, semdedup_config.embeddings_save_loc
)

# Some time jsonl files are stored as .json
# So to handle that case we can pass the input_file_extension
if args.input_file_extension is not None:
input_file_extension = args.input_file_extension
else:
input_file_extension = args.input_file_type
print("input_file_extension", input_file_extension)

st = time.time()
input_files = get_remaining_files(
input_file_path=args.input_data_dir,
output_file_path=output_data_dir,
input_file_type=input_file_extension,
num_files=semdedup_config.num_files,
)

logger.info(f"Processing {len(input_files)} files")
if len(input_files) == 0:
logger.info("No files to process")
return

ddf = read_data(
input_files=input_files, file_type=args.input_file_type, add_filename=False
input_files=input_files, file_type=args.input_file_type, add_filename=True
)
ddf = ddf.reset_index(drop=True)
dataset = DocumentDataset(ddf)

# Can repartition here if needed
# ddf = ddf.repartition(partition_size="64MB")
embedding_creator = EmbeddingCreator(
Expand All @@ -75,8 +79,10 @@ def main(args):
),
input_column=semdedup_config.input_column,
logger=logger,
write_to_filename=False,
write_to_filename=True,
input_file_type=input_file_extension,
)

embedding_dataset = embedding_creator(dataset=dataset)
print(embedding_dataset.df.head())
logger.info(f"Time taken: {time.time() - st}")
Expand Down
4 changes: 4 additions & 0 deletions nemo_curator/utils/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def parse_gpu_dedup_args(description: str) -> argparse.ArgumentParser:
@staticmethod
def parse_semdedup_args(
add_input_args=False,
add_file_type_args=False,
description="Default argument parser for semantic deduplication",
) -> argparse.ArgumentParser:
"""
Expand All @@ -560,6 +561,9 @@ def parse_semdedup_args(
argumentHelper.add_arg_input_file_extension()
argumentHelper.add_arg_input_file_type()
argumentHelper.add_arg_input_text_field()
elif add_file_type_args:
argumentHelper.add_arg_input_file_extension()
argumentHelper.add_arg_input_file_type()

argumentHelper.parser.add_argument(
"--config-file",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ embeddings_save_loc: "embeddings"
embedding_model_name_or_path: "sentence-transformers/all-MiniLM-L6-v2"
embedding_batch_size: 128
embedding_max_mem_gb: 20
input_file_type: "jsonl"

# Clustering configuration
clustering_save_loc: "clustering_results"
Expand Down

0 comments on commit 080563d

Please sign in to comment.