Skip to content

Commit

Permalink
General improvements of misc commands and the extended cog (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
AiroPi authored Mar 29, 2024
1 parent be4fd92 commit c4bd67b
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 60 deletions.
15 changes: 12 additions & 3 deletions src/cogs/restore.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import logging
import re
from typing import TYPE_CHECKING

from core import ExtendedCog, misc_command
from core import ExtendedCog, MiscCommandContext, misc_command
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check

if TYPE_CHECKING:
Expand All @@ -16,11 +17,19 @@


class Restore(ExtendedCog):
@misc_command("restore", description="Send a message back in chat if a link is send.", extras={"soon": True})
def contains_message_link(self, message: Message) -> bool:
return bool(re.search(r"<?https://(?:.+\.)?discord(?:app)?\.com/channels/(\d+)/(\d+)/(\d+)", message.content))

@misc_command(
"restore",
description="Send a message back in chat if a link is send.",
extras={"soon": True},
trigger_condition=contains_message_link,
)
@bot_required_permissions(manage_webhooks=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
async def on_message(self, message: Message) -> None:
async def on_message(self, ctx: MiscCommandContext[MyBot], message: Message) -> None:
raise NotImplementedError("Restore is not implemented.")


Expand Down
12 changes: 4 additions & 8 deletions src/cogs/translate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from discord import Embed, Message, app_commands, ui
from discord.app_commands import locale_str as __

from core import ExtendedCog, ResponseType, TemporaryCache, db, misc_command, response_constructor
from core import ExtendedCog, MiscCommandContext, ResponseType, TemporaryCache, db, misc_command, response_constructor
from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check
from core.constants import EmbedsCharLimits
from core.errors import BadArgument, NonSpecificError
Expand Down Expand Up @@ -282,11 +282,7 @@ async def translate_misc_condition(self, payload: RawReactionActionEvent) -> boo
@bot_required_permissions(send_messages=True, embed_links=True)
@misc_check(is_activated)
@misc_check(is_user_authorized)
async def translate_misc_command(self, payload: RawReactionActionEvent):
user = await self.bot.getch_user(payload.user_id)
if not user or user.bot: # TODO(airo.pi_): automatically ignore bots
return

async def translate_misc_command(self, ctx: MiscCommandContext[MyBot], payload: RawReactionActionEvent):
channel = await self.bot.getch_channel(payload.channel_id)
if channel is None:
return
Expand All @@ -304,12 +300,12 @@ async def public_pre_strategy():
await channel.typing()

async def private_pre_strategy():
await user.typing()
await ctx.user.typing()

if await self.public_translations(payload.guild_id):
strategies = Strategies(pre=public_pre_strategy, send=partial(channel.send, reference=message))
else:
strategies = Strategies(pre=private_pre_strategy, send=user.send)
strategies = Strategies(pre=private_pre_strategy, send=ctx.user.send)

await self.process(
payload.user_id,
Expand Down
1 change: 1 addition & 0 deletions src/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .extended_commands import (
ExtendedCog as ExtendedCog,
ExtendedGroupCog as ExtendedGroupCog,
MiscCommandContext as MiscCommandContext,
cog_property as cog_property,
misc_command as misc_command,
)
Expand Down
4 changes: 2 additions & 2 deletions src/core/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
T = TypeVar("T")

CogT = TypeVar("CogT", bound="ExtendedCog")
BotT = TypeVar("BotT", bound="commands.Bot | commands.AutoShardedBot")

type Bot = commands.Bot | commands.AutoShardedBot
BotT = TypeVar("BotT", bound=Bot)

Snowflake = int

Expand Down
84 changes: 39 additions & 45 deletions src/core/extended_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
Generic,
Literal,
ParamSpec,
Protocol,
Self,
TypeVar,
Expand All @@ -23,23 +20,20 @@
from discord.ext import commands
from discord.utils import maybe_coroutine

from ._types import BotT, CogT
from .errors import MiscCheckFailure, MiscCommandError, MiscNoPrivateMessage, UnexpectedError

if TYPE_CHECKING:
from discord.abc import MessageableChannel, Snowflake
from discord.ext.commands.bot import AutoShardedBot, Bot, BotBase # pyright: ignore[reportMissingTypeStubs]
from discord.ext.commands.bot import BotBase # pyright: ignore[reportMissingTypeStubs]

from mybot import MyBot

from ._types import CoroT, UnresolvedContext, UnresolvedContextT
from ._types import Bot, CogT, CoroT, UnresolvedContext, UnresolvedContextT

ConditionCallback = Callable[Concatenate["CogT", UnresolvedContextT, "P"], CoroT[bool] | bool]
Callback = Callable[Concatenate["CogT", UnresolvedContextT, "P"], CoroT["T"]]
type ConditionCallback[CogT, UnresolvedContextT] = Callable[[CogT, UnresolvedContextT], CoroT[bool] | bool]
type Callback[CogT, UnresolvedContext, R] = Callable[[CogT, "MiscCommandContext[Any]", UnresolvedContext], CoroT[R]]

P = ParamSpec("P")
T = TypeVar("T")
C = TypeVar("C", bound="commands.Cog")
R = TypeVar("R")


LiteralNames = Literal["raw_reaction_add", "message"]
Expand All @@ -57,7 +51,7 @@ class MiscCommandsType(Enum):


class ExtendedCog(commands.Cog):
__cog_misc_commands__: list[MiscCommand[Any, ..., Any]]
__cog_misc_commands__: list[MiscCommand[Any, Any]]
bot: MyBot

def __new__(cls, *args: Any, **kwargs: Any) -> Self:
Expand All @@ -74,7 +68,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
def __init__(self, bot: MyBot) -> None:
self.bot = bot

def get_misc_commands(self) -> list[MiscCommand[Any, ..., Any]]:
def get_misc_commands(self) -> list[MiscCommand[Any, Any]]:
"""Return all the misc commands in this cog."""
return list(self.__cog_misc_commands__)

Expand All @@ -93,18 +87,18 @@ class ExtendedGroupCog(ExtendedCog):
__cog_is_app_commands_group__: ClassVar[bool] = True


class MiscCommand(Generic[CogT, P, T]):
bot: Bot | AutoShardedBot
class MiscCommand[CogT: ExtendedCog, R]:
bot: Bot

def __init__(
self,
name: str,
callback: Callback[CogT, UnresolvedContextT, P, T],
callback: Callback[CogT, UnresolvedContextT, R],
description: str,
nsfw: bool,
type: MiscCommandsType,
extras: dict[Any, Any],
trigger_condition: Callable[Concatenate[CogT, UnresolvedContext, P], bool | CoroT[bool]] | None,
trigger_condition: Callable[[CogT, UnresolvedContextT], bool | CoroT[bool]] | None,
) -> None:
self.name = name
self.type = type
Expand All @@ -123,14 +117,12 @@ def __init__(
)
self._callback = callback

async def do_call(self, cog: CogT, context: UnresolvedContext, *args: P.args, **kwargs: P.kwargs) -> T:
async def do_call(self, cog: CogT, context: UnresolvedContext) -> R:
if self.trigger_condition:
trigger_condition = await discord.utils.maybe_coroutine(
self.trigger_condition,
self.trigger_condition, # type: ignore
cog,
context,
*args,
**kwargs, # type: ignore
context, # type: ignore
)
if not trigger_condition:
return None # type: ignore
Expand All @@ -143,12 +135,12 @@ async def do_call(self, cog: CogT, context: UnresolvedContext, *args: P.args, **
self.bot.dispatch("misc_command_error", resolved_context, e)
return None # type: ignore

return await self._callback(cog, context, *args, **kwargs) # type: ignore
return await self._callback(cog, resolved_context, context) # type: ignore

def add_check(self, predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> None:
self.checks.append(predicate)

async def condition(self, func: ConditionCallback[CogT, UnresolvedContextT, P]) -> None:
async def condition(self, func: ConditionCallback[CogT, UnresolvedContextT]) -> None:
self.trigger_condition = func


Expand All @@ -159,8 +151,8 @@ def misc_command(
nsfw: bool = False,
listener_name: LiteralNames | None = None,
extras: dict[Any, Any] | None = None,
trigger_condition: ConditionCallback[CogT, UnresolvedContextT, P] | None = None,
) -> Callable[[Callback[CogT, UnresolvedContextT, P, T]], Callback[CogT, UnresolvedContextT, P, T]]:
trigger_condition: ConditionCallback[CogT, UnresolvedContextT] | None = None,
) -> Callable[[Callback[CogT, UnresolvedContextT, R]], Callable[[CogT, UnresolvedContext], CoroT[R]]]:
"""Register an event listener as a "command" that can be retrieved from the feature exporter.
Checkers will be called within the second argument of the function (right after the Cog (self))
Expand All @@ -179,10 +171,12 @@ def misc_command(
A wrapped function, bound with a MiscCommand.
"""

def inner(func: Callback[CogT, UnresolvedContextT, P, T]) -> Callback[CogT, UnresolvedContextT, P, T]:
def inner(
func: Callback[CogT, UnresolvedContextT, R],
) -> Callable[[CogT, UnresolvedContext], CoroT[R]]:
true_listener_name = "on_" + listener_name if listener_name else func.__name__

misc_command = MiscCommand[CogT, P, T](
misc_command = MiscCommand["CogT", R](
name=name,
callback=func,
description=description,
Expand All @@ -193,8 +187,8 @@ def inner(func: Callback[CogT, UnresolvedContextT, P, T]) -> Callback[CogT, Unre
)

@wraps(func)
async def inner(cog: CogT, context: UnresolvedContext, *args: P.args, **kwargs: P.kwargs) -> T:
return await misc_command.do_call(cog, context, *args, **kwargs)
async def inner(cog: CogT, context: UnresolvedContext) -> R:
return await misc_command.do_call(cog, context)

setattr(inner, "__listener_as_command__", misc_command)

Expand All @@ -218,21 +212,21 @@ class MiscCommandContextFilled(Protocol):
user: discord.User


class MiscCommandContext(Generic[BotT]):
class MiscCommandContext[B: Bot]:
def __init__(
self,
bot: BotT,
bot: B,
channel: MessageableChannel,
user: User | Member,
command: MiscCommand[Any, ..., Any],
command: MiscCommand[Any, Any],
) -> None:
self.channel: MessageableChannel = channel
self.user: User | Member = user
self.bot: BotT = bot
self.command: MiscCommand[Any, ..., Any] = command
self.bot: B = bot
self.command: MiscCommand[Any, Any] = command

@classmethod
async def resolve(cls, bot: BotT, context: UnresolvedContext, command: MiscCommand[Any, ..., Any]) -> Self:
async def resolve(cls, bot: B, context: UnresolvedContext, command: MiscCommand[Any, Any]) -> Self:
channel: MessageableChannel
user: User | Member

Expand Down Expand Up @@ -268,15 +262,15 @@ def bot_permissions(self) -> Permissions:
return channel.permissions_for(me)


def misc_guild_only() -> Callable[[T], T]:
def misc_guild_only() -> Callable[[R], R]:
def predicate(ctx: MiscCommandContext[Any]) -> bool:
if ctx.channel.guild is None:
raise MiscNoPrivateMessage
return True

def decorator(func: T) -> T:
def decorator(func: R) -> R:
if hasattr(func, "__listener_as_command__"):
misc_command: MiscCommand[Any, ..., Any] = getattr(func, "__listener_as_command__")
misc_command: MiscCommand[Any, Any] = getattr(func, "__listener_as_command__")
misc_command.add_check(predicate)
misc_command.guild_only = True
else:
Expand All @@ -291,10 +285,10 @@ def decorator(func: T) -> T:
return decorator


def misc_check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[T], T]:
def decorator(func: T) -> T:
def misc_check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[R], R]:
def decorator(func: R) -> R:
if hasattr(func, "__listener_as_command__"):
misc_command: MiscCommand[Any, ..., Any] = getattr(func, "__listener_as_command__")
misc_command: MiscCommand[Any, Any] = getattr(func, "__listener_as_command__")
misc_command.add_check(predicate)
else:
if not hasattr(func, "__misc_commands_checks__"):
Expand All @@ -314,10 +308,10 @@ def cog_property(cog_name: str):
cog_name: the cog name to return
"""

def inner(_: Callable[..., C]) -> C:
def inner(_: Callable[..., CogT]) -> CogT:
@property
def cog_getter(self: Any) -> C: # self is a cog within the .bot attribute (because every Cog should have it)
cog: C | None = self.bot.get_cog(cog_name)
def cog_getter(self: Any) -> CogT: # self is a cog within the .bot attribute (because every Cog should have it)
cog: CogT | None = self.bot.get_cog(cog_name)
if cog is None:
raise UnexpectedError(f"Cog named {cog_name} is not loaded.")
return cog
Expand Down
2 changes: 1 addition & 1 deletion src/features_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


FeatureCodebaseTypes = (
app_commands.Command[Any, ..., Any] | app_commands.Group | app_commands.ContextMenu | MiscCommand[Any, ..., Any]
app_commands.Command[Any, ..., Any] | app_commands.Group | app_commands.ContextMenu | MiscCommand[Any, Any]
)


Expand Down
3 changes: 2 additions & 1 deletion src/mybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, startup_sync: bool = False) -> None:
intents.reactions = True
intents.guilds = True
intents.messages = True
intents.message_content = True
logger.debug("Intents : %s", ", ".join(flag[0] for flag in intents if flag[1]))

super().__init__(
Expand Down Expand Up @@ -318,7 +319,7 @@ def misc_commands(self):
Returns:
_type_: the list of misc commands
"""
misc_commands: list[MiscCommand[Any, ..., Any]] = []
misc_commands: list[MiscCommand[Any, Any]] = []
for cog in self.cogs.values():
if isinstance(cog, ExtendedCog):
misc_commands.extend(cog.get_misc_commands())
Expand Down

0 comments on commit c4bd67b

Please sign in to comment.