Skip to content

Commit

Permalink
✨ plugin's scope
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Dec 26, 2024
1 parent 9a9c9aa commit eb9b5b2
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 54 deletions.
4 changes: 3 additions & 1 deletion arclet/entari/command/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..event.command import CommandExecute
from ..event.config import ConfigReload
from ..message import MessageChain
from ..plugin import RootlessPlugin
from ..plugin import RootlessPlugin, _current_plugin
from ..session import Session
from .argv import MessageArgv # noqa: F401
from .model import CommandResult, Match, Query
Expand Down Expand Up @@ -130,6 +130,8 @@ def on(
meta: Optional[CommandMeta] = None,
) -> Callable[[TTarget[Optional[TM]]], Subscriber[Optional[TM]]]:
auxiliaries = auxiliaries or []
if plg := _current_plugin.get():
auxiliaries.extend(plg._scope.auxiliaries)
providers = providers or []

def wrapper(func: TTarget[Optional[TM]]) -> Subscriber[Optional[TM]]:
Expand Down
31 changes: 16 additions & 15 deletions arclet/entari/command/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .provider import AlconnaProviderFactory, AlconnaSuppiler, Assign, MessageJudges, _seminal

exec_pub = es.define(CommandExecute)
exec_pub.bind(AlconnaProviderFactory())


class AlconnaPluginDispatcher(PluginDispatcher):
Expand All @@ -26,11 +27,17 @@ def __init__(
):
self.supplier = AlconnaSuppiler(command)
super().__init__(plugin, MessageCreatedEvent, command.path)
self.scope.bind(
self.auxiliaries.append(
MessageJudges(need_reply_me, need_notice_me, use_config_prefix),
self.supplier,
AlconnaProviderFactory(),
)
self.auxiliaries.append(self.supplier)
self.providers.append(AlconnaProviderFactory())

@plugin.collect
def dispose():
command_manager.delete(self.supplier.cmd)
del self.supplier.cmd
del self.supplier

def assign(
self,
Expand All @@ -45,19 +52,17 @@ def assign(
_auxiliaries.append(Assign(path, value, or_not))
return self.register(priority=priority, auxiliaries=_auxiliaries, providers=providers)

def dispose(self):
super().dispose()
command_manager.delete(self.supplier.cmd)
del self.supplier.cmd
del self.supplier

def on_execute(
self,
priority: int = 16,
auxiliaries: list[BaseAuxiliary] | None = None,
providers: list[Provider | type[Provider] | ProviderFactory | type[ProviderFactory]] | None = None,
):
return self.scope.register(priority=priority, auxiliaries=auxiliaries, providers=providers, publisher=exec_pub)
_auxiliaries = auxiliaries or []
_auxiliaries.append(self.supplier)
return self.plugin._scope.register(
priority=priority, auxiliaries=_auxiliaries, providers=providers, publisher=exec_pub
)

Match = Match
Query = Query
Expand All @@ -71,8 +76,4 @@ def mount(
) -> AlconnaPluginDispatcher:
if not (plugin := Plugin.current()):
raise LookupError("no plugin context found")
disp = AlconnaPluginDispatcher(plugin, cmd, need_reply_me, need_notice_me, use_config_prefix)
if disp.scope.id in plugin.dispatchers:
return plugin.dispatchers[disp.scope.id] # type: ignore
plugin.dispatchers[disp.scope.id] = disp
return disp
return AlconnaPluginDispatcher(plugin, cmd, need_reply_me, need_notice_me, use_config_prefix)
2 changes: 1 addition & 1 deletion arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def declare_static():
if not (plugin := _current_plugin.get(None)):
raise LookupError("no plugin context found")
plugin.is_static = True
if plugin.dispatchers:
if plugin._scope.subscribers:
raise StaticPluginDispatchError("static plugin cannot dispatch events")


Expand Down
41 changes: 13 additions & 28 deletions arclet/entari/plugin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from typing import Any, Callable, TypeVar, overload
from weakref import ProxyType, finalize, proxy

from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory, StepOut, Subscriber, es
from arclet.letoderea import BaseAuxiliary, Provider, ProviderFactory, Scope, StepOut, Subscriber, es
from arclet.letoderea.publisher import Publisher, _publishers
from arclet.letoderea.scope import _scopes
from arclet.letoderea.typing import TTarget
from creart import it
from launart import Launart, Service
Expand Down Expand Up @@ -42,14 +41,11 @@ def __init__(
event: type[TE],
name: str | None = None,
):
self.publisher = es.define(event)
self.publisher = es.define(event, name)
self.plugin = plugin
self._event = event
scope_id = f"{plugin.id}#{name or self.publisher.id}"
if scope_id in _scopes:
self.scope = _scopes[scope_id]
else:
self.scope = es.scope(f"{self.plugin.id}#{name or self.publisher.id}")
self.providers: list[Provider[Any] | ProviderFactory] = []
self.auxiliaries: list[BaseAuxiliary] = []

def waiter(
self,
Expand All @@ -67,9 +63,6 @@ def wrapper(func: TTarget[R]):

return wrapper

def dispose(self):
self.scope.dispose()

@overload
def register(
self,
Expand Down Expand Up @@ -106,10 +99,12 @@ def register(
) = None,
temporary: bool = False,
):
wrapper = self.scope.register(
_auxiliaries = auxiliaries or []
_providers = providers or []
wrapper = self.plugin._scope.register(
priority=priority,
auxiliaries=auxiliaries,
providers=providers,
auxiliaries=[*self.auxiliaries, *_auxiliaries],
providers=[*self.providers, *_providers],
temporary=temporary,
publisher=self.publisher,
)
Expand Down Expand Up @@ -191,14 +186,14 @@ class Plugin:
id: str
module: ModuleType

dispatchers: dict[str, PluginDispatcher] = field(default_factory=dict)
subplugins: set[str] = field(default_factory=set)
config: dict[str, Any] = field(default_factory=dict)
is_static: bool = False
_metadata: PluginMetadata | None = None
_is_disposed: bool = False
_services: dict[str, Service] = field(init=False, default_factory=dict)
_dispose_callbacks: list[Callable[[], None]] = field(init=False, default_factory=list)
_scope: Scope = field(init=False)

@property
def available(self) -> bool:
Expand Down Expand Up @@ -239,7 +234,7 @@ def update_filter(self, allow: dict, deny: dict):
plugin_service.filters[self.id] = fter

def __post_init__(self):

self._scope = es.scope(self.id)
plugin_service.plugins[self.id] = self
self.update_filter(self.config.pop("$allow", {}), self.config.pop("$deny", {}))
if "$static" in self.config:
Expand Down Expand Up @@ -287,20 +282,14 @@ def dispose(self):
log.plugin.opt(colors=True).error(f"failed to dispose sub-plugin <r>{subplug}</r> caused by {e!r}")
plugin_service.plugins.pop(subplug, None)
self.subplugins.clear()
for disp in self.dispatchers.values():
disp.dispose()
self.dispatchers.clear()
self._scope.dispose()
del plugin_service.plugins[self.id]
del self.module

def dispatch(self, event: type[TE], name: str | None = None):
if self.is_static:
raise StaticPluginDispatchError("static plugin cannot dispatch events")
disp = PluginDispatcher(self, event, name=name)
if disp.scope.id in self.dispatchers:
return self.dispatchers[disp.scope.id]
self.dispatchers[disp.scope.id] = disp
return disp
return PluginDispatcher(self, event, name=name)

@overload
def use(
Expand Down Expand Up @@ -349,10 +338,6 @@ def use(
if pid not in _publishers:
raise LookupError(f"no publisher found: {pid}")
disp = PluginDispatcher(self, _publishers[pid].target)
if disp.scope.id in self.dispatchers:
disp = self.dispatchers[disp.scope.id]
else:
self.dispatchers[disp.scope.id] = disp
if func:
return disp.register(func=func, priority=priority, auxiliaries=auxiliaries, providers=providers)
return disp.register(priority=priority, auxiliaries=auxiliaries, providers=providers)
Expand Down
16 changes: 8 additions & 8 deletions arclet/entari/plugin/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from types import ModuleType
from typing import Optional

from arclet.letoderea import global_auxiliaries
from arclet.letoderea.context import scope_ctx

from ..config import EntariConfig
from ..logger import log
Expand Down Expand Up @@ -235,24 +235,24 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non
setattr(module, "__getattr_or_import__", getattr_or_import)
setattr(module, "__plugin_service__", plugin_service)

aux = AccessAuxiliary(plugin.id)

# enter plugin context
token = _current_plugin.set(plugin)
try:
if not plugin.is_static:
if not is_sub:
global_auxiliaries.append(aux)
plugin._scope.auxiliaries.append(AccessAuxiliary(plugin.id))
token1 = scope_ctx.set(plugin._scope)
try:
super().exec_module(module)
except Exception:
plugin.dispose()
raise
finally:
if not is_sub:
global_auxiliaries.remove(aux)
# leave plugin context
delattr(module, "__cached__")
delattr(module, "__plugin_service__")
sys.modules.pop(module.__name__, None)
if not plugin.is_static:
scope_ctx.reset(token1) # type: ignore
_current_plugin.reset(token)

# get plugin metadata
Expand All @@ -262,7 +262,7 @@ def exec_module(self, module: ModuleType, config: Optional[dict[str, str]] = Non
return


class _NamespacePath(_bootstrap_external._NamespacePath):
class _NamespacePath(_bootstrap_external._NamespacePath): # type: ignore
def _get_parent_path(self):
parent_module_name, path_attr_name = self._find_parent_path_names()
if parent_module_name in plugin_service.plugins:
Expand Down
2 changes: 1 addition & 1 deletion example_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def show(session: Session):

TEST = 5

print([*Plugin.current().dispatchers.keys()])
print([*Plugin.current()._scope.subscribers])
print(Plugin.current().subplugins)
print(local_data.get_temp_dir())

Expand Down

0 comments on commit eb9b5b2

Please sign in to comment.