zl程序教程

您现在的位置是:首页 >  前端

当前栏目

实现了一个基于 OneBot v11 的开发框架,聊聊其中的细节

框架开发 实现 一个 基于 细节 聊聊 其中
2023-06-13 09:17:12 时间

前言

写了一个基于 OneBot v11 的机器人开发框架,或者说 SDK,其中包含了蛮多东西,所以单独写篇文章聊聊。

整个项目说起来还是比较复杂,所以这里只捡出几个核心实现。这个框架只是用来学习原理,并没有投入生产环境的打算。当然如果你愿意或者喜欢本项目的风格,部署到生产环境是没问题的 :)

项目地址:https://github.com/kifuan/shirasu

使用方法

还是比较直观的,这里把 README 里的示例粘贴过来。

import asyncio
from shirasu import AddonPool, OneBotClient


if __name__ == '__main__':
    pool = AddonPool.from_modules(
        'shirasu.addons.echo',
        'shirasu.addons.help',
    )
    asyncio.run(OneBotClient.listen(pool=pool))

至于插件的定义方法下文会介绍,接下来在入口文件同级目录下创建 shirasu.yml 配置 WebSocket 地址。

# The WebSocket server URL(not reverse WebSocket).
ws: ws://127.0.0.1:8080

随后打开一个 OneBot v11 具体实现,如 go-cqhttp,以正向ws的方式进行连接。给机器人发 /echo hello 它就会正常回复一个 hello 了。

Task 隐患

根据 python/cpython#91887,使用 asyncio.create_task 创建的任务都只有一个弱引用。也就是说,如果你不手动将创建的任务储存起来,GC 可以把没执行的任务直接回收掉。所以就有了这个 issue,下方回复中给出可以使用 asyncio.TaskGroup 但是需要 Python 3.11+。我的开发环境是 3.10,所以手动处理一下这个问题。

解决方式很简单,用一个 set 把没完成的任务存起来就行,大致代码如下:

def __init__(self):
    self._tasks: set[asyncio.Task] = set()

def use(self):
    task = asyncio.create_task(...)
    self._tasks.add(task)
    task.add_done_callback(self._tasks.discard)

这样就可以保证没完成的任务不会被 GC 删除,扫除了一个隐患。

获取 WS 响应数据

使用 Python 中的 websockets 库,每次收到消息时都会使用 asyncio.create_task 创建一个 task 并行处理,所以这里就会有一个问题,就是当你发送之后并不能确定下一次收到的是否为当前调用结果,所以 OneBot 定义了一个 echo 字段,同一个请求和返回时这个字段将是相同的。

这里提一下,在 OneBot v11 中并没有规定 echo 的数据类型,而 v12 中规定是 string。至于具体实现,参考 go-cqhttp 源码发现传什么数据都无所谓。下面的 jgjson.Parse 返回的对象。

ret := c.apiCaller.Call(t, j.Get("params"))
if j.Get("echo").Exists() {
	ret["echo"] = j.Get("echo").Value()
}

但是最新标准规定要传字符串,那就用字符串了。

据此,我们需要创建一个工具类,利用 asyncio.Future 实现数据的延时获取,下方源码位于 shirasu/util/future_table.py

import sys
import asyncio
from typing import Any


class FutureTable:
    def __init__(self) -> None:
        self._future_id = 0
        self._futures: dict[int, asyncio.Future] = {}

    def register(self) -> int:
        self._future_id = (self._future_id + 1) % sys.maxsize
        self._futures[self._future_id] = asyncio.get_event_loop().create_future()
        return self._future_id

    def set(self, echo: int, data: dict[str, Any]) -> None:
        if future := self._futures.get(echo):
            future.set_result(data)

    async def get(self, future_id: int, timeout: float) -> dict[str, Any]:
        if not (future := self._futures.get(future_id)):
            raise KeyError(f'future id {future_id} does not exist')
        try:
            return await asyncio.wait_for(future, timeout)
        finally:
            del self._futures[future_id]

在每一次收数据时,如果检查到存在 echo 字段,就将这个 FutureTable 中设置上对应的相应值,从而达到调用 API 并获取响应数据的目的。这里附上核心代码,源码位于shirasu/client/onebot.py

async def call_action(self, action: str, **params: Any) -> dict[str, Any]:
    logger.info(f'Calling action {action}.')
    future_id = self._futures.register()
    await self._ws.send(ujson.dumps({
        'action': action,
        'params': params,
        'echo': str(future_id),
    }))

    data = await self._futures.get(future_id, self._global_config.action_timeout)
    if data.get('status') == 'failed':
        raise ClientActionError(data)
        
    return data.get('data', {})

async def _handle(self, data: dict[str, Any]) -> None:
    if echo := data.get('echo'):
        self._futures.set(int(echo), data)
        return
    ...

这个 _handle 就是每次获取到数据时建立的 task

简单依赖注入框架

我个人比较喜欢 Spring 那种依赖注入。

@Autowired
private FooService fooService;

但是在 FastAPI 中使用时往往需要你这么做:

def use(foo: Foo = Depends(get_foo)) -> None:
    ...

其实也有别的框架,可以做到 Spring 那种类型的,但是我还是按照我的想法实现了一个简单的框架,源码在 shirasu/di.py,这里就贴上使用方法:

import asyncio
from datetime import datetime
from shirasu.di import inject, provide


@provide('now')
async def provide_now() -> datetime:
    return datetime.now()


@provide('today')
async def provide_today(now: datetime) -> int:
    await asyncio.sleep(.1)
    return now.day


@inject()
async def use_today(today: int) -> None:
    print(today)


@inject()
async def use_now(now: datetime) -> None:
    await asyncio.sleep(.1)
    print(now.year)


# 1
await use_today()

# 2023
await use_now()

为了保持一致性和方便使用,这里均采用异步函数,为了方便在 provider 中进行异步操作。现在看上去不太方便因为没有真正到应用场景。

我个人认为,依赖注入最重要的就是你的 provider 也可以有自己的依赖项,比如上方代码的 today 就依赖于 now,如果后期它还有别的依赖项,可以只修改它本身,其它地方均无需修改

此外,这个框架还会根据你提供的 type hint 来判断 provider 返回的类型和你在参数后面写的类型是否一致,当然如果是子类也可以,如果不一致它就会打印一条警告信息。你可以直接不标注类型来跳过类型检测


接下来聊聊实现的核心逻辑,它是根据你 provide 的时候提供的字符串来判断。以下几个函数是实现的核心,如果读者想看源码可以直接去文章开头提到的仓库下 shirasu/di.py 找到完整代码,其中 ATAwaitable[T] 的别名。

class DependencyInjector:
    """
    Dependency injector based on parameter names.
    Note: positional-only arguments are not supported.
    """

    def __init__(self) -> None:
        self._providers: dict[str, Callable[..., AT]] = {}

    async def _inject_func_args(self, func: Callable[..., AT], *inject_for: str) -> dict[str, Any]:
        params = inspect.signature(func).parameters

        # Check unknown dependencies.
        if unknown_deps := [dep for dep in params if dep not in self._providers]:
            raise UnknownDependencyError(unknown_deps)

        # Check circular dependencies.
        if circular_deps := [dep for dep in params if dep in inject_for]:
            raise CircularDependencyError(circular_deps)

        args = {
            dep: await self._apply(self._providers[dep], dep, *inject_for)
            for dep in params
        }

        # Check types of injected parameters.
        for dep, param in params.items():
            anno = param.annotation

            # Skip untyped parameters.
            if anno == inspect.Parameter.empty:
                continue

            if not isinstance(val := args[dep], expected := anno):
                module = inspect.getmodule(func)
                module_name = module.__name__ if module else '<unknown module>'
                module_func_name = f'{module_name}:{func.__name__}'
                logger.warning(f'type mismatch for parameter {dep} in function {module_func_name}, '
                               f'real type: {type(val).__name__}, expected: {expected.__name__}')

        return args

    async def _apply(self, func: Callable[..., AT], *apply_for: str) -> AT:
        injected_args = await self._inject_func_args(func, *apply_for)
        return await func(**injected_args)

通过可变长参数 apply_forinject_for 来判断是否存在循环依赖,还在获取依赖的时候递归调用 _apply 从而达到前文提到的 provider 也可以有依赖项的目的。被 @inject() 包装的代码其实就是调用 _apply(func)

插件系统

这相当于面向用户最常用的接口,源码位于 shirasu/addon 下,还有一部分内置插件位于 shirasu/addons,基本都是为了我测试而写的。

实现起来没什么难度,这里就简单贴一下使用方法,以 shirasu/addons/square.py 为例,这是一个计算平方的插件。

from pydantic import BaseModel
from shirasu import Client, Addon, MessageEvent, command


class SquareConfig(BaseModel):
    precision: int = 2


square = Addon(
    name='square',
    usage='/square number',
    description='Calculates the square of given number.',
    config_model=SquareConfig,
)


@square.receive(command('square'))
async def handle_square(client: Client, event: MessageEvent, config: SquareConfig) -> None:
    arg = event.arg

    try:
        result = round(float(arg) ** 2, config.precision)
        await client.send(f'{result:g}')
    except ValueError:
        await client.reject(f'Invalid number: {arg}')

可以看到内部使用 pydantic 进行数据配置,这个在下个章节细说,这里先跳过。

@square.receive(...) 装饰的函数默认就也会被注入,所以不需要手动写 @inject()

在加载插件的时候,使用了 importlib 这个内置库,原理就是扫描整个模块,把 Addon 的实例加载进来而已,源码位于 shirasu/addon/pool.py

def load_module(self, module_name: str) -> 'AddonPool':
    try:
        module = importlib.import_module(module_name)
    except ImportError as e:
        raise LoadAddonError(f'failed to load addons in module {module_name}') from e

    addons = [p for p in module.__dict__.values() if isinstance(p, Addon)]
    if not addons:
        raise LoadAddonError(f'no addons in module {module_name}')

    for addon in addons:
        self.load(addon)

    return self

配置系统

本项目采用 YAMLpydantic 进行配置,相对于 .env 的优势不必多说,我本人也是比较喜欢 YAML 或者 JSON 的配置模式的。

为了加载 yml 文件,需要依赖 pyyaml 这个库,以下源码位于 shirasu/config.py

import yaml
from typing import Any
from pathlib import Path
from pydantic import BaseModel


class GlobalConfig(BaseModel):
    """
    The global configuration.
    """

    ws: str = 'ws://127.0.0.1:8080'
    addons: dict[str, dict[str, Any]] = {}
    command_prefixes: list[str] = ['/']
    action_timeout: float = 30.


def load_config(path: str | Path) -> GlobalConfig:
    return GlobalConfig.parse_obj(yaml.safe_load(Path(path).read_text('utf8')))

对于 addons 这个配置项,这是每个插件具体的配置,如 square 这个插件配置它的精度,就在 yml 里这么写:

addons:
  square:
    precision: 3

shirasu/addon/addon.py 中,会为每个插件的 receiver 注入一个它自身的配置,如下:

def _provide_config(self) -> None:
    async def provide(global_config: GlobalConfig) -> Any:
        return self._config_model.parse_obj(global_config.addons.get(self._name, {}))
    di.provide('config', provide, check_duplicate=False)

在调用 matcherreceiver 前都会先调用 _provide_config 来为它们提供插件的配置,其中 self._config_modelType[pydantic.BaseModel],当插件被定义的时候需要指定 config_model,就像上文中的 square 插件做到的那样。

这里顺带提一下,@provide() 装饰器只是对 di.provide 这个方法进行了包装,在框架内部有大量代码都直接使用 di.provide 来提供依赖。

测试系统

核心步骤就是发一条假信息,收一条信息,判断是否符合要求。

本项目提供了一个简单的方法进行单元测试,推荐使用 pytest + pytest-asyncio 进行单元测试。

import pytest
import asyncio
from shirasu import MockClient, AddonPool


@pytest.mark.asyncio
async def test_square():
    pool = AddonPool.from_modules('shirasu.addons.square')
    client = MockClient(pool)

    await client.post_message('/square 2')
    square2_msg = await client.get_message()
    assert square2_msg.plain_text == '4'

    await client.post_message('/square a')
    rejected_msg = await client.get_message_event()
    assert rejected_msg.is_rejected


@pytest.mark.asyncio
async def test_echo():
    pool = AddonPool.from_modules('shirasu.addons.echo')
    client = MockClient(pool)

    await client.post_message('/echo hello')
    echo_msg = await client.get_message()
    assert echo_msg.plain_text == 'hello'

    await client.post_message('echo hello')
    with pytest.raises(asyncio.TimeoutError):
        await client.get_message()

如果不希望它有任何回复,可以捕捉 asyncio.TimeoutError,下面是 shirasu/client/mock.py 中的核心代码。

_message_event_queue: asyncio.queues.Queue[MessageEvent]

async def post_event(self, event: Event) -> None:
    self.curr_event = event
    await self.apply_addons()

async def get_message_event(self, timeout: float = .1) -> MessageEvent:    
    return await asyncio.wait_for(self._message_event_queue.get(), timeout)

通过 asyncio 提供的 queueasyncio.wait_for 来实现一个超时自动报错的效果。

特别感谢

  • nonebot-adapter-onebot:部分代码参考 NB 实现,NB 就是 NB!
  • go-cqhttp:开发过程中主要参考 go-cqhttp 的文档。
  • voidbot:最核心的功能都在这200行代码中实现了,对我帮助很大,但他写的是同步逻辑。
  • arashi:这个仓库目的也是提供一个功能比较完善的最小实现。