diff --git a/.gitignore b/.gitignore index bb30b8566..5ff0a0ffa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ tests/vocab.pkl .idea/ +.vscode/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 7d4ef4b56..6c2ff18ab 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -52,6 +52,7 @@ from .ncf import GMF from .ncf import MLP from .ncf import NeuMF +from .ncf import GMF_PyTorch, MLP_PyTorch, NeuMF_PyTorch from .ngcf import NGCF from .nmf import NMF from .online_ibpr import OnlineIBPR @@ -74,4 +75,3 @@ "FM model is only supported on Linux.\n" + "Windows executable can be found at http://www.libfm.org." ) - diff --git a/cornac/models/ncf/__init__.py b/cornac/models/ncf/__init__.py index 5a94881a5..4b663b224 100644 --- a/cornac/models/ncf/__init__.py +++ b/cornac/models/ncf/__init__.py @@ -16,3 +16,6 @@ from .recom_gmf import GMF from .recom_mlp import MLP from .recom_neumf import NeuMF +from .pytorch_gmf import GMF_PyTorch +from .pytorch_mlp import MLP_PyTorch +from .pytorch_neumf import NeuMF_PyTorch diff --git a/cornac/models/ncf/pytorch_gmf.py b/cornac/models/ncf/pytorch_gmf.py new file mode 100644 index 000000000..e90757548 --- /dev/null +++ b/cornac/models/ncf/pytorch_gmf.py @@ -0,0 +1,267 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import numpy as np +import torch +import torch.nn as nn +from tqdm.auto import trange + +from .pytorch_ncf_base import NCFBase_PyTorch +from ...exception import ScoreException + + +class GMF_PyTorch(NCFBase_PyTorch): + """Generalized Matrix Factorization. + + Parameters + ---------- + num_factors: int, optional, default: 8 + Embedding size of MF model. + + regs: float, optional, default: 0. + Regularization for user and item embeddings. + + num_epochs: int, optional, default: 20 + Number of epochs. + + batch_size: int, optional, default: 256 + Batch size. + + num_neg: int, optional, default: 4 + Number of negative instances to pair with a positive instance. + + lr: float, optional, default: 0.001 + Learning rate. + + learner: str, optional, default: 'adam' + Specify an optimizer: adagrad, adam, rmsprop, sgd + + early_stopping: {min_delta: float, patience: int}, optional, default: None + If `None`, no early stopping. Meaning of the arguments: + + - `min_delta`: the minimum increase in monitored value on validation set to be considered as improvement, \ + i.e. an increment of less than min_delta will count as no improvement. + + - `patience`: number of epochs with no improvement after which training should be stopped. + + name: string, optional, default: 'GMF' + Name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + pre-trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + References + ---------- + * He, X., Liao, L., Zhang, H., Nie, L., Hu, X., & Chua, T. S. (2017, April). Neural collaborative filtering. \ + In Proceedings of the 26th international conference on world wide web (pp. 173-182). + """ + + def __init__( + self, + name="GMF-PyTorch", + num_factors=8, + num_epochs=20, + batch_size=256, + num_neg=4, + lr=1e-3, + reg=0.0, + learner="adam", + early_stopping=None, + trainable=True, + verbose=True, + seed=None, + use_pretrain: bool = False, + use_NeuMF: bool = False, + pretrained_GMF=None, + sinkhorn=False, + alpha=1, + df1=None, + df2=None, + args=None, + ): + super().__init__( + name=name, + num_factors=num_factors, + trainable=trainable, + verbose=verbose, + num_epochs=num_epochs, + batch_size=batch_size, + num_neg=num_neg, + lr=lr, + reg=reg, + learner=learner, + early_stopping=early_stopping, + seed=seed, + use_pretrain=use_pretrain, + use_NeuMF=use_NeuMF, + pretrained_GMF=pretrained_GMF, + ) + + self.sinkhorn = sinkhorn + self.alpha = alpha + self.df1 = df1 + self.df2 = df2 + self.args = args + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + super().fit(train_set, val_set) + + if self.trainable is False: + return self + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.seed) + + from .pytorch_ncf_base import GMF_torch as GMF + + self.model = GMF( + self.num_users, + self.num_items, + self.num_factors, + self.use_pretrain, + self.use_NeuMF, + self.pretrained_GMF, + ).to(self.device) + + criteria = nn.MSELoss(reduction="sum") + optimizer = self.learner( + self.model.parameters(), + lr=self.lr, + weight_decay=self.reg, + ) + + loop = trange(self.num_epochs, disable=not self.verbose) + for _ in loop: + count = 0 + sum_loss = 0 + for batch_id, (batch_users, batch_items, batch_ratings) in enumerate( + self.train_set.uir_iter( + self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg + ) + ): + batch_users = torch.from_numpy(batch_users).to(self.device) + batch_items = torch.from_numpy(batch_items).to(self.device) + batch_ratings = torch.tensor(batch_ratings, dtype=torch.float).to( + self.device + ) + + optimizer.zero_grad() + outputs = self.model(batch_users, batch_items) + loss = criteria(outputs, batch_ratings) + loss.backward() + optimizer.step() + + count += len(batch_users) + sum_loss += loss.data.item() + + if batch_id % 10 == 0: + loop.set_postfix(loss=(sum_loss / count)) + + if self.sinkhorn: + df1 = self.df1 + df2 = self.df2 + args = self.args + assert df1 is not None and df2 is not None + import geomloss + + uid_df1 = df1["user_id"].unique() + uid_df2 = df2["user_id"].unique() + uidx_1 = torch.tensor([train_set.uid_map[key] for key in uid_df1]).to( + device + ) + uidx_2 = torch.tensor([train_set.uid_map[key] for key in uid_df2]).to( + device + ) + sinkhorn_loss = geomloss.SamplesLoss( + loss="sinkhorn", + p=1, + blur=args.epsilon, + scaling=args.scaling, + ) + l_s = self.alpha * sinkhorn_loss( + self.model.u_factors(uidx_1), self.model.u_factors(uidx_2) + ) + optimizer.zero_grad() + l_s.backward() + optimizer.step() + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + """ + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + item_ids = torch.from_numpy(np.arange(self.train_set.num_items)).to( + self.device + ) + user_ids = torch.tensor(user_idx).unsqueeze(0).to(self.device) + + known_item_scores = self.model.predict(user_ids, item_ids).squeeze() + return known_item_scores.cpu().numpy() + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + user_pred = self.model.predict(user_ids, item_ids).squeeze() + return user_pred.cpu().numpy() diff --git a/cornac/models/ncf/pytorch_mlp.py b/cornac/models/ncf/pytorch_mlp.py new file mode 100644 index 000000000..0294aa6be --- /dev/null +++ b/cornac/models/ncf/pytorch_mlp.py @@ -0,0 +1,227 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import torch +import torch.nn as nn +from tqdm.auto import trange + +from .pytorch_ncf_base import NCFBase_PyTorch +from ...exception import ScoreException + + +class MLP_PyTorch(NCFBase_PyTorch): + """Multi-Layer Perceptron. + + Parameters + ---------- + layers: list, optional, default: [64, 32, 16, 8] + MLP layers. Note that the first layer is the concatenation of + user and item embeddings. So layers[0]/2 is the embedding size. + + act_fn: str, default: 'relu' + Name of the activation function used for the MLP layers. + Supported functions: ['sigmoid', 'tanh', 'elu', 'relu', 'selu, 'relu6', 'leaky_relu'] + + reg_layers: list, optional, default: [0., 0., 0., 0.] + Regularization for each MLP layer, + reg_layers[0] is the regularization for embeddings. + + num_epochs: int, optional, default: 20 + Number of epochs. + + batch_size: int, optional, default: 256 + Batch size. + + num_neg: int, optional, default: 4 + Number of negative instances to pair with a positive instance. + + lr: float, optional, default: 0.001 + Learning rate. + + learner: str, optional, default: 'adam' + Specify an optimizer: adagrad, adam, rmsprop, sgd + + early_stopping: {min_delta: float, patience: int}, optional, default: None + If `None`, no early stopping. Meaning of the arguments: + + - `min_delta`: the minimum increase in monitored value on validation set to be considered as improvement, \ + i.e. an increment of less than min_delta will count as no improvement. + + - `patience`: number of epochs with no improvement after which training should be stopped. + + name: string, optional, default: 'MLP' + Name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + pre-trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + References + ---------- + * He, X., Liao, L., Zhang, H., Nie, L., Hu, X., & Chua, T. S. (2017, April). Neural collaborative filtering. \ + In Proceedings of the 26th international conference on world wide web (pp. 173-182). + """ + + def __init__( + self, + name="MLP-PyTorch", + num_factors=8, + layers=(64, 32, 16, 8), + act_fn="relu", + num_epochs=20, + batch_size=256, + num_neg=4, + lr=0.001, + reg=0.0, + learner="adam", + early_stopping=None, + trainable=True, + verbose=True, + seed=None, + use_pretrain: bool = False, + use_NeuMF: bool = False, + pretrained_MLP=None, + ): + super().__init__( + name=name, + num_factors=num_factors, + layers=layers, + act_fn=act_fn, + num_epochs=num_epochs, + batch_size=batch_size, + num_neg=num_neg, + lr=lr, + reg=reg, + learner=learner, + early_stopping=early_stopping, + trainable=trainable, + verbose=verbose, + seed=seed, + use_pretrain=use_pretrain, + use_NeuMF=use_NeuMF, + pretrained_MLP=pretrained_MLP, + ) + + def fit(self, train_set, val_set=None): + super().fit(train_set, val_set) + + if self.trainable is False: + return self + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.seed) + + from .pytorch_ncf_base import MLP_torch as MLP + + self.model = MLP( + num_users=self.num_users, + num_items=self.num_items, + num_factors=self.num_factors, + layers=self.layers, + act_fn=self.act_fn, + use_pretrain=self.use_pretrain, + use_NeuMF=self.use_NeuMF, + pretrained_MLP=self.pretrained_MLP, + ).to(self.device) + + criteria = nn.BCELoss() + optimizer = self.learner( + self.model.parameters(), + lr=self.lr, + weight_decay=self.reg, + ) + + loop = trange(self.num_epochs, disable=not self.verbose) + for _ in loop: + count = 0 + sum_loss = 0 + for batch_id, (batch_users, batch_items, batch_ratings) in enumerate( + self.train_set.uir_iter( + self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg + ) + ): + batch_users = torch.from_numpy(batch_users).to(self.device) + batch_items = torch.from_numpy(batch_items).to(self.device) + batch_ratings = torch.tensor(batch_ratings, dtype=torch.float).to( + self.device + ) + + optimizer.zero_grad() + outputs = self.model(batch_users, batch_items) + loss = criteria(outputs, batch_ratings) + loss.backward() + optimizer.step() + + count += len(batch_users) + sum_loss += loss.data.item() + + if batch_id % 10 == 0: + loop.set_postfix(loss=(sum_loss / count)) + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + """ + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + item_ids = torch.from_numpy(np.arange(self.train_set.num_items)).to( + self.device + ) + user_ids = torch.tensor(user_idx).unsqueeze(0).to(self.device) + + known_item_scores = self.model.predict(user_ids, item_ids).squeeze() + return known_item_scores.cpu().numpy() + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + user_ids = torch.tensor(user_idx).unsqueeze(0).to(self.device) + item_ids = torch.tensor(item_idx).unsqueeze(0).to(self.device) + user_pred = self.model.predict(user_ids, item_ids).squeeze() + return user_pred.cpu().numpy() diff --git a/cornac/models/ncf/pytorch_ncf_base.py b/cornac/models/ncf/pytorch_ncf_base.py new file mode 100644 index 000000000..211f09567 --- /dev/null +++ b/cornac/models/ncf/pytorch_ncf_base.py @@ -0,0 +1,429 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import torch +import torch.nn as nn + +from ..recommender import Recommender +from ...utils import get_rng + + +class NCFBase_PyTorch(Recommender): + """ + Parameters + ---------- + num_epochs: int, optional, default: 20 + Number of epochs. + + batch_size: int, optional, default: 256 + Batch size. + + num_neg: int, optional, default: 4 + Number of negative instances to pair with a positive instance. + + lr: float, optional, default: 0.001 + Learning rate. + + learner: str, optional, default: 'adam' + Specify an optimizer: adagrad, adam, rmsprop, sgd + + early_stopping: {min_delta: float, patience: int}, optional, default: None + If `None`, no early stopping. Meaning of the arguments: + + - `min_delta`: the minimum increase in monitored value on validation set to be considered as improvement, \ + i.e. an increment of less than min_delta will count as no improvement. + - `patience`: number of epochs with no improvement after which training should be stopped. + + name: string, optional, default: 'Torch-NCF' + Name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + pre-trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + """ + + def __init__( + self, + name="NCF", + num_factors=8, + layers=None, + act_fn="relu", + num_epochs=20, + batch_size=256, + num_neg=4, + lr=1e-3, + reg=0.0, + learner="adam", + early_stopping=None, + trainable=True, + verbose=True, + seed=None, + use_pretrain: bool = False, + use_NeuMF: bool = False, + pretrained_GMF=None, + pretrained_MLP=None, + ): + super().__init__(name=name, trainable=trainable, verbose=verbose) + self.num_factors = num_factors + self.layers = layers + self.num_epochs = num_epochs + self.batch_size = batch_size + self.num_neg = num_neg + self.lr = lr + self.reg = reg + self.early_stopping = early_stopping + self.seed = seed + self.rng = get_rng(seed) + self.use_pretrain = use_pretrain + self.use_NeuMF = use_NeuMF + self.pretrained_GMF = pretrained_GMF + self.pretrained_MLP = pretrained_MLP + + optimizer = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "rmsprop": torch.optim.RMSprop, + "adagrad": torch.optim.Adagrad, + } + self.learner = optimizer[learner.lower()] + + activation_functions = { + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "elu": nn.ELU(), + "selu": nn.SELU(), + "relu": nn.ReLU(), + "relu6": nn.ReLU6(), + "leakyrelu": nn.LeakyReLU(), + } + self.act_fn = activation_functions[act_fn.lower()] + + def fit(self, train_set, val_set=None): + """Fit the model to observations. + + Parameters + ---------- + train_set: :obj:`cornac.data.Dataset`, required + User-Item preference data as well as additional modalities. + + val_set: :obj:`cornac.data.Dataset`, optional, default: None + User-Item preference data for model selection purposes (e.g., early stopping). + + Returns + ------- + self : object + """ + Recommender.fit(self, train_set, val_set) + + if self.trainable: + if not hasattr(self, "user_embedding"): + self.num_users = self.train_set.num_users + self.num_items = self.train_set.num_items + + return self + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + + """ + raise NotImplementedError("The algorithm is not able to make score prediction!") + + def monitor_value(self): + """Calculating monitored value used for early stopping on validation set (`val_set`). + This function will be called by `early_stop()` function. + + Returns + ------- + res : float + Monitored value on validation set. + Return `None` if `val_set` is `None`. + """ + if self.val_set is None: + return None + + from ...metrics import Recall + from ...eval_methods import ranking_eval + + recall_20 = ranking_eval( + model=self, + metrics=[Recall(k=20)], + train_set=self.train_set, + test_set=self.val_set, + )[0][0] + + return recall_20 + + +class GMF_torch(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + num_factors: int = 8, + use_pretrain: bool = False, + use_NeuMF: bool = False, + pretrained_GMF=None, + ): + super(GMF_torch, self).__init__() + + self.pretrained_GMF = pretrained_GMF + self.num_users = num_users + self.num_items = num_items + self.use_pretrain = use_pretrain + self.use_NeuMF = use_NeuMF + + self.pretrained_GMF = pretrained_GMF + + self.user_embedding = nn.Embedding(num_users, num_factors) + self.item_embedding = nn.Embedding(num_items, num_factors) + + self.predict_layer = nn.Linear(num_factors, 1) + self.Sigmoid = nn.Sigmoid() + + if use_pretrain: + self._load_pretrained_model() + else: + self._init_weight() + + def _init_weight(self): + if not self.use_pretrain: + nn.init.normal_(self.user_embedding.weight, std=1e-2) + nn.init.normal_(self.item_embedding.weight, std=1e-2) + if not self.use_NeuMF: + nn.init.normal_(self.predict_layer.weight, std=1e-2) + + def _load_pretrained_model(self): + self.user_embedding.weight.data.copy_(self.pretrained_GMF.user_embedding.weight) + self.item_embedding.weight.data.copy_(self.pretrained_GMF.item_embedding.weight) + + def forward(self, users, items): + embedding_elementwise = self.user_embedding(users) * self.item_embedding(items) + if not self.use_NeuMF: + output = self.predict_layer(embedding_elementwise) + output = self.Sigmoid(output) + output = output.view(-1) + else: + output = embedding_elementwise + + return output + + def predict(self, users, items): + with torch.no_grad(): + preds = (self.user_embedding(users) * self.item_embedding(items)).sum( + dim=1, keepdim=True + ) + return preds.squeeze() + + +class MLP_torch(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + num_factors: int = 8, + layers=None, + act_fn=nn.ReLU(), + use_pretrain: bool = False, + use_NeuMF: bool = False, + pretrained_MLP=None, + ): + super(MLP_torch, self).__init__() + + if layers is None: + layers = [64, 32, 16, 8] + + self.pretrained_MLP = pretrained_MLP + self.num_users = num_users + self.num_items = num_items + self.use_pretrain = use_pretrain + self.user_embedding = nn.Embedding(num_users, layers[0] // 2) + self.item_embedding = nn.Embedding(num_items, layers[0] // 2) + self.use_NeuMF = use_NeuMF + MLP_layers = [] + + for idx, factor in enumerate(layers[:-1]): + # ith MLP layer (layer[i],layer[i]//2) -> #(i+1)th MLP layer (layer[i+1],layer[i+1]//2) + # ex) (64,32) -> (32,16) -> (16,8) + # MLP_layers.append(nn.Linear(factor, factor // 2)) + + MLP_layers.append(nn.Linear(factor, layers[idx + 1])) + MLP_layers.append(act_fn) + + # unpacking layers in to torch.nn.Sequential + self.MLP_model = nn.Sequential(*MLP_layers) + + self.predict_layer = nn.Linear(num_factors, 1) + self.Sigmoid = nn.Sigmoid() + + if self.use_pretrain: + self._load_pretrained_model() + else: + self._init_weight() + + def _init_weight(self): + if not self.use_pretrain: + nn.init.normal_(self.user_embedding.weight, std=1e-2) + nn.init.normal_(self.item_embedding.weight, std=1e-2) + for layer in self.MLP_model: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + if not self.use_NeuMF: + nn.init.normal_(self.predict_layer.weight, std=1e-2) + + def _load_pretrained_model(self): + self.user_embedding.weight.data.copy_(self.pretrained_MLP.user_embedding.weight) + self.item_embedding.weight.data.copy_(self.pretrained_MLP.item_embedding.weight) + for layer, pretrained_layer in zip( + self.MLP_model, self.pretrained_MLP.MLP_model + ): + if isinstance(layer, nn.Linear) and isinstance(pretrained_layer, nn.Linear): + layer.weight.data.copy_(pretrained_layer.weight) + layer.bias.data.copy_(pretrained_layer.bias) + + def forward(self, user, item): + embed_user = self.user_embedding(user) + embed_item = self.item_embedding(item) + embed_input = torch.cat((embed_user, embed_item), dim=-1) + output = self.MLP_model(embed_input) + + if not self.use_NeuMF: + output = self.predict_layer(output) + output = self.Sigmoid(output) + output = output.view(-1) + + return output + + def predict(self, users, items): + with torch.no_grad(): + embed_user = self.user_embedding(users) + if len(users) == 1: + # replicate user embedding to len(items) + embed_user = embed_user.repeat(len(items), 1) + embed_item = self.item_embedding(items) + embed_input = torch.cat((embed_user, embed_item), dim=-1) + output = self.MLP_model(embed_input) + + output = self.predict_layer(output) + output = self.Sigmoid(output) + output = output.view(-1) + return output + + def __call__(self, *args): + return self.forward(*args) + + +class NeuMF_torch(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + num_factors: int = 8, + layers=None, # layer for MLP + act_fn=nn.ReLU(), + use_pretrain: bool = False, + pretrained_GMF=None, + pretrained_MLP=None, + ): + super(NeuMF_torch, self).__init__() + + self.use_pretrain = use_pretrain + self.pretrained_GMF = pretrained_GMF + self.pretrained_MLP = pretrained_MLP + + # layer for MLP + if layers is None: + layers = [64, 32, 16, 8] + + self.predict_layer = nn.Linear(num_factors * 2, 1) + self.Sigmoid = nn.Sigmoid() + + self.GMF = GMF_torch( + num_users, + num_items, + num_factors, + use_pretrain=use_pretrain, + use_NeuMF=True, + pretrained_GMF=self.pretrained_GMF, + ) + self.MLP = MLP_torch( + num_users=num_users, + num_items=num_items, + num_factors=num_factors, + layers=layers, + act_fn=act_fn, + use_pretrain=use_pretrain, + use_NeuMF=True, + pretrained_MLP=self.pretrained_MLP, + ) + + if self.use_pretrain: + self._load_pretrain_model() + + if not self.use_pretrain: + nn.init.normal_(self.predict_layer.weight, std=1e-2) + + def _load_pretrain_model(self): + predict_weight = torch.cat( + [ + self.pretrained_GMF.predict_layer.weight, + self.pretrained_MLP.predict_layer.weight, + ], + dim=1, + ) + predict_bias = ( + self.pretrained_GMF.predict_layer.bias + + self.pretrained_MLP.predict_layer.bias + ) + self.predict_layer.weight.data.copy_(0.5 * predict_weight) + self.predict_layer.bias.data.copy_(0.5 * predict_bias) + + def forward(self, user, item): + before_last_layer_output = torch.cat( + (self.GMF(user, item), self.MLP(user, item)), dim=-1 + ) + output = self.predict_layer(before_last_layer_output) + output = self.Sigmoid(output) + return output.view(-1) + + def predict(self, users, items): + with torch.no_grad(): + if len(users) == 1: + # replicate user embedding to len(items) + users = users.repeat(len(items)) + # breakpoint() + before_last_layer_output = torch.cat( + (self.GMF(users, items), self.MLP(users, items)), dim=-1 + ) + preds = self.predict_layer(before_last_layer_output) + preds = self.Sigmoid(preds) + return preds.view(-1) diff --git a/cornac/models/ncf/pytorch_neumf.py b/cornac/models/ncf/pytorch_neumf.py new file mode 100644 index 000000000..1584c5bf2 --- /dev/null +++ b/cornac/models/ncf/pytorch_neumf.py @@ -0,0 +1,232 @@ +# Copyright 2018 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import torch +import torch.nn as nn +from tqdm.auto import trange + +from .pytorch_ncf_base import NCFBase_PyTorch +from ...exception import ScoreException + + +class NeuMF_PyTorch(NCFBase_PyTorch): + """Neural Matrix Factorization. + + Parameters + ---------- + num_factors: int, optional, default: 8 + Embedding size of MF model. + + layers: list, optional, default: [64, 32, 16, 8] + MLP layers. Note that the first layer is the concatenation of + user and item embeddings. So layers[0]/2 is the embedding size. + + act_fn: str, default: 'relu' + Name of the activation function used for the MLP layers. + Supported functions: ['sigmoid', 'tanh', 'elu', 'relu', 'selu, 'relu6', 'leaky_relu'] + + reg_mf: float, optional, default: 0. + Regularization for MF embeddings. + + reg_layers: list, optional, default: [0., 0., 0., 0.] + Regularization for each MLP layer, + reg_layers[0] is the regularization for embeddings. + + num_epochs: int, optional, default: 20 + Number of epochs. + + batch_size: int, optional, default: 256 + Batch size. + + num_neg: int, optional, default: 4 + Number of negative instances to pair with a positive instance. + + lr: float, optional, default: 0.001 + Learning rate. + + learner: str, optional, default: 'adam' + Specify an optimizer: adagrad, adam, rmsprop, sgd + + early_stopping: {min_delta: float, patience: int}, optional, default: None + If `None`, no early stopping. Meaning of the arguments: + + - `min_delta`: the minimum increase in monitored value on validation set to be considered as improvement, \ + i.e. an increment of less than min_delta will count as no improvement. + + - `patience`: number of epochs with no improvement after which training should be stopped. + + name: string, optional, default: 'NeuMF' + Name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trained and Cornac assumes that the model is already \ + pre-trained. + + verbose: boolean, optional, default: False + When True, some running logs are displayed. + + seed: int, optional, default: None + Random seed for parameters initialization. + + References + ---------- + * He, X., Liao, L., Zhang, H., Nie, L., Hu, X., & Chua, T. S. (2017, April). Neural collaborative filtering. \ + In Proceedings of the 26th international conference on world wide web (pp. 173-182). + """ + + def __init__( + self, + name="NeuMF-PyTorch", + num_factors=8, + layers=(64, 32, 16, 8), + act_fn="relu", + num_epochs=20, + batch_size=256, + num_neg=4, + lr=0.001, + reg=0.0, + learner="adam", + early_stopping=None, + trainable=True, + verbose=True, + seed=None, + use_pretrain: bool = False, + pretrained_GMF=None, + pretrained_MLP=None, + ): + super().__init__( + name=name, + num_factors=num_factors, + layers=layers, + act_fn=act_fn, + num_epochs=num_epochs, + batch_size=batch_size, + num_neg=num_neg, + lr=lr, + reg=reg, + learner=learner, + early_stopping=early_stopping, + trainable=trainable, + verbose=verbose, + seed=seed, + use_pretrain=use_pretrain, + pretrained_GMF=pretrained_GMF, + pretrained_MLP=pretrained_MLP, + ) + + def fit(self, train_set, val_set=None): + super().fit(train_set, val_set) + + if self.trainable is False: + return self + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.seed) + + from .pytorch_ncf_base import NeuMF_torch as NeuMF + + self.model = NeuMF( + num_users=self.num_users, + num_items=self.num_items, + num_factors=self.num_factors, + layers=self.layers, + use_pretrain=self.use_pretrain, + pretrained_GMF=self.pretrained_GMF, + pretrained_MLP=self.pretrained_MLP, + ).to(self.device) + + criteria = nn.BCELoss() + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.reg, + ) + + loop = trange(self.num_epochs, disable=not self.verbose) + for _ in loop: + count = 0 + sum_loss = 0 + for batch_id, (batch_users, batch_items, batch_ratings) in enumerate( + self.train_set.uir_iter( + self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg + ) + ): + batch_users = torch.from_numpy(batch_users).to(self.device) + batch_items = torch.from_numpy(batch_items).to(self.device) + batch_ratings = torch.tensor(batch_ratings, dtype=torch.float).to( + self.device + ) + + optimizer.zero_grad() + outputs = self.model(batch_users, batch_items) + loss = criteria(outputs, batch_ratings) + loss.backward() + optimizer.step() + + count += len(batch_users) + sum_loss += loss.data.item() + + if batch_id % 10 == 0: + loop.set_postfix(loss=(sum_loss / count)) + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + """ + if item_idx is None: + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + item_ids = torch.from_numpy(np.arange(self.train_set.num_items)).to( + self.device + ) + user_ids = torch.tensor(user_idx).unsqueeze(0).to(self.device) + + known_item_scores = self.model.predict(user_ids, item_ids).squeeze() + return known_item_scores.cpu().numpy() + else: + if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( + item_idx + ): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + user_ids = torch.tensor(user_idx).unsqueeze(0).to(self.device) + item_ids = torch.tensor(item_idx).unsqueeze(0).to(self.device) + user_pred = self.model.predict(user_ids, item_ids).squeeze() + return user_pred.cpu().numpy()