Skip to content

Commit

Permalink
Support weighted LinkPredNDCG metric (#9945)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 15, 2025
1 parent 51a5aa0 commit 371678c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
- Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
16 changes: 16 additions & 0 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def test_map():
def test_ndcg():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]])
edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5])

metric = LinkPredNDCG(k=2)
assert str(metric) == 'LinkPredNDCG(k=2)'
Expand All @@ -126,6 +127,21 @@ def test_ndcg():
metric.compute()
metric.reset()

metric = LinkPredNDCG(k=2, weighted=True)
assert str(metric) == 'LinkPredNDCG(k=2, weighted=True)'
with pytest.raises(ValueError, match="'edge_label_weight'"):
metric.update(pred_index_mat, edge_label_index)

metric.update(pred_index_mat, edge_label_index, edge_label_weight)
result = metric.compute()

assert float(result) == pytest.approx(0.7854486)

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)
metric.compute()
metric.reset()


def test_mrr():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
Expand Down
50 changes: 41 additions & 9 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import Tensor

from torch_geometric.index import index2ptr
from torch_geometric.utils import cumsum, scatter

try:
Expand Down Expand Up @@ -58,7 +59,10 @@ def pred_rel_mat(self) -> Tensor:

pred_rel_mat = flat_label_index[pos] == flat_pred_index # Find matches
if edge_label_weight is not None:
pred_rel_mat = edge_label_weight[pos].where(pred_rel_mat, 0.0)
pred_rel_mat = edge_label_weight[pos].where(
pred_rel_mat,
pred_rel_mat.new_zeros(1),
)
pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())

self._pred_rel_mat = pred_rel_mat
Expand Down Expand Up @@ -401,19 +405,47 @@ def __init__(self, k: int, weighted: bool = False):
self.weighted = weighted

dtype = torch.get_default_dtype()
multiplier = 1.0 / torch.arange(2, k + 2, dtype=dtype).log2()
discount = torch.arange(2, k + 2, dtype=dtype).log2()

self.multiplier: Tensor
self.register_buffer('multiplier', multiplier)
self.discount: Tensor
self.register_buffer('discount', discount)

self.idcg: Tensor
self.register_buffer('idcg', cumsum(multiplier))
if not weighted:
self.register_buffer('idcg', cumsum(1.0 / discount))
else:
self.idcg = None

def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
multiplier = self.multiplier[:pred_rel_mat.size(1)].view(1, -1)
dcg = (pred_rel_mat * multiplier).sum(dim=-1)
idcg = self.idcg[data.label_count.clamp(max=self.k)]
discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)
dcg = (pred_rel_mat / discount).sum(dim=-1)

if not self.weighted:
assert self.idcg is not None
idcg = self.idcg[data.label_count.clamp(max=self.k)]
else:
assert data.edge_label_weight is not None
# Sort weights in buckets via two sorts:
weight, perm = data.edge_label_weight.sort(descending=True)
batch = data.edge_label_index[0][perm]
batch, perm = torch.sort(batch, stable=True)
weight = weight[perm]

# Shrink buckets that are larger than `k`:
arange = torch.arange(batch.size(0), device=batch.device)
ptr = index2ptr(batch, size=data.pred_index_mat.size(0))
batched_arange = arange - ptr[batch]
mask = batched_arange < self.k
batch = batch[mask]
batched_arange = batched_arange[mask]
weight = weight[mask]

# Compute ideal relevance matrix:
irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k)
irel_mat[batch * self.k + batched_arange] = weight
irel_mat = irel_mat.view(-1, self.k)

idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1)

out = dcg / idcg
out[out.isnan() | out.isinf()] = 0.0
Expand Down

0 comments on commit 371678c

Please sign in to comment.