generated from FreddeFrallan/Best-README-Template
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathUtils.py
29 lines (20 loc) · 967 Bytes
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from NonResidualAttention import PromptModelSetup
from DataManagement import PromptGenerator
import transformers
def loadPaperGPT2LargeModel():
maxSeqLen, maxPromptLen = 32, 128
base = 'gpt2-large'
clmBasePath = 'gpt2-large'
promptModelBasePath = 'Non-Residual-Prompting/GPT2-Large'
postWeightsPath = 'Non-Residual-Prompting/GPT2-Large-Post-Transformation'
tokenizer = transformers.AutoTokenizer.from_pretrained(base)
model = PromptModelSetup.PromptModelSetup(base, CLMWeightsPath=clmBasePath,
promptModelWeightsPath=promptModelBasePath,
postWeightsPath=postWeightsPath)
promptGenerator = PromptGenerator.PromptGenerator(tokenizer, maxPromptLen)
return model, tokenizer, maxSeqLen, promptGenerator
AVAILABLE_MODELS = {
'gpt2-large': loadPaperGPT2LargeModel
}
def loadModel(modelName):
return AVAILABLE_MODELS[modelName]()