-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_preprocess.py
47 lines (40 loc) · 1.74 KB
/
bert_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from pathlib import Path
from sentence_transformers import SentenceTransformer
import spacy
from dataset import Dataset
import sys, random
from spacy.lang.en import English
from collections import defaultdict
import numpy as np
import json
model = SentenceTransformer('bert-base-nli-mean-tokens')
dataset = Dataset(Path("/home/ianic/tar/pan19-author-profiling-training-2019-02-18/en"),
Path("/home/ianic/tar/pan19-author-profiling-training-2019-02-18/en_labels/truth.txt"),
Path("/home/ianic/tar/pan19-author-profiling-test-2019-04-29/en"),
Path("/home/ianic/tar/pan19-author-profiling-test-2019-04-29/truth.txt"))
train_data, train_labels, test_data, test_labels = dataset.get_data()
print(len(train_data), len(train_labels))
print(len(test_data), len(test_labels))
nlp = English() # just the language with no model
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
tweet_embeddings = defaultdict(list)
for cnt, account in enumerate(train_data):
print(f"account: {cnt}")
tweets = train_data[account]
for i, tweet in enumerate(tweets):
doc = nlp(tweet)
sentences = [str(sentence) for sentence in doc.sents]
sentence_embeddings = model.encode(sentences)
tweet_embed = 0
for j in range(len(sentence_embeddings)):
tweet_embed += sentence_embeddings[j] / len(sentence_embeddings)
tweet_embeddings[account].append(tweet_embed.tolist())
with open('bert_embeddings_train.json', 'w') as fp:
json.dump(tweet_embeddings, fp)
# with open('bert_embeddings.json', 'r') as fp:
# tweet_embeddings = json.load(fp)
# print(type(tweet_embeddings))
# print(type(tweet_embeddings[account]))
# print(len(tweet_embeddings[account]))
# print(len(tweet_embeddings[account][0]))