Skip to content

Commit

Permalink
Fix facebook support, add postfmt command
Browse files Browse the repository at this point in the history
  • Loading branch information
amadejkastelic committed Jul 14, 2024
1 parent 6f51412 commit b091896
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 4 deletions.
1 change: 1 addition & 0 deletions bot/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class Store(enum.Enum):
SERVER = 'server'
SERVER_POST_COUNT = 'server_post_count'
SERVER_USER_BANNED = 'server_user_banned'
SERVER_INTEGRATION_POST_FORMAT = 'srv_int_post_fmt'


NO_HIT = -1
Expand Down
2 changes: 1 addition & 1 deletion bot/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Integration(object):
def __str__(self) -> str:
return INTEGRATION_INFO_FORMAT.format(
name=self.integration.value.capitalize(),
enabled=self.enabled,
enabled='Enabled' if self.enabled else 'Disabled',
)


Expand Down
6 changes: 5 additions & 1 deletion bot/downloader/facebook/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import typing
from urllib import parse as urllib_parse

import facebook_scraper
from django.conf import settings
Expand Down Expand Up @@ -35,11 +36,14 @@ def __init__(self, cookies_path: str):
self.cookies_path = cookies_path

async def get_integration_data(self, url: str) -> typing.Tuple[constants.Integration, str, typing.Optional[int]]:
if url.split('?')[0].endswith('/watch') and 'v=' in url:
return self.INTEGRATION, urllib_parse.parse_qs(urllib_parse.urlparse(url).query).get('v', None), None

return self.INTEGRATION, url.split('?')[0].split('/')[-1], None

async def get_post(self, url: str) -> domain.Post:
kwargs = {}
if os.path.exists(self.cookies_path):
if self.cookies_path and os.path.exists(self.cookies_path):
kwargs['cookies'] = 'cookies.txt'

fb_post = next(facebook_scraper.get_posts(post_urls=[url], **kwargs))
Expand Down
2 changes: 2 additions & 0 deletions bot/downloader/facebook/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ class FacebookConfig(base.BaseClientConfig):


class FacebookConfigSchema(base.BaseClientConfigSchema):
_CONFIG_CLASS = FacebookConfig

cookies_file_path = marshmallow_fields.Str(allow_none=True, load_default=None)
20 changes: 19 additions & 1 deletion bot/integrations/discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, *, intents: discord.Intents, **options: typing.Any) -> None:
description='(Un)Ban a user from using embed commands',
callback=self.silence_user,
),
app_commands.Command(
name='postfmt',
description='Fetches post format for specified site',
callback=self.get_post_format,
),
]

self.tree = app_commands.CommandTree(client=self)
Expand Down Expand Up @@ -96,7 +101,7 @@ async def on_message(self, message: discord.Message): # noqa: C901
),
new_message.add_reaction('❌'),
)
return
raise e

try:
msg = await self._send_post(post=post, send_func=message.channel.send, author=message.author)
Expand Down Expand Up @@ -194,6 +199,19 @@ async def silence_user(self, interaction: discord.Interaction, member: discord.M
ephemeral=True,
)

@checks.has_permissions(administrator=True)
async def get_post_format(self, interaction: discord.Interaction, site: constants.Integration) -> None:
await interaction.response.defer(ephemeral=True)

await interaction.followup.send(
content=service.get_post_format(
server_vendor=constants.ServerVendor.DISCORD,
server_uid=interaction.guild.id,
integration=site,
),
ephemeral=True,
)

async def _send_post(
self,
post: domain.Post,
Expand Down
32 changes: 32 additions & 0 deletions bot/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,38 @@ def update_post_format(
cache.delete(store=cache.Store.SERVER, key=f'{vendor.value}_{vendor_uid}')


def get_post_format(
vendor: constants.ServerVendor,
vendor_uid: str,
integration: constants.Integration,
) -> str:
cache_key = f'{vendor.value}_{vendor_uid}_{integration.value}'
post_format = cache.get(
store=cache.Store.SERVER_INTEGRATION_POST_FORMAT,
key=cache_key,
)
if post_format != cache.NO_HIT:
return post_format

server_integration = (
models.ServerIntegration.objects.filter(
server__vendor=vendor,
server__vendor_uid=vendor_uid,
integration=integration,
)
.only('post_format')
.first()
)
if server_integration is not None and server_integration.post_format is not None:
post_format = server_integration.post_format
else:
post_format = domain.DEFAULT_POST_FORMAT

cache.set(store=cache.Store.SERVER_INTEGRATION_POST_FORMAT, key=cache_key, value=post_format)

return post_format


def get_server(
vendor: constants.ServerVendor,
vendor_uid: str,
Expand Down
14 changes: 13 additions & 1 deletion bot/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def change_server_member_banned_status(
)


def get_post_format(
server_vendor: constants.ServerVendor,
server_uid: str,
integration: constants.Integration,
) -> str:
return repository.get_post_format(
vendor=server_vendor,
vendor_uid=server_uid,
integration=integration,
)


async def get_post(
url: str,
server_vendor: constants.ServerVendor,
Expand All @@ -59,7 +71,7 @@ async def get_post(
)
if not server:
logging.info(f'Server {server_uid} not configured, creating a default config')
repository.create_server(vendor=server_vendor, vendor_uid=server_uid)
server = repository.create_server(vendor=server_vendor, vendor_uid=server_uid)

num_posts_in_server = repository.get_number_of_posts_in_server_from_datetime(
server_id=server._internal_id,
Expand Down

0 comments on commit b091896

Please sign in to comment.