forked from ProsusAI/finBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
28 lines (20 loc) · 1012 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from finbert.finbert import predict
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
import argparse
from pathlib import Path
import datetime
import os
parser = argparse.ArgumentParser(description='Sentiment analyzer')
parser.add_argument('-a', action="store_true", default=False)
parser.add_argument('--text_path', type=str, help='Path to the text file.')
parser.add_argument('--output_dir', type=str, help='Where to write the results')
parser.add_argument('--model_path', type=str, help='Path to classifier model')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
with open(args.text_path,'r') as f:
text = f.read()
model = BertForSequenceClassification.from_pretrained(args.model_path,num_labels=3,cache_dir=None)
#now = datetime.datetime.now().strftime("predictions_%B-%d-%Y-%I:%M.csv")
output = Path(args.text_path).stem + '_predictions.csv'
predict(text,model,write_to_csv=True,path=os.path.join(args.output_dir,output))