From 477c0f4860b7842fe512798ac246c49c16fcd103 Mon Sep 17 00:00:00 2001 From: ProgramRipper Date: Sat, 17 Feb 2024 13:28:27 +0000 Subject: [PATCH] :zap: perf(sqla): single layer scoped session --- nonebot_plugin_orm/__init__.py | 38 +++++++--------------------------- nonebot_plugin_orm/param.py | 4 +++- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/nonebot_plugin_orm/__init__.py b/nonebot_plugin_orm/__init__.py index b955100..2823cff 100644 --- a/nonebot_plugin_orm/__init__.py +++ b/nonebot_plugin_orm/__init__.py @@ -4,20 +4,18 @@ import logging from typing import Any from argparse import Namespace -from contextlib import suppress from functools import wraps, lru_cache import click from nonebot.rule import Rule -from nonebot.adapters import Event import sqlalchemy.ext.asyncio as sa_async from nonebot.permission import Permission +from sqlalchemy.util import greenlet_spawn from sqlalchemy import URL, Table, MetaData +from nonebot.message import run_postprocessor from nonebot.params import Depends, DefaultParam from nonebot.plugin import Plugin, PluginMetadata -from sqlalchemy.util import ScopedRegistry, greenlet_spawn from sqlalchemy.log import Identified, _qual_logger_name_for_cls -from nonebot.message import run_postprocessor, event_postprocessor from nonebot.matcher import Matcher, current_event, current_matcher from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from nonebot import logger, require, get_driver, get_plugin_by_module_name @@ -66,7 +64,7 @@ _metadatas: dict[str, MetaData] _plugins: dict[str, Plugin] _session_factory: sa_async.async_sessionmaker[sa_async.AsyncSession] -_scoped_sessions: ScopedRegistry[sa_async.async_scoped_session[sa_async.AsyncSession]] +_scoped_sessions: sa_async.async_scoped_session[sa_async.AsyncSession] _data_dir = get_data_dir(__plugin_meta__.name) _driver = get_driver() @@ -104,16 +102,12 @@ def _init_orm(): **plugin_config.sqlalchemy_session_options, } ) - _scoped_sessions = ScopedRegistry( - lambda: sa_async.async_scoped_session( - _session_factory, lambda: current_matcher.get(None) - ), - lambda: id(current_event.get(None)), + _scoped_sessions = sa_async.async_scoped_session( + _session_factory, + lambda: (id(current_event.get(None)), current_matcher.get(None)), ) - # XXX: workaround for https://github.com/nonebot/nonebot2/issues/2475 - event_postprocessor(_clear_scoped_session) - run_postprocessor(_close_scoped_session) + run_postprocessor(_scoped_sessions.remove) @wraps(lambda: None) # NOTE: for dependency injection @@ -129,7 +123,7 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession: async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]: try: - return _scoped_sessions() + return _scoped_sessions except NameError: raise RuntimeError("nonebot-plugin-orm 未初始化") from None @@ -139,22 +133,6 @@ async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSe ] -# @event_postprocessor -def _clear_scoped_session(event: Event) -> None: - with suppress(KeyError): - del _scoped_sessions.registry[id(event)] - - -# @run_postprocessor -async def _close_scoped_session(event: Event, matcher: Matcher) -> None: - with suppress(KeyError): - session: sa_async.AsyncSession = _scoped_sessions.registry[ - id(event) - ].registry.registry[matcher] - del _scoped_sessions.registry[id(event)].registry.registry[matcher] - await session.close() - - def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine: if isinstance(engine, AsyncEngine): return engine diff --git a/nonebot_plugin_orm/param.py b/nonebot_plugin_orm/param.py index 5077f27..3bc8a7a 100644 --- a/nonebot_plugin_orm/param.py +++ b/nonebot_plugin_orm/param.py @@ -165,7 +165,9 @@ def _check_param( if depends_inner is not None: dependency = compile_dependency(depends_inner.dependency, option) - elif all(map(isclass, models)) and all(map(issubclass, models, repeat(Model))): + elif all(map(isclass, models)) and all( + map(issubclass, cast(Tuple[type, ...], models), repeat(Model)) + ): models = cast(Tuple[Type[Model], ...], models) dependency = compile_dependency( select(*models).where(