Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to evaluate every next item in next-item evaluation #580

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def __init__(
self.val_set = None
self.rating_threshold = rating_threshold
self.exclude_unknowns = exclude_unknowns
self.mode = kwargs.get("mode", None)
lthoang marked this conversation as resolved.
Show resolved Hide resolved
self.verbose = verbose
self.seed = seed
self.rng = get_rng(seed)
Expand Down Expand Up @@ -663,6 +664,7 @@ def eval(
rating_metrics,
ranking_metrics,
verbose,
**kwargs,
):
"""Running evaluation for rating and ranking metrics respectively."""
metric_avg_results = OrderedDict()
Expand Down Expand Up @@ -754,6 +756,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
lthoang marked this conversation as resolved.
Show resolved Hide resolved
verbose=self.verbose,
)
test_time = time.time() - start
Expand All @@ -774,6 +777,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
val_time = time.time() - start
Expand Down
108 changes: 66 additions & 42 deletions cornac/eval_methods/next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
from . import BaseMethod


EVALUATION_MODES = [
lthoang marked this conversation as resolved.
Show resolved Hide resolved
"last",
"next",
]

def ranking_eval(
model,
metrics,
train_set,
test_set,
user_based=False,
exclude_unknowns=True,
mode="last",
verbose=False,
):
"""Evaluate model on provided ranking metrics.
Expand Down Expand Up @@ -68,62 +74,68 @@ def ranking_eval(
return [], []

avg_results = []
session_results = [{} for _ in enumerate(metrics)]
session_results = [defaultdict(list) for _ in enumerate(metrics)]
user_results = [defaultdict(list) for _ in enumerate(metrics)]

user_sessions = defaultdict(list)
session_ids = []
for [sid], [mapped_ids], [session_items] in tqdm(
test_set.si_iter(batch_size=1, shuffle=False),
total=len(test_set.sessions),
desc="Ranking",
disable=not verbose,
miniters=100,
):
test_pos_items = session_items[-1:] # last item in the session
if len(test_pos_items) == 0:
if len(session_items) < 2: # exclude all session with size smaller than 2
continue
user_idx = test_set.uir_tuple[0][mapped_ids[0]]
if user_based:
user_sessions[user_idx].append(sid)
# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:-1],
history_mapped_ids=mapped_ids[:-1],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
session_ids.append(sid)

start_pos = 1 if mode == "next" else len(session_items) - 1
for test_pos in range(start_pos, len(session_items), 1):
test_pos_items = session_items[test_pos]

# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:test_pos],
history_mapped_ids=mapped_ids[:test_pos],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid] = mt_score

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid].append(mt_score)

# avg results of ranking metrics
for i, mt in enumerate(metrics):
Expand All @@ -132,7 +144,8 @@ def ranking_eval(
user_avg_results = [np.mean(user_results[i][user_idx]) for user_idx in user_ids]
avg_results.append(np.mean(user_avg_results))
else:
avg_results.append(sum(session_results[i].values()) / len(session_results[i]))
session_result = [score for sid in session_ids for score in session_results[i][sid]]
avg_results.append(np.mean(session_result))
return avg_results, user_results


Expand Down Expand Up @@ -163,6 +176,11 @@ class NextItemEvaluation(BaseMethod):
seed: int, optional, default: None
Random seed for reproducibility.

mode: str, optional, default: 'last'
Evaluation mode is either 'next' or 'last'.
If 'last', only evaluate the last item.
If 'next', evaluate every next item in the sequence,

exclude_unknowns: bool, optional, default: True
If `True`, unknown items will be ignored during model evaluation.

Expand All @@ -178,6 +196,7 @@ def __init__(
val_size=0.0,
fmt="SIT",
seed=None,
mode="last",
exclude_unknowns=True,
verbose=False,
**kwargs,
Expand All @@ -191,8 +210,11 @@ def __init__(
seed=seed,
exclude_unknowns=exclude_unknowns,
verbose=verbose,
mode=mode,
**kwargs,
)
assert mode in EVALUATION_MODES
lthoang marked this conversation as resolved.
Show resolved Hide resolved
self.mode = mode
self.global_sid_map = kwargs.get("global_sid_map", OrderedDict())

def _build_datasets(self, train_data, test_data, val_data=None):
Expand Down Expand Up @@ -263,6 +285,7 @@ def eval(
ranking_metrics,
user_based=False,
verbose=False,
mode="last",
**kwargs,
):
metric_avg_results = OrderedDict()
Expand All @@ -275,6 +298,7 @@ def eval(
test_set=test_set,
user_based=user_based,
exclude_unknowns=exclude_unknowns,
mode=mode,
verbose=verbose,
)

Expand Down
15 changes: 15 additions & 0 deletions tests/cornac/eval_methods/test_next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ def test_evaluate(self):
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@2'), 0)
self.assertEqual(result[0].metric_avg_results.get('Recall@2'), 0)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=5), Recall(k=5)], user_based=True
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@5'), 2/3)
self.assertEqual(result[0].metric_avg_results.get('Recall@5'), 2/3)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT", mode="next")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=2), Recall(k=2)], user_based=False
)

self.assertEqual(result[0].metric_avg_results.get('HitRatio@2'), 1/8)
self.assertEqual(result[0].metric_avg_results.get('Recall@2'), 1/8)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT", mode="next")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=5), Recall(k=5)], user_based=True
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@5'), 3/4)
self.assertEqual(result[0].metric_avg_results.get('Recall@5'), 3/4)

if __name__ == "__main__":
unittest.main()
Loading