Skip to content

Commit

Permalink
use restful client
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayini1119 committed Aug 21, 2023
1 parent 302eca9 commit 5dc6b01
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from .. import __version__
from ..client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulClient,
Expand Down Expand Up @@ -354,9 +353,7 @@ def model_generate(
):
endpoint = get_endpoint(endpoint)
if stream:
# TODO: when stream=True, RestfulClient cannot generate words one by one.
# So use Client in temporary. The implementation needs to be changed to
# RestfulClient in the future.

async def generate_internal():
while True:
# the prompt will be written to stdout.
Expand All @@ -365,7 +362,7 @@ async def generate_internal():
if prompt == "":
break
print(f"Completion: {prompt}", end="", file=sys.stdout)
async for chunk in model.generate(
for chunk in model.generate(
prompt=prompt,
generate_config={"stream": stream, "max_tokens": max_tokens},
):
Expand All @@ -376,7 +373,7 @@ async def generate_internal():
print(choice["text"], end="", flush=True, file=sys.stdout)
print("\n", file=sys.stdout)

client = Client(endpoint=endpoint)
client = RESTfulClient(base_url=endpoint)
model = client.get_model(model_uid=model_uid)

loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -436,9 +433,7 @@ def model_chat(
endpoint = get_endpoint(endpoint)
chat_history: "List[ChatCompletionMessage]" = []
if stream:
# TODO: when stream=True, RestfulClient cannot generate words one by one.
# So use Client in temporary. The implementation needs to be changed to
# RestfulClient in the future.

async def chat_internal():
while True:
# the prompt will be written to stdout.
Expand All @@ -449,7 +444,7 @@ async def chat_internal():
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
print("Assistant: ", end="", file=sys.stdout)
response_content = ""
async for chunk in model.chat(
for chunk in model.chat(
prompt=prompt,
chat_history=chat_history,
generate_config={"stream": stream, "max_tokens": max_tokens},
Expand All @@ -465,7 +460,7 @@ async def chat_internal():
ChatCompletionMessage(role="assistant", content=response_content)
)

client = Client(endpoint=endpoint)
client = RESTfulClient(base_url=endpoint)
model = client.get_model(model_uid=model_uid)

loop = asyncio.get_event_loop()
Expand Down

0 comments on commit 5dc6b01

Please sign in to comment.