Skip to content

Commit

Permalink
Ported sampler from Stateless to Stateful pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
AsyaPronina committed Jan 8, 2025
1 parent 5ab58ca commit 7b1a495
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
67 changes: 45 additions & 22 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,16 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const std::string&,
const ov::AnyMap& config
) : LLMPipelineImplBase(tokenizer,
utils::from_config_json_if_exists(models_path)) {
utils::from_config_json_if_exists(models_path)),
m_sampler(m_tokenizer) {

auto model = genai::utils::singleton_core().read_model(models_path / "openvino_model.xml", {}, config);
ModelConfigDesc model_desc = get_modeldesc_from_json(models_path / "config.json");
ov::AnyMap properties = config;

auto compiled = setupAndCompileModel(model, model_desc, properties);
m_request = compiled->create_infer_request();
m_sampler.set_seed(m_generation_config.rng_seed);
}


Expand All @@ -704,10 +706,12 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const std::string&,
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {
) : LLMPipelineImplBase(tokenizer, generation_config),
m_sampler(m_tokenizer) {
ov::AnyMap properties_copy = properties;
auto compiled = setupAndCompileModel(model, model_desc, properties_copy);
m_request = compiled->create_infer_request();
m_sampler.set_seed(m_generation_config.rng_seed);
}

std::shared_ptr<ov::CompiledModel> StatefulLLMPipeline::setupAndCompileModel(
Expand Down Expand Up @@ -816,7 +820,9 @@ EncodedResults StatefulLLMPipeline::generate(
attention_mask = data->attention_mask;
}

OPENVINO_ASSERT(input_ids.get_shape().at(0) == 1u, "Currently only batch size=1 is supported");
ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];
OPENVINO_ASSERT(batch_size == 1u, "Currently only batch size=1 is supported");

GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
// If eos_token_id was not provided, take value from default m_generation_config
Expand All @@ -833,10 +839,12 @@ EncodedResults StatefulLLMPipeline::generate(
streamer_ptr = std::make_shared<TextCallbackStreamer>(m_tokenizer, *callback);
}

OPENVINO_ASSERT(config.is_greedy_decoding(), "Currently only greedy decoding is supported");
OPENVINO_ASSERT(config.is_greedy_decoding() || config.is_multinomial(),
"Currently only greedy and multinomial decoding are supported");

OPENVINO_ASSERT(config.num_return_sequences == 1u,
"Currently only \"num_return_sequences\" equal to 1 is supported!");

ov::Shape prompts_shape = input_ids.get_shape();
const size_t batch_size = prompts_shape[0];
ov::genai::EncodedResults results;
auto& raw_perf_counters = results.perf_metrics.raw_metrics;
// NB: Only batch=1 is supported now
Expand All @@ -856,26 +864,41 @@ EncodedResults StatefulLLMPipeline::generate(

m_request.infer();

int64_t last_token = utils::argmax(m_request.get_tensor("logits"), 0);
auto logits = m_request.get_tensor("logits");
int64_t output_sequence_len = logits.get_shape().at(1);

results.tokens[0].push_back(last_token);
if (streamer_ptr && streamer_ptr->put(last_token)) {
return results;
}
// Swap max_new_token to get_max_new_token()
auto sequence_group = std::make_shared<SequenceGroup>(
0 /* request_id */, input_ids, config, 1 /* block_size */);
sequence_group->update_processed_tokens_num(input_ids.get_size());
sequence_group->schedule_tokens(output_sequence_len);

// NB: Controls what tokens are ready to be pushed into the streamer
// Set max_new_tokens here via get_max_new_token(prompt)
GenerationHandle handle = std::make_shared<GenerationHandleImpl>(
sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters());

SamplerOutput sampler_output = m_sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer_ptr, handle);

int64_t input_ids_data = -1;
int64_t position_ids_data = prompt_len - 1;
std::vector<int64_t> attention_mask_data(prompt_len - 1, 1);
m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast<void*>(&input_ids_data)));
m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast<void*>(&position_ids_data)));

const size_t max_tokens = config.get_max_new_tokens(prompt_len);
for (int i = 0; i < max_tokens - 1; ++i) {
while (sequence_group->is_running()) {
// KV Cache is full, no further generation is possible
if (position_ids_data + 1 == m_kvcache_total) {
sequence_group->set_out_of_memory();
break;
}

sequence_group->schedule_tokens(1);
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);
auto last_token = running_sequences.front()->get_generated_ids().back();

// Just change the variables here, as pointers to them are already set to corresponding tensors
input_ids_data = last_token;
++position_ids_data;
Expand All @@ -885,24 +908,24 @@ EncodedResults StatefulLLMPipeline::generate(

m_request.infer();

last_token = utils::argmax(m_request.get_tensor("logits"), 0);
results.tokens[0].push_back(last_token);

raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now());
raw_perf_counters.m_batch_sizes.emplace_back(batch_size);
if (streamer_ptr && streamer_ptr->put(last_token)) {
break;
}

if (last_token == config.eos_token_id && !config.ignore_eos) {
break;
}
SamplerOutput sampler_output = m_sampler.sample(
{sequence_group}, m_request.get_tensor("logits"));
stream_generated_tokens(streamer_ptr, handle);
}

if (streamer_ptr) {
streamer_ptr->end();
}

OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u);
auto sequence = sequence_group->get_finished_sequences().front();
results.tokens[0] = sequence->get_generated_ids();
results.scores[0] = sequence->get_cumulative_log_prob();
m_sampler.clear_request_info(sequence_group->get_request_id());

auto stop_time = std::chrono::steady_clock::now();
// If is called without tokenization then that stat will not be reported.
auto& metrics = results.perf_metrics;
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class StatefulLLMPipeline : public LLMPipelineImplBase {
private:
uint32_t m_kvcache_total = 0u;
ov::InferRequest m_request;

Sampler m_sampler;

bool m_is_chat_conversation = false;
ChatHistory m_history;
};
Expand Down

0 comments on commit 7b1a495

Please sign in to comment.