Skip to content

Commit

Permalink
⚡ perf(sqla): single layer scoped session
Browse files Browse the repository at this point in the history
  • Loading branch information
ProgramRipper committed Feb 17, 2024
1 parent 28326e3 commit 477c0f4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 31 deletions.
38 changes: 8 additions & 30 deletions nonebot_plugin_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
import logging
from typing import Any
from argparse import Namespace
from contextlib import suppress
from functools import wraps, lru_cache

import click
from nonebot.rule import Rule
from nonebot.adapters import Event
import sqlalchemy.ext.asyncio as sa_async
from nonebot.permission import Permission
from sqlalchemy.util import greenlet_spawn
from sqlalchemy import URL, Table, MetaData
from nonebot.message import run_postprocessor
from nonebot.params import Depends, DefaultParam
from nonebot.plugin import Plugin, PluginMetadata
from sqlalchemy.util import ScopedRegistry, greenlet_spawn
from sqlalchemy.log import Identified, _qual_logger_name_for_cls
from nonebot.message import run_postprocessor, event_postprocessor
from nonebot.matcher import Matcher, current_event, current_matcher
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from nonebot import logger, require, get_driver, get_plugin_by_module_name
Expand Down Expand Up @@ -66,7 +64,7 @@
_metadatas: dict[str, MetaData]
_plugins: dict[str, Plugin]
_session_factory: sa_async.async_sessionmaker[sa_async.AsyncSession]
_scoped_sessions: ScopedRegistry[sa_async.async_scoped_session[sa_async.AsyncSession]]
_scoped_sessions: sa_async.async_scoped_session[sa_async.AsyncSession]

_data_dir = get_data_dir(__plugin_meta__.name)
_driver = get_driver()
Expand Down Expand Up @@ -104,16 +102,12 @@ def _init_orm():
**plugin_config.sqlalchemy_session_options,
}
)
_scoped_sessions = ScopedRegistry(
lambda: sa_async.async_scoped_session(
_session_factory, lambda: current_matcher.get(None)
),
lambda: id(current_event.get(None)),
_scoped_sessions = sa_async.async_scoped_session(
_session_factory,
lambda: (id(current_event.get(None)), current_matcher.get(None)),
)

# XXX: workaround for https://github.com/nonebot/nonebot2/issues/2475
event_postprocessor(_clear_scoped_session)
run_postprocessor(_close_scoped_session)
run_postprocessor(_scoped_sessions.remove)


@wraps(lambda: None) # NOTE: for dependency injection
Expand All @@ -129,7 +123,7 @@ def get_session(**local_kw: Any) -> sa_async.AsyncSession:

async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSession]:
try:
return _scoped_sessions()
return _scoped_sessions
except NameError:
raise RuntimeError("nonebot-plugin-orm 未初始化") from None

Expand All @@ -139,22 +133,6 @@ async def get_scoped_session() -> sa_async.async_scoped_session[sa_async.AsyncSe
]


# @event_postprocessor
def _clear_scoped_session(event: Event) -> None:
with suppress(KeyError):
del _scoped_sessions.registry[id(event)]


# @run_postprocessor
async def _close_scoped_session(event: Event, matcher: Matcher) -> None:
with suppress(KeyError):
session: sa_async.AsyncSession = _scoped_sessions.registry[
id(event)
].registry.registry[matcher]
del _scoped_sessions.registry[id(event)].registry.registry[matcher]
await session.close()


def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEngine:
if isinstance(engine, AsyncEngine):
return engine
Expand Down
4 changes: 3 additions & 1 deletion nonebot_plugin_orm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def _check_param(

if depends_inner is not None:
dependency = compile_dependency(depends_inner.dependency, option)
elif all(map(isclass, models)) and all(map(issubclass, models, repeat(Model))):
elif all(map(isclass, models)) and all(
map(issubclass, cast(Tuple[type, ...], models), repeat(Model))
):
models = cast(Tuple[Type[Model], ...], models)
dependency = compile_dependency(
select(*models).where(
Expand Down

0 comments on commit 477c0f4

Please sign in to comment.