Skip to content

Commit

Permalink
Generate url for polls
Browse files Browse the repository at this point in the history
  • Loading branch information
AiroPi committed Sep 3, 2024
1 parent 6aea935 commit 296234f
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 71 deletions.
56 changes: 56 additions & 0 deletions alembic/versions/a0556697d480_add_url_to_poll_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# type: ignore

"""Add url to poll table
Revision ID: a0556697d480
Revises: 82e8adf72f35
Create Date: 2024-09-03 21:58:21.606304
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session

try:
import fastnanoid
except ImportError:
fastnanoid = None

# revision identifiers, used by Alembic.
revision = 'a0556697d480'
down_revision = '82e8adf72f35'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('poll', sa.Column('url', sa.VARCHAR(length=21), nullable=True))

connection = op.get_bind()
session = Session(bind=connection)

if fastnanoid is None:
return

try:
rows = session.execute(sa.select([sa.table('poll')])).fetchall()

for row in rows:
session.execute(
sa.update(sa.table('poll'))
.where(sa.text('id = :id'))
.values(url=fastnanoid.generate()),
{'id': row['id']}
)

session.commit()
finally:
session.close()

op.alter_column('poll', 'url', nullable=False)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('poll', 'url')
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions bin/alembic.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
docker compose --progress quiet up database -d --quiet-pull
docker compose --progress quiet build mybot
docker compose --progress quiet run --rm -t -v "${PWD}/alembic:/app/alembic" mybot alembic "$@"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"aiohttp",
"python-dateutil",
"typer",
"fastnanoid",
]
requires-python = ">=3.12"

Expand Down
70 changes: 15 additions & 55 deletions src/cogs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,22 @@
import logging
from collections.abc import Awaitable, Callable
from functools import partial
from os import getpid
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar, cast
from typing import TYPE_CHECKING, Concatenate

from aiohttp import hdrs, web
from psutil import Process

from core import ExtendedCog, config, db
from core.db.queries.poll import get_poll_answers

if TYPE_CHECKING:
from mybot import MyBot


logger = logging.getLogger(__name__)

P = ParamSpec("P")
S = TypeVar("S", bound="ExtendedCog")


def route(method: str, path: str):
def wrap(func: Callable[Concatenate[S, web.Request, P], Awaitable[web.Response]]):
def wrap[C: ExtendedCog](func: Callable[Concatenate[C, web.Request, ...], Awaitable[web.Response]]):
func.__route__ = (method, path) # type: ignore
return func

Expand Down Expand Up @@ -52,71 +48,35 @@ async def start(self) -> None:

await self.runner.setup()
site = web.TCPSite(self.runner, "0.0.0.0", 8080) # noqa: S104 # in a docker container
print("Server started")
logger.info("Server started on address 0.0.0.0:8080")
await site.start()

async def cog_unload(self) -> None:
await self.app.shutdown()
await self.runner.cleanup()

@route(hdrs.METH_GET, "/memory")
async def test(self, request: web.Request):
rss = cast(int, Process(getpid()).memory_info().rss) # pyright: ignore[reportUnknownMemberType]
return web.Response(text=f"{round(rss / 1024 / 1024, 2)} MB")

@route(hdrs.METH_GET, r"/poll/{poll_message_id:\d+}/")
@route(hdrs.METH_GET, r"/poll/{poll_url}/")
async def poll(self, request: web.Request):
poll_message_id = int(request.match_info["poll_message_id"])
result = await db.get_poll_informations(self.bot)(poll_message_id)
poll_url = request.match_info["poll_url"]
result = await db.get_poll_informations(self.bot)(poll_url)
if result is None:
return web.Response(status=404)
poll, values = result

return web.Response(
text=json.dumps(
{
"poll_id": poll.id,
"title": poll.title,
"description": poll.description,
"type": poll.type.name,
"values": values,
}
)
)

@route(hdrs.METH_GET, r"/poll/{poll_message_id:\d+}/{choice_id:\d+}/")

return web.Response(text=json.dumps(result))

@route(hdrs.METH_GET, r"/poll/{poll_url}/{choice_id:\d+}/")
async def poll_votes(self, request: web.Request):
message_id = int(request.match_info["poll_message_id"])
poll_url = request.match_info["poll_url"]
choice_id = int(request.match_info["choice_id"])
try:
from_ = int(request.query.get("from", 0))
number = max(int(request.query.get("number", 10)), 100)
except ValueError:
return web.Response(status=400)

async with self.bot.async_session() as session:
result = await session.execute(
db.select(db.PollAnswer)
.join(db.Poll)
.where(db.Poll.message_id == message_id)
.where(db.PollAnswer.poll_id == db.Poll.id, db.PollAnswer.value == str(choice_id))
.limit(number)
.offset(from_)
)
votes = result.scalars().all()

return web.Response(
text=json.dumps(
[
{
"id": vote.id,
"user_id": vote.user_id,
"anonymous": vote.anonymous,
}
for vote in votes
]
)
)
votes = await get_poll_answers(self.bot)(poll_url, choice_id, from_, number)

return web.Response(text=json.dumps(votes))


async def setup(bot: MyBot):
Expand Down
2 changes: 2 additions & 0 deletions src/cogs/poll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Self, cast

import discord
import fastnanoid
from discord import app_commands, ui
from discord.app_commands import locale_str as __
from sqlalchemy.orm import selectinload
Expand Down Expand Up @@ -59,6 +60,7 @@ async def callback(self, inter: Interaction, poll_type: db.PollType) -> None:
author_id=inter.user.id,
type=db.PollType(poll_type.value),
creation_date=inter.created_at,
url=fastnanoid.generate(),
)

poll_menu_from_type: dict[db.PollType, type[PollModal]] = {
Expand Down
2 changes: 1 addition & 1 deletion src/cogs/poll/vote_menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def __init__(self, cog: PollCog, poll: db.Poll | None = None):
ui.Button(
style=discord.ButtonStyle.url,
label=_("Results", _silent=True),
url=f"http://localhost:8000/poll/{poll.message_id}",
url=f"http://localhost:8000/poll/{poll.url}",
)
)

Expand Down
112 changes: 99 additions & 13 deletions src/core/db/queries/poll.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
from typing import TYPE_CHECKING, Any
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, TypedDict

import sqlalchemy as sql
from sqlalchemy import orm

from ..tables import Poll, PollAnswer, PollChoice, PollType
from ..tables import Poll, PollAnswer, PollChoice, PollType, UserDB
from ..utils import with_session

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession


class PollInformation(TypedDict):
poll_id: int
title: str
description: str | None
type: int
values: list[ChoiceInformation]


class ChoiceInformation(TypedDict):
id: int
label: str
count: int
answers_preview: list[AnswerInformation]


class AnonAnswerInformation(TypedDict):
username: None
avatar: None
anonymous: Literal[True]


class PublicAnswerInformation(TypedDict):
username: str
avatar: str
anonymous: Literal[False]


AnswerInformation = AnonAnswerInformation | PublicAnswerInformation


@with_session
async def get_poll_informations(session: AsyncSession, message_id: int):
query = sql.select(Poll).where(Poll.message_id == message_id).options(orm.noload(Poll.choices))
async def get_poll_informations(session: AsyncSession, poll_url: str) -> PollInformation | None:
query = sql.select(Poll).where(Poll.url == poll_url).options(orm.noload(Poll.choices))
result = await session.execute(query)
poll = result.scalar_one_or_none()
if poll is None:
Expand All @@ -21,14 +53,16 @@ async def get_poll_informations(session: AsyncSession, message_id: int):
answer_count_subquery = (
sql.select(
sql.cast(PollAnswer.value, sql.Integer).label("choice_id"),
sql.func.count(sql.PollAnswer.id).label("choice_count"),
sql.func.count(PollAnswer.id).label("choice_count"),
)
.where(PollAnswer.poll_id == poll.id)
.group_by(PollAnswer.value)
.subquery()
)
user_ids_subquery = (
sql.select(sql.cast(PollAnswer.value, sql.Integer).label("choice_id"), PollAnswer.user_id)
sql.select(
sql.cast(PollAnswer.value, sql.Integer).label("choice_id"), PollAnswer.user_id, PollAnswer.anonymous
)
.where(PollAnswer.poll_id == poll.id)
.limit(3)
.subquery()
Expand All @@ -37,26 +71,78 @@ async def get_poll_informations(session: AsyncSession, message_id: int):
sql.select(
PollChoice,
sql.func.coalesce(answer_count_subquery.c.choice_count, 0).label("choice_count"),
sql.func.array_agg(user_ids_subquery.c.user_id).label("user_ids"),
sql.func.array_agg(UserDB.username).label("usernames"),
sql.func.array_agg(UserDB.avatar).label("avatars"),
sql.func.array_agg(user_ids_subquery.c.anonymous).label("anon"),
)
.outerjoin(
answer_count_subquery,
sql.PollChoice.id == answer_count_subquery.c.choice_id,
PollChoice.id == answer_count_subquery.c.choice_id,
)
.outerjoin(
user_ids_subquery,
sql.PollChoice.id == user_ids_subquery.c.choice_id,
PollChoice.id == user_ids_subquery.c.choice_id,
)
.outerjoin(UserDB, user_ids_subquery.c.user_id == UserDB.user_id)
.where(PollChoice.poll_id == poll.id)
.group_by(PollChoice.id, answer_count_subquery.c.choice_count)
)
result = await session.execute(query)
choices = result.all()
values: list[dict[str, Any]] = [
{"id": c.id, "label": c.label, "count": nb, "users_preview": (users if users != [None] else [])}
for c, nb, users in choices
values: list[ChoiceInformation] = [
{
"id": c.id,
"label": c.label,
"count": nb,
"answers_preview": [
{
"username": username if not an else None,
"avatar": avatar if not an else None,
"anonymous": an,
}
for username, avatar, an in zip(usernames, avatars, anon)
]
# Postgres return [NULL] instead of an empty array, so we replace [None] with [].
# There is a minor scenario where this is a problem: if there is exactly one vote from a user why is not in the database.
# This minor case is ignored for now.
if usernames != [None]
else [],
}
for c, nb, usernames, avatars, anon in choices
]
else:
values = []

return poll, values
return PollInformation(
poll_id=poll.id, title=poll.title, description=poll.description, type=poll.type.value, values=values
)


@with_session
async def get_poll_answers(
session: AsyncSession, poll_url: str, choice_id: int, from_: int, number: int
) -> list[AnswerInformation]:
result = await session.execute(
sql.select(
PollAnswer.anonymous,
UserDB.username,
UserDB.avatar,
)
.join(Poll, Poll.id == PollAnswer.poll_id)
.outerjoin(UserDB, PollAnswer.user_id == UserDB.user_id)
.where(
Poll.url == poll_url,
PollAnswer.poll_id == Poll.id,
PollAnswer.value == str(choice_id),
)
.limit(number)
.offset(from_)
)
return [
{
"username": username if not anon else None,
"avatar": avatar if not anon else None,
"anonymous": anon,
}
for anon, username, avatar in result.all()
]
1 change: 1 addition & 0 deletions src/core/db/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Poll(Base, kw_only=True):
public_results: Mapped[bool] = mapped_column(default=True)
closed: Mapped[bool] = mapped_column(default=False)
anonymous_allowed: Mapped[bool] = mapped_column(default=False)
url: Mapped[str] = mapped_column(VARCHAR(21))
allowed_roles: Mapped[list[Snowflake]] = _mapped_column(
MutableList.as_mutable(ARRAY(BigInteger)), default_factory=list
)
Expand Down
Loading

0 comments on commit 296234f

Please sign in to comment.