Skip to content

Commit

Permalink
refactor the checkers system
Browse files Browse the repository at this point in the history
  • Loading branch information
AiroPi committed Mar 30, 2024
1 parent c4bd67b commit 8d2cf07
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 159 deletions.
4 changes: 3 additions & 1 deletion src/cogs/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from discord import app_commands

from core import ExtendedCog, config
from core.checkers import is_me

if TYPE_CHECKING:
from discord import Interaction
Expand All @@ -15,9 +16,10 @@
logger = logging.getLogger(__name__)


class Admin(ExtendedCog): # TODO(airo.pi_): add checkers
class Admin(ExtendedCog):
@app_commands.command()
@app_commands.guilds(config.support_guild_id)
@is_me
async def reload_extension(self, inter: Interaction, extension: str):
await self.bot.reload_extension(extension)
await inter.response.send_message(f"Extension [{extension}] reloaded successfully")
Expand Down
2 changes: 1 addition & 1 deletion src/cogs/clear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, bot: MyBot):

self.clear_max_concurrency = checkers.MaxConcurrency(1, key=channel_bucket, wait=False)

@checkers.app.bot_required_permissions(
@checkers.bot_required_permissions(
manage_messages=True, read_message_history=True, read_messages=True, connect=True
)
@app_commands.command(
Expand Down
7 changes: 3 additions & 4 deletions src/cogs/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from core import ExtendedCog
from core._config import config
from core.checkers.app import is_me
from core.checkers.base import is_me_bool
from core.checkers import is_me, is_me_test
from core.utils import size_text

if TYPE_CHECKING:
Expand All @@ -32,7 +31,7 @@

class Eval(ExtendedCog):
@commands.command(name="+eval")
@commands.check(lambda ctx: is_me_bool(ctx.author.id))
@commands.check(lambda ctx: is_me_test(ctx.author.id))
async def add_eval(self, ctx: commands.Context[MyBot]) -> None:
try:
self.bot.tree.add_command(self._eval, guild=ctx.guild)
Expand All @@ -44,7 +43,7 @@ async def add_eval(self, ctx: commands.Context[MyBot]) -> None:
await ctx.send("Command added.")

@commands.command(name="-eval")
@commands.check(lambda ctx: is_me_bool(ctx.author.id))
@commands.check(lambda ctx: is_me_test(ctx.author.id))
async def remove_eval(self, ctx: commands.Context[MyBot]) -> None:
if self.bot.tree.remove_command("eval", guild=ctx.guild) is None:
await ctx.send("Command not registered. Cleaning eventual leftovers...")
Expand Down
2 changes: 1 addition & 1 deletion src/cogs/poll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlalchemy.orm import selectinload

from core import ExtendedGroupCog, db
from core.checkers.app import bot_required_permissions
from core.checkers import bot_required_permissions
from core.errors import NonSpecificError
from core.i18n import _

Expand Down
12 changes: 9 additions & 3 deletions src/cogs/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import re
from typing import TYPE_CHECKING

from discord import Interaction, app_commands

from core import ExtendedCog, MiscCommandContext, misc_command
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check
from core.checkers import bot_required_permissions, check, is_activated_predicate, is_user_authorized_predicate

if TYPE_CHECKING:
from discord import Message
Expand All @@ -27,11 +29,15 @@ def contains_message_link(self, message: Message) -> bool:
trigger_condition=contains_message_link,
)
@bot_required_permissions(manage_webhooks=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
@check(is_activated_predicate)
@check(is_user_authorized_predicate)
async def on_message(self, ctx: MiscCommandContext[MyBot], message: Message) -> None:
raise NotImplementedError("Restore is not implemented.")

@app_commands.command()
async def test(self, inter: Interaction):
pass


async def setup(bot: MyBot):
await bot.add_cog(Restore(bot))
6 changes: 3 additions & 3 deletions src/cogs/translate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from discord.app_commands import locale_str as __

from core import ExtendedCog, MiscCommandContext, ResponseType, TemporaryCache, db, misc_command, response_constructor
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check
from core.checkers import bot_required_permissions, check, is_activated_predicate, is_user_authorized_predicate
from core.constants import EmbedsCharLimits
from core.errors import BadArgument, NonSpecificError
from core.i18n import _
Expand Down Expand Up @@ -280,8 +280,8 @@ async def translate_misc_condition(self, payload: RawReactionActionEvent) -> boo
trigger_condition=translate_misc_condition,
)
@bot_required_permissions(send_messages=True, embed_links=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
@check(is_activated_predicate)
@check(is_user_authorized_predicate)
async def translate_misc_command(self, ctx: MiscCommandContext[MyBot], payload: RawReactionActionEvent):
channel = await self.bot.getch_channel(payload.channel_id)
if channel is None:
Expand Down
157 changes: 156 additions & 1 deletion src/core/checkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,157 @@
from . import app as app
from __future__ import annotations

import inspect
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast

import discord
from discord import Interaction
from discord.app_commands import Command, ContextMenu, check as app_check

from .._config import config
from .._types import CoroT
from ..errors import BotMissingPermissions, NotAllowedUser
from ..extended_commands import MiscCommandContext, check as misc_check
from ..utils import CommandType
from .max_concurrency import MaxConcurrency as MaxConcurrency

if TYPE_CHECKING:
from mybot import MyBot

type Context = Interaction | MiscCommandContext[Any]


logger = logging.getLogger(__name__)


def _determine_type(obj: Any) -> CommandType:
"""This function will determine the type of the command.
It makes some assumptions about the type of the command based on the annotations of the function.
"""
if isinstance(obj, Command | ContextMenu):
return CommandType.APP
if hasattr(obj, "__listener_as_command__"):
return CommandType.MISC
else:
annotations = inspect.get_annotations(obj)
target = next(iter(annotations.values())) # get the first annotation
if target is MiscCommandContext:
return CommandType.MISC
if target is Interaction:
return CommandType.APP
if isinstance(target, str):
# I don't know how to handle this case properly because MyBot is not imported in this file
if target.startswith("Interaction"):
return CommandType.APP
if target.startswith("MiscCommandContext"):
return CommandType.MISC
raise TypeError("Could not determine the type of the command.")


def _add_extra[T](type_: CommandType, func: T, name: str, value: Any) -> T:
copy_func = func # typing behavior
if type_ is CommandType.APP:
if isinstance(func, Command | ContextMenu):
func.extras[name] = value
else:
logger.critical(
"Because we need to add extras, this decorator must be above the command decorator. "
"(Command should already be defined)"
)
elif type_ is CommandType.MISC:
if hasattr(func, "__listener_as_command__"):
command: Command[Any, ..., Any] = getattr(func, "__listener_as_command__")
command.extras[name] = value
else:
if not hasattr(func, "__misc_commands_extras__"):
setattr(func, "__misc_commands_extras__", {})
getattr(func, "__misc_commands_extras__")[name] = value
return copy_func


def check[C: Interaction | MiscCommandContext[Any], F](
predicate: Callable[[C], bool | CoroT[bool]],
) -> Callable[[F], F]:
def decorator(func: F) -> F:
match _determine_type(func):
case CommandType.APP:
p = cast(Callable[[Interaction], bool | CoroT[bool]], predicate)
return app_check(p)(func)
case CommandType.MISC:
p = cast(Callable[[MiscCommandContext[Any]], bool | CoroT[bool]], predicate)
return misc_check(p)(func)

return decorator


def _bot_required_permissions_test(perms: dict[str, bool]) -> Callable[..., bool]:
def predicate(ctx: Context):
match ctx:
case discord.Interaction():
permissions = ctx.app_permissions
case MiscCommandContext():
permissions = ctx.bot_permissions

missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value]

if not missing:
return True

raise BotMissingPermissions(missing)

return predicate


def bot_required_permissions[T](**perms: bool) -> Callable[[T], T]:
invalid = set(perms) - set(discord.Permissions.VALID_FLAGS)
if invalid:
raise TypeError(f"Invalid permission(s): {", ".join(invalid)}")

def decorator(func: T) -> T:
type_ = _determine_type(func)
_add_extra(
type_,
func,
"bot_required_permissions",
[perm for perm, value in perms.items() if value is True],
)
match type_:
case CommandType.APP:
return app_check(_bot_required_permissions_test(perms))(func)
case CommandType.MISC:
return misc_check(_bot_required_permissions_test(perms))(func)

return decorator


async def is_user_authorized_predicate(context: MiscCommandContext[MyBot]) -> bool:
del context # unused
# TODO(airo.pi_): check using the database if the user is authorized
return True


is_user_authorized = check(is_user_authorized_predicate) # misc commands only


async def is_activated_predicate(context: MiscCommandContext[MyBot]) -> bool:
del context # unused
# TODO(airo.pi_): check using the database if the misc command is activated
return True


is_activated = check(is_activated_predicate) # misc commands only


def allowed_users_test(*user_ids: int) -> Callable[..., bool]:
def inner(user_id: int) -> bool:
if user_id not in user_ids:
raise NotAllowedUser(user_id)
return True

return inner


is_me_test = allowed_users_test(*config.owners_ids) # test function used for eval commands
is_me = check(lambda ctx: is_me_test(ctx.user.id))
16 changes: 0 additions & 16 deletions src/core/checkers/app.py

This file was deleted.

100 changes: 0 additions & 100 deletions src/core/checkers/base.py

This file was deleted.

Loading

0 comments on commit 8d2cf07

Please sign in to comment.