-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrunPopK.py
37 lines (26 loc) · 1.38 KB
/
runPopK.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from utils.arg_parser import extract_args_from_json
from utils.data_provider import split_dataset
from utils.reset_seed import set_seeds
from dataloaders.TestDataLoader import NoAdditionalInfoTestDataLoader
from models.PopK import PopularKSlateGeneration
from utils.experiment_builder_plain import ExperimentBuilderPlain
from torch.utils.data import DataLoader
class ExperimentBuilderPopK(ExperimentBuilderPlain):
def eval_iteration(self):
return self.model.forward()
def experiments_run():
configs = extract_args_from_json()
print(configs)
set_seeds(configs['seed'])
df_train, df_test, df_train_matrix, df_test_matrix, movies_categories, titles = split_dataset(configs)
test_dataset = NoAdditionalInfoTestDataLoader(df_test, df_test_matrix)
test_loader = DataLoader(test_dataset, batch_size=configs['test_batch_size'],
shuffle=True, num_workers=4, drop_last=True)
for slate_size in configs['slate_size']:
print(f'Test for {slate_size}')
model = PopularKSlateGeneration(slate_size, df_train, df_train_matrix, configs['test_batch_size'])
experiment_builder = ExperimentBuilderPopK(model, test_loader, len(df_train_matrix.columns), movies_categories,
titles, configs)
experiment_builder.run_experiment()
if __name__ == '__main__':
experiments_run()