-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathprompt.py
24 lines (22 loc) · 866 Bytes
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import WEIGHTS, top_vals, format_token
import torch
with torch.no_grad():
# load model
model = input("Model: ")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token
gpt = AutoModelForCausalLM.from_pretrained(
model,
revision="main",
torch_dtype=WEIGHTS.get(model, torch.bfloat16) if device == "cuda:0" else torch.float32,
).to(device)
# make data
while True:
text = input("Text: ")
text = tokenizer(text, return_tensors="pt").to(device)
print([format_token(tokenizer, i) for i in text.input_ids[0]])
logits = gpt(**text).logits[0, -1]
probs = logits.softmax(-1)
top_vals(tokenizer, probs)