From 5dc6b01a95b0b4b624e7751f8ad7d9400dd61e34 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Mon, 21 Aug 2023 13:02:26 +0800 Subject: [PATCH] use restful client --- xinference/deploy/cmdline.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 6b67c67c61..ce84010b9a 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -24,7 +24,6 @@ from .. import __version__ from ..client import ( - Client, RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, RESTfulClient, @@ -354,9 +353,7 @@ def model_generate( ): endpoint = get_endpoint(endpoint) if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. + async def generate_internal(): while True: # the prompt will be written to stdout. @@ -365,7 +362,7 @@ async def generate_internal(): if prompt == "": break print(f"Completion: {prompt}", end="", file=sys.stdout) - async for chunk in model.generate( + for chunk in model.generate( prompt=prompt, generate_config={"stream": stream, "max_tokens": max_tokens}, ): @@ -376,7 +373,7 @@ async def generate_internal(): print(choice["text"], end="", flush=True, file=sys.stdout) print("\n", file=sys.stdout) - client = Client(endpoint=endpoint) + client = RESTfulClient(base_url=endpoint) model = client.get_model(model_uid=model_uid) loop = asyncio.get_event_loop() @@ -436,9 +433,7 @@ def model_chat( endpoint = get_endpoint(endpoint) chat_history: "List[ChatCompletionMessage]" = [] if stream: - # TODO: when stream=True, RestfulClient cannot generate words one by one. - # So use Client in temporary. The implementation needs to be changed to - # RestfulClient in the future. + async def chat_internal(): while True: # the prompt will be written to stdout. @@ -449,7 +444,7 @@ async def chat_internal(): chat_history.append(ChatCompletionMessage(role="user", content=prompt)) print("Assistant: ", end="", file=sys.stdout) response_content = "" - async for chunk in model.chat( + for chunk in model.chat( prompt=prompt, chat_history=chat_history, generate_config={"stream": stream, "max_tokens": max_tokens}, @@ -465,7 +460,7 @@ async def chat_internal(): ChatCompletionMessage(role="assistant", content=response_content) ) - client = Client(endpoint=endpoint) + client = RESTfulClient(base_url=endpoint) model = client.get_model(model_uid=model_uid) loop = asyncio.get_event_loop()