forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 387
/
model_adapter.h
120 lines (100 loc) · 3.51 KB
/
model_adapter.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#pragma once
#include <cassert>
#include <cstring>
#include <fstream>
#include <regex>
#include <iostream>
#include <iterator>
#include <queue>
#include <string>
#include <math.h>
#include <vector>
#include "expose.h"
enum FileFormat
{
BADFORMAT=0, //unknown, uninit, or failed to load
GGML=1, // 1=(original llama ggml, alpaca, GPT4ALL, GPTJ header)
GGHF=2, // 2=(llama ggmf)
GGJT=3, // 3=(llama ggjt)
GGJT_2=4, //newer llama format unshuffled
GGJT_3=5, //using 16bit scalar
GGUF_GENERIC=6, //GGUF (llama newest ver)
GPTJ_1=100, //the very first super old GPTJ format
GPTJ_2=101, //pygmalion, uses old ggml lib
GPTJ_3=102, //uses new ggml lib
GPTJ_4=103, //unshuffled
GPTJ_5=104, //using 16bit scalar
GPT2_1=200,
GPT2_2=201,
GPT2_3=202, //unshuffled
GPT2_4=203, //using 16bit scalar
RWKV_1=300,
RWKV_2=301,
NEOX_1=400,
NEOX_2=401,
NEOX_3=402, //redpajama
NEOX_4=403, //unshuffled
NEOX_5=404, //unshuffled redpajama
NEOX_6=405, //using 16bit scalar
NEOX_7=406, //using 16bit scalar redpajama
MPT_1=500, //first supported mpt version
};
enum GGUFArch
{
ARCH_DEFAULT = 0, //used for llama3 and other generic gguf
ARCH_FALCON = 1,
ARCH_PHI = 2,
ARCH_MAMBA = 3,
ARCH_SOLAR = 4,
ARCH_QWEN2 = 5,
ARCH_RWKV = 6,
ARCH_QWEN2VL = 7,
};
struct FileFormatExtraMeta
{
int n_ctx_train = 2048;
int fileversion = 0;
GGUFArch model_architecture = GGUFArch::ARCH_DEFAULT;
int n_expert_count = 0;
};
struct TopPicksData
{
std::string selected_token;
int32_t selected_tokenid;
float selected_logprob;
float selected_probability;
std::vector<std::string> tokens;
std::vector<int> tokenid;
std::vector<float> logprobs;
std::vector<float> p;
};
enum ModelLoadResult
{
FAIL = 0,
SUCCESS = 1,
RETRY_LOAD = 2, //used if it's suspected that the model is an older format
};
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format, FileFormatExtraMeta file_format_meta);
generation_outputs gpttype_generate(const generation_inputs inputs);
bool gpttype_generate_abort();
std::string gpttype_get_chat_template();
const std::string & gpttype_get_pending_output();
std::vector<int> gpttype_get_token_arr(const std::string & input, bool addbos);
std::string gpttype_detokenize(const std::vector<int> & input, bool render_special);
const std::vector<TopPicksData> gpttype_get_top_picks_data();
bool sdtype_load_model(const sd_load_model_inputs inputs);
sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs);
bool whispertype_load_model(const whisper_load_model_inputs inputs);
whisper_generation_outputs whispertype_generate(const whisper_generation_inputs inputs);
void timer_start();
double timer_check();
void print_tok_vec(std::vector<int> &embd);
void print_tok_vec(std::vector<float> &embd);
void print_vec(std::vector<std::string> &embd);
std::vector<int> LongestCommonSubseq(const std::vector<int> x, const std::vector<int> y);
bool ArrStartWith(const std::vector<int> targetArray, const std::vector<int> searchSeq);
int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> searchSeq);
FileFormat check_file_format(const std::string & fname, FileFormatExtraMeta * fileformatmeta);
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
const bool useSmartContext, const bool requireFullSubset);