-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathdemo.py
202 lines (178 loc) Β· 6.03 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import argparse
import base64
import time
from mlx_engine.generate import load_model, load_draft_model, create_generator, tokenize
from mlx_engine.model_kit import VALID_KV_BITS, VALID_KV_GROUP_SIZE
DEFAULT_PROMPT = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Explain the rules of chess in one sentence.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
DEFAULT_TEMP = 0.8
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="LM Studio mlx-engine inference script"
)
parser.add_argument(
"--model",
required=True,
type=str,
help="The file system path to the model",
)
parser.add_argument(
"--prompt",
default=DEFAULT_PROMPT,
type=str,
help="Message to be processed by the model",
)
parser.add_argument(
"--images",
type=str,
nargs="+",
help="Path of the images to process",
)
parser.add_argument(
"--temp",
default=DEFAULT_TEMP,
type=float,
help="Sampling temperature",
)
parser.add_argument(
"--stop-strings",
type=str,
nargs="+",
help="Strings that will stop the generation",
)
parser.add_argument(
"--top-logprobs",
type=int,
default=0,
help="Number of top logprobs to return",
)
parser.add_argument(
"--max-kv-size",
type=int,
help="Max context size of the model",
)
parser.add_argument(
"--kv-bits",
type=int,
choices=VALID_KV_BITS,
help="Number of bits for KV cache quantization. Must be between 3 and 8 (inclusive)",
)
parser.add_argument(
"--kv-group-size",
type=int,
choices=VALID_KV_GROUP_SIZE,
help="Group size for KV cache quantization",
)
parser.add_argument(
"--quantized-kv-start",
type=int,
help="When --kv-bits is set, start quantizing the KV cache from this step onwards",
)
parser.add_argument(
"--draft-model",
type=str,
help="The file system path to the draft model for speculative decoding.",
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
)
parser.add_argument(
"--print-prompt-progress",
action="store_true",
help="Enable printed prompt processing progress callback",
)
return parser
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
class GenerationStatsCollector:
def __init__(self):
self.start_time = time.time()
self.first_token_time = None
self.total_tokens = 0
def add_tokens(self, tokens):
"""Record new tokens and their timing."""
if self.first_token_time is None:
self.first_token_time = time.time()
self.total_tokens += len(tokens)
def print_stats(self):
"""Print generation statistics."""
end_time = time.time()
total_time = end_time - self.start_time
print(f"\n\nGeneration stats:")
print(f" - Time to first token: {self.first_token_time - self.start_time:.2f}s")
print(f" - Total tokens generated: {self.total_tokens}")
print(f" - Total time: {total_time:.2f}s")
print(f" - Tokens per second: {self.total_tokens / total_time:.2f}")
if __name__ == "__main__":
# Parse arguments
parser = setup_arg_parser()
args = parser.parse_args()
if isinstance(args.images, str):
args.images = [args.images]
# Set up prompt processing callback
def prompt_progress_callback(percent):
if args.print_prompt_progress:
width = 40 # bar width
filled = int(width * percent / 100)
bar = "β" * filled + "β" * (width - filled)
print(f"\rProcessing prompt: |{bar}| ({percent:.1f}%)", end="", flush=True)
if percent >= 100:
print() # new line when done
else:
pass
# Load the model
model_path = args.model
model_kit = load_model(
str(model_path),
max_kv_size=args.max_kv_size,
trust_remote_code=False,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
)
# Load draft model if requested
if args.draft_model:
load_draft_model(model_kit=model_kit, path=args.draft_model)
# Tokenize the prompt
prompt = args.prompt
prompt_tokens = tokenize(model_kit, prompt)
# Handle optional images
images_base64 = []
if args.images:
if isinstance(args.images, str):
args.images = [args.images]
images_base64 = [image_to_base64(img_path) for img_path in args.images]
# Record top logprobs
logprobs_list = []
# Initialize generation stats collector
stats_collector = GenerationStatsCollector()
# Generate the response
generator = create_generator(
model_kit,
prompt_tokens,
images_b64=images_base64,
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
prompt_progress_callback=prompt_progress_callback,
num_draft_tokens=args.num_draft_tokens,
temp=args.temp,
)
for generation_result in generator:
print(generation_result.text, end="", flush=True)
stats_collector.add_tokens(generation_result.tokens)
logprobs_list.extend(generation_result.top_logprobs)
if generation_result.stop_condition:
stats_collector.print_stats()
print(
f"\nStopped generation due to: {generation_result.stop_condition.stop_reason}"
)
if generation_result.stop_condition.stop_string:
print(f"Stop string: {generation_result.stop_condition.stop_string}")
if args.top_logprobs:
[print(x) for x in logprobs_list]