Skip to content

Commit

Permalink
use ov threading
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Nov 21, 2024
1 parent 59a4e6d commit 89b54bd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 57 deletions.
4 changes: 1 addition & 3 deletions src/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ if(TARGET openvino_tokenizers)
endif()
add_library(openvino::genai ALIAS ${TARGET_NAME})

find_package(TBB REQUIRED)

target_include_directories(${TARGET_NAME}
PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>" "$<INSTALL_INTERFACE:runtime/include>"
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src")

target_include_directories(${TARGET_NAME} SYSTEM PRIVATE "${safetensors.h_SOURCE_DIR}")

target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime TBB::tbb PRIVATE openvino::threading nlohmann_json::nlohmann_json jinja2cpp)
target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime openvino::threading PRIVATE nlohmann_json::nlohmann_json jinja2cpp)

target_compile_features(${TARGET_NAME} PUBLIC cxx_std_17)

Expand Down
118 changes: 64 additions & 54 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0

#include <future>
#include "oneapi/tbb.h"
#include "openvino/core/parallel.hpp"
#include "sampler.hpp"

namespace ov::genai {
Expand Down Expand Up @@ -879,24 +879,23 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
std::cout << "Tokens in the tensor: " << logits_shape[0] << std::endl;

SamplerOutput sampler_output;
//std::mutex sampler_output_mutex;
tbb::spin_mutex sampler_output_mutex;
std::unordered_map<size_t, size_t> sequence_group_offsets;
std::unordered_map<size_t, SamplerOutput> sequence_group_sampler_outputs;

// First sequential pass to collect metadata and prepare for parallel processing
size_t last_request_id = 0;
for (size_t sequence_group_id = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
for (size_t i = 0; i < sequence_groups.size(); ++i) {
SequenceGroup::Ptr sequence_group = sequence_groups[i];
if (!sequence_group->is_scheduled())
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
const auto request_id = sequence_group->get_request_id();
if (sequence_group_id == 0) {
if (i == 0) {
sequence_group_offsets[request_id] = 0;
last_request_id = request_id;
} else {
Expand All @@ -907,62 +906,73 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();
m_logit_processors.insert({request_id, LogitProcessor(sampling_params, sequence_group->get_prompt_ids())});
}

sequence_group_sampler_outputs[request_id] = SamplerOutput{};
}

// Parallel sampling execution
tbb::parallel_for(tbb::blocked_range<size_t>(0, sequence_groups.size()), [&](const tbb::blocked_range<size_t>& r) {
for (size_t sequence_group_id = r.begin(); sequence_group_id != r.end(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
if (!sequence_group->is_scheduled())
continue;
for (const auto& entry : sequence_group_offsets) {
std::cout << entry.first << ": " << entry.second << ", ";
}
std::cout << std::endl;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
// Parallel sampling execution
// ov::parallel_for(kvcache_compiled.outputs().size() - 1, [&](size_t i)
// tbb::parallel_for(tbb::blocked_range<size_t>(0, sequence_groups.size()), [&](const tbb::blocked_range<size_t>& r) {
ov::parallel_for(sequence_groups.size(), [&](size_t i) {
SequenceGroup::Ptr sequence_group = sequence_groups[i];
if (!sequence_group->is_scheduled())
return;

if (sequence_group->requires_sampling()) {
const auto request_id = sequence_group->get_request_id();
const void * sequence_group_logits_data = logits_data + vocab_size * sequence_group_offsets[request_id];
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);

// Call sample_from_sequence_group synchronously
auto sequence_group_sampling_info = sample_from_sequence_group(sequence_group, sequence_group_logits,
m_logit_processors.at(request_id), is_validation_mode_enabled);

// Merge sampler output from sequence group to the main one
{
tbb::spin_mutex::scoped_lock lock(sampler_output_mutex);
sampler_output.m_dropped_sequences.insert(
sampler_output.m_dropped_sequences.end(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.begin(),
sequence_group_sampling_info.sampler_output.m_dropped_sequences.end()
);

for (const auto& forked_seq : sequence_group_sampling_info.sampler_output.m_forked_sequences) {
sampler_output.m_forked_sequences[forked_seq.first].insert(
sampler_output.m_forked_sequences[forked_seq.first].end(),
forked_seq.second.begin(),
forked_seq.second.end()
);
}
}
size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled

// NOTE: it should be before 'get_num_scheduled_tokens' is used
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (sequence_group_sampling_info.max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id());
logit_processor.update_generated_len(min_processed_tokens);
}
} else {
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
if (sequence_group->requires_sampling()) {
const auto request_id = sequence_group->get_request_id();
const void * sequence_group_logits_data = logits_data + vocab_size * sequence_group_offsets[request_id];
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);

// Call sample_from_sequence_group synchronously
auto sequence_group_sampling_info = sample_from_sequence_group(sequence_group, sequence_group_logits,
m_logit_processors.at(request_id), is_validation_mode_enabled);

// Store sampler output from sequence group in the map
sequence_group_sampler_outputs[request_id] = sequence_group_sampling_info.sampler_output;

// NOTE: it should be before 'get_num_scheduled_tokens' is used
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
if (sequence_group_sampling_info.max_removed_tokens_per_request) {
auto min_processed_tokens = sequence_group->get_prompt_len() + sequence_group_sampling_info.min_generated_len - 1;
sequence_group->update_processed_tokens_num(min_processed_tokens);
auto& logit_processor = m_logit_processors.at(sequence_group->get_request_id());
logit_processor.update_generated_len(min_processed_tokens);
}
} else {
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
sequence_group->finish_iteration();
}
});

// Merge sampler outputs from the map into the main sampler output
SamplerOutput sampler_output;
for (const auto& entry : sequence_group_sampler_outputs) {
const auto& group_output = entry.second;
sampler_output.m_dropped_sequences.insert(
sampler_output.m_dropped_sequences.end(),
group_output.m_dropped_sequences.begin(),
group_output.m_dropped_sequences.end()
);

for (const auto& forked_seq : group_output.m_forked_sequences) {
sampler_output.m_forked_sequences[forked_seq.first].insert(
sampler_output.m_forked_sequences[forked_seq.first].end(),
forked_seq.second.begin(),
forked_seq.second.end()
);
}
}

return sampler_output;
}

Expand Down

0 comments on commit 89b54bd

Please sign in to comment.