From 0ab8300d75c3ce3c40b983f3fb21854a9f592373 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 4 Feb 2025 14:03:09 -0800 Subject: [PATCH] Add option to set cache precision in TBE benchmark Summary: - Add option to set cache precision in TBE benchmark Reviewed By: sryap Differential Revision: D69134252 --- .../bench/split_table_batched_embeddings_benchmark.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index b06ac8b2b2..79c2e4b8c4 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -122,6 +122,7 @@ def cli() -> None: @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.FP32) +@click.option("--cache-precision", type=SparseType, default=None) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) @click.option("--warmup-runs", default=0) @@ -174,6 +175,7 @@ def device( # noqa C901 batch_size: int, embedding_dim: int, weights_precision: SparseType, + cache_precision: Optional[SparseType], stoc: bool, iters: int, warmup_runs: int, @@ -317,7 +319,9 @@ def device( # noqa C901 ) for d in Ds ], - cache_precision=weights_precision, + cache_precision=( + weights_precision if cache_precision is None else cache_precision + ), cache_algorithm=CacheAlgorithm.LRU, cache_load_factor=cache_load_factor, **common_split_args,