diff --git a/python/cuvs_bench/cuvs_bench/run/run.py b/python/cuvs_bench/cuvs_bench/run/run.py index 0159d2c19..d7827a096 100644 --- a/python/cuvs_bench/cuvs_bench/run/run.py +++ b/python/cuvs_bench/cuvs_bench/run/run.py @@ -158,6 +158,7 @@ def gather_algorithm_configs( def load_algorithms_conf( algos_conf_fs: list, allowed_algos: Optional[list], + allowed_groups: Optional[list], allowed_algo_groups: Optional[tuple], ) -> dict: """ @@ -187,8 +188,11 @@ def load_algorithms_conf( continue if allowed_algos and algo["name"] not in allowed_algos: continue + groups = algo.get("groups", {}) + if allowed_groups: + groups = {k: v for k, v in groups.items() if k in allowed_groups} algos_conf[algo["name"]] = { - "groups": algo.get("groups", {}), + "groups": groups, "constraints": algo.get("constraints", {}), } if allowed_algo_groups and algo["name"] in allowed_algo_groups[0]: @@ -643,6 +647,7 @@ def run_benchmark( algos_conf_fs = gather_algorithm_configs(scripts_path, configuration) allowed_algos = algorithms.split(",") if algorithms else None + allowed_groups = groups.split(",") if groups else None allowed_algo_groups = ( [algo_group.split(".") for algo_group in algo_groups.split(",")] if algo_groups @@ -651,7 +656,8 @@ def run_benchmark( algos_conf = load_algorithms_conf( algos_conf_fs, allowed_algos, - list(zip(*allowed_algo_groups)) if allowed_algo_groups else None, + allowed_groups, + tuple(zip(*allowed_algo_groups)) if allowed_algo_groups else None, ) executables_to_run = prepare_executables(