From 654e1dcdcac2ecd21221202fe59b3adbf656915d Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:33:44 -0800 Subject: [PATCH] Adjust `keep_cols` logic (#109) * Adjust `keep_cols` logic Signed-off-by: Sarah Yurick * flake8 lint Signed-off-by: Sarah Yurick --------- Signed-off-by: Sarah Yurick --- crossfit/data/sparse/core.py | 4 ++-- crossfit/op/label.py | 5 +++++ examples/dask_aggregate_bench.py | 6 +++--- tests/pytrec_utils.py | 10 +++++----- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/crossfit/data/sparse/core.py b/crossfit/data/sparse/core.py index 5c22fd36..c2d0c5ac 100644 --- a/crossfit/data/sparse/core.py +++ b/crossfit/data/sparse/core.py @@ -172,12 +172,12 @@ def to_pytrec(self, is_run=False): qrel = {} for i in range(self.indices.shape[0]): - query_id = f"q{i+1}" + query_id = f"q{i + 1}" qrel[query_id] = {} row = sparse_matrix[i] for j, score in zip(row.indices, row.data): - doc_id = f"d{j+1}" + doc_id = f"d{j + 1}" qrel[query_id][doc_id] = int(score) if is_run else float(score) return qrel diff --git a/crossfit/op/label.py b/crossfit/op/label.py index c8ed06ab..1f6419dd 100644 --- a/crossfit/op/label.py +++ b/crossfit/op/label.py @@ -16,6 +16,11 @@ def __init__( suffix: str = "labels", axis=-1, ): + if keep_cols is not None and suffix in keep_cols: + # suffix is already kept as a column + # and will raise an error if it is in keep_cols + keep_cols.remove(suffix) + super().__init__(pre=pre, cols=cols, keep_cols=keep_cols) self.labels = labels self.suffix = suffix diff --git a/examples/dask_aggregate_bench.py b/examples/dask_aggregate_bench.py index 472dbd84..392ad425 100644 --- a/examples/dask_aggregate_bench.py +++ b/examples/dask_aggregate_bench.py @@ -65,7 +65,7 @@ t0 = time.time() result = aggregate(ddf, agg, to_frame=True) tf = time.time() - print(f"\nWall Time: {tf-t0} seconds\n") + print(f"\nWall Time: {tf - t0} seconds\n") # View result print(f"Result:\n{result}\n") @@ -76,12 +76,12 @@ t0 = time.time() std = ddf.groupby(groupby).std().compute() tf = time.time() - print(f"\nddf.groupby().std() takes {tf-t0} seconds, and returns:\n") + print(f"\nddf.groupby().std() takes {tf - t0} seconds, and returns:\n") print(f"\n{std}\n") else: # Compare to ddf.std() t0 = time.time() std = ddf.std().compute() tf = time.time() - print(f"\nddf.std() takes {tf-t0} seconds, and returns:\n") + print(f"\nddf.std() takes {tf - t0} seconds, and returns:\n") print(f"\n{std}\n") diff --git a/tests/pytrec_utils.py b/tests/pytrec_utils.py index e2dbac24..a06f21b2 100644 --- a/tests/pytrec_utils.py +++ b/tests/pytrec_utils.py @@ -24,13 +24,13 @@ def create_qrel(relevance_scores, ids=None): qrel = {} for i, query_scores in enumerate(relevance_scores): - query_id = ids[i] if ids is not None else f"q{i+1}" + query_id = ids[i] if ids is not None else f"q{i + 1}" qrel[query_id] = {} for j, score in enumerate(query_scores): _score = int(score.item()) if _score > 0: - doc_id = f"d{j+1}" + doc_id = f"d{j + 1}" qrel[query_id][doc_id] = int(score.item()) return qrel @@ -41,10 +41,10 @@ def create_run(predicted_scores, ids=None): run = {} for i, query_scores in enumerate(predicted_scores): - query_id = ids[i] if ids is not None else f"q{i+1}" + query_id = ids[i] if ids is not None else f"q{i + 1}" run[query_id] = {} for j, score in enumerate(query_scores): - doc_id = f"d{j+1}" + doc_id = f"d{j + 1}" run[query_id][doc_id] = float(score.item()) return run @@ -60,6 +60,6 @@ def create_results(metric_arrays): for k, v in metric_arrays.items(): q_out[k] = float(v[i]) - outputs[f"q{i+1}"] = q_out + outputs[f"q{i + 1}"] = q_out return outputs