Skip to content

Commit

Permalink
Weighted LinkPredRecall (#9947)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 15, 2025
1 parent 8db0702 commit 50f5622
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))
- Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
16 changes: 15 additions & 1 deletion test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,33 @@ def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k):
def test_recall():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])
edge_label_weight = torch.tensor([4.0, 1.0, 2.0, 3.0, 0.5])

metric = LinkPredRecall(k=2)
assert str(metric) == 'LinkPredRecall(k=2)'
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5))

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index)
metric.compute()
metric.reset()

metric = LinkPredRecall(k=2, weighted=True)
assert str(metric) == 'LinkPredRecall(k=2, weighted=True)'
with pytest.raises(ValueError, match="'edge_label_weight'"):
metric.update(pred_index_mat, edge_label_index)

metric.update(pred_index_mat, edge_label_index, edge_label_weight)
result = metric.compute()
assert float(result) == pytest.approx(0.5 * (5.0 / 7.0 + 3.0 / 3.5))

# Test with `k > pred_index_mat.size(1)`:
metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)
metric.compute()
metric.reset()


def test_f1():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
Expand Down
81 changes: 63 additions & 18 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,51 @@ def label_count(self) -> Tensor:
self._label_count = label_count
return label_count

@property
def label_weight_sum(self) -> Tensor:
r"""The sum of edge label weights for every example."""
if self.edge_label_weight is None:
return self.label_count

if hasattr(self, '_label_weight_sum'):
return self._label_weight_sum # type: ignore

label_weight_sum = scatter(
self.edge_label_weight,
self.edge_label_index[0],
dim=0,
dim_size=self.pred_index_mat.size(0),
reduce='sum',
)

self._label_weight_sum = label_weight_sum
return label_weight_sum

@property
def edge_label_weight_pos(self) -> Optional[Tensor]:
r"""Returns the position of edge label weights in descending order
within example-wise buckets.
"""
if self.edge_label_weight is None:
return None

if hasattr(self, '_edge_label_weight_pos'):
return self._edge_label_weight_pos # type: ignore

# Get the permutation via two sorts: One globally on the weights,
# followed by a (stable) sort on the example indices.
perm1 = self.edge_label_weight.argsort(descending=True)
perm2 = self.edge_label_index[0][perm1].argsort(stable=True)
perm = perm1[perm2]
# Invert the permutation to get the final position:
pos = torch.empty_like(perm)
pos[perm] = torch.arange(perm.size(0), device=perm.device)
# Normalize position to zero within all buckets:
pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]

self._edge_label_weight_pos = pos
return pos


class LinkPredMetric(BaseMetric):
r"""An abstract class for computing link prediction retrieval metrics.
Expand Down Expand Up @@ -231,7 +276,9 @@ def __init__(

if isinstance(metrics, (list, tuple)):
metrics = {
f'{metric.__class__.__name__}@{metric.k}': metric
(f'{"Weighted" if metric.weighted else ""}'
f'{metric.__class__.__name__}@{metric.k}'):
metric
for metric in metrics
}
assert len(metrics) > 0
Expand Down Expand Up @@ -301,6 +348,10 @@ def update( # type: ignore
data.edge_label_weight = None
if hasattr(data, '_pred_rel_mat'):
data._pred_rel_mat = data._pred_rel_mat != 0.0
if hasattr(data, '_label_weight_sum'):
del data._label_weight_sum
if hasattr(data, '_edge_label_weight_pos'):
del data._edge_label_weight_pos

for metric in self.values():
if not metric.weighted:
Expand Down Expand Up @@ -343,11 +394,14 @@ class LinkPredRecall(LinkPredMetric):
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True
weighted: bool = False

def __init__(self, k: int, weighted: bool = False):
super().__init__(k=k)
self.weighted = weighted

def _compute(self, data: LinkPredMetricData) -> Tensor:
pred_rel_mat = data.pred_rel_mat[:, :self.k]
return pred_rel_mat.sum(dim=-1) / data.label_count.clamp(min=1e-7)
return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)


class LinkPredF1(LinkPredMetric):
Expand Down Expand Up @@ -397,7 +451,6 @@ class LinkPredNDCG(LinkPredMetric):
:obj:`edge_label_weight`. (default: :obj:`False`)
"""
higher_is_better: bool = True
weighted: bool = False

def __init__(self, k: int, weighted: bool = False):
super().__init__(k=k)
Expand All @@ -424,26 +477,18 @@ def _compute(self, data: LinkPredMetricData) -> Tensor:
idcg = self.idcg[data.label_count.clamp(max=self.k)]
else:
assert data.edge_label_weight is not None
# Sort weights within example-wise buckets via two sorts to get the
# local index order within buckets:
weight, batch = data.edge_label_weight, data.edge_label_index[0]
perm1 = weight.argsort(descending=True)
perm2 = batch[perm1].argsort(stable=True)
global_index = torch.empty_like(perm1)
global_index[perm1[perm2]] = torch.arange(
global_index.size(0), device=global_index.device)
local_index = global_index - cumsum(data.label_count)[batch]

# Get the discount per local index:
pos = data.edge_label_weight_pos
assert pos is not None

discount = torch.cat([
self.discount,
self.discount.new_full((1, ), fill_value=float('inf')),
])
discount = discount[local_index.clamp(max=self.k + 1)]
discount = discount[pos.clamp(max=self.k + 1)]

idcg = scatter( # Apply discount and aggregate:
weight / discount,
batch,
data.edge_label_weight / discount,
data.edge_label_index[0],
dim_size=data.pred_index_mat.size(0),
reduce='sum',
)
Expand Down

0 comments on commit 50f5622

Please sign in to comment.