-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathneural_collaborative_filtering.py
30 lines (27 loc) · 1.15 KB
/
neural_collaborative_filtering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
class NeuralCollaborativeFiltering(torch.nn.Module):
"""
A pytorch implementation of Neural Collaborative Filtering.
Reference:
X He, et al. Neural Collaborative Filtering, 2017.
"""
def __init__(self, field_dims, user_field_idx, item_field_idx, embed_dim, mlp_dims, dropout):
super().__init__()
self.user_field_idx = user_field_idx
self.item_field_idx = item_field_idx
self.embedding = FeaturesEmbedding(field_dims, embed_dim)
self.embed_output_dim = len(field_dims) * embed_dim
self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout, output_layer=False)
self.fc = torch.nn.Linear(mlp_dims[-1] + embed_dim, 1)
def forward(self, x):
"""
:param x: Long tensor of size ``(batch_size, num_user_fields)``
"""
x = self.embedding(x)
user_x = x[:, self.user_field_idx].squeeze(1)
item_x = x[:, self.item_field_idx].squeeze(1)
x = self.mlp(x.view(-1, self.embed_output_dim))
gmf = user_x * item_x
x = torch.cat([gmf, x], dim=1)
x = self.fc(x).squeeze(1)
return torch.sigmoid(x)