-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathchat.cpp
executable file
·547 lines (483 loc) · 19.3 KB
/
chat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
//===----------------------------------------------------------------------===//
//
// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//
#include <iostream>
#include <cstdlib>
#include <vector>
#include <assert.h>
#include <chrono>
#include <algorithm>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <getopt.h>
#include <stdio.h>
#include <inttypes.h>
#include <random>
#include <numeric>
#include "bmruntime_interface.h"
#include "memory.h"
#include "utils.h"
static const float ATTENTION_MASK = -10000.;
class Qwen {
public:
void init(const std::vector<int> &devid, std::string model_path);
void deinit();
int forward_first(std::vector<int> &tokens);
int forward_next();
int forward_prompt_first(std::vector<int> &tokens);
std::vector<int> generate(std::vector<int> &history_tokens, int EOS);
std::mt19937 sgen;
Qwen() : sgen(std::random_device()()){};
private:
void net_launch(const bm_net_info_t *net, int stage_idx = 0);
inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src);
inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src, int &offset);
void head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
int greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
int penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem);
public:
int token_length;
int prompt_length;
int SEQLEN; // read from bmodel
int NUM_LAYERS; // read from bmodel
int MAX_PROMPT_LENGTH;
int MAX_UNPROMPT_LENGTH;
bool io_alone;
std::vector<int> visited_tokens;
std::vector<int> prompt_tokens;
uint16_t mask_value;
// generation
float temperature;
float top_p;
float repeat_penalty;
int repeat_last_n;
int max_new_tokens;
std::string generation_mode;
std::string prompt_mode;
private:
std::vector<bm_handle_t> handles;
bm_handle_t bm_handle;
void *p_bmrt;
std::vector<const bm_net_info_t *> net_blocks;
std::vector<const bm_net_info_t *> net_blocks_cache;
std::vector<const bm_net_info_t *> net_blocks_prompt_cache;
const bm_net_info_t *net_embed;
const bm_net_info_t *net_embed_cache;
const bm_net_info_t *net_lm, *net_greedy_head, *net_penalty_sample_head;
std::vector<bm_device_mem_t> past_key;
std::vector<bm_device_mem_t> past_value;
};
void Qwen::net_launch(const bm_net_info_t *net, int stage_idx) {
std::vector<bm_tensor_t> in_tensors(net->input_num);
std::vector<bm_tensor_t> out_tensors(net->output_num);
for (int i = 0; i < net->input_num; i++) {
bmrt_tensor_with_device(
&in_tensors[i], net->stages[stage_idx].input_mems[i],
net->input_dtypes[i], net->stages[stage_idx].input_shapes[i]);
}
for (int i = 0; i < net->output_num; i++) {
bmrt_tensor_with_device(
&out_tensors[i], net->stages[stage_idx].output_mems[i],
net->output_dtypes[i], net->stages[stage_idx].output_shapes[i]);
}
auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(),
net->input_num, out_tensors.data(),
net->output_num, true, false);
assert(ret);
bm_thread_sync(bm_handle);
}
void Qwen::d2d(bm_device_mem_t &dst, bm_device_mem_t &src) {
bm_memcpy_d2d_byte(bm_handle, dst, 0, src, 0, bm_mem_get_device_size(dst));
}
void Qwen::d2d(bm_device_mem_t &dst, bm_device_mem_t &src, int &offset) {
bm_memcpy_d2d_byte(bm_handle, dst, offset, src, 0, bm_mem_get_device_size(src));
}
void Qwen::init(const std::vector<int> &devices, std::string model_path) {
// request bm_handle
std::cout << "Device [ ";
for (auto d : devices) {
std::cout << d << " ";
}
std::cout << "] loading ....\n";
for (auto d : devices) {
bm_handle_t h;
bm_status_t status = bm_dev_request(&h, d);
assert(BM_SUCCESS == status);
handles.push_back(h);
}
bm_handle = handles[0];
// create bmruntime
#ifdef SOC_TARGET
p_bmrt = bmrt_create(handles[0]);
#else
p_bmrt = bmrt_create_ex(handles.data(), handles.size());
#endif
assert(NULL != p_bmrt);
// load bmodel by file
printf("Model[%s] loading ....\n", model_path.c_str());
bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str());
assert(true == ret);
printf("Done!\n");
// net embed and lm_head
net_embed = bmrt_get_network_info(p_bmrt, "embedding");
net_embed_cache = bmrt_get_network_info(p_bmrt, "embedding_cache");
net_lm = bmrt_get_network_info(p_bmrt, "lm_head");
net_greedy_head = bmrt_get_network_info(p_bmrt, "greedy_head");
net_penalty_sample_head = bmrt_get_network_info(p_bmrt, "penalty_sample_head");
SEQLEN = net_embed->stages[0].input_shapes[0].dims[1]; // real seqlen
auto num_nets = bmrt_get_network_number(p_bmrt);
NUM_LAYERS = (num_nets - 5) / 3;
// resize
visited_tokens.resize(SEQLEN);
// net blocks
for (int i = 0; i < NUM_LAYERS; i++) {
auto block_name = "block_" + std::to_string(i);
auto cache_name = "block_cache_" + std::to_string(i);
auto prompt_cache_name = "block_prompt_cache_" + std::to_string(i);
net_blocks.emplace_back(bmrt_get_network_info(p_bmrt, block_name.c_str()));
net_blocks_cache.emplace_back(
bmrt_get_network_info(p_bmrt, cache_name.c_str()));
net_blocks_prompt_cache.emplace_back(
bmrt_get_network_info(p_bmrt, prompt_cache_name.c_str()));
}
MAX_UNPROMPT_LENGTH = net_blocks_prompt_cache[0]->stages[0].input_shapes[0].dims[1]; // real seqlen
MAX_PROMPT_LENGTH = net_blocks_prompt_cache[0]->stages[0].input_shapes[3].dims[1]; // real seqlen
// convert attention to uint16_t
if (net_blocks_cache[0]->input_dtypes[2] == BM_FLOAT16) {
mask_value = fp32_to_fp16_bits(ATTENTION_MASK);
} else if (net_blocks_cache[0]->input_dtypes[2] == BM_BFLOAT16) {
mask_value = fp32_to_bf16_bits(ATTENTION_MASK);
} else {
std::cerr << "\nError: Invalid attention dtype\n";
std::cerr << "Supported dtype are 'BM_FLOAT16' or 'BM_BFLOAT16'\n";
throw std::runtime_error("Invalid attention dtype");
}
// kv cache
past_key.resize(NUM_LAYERS);
past_value.resize(NUM_LAYERS);
auto addr_mode = net_blocks_cache[0]->addr_mode;
io_alone = addr_mode == 1;
for (int i = 0; i < NUM_LAYERS; i++) {
assert(addr_mode == net_blocks_cache[i]->addr_mode);
if (io_alone) {
past_key[i] = net_blocks_cache[i]->stages[0].input_mems[3];
past_value[i] = net_blocks_cache[i]->stages[0].input_mems[4];
} else {
auto ret = bm_malloc_device_byte(bm_handle, &past_key[i],
net_blocks_cache[i]->max_input_bytes[3]);
assert(BM_SUCCESS == ret);
ret = bm_malloc_device_byte(bm_handle, &past_value[i],
net_blocks_cache[i]->max_input_bytes[4]);
assert(BM_SUCCESS == ret);
}
}
}
void Qwen::deinit() {
if (false == io_alone) {
for (int i = 0; i < NUM_LAYERS; i++) {
bm_free_device(bm_handle, past_key[i]);
bm_free_device(bm_handle, past_value[i]);
}
}
bmrt_destroy(p_bmrt);
for (auto h : handles) {
bm_dev_free(h);
}
}
void Qwen::head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
std::vector<bm_tensor_t> in_tensors(net->input_num);
std::vector<bm_tensor_t> out_tensors(net->output_num);
bmrt_tensor_with_device(
&in_tensors[0], logits_mem,
net->input_dtypes[0], net->stages[0].input_shapes[0]);
for (int i = 1; i < net->input_num; i++) {
bmrt_tensor_with_device(
&in_tensors[i], net->stages[0].input_mems[i],
net->input_dtypes[i], net->stages[0].input_shapes[i]);
}
for (int i = 0; i < net->output_num; i++) {
bmrt_tensor_with_device(
&out_tensors[i], net->stages[0].output_mems[i],
net->output_dtypes[i], net->stages[0].output_shapes[i]);
}
auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(),
net->input_num, out_tensors.data(),
net->output_num, true, false);
assert(ret);
bm_thread_sync(bm_handle);
}
int Qwen::greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
auto &out_mem = net->stages[0].output_mems[0];
head_launch(net, logits_mem);
int token = 0;
bm_memcpy_d2s(bm_handle, (void *)&token, out_mem);
return token;
}
int Qwen::penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem) {
auto &in1_mem = net->stages[0].input_mems[1];
auto &in2_mem = net->stages[0].input_mems[2];
auto &in3_mem = net->stages[0].input_mems[3];
auto &in4_mem = net->stages[0].input_mems[4];
auto &out0_mem = net->stages[0].output_mems[0];
auto &out1_mem = net->stages[0].output_mems[1];
// repeat_penalty + top_p + top_k + temperature
std::vector<int> generated_tokens(SEQLEN, visited_tokens[token_length - 1]);
repeat_last_n = std::min(repeat_last_n, token_length);
std::copy(visited_tokens.begin() + token_length - repeat_last_n,
visited_tokens.begin() + token_length,
generated_tokens.begin());
bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data());
bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p);
bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature);
bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty);
// inference
head_launch(net, logits_mem);
// get logit & token
int candidate_num = net->stages[0].output_shapes[0].dims[1];
std::vector<float> probs(candidate_num);
bm_memcpy_d2s(bm_handle, probs.data(), out0_mem);
std::vector<int> tokens(candidate_num);
bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem);
// penalty_sample
std::discrete_distribution<> dist(probs.begin(), probs.end());
return tokens[dist(sgen)];
}
int Qwen::forward_first(std::vector<int> &tokens) {
std::vector<int> position_id(SEQLEN, 0);
std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, mask_value);
std::copy(tokens.begin(), tokens.end(), visited_tokens.data());
token_length = tokens.size();
for (int i = 0; i < token_length; i++) {
position_id[i] = i;
}
for (int i = 0; i < token_length; i++) {
for (int j = 0; j < SEQLEN; j++) {
if (j <= i) {
attention_mask[i * SEQLEN + j] = 0;
}
}
}
// forward embeding
auto &in_mem = net_embed->stages[0].input_mems[0];
auto &out_mem = net_embed->stages[0].output_mems[0];
bm_memcpy_s2d(bm_handle, in_mem, (void *)visited_tokens.data());
net_launch(net_embed); // prefil embedding
// forward blocks
for (int idx = 0; idx < NUM_LAYERS; idx++) {
auto &in0_mem = net_blocks[idx]->stages[0].input_mems[0];
auto &in1_mem = net_blocks[idx]->stages[0].input_mems[1];
auto &in2_mem = net_blocks[idx]->stages[0].input_mems[2];
d2d(in0_mem, out_mem);
if (idx == 0) {
// only first time need copy
bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data());
bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
}
net_launch(net_blocks[idx]);
out_mem = net_blocks[idx]->stages[0].output_mems[0];
d2d(past_key[idx], net_blocks[idx]->stages[0].output_mems[1]);
d2d(past_value[idx], net_blocks[idx]->stages[0].output_mems[2]);
}
// forward lmhead
int bytes = out_mem.size / SEQLEN;
auto &lm_in_mem = net_lm->stages[0].input_mems[0];
auto &lm_out_mem = net_lm->stages[0].output_mems[0];
bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0, out_mem,
(token_length - 1) * bytes, bytes);
net_launch(net_lm);
int token = 0;
if (generation_mode == "greedy") {
token = greedy_search(net_greedy_head, lm_out_mem);
} else if (generation_mode == "penalty_sample") {
token = penalty_sample(net_penalty_sample_head, lm_out_mem);
}
visited_tokens[token_length] = token;
token_length += 1;
return token;
}
int Qwen::forward_next() {
int cur_token = visited_tokens[token_length - 1];
std::vector<uint16_t> attention_mask(SEQLEN + 1, 0);
for (int i = token_length - 1; i < SEQLEN; i++) {
attention_mask[i] = mask_value;
}
int32_t position_id = token_length - 1;
// embedding
auto &in_mem = net_embed_cache->stages[0].input_mems[0];
auto &out_mem = net_embed_cache->stages[0].output_mems[0];
bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token);
net_launch(net_embed_cache);
// blocks
int bytes =
bm_mem_get_device_size(net_blocks_cache[0]->stages[0].output_mems[1]);
int token_offset = (token_length - 1) * bytes;
for (int idx = 0; idx < NUM_LAYERS; idx++) {
auto &in0_mem = net_blocks_cache[idx]->stages[0].input_mems[0];
auto &in1_mem = net_blocks_cache[idx]->stages[0].input_mems[1];
auto &in2_mem = net_blocks_cache[idx]->stages[0].input_mems[2];
auto &in3_mem = net_blocks_cache[idx]->stages[0].input_mems[3];
auto &in4_mem = net_blocks_cache[idx]->stages[0].input_mems[4];
auto &out0_mem = net_blocks_cache[idx]->stages[0].output_mems[0];
auto &out1_mem = net_blocks_cache[idx]->stages[0].output_mems[1];
auto &out2_mem = net_blocks_cache[idx]->stages[0].output_mems[2];
d2d(in0_mem, out_mem);
if (io_alone) {
if (idx == 0) {
bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id);
bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
} else {
d2d(in1_mem, net_blocks_cache[0]->stages[0].input_mems[1]);
d2d(in2_mem, net_blocks_cache[0]->stages[0].input_mems[2]);
}
} else {
if (idx == 0) {
bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id);
bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
}
d2d(in3_mem, past_key[idx]);
d2d(in4_mem, past_value[idx]);
}
net_launch(net_blocks_cache[idx]);
out_mem = out0_mem;
bm_memcpy_d2d_byte(bm_handle, past_key[idx], token_offset, out1_mem, 0,
bytes);
bm_memcpy_d2d_byte(bm_handle, past_value[idx], token_offset, out2_mem, 0,
bytes);
}
// forward lmhead
auto &lm_in_mem = net_lm->stages[0].input_mems[0];
auto &lm_out_mem = net_lm->stages[0].output_mems[0];
d2d(lm_in_mem, out_mem);
net_launch(net_lm);
int token = 0;
if (generation_mode == "greedy") {
token = greedy_search(net_greedy_head, lm_out_mem);
} else if (generation_mode == "penalty_sample") {
token = penalty_sample(net_penalty_sample_head, lm_out_mem);
}
visited_tokens[token_length] = token;
token_length += 1;
return token;
}
int Qwen::forward_prompt_first(std::vector<int> &tokens) {
visited_tokens.clear();
visited_tokens.resize(SEQLEN);
std::vector<int> position_id(MAX_UNPROMPT_LENGTH, 0);
std::vector<uint16_t> attention_mask(MAX_UNPROMPT_LENGTH * SEQLEN, mask_value);
std::copy(tokens.begin(), tokens.end(), visited_tokens.data());
token_length = tokens.size();
for (int i = 0; i < token_length; i++) {
position_id[i] = i + prompt_length;
}
for (int i = 0; i < token_length; i++) {
for (int j = 0; j < prompt_length; j++) {
attention_mask[i * SEQLEN + j] = 0;
}
for (int j = MAX_PROMPT_LENGTH; j < SEQLEN; j++) {
if (j - MAX_PROMPT_LENGTH <= i) {
attention_mask[i * SEQLEN + j] = 0;
}
}
}
// forward embeding
auto &embed_in_mem = net_embed->stages[0].input_mems[0];
auto &embed_out_mem = net_embed->stages[0].output_mems[0];
bm_memcpy_s2d(bm_handle, embed_in_mem, (void *)visited_tokens.data());
net_launch(net_embed); // prefil embedding
auto &out_mem = net_blocks_prompt_cache[0]->stages[0].output_mems[0];
// forward blocks
int prompt_bytes =
bm_mem_get_device_size(net_blocks_prompt_cache[0]->stages[0].input_mems[3]) / MAX_PROMPT_LENGTH;
int prompt_offset = prompt_bytes * prompt_length;
for (int idx = 0; idx < NUM_LAYERS; idx++) {
auto &in0_mem = net_blocks_prompt_cache[idx]->stages[0].input_mems[0];
auto &in1_mem = net_blocks_prompt_cache[idx]->stages[0].input_mems[1];
auto &in2_mem = net_blocks_prompt_cache[idx]->stages[0].input_mems[2];
auto &in3_mem = net_blocks_prompt_cache[idx]->stages[0].input_mems[3];
auto &in4_mem = net_blocks_prompt_cache[idx]->stages[0].input_mems[4];
if (io_alone) {
if (idx == 0) {
d2d(in0_mem, embed_out_mem);
bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data());
bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data());
} else {
d2d(in0_mem, net_blocks_prompt_cache[idx-1]->stages[0].output_mems[0]);
d2d(in1_mem, net_blocks_prompt_cache[0]->stages[0].input_mems[1]);
d2d(in2_mem, net_blocks_prompt_cache[0]->stages[0].input_mems[2]);
}
d2d(in3_mem, past_key[idx]);
d2d(in4_mem, past_value[idx]);
} else {
throw std::runtime_error("Only support io_alone");
}
net_launch(net_blocks_prompt_cache[idx]);
d2d(past_key[idx], net_blocks_prompt_cache[idx]->stages[0].output_mems[1], prompt_offset);
d2d(past_value[idx], net_blocks_prompt_cache[idx]->stages[0].output_mems[2], prompt_offset);
}
// forward lmhead
int bytes = out_mem.size / MAX_UNPROMPT_LENGTH;
auto &lm_in_mem = net_lm->stages[0].input_mems[0];
auto &lm_out_mem = net_lm->stages[0].output_mems[0];
bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0,
net_blocks_prompt_cache[NUM_LAYERS-1]->stages[0].output_mems[0],
(token_length - 1) * bytes, bytes);
net_launch(net_lm);
int token = 0;
if (generation_mode == "greedy") {
token = greedy_search(net_greedy_head, lm_out_mem);
} else if (generation_mode == "penalty_sample") {
token = penalty_sample(net_penalty_sample_head, lm_out_mem);
}
visited_tokens[token_length] = token;
visited_tokens.insert(visited_tokens.begin(), prompt_tokens.begin(), prompt_tokens.end());
visited_tokens.erase(visited_tokens.end() - prompt_length, visited_tokens.end());
token_length += 1;
token_length += prompt_length;
return token;
}
std::vector<int> Qwen::generate(std::vector<int> &history_tokens, int EOS) {
if (history_tokens.empty()) {
printf("Sorry: your question is empty!!\n");
history_tokens.clear();
return {};
}
// make sure token not too large
if ((int)history_tokens.size() > SEQLEN - 10) {
history_tokens.clear();
printf("Error: your question is too large!\n");
return {};
}
std::vector<int> result_tokens;
int token = forward_first(history_tokens);
while (token != EOS && token_length < SEQLEN) {
result_tokens.emplace_back(token);
token = forward_next();
}
return result_tokens;
}
PYBIND11_MODULE(chat, m) {
pybind11::class_<Qwen>(m, "Qwen")
.def(pybind11::init<>())
.def("init", &Qwen::init)
.def("forward_first", &Qwen::forward_first)
.def("forward_next", &Qwen::forward_next)
.def("forward_prompt_first", &Qwen::forward_prompt_first)
.def("generate", &Qwen::generate)
.def("deinit", &Qwen::deinit)
.def_readwrite("SEQLEN", &Qwen::SEQLEN) // read SEQLEN in pipeline.py
.def_readwrite("token_length", &Qwen::token_length)
.def_readwrite("prompt_length", &Qwen::prompt_length)
.def_readwrite("prompt_tokens", &Qwen::prompt_tokens)
.def_readwrite("temperature", &Qwen::temperature)
.def_readwrite("top_p", &Qwen::top_p)
.def_readwrite("repeat_penalty", &Qwen::repeat_penalty)
.def_readwrite("repeat_last_n", &Qwen::repeat_last_n)
.def_readwrite("max_new_tokens", &Qwen::max_new_tokens)
.def_readwrite("generation_mode", &Qwen::generation_mode)
.def_readwrite("prompt_mode", &Qwen::prompt_mode);
}