Skip to content

Commit

Permalink
Add option to evaluate every next item in next-item evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Jan 9, 2024
1 parent f2d44ce commit dad914e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 42 deletions.
4 changes: 4 additions & 0 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def __init__(
self.val_set = None
self.rating_threshold = rating_threshold
self.exclude_unknowns = exclude_unknowns
self.mode = kwargs.get("mode", None)
self.verbose = verbose
self.seed = seed
self.rng = get_rng(seed)
Expand Down Expand Up @@ -663,6 +664,7 @@ def eval(
rating_metrics,
ranking_metrics,
verbose,
**kwargs,
):
"""Running evaluation for rating and ranking metrics respectively."""
metric_avg_results = OrderedDict()
Expand Down Expand Up @@ -754,6 +756,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
test_time = time.time() - start
Expand All @@ -774,6 +777,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
val_time = time.time() - start
Expand Down
108 changes: 66 additions & 42 deletions cornac/eval_methods/next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
from . import BaseMethod


EVALUATION_MODES = [
"last",
"next",
]

def ranking_eval(
model,
metrics,
train_set,
test_set,
user_based=False,
exclude_unknowns=True,
mode="last",
verbose=False,
):
"""Evaluate model on provided ranking metrics.
Expand Down Expand Up @@ -68,62 +74,68 @@ def ranking_eval(
return [], []

avg_results = []
session_results = [{} for _ in enumerate(metrics)]
session_results = [defaultdict(list) for _ in enumerate(metrics)]
user_results = [defaultdict(list) for _ in enumerate(metrics)]

user_sessions = defaultdict(list)
session_ids = []
for [sid], [mapped_ids], [session_items] in tqdm(
test_set.si_iter(batch_size=1, shuffle=False),
total=len(test_set.sessions),
desc="Ranking",
disable=not verbose,
miniters=100,
):
test_pos_items = session_items[-1:] # last item in the session
if len(test_pos_items) == 0:
if len(session_items) < 2: # exclude all session with size smaller than 2
continue
user_idx = test_set.uir_tuple[0][mapped_ids[0]]
if user_based:
user_sessions[user_idx].append(sid)
# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:-1],
history_mapped_ids=mapped_ids[:-1],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
session_ids.append(sid)

start_pos = 1 if mode == "next" else len(session_items) - 1
for test_pos in range(start_pos, len(session_items), 1):
test_pos_items = session_items[test_pos]

# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:test_pos],
history_mapped_ids=mapped_ids[:test_pos],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid] = mt_score

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid].append(mt_score)

# avg results of ranking metrics
for i, mt in enumerate(metrics):
Expand All @@ -132,7 +144,8 @@ def ranking_eval(
user_avg_results = [np.mean(user_results[i][user_idx]) for user_idx in user_ids]
avg_results.append(np.mean(user_avg_results))
else:
avg_results.append(sum(session_results[i].values()) / len(session_results[i]))
session_avg_results = [np.mean(session_results[i][sid] for sid in session_ids)]
avg_results.append(np.mean(session_avg_results))
return avg_results, user_results


Expand Down Expand Up @@ -163,6 +176,11 @@ class NextItemEvaluation(BaseMethod):
seed: int, optional, default: None
Random seed for reproducibility.
mode: str, optional, default: 'last'
Evaluation mode is either 'next' or 'last'.
If 'last', only evaluate the last item.
If 'next', evaluate every next item in the sequence,
exclude_unknowns: bool, optional, default: True
If `True`, unknown items will be ignored during model evaluation.
Expand All @@ -178,6 +196,7 @@ def __init__(
val_size=0.0,
fmt="SIT",
seed=None,
mode="last",
exclude_unknowns=True,
verbose=False,
**kwargs,
Expand All @@ -191,8 +210,11 @@ def __init__(
seed=seed,
exclude_unknowns=exclude_unknowns,
verbose=verbose,
mode=mode,
**kwargs,
)
assert mode in EVALUATION_MODES
self.mode = mode
self.global_sid_map = kwargs.get("global_sid_map", OrderedDict())

def _build_datasets(self, train_data, test_data, val_data=None):
Expand Down Expand Up @@ -263,6 +285,7 @@ def eval(
ranking_metrics,
user_based=False,
verbose=False,
mode="last",
**kwargs,
):
metric_avg_results = OrderedDict()
Expand All @@ -275,6 +298,7 @@ def eval(
test_set=test_set,
user_based=user_based,
exclude_unknowns=exclude_unknowns,
mode=mode,
verbose=verbose,
)

Expand Down

0 comments on commit dad914e

Please sign in to comment.