Skip to content

Commit

Permalink
[bug fix] avoid returning stop tokens in response.
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 3, 2023
1 parent 7f1679f commit bb329ec
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
29 changes: 19 additions & 10 deletions src/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,40 @@ Sequence::Sequence(std::string prompt,
output_offset_ = echo ? 0 : token_ids_.size();
}

bool Sequence::check_stopping_creteria() {
bool Sequence::append_new_token_id(int32_t next_token_id) {
if (is_finished_) {
return true;
return false;
}
// check against stopping criterias
const size_t generated_tokens = token_ids_.size() - num_prompt_tokens_;
const size_t max_new_tokens = stopping_criteria_->max_tokens;
if (max_new_tokens > 0 && generated_tokens >= max_new_tokens) {
if (max_new_tokens > 0 && (generated_tokens + 1) >= max_new_tokens) {
// add the last token then mark the sequence as finished
cache_pos_ = token_ids_.size();
token_ids_.push_back(next_token_id);

finish_reason_ = FinishReason::LENGTH;
return is_finished_ = true;
is_finished_ = true;
return false;
}

const auto last_token_id = token_ids_.back();
if (!stopping_criteria_->ignore_eos_token &&
last_token_id == stopping_criteria_->eos_token_id) {
next_token_id == stopping_criteria_->eos_token_id) {
finish_reason_ = FinishReason::STOP;
return is_finished_ = true;
is_finished_ = true;
return false;
}
// check against stop tokens ids
if (stopping_criteria_->stop_token_ids.count(last_token_id) > 0) {
if (stopping_criteria_->stop_token_ids.count(next_token_id) > 0) {
finish_reason_ = FinishReason::STOP;
return is_finished_ = true;
is_finished_ = true;
return false;
}

return false;
// all tokens before pos should be processed and cached.
cache_pos_ = token_ids_.size();
token_ids_.push_back(next_token_id);
return true;
}

// decode the sequence to get delta text using the tokenizer
Expand Down
12 changes: 3 additions & 9 deletions src/request/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,9 @@ class Sequence final {
// whether the sequence is in prefill stage, no kv cache has been generated
bool is_prefill() const { return cache_pos_ == 0; }

// add a new token id to the sequence
void append_new_token_id(int token_id) {
// all tokens before pos should be processed and cached.
cache_pos_ = token_ids_.size();
token_ids_.push_back(token_id);
}
// add a new token id to the sequence and check if the sequence is finished.
// returns false if the sequence is finished.
bool append_new_token_id(int32_t next_token_id);

// add new cache blocks
void append_blocks(const std::vector<int32_t>& new_blocks) {
Expand Down Expand Up @@ -98,9 +95,6 @@ class Sequence final {
// get the reason why the sequence is finished
FinishReason finish_reason() const { return finish_reason_; }

// check stopping criterias
bool check_stopping_creteria();

// decode the tokens till end to get delta text using the tokenizer
std::string decode_delta_text(size_t end, const Tokenizer& tokenizer);

Expand Down
8 changes: 3 additions & 5 deletions src/scheduler/continuous_batching_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,9 @@ void ContinuousBatchingScheduler::step(const absl::Duration& timeout) {
// process sequence in batch
for (int64_t i = 0; i < num_seqs; ++i) {
Sequence* seq = sequences_batch_[i];
// append new token id to the sequence
seq->append_new_token_id(static_cast<int>(new_token_ids[i]));

// check if the sequence is finished and update its status
seq->check_stopping_creteria();
const int32_t next_token_id = static_cast<int32_t>(new_token_ids[i]);
// add the next token to sequence and check if the sequence is finished
seq->append_new_token_id(next_token_id);

// stream delta to client if streaming is enabled
if (seq->is_streaming()) {
Expand Down
11 changes: 5 additions & 6 deletions src/server/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,15 @@ int main(int argc, char* argv[]) {
const auto flat_tensor = next_token.view({-1});

// add the next token to the list of tokens
const auto next_token_scalar = static_cast<int>(flat_tensor.item<int>());
sequence.append_new_token_id(next_token_scalar);
const auto next_token_id = static_cast<int>(flat_tensor.item<int>());
if (!sequence.append_new_token_id(next_token_id)) {
// sequence is finished
break;
}

// decode the output and print delta
std::cout << sequence.decode_delta_text(sequence.num_tokens(), *tokenizer)
<< std::flush;

if (sequence.check_stopping_creteria()) {
break;
}
}

// release the slots for the sequence
Expand Down

0 comments on commit bb329ec

Please sign in to comment.