From a470d5c877b6007728d7660f1925ff7a763cf755 Mon Sep 17 00:00:00 2001 From: darrylong Date: Tue, 17 Oct 2023 13:53:19 +0800 Subject: [PATCH] Optimize LightGCN Model (#531) * Generated model base from LightGCN * wip * wip example * add self-connection * refactor code * added sanity check * Changed train batch size in example to 1024 * Updated readme for example folder * Update Readme * update docs * Update block comment * WIP * Updated validation metric * Updated message handling * Added legacy lightgcn for comparison purposes * Changed to follow 'a_k = 1/(k+1)', k instead of i * Changed early stopping technique to follow NGCF * remove test_batchsize, early stop verbose to false * Changed parameters to align with paper and ngcf * refractor codes * update docstring * change param name to 'batch_size' * Fix paper reference --------- Co-authored-by: tqtg Co-authored-by: Quoc-Tuan Truong --- cornac/models/lightgcn/lightgcn.py | 159 +++++++++++++---------- cornac/models/lightgcn/recom_lightgcn.py | 115 ++++++---------- examples/lightgcn_example.py | 6 +- 3 files changed, 132 insertions(+), 148 deletions(-) diff --git a/cornac/models/lightgcn/lightgcn.py b/cornac/models/lightgcn/lightgcn.py index 22983ab3b..eedd72b42 100644 --- a/cornac/models/lightgcn/lightgcn.py +++ b/cornac/models/lightgcn/lightgcn.py @@ -1,9 +1,14 @@ import torch import torch.nn as nn +import torch.nn.functional as F import dgl import dgl.function as fn +USER_KEY = "user" +ITEM_KEY = "item" + + def construct_graph(data_set): """ Generates graph given a cornac data set @@ -14,89 +19,109 @@ def construct_graph(data_set): The data set as provided by cornac """ user_indices, item_indices, _ = data_set.uir_tuple - user_nodes, item_nodes = ( - torch.from_numpy(user_indices), - torch.from_numpy( - item_indices + data_set.total_users - ), # increment item node idx by num users - ) - u = torch.cat([user_nodes, item_nodes], dim=0) - v = torch.cat([item_nodes, user_nodes], dim=0) + data_dict = { + (USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices), + (ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices), + } + num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items} - g = dgl.graph((u, v), num_nodes=(data_set.total_users + data_set.total_items)) - return g + return dgl.heterograph(data_dict, num_nodes_dict=num_dict) class GCNLayer(nn.Module): - def __init__(self): + def __init__(self, norm_dict): super(GCNLayer, self).__init__() - def forward(self, graph, src_embedding, dst_embedding): - with graph.local_scope(): - inner_product = torch.cat((src_embedding, dst_embedding), dim=0) - - out_degs = graph.out_degrees().to(src_embedding.device).float().clamp(min=1) - norm_out_degs = torch.pow(out_degs, -0.5).view(-1, 1) # D^-1/2 - - inner_product = inner_product * norm_out_degs - - graph.ndata["h"] = inner_product - graph.update_all( - message_func=fn.copy_u("h", "m"), reduce_func=fn.sum("m", "h") - ) - - res = graph.ndata["h"] - - in_degs = graph.in_degrees().to(src_embedding.device).float().clamp(min=1) - norm_in_degs = torch.pow(in_degs, -0.5).view(-1, 1) # D^-1/2 - - res = res * norm_in_degs - return res + # norm + self.norm_dict = norm_dict + + def forward(self, g, feat_dict): + funcs = {} # message and reduce functions dict + # for each type of edges, compute messages and reduce them all + for srctype, etype, dsttype in g.canonical_etypes: + src, dst = g.edges(etype=(srctype, etype, dsttype)) + norm = self.norm_dict[(srctype, etype, dsttype)] + # TODO: CHECK HERE + messages = norm * feat_dict[srctype][src] # compute messages + g.edges[(srctype, etype, dsttype)].data[ + etype + ] = messages # store in edata + funcs[(srctype, etype, dsttype)] = ( + fn.copy_e(etype, "m"), + fn.sum("m", "h"), + ) # define message and reduce functions + + g.multi_update_all( + funcs, "sum" + ) # update all, reduce by first type-wisely then across different types + feature_dict = {} + for ntype in g.ntypes: + h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize + feature_dict[ntype] = h + return feature_dict class Model(nn.Module): - def __init__(self, user_size, item_size, hidden_size, num_layers=3, device=None): + def __init__(self, g, in_size, num_layers, lambda_reg, device=None): super(Model, self).__init__() - self.user_size = user_size - self.item_size = item_size - self.hidden_size = hidden_size - self.embedding_weights = self._init_weights() - self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)]) + self.norm_dict = dict() + self.lambda_reg = lambda_reg self.device = device - def forward(self, graph): - user_embedding = self.embedding_weights["user_embedding"] - item_embedding = self.embedding_weights["item_embedding"] + for srctype, etype, dsttype in g.canonical_etypes: + src, dst = g.edges(etype=(srctype, etype, dsttype)) + dst_degree = g.in_degrees( + dst, etype=(srctype, etype, dsttype) + ).float() # obtain degrees + src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() + norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm + self.norm_dict[(srctype, etype, dsttype)] = norm - for i, layer in enumerate(self.layers, start=1): - if i == 1: - embeddings = layer(graph, user_embedding, item_embedding) - else: - embeddings = layer( - graph, embeddings[: self.user_size], embeddings[self.user_size:] - ) + self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)]) - user_embedding = user_embedding + embeddings[: self.user_size] * ( - 1 / (i + 1) - ) - item_embedding = item_embedding + embeddings[self.user_size:] * ( - 1 / (i + 1) - ) - - return user_embedding, item_embedding + self.initializer = nn.init.xavier_uniform_ - def _init_weights(self): - initializer = nn.init.xavier_uniform_ - - weights_dict = nn.ParameterDict( + # embeddings for different types of nodes + self.feature_dict = nn.ParameterDict( { - "user_embedding": nn.Parameter( - initializer(torch.empty(self.user_size, self.hidden_size)) - ), - "item_embedding": nn.Parameter( - initializer(torch.empty(self.item_size, self.hidden_size)) - ), + ntype: nn.Parameter( + self.initializer(torch.empty(g.num_nodes(ntype), in_size)) + ) + for ntype in g.ntypes } ) - return weights_dict + + def forward(self, g, users=None, pos_items=None, neg_items=None): + h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes} + # obtain features of each layer and concatenate them all + user_embeds = h_dict[USER_KEY] + item_embeds = h_dict[ITEM_KEY] + + for k, layer in enumerate(self.layers): + h_dict = layer(g, h_dict) + user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1)) + item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1)) + + u_g_embeddings = user_embeds if users is None else user_embeds[users, :] + pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :] + neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :] + + return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings + + def loss_fn(self, users, pos_items, neg_items): + pos_scores = (users * pos_items).sum(1) + neg_scores = (users * neg_items).sum(1) + + bpr_loss = F.softplus(neg_scores - pos_scores).mean() + reg_loss = ( + (1 / 2) + * ( + torch.norm(users) ** 2 + + torch.norm(pos_items) ** 2 + + torch.norm(neg_items) ** 2 + ) + / len(users) + ) + + return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss diff --git a/cornac/models/lightgcn/recom_lightgcn.py b/cornac/models/lightgcn/recom_lightgcn.py index 669b4ea62..635fb67a3 100644 --- a/cornac/models/lightgcn/recom_lightgcn.py +++ b/cornac/models/lightgcn/recom_lightgcn.py @@ -28,21 +28,18 @@ class LightGCN(Recommender): name: string, default: 'LightGCN' The name of the recommender model. + emb_size: int, default: 64 + Size of the node embeddings. + num_epochs: int, default: 1000 - Maximum number of iterations or the number of epochs + Maximum number of iterations or the number of epochs. learning_rate: float, default: 0.001 The learning rate that determines the step size at each iteration - train_batch_size: int, default: 1024 + batch_size: int, default: 1024 Mini-batch size used for train set - test_batch_size: int, default: 100 - Mini-batch size used for test set - - hidden_dim: int, default: 64 - The embedding size of the model - num_layers: int, default: 3 Number of LightGCN Layers @@ -80,11 +77,10 @@ class LightGCN(Recommender): def __init__( self, name="LightGCN", + emb_size=64, num_epochs=1000, learning_rate=0.001, - train_batch_size=1024, - test_batch_size=100, - hidden_dim=64, + batch_size=1024, num_layers=3, early_stopping=None, lambda_reg=1e-4, @@ -93,13 +89,11 @@ def __init__( seed=2020, ): super().__init__(name=name, trainable=trainable, verbose=verbose) - + self.emb_size = emb_size self.num_epochs = num_epochs self.learning_rate = learning_rate - self.hidden_dim = hidden_dim + self.batch_size = batch_size self.num_layers = num_layers - self.train_batch_size = train_batch_size - self.test_batch_size = test_batch_size self.early_stopping = early_stopping self.lambda_reg = lambda_reg self.seed = seed @@ -135,19 +129,15 @@ def fit(self, train_set, val_set=None): if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) + graph = construct_graph(train_set).to(self.device) model = Model( - train_set.total_users, - train_set.total_items, - self.hidden_dim, + graph, + self.emb_size, self.num_layers, + self.lambda_reg, ).to(self.device) - graph = construct_graph(train_set).to(self.device) - - optimizer = torch.optim.Adam( - model.parameters(), lr=self.learning_rate, weight_decay=self.lambda_reg - ) - loss_fn = torch.nn.BCELoss(reduction="sum") + optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) # model training pbar = trange( @@ -163,35 +153,26 @@ def fit(self, train_set, val_set=None): accum_loss = 0.0 for batch_u, batch_i, batch_j in tqdm( train_set.uij_iter( - batch_size=self.train_batch_size, + batch_size=self.batch_size, shuffle=True, ), desc="Epoch", - total=train_set.num_batches(self.train_batch_size), + total=train_set.num_batches(self.batch_size), leave=False, position=1, disable=not self.verbose, ): - user_embeddings, item_embeddings = model(graph) - - batch_u = torch.from_numpy(batch_u).long().to(self.device) - batch_i = torch.from_numpy(batch_i).long().to(self.device) - batch_j = torch.from_numpy(batch_j).long().to(self.device) - - user_embed = user_embeddings[batch_u] - positive_item_embed = item_embeddings[batch_i] - negative_item_embed = item_embeddings[batch_j] - - ui_scores = (user_embed * positive_item_embed).sum(dim=1) - uj_scores = (user_embed * negative_item_embed).sum(dim=1) + u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model( + graph, batch_u, batch_i, batch_j + ) - loss = loss_fn( - torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores) + batch_loss, batch_bpr_loss, batch_reg_loss = model.loss_fn( + u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings ) - accum_loss += loss.cpu().item() + accum_loss += batch_loss.cpu().item() * len(batch_u) optimizer.zero_grad() - loss.backward() + batch_loss.backward() optimizer.step() accum_loss /= len(train_set.uir_tuple[0]) # normalize over all observations @@ -199,17 +180,16 @@ def fit(self, train_set, val_set=None): # store user and item embedding matrices for prediction model.eval() - self.U, self.V = model(graph) + u_embs, i_embs, _ = model(graph) + # we will use numpy for faster prediction in the score function, no need torch + self.U = u_embs.cpu().detach().numpy() + self.V = i_embs.cpu().detach().numpy() if self.early_stopping is not None and self.early_stop( **self.early_stopping ): break - # we will use numpy for faster prediction in the score function, no need torch - self.U = self.U.cpu().detach().numpy() - self.V = self.V.cpu().detach().numpy() - def monitor_value(self): """Calculating monitored value used for early stopping on validation set (`val_set`). This function will be called by `early_stop()` function. @@ -223,38 +203,17 @@ def monitor_value(self): if self.val_set is None: return None - import torch + from ...metrics import Recall + from ...eval_methods import ranking_eval - loss_fn = torch.nn.BCELoss(reduction="sum") - accum_loss = 0.0 - pbar = tqdm( - self.val_set.uij_iter(batch_size=self.test_batch_size), - desc="Validation", - total=self.val_set.num_batches(self.test_batch_size), - leave=False, - position=1, - disable=not self.verbose, - ) - for batch_u, batch_i, batch_j in pbar: - batch_u = torch.from_numpy(batch_u).long().to(self.device) - batch_i = torch.from_numpy(batch_i).long().to(self.device) - batch_j = torch.from_numpy(batch_j).long().to(self.device) - - user_embed = self.U[batch_u] - positive_item_embed = self.V[batch_i] - negative_item_embed = self.V[batch_j] - - ui_scores = (user_embed * positive_item_embed).sum(dim=1) - uj_scores = (user_embed * negative_item_embed).sum(dim=1) - - loss = loss_fn( - torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores) - ) - accum_loss += loss.cpu().item() - pbar.set_postfix(val_loss=accum_loss) - - accum_loss /= len(self.val_set.uir_tuple[0]) - return -accum_loss # higher is better -> smaller loss is better + recall_20 = ranking_eval( + model=self, + metrics=[Recall(k=20)], + train_set=self.train_set, + test_set=self.val_set + )[0][0] + + return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF. def score(self, user_idx, item_idx=None): """Predict the scores/ratings of a user for an item. diff --git a/examples/lightgcn_example.py b/examples/lightgcn_example.py index 11efb6869..48789ebea 100644 --- a/examples/lightgcn_example.py +++ b/examples/lightgcn_example.py @@ -36,10 +36,10 @@ # Instantiate the LightGCN model lightgcn = cornac.models.LightGCN( seed=123, - num_epochs=2000, + num_epochs=1000, num_layers=3, - early_stopping={"min_delta": 1e-4, "patience": 3}, - train_batch_size=256, + early_stopping={"min_delta": 1e-4, "patience": 50}, + batch_size=1024, learning_rate=0.001, lambda_reg=1e-4, verbose=True