Skip to content

Commit

Permalink
added usage statistics into response
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Jan 4, 2024
1 parent e551884 commit 298f246
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 169 deletions.
4 changes: 2 additions & 2 deletions src/engine/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
auto& ids = token_ids_vec.emplace_back();
auto& counts = token_counts_vec.emplace_back();

const auto& seq_token_counts = sequence->token_counts();
const auto& seq_token_counts = sequence->token_to_count_map();
const auto unique_tokens = seq_token_counts.size();
ids.reserve(unique_tokens);
counts.reserve(unique_tokens);
Expand Down Expand Up @@ -136,7 +136,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
auto& ids = token_ids_vec.emplace_back();
auto& counts = token_counts_vec.emplace_back();

const auto& seq_token_counts = sequence->token_counts();
const auto& seq_token_counts = sequence->token_to_count_map();
const auto unique_tokens = seq_token_counts.size();
ids.reserve(unique_tokens);
counts.reserve(unique_tokens);
Expand Down
28 changes: 15 additions & 13 deletions src/request/request.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
#include "request.h"

#include <absl/time/clock.h>
#include <absl/time/time.h>

#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "sampling_parameter.h"
#include "common/logging.h"
#include "sequence.h"
#include "status.h"
#include "stopping_criteria.h"

namespace llm {

void Request::add_sequence(std::string prompt,
std::vector<int32_t> token_ids,
OnStream on_stream) {
sequences.emplace_back(std::move(prompt),
std::move(token_ids),
&sampling_param,
&stopping_criteria,
on_stream,
echo);
Request::Request(const std::string& id,
const std::vector<int32_t>& prompt_tokens)
: id(id),
created_time(absl::ToUnixSeconds(absl::Now())),
prompt_tokens(prompt_tokens) {}

void Request::add_sequence(OnStream on_stream) {
if (stream) {
GCHECK(on_stream) << "on_stream should not be null if stream is true";
}
sequences.emplace_back(*this, on_stream);
}

bool Request::is_finished() const {
Expand Down
58 changes: 41 additions & 17 deletions src/request/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <cstdint>
#include <deque>
#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -37,36 +36,60 @@ enum class ScheduleStatus {
CANCELLED,
};

struct Statistics {
// the number of tokens in the prompt.
size_t num_prompt_tokens = 0;
// the number of tokens in the generated completion.
size_t num_generated_tokens = 0;
// the total number of tokens used in the request (prompt + completion).
size_t num_total_tokens = 0;
};

// Priority of the request.
// The higher the priority, the sooner the request is processed.
enum class RequestPriority { HIGH = 0, MEDIUM, LOW };

using OnFinish = std::function<bool(const std::string& output_text,
FinishReason finish_reason,
const Status& status)>;
struct SequenceResult {
std::string output_text;

FinishReason finish_reason;
};

// Function to call when a request is finished.
using OnFinish =
std::function<bool(const std::vector<SequenceResult>& seq_results,
const Status& status,
const Statistics& stats)>;

// Function to call when a stream request is finished.
using OnStreamFinish = std::function<bool(const Status& status)>;

// A request is a data structure that encapsulates all the necessary
// information required to process a request efficiently. It acts as a
// container, holding essential data, such as input parameters, configuration
// settings, and any additional context-specific details related to the
// request's handling.
struct Request {
struct Request final {
public:
Request(std::string id) : id(std::move(id)){};
Request(const std::string& id, const std::vector<int32_t>& prompt_tokens);

void add_sequence(std::string prompt,
std::vector<int32_t> token_ids,
OnStream on_stream);
void add_sequence(OnStream on_stream = nullptr);

bool is_finished() const;

size_t num_prompt_tokens() const { return prompt_tokens.size(); }

// The unique id of the request.
// NOLINTNEXTLINE
const std::string id;

// list of sequences to generate completions for the prompt
// use deque instead of vector to avoid no-copy move for Sequence
std::deque<Sequence> sequences;
// Scheduled time of the request.
// NOLINTNEXTLINE
const int64_t created_time;

// the token ids from request's prompt.
// NOLINTNEXTLINE
const std::vector<int32_t> prompt_tokens;

// sampling parameters
SamplingParameter sampling_param;
Expand All @@ -80,17 +103,18 @@ struct Request {
// Whether to echo back the prompt in the output.
bool echo = true;

// The status of the request.
// ScheduleStatus status = ScheduleStatus::WAITING;

// the priority of the request.
RequestPriority priority = RequestPriority::MEDIUM;

// Scheduled time of the request.
int64_t created_time = 0;
// list of sequences to generate completions for the prompt
// use deque instead of vector to avoid no-copy move for Sequence
std::deque<Sequence> sequences;

// function to call when the request is finished.
OnFinish on_finish;

// function to call when a stream request is finished.
OnStreamFinish on_stream_finish;
};

// Compare two request contexts based on priority then scheduled time.
Expand Down
49 changes: 24 additions & 25 deletions src/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,33 @@
#include <string>
#include <vector>

#include "request.h"
#include "tokenizer/tokenizer.h"

namespace llm {

// NOLINTNEXTLINE
std::atomic<int64_t> Sequence::next_id_{1};

Sequence::Sequence(std::string prompt,
std::vector<int32_t> token_ids,
const SamplingParameter* sampling_param,
const StoppingCriteria* stopping_criteria,
OnStream on_stream,
bool echo)
: id_(next_id_.fetch_add(1)),
prompt_(std::move(prompt)),
token_ids_(std::move(token_ids)),
num_prompt_tokens_(token_ids_.size()),
sampling_param_(sampling_param),
stopping_criteria_(stopping_criteria),
on_stream_(on_stream) {
Sequence::Sequence(const Request& request, OnStream on_stream)
: id_(next_id_.fetch_add(1)), request_(request), on_stream_(on_stream) {
const auto& prompt_tokens = request_.prompt_tokens;
// reserve enough space for the token ids to avoid reallocation
// so that the token ids are not invalidated
const size_t max_tokens = stopping_criteria_->max_tokens;
const size_t max_tokens = request_.stopping_criteria.max_tokens;
token_ids_.reserve(max_tokens + token_ids_.size());
token_ids_ = prompt_tokens;
num_prompt_tokens_ = prompt_tokens.size();

// if echo is true, set prefix_offset_ and output_offset_ to 0 to print the
// whole sequence, otherwise set them to the length of the prompt to skip the
// prompt.
prefix_offset_ = echo ? 0 : token_ids_.size();
output_offset_ = echo ? 0 : token_ids_.size();
prefix_offset_ = request_.echo ? 0 : token_ids_.size();
output_offset_ = request_.echo ? 0 : token_ids_.size();

// calculate the token counts
for (const int32_t token_id : token_ids_) {
token_counts_[token_id]++;
token_to_count_map_[token_id]++;
}
}

Expand All @@ -47,15 +41,17 @@ bool Sequence::append_new_token_id(int32_t next_token_id) {
return false;
}

const auto& stopping_criteria = request_.stopping_criteria;

// check eos and stop tokens ids first
if (!stopping_criteria_->ignore_eos_token &&
next_token_id == stopping_criteria_->eos_token_id) {
if (!stopping_criteria.ignore_eos_token &&
next_token_id == stopping_criteria.eos_token_id) {
finish_reason_ = FinishReason::STOP;
is_finished_ = true;
return false;
}
// check against stop tokens ids
if (stopping_criteria_->stop_token_ids.count(next_token_id) > 0) {
if (stopping_criteria.stop_token_ids.count(next_token_id) > 0) {
finish_reason_ = FinishReason::STOP;
is_finished_ = true;
return false;
Expand All @@ -64,12 +60,11 @@ bool Sequence::append_new_token_id(int32_t next_token_id) {
// all tokens before pos should be processed and cached.
cache_pos_ = token_ids_.size();
token_ids_.push_back(next_token_id);
token_counts_[next_token_id]++;
token_to_count_map_[next_token_id]++;

// check against max tokens
const size_t generated_tokens = token_ids_.size() - num_prompt_tokens_;
const size_t max_new_tokens = stopping_criteria_->max_tokens;
if (max_new_tokens > 0 && generated_tokens >= max_new_tokens) {
const size_t max_new_tokens = stopping_criteria.max_tokens;
if (max_new_tokens > 0 && num_generated_tokens() >= max_new_tokens) {
finish_reason_ = FinishReason::LENGTH;
is_finished_ = true;
return false;
Expand All @@ -96,4 +91,8 @@ std::string Sequence::decode_delta_text(size_t end,
return "";
}

const SamplingParameter& Sequence::sampling_param() const {
return request_.sampling_param;
}

} // namespace llm
37 changes: 15 additions & 22 deletions src/request/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "tokenizer/tokenizer.h"

namespace llm {
struct Request;

// "stop" - the model hit a natural stop point or a provided stop sequence.
// "length" - the maximum number of tokens specified in the request was reached.
Expand All @@ -30,22 +31,17 @@ using OnStream =
// current position in generating tokens, etc.
class Sequence final {
public:
Sequence(std::string prompt,
std::vector<int32_t> token_ids,
const SamplingParameter* sampling_param,
const StoppingCriteria* stopping_criteria,
OnStream on_stream,
bool echo);
Sequence(const Request& request, OnStream on_stream);

// get the id of the sequence
int64_t id() const { return id_; }

// get token ids
const std::vector<int32_t>& token_ids() const { return token_ids_; }

// get token ids count
const std::unordered_map<int32_t, int32_t>& token_counts() const {
return token_counts_;
// get token ids to count map
const std::unordered_map<int32_t, int32_t>& token_to_count_map() const {
return token_to_count_map_;
}

// get the total number of tokens
Expand All @@ -54,11 +50,13 @@ class Sequence final {
// get the number of prompt tokens
size_t num_prompt_tokens() const { return num_prompt_tokens_; }

// get the prompt string
const std::string& prompt() const { return prompt_; }
// get the number of generated tokens
size_t num_generated_tokens() const {
return num_tokens() - num_prompt_tokens();
}

// get the sampling parameters
const SamplingParameter& sampling_param() const { return *sampling_param_; }
const SamplingParameter& sampling_param() const;

// whether the sequence is in prefill stage, no kv cache has been generated
bool is_prefill() const { return cache_pos_ == 0; }
Expand Down Expand Up @@ -125,25 +123,20 @@ class Sequence final {
}

// global unique id for the sequence
int64_t id_ = 0;
const int64_t id_;

// prompt to generate completions for
std::string prompt_;
// The request that the sequence belongs to.
const Request& request_;

// token ids generated from p
// token ids generated for the sequence
std::vector<int32_t> token_ids_;

// the count of each token id
std::unordered_map<int32_t, int32_t> token_counts_;
std::unordered_map<int32_t, int32_t> token_to_count_map_;

// the length of the prompt tokens
size_t num_prompt_tokens_ = 0;

// sampling parameters
const SamplingParameter* sampling_param_ = nullptr;

const StoppingCriteria* stopping_criteria_ = nullptr;

// the cache position.
// all tokens before pos should be processed and cached.
size_t cache_pos_ = 0;
Expand Down
13 changes: 7 additions & 6 deletions src/request/status.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <cstdint>
#include <ostream>
#include <string>

namespace llm {
Expand Down Expand Up @@ -36,17 +37,17 @@ class Status final {
StatusCode error_code() const { return code_; }
const std::string& error_msg() const { return msg_; }

bool ok() const { return code_ == StatusCode::OK;}
bool ok() const { return code_ == StatusCode::OK; }

private:
StatusCode code_ = StatusCode::OK;
std::string msg_;
};

// inline std::ostream& operator<<(std::ostream& os, const Status& status) {
// os << "Status, code: " << status.error_code() << ", message: " << status.error_msg();
// return os;
// }

inline std::ostream& operator<<(std::ostream& os, const Status& status) {
os << "Status, code: " << static_cast<uint8_t>(status.error_code())
<< ", message: " << status.error_msg();
return os;
}

} // namespace llm
Loading

0 comments on commit 298f246

Please sign in to comment.