Skip to content

Commit

Permalink
added n support for non-stream requests. (not-optimized implementation)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Jan 4, 2024
1 parent 298f246 commit 92ad14a
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 116 deletions.
157 changes: 96 additions & 61 deletions src/server/handlers/chat_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ bool verify_request_arguments(ChatCallData* call_data) {
return false;
}

// n and best_of are not implemented yet
if (request.has_n() && request.n() > 1) {
call_data->finish_with_error(grpc::StatusCode::UNIMPLEMENTED,
"n > 1 is not supported yet");
const bool stream = request.has_stream() ? request.stream() : false;
// n is not implemented for stream requests
// N.B. grpc stream requires user to send messages in sequence for order
// guarantee. If we want to support n > 1 for stream requests, we need to add
// support in call_data to send messages in sequence explicitly to workaround
// the grpc stream limitation.
if (stream && request.has_n() && request.n() > 1) {
call_data->finish_with_error(
grpc::StatusCode::UNIMPLEMENTED,
"n != 1 is not supported yet for stream requests");
return false;
}
// temperature between [0.0, 2.0]
Expand Down Expand Up @@ -76,6 +82,70 @@ bool verify_request_arguments(ChatCallData* call_data) {
return true;
}

bool send_delta_to_client(ChatCallData* call_data,
Request* request,
uint32_t index,
bool first_message,
const std::string& delta,
FinishReason reason) {
ChatResponse response;
response.set_object("chat.completion.chunk");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);
auto* choice = response.add_choices();
choice->set_index(index);
// add message
auto* message = choice->mutable_delta();
// only set role for first message
if (first_message) {
message->set_role("assistant");
}
message->set_content(delta);
if (reason != FinishReason::NONE) {
choice->set_finish_reason(finish_reason_to_string(reason));
}
return call_data->write(response);
}

bool send_result_to_client(ChatCallData* call_data,
Request* request,
const std::vector<SequenceResult>& seq_results,
const Status& /*status*/,
const Statistics& stats) {
ChatResponse response;
response.set_object("chat.completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);

// add choices into response
for (uint32_t i = 0; i < seq_results.size(); ++i) {
const auto& seq_result = seq_results[i];
auto* choice = response.add_choices();
choice->set_index(i);
auto* message = choice->mutable_message();
message->set_role("assistant");
message->set_content(seq_result.output_text);
if (seq_result.finish_reason != FinishReason::NONE) {
choice->set_finish_reason(
finish_reason_to_string(seq_result.finish_reason));
}
}

// add usage statistics
auto* usage = response.mutable_usage();
usage->set_prompt_tokens(static_cast<int32_t>(stats.num_prompt_tokens));
usage->set_completion_tokens(
static_cast<int32_t>(stats.num_generated_tokens));
usage->set_total_tokens(static_cast<int32_t>(stats.num_total_tokens));

// TODO: combine write and finish
call_data->write(response);
// TODO: mapping status to grpc status
return call_data->finish();
}

// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
const Tokenizer& tokenizer,
Expand Down Expand Up @@ -180,72 +250,37 @@ std::unique_ptr<Request> grpc_request_to_request(ChatCallData* call_data,
request->echo = false;

// add on_stream and on_finish callbacks
const uint32_t num_seqs = grpc_request.has_n() ? grpc_request.n() : 1;
if (request->stream) {
auto on_stream = [call_data, request = request.get(), first_message = true](
const std::string& delta,
FinishReason reason) mutable -> bool {
ChatResponse response;
response.set_object("chat.completion.chunk");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);
auto* choice = response.add_choices();
auto* message = choice->mutable_delta();
// only set role for first message
if (first_message) {
message->set_role("assistant");
first_message = false;
}
message->set_content(delta);
choice->set_index(0);
if (reason != FinishReason::NONE) {
choice->set_finish_reason(finish_reason_to_string(reason));
}
return call_data->write(response);
};

request->add_sequence(on_stream);
// add sequences with on_stream callback
for (uint32_t i = 0; i < num_seqs; ++i) {
request->add_sequence(
[call_data, request = request.get(), i, first_message = true](
const std::string& delta, FinishReason reason) mutable -> bool {
bool ret = send_delta_to_client(
call_data, request, i, first_message, delta, reason);
first_message = false;
return ret;
});
}

// set callback for stream request
request->on_stream_finish = [call_data](const Status& /*status*/) -> bool {
return call_data->finish();
};
} else {
request->add_sequence();
// add sequences
for (uint32_t i = 0; i < num_seqs; ++i) {
request->add_sequence();
}

// set callback for non-stream request
request->on_finish = [call_data, request = request.get()](
const std::vector<SequenceResult>& seq_results,
const Status& /*status*/,
const Status& status,
const Statistics& stats) -> bool {
ChatResponse response;
response.set_object("chat.completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);

// add choices into response
for (uint32_t i = 0; i < seq_results.size(); ++i) {
const auto& seq_result = seq_results[i];
auto* choice = response.add_choices();
choice->set_index(i);
auto* message = choice->mutable_message();
message->set_role("assistant");
message->set_content(seq_result.output_text);
if (seq_result.finish_reason != FinishReason::NONE) {
choice->set_finish_reason(
finish_reason_to_string(seq_result.finish_reason));
}
}

// add usage statistics
auto* usage = response.mutable_usage();
usage->set_prompt_tokens(static_cast<int32_t>(stats.num_prompt_tokens));
usage->set_completion_tokens(
static_cast<int32_t>(stats.num_generated_tokens));
usage->set_total_tokens(static_cast<int32_t>(stats.num_total_tokens));

// TODO: combine write and finish
call_data->write(response);
// TODO: mapping status to grpc status
return call_data->finish();
return send_result_to_client(
call_data, request, seq_results, status, stats);
};
}
return request;
Expand Down
136 changes: 81 additions & 55 deletions src/server/handlers/completion_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ std::string generate_request_id() {
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
bool verify_request_arguments(CompletionCallData* call_data) {
const auto& request = call_data->request();
// n and best_of are not implemented yet
if (request.has_n() && request.n() > 1) {
// n is not implemented yet for stream request
const bool stream = request.has_stream() ? request.stream() : false;
const uint32_t n = request.has_n() ? request.n() : 1;
if (stream && n > 1) {
call_data->finish_with_error(grpc::StatusCode::UNIMPLEMENTED,
"n > 1 is not supported yet");
return false;
}
if (request.has_best_of() && request.best_of() > 1) {

if (request.has_best_of() && request.best_of() != n) {
call_data->finish_with_error(grpc::StatusCode::UNIMPLEMENTED,
"best_of > 1 is not supported yet");
"best_of != n is not supported yet");
return false;
}

Expand Down Expand Up @@ -96,6 +99,61 @@ bool verify_request_arguments(CompletionCallData* call_data) {
return true;
}

bool send_delta_to_client(CompletionCallData* call_data,
Request* request,
uint32_t index,
const std::string& delta,
FinishReason reason) {
CompletionResponse response;
response.set_object("text_completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);
auto* choice = response.add_choices();
choice->set_index(index);
choice->set_text(delta);
if (reason != FinishReason::NONE) {
choice->set_finish_reason(finish_reason_to_string(reason));
}
return call_data->write(response);
}

bool send_result_to_client(CompletionCallData* call_data,
Request* request,
const std::vector<SequenceResult>& seq_results,
const Status& /*status*/,
const Statistics& stats) {
CompletionResponse response;
response.set_object("text_completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);

// add choices into response
for (uint32_t i = 0; i < seq_results.size(); ++i) {
const auto& seq_result = seq_results[i];
auto* choice = response.add_choices();
choice->set_index(i);
choice->set_text(seq_result.output_text);
// choice->set_logprobs(0);
if (seq_result.finish_reason != FinishReason::NONE) {
choice->set_finish_reason(
finish_reason_to_string(seq_result.finish_reason));
}
}

// add usage statistics
auto* usage = response.mutable_usage();
usage->set_prompt_tokens(static_cast<int32_t>(stats.num_prompt_tokens));
usage->set_completion_tokens(
static_cast<int32_t>(stats.num_generated_tokens));
usage->set_total_tokens(static_cast<int32_t>(stats.num_total_tokens));
// TODO: combine write and finish
call_data->write(response);
// TODO: mapping status to grpc status
return call_data->finish();
}

std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
const Tokenizer& tokenizer,
const ModelArgs& model_args) {
Expand Down Expand Up @@ -166,66 +224,34 @@ std::unique_ptr<Request> grpc_request_to_request(CompletionCallData* call_data,
}

// add on_stream and on_finish callbacks
const uint32_t num_seqs = grpc_request.has_n() ? grpc_request.n() : 1;
if (request->stream) {
auto on_stream = [call_data, request = request.get()](
const std::string& delta,
FinishReason reason) -> bool {
CompletionResponse response;
response.set_object("text_completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);
auto* choice = response.add_choices();
choice->set_text(delta);
// choice->set_logprobs(0);
choice->set_index(0);
if (reason != FinishReason::NONE) {
choice->set_finish_reason(finish_reason_to_string(reason));
}
return call_data->write(response);
};

request->add_sequence(on_stream);
// add sequences with on_stream callback
for (uint32_t i = 0; i < num_seqs; ++i) {
request->add_sequence(
[call_data, request = request.get(), i](const std::string& delta,
FinishReason reason) -> bool {
return send_delta_to_client(call_data, request, i, delta, reason);
});
}

// add on_stream_finish callback
request->on_stream_finish = [call_data](const Status& /*status*/) -> bool {
// TODO: mapping status to grpc status
return call_data->finish();
};
} else {
request->add_sequence();
// add sequences
for (uint32_t i = 0; i < num_seqs; ++i) {
request->add_sequence();
}

// add on_finish callback
request->on_finish = [call_data, request = request.get()](
const std::vector<SequenceResult>& seq_results,
const Status& /*status*/,
const Status& status,
const Statistics& stats) -> bool {
CompletionResponse response;
response.set_object("text_completion");
response.set_id(request->id);
response.set_created(request->created_time);
// response.set_model(request->model);

// add choices into response
for (uint32_t i = 0; i < seq_results.size(); ++i) {
const auto& seq_result = seq_results[i];
auto* choice = response.add_choices();
choice->set_index(i);
choice->set_text(seq_result.output_text);
// choice->set_logprobs(0);
if (seq_result.finish_reason != FinishReason::NONE) {
choice->set_finish_reason(
finish_reason_to_string(seq_result.finish_reason));
}
}

// add usage statistics
auto* usage = response.mutable_usage();
usage->set_prompt_tokens(static_cast<int32_t>(stats.num_prompt_tokens));
usage->set_completion_tokens(
static_cast<int32_t>(stats.num_generated_tokens));
usage->set_total_tokens(static_cast<int32_t>(stats.num_total_tokens));
// TODO: combine write and finish
call_data->write(response);
// TODO: mapping status to grpc status
return call_data->finish();
return send_result_to_client(
call_data, request, seq_results, status, stats);
};
}
return request;
Expand Down

0 comments on commit 92ad14a

Please sign in to comment.