Skip to content

Commit

Permalink
Merge pull request #950 from lss233/sourcery/pull-949
Browse files Browse the repository at this point in the history
V3初步适配mirai,适配QQ频道私域机器人 (Sourcery refactored)
  • Loading branch information
Haibersut committed Jun 13, 2023
2 parents 9a99e05 + 2ac2fde commit a2f16cd
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 85 deletions.
4 changes: 4 additions & 0 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def setup_cloudflared(app: Quart):
from framework.platforms.discord_bot import start_task

bots.append(start_task())
if constants.config.qqchannel:
logger.info("检测到 QQChannel 配置,将启动 QQChannel 模式……")
from framework.platforms.qqchannel_bot import start_task

bots.append(start_task())

async def setup_web_service():
from framework.platforms.onebot_bot import bot, start_http_app
Expand Down
12 changes: 12 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ class WecomBot(BaseModel):
description="企业微信应用 API 令牌 的 EncodingAESKey",
)

class QQChannel(BaseModel):
appid: str = Field(
title="Appid",
description="QQ Channel 的 App ID",
default=None
)
token: str = Field(
title="Token",
description="QQ Channel 的 Token",
default=None
)

class OpenAIGPT3Params(BaseModel):
temperature: float = Field(
Expand Down Expand Up @@ -876,6 +887,7 @@ class Config(BaseModel):
discord: Optional[DiscordBot] = None
http: Optional[HttpService] = HttpService()
wecom: Optional[WecomBot] = None
qqchannel: Optional[QQChannel] = None

# === Account Settings ===
accounts: AccountsModel = AccountsModel()
Expand Down
1 change: 1 addition & 0 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@


class BotPlatform(Enum):
QQChannelBot = "qq"
AriadneBot = "mirai"
DiscordBot = "discord"
Onebot = "onebot"
Expand Down
2 changes: 1 addition & 1 deletion framework/middlewares/baiducloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def handle_respond(self, request: Request, response: Response, _next: Call
if not config.baiducloud.check:
return await _next(request, response)
# 不处理没有文字的信息
if not response.body.has(Plain) or not response.text:
if not response.text:
return await _next(request, response)

try:
Expand Down
7 changes: 6 additions & 1 deletion framework/middlewares/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ async def handle_request(self, request: Request, response: Response, _next: Call

async def handle_respond_completed(self, request: Request, response: Response):
key = '好友' if request.session_id.startswith("friend-") else '群组'
msg_id = request.session_id.split('-', 1)[1]

if config.qqchannel:
msg_id = request.session_id
else:
msg_id = request.session_id.split('-', 1)[1]

manager.increment_usage(key, msg_id)
rate_usage = manager.check_exceed(key, msg_id)
if rate_usage >= config.ratelimit.warning_rate:
Expand Down
196 changes: 113 additions & 83 deletions framework/platforms/ariadne_bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import functools
import time
from typing import Union
from typing import Optional, Union

from charset_normalizer import from_bytes
from graia.amnesia.builtins.aiohttp import AiohttpServerService
Expand Down Expand Up @@ -29,6 +30,9 @@
from framework.middlewares.ratelimit import manager as ratelimit_manager
from framework.universal import handle_message
from framework.utils.text_to_img import to_image
from framework.request import Request, Response
from framework.messages import ImageElement
from framework.tts.tts import TTSResponse

# Refer to https://graia.readthedocs.io/ariadne/quickstart/
if config.mirai.reverse_ws_port:
Expand All @@ -51,72 +55,94 @@
),
)

def response(target: Union[Friend, Group], source: Source):
async def respond(msg: AriadneBaseModel):
# 如果是非字符串
if not isinstance(msg, Plain) and not isinstance(msg, str):
event = await app.send_message(
target,
msg,
quote=source if config.response.quote else False
)
# 如果开启了强制转图片
elif config.text_to_image.always and not isinstance(msg, Voice):
event = await app.send_message(
target,
await to_image(str(msg)),
quote=source if config.response.quote else False
)

async def send_message_with_quote(target: Union[Friend, Group], msg: AriadneBaseModel,
quote: Optional[Source] = None):
return await app.send_message(
target,
msg,
quote=quote if config.response.quote else False
)


async def handle_message_failure(target: Union[Friend, Group], msg: MessageChain, quote: Optional[Source] = None):
logger.warning("原始消息发送失败,尝试通过图片发送")
new_elems = []
for elem in msg:
if not new_elems:
new_elems.append(elem)
elif isinstance(new_elems[-1], Plain) and isinstance(elem, Plain):
new_elems[-1].text = new_elems[-1].text + '\n' + elem.text
else:
new_elems.append(elem)
rendered_elems = []
for elem in new_elems:
if isinstance(elem, Plain):
rendered_elems.append(await to_image(elem))
else:
event = await app.send_message(
target,
msg,
quote=source if config.response.quote else False
)
if event.source.id < 0:
event = await app.send_message(
target,
MessageChain(
Forward(
[
ForwardNode(
target=config.mirai.qq,
name="ChatGPT",
message=msg,
time=datetime.datetime.now()
)
]
rendered_elems.append(elem)
return await send_message_with_quote(target, MessageChain(rendered_elems), quote)


async def respond(target: Union[Friend, Group], source: Source, msg: AriadneBaseModel):
# 如果是非字符串
if not isinstance(msg, Plain) and not isinstance(msg, str):
event = await send_message_with_quote(target, msg, source)
# 如果开启了强制转图片
elif config.text_to_image.always and not isinstance(msg, Voice):
event = await send_message_with_quote(target, await to_image(str(msg)), source)
else:
event = await send_message_with_quote(target, msg, source)

if event.source.id < 0:
logger.warning("原始消息发送失败,尝试通过转发发送")
event = await send_message_with_quote(target, MessageChain(
Forward(
[
ForwardNode(
target=config.mirai.qq,
name="ChatGPT",
message=msg,
time=datetime.datetime.now()
)
)
)
if event.source.id < 0:
await app.send_message(
target,
"消息发送失败,被TX吞了,尝试转成图片再试一次,请稍等",
quote=source if config.response.quote else False
)
new_elems = []
for elem in msg:
if not new_elems:
new_elems.append(elem)
elif isinstance(new_elems[-1], Plain) and isinstance(elem, Plain):
new_elems[-1].text = new_elems[-1].text + '\n' + elem.text
else:
new_elems.append(elem)
rendered_elems = []
for elem in new_elems:
if isinstance(elem, Plain):
rendered_elems.append(await to_image(elem))
else:
rendered_elems.append(elem)
event = await app.send_message(
target,
MessageChain(rendered_elems),
quote=source if config.response.quote else False
)
return event

return respond
])
))

if event.source.id < 0:
event = await handle_message_failure(target, msg, source)

return event


async def response(target: Union[Friend, Group], source: Source, chain: MessageChain = None, text: str = None,
voice: TTSResponse = None, image: ImageElement = None):
try:
if chain:
logger.debug(f"[Mirai] 尝试发送消息:{str(chain)}")
return await respond(target, source, chain)
elif text:
return await respond(target, source, MessageChain([Plain(text)]))
elif voice:
return await respond(target, source, Voice(voice))
elif image:
return await respond(target, source, MessageChain([image]))
else:
raise ValueError("没有为response函数提供有效输入")

except Exception as e:
logger.error(f"处理响应时发生错误: {e}")


def create_request(user_id, target_id, platform, is_manager, chain, nickname, session_prefix):
request = Request()
request.user_id = user_id
request.group_id = target_id
request.session_id = f"{session_prefix}-{target_id}"
request.message = chain
request.platform = platform
request.is_manager = is_manager
request.nickname = nickname
return request


FriendTrigger = DetectPrefix(config.trigger.prefix + config.trigger.prefix_friend)
Expand All @@ -136,14 +162,17 @@ async def friend_message_listener(app: Ariadne, target: Friend, source: Source,
if chain.display.startswith("."):
return

await handle_message(
response(target, source),
f"friend-{target.id}",
chain.display,
chain,
is_manager=target.id == config.mirai.manager_qq,
nickname=target.nickname
)
request = create_request(target.id, target.id, constants.BotPlatform.AriadneBot,
target.id == config.mirai.manager_qq,
chain, target.nickname, "friend")

respond_partial = functools.partial(response, target, source)
response_obj = Response(respond_partial)

try:
await handle_message(request, response_obj)
except Exception as e:
logger.exception(e)


GroupTrigger = Annotated[MessageChain, MentionMe(config.trigger.require_mention != "at"), DetectPrefix(
Expand All @@ -156,14 +185,17 @@ async def group_message_listener(target: Group, source: Source, chain: GroupTrig
if chain.display.startswith("."):
return

await handle_message(
response(target, source),
f"group-{target.id}",
chain.display,
chain,
is_manager=member.id == config.mirai.manager_qq,
nickname=member.name
)
request = create_request(member.id, target.id, constants.BotPlatform.AriadneBot,
member.id == config.mirai.manager_qq,
chain, member.name, "group")

respond_partial = functools.partial(response, target, source)
response_obj = Response(respond_partial)

try:
await handle_message(request, response_obj)
except Exception as e:
logger.exception(e)


@app.broadcast.receiver("NewFriendRequestEvent")
Expand Down Expand Up @@ -240,7 +272,6 @@ async def update_rate(app: Ariadne, event: MessageEvent, sender: Union[Friend, M
raise ExecutionStop()



@cmd.command(".查看 {msg_type: str} {msg_id: str} 的使用情况")
async def show_rate(app: Ariadne, event: MessageEvent, msg_type: str, msg_id: str):
try:
Expand Down Expand Up @@ -281,7 +312,6 @@ async def show_rate(app: Ariadne, event: MessageEvent, msg_type: str, msg_id: st
raise ExecutionStop()



@cmd.command(".预设列表")
async def presets_list(app: Ariadne, event: MessageEvent, sender: Union[Friend, Member]):
try:
Expand Down
Loading

0 comments on commit a2f16cd

Please sign in to comment.