From 5593851b9588876548978cf04d9e45461b825ae7 Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:33:18 -0400 Subject: [PATCH 1/7] feat(redis): Add basic Redis impl. with level cooldown caching --- .env.example | 2 ++ docker-compose.yml | 14 +++++++- poetry.lock | 17 ++++++++- pyproject.toml | 1 + tux/bot.py | 10 +++++- tux/cogs/services/levels.py | 45 ++++++++++++++++++++---- tux/database/redis.py | 70 +++++++++++++++++++++++++++++++++++++ tux/utils/constants.py | 3 ++ 8 files changed, 152 insertions(+), 10 deletions(-) create mode 100644 tux/database/redis.py diff --git a/.env.example b/.env.example index ab715c73..2e96f490 100644 --- a/.env.example +++ b/.env.example @@ -17,6 +17,8 @@ DEV_TOKEN="" # # Optional # +REDIS_ENABLED=true +REDIS_URL=redis://localhost:6379 SENTRY_URL="" diff --git a/docker-compose.yml b/docker-compose.yml index 7d794d12..45137f5f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,4 +12,16 @@ services: ignore: - .venv/ env_file: - - .env \ No newline at end of file + - .env + + redis: + image: redis:latest + container_name: redis + restart: always + ports: + - "6379:6379" + volumes: + - redis_data:/data + +volumes: + redis_data: {} diff --git a/poetry.lock b/poetry.lock index fdaa878b..45a560ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1905,6 +1905,21 @@ files = [ [package.dependencies] "discord.py" = ">=2.0.0" +[[package]] +name = "redis" +version = "5.1.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +files = [ + {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, + {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, +] + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "regex" version = "2024.9.11" @@ -2494,4 +2509,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.12,<4" -content-hash = "6ab3fe78575b8027fe7032997cebda8a035a760dc57e21a4ed0a7325d46be874" +content-hash = "ea9d242987cee0e9b5b6227acdefe0212e4a68f9abd291c665c395459134bc77" diff --git a/pyproject.toml b/pyproject.toml index fe3848cc..5fe4c9c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ types-psutil = "^6.0.0.20240621" types-pytz = "^2024.2.0.20240913" types-pyyaml = "^6.0.12.20240808" typing-extensions = "^4.12.2" +redis = "^5.1.1" [tool.poetry.group.docs.dependencies] mkdocs-material = "^9.5.30" diff --git a/tux/bot.py b/tux/bot.py index 4581f61b..3be20955 100644 --- a/tux/bot.py +++ b/tux/bot.py @@ -6,6 +6,8 @@ from tux.cog_loader import CogLoader from tux.database.client import db +from tux.database.redis import redis_manager +from tux.utils.constants import CONST class Tux(commands.Bot): @@ -13,6 +15,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.setup_task = asyncio.create_task(self.setup()) self.is_shutting_down = False + self.redis = redis_manager async def setup(self) -> None: """ @@ -26,11 +29,15 @@ async def setup(self) -> None: logger.info(f"Prisma client connected: {db.is_connected()}") logger.info(f"Prisma client registered: {db.is_registered()}") + # Connect to Redis + logger.info("Setting up Redis client...") + await self.redis.connect(CONST.REDIS_URL) + except Exception as e: logger.critical(f"An error occurred while connecting to the database: {e}") return - # Load Jishaku for debugging + # Load Jishaku for debuggings await self.load_extension("jishaku") # Load cogs via CogLoader await self.load_cogs() @@ -82,6 +89,7 @@ async def shutdown(self) -> None: try: logger.info("Closing database connections.") await db.disconnect() + await self.redis.disconnect() except Exception as e: logger.critical(f"Error during database disconnection: {e}") diff --git a/tux/cogs/services/levels.py b/tux/cogs/services/levels.py index ea4216f9..cb49bee3 100644 --- a/tux/cogs/services/levels.py +++ b/tux/cogs/services/levels.py @@ -25,6 +25,7 @@ def __init__(self, bot: Tux) -> None: self.levels_exponent = self.settings.get("LEVELS_EXPONENT") self.xp_roles = {role["level"]: role["role_id"] for role in self.settings["XP_ROLES"]} self.xp_multipliers = {role["role_id"]: role["multiplier"] for role in self.settings["XP_MULTIPLIERS"]} + self.redis = bot.redis @commands.Cog.listener("on_message") async def xp_listener(self, message: discord.Message) -> None: @@ -63,7 +64,7 @@ async def process_xp_gain(self, member: discord.Member, guild: discord.Guild) -> if await self.levels_controller.is_blacklisted(member.id, guild.id): return - last_message_time = await self.levels_controller.get_last_message_time(member.id, guild.id) + last_message_time = await self.get_last_message_time(member.id, guild.id) if last_message_time and self.is_on_cooldown(last_message_time): return @@ -73,17 +74,47 @@ async def process_xp_gain(self, member: discord.Member, guild: discord.Guild) -> new_xp = current_xp + xp_increment new_level = self.calculate_level(new_xp) + await self.update_xp_and_level(member.id, guild.id, new_xp, new_level) + + if new_level > current_level: + logger.debug(f"User {member.name} leveled up from {current_level} to {new_level} in guild {guild.name}") + await self.handle_level_up(member, guild, new_level) + + async def get_last_message_time(self, user_id: int, guild_id: int) -> datetime.datetime | None: + """ + Retrieves the last message time for a member from Redis. + + Parameters + ---------- + user_id : int + The ID of the member. + guild_id : int + The ID of the guild. + + Returns + ------- + datetime.datetime | None + The last message time if cached, otherwise None. + """ + cache_key = f"last_message_time:{user_id}:{guild_id}" + cached_time = await self.redis.get(cache_key) + + if cached_time and cached_time != "None": + return datetime.datetime.fromtimestamp(float(cached_time), tz=datetime.UTC) + + return None + + async def update_xp_and_level(self, user_id: int, guild_id: int, new_xp: float, new_level: int) -> None: await self.levels_controller.update_xp_and_level( - member.id, - guild.id, + user_id, + guild_id, new_xp, new_level, datetime.datetime.fromtimestamp(time.time(), tz=datetime.UTC), ) - if new_level > current_level: - logger.debug(f"User {member.name} leveled up from {current_level} to {new_level} in guild {guild.name}") - await self.handle_level_up(member, guild, new_level) + last_message_time_key = f"last_message_time:{user_id}:{guild_id}" + await self.redis.set(last_message_time_key, str(time.time()), expiration=self.xp_cooldown) def is_on_cooldown(self, last_message_time: datetime.datetime) -> bool: """ @@ -141,7 +172,7 @@ async def update_roles(self, member: discord.Member, guild: discord.Guild, new_l roles_to_remove = [r for r in member.roles if r.id in self.xp_roles.values() and r != highest_role] await member.remove_roles(*roles_to_remove) logger.debug( - f"Assigned role {highest_role.name if highest_role else "None"} to member {member} and removed roles {", ".join(r.name for r in roles_to_remove)}", + f"Assigned role {highest_role.name if highest_role else 'None'} to member {member} and removed roles {', '.join(r.name for r in roles_to_remove)}", ) @staticmethod diff --git a/tux/database/redis.py b/tux/database/redis.py new file mode 100644 index 00000000..d1846c40 --- /dev/null +++ b/tux/database/redis.py @@ -0,0 +1,70 @@ +from typing import Any + +import redis.asyncio as redis +from loguru import logger + + +class RedisManager: + def __init__(self): + self.redis: redis.Redis | None = None + self.is_connected: bool = False + + async def connect(self, url: str) -> None: + try: + self.redis = redis.from_url(url, decode_responses=True) # type: ignore + await self.redis.ping() # type: ignore + self.is_connected = True + logger.info("Successfully connected to Redis") + except redis.ConnectionError as e: + logger.warning(f"Failed to connect to Redis: {e}") + self.is_connected = False + + async def disconnect(self) -> None: + if self.redis: + await self.redis.close() + self.is_connected = False + logger.info("Disconnected from Redis") + + async def get(self, key: str) -> str | None: + if self.redis and self.is_connected: + value = await self.redis.get(key) + logger.info(f"Retrieved key '{key}' with value '{value}' from Redis") + return value + logger.warning(f"Failed to retrieve key '{key}' from Redis") + return None + + async def set(self, key: str, value: str, expiration: int | None = None) -> None: + if self.redis and self.is_connected: + await self.redis.set(key, value, ex=expiration) + logger.info(f"Set key '{key}' with value '{value}' in Redis with expiration '{expiration}'") + + async def delete(self, key: str) -> None: + if self.redis and self.is_connected: + await self.redis.delete(key) + logger.info(f"Deleted key '{key}' from Redis") + + async def increment(self, key: str, amount: int = 1) -> int | None: + if self.redis and self.is_connected: + new_value = await self.redis.incr(key, amount) + logger.info(f"Incremented key '{key}' by '{amount}', new value is '{new_value}'") + return new_value + logger.warning(f"Failed to increment key '{key}' by '{amount}'") + return None + + async def zadd(self, key: str, mapping: dict[str, float]) -> None: + if self.redis and self.is_connected: + await self.redis.zadd(key, mapping) + logger.info(f"Added to sorted set '{key}' with mapping '{mapping}'") + + async def zrange(self, key: str, start: int, end: int, desc: bool = False, withscores: bool = False) -> list[Any]: + if self.redis and self.is_connected: + result = await self.redis.zrange(key, start, end, desc=desc, withscores=withscores) # type: ignore + logger.info( + f"Retrieved range from sorted set '{key}' from '{start}' to '{end}', desc='{desc}', withscores='{withscores}': {result}", + ) + return result + logger.warning(f"Failed to retrieve range from sorted set '{key}' from '{start}' to '{end}'") + return [] + + +redis_manager = RedisManager() diff --git a/tux/utils/constants.py b/tux/utils/constants.py index 76bc2cbd..eea61452 100644 --- a/tux/utils/constants.py +++ b/tux/utils/constants.py @@ -30,6 +30,9 @@ class Constants: DEFAULT_DEV_PREFIX: Final[str] = config["DEFAULT_PREFIX"]["DEV"] DEV_COG_IGNORE_LIST: Final[set[str]] = set(os.getenv("DEV_COG_IGNORE_LIST", "").split(",")) + # Redis constants + REDIS_URL: Final[str] = os.getenv("REDIS_URL", "redis://localhost:6379") + # Debug env constants DEBUG: Final[bool] = bool(os.getenv("DEBUG", "True")) From f2c7125a8612f3224d68eb07570ffcd9ec82d837 Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:35:12 -0400 Subject: [PATCH 2/7] patch(redis): Clean up .env.example --- .env.example | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 2e96f490..aecf6bd9 100644 --- a/.env.example +++ b/.env.example @@ -14,11 +14,11 @@ PROD_TOKEN="" DEV_DATABASE_URL="" DEV_TOKEN="" +REDIS_URL=redis://localhost:6379 + # # Optional # -REDIS_ENABLED=true -REDIS_URL=redis://localhost:6379 SENTRY_URL="" From d3c7e2195f46660afeef74d85e52d56a36c0c12d Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:42:06 -0400 Subject: [PATCH 3/7] chore(redis): Enable optional debug logging in RedisManager --- .env.example | 2 +- tux/database/redis.py | 22 +++++++++++++--------- tux/utils/constants.py | 1 + 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.env.example b/.env.example index aecf6bd9..48073d13 100644 --- a/.env.example +++ b/.env.example @@ -19,8 +19,8 @@ REDIS_URL=redis://localhost:6379 # # Optional # - SENTRY_URL="" +REDIS_DEBUG_LOG=false PROD_COG_IGNORE_LIST= DEV_COG_IGNORE_LIST= diff --git a/tux/database/redis.py b/tux/database/redis.py index d1846c40..4471e998 100644 --- a/tux/database/redis.py +++ b/tux/database/redis.py @@ -1,13 +1,17 @@ +from collections.abc import Callable from typing import Any import redis.asyncio as redis from loguru import logger +from tux.utils.constants import CONST + class RedisManager: def __init__(self): self.redis: redis.Redis | None = None self.is_connected: bool = False + self.debug_log: Callable[[str], None] = logger.debug if CONST.REDIS_DEBUG_LOG else lambda msg: None async def connect(self, url: str) -> None: try: @@ -28,42 +32,42 @@ async def disconnect(self) -> None: async def get(self, key: str) -> str | None: if self.redis and self.is_connected: value = await self.redis.get(key) - logger.info(f"Retrieved key '{key}' with value '{value}' from Redis") + self.debug_log(f"Retrieved key '{key}' with value '{value}' from Redis") return value - logger.warning(f"Failed to retrieve key '{key}' from Redis") + self.debug_log(f"Failed to retrieve key '{key}' from Redis") return None async def set(self, key: str, value: str, expiration: int | None = None) -> None: if self.redis and self.is_connected: await self.redis.set(key, value, ex=expiration) - logger.info(f"Set key '{key}' with value '{value}' in Redis with expiration '{expiration}'") + self.debug_log(f"Set key '{key}' with value '{value}' in Redis with expiration '{expiration}'") async def delete(self, key: str) -> None: if self.redis and self.is_connected: await self.redis.delete(key) - logger.info(f"Deleted key '{key}' from Redis") + self.debug_log(f"Deleted key '{key}' from Redis") async def increment(self, key: str, amount: int = 1) -> int | None: if self.redis and self.is_connected: new_value = await self.redis.incr(key, amount) - logger.info(f"Incremented key '{key}' by '{amount}', new value is '{new_value}'") + self.debug_log(f"Incremented key '{key}' by '{amount}', new value is '{new_value}'") return new_value - logger.warning(f"Failed to increment key '{key}' by '{amount}'") + self.debug_log(f"Failed to increment key '{key}' by '{amount}'") return None async def zadd(self, key: str, mapping: dict[str, float]) -> None: if self.redis and self.is_connected: await self.redis.zadd(key, mapping) - logger.info(f"Added to sorted set '{key}' with mapping '{mapping}'") + self.debug_log(f"Added to sorted set '{key}' with mapping '{mapping}'") async def zrange(self, key: str, start: int, end: int, desc: bool = False, withscores: bool = False) -> list[Any]: if self.redis and self.is_connected: result = await self.redis.zrange(key, start, end, desc=desc, withscores=withscores) # type: ignore - logger.info( + self.debug_log( f"Retrieved range from sorted set '{key}' from '{start}' to '{end}', desc='{desc}', withscores='{withscores}': {result}", ) return result - logger.warning(f"Failed to retrieve range from sorted set '{key}' from '{start}' to '{end}'") + self.debug_log(f"Failed to retrieve range from sorted set '{key}' from '{start}' to '{end}'") return [] diff --git a/tux/utils/constants.py b/tux/utils/constants.py index eea61452..a55dd862 100644 --- a/tux/utils/constants.py +++ b/tux/utils/constants.py @@ -32,6 +32,7 @@ class Constants: # Redis constants REDIS_URL: Final[str] = os.getenv("REDIS_URL", "redis://localhost:6379") + REDIS_DEBUG_LOG: Final[bool] = os.getenv("REDIS_DEBUG_LOG", "false").lower() == "true" # Debug env constants DEBUG: Final[bool] = bool(os.getenv("DEBUG", "True")) From 0ce6338afddfa7b22151550b64d642e281362dae Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:44:32 -0400 Subject: [PATCH 4/7] chore(redis): Add docs to RedisManager --- tux/database/redis.py | 90 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tux/database/redis.py b/tux/database/redis.py index 4471e998..3c0972a5 100644 --- a/tux/database/redis.py +++ b/tux/database/redis.py @@ -14,6 +14,14 @@ def __init__(self): self.debug_log: Callable[[str], None] = logger.debug if CONST.REDIS_DEBUG_LOG else lambda msg: None async def connect(self, url: str) -> None: + """ + Connect to the Redis server. + + Parameters + ---------- + url : str + The URL of the Redis server to connect to. + """ try: self.redis = redis.from_url(url, decode_responses=True) # type: ignore await self.redis.ping() # type: ignore @@ -24,12 +32,28 @@ async def connect(self, url: str) -> None: self.is_connected = False async def disconnect(self) -> None: + """ + Disconnect from the Redis server. + """ if self.redis: await self.redis.close() self.is_connected = False logger.info("Disconnected from Redis") async def get(self, key: str) -> str | None: + """ + Retrieve a value from Redis by key. + + Parameters + ---------- + key : str + The key to retrieve the value for. + + Returns + ------- + str | None + The value associated with the key, or None if the key does not exist. + """ if self.redis and self.is_connected: value = await self.redis.get(key) self.debug_log(f"Retrieved key '{key}' with value '{value}' from Redis") @@ -38,16 +62,51 @@ async def get(self, key: str) -> str | None: return None async def set(self, key: str, value: str, expiration: int | None = None) -> None: + """ + Set a value in Redis with an optional expiration time. + + Parameters + ---------- + key : str + The key to set the value for. + value : str + The value to set. + expiration : int, optional + The expiration time in seconds, by default None. + """ if self.redis and self.is_connected: await self.redis.set(key, value, ex=expiration) self.debug_log(f"Set key '{key}' with value '{value}' in Redis with expiration '{expiration}'") async def delete(self, key: str) -> None: + """ + Delete a key from Redis. + + Parameters + ---------- + key : str + The key to delete. + """ if self.redis and self.is_connected: await self.redis.delete(key) self.debug_log(f"Deleted key '{key}' from Redis") async def increment(self, key: str, amount: int = 1) -> int | None: + """ + Increment the value of a key in Redis by a specified amount. + + Parameters + ---------- + key : str + The key to increment the value for. + amount : int, optional + The amount to increment by, by default 1. + + Returns + ------- + int | None + The new value after incrementing, or None if the operation failed. + """ if self.redis and self.is_connected: new_value = await self.redis.incr(key, amount) self.debug_log(f"Incremented key '{key}' by '{amount}', new value is '{new_value}'") @@ -56,11 +115,42 @@ async def increment(self, key: str, amount: int = 1) -> int | None: return None async def zadd(self, key: str, mapping: dict[str, float]) -> None: + """ + Add one or more members to a sorted set in Redis, or update its score if it already exists. + + Parameters + ---------- + key : str + The key of the sorted set. + mapping : dict[str, float] + A dictionary of member-score pairs to add to the sorted set. + """ if self.redis and self.is_connected: await self.redis.zadd(key, mapping) self.debug_log(f"Added to sorted set '{key}' with mapping '{mapping}'") async def zrange(self, key: str, start: int, end: int, desc: bool = False, withscores: bool = False) -> list[Any]: + """ + Retrieve a range of members in a sorted set by index. + + Parameters + ---------- + key : str + The key of the sorted set. + start : int + The start index. + end : int + The end index. + desc : bool, optional + Whether to sort the results in descending order, by default False. + withscores : bool, optional + Whether to include the scores of the members, by default False. + + Returns + ------- + list[Any] + A list of members in the specified range, optionally with their scores. + """ if self.redis and self.is_connected: result = await self.redis.zrange(key, start, end, desc=desc, withscores=withscores) # type: ignore self.debug_log( From 4db7aa1acd1ed47e43af3609df0d1f9a5a622b9f Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:47:39 -0400 Subject: [PATCH 5/7] chore(redis): Move REDIS_DEBUG_LOG to settings.yml --- .env.example | 1 - config/settings.yml.example | 2 ++ tux/utils/constants.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index 48073d13..7701bc2f 100644 --- a/.env.example +++ b/.env.example @@ -20,7 +20,6 @@ REDIS_URL=redis://localhost:6379 # Optional # SENTRY_URL="" -REDIS_DEBUG_LOG=false PROD_COG_IGNORE_LIST= DEV_COG_IGNORE_LIST= diff --git a/config/settings.yml.example b/config/settings.yml.example index 479828e8..22f4bde6 100644 --- a/config/settings.yml.example +++ b/config/settings.yml.example @@ -13,6 +13,8 @@ USER_IDS: TEMPVC_CATEGORY_ID: 123456789012345679 TEMPVC_CHANNEL_ID: 123456789012345679 +REDIS_DEBUG_LOG: false + XP_BLACKLIST_CHANNEL: - 123456789012345679 - 123456789012345679 diff --git a/tux/utils/constants.py b/tux/utils/constants.py index a55dd862..00708a14 100644 --- a/tux/utils/constants.py +++ b/tux/utils/constants.py @@ -32,7 +32,7 @@ class Constants: # Redis constants REDIS_URL: Final[str] = os.getenv("REDIS_URL", "redis://localhost:6379") - REDIS_DEBUG_LOG: Final[bool] = os.getenv("REDIS_DEBUG_LOG", "false").lower() == "true" + REDIS_DEBUG_LOG: Final[bool] = config["REDIS_DEBUG_LOG"] # Debug env constants DEBUG: Final[bool] = bool(os.getenv("DEBUG", "True")) From 8f0fd33a5a6904160a2712705c171624ecbd4fa6 Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 13:53:47 -0400 Subject: [PATCH 6/7] chore(levels): Refactor cooldown check in LevelsService --- tux/cogs/services/levels.py | 42 ++++++++----------------------------- tux/database/redis.py | 24 ++++++++++++++++++++- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/tux/cogs/services/levels.py b/tux/cogs/services/levels.py index cb49bee3..d792b97f 100644 --- a/tux/cogs/services/levels.py +++ b/tux/cogs/services/levels.py @@ -64,8 +64,7 @@ async def process_xp_gain(self, member: discord.Member, guild: discord.Guild) -> if await self.levels_controller.is_blacklisted(member.id, guild.id): return - last_message_time = await self.get_last_message_time(member.id, guild.id) - if last_message_time and self.is_on_cooldown(last_message_time): + if await self.is_on_cooldown(member.id, guild.id): return current_xp, current_level = await self.levels_controller.get_xp_and_level(member.id, guild.id) @@ -80,9 +79,9 @@ async def process_xp_gain(self, member: discord.Member, guild: discord.Guild) -> logger.debug(f"User {member.name} leveled up from {current_level} to {new_level} in guild {guild.name}") await self.handle_level_up(member, guild, new_level) - async def get_last_message_time(self, user_id: int, guild_id: int) -> datetime.datetime | None: + async def is_on_cooldown(self, user_id: int, guild_id: int) -> bool: """ - Retrieves the last message time for a member from Redis. + Checks if the member is on cooldown. Parameters ---------- @@ -93,16 +92,11 @@ async def get_last_message_time(self, user_id: int, guild_id: int) -> datetime.d Returns ------- - datetime.datetime | None - The last message time if cached, otherwise None. + bool + True if the member is on cooldown, False otherwise. """ - cache_key = f"last_message_time:{user_id}:{guild_id}" - cached_time = await self.redis.get(cache_key) - - if cached_time and cached_time != "None": - return datetime.datetime.fromtimestamp(float(cached_time), tz=datetime.UTC) - - return None + cache_key = f"xp_cooldown:{user_id}:{guild_id}" + return await self.redis.exists(cache_key) async def update_xp_and_level(self, user_id: int, guild_id: int, new_xp: float, new_level: int) -> None: await self.levels_controller.update_xp_and_level( @@ -113,26 +107,8 @@ async def update_xp_and_level(self, user_id: int, guild_id: int, new_xp: float, datetime.datetime.fromtimestamp(time.time(), tz=datetime.UTC), ) - last_message_time_key = f"last_message_time:{user_id}:{guild_id}" - await self.redis.set(last_message_time_key, str(time.time()), expiration=self.xp_cooldown) - - def is_on_cooldown(self, last_message_time: datetime.datetime) -> bool: - """ - Checks if the member is on cooldown. - - Parameters - ---------- - last_message_time : datetime.datetime - The time of the last message. - - Returns - ------- - bool - True if the member is on cooldown, False otherwise. - """ - return (datetime.datetime.fromtimestamp(time.time(), tz=datetime.UTC) - last_message_time) < datetime.timedelta( - seconds=self.xp_cooldown, - ) + cooldown_key = f"xp_cooldown:{user_id}:{guild_id}" + await self.redis.set(cooldown_key, "1", expiration=self.xp_cooldown) async def handle_level_up(self, member: discord.Member, guild: discord.Guild, new_level: int) -> None: """ diff --git a/tux/database/redis.py b/tux/database/redis.py index 3c0972a5..f7d44a79 100644 --- a/tux/database/redis.py +++ b/tux/database/redis.py @@ -9,7 +9,7 @@ class RedisManager: def __init__(self): - self.redis: redis.Redis | None = None + self.redis: redis.Redis = redis.Redis() self.is_connected: bool = False self.debug_log: Callable[[str], None] = logger.debug if CONST.REDIS_DEBUG_LOG else lambda msg: None @@ -160,5 +160,27 @@ async def zrange(self, key: str, start: int, end: int, desc: bool = False, withs self.debug_log(f"Failed to retrieve range from sorted set '{key}' from '{start}' to '{end}'") return [] + async def exists(self, key: str) -> bool: + """ + Check if a key exists in Redis. + + Parameters + ---------- + key : str + The key to check for existence. + + Returns + ------- + bool + True if the key exists, False otherwise. + """ + if self.redis and self.is_connected: + exists = await self.redis.exists(key) + self.debug_log(f"Checked existence of key '{key}', exists: {exists}") + return exists > 0 + + self.debug_log(f"Failed to check existence of key '{key}'") + return False + redis_manager = RedisManager() From 5da7f746b9379ccc706f83d9f35af6bc73f9b7c4 Mon Sep 17 00:00:00 2001 From: wlinator Date: Sat, 5 Oct 2024 14:05:30 -0400 Subject: [PATCH 7/7] feat(redis): Simplify Redis impl. --- tux/bot.py | 9 +- tux/cogs/services/levels.py | 4 +- tux/database/redis.py | 163 +----------------------------------- 3 files changed, 10 insertions(+), 166 deletions(-) diff --git a/tux/bot.py b/tux/bot.py index 3be20955..4d729519 100644 --- a/tux/bot.py +++ b/tux/bot.py @@ -35,9 +35,12 @@ async def setup(self) -> None: except Exception as e: logger.critical(f"An error occurred while connecting to the database: {e}") - return + # You might want to exit the program here if the database connection fails + import sys + + sys.exit(1) - # Load Jishaku for debuggings + # Load Jishaku for debugging await self.load_extension("jishaku") # Load cogs via CogLoader await self.load_cogs() @@ -89,7 +92,7 @@ async def shutdown(self) -> None: try: logger.info("Closing database connections.") await db.disconnect() - await self.redis.disconnect() + await self.redis.interface.close() except Exception as e: logger.critical(f"Error during database disconnection: {e}") diff --git a/tux/cogs/services/levels.py b/tux/cogs/services/levels.py index d792b97f..7991635c 100644 --- a/tux/cogs/services/levels.py +++ b/tux/cogs/services/levels.py @@ -25,7 +25,7 @@ def __init__(self, bot: Tux) -> None: self.levels_exponent = self.settings.get("LEVELS_EXPONENT") self.xp_roles = {role["level"]: role["role_id"] for role in self.settings["XP_ROLES"]} self.xp_multipliers = {role["role_id"]: role["multiplier"] for role in self.settings["XP_MULTIPLIERS"]} - self.redis = bot.redis + self.redis = bot.redis.interface @commands.Cog.listener("on_message") async def xp_listener(self, message: discord.Message) -> None: @@ -108,7 +108,7 @@ async def update_xp_and_level(self, user_id: int, guild_id: int, new_xp: float, ) cooldown_key = f"xp_cooldown:{user_id}:{guild_id}" - await self.redis.set(cooldown_key, "1", expiration=self.xp_cooldown) + await self.redis.set(cooldown_key, "1", ex=self.xp_cooldown) async def handle_level_up(self, member: discord.Member, guild: discord.Guild, new_level: int) -> None: """ diff --git a/tux/database/redis.py b/tux/database/redis.py index f7d44a79..9ae44507 100644 --- a/tux/database/redis.py +++ b/tux/database/redis.py @@ -1,17 +1,10 @@ -from collections.abc import Callable -from typing import Any - import redis.asyncio as redis from loguru import logger -from tux.utils.constants import CONST - class RedisManager: def __init__(self): - self.redis: redis.Redis = redis.Redis() - self.is_connected: bool = False - self.debug_log: Callable[[str], None] = logger.debug if CONST.REDIS_DEBUG_LOG else lambda msg: None + self.interface: redis.Redis = redis.Redis() async def connect(self, url: str) -> None: """ @@ -25,162 +18,10 @@ async def connect(self, url: str) -> None: try: self.redis = redis.from_url(url, decode_responses=True) # type: ignore await self.redis.ping() # type: ignore - self.is_connected = True logger.info("Successfully connected to Redis") + except redis.ConnectionError as e: logger.warning(f"Failed to connect to Redis: {e}") - self.is_connected = False - - async def disconnect(self) -> None: - """ - Disconnect from the Redis server. - """ - if self.redis: - await self.redis.close() - self.is_connected = False - logger.info("Disconnected from Redis") - - async def get(self, key: str) -> str | None: - """ - Retrieve a value from Redis by key. - - Parameters - ---------- - key : str - The key to retrieve the value for. - - Returns - ------- - str | None - The value associated with the key, or None if the key does not exist. - """ - if self.redis and self.is_connected: - value = await self.redis.get(key) - self.debug_log(f"Retrieved key '{key}' with value '{value}' from Redis") - return value - self.debug_log(f"Failed to retrieve key '{key}' from Redis") - return None - - async def set(self, key: str, value: str, expiration: int | None = None) -> None: - """ - Set a value in Redis with an optional expiration time. - - Parameters - ---------- - key : str - The key to set the value for. - value : str - The value to set. - expiration : int, optional - The expiration time in seconds, by default None. - """ - if self.redis and self.is_connected: - await self.redis.set(key, value, ex=expiration) - self.debug_log(f"Set key '{key}' with value '{value}' in Redis with expiration '{expiration}'") - - async def delete(self, key: str) -> None: - """ - Delete a key from Redis. - - Parameters - ---------- - key : str - The key to delete. - """ - if self.redis and self.is_connected: - await self.redis.delete(key) - self.debug_log(f"Deleted key '{key}' from Redis") - - async def increment(self, key: str, amount: int = 1) -> int | None: - """ - Increment the value of a key in Redis by a specified amount. - - Parameters - ---------- - key : str - The key to increment the value for. - amount : int, optional - The amount to increment by, by default 1. - - Returns - ------- - int | None - The new value after incrementing, or None if the operation failed. - """ - if self.redis and self.is_connected: - new_value = await self.redis.incr(key, amount) - self.debug_log(f"Incremented key '{key}' by '{amount}', new value is '{new_value}'") - return new_value - self.debug_log(f"Failed to increment key '{key}' by '{amount}'") - return None - - async def zadd(self, key: str, mapping: dict[str, float]) -> None: - """ - Add one or more members to a sorted set in Redis, or update its score if it already exists. - - Parameters - ---------- - key : str - The key of the sorted set. - mapping : dict[str, float] - A dictionary of member-score pairs to add to the sorted set. - """ - if self.redis and self.is_connected: - await self.redis.zadd(key, mapping) - self.debug_log(f"Added to sorted set '{key}' with mapping '{mapping}'") - - async def zrange(self, key: str, start: int, end: int, desc: bool = False, withscores: bool = False) -> list[Any]: - """ - Retrieve a range of members in a sorted set by index. - - Parameters - ---------- - key : str - The key of the sorted set. - start : int - The start index. - end : int - The end index. - desc : bool, optional - Whether to sort the results in descending order, by default False. - withscores : bool, optional - Whether to include the scores of the members, by default False. - - Returns - ------- - list[Any] - A list of members in the specified range, optionally with their scores. - """ - if self.redis and self.is_connected: - result = await self.redis.zrange(key, start, end, desc=desc, withscores=withscores) # type: ignore - self.debug_log( - f"Retrieved range from sorted set '{key}' from '{start}' to '{end}', desc='{desc}', withscores='{withscores}': {result}", - ) - return result - self.debug_log(f"Failed to retrieve range from sorted set '{key}' from '{start}' to '{end}'") - return [] - - async def exists(self, key: str) -> bool: - """ - Check if a key exists in Redis. - - Parameters - ---------- - key : str - The key to check for existence. - - Returns - ------- - bool - True if the key exists, False otherwise. - """ - if self.redis and self.is_connected: - exists = await self.redis.exists(key) - self.debug_log(f"Checked existence of key '{key}', exists: {exists}") - return exists > 0 - - self.debug_log(f"Failed to check existence of key '{key}'") - return False redis_manager = RedisManager()