Skip to content

Commit

Permalink
Merge pull request #3 from EagleChen/master
Browse files Browse the repository at this point in the history
feat: 支持stream效果
  • Loading branch information
yokonsan committed May 21, 2023
2 parents 4fadaff + df4f3c8 commit be00042
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
16 changes: 16 additions & 0 deletions claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fastapi import FastAPI, Depends, Header, HTTPException, status
from pydantic import BaseModel
from fastapi.responses import StreamingResponse

from slack import client

Expand Down Expand Up @@ -33,6 +34,21 @@ async def chat(body: ClaudeChatPrompt):
"claude": await client.get_reply()
}

# add --no-buffer to see the effect of streaming
# curl -X 'POST' --no-buffer \
# 'http://127.0.0.1:8088/claude/stream_chat' \
# -H 'accept: text/plain' \
# -H 'Content-Type: application/json' \
# -d '{
# "prompt": "今天天气很不错吧"}'
@app.post("/claude/stream_chat", dependencies=[Depends(must_token)])
async def chat(body: ClaudeChatPrompt):
await client.open_channel()
await client.chat(body.prompt)

sr = client.get_stream_reply()

return StreamingResponse(sr, media_type="text/plain")

@app.post("/claude/reset", dependencies=[Depends(must_token)])
async def chat():
Expand Down
22 changes: 21 additions & 1 deletion slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ async def chat(self, text):

async def open_channel(self):
if not self.CHANNEL_ID:
print(111)
response = await self.conversations_open(users=CLAUDE_BOT_ID)
self.CHANNEL_ID = response["channel"]["id"]

Expand All @@ -43,6 +42,27 @@ async def get_reply(self):

raise Exception("Get replay timeout")

async def get_stream_reply(self):
l = 0
for _ in range(150):
try:
resp = await self.conversations_history(channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=2)
msg = [msg["text"] for msg in resp["messages"] if msg["user"] == CLAUDE_BOT_ID]
if msg:
last_msg = msg[-1]
more = False
if msg[-1].endswith("Typing…_"):
last_msg = str(msg[-1])[:-11] # remove typing…
more = True
diff = last_msg[l:]
l = len(last_msg)
yield diff
if not more:
break
except (SlackApiError, KeyError) as e:
print(f"Get reply error: {e}")

await asyncio.sleep(2)

client = SlackClient(token=getenv("SLACK_USER_TOKEN"))

Expand Down

0 comments on commit be00042

Please sign in to comment.