diff --git a/bdfr/__main__.py b/bdfr/__main__.py index 1117a70a..c26f5775 100644 --- a/bdfr/__main__.py +++ b/bdfr/__main__.py @@ -13,53 +13,54 @@ logger = logging.getLogger() _common_options = [ - click.argument('directory', type=str), - click.option('--authenticate', is_flag=True, default=None), - click.option('--config', type=str, default=None), - click.option('--opts', type=str, default=None), - click.option('--disable-module', multiple=True, default=None, type=str), - click.option('--exclude-id', default=None, multiple=True), - click.option('--exclude-id-file', default=None, multiple=True), - click.option('--file-scheme', default=None, type=str), - click.option('--folder-scheme', default=None, type=str), - click.option('--ignore-user', type=str, multiple=True, default=None), - click.option('--include-id-file', multiple=True, default=None), - click.option('--log', type=str, default=None), - click.option('--saved', is_flag=True, default=None), - click.option('--search', default=None, type=str), - click.option('--submitted', is_flag=True, default=None), - click.option('--subscribed', is_flag=True, default=None), - click.option('--time-format', type=str, default=None), - click.option('--upvoted', is_flag=True, default=None), - click.option('-L', '--limit', default=None, type=int), - click.option('-l', '--link', multiple=True, default=None, type=str), - click.option('-m', '--multireddit', multiple=True, default=None, type=str), - click.option('-S', '--sort', type=click.Choice(('hot', 'top', 'new', 'controversial', 'rising', 'relevance')), - default=None), - click.option('-s', '--subreddit', multiple=True, default=None, type=str), - click.option('-t', '--time', type=click.Choice(('all', 'hour', 'day', 'week', 'month', 'year')), default=None), - click.option('-u', '--user', type=str, multiple=True, default=None), - click.option('-v', '--verbose', default=None, count=True), + click.argument("directory", type=str), + click.option("--authenticate", is_flag=True, default=None), + click.option("--config", type=str, default=None), + click.option("--opts", type=str, default=None), + click.option("--disable-module", multiple=True, default=None, type=str), + click.option("--exclude-id", default=None, multiple=True), + click.option("--exclude-id-file", default=None, multiple=True), + click.option("--file-scheme", default=None, type=str), + click.option("--folder-scheme", default=None, type=str), + click.option("--ignore-user", type=str, multiple=True, default=None), + click.option("--include-id-file", multiple=True, default=None), + click.option("--log", type=str, default=None), + click.option("--saved", is_flag=True, default=None), + click.option("--search", default=None, type=str), + click.option("--submitted", is_flag=True, default=None), + click.option("--subscribed", is_flag=True, default=None), + click.option("--time-format", type=str, default=None), + click.option("--upvoted", is_flag=True, default=None), + click.option("-L", "--limit", default=None, type=int), + click.option("-l", "--link", multiple=True, default=None, type=str), + click.option("-m", "--multireddit", multiple=True, default=None, type=str), + click.option( + "-S", "--sort", type=click.Choice(("hot", "top", "new", "controversial", "rising", "relevance")), default=None + ), + click.option("-s", "--subreddit", multiple=True, default=None, type=str), + click.option("-t", "--time", type=click.Choice(("all", "hour", "day", "week", "month", "year")), default=None), + click.option("-u", "--user", type=str, multiple=True, default=None), + click.option("-v", "--verbose", default=None, count=True), ] _downloader_options = [ - click.option('--make-hard-links', is_flag=True, default=None), - click.option('--max-wait-time', type=int, default=None), - click.option('--no-dupes', is_flag=True, default=None), - click.option('--search-existing', is_flag=True, default=None), - click.option('--skip', default=None, multiple=True), - click.option('--skip-domain', default=None, multiple=True), - click.option('--skip-subreddit', default=None, multiple=True), - click.option('--min-score', type=int, default=None), - click.option('--max-score', type=int, default=None), - click.option('--min-score-ratio', type=float, default=None), - click.option('--max-score-ratio', type=float, default=None), + click.option("--make-hard-links", is_flag=True, default=None), + click.option("--max-wait-time", type=int, default=None), + click.option("--no-dupes", is_flag=True, default=None), + click.option("--search-existing", is_flag=True, default=None), + click.option("--skip", default=None, multiple=True), + click.option("--skip-domain", default=None, multiple=True), + click.option("--skip-subreddit", default=None, multiple=True), + click.option("--min-score", type=int, default=None), + click.option("--max-score", type=int, default=None), + click.option("--min-score-ratio", type=float, default=None), + click.option("--max-score-ratio", type=float, default=None), ] _archiver_options = [ - click.option('--all-comments', is_flag=True, default=None), - click.option('--comment-context', is_flag=True, default=None), - click.option('-f', '--format', type=click.Choice(('xml', 'json', 'yaml')), default=None), + click.option("--all-comments", is_flag=True, default=None), + click.option("--comment-context", is_flag=True, default=None), + click.option("-f", "--format", type=click.Choice(("xml", "json", "yaml")), default=None), ] @@ -68,6 +69,7 @@ def wrap(func): for opt in opts: func = opt(func) return func + return wrap @@ -76,7 +78,7 @@ def cli(): pass -@cli.command('download') +@cli.command("download") @_add_options(_common_options) @_add_options(_downloader_options) @click.pass_context @@ -88,13 +90,13 @@ def cli_download(context: click.Context, **_): reddit_downloader = RedditDownloader(config) reddit_downloader.download() except Exception: - logger.exception('Downloader exited unexpectedly') + logger.exception("Downloader exited unexpectedly") raise else: - logger.info('Program complete') + logger.info("Program complete") -@cli.command('archive') +@cli.command("archive") @_add_options(_common_options) @_add_options(_archiver_options) @click.pass_context @@ -106,13 +108,13 @@ def cli_archive(context: click.Context, **_): reddit_archiver = Archiver(config) reddit_archiver.download() except Exception: - logger.exception('Archiver exited unexpectedly') + logger.exception("Archiver exited unexpectedly") raise else: - logger.info('Program complete') + logger.info("Program complete") -@cli.command('clone') +@cli.command("clone") @_add_options(_common_options) @_add_options(_archiver_options) @_add_options(_downloader_options) @@ -125,10 +127,10 @@ def cli_clone(context: click.Context, **_): reddit_scraper = RedditCloner(config) reddit_scraper.download() except Exception: - logger.exception('Scraper exited unexpectedly') + logger.exception("Scraper exited unexpectedly") raise else: - logger.info('Program complete') + logger.info("Program complete") def setup_logging(verbosity: int): @@ -141,7 +143,7 @@ def filter(self, record: logging.LogRecord) -> bool: stream = logging.StreamHandler(sys.stdout) stream.addFilter(StreamExceptionFilter()) - formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') + formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s") stream.setFormatter(formatter) logger.addHandler(stream) @@ -151,10 +153,10 @@ def filter(self, record: logging.LogRecord) -> bool: stream.setLevel(logging.DEBUG) else: stream.setLevel(9) - logging.getLogger('praw').setLevel(logging.CRITICAL) - logging.getLogger('prawcore').setLevel(logging.CRITICAL) - logging.getLogger('urllib3').setLevel(logging.CRITICAL) + logging.getLogger("praw").setLevel(logging.CRITICAL) + logging.getLogger("prawcore").setLevel(logging.CRITICAL) + logging.getLogger("urllib3").setLevel(logging.CRITICAL) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/bdfr/archive_entry/base_archive_entry.py b/bdfr/archive_entry/base_archive_entry.py index 57e36f84..49ea58a8 100644 --- a/bdfr/archive_entry/base_archive_entry.py +++ b/bdfr/archive_entry/base_archive_entry.py @@ -19,21 +19,21 @@ def compile(self) -> dict: @staticmethod def _convert_comment_to_dict(in_comment: Comment) -> dict: out_dict = { - 'author': in_comment.author.name if in_comment.author else 'DELETED', - 'id': in_comment.id, - 'score': in_comment.score, - 'subreddit': in_comment.subreddit.display_name, - 'author_flair': in_comment.author_flair_text, - 'submission': in_comment.submission.id, - 'stickied': in_comment.stickied, - 'body': in_comment.body, - 'is_submitter': in_comment.is_submitter, - 'distinguished': in_comment.distinguished, - 'created_utc': in_comment.created_utc, - 'parent_id': in_comment.parent_id, - 'replies': [], + "author": in_comment.author.name if in_comment.author else "DELETED", + "id": in_comment.id, + "score": in_comment.score, + "subreddit": in_comment.subreddit.display_name, + "author_flair": in_comment.author_flair_text, + "submission": in_comment.submission.id, + "stickied": in_comment.stickied, + "body": in_comment.body, + "is_submitter": in_comment.is_submitter, + "distinguished": in_comment.distinguished, + "created_utc": in_comment.created_utc, + "parent_id": in_comment.parent_id, + "replies": [], } in_comment.replies.replace_more(limit=None) for reply in in_comment.replies: - out_dict['replies'].append(BaseArchiveEntry._convert_comment_to_dict(reply)) + out_dict["replies"].append(BaseArchiveEntry._convert_comment_to_dict(reply)) return out_dict diff --git a/bdfr/archive_entry/comment_archive_entry.py b/bdfr/archive_entry/comment_archive_entry.py index 1bb5c180..1c72811a 100644 --- a/bdfr/archive_entry/comment_archive_entry.py +++ b/bdfr/archive_entry/comment_archive_entry.py @@ -17,5 +17,5 @@ def __init__(self, comment: praw.models.Comment): def compile(self) -> dict: self.source.refresh() self.post_details = self._convert_comment_to_dict(self.source) - self.post_details['submission_title'] = self.source.submission.title + self.post_details["submission_title"] = self.source.submission.title return self.post_details diff --git a/bdfr/archive_entry/submission_archive_entry.py b/bdfr/archive_entry/submission_archive_entry.py index c124e0f0..92f326ee 100644 --- a/bdfr/archive_entry/submission_archive_entry.py +++ b/bdfr/archive_entry/submission_archive_entry.py @@ -18,32 +18,32 @@ def compile(self) -> dict: comments = self._get_comments() self._get_post_details() out = self.post_details - out['comments'] = comments + out["comments"] = comments return out def _get_post_details(self): self.post_details = { - 'title': self.source.title, - 'name': self.source.name, - 'url': self.source.url, - 'selftext': self.source.selftext, - 'score': self.source.score, - 'upvote_ratio': self.source.upvote_ratio, - 'permalink': self.source.permalink, - 'id': self.source.id, - 'author': self.source.author.name if self.source.author else 'DELETED', - 'link_flair_text': self.source.link_flair_text, - 'num_comments': self.source.num_comments, - 'over_18': self.source.over_18, - 'spoiler': self.source.spoiler, - 'pinned': self.source.pinned, - 'locked': self.source.locked, - 'distinguished': self.source.distinguished, - 'created_utc': self.source.created_utc, + "title": self.source.title, + "name": self.source.name, + "url": self.source.url, + "selftext": self.source.selftext, + "score": self.source.score, + "upvote_ratio": self.source.upvote_ratio, + "permalink": self.source.permalink, + "id": self.source.id, + "author": self.source.author.name if self.source.author else "DELETED", + "link_flair_text": self.source.link_flair_text, + "num_comments": self.source.num_comments, + "over_18": self.source.over_18, + "spoiler": self.source.spoiler, + "pinned": self.source.pinned, + "locked": self.source.locked, + "distinguished": self.source.distinguished, + "created_utc": self.source.created_utc, } def _get_comments(self) -> list[dict]: - logger.debug(f'Retrieving full comment tree for submission {self.source.id}') + logger.debug(f"Retrieving full comment tree for submission {self.source.id}") comments = [] self.source.comments.replace_more(limit=None) for top_level_comment in self.source.comments: diff --git a/bdfr/archiver.py b/bdfr/archiver.py index 809af964..3d0d31b6 100644 --- a/bdfr/archiver.py +++ b/bdfr/archiver.py @@ -30,26 +30,28 @@ def download(self): for generator in self.reddit_lists: for submission in generator: try: - if (submission.author and submission.author.name in self.args.ignore_user) or \ - (submission.author is None and 'DELETED' in self.args.ignore_user): + if (submission.author and submission.author.name in self.args.ignore_user) or ( + submission.author is None and "DELETED" in self.args.ignore_user + ): logger.debug( - f'Submission {submission.id} in {submission.subreddit.display_name} skipped' - f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user') + f"Submission {submission.id} in {submission.subreddit.display_name} skipped" + f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user' + ) continue if submission.id in self.excluded_submission_ids: - logger.debug(f'Object {submission.id} in exclusion list, skipping') + logger.debug(f"Object {submission.id} in exclusion list, skipping") continue - logger.debug(f'Attempting to archive submission {submission.id}') + logger.debug(f"Attempting to archive submission {submission.id}") self.write_entry(submission) except prawcore.PrawcoreException as e: - logger.error(f'Submission {submission.id} failed to be archived due to a PRAW exception: {e}') + logger.error(f"Submission {submission.id} failed to be archived due to a PRAW exception: {e}") def get_submissions_from_link(self) -> list[list[praw.models.Submission]]: supplied_submissions = [] for sub_id in self.args.link: if len(sub_id) == 6: supplied_submissions.append(self.reddit_instance.submission(id=sub_id)) - elif re.match(r'^\w{7}$', sub_id): + elif re.match(r"^\w{7}$", sub_id): supplied_submissions.append(self.reddit_instance.comment(id=sub_id)) else: supplied_submissions.append(self.reddit_instance.submission(url=sub_id)) @@ -60,7 +62,7 @@ def get_user_data(self) -> list[Iterator]: if self.args.user and self.args.all_comments: sort = self.determine_sort_function() for user in self.args.user: - logger.debug(f'Retrieving comments of user {user}') + logger.debug(f"Retrieving comments of user {user}") results.append(sort(self.reddit_instance.redditor(user).comments, limit=self.args.limit)) return results @@ -71,43 +73,44 @@ def _pull_lever_entry_factory(praw_item: Union[praw.models.Submission, praw.mode elif isinstance(praw_item, praw.models.Comment): return CommentArchiveEntry(praw_item) else: - raise ArchiverError(f'Factory failed to classify item of type {type(praw_item).__name__}') + raise ArchiverError(f"Factory failed to classify item of type {type(praw_item).__name__}") def write_entry(self, praw_item: Union[praw.models.Submission, praw.models.Comment]): if self.args.comment_context and isinstance(praw_item, praw.models.Comment): - logger.debug(f'Converting comment {praw_item.id} to submission {praw_item.submission.id}') + logger.debug(f"Converting comment {praw_item.id} to submission {praw_item.submission.id}") praw_item = praw_item.submission archive_entry = self._pull_lever_entry_factory(praw_item) - if self.args.format == 'json': + if self.args.format == "json": self._write_entry_json(archive_entry) - elif self.args.format == 'xml': + elif self.args.format == "xml": self._write_entry_xml(archive_entry) - elif self.args.format == 'yaml': + elif self.args.format == "yaml": self._write_entry_yaml(archive_entry) else: - raise ArchiverError(f'Unknown format {self.args.format} given') - logger.info(f'Record for entry item {praw_item.id} written to disk') + raise ArchiverError(f"Unknown format {self.args.format} given") + logger.info(f"Record for entry item {praw_item.id} written to disk") def _write_entry_json(self, entry: BaseArchiveEntry): - resource = Resource(entry.source, '', lambda: None, '.json') + resource = Resource(entry.source, "", lambda: None, ".json") content = json.dumps(entry.compile()) self._write_content_to_disk(resource, content) def _write_entry_xml(self, entry: BaseArchiveEntry): - resource = Resource(entry.source, '', lambda: None, '.xml') - content = dict2xml.dict2xml(entry.compile(), wrap='root') + resource = Resource(entry.source, "", lambda: None, ".xml") + content = dict2xml.dict2xml(entry.compile(), wrap="root") self._write_content_to_disk(resource, content) def _write_entry_yaml(self, entry: BaseArchiveEntry): - resource = Resource(entry.source, '', lambda: None, '.yaml') + resource = Resource(entry.source, "", lambda: None, ".yaml") content = yaml.dump(entry.compile()) self._write_content_to_disk(resource, content) def _write_content_to_disk(self, resource: Resource, content: str): file_path = self.file_name_formatter.format_path(resource, self.download_directory) file_path.parent.mkdir(exist_ok=True, parents=True) - with open(file_path, 'w', encoding="utf-8") as file: + with open(file_path, "w", encoding="utf-8") as file: logger.debug( - f'Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}' - f' format at {file_path}') + f"Writing entry {resource.source_submission.id} to file in {resource.extension[1:].upper()}" + f" format at {file_path}" + ) file.write(content) diff --git a/bdfr/cloner.py b/bdfr/cloner.py index 47e03f8b..c26d17b5 100644 --- a/bdfr/cloner.py +++ b/bdfr/cloner.py @@ -23,4 +23,4 @@ def download(self): self._download_submission(submission) self.write_entry(submission) except prawcore.PrawcoreException as e: - logger.error(f'Submission {submission.id} failed to be cloned due to a PRAW exception: {e}') + logger.error(f"Submission {submission.id} failed to be cloned due to a PRAW exception: {e}") diff --git a/bdfr/configuration.py b/bdfr/configuration.py index c15e429d..a2a53105 100644 --- a/bdfr/configuration.py +++ b/bdfr/configuration.py @@ -1,28 +1,29 @@ #!/usr/bin/env python3 # coding=utf-8 +import logging from argparse import Namespace from pathlib import Path from typing import Optional -import logging import click import yaml logger = logging.getLogger(__name__) + class Configuration(Namespace): def __init__(self): super(Configuration, self).__init__() self.authenticate = False self.config = None self.opts: Optional[str] = None - self.directory: str = '.' + self.directory: str = "." self.disable_module: list[str] = [] self.exclude_id = [] self.exclude_id_file = [] - self.file_scheme: str = '{REDDITOR}_{TITLE}_{POSTID}' - self.folder_scheme: str = '{SUBREDDIT}' + self.file_scheme: str = "{REDDITOR}_{TITLE}_{POSTID}" + self.folder_scheme: str = "{SUBREDDIT}" self.ignore_user = [] self.include_id_file = [] self.limit: Optional[int] = None @@ -42,11 +43,11 @@ def __init__(self): self.max_score = None self.min_score_ratio = None self.max_score_ratio = None - self.sort: str = 'hot' + self.sort: str = "hot" self.submitted: bool = False self.subscribed: bool = False self.subreddit: list[str] = [] - self.time: str = 'all' + self.time: str = "all" self.time_format = None self.upvoted: bool = False self.user: list[str] = [] @@ -54,15 +55,15 @@ def __init__(self): # Archiver-specific options self.all_comments = False - self.format = 'json' + self.format = "json" self.comment_context: bool = False def process_click_arguments(self, context: click.Context): - if context.params.get('opts') is not None: - self.parse_yaml_options(context.params['opts']) + if context.params.get("opts") is not None: + self.parse_yaml_options(context.params["opts"]) for arg_key in context.params.keys(): if not hasattr(self, arg_key): - logger.warning(f'Ignoring an unknown CLI argument: {arg_key}') + logger.warning(f"Ignoring an unknown CLI argument: {arg_key}") continue val = context.params[arg_key] if val is None or val == (): @@ -73,16 +74,16 @@ def process_click_arguments(self, context: click.Context): def parse_yaml_options(self, file_path: str): yaml_file_loc = Path(file_path) if not yaml_file_loc.exists(): - logger.error(f'No YAML file found at {yaml_file_loc}') + logger.error(f"No YAML file found at {yaml_file_loc}") return with yaml_file_loc.open() as file: try: opts = yaml.load(file, Loader=yaml.FullLoader) except yaml.YAMLError as e: - logger.error(f'Could not parse YAML options file: {e}') + logger.error(f"Could not parse YAML options file: {e}") return for arg_key, val in opts.items(): if not hasattr(self, arg_key): - logger.warning(f'Ignoring an unknown YAML argument: {arg_key}') + logger.warning(f"Ignoring an unknown YAML argument: {arg_key}") continue setattr(self, arg_key, val) diff --git a/bdfr/connector.py b/bdfr/connector.py index 3b359e8b..ea970db0 100644 --- a/bdfr/connector.py +++ b/bdfr/connector.py @@ -41,18 +41,18 @@ class SortType(Enum): TOP = auto() class TimeType(Enum): - ALL = 'all' - DAY = 'day' - HOUR = 'hour' - MONTH = 'month' - WEEK = 'week' - YEAR = 'year' + ALL = "all" + DAY = "day" + HOUR = "hour" + MONTH = "month" + WEEK = "week" + YEAR = "year" class RedditConnector(metaclass=ABCMeta): def __init__(self, args: Configuration): self.args = args - self.config_directories = appdirs.AppDirs('bdfr', 'BDFR') + self.config_directories = appdirs.AppDirs("bdfr", "BDFR") self.run_time = datetime.now().isoformat() self._setup_internal_objects() @@ -68,13 +68,13 @@ def _setup_internal_objects(self): self.parse_disabled_modules() self.download_filter = self.create_download_filter() - logger.log(9, 'Created download filter') + logger.log(9, "Created download filter") self.time_filter = self.create_time_filter() - logger.log(9, 'Created time filter') + logger.log(9, "Created time filter") self.sort_filter = self.create_sort_filter() - logger.log(9, 'Created sort filter') + logger.log(9, "Created sort filter") self.file_name_formatter = self.create_file_name_formatter() - logger.log(9, 'Create file name formatter') + logger.log(9, "Create file name formatter") self.create_reddit_instance() self.args.user = list(filter(None, [self.resolve_user_name(user) for user in self.args.user])) @@ -88,7 +88,7 @@ def _setup_internal_objects(self): self.master_hash_list = {} self.authenticator = self.create_authenticator() - logger.log(9, 'Created site authenticator') + logger.log(9, "Created site authenticator") self.args.skip_subreddit = self.split_args_input(self.args.skip_subreddit) self.args.skip_subreddit = {sub.lower() for sub in self.args.skip_subreddit} @@ -96,18 +96,18 @@ def _setup_internal_objects(self): def read_config(self): """Read any cfg values that need to be processed""" if self.args.max_wait_time is None: - self.args.max_wait_time = self.cfg_parser.getint('DEFAULT', 'max_wait_time', fallback=120) - logger.debug(f'Setting maximum download wait time to {self.args.max_wait_time} seconds') + self.args.max_wait_time = self.cfg_parser.getint("DEFAULT", "max_wait_time", fallback=120) + logger.debug(f"Setting maximum download wait time to {self.args.max_wait_time} seconds") if self.args.time_format is None: - option = self.cfg_parser.get('DEFAULT', 'time_format', fallback='ISO') - if re.match(r'^[\s\'\"]*$', option): - option = 'ISO' - logger.debug(f'Setting datetime format string to {option}') + option = self.cfg_parser.get("DEFAULT", "time_format", fallback="ISO") + if re.match(r"^[\s\'\"]*$", option): + option = "ISO" + logger.debug(f"Setting datetime format string to {option}") self.args.time_format = option if not self.args.disable_module: - self.args.disable_module = [self.cfg_parser.get('DEFAULT', 'disabled_modules', fallback='')] + self.args.disable_module = [self.cfg_parser.get("DEFAULT", "disabled_modules", fallback="")] # Update config on disk - with open(self.config_location, 'w') as file: + with open(self.config_location, "w") as file: self.cfg_parser.write(file) def parse_disabled_modules(self): @@ -119,48 +119,48 @@ def parse_disabled_modules(self): def create_reddit_instance(self): if self.args.authenticate: - logger.debug('Using authenticated Reddit instance') - if not self.cfg_parser.has_option('DEFAULT', 'user_token'): - logger.log(9, 'Commencing OAuth2 authentication') - scopes = self.cfg_parser.get('DEFAULT', 'scopes', fallback='identity, history, read, save') + logger.debug("Using authenticated Reddit instance") + if not self.cfg_parser.has_option("DEFAULT", "user_token"): + logger.log(9, "Commencing OAuth2 authentication") + scopes = self.cfg_parser.get("DEFAULT", "scopes", fallback="identity, history, read, save") scopes = OAuth2Authenticator.split_scopes(scopes) oauth2_authenticator = OAuth2Authenticator( scopes, - self.cfg_parser.get('DEFAULT', 'client_id'), - self.cfg_parser.get('DEFAULT', 'client_secret'), + self.cfg_parser.get("DEFAULT", "client_id"), + self.cfg_parser.get("DEFAULT", "client_secret"), ) token = oauth2_authenticator.retrieve_new_token() - self.cfg_parser['DEFAULT']['user_token'] = token - with open(self.config_location, 'w') as file: + self.cfg_parser["DEFAULT"]["user_token"] = token + with open(self.config_location, "w") as file: self.cfg_parser.write(file, True) token_manager = OAuth2TokenManager(self.cfg_parser, self.config_location) self.authenticated = True self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + client_id=self.cfg_parser.get("DEFAULT", "client_id"), + client_secret=self.cfg_parser.get("DEFAULT", "client_secret"), user_agent=socket.gethostname(), token_manager=token_manager, ) else: - logger.debug('Using unauthenticated Reddit instance') + logger.debug("Using unauthenticated Reddit instance") self.authenticated = False self.reddit_instance = praw.Reddit( - client_id=self.cfg_parser.get('DEFAULT', 'client_id'), - client_secret=self.cfg_parser.get('DEFAULT', 'client_secret'), + client_id=self.cfg_parser.get("DEFAULT", "client_id"), + client_secret=self.cfg_parser.get("DEFAULT", "client_secret"), user_agent=socket.gethostname(), ) def retrieve_reddit_lists(self) -> list[praw.models.ListingGenerator]: master_list = [] master_list.extend(self.get_subreddits()) - logger.log(9, 'Retrieved subreddits') + logger.log(9, "Retrieved subreddits") master_list.extend(self.get_multireddits()) - logger.log(9, 'Retrieved multireddits') + logger.log(9, "Retrieved multireddits") master_list.extend(self.get_user_data()) - logger.log(9, 'Retrieved user data') + logger.log(9, "Retrieved user data") master_list.extend(self.get_submissions_from_link()) - logger.log(9, 'Retrieved submissions for given links') + logger.log(9, "Retrieved submissions for given links") return master_list def determine_directories(self): @@ -178,37 +178,37 @@ def load_config(self): self.config_location = cfg_path return possible_paths = [ - Path('./config.cfg'), - Path('./default_config.cfg'), - Path(self.config_directory, 'config.cfg'), - Path(self.config_directory, 'default_config.cfg'), + Path("./config.cfg"), + Path("./default_config.cfg"), + Path(self.config_directory, "config.cfg"), + Path(self.config_directory, "default_config.cfg"), ] self.config_location = None for path in possible_paths: if path.resolve().expanduser().exists(): self.config_location = path - logger.debug(f'Loading configuration from {path}') + logger.debug(f"Loading configuration from {path}") break if not self.config_location: - with importlib.resources.path('bdfr', 'default_config.cfg') as path: + with importlib.resources.path("bdfr", "default_config.cfg") as path: self.config_location = path - shutil.copy(self.config_location, Path(self.config_directory, 'default_config.cfg')) + shutil.copy(self.config_location, Path(self.config_directory, "default_config.cfg")) if not self.config_location: - raise errors.BulkDownloaderException('Could not find a configuration file to load') + raise errors.BulkDownloaderException("Could not find a configuration file to load") self.cfg_parser.read(self.config_location) def create_file_logger(self): main_logger = logging.getLogger() if self.args.log is None: - log_path = Path(self.config_directory, 'log_output.txt') + log_path = Path(self.config_directory, "log_output.txt") else: log_path = Path(self.args.log).resolve().expanduser() if not log_path.parent.exists(): - raise errors.BulkDownloaderException(f'Designated location for logfile does not exist') - backup_count = self.cfg_parser.getint('DEFAULT', 'backup_log_count', fallback=3) + raise errors.BulkDownloaderException(f"Designated location for logfile does not exist") + backup_count = self.cfg_parser.getint("DEFAULT", "backup_log_count", fallback=3) file_handler = logging.handlers.RotatingFileHandler( log_path, - mode='a', + mode="a", backupCount=backup_count, ) if log_path.exists(): @@ -216,10 +216,11 @@ def create_file_logger(self): file_handler.doRollover() except PermissionError: logger.critical( - 'Cannot rollover logfile, make sure this is the only ' - 'BDFR process or specify alternate logfile location') + "Cannot rollover logfile, make sure this is the only " + "BDFR process or specify alternate logfile location" + ) raise - formatter = logging.Formatter('[%(asctime)s - %(name)s - %(levelname)s] - %(message)s') + formatter = logging.Formatter("[%(asctime)s - %(name)s - %(levelname)s] - %(message)s") file_handler.setFormatter(formatter) file_handler.setLevel(0) @@ -227,16 +228,16 @@ def create_file_logger(self): @staticmethod def sanitise_subreddit_name(subreddit: str) -> str: - pattern = re.compile(r'^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$') + pattern = re.compile(r"^(?:https://www\.reddit\.com/)?(?:r/)?(.*?)/?$") match = re.match(pattern, subreddit) if not match: - raise errors.BulkDownloaderException(f'Could not find subreddit name in string {subreddit}') + raise errors.BulkDownloaderException(f"Could not find subreddit name in string {subreddit}") return match.group(1) @staticmethod def split_args_input(entries: list[str]) -> set[str]: all_entries = [] - split_pattern = re.compile(r'[,;]\s?') + split_pattern = re.compile(r"[,;]\s?") for entry in entries: results = re.split(split_pattern, entry) all_entries.extend([RedditConnector.sanitise_subreddit_name(name) for name in results]) @@ -251,13 +252,13 @@ def get_subreddits(self) -> list[praw.models.ListingGenerator]: subscribed_subreddits = list(self.reddit_instance.user.subreddits(limit=None)) subscribed_subreddits = {s.display_name for s in subscribed_subreddits} except prawcore.InsufficientScope: - logger.error('BDFR has insufficient scope to access subreddit lists') + logger.error("BDFR has insufficient scope to access subreddit lists") else: - logger.error('Cannot find subscribed subreddits without an authenticated instance') + logger.error("Cannot find subscribed subreddits without an authenticated instance") if self.args.subreddit or subscribed_subreddits: for reddit in self.split_args_input(self.args.subreddit) | subscribed_subreddits: - if reddit == 'friends' and self.authenticated is False: - logger.error('Cannot read friends subreddit without an authenticated instance') + if reddit == "friends" and self.authenticated is False: + logger.error("Cannot read friends subreddit without an authenticated instance") continue try: reddit = self.reddit_instance.subreddit(reddit) @@ -267,26 +268,29 @@ def get_subreddits(self) -> list[praw.models.ListingGenerator]: logger.error(e) continue if self.args.search: - out.append(reddit.search( - self.args.search, - sort=self.sort_filter.name.lower(), - limit=self.args.limit, - time_filter=self.time_filter.value, - )) + out.append( + reddit.search( + self.args.search, + sort=self.sort_filter.name.lower(), + limit=self.args.limit, + time_filter=self.time_filter.value, + ) + ) logger.debug( - f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"') + f'Added submissions from subreddit {reddit} with the search term "{self.args.search}"' + ) else: out.append(self.create_filtered_listing_generator(reddit)) - logger.debug(f'Added submissions from subreddit {reddit}') + logger.debug(f"Added submissions from subreddit {reddit}") except (errors.BulkDownloaderException, praw.exceptions.PRAWException) as e: - logger.error(f'Failed to get submissions for subreddit {reddit}: {e}') + logger.error(f"Failed to get submissions for subreddit {reddit}: {e}") return out def resolve_user_name(self, in_name: str) -> str: - if in_name == 'me': + if in_name == "me": if self.authenticated: resolved_name = self.reddit_instance.user.me().name - logger.log(9, f'Resolved user to {resolved_name}') + logger.log(9, f"Resolved user to {resolved_name}") return resolved_name else: logger.warning('To use "me" as a user, an authenticated Reddit instance must be used') @@ -318,7 +322,7 @@ def determine_sort_function(self) -> Callable: def get_multireddits(self) -> list[Iterator]: if self.args.multireddit: if len(self.args.user) != 1: - logger.error(f'Only 1 user can be supplied when retrieving from multireddits') + logger.error(f"Only 1 user can be supplied when retrieving from multireddits") return [] out = [] for multi in self.split_args_input(self.args.multireddit): @@ -327,9 +331,9 @@ def get_multireddits(self) -> list[Iterator]: if not multi.subreddits: raise errors.BulkDownloaderException out.append(self.create_filtered_listing_generator(multi)) - logger.debug(f'Added submissions from multireddit {multi}') + logger.debug(f"Added submissions from multireddit {multi}") except (errors.BulkDownloaderException, praw.exceptions.PRAWException, prawcore.PrawcoreException) as e: - logger.error(f'Failed to get submissions for multireddit {multi}: {e}') + logger.error(f"Failed to get submissions for multireddit {multi}: {e}") return out else: return [] @@ -344,7 +348,7 @@ def create_filtered_listing_generator(self, reddit_source) -> Iterator: def get_user_data(self) -> list[Iterator]: if any([self.args.submitted, self.args.upvoted, self.args.saved]): if not self.args.user: - logger.warning('At least one user must be supplied to download user data') + logger.warning("At least one user must be supplied to download user data") return [] generators = [] for user in self.args.user: @@ -354,18 +358,20 @@ def get_user_data(self) -> list[Iterator]: logger.error(e) continue if self.args.submitted: - logger.debug(f'Retrieving submitted posts of user {self.args.user}') - generators.append(self.create_filtered_listing_generator( - self.reddit_instance.redditor(user).submissions, - )) + logger.debug(f"Retrieving submitted posts of user {self.args.user}") + generators.append( + self.create_filtered_listing_generator( + self.reddit_instance.redditor(user).submissions, + ) + ) if not self.authenticated and any((self.args.upvoted, self.args.saved)): - logger.warning('Accessing user lists requires authentication') + logger.warning("Accessing user lists requires authentication") else: if self.args.upvoted: - logger.debug(f'Retrieving upvoted posts of user {self.args.user}') + logger.debug(f"Retrieving upvoted posts of user {self.args.user}") generators.append(self.reddit_instance.redditor(user).upvoted(limit=self.args.limit)) if self.args.saved: - logger.debug(f'Retrieving saved posts of user {self.args.user}') + logger.debug(f"Retrieving saved posts of user {self.args.user}") generators.append(self.reddit_instance.redditor(user).saved(limit=self.args.limit)) return generators else: @@ -377,10 +383,10 @@ def check_user_existence(self, name: str): if user.id: return except prawcore.exceptions.NotFound: - raise errors.BulkDownloaderException(f'Could not find user {name}') + raise errors.BulkDownloaderException(f"Could not find user {name}") except AttributeError: - if hasattr(user, 'is_suspended'): - raise errors.BulkDownloaderException(f'User {name} is banned') + if hasattr(user, "is_suspended"): + raise errors.BulkDownloaderException(f"User {name} is banned") def create_file_name_formatter(self) -> FileNameFormatter: return FileNameFormatter(self.args.file_scheme, self.args.folder_scheme, self.args.time_format) @@ -409,7 +415,7 @@ def download(self): @staticmethod def check_subreddit_status(subreddit: praw.models.Subreddit): - if subreddit.display_name in ('all', 'friends'): + if subreddit.display_name in ("all", "friends"): return try: assert subreddit.id @@ -418,7 +424,7 @@ def check_subreddit_status(subreddit: praw.models.Subreddit): except prawcore.Redirect: raise errors.BulkDownloaderException(f"Source {subreddit.display_name} does not exist") except prawcore.Forbidden: - raise errors.BulkDownloaderException(f'Source {subreddit.display_name} is private and cannot be scraped') + raise errors.BulkDownloaderException(f"Source {subreddit.display_name} is private and cannot be scraped") @staticmethod def read_id_files(file_locations: list[str]) -> set[str]: @@ -426,9 +432,9 @@ def read_id_files(file_locations: list[str]) -> set[str]: for id_file in file_locations: id_file = Path(id_file).resolve().expanduser() if not id_file.exists(): - logger.warning(f'ID file at {id_file} does not exist') + logger.warning(f"ID file at {id_file} does not exist") continue - with id_file.open('r') as file: + with id_file.open("r") as file: for line in file: out.append(line.strip()) return set(out) diff --git a/bdfr/download_filter.py b/bdfr/download_filter.py index 28053bed..9019cc98 100644 --- a/bdfr/download_filter.py +++ b/bdfr/download_filter.py @@ -33,8 +33,8 @@ def check_resource(self, res: Resource) -> bool: def _check_extension(self, resource_extension: str) -> bool: if not self.excluded_extensions: return True - combined_extensions = '|'.join(self.excluded_extensions) - pattern = re.compile(r'.*({})$'.format(combined_extensions)) + combined_extensions = "|".join(self.excluded_extensions) + pattern = re.compile(r".*({})$".format(combined_extensions)) if re.match(pattern, resource_extension): logger.log(9, f'Url "{resource_extension}" matched with "{pattern}"') return False @@ -44,8 +44,8 @@ def _check_extension(self, resource_extension: str) -> bool: def _check_domain(self, url: str) -> bool: if not self.excluded_domains: return True - combined_domains = '|'.join(self.excluded_domains) - pattern = re.compile(r'https?://.*({}).*'.format(combined_domains)) + combined_domains = "|".join(self.excluded_domains) + pattern = re.compile(r"https?://.*({}).*".format(combined_domains)) if re.match(pattern, url): logger.log(9, f'Url "{url}" matched with "{pattern}"') return False diff --git a/bdfr/downloader.py b/bdfr/downloader.py index 6f269377..fa5d10c0 100644 --- a/bdfr/downloader.py +++ b/bdfr/downloader.py @@ -25,7 +25,7 @@ def _calc_hash(existing_file: Path): chunk_size = 1024 * 1024 md5_hash = hashlib.md5() - with existing_file.open('rb') as file: + with existing_file.open("rb") as file: chunk = file.read(chunk_size) while chunk: md5_hash.update(chunk) @@ -46,28 +46,32 @@ def download(self): try: self._download_submission(submission) except prawcore.PrawcoreException as e: - logger.error(f'Submission {submission.id} failed to download due to a PRAW exception: {e}') + logger.error(f"Submission {submission.id} failed to download due to a PRAW exception: {e}") def _download_submission(self, submission: praw.models.Submission): if submission.id in self.excluded_submission_ids: - logger.debug(f'Object {submission.id} in exclusion list, skipping') + logger.debug(f"Object {submission.id} in exclusion list, skipping") return elif submission.subreddit.display_name.lower() in self.args.skip_subreddit: - logger.debug(f'Submission {submission.id} in {submission.subreddit.display_name} in skip list') + logger.debug(f"Submission {submission.id} in {submission.subreddit.display_name} in skip list") return - elif (submission.author and submission.author.name in self.args.ignore_user) or \ - (submission.author is None and 'DELETED' in self.args.ignore_user): + elif (submission.author and submission.author.name in self.args.ignore_user) or ( + submission.author is None and "DELETED" in self.args.ignore_user + ): logger.debug( - f'Submission {submission.id} in {submission.subreddit.display_name} skipped' - f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user') + f"Submission {submission.id} in {submission.subreddit.display_name} skipped" + f' due to {submission.author.name if submission.author else "DELETED"} being an ignored user' + ) return elif self.args.min_score and submission.score < self.args.min_score: logger.debug( - f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]") + f"Submission {submission.id} filtered due to score {submission.score} < [{self.args.min_score}]" + ) return elif self.args.max_score and self.args.max_score < submission.score: logger.debug( - f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]") + f"Submission {submission.id} filtered due to score {submission.score} > [{self.args.max_score}]" + ) return elif (self.args.min_score_ratio and submission.upvote_ratio < self.args.min_score_ratio) or ( self.args.max_score_ratio and self.args.max_score_ratio < submission.upvote_ratio @@ -75,47 +79,48 @@ def _download_submission(self, submission: praw.models.Submission): logger.debug(f"Submission {submission.id} filtered due to score ratio ({submission.upvote_ratio})") return elif not isinstance(submission, praw.models.Submission): - logger.warning(f'{submission.id} is not a submission') + logger.warning(f"{submission.id} is not a submission") return elif not self.download_filter.check_url(submission.url): - logger.debug(f'Submission {submission.id} filtered due to URL {submission.url}') + logger.debug(f"Submission {submission.id} filtered due to URL {submission.url}") return - logger.debug(f'Attempting to download submission {submission.id}') + logger.debug(f"Attempting to download submission {submission.id}") try: downloader_class = DownloadFactory.pull_lever(submission.url) downloader = downloader_class(submission) - logger.debug(f'Using {downloader_class.__name__} with url {submission.url}') + logger.debug(f"Using {downloader_class.__name__} with url {submission.url}") except errors.NotADownloadableLinkError as e: - logger.error(f'Could not download submission {submission.id}: {e}') + logger.error(f"Could not download submission {submission.id}: {e}") return if downloader_class.__name__.lower() in self.args.disable_module: - logger.debug(f'Submission {submission.id} skipped due to disabled module {downloader_class.__name__}') + logger.debug(f"Submission {submission.id} skipped due to disabled module {downloader_class.__name__}") return try: content = downloader.find_resources(self.authenticator) except errors.SiteDownloaderError as e: - logger.error(f'Site {downloader_class.__name__} failed to download submission {submission.id}: {e}') + logger.error(f"Site {downloader_class.__name__} failed to download submission {submission.id}: {e}") return for destination, res in self.file_name_formatter.format_resource_paths(content, self.download_directory): if destination.exists(): - logger.debug(f'File {destination} from submission {submission.id} already exists, continuing') + logger.debug(f"File {destination} from submission {submission.id} already exists, continuing") continue elif not self.download_filter.check_resource(res): - logger.debug(f'Download filter removed {submission.id} file with URL {submission.url}') + logger.debug(f"Download filter removed {submission.id} file with URL {submission.url}") continue try: - res.download({'max_wait_time': self.args.max_wait_time}) + res.download({"max_wait_time": self.args.max_wait_time}) except errors.BulkDownloaderException as e: - logger.error(f'Failed to download resource {res.url} in submission {submission.id} ' - f'with downloader {downloader_class.__name__}: {e}') + logger.error( + f"Failed to download resource {res.url} in submission {submission.id} " + f"with downloader {downloader_class.__name__}: {e}" + ) return resource_hash = res.hash.hexdigest() destination.parent.mkdir(parents=True, exist_ok=True) if resource_hash in self.master_hash_list: if self.args.no_dupes: - logger.info( - f'Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere') + logger.info(f"Resource hash {resource_hash} from submission {submission.id} downloaded elsewhere") return elif self.args.make_hard_links: try: @@ -123,29 +128,30 @@ def _download_submission(self, submission: praw.models.Submission): except AttributeError: self.master_hash_list[resource_hash].link_to(destination) logger.info( - f'Hard link made linking {destination} to {self.master_hash_list[resource_hash]}' - f' in submission {submission.id}') + f"Hard link made linking {destination} to {self.master_hash_list[resource_hash]}" + f" in submission {submission.id}" + ) return try: - with destination.open('wb') as file: + with destination.open("wb") as file: file.write(res.content) - logger.debug(f'Written file to {destination}') + logger.debug(f"Written file to {destination}") except OSError as e: logger.exception(e) - logger.error(f'Failed to write file in submission {submission.id} to {destination}: {e}') + logger.error(f"Failed to write file in submission {submission.id} to {destination}: {e}") return creation_time = time.mktime(datetime.fromtimestamp(submission.created_utc).timetuple()) os.utime(destination, (creation_time, creation_time)) self.master_hash_list[resource_hash] = destination - logger.debug(f'Hash added to master list: {resource_hash}') - logger.info(f'Downloaded submission {submission.id} from {submission.subreddit.display_name}') + logger.debug(f"Hash added to master list: {resource_hash}") + logger.info(f"Downloaded submission {submission.id} from {submission.subreddit.display_name}") @staticmethod def scan_existing_files(directory: Path) -> dict[str, Path]: files = [] for (dirpath, dirnames, filenames) in os.walk(directory): files.extend([Path(dirpath, file) for file in filenames]) - logger.info(f'Calculating hashes for {len(files)} files') + logger.info(f"Calculating hashes for {len(files)} files") pool = Pool(15) results = pool.map(_calc_hash, files) diff --git a/bdfr/exceptions.py b/bdfr/exceptions.py index 91fda2c1..1757cd99 100644 --- a/bdfr/exceptions.py +++ b/bdfr/exceptions.py @@ -1,5 +1,6 @@ #!/usr/bin/env + class BulkDownloaderException(Exception): pass diff --git a/bdfr/file_name_formatter.py b/bdfr/file_name_formatter.py index 4a039c9d..684c6261 100644 --- a/bdfr/file_name_formatter.py +++ b/bdfr/file_name_formatter.py @@ -18,20 +18,20 @@ class FileNameFormatter: key_terms = ( - 'date', - 'flair', - 'postid', - 'redditor', - 'subreddit', - 'title', - 'upvotes', + "date", + "flair", + "postid", + "redditor", + "subreddit", + "title", + "upvotes", ) def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str): if not self.validate_string(file_format_string): raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string') self.file_format_string = file_format_string - self.directory_format_string: list[str] = directory_format_string.split('/') + self.directory_format_string: list[str] = directory_format_string.split("/") self.time_format_string = time_format_string def _format_name(self, submission: Union[Comment, Submission], format_string: str) -> str: @@ -40,108 +40,111 @@ def _format_name(self, submission: Union[Comment, Submission], format_string: st elif isinstance(submission, Comment): attributes = self._generate_name_dict_from_comment(submission) else: - raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}') + raise BulkDownloaderException(f"Cannot name object {type(submission).__name__}") result = format_string for key in attributes.keys(): - if re.search(fr'(?i).*{{{key}}}.*', result): - key_value = str(attributes.get(key, 'unknown')) + if re.search(rf"(?i).*{{{key}}}.*", result): + key_value = str(attributes.get(key, "unknown")) key_value = FileNameFormatter._convert_unicode_escapes(key_value) - key_value = key_value.replace('\\', '\\\\') - result = re.sub(fr'(?i){{{key}}}', key_value, result) + key_value = key_value.replace("\\", "\\\\") + result = re.sub(rf"(?i){{{key}}}", key_value, result) - result = result.replace('/', '') + result = result.replace("/", "") - if platform.system() == 'Windows': + if platform.system() == "Windows": result = FileNameFormatter._format_for_windows(result) return result @staticmethod def _convert_unicode_escapes(in_string: str) -> str: - pattern = re.compile(r'(\\u\d{4})') + pattern = re.compile(r"(\\u\d{4})") matches = re.search(pattern, in_string) if matches: for match in matches.groups(): - converted_match = bytes(match, 'utf-8').decode('unicode-escape') + converted_match = bytes(match, "utf-8").decode("unicode-escape") in_string = in_string.replace(match, converted_match) return in_string def _generate_name_dict_from_submission(self, submission: Submission) -> dict: submission_attributes = { - 'title': submission.title, - 'subreddit': submission.subreddit.display_name, - 'redditor': submission.author.name if submission.author else 'DELETED', - 'postid': submission.id, - 'upvotes': submission.score, - 'flair': submission.link_flair_text, - 'date': self._convert_timestamp(submission.created_utc), + "title": submission.title, + "subreddit": submission.subreddit.display_name, + "redditor": submission.author.name if submission.author else "DELETED", + "postid": submission.id, + "upvotes": submission.score, + "flair": submission.link_flair_text, + "date": self._convert_timestamp(submission.created_utc), } return submission_attributes def _convert_timestamp(self, timestamp: float) -> str: input_time = datetime.datetime.fromtimestamp(timestamp) - if self.time_format_string.upper().strip() == 'ISO': + if self.time_format_string.upper().strip() == "ISO": return input_time.isoformat() else: return input_time.strftime(self.time_format_string) def _generate_name_dict_from_comment(self, comment: Comment) -> dict: comment_attributes = { - 'title': comment.submission.title, - 'subreddit': comment.subreddit.display_name, - 'redditor': comment.author.name if comment.author else 'DELETED', - 'postid': comment.id, - 'upvotes': comment.score, - 'flair': '', - 'date': self._convert_timestamp(comment.created_utc), + "title": comment.submission.title, + "subreddit": comment.subreddit.display_name, + "redditor": comment.author.name if comment.author else "DELETED", + "postid": comment.id, + "upvotes": comment.score, + "flair": "", + "date": self._convert_timestamp(comment.created_utc), } return comment_attributes def format_path( - self, - resource: Resource, - destination_directory: Path, - index: Optional[int] = None, + self, + resource: Resource, + destination_directory: Path, + index: Optional[int] = None, ) -> Path: subfolder = Path( destination_directory, *[self._format_name(resource.source_submission, part) for part in self.directory_format_string], ) - index = f'_{index}' if index else '' + index = f"_{index}" if index else "" if not resource.extension: - raise BulkDownloaderException(f'Resource from {resource.url} has no extension') + raise BulkDownloaderException(f"Resource from {resource.url} has no extension") file_name = str(self._format_name(resource.source_submission, self.file_format_string)) - file_name = re.sub(r'\n', ' ', file_name) + file_name = re.sub(r"\n", " ", file_name) - if not re.match(r'.*\.$', file_name) and not re.match(r'^\..*', resource.extension): - ending = index + '.' + resource.extension + if not re.match(r".*\.$", file_name) and not re.match(r"^\..*", resource.extension): + ending = index + "." + resource.extension else: ending = index + resource.extension try: file_path = self.limit_file_name_length(file_name, ending, subfolder) except TypeError: - raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}') + raise BulkDownloaderException(f"Could not determine path name: {subfolder}, {index}, {resource.extension}") return file_path @staticmethod def limit_file_name_length(filename: str, ending: str, root: Path) -> Path: root = root.resolve().expanduser() - possible_id = re.search(r'((?:_\w{6})?$)', filename) + possible_id = re.search(r"((?:_\w{6})?$)", filename) if possible_id: ending = possible_id.group(1) + ending - filename = filename[:possible_id.start()] + filename = filename[: possible_id.start()] max_path = FileNameFormatter.find_max_path_length() max_file_part_length_chars = 255 - len(ending) - max_file_part_length_bytes = 255 - len(ending.encode('utf-8')) + max_file_part_length_bytes = 255 - len(ending.encode("utf-8")) max_path_length = max_path - len(ending) - len(str(root)) - 1 out = Path(root, filename + ending) - while any([len(filename) > max_file_part_length_chars, - len(filename.encode('utf-8')) > max_file_part_length_bytes, - len(str(out)) > max_path_length, - ]): + while any( + [ + len(filename) > max_file_part_length_chars, + len(filename.encode("utf-8")) > max_file_part_length_bytes, + len(str(out)) > max_path_length, + ] + ): filename = filename[:-1] out = Path(root, filename + ending) @@ -150,44 +153,46 @@ def limit_file_name_length(filename: str, ending: str, root: Path) -> Path: @staticmethod def find_max_path_length() -> int: try: - return int(subprocess.check_output(['getconf', 'PATH_MAX', '/'])) + return int(subprocess.check_output(["getconf", "PATH_MAX", "/"])) except (ValueError, subprocess.CalledProcessError, OSError): - if platform.system() == 'Windows': + if platform.system() == "Windows": return 260 else: return 4096 def format_resource_paths( - self, - resources: list[Resource], - destination_directory: Path, + self, + resources: list[Resource], + destination_directory: Path, ) -> list[tuple[Path, Resource]]: out = [] if len(resources) == 1: try: out.append((self.format_path(resources[0], destination_directory, None), resources[0])) except BulkDownloaderException as e: - logger.error(f'Could not generate file path for resource {resources[0].url}: {e}') - logger.exception('Could not generate file path') + logger.error(f"Could not generate file path for resource {resources[0].url}: {e}") + logger.exception("Could not generate file path") else: for i, res in enumerate(resources, start=1): - logger.log(9, f'Formatting filename with index {i}') + logger.log(9, f"Formatting filename with index {i}") try: out.append((self.format_path(res, destination_directory, i), res)) except BulkDownloaderException as e: - logger.error(f'Could not generate file path for resource {res.url}: {e}') - logger.exception('Could not generate file path') + logger.error(f"Could not generate file path for resource {res.url}: {e}") + logger.exception("Could not generate file path") return out @staticmethod def validate_string(test_string: str) -> bool: if not test_string: return False - result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms]) + result = any([f"{{{key}}}" in test_string.lower() for key in FileNameFormatter.key_terms]) if result: - if 'POSTID' not in test_string: - logger.warning('Some files might not be downloaded due to name conflicts as filenames are' - ' not guaranteed to be be unique without {POSTID}') + if "POSTID" not in test_string: + logger.warning( + "Some files might not be downloaded due to name conflicts as filenames are" + " not guaranteed to be be unique without {POSTID}" + ) return True else: return False @@ -196,11 +201,11 @@ def validate_string(test_string: str) -> bool: def _format_for_windows(input_string: str) -> str: invalid_characters = r'<>:"\/|?*' for char in invalid_characters: - input_string = input_string.replace(char, '') + input_string = input_string.replace(char, "") input_string = FileNameFormatter._strip_emojis(input_string) return input_string @staticmethod def _strip_emojis(input_string: str) -> str: - result = input_string.encode('ascii', errors='ignore').decode('utf-8') + result = input_string.encode("ascii", errors="ignore").decode("utf-8") return result diff --git a/bdfr/oauth2.py b/bdfr/oauth2.py index bd60c9b8..60f21692 100644 --- a/bdfr/oauth2.py +++ b/bdfr/oauth2.py @@ -17,7 +17,6 @@ class OAuth2Authenticator: - def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str): self._check_scopes(wanted_scopes) self.scopes = wanted_scopes @@ -26,39 +25,41 @@ def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str): @staticmethod def _check_scopes(wanted_scopes: set[str]): - response = requests.get('https://www.reddit.com/api/v1/scopes.json', - headers={'User-Agent': 'fetch-scopes test'}) + response = requests.get( + "https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes test"} + ) known_scopes = [scope for scope, data in response.json().items()] - known_scopes.append('*') + known_scopes.append("*") for scope in wanted_scopes: if scope not in known_scopes: - raise BulkDownloaderException(f'Scope {scope} is not known to reddit') + raise BulkDownloaderException(f"Scope {scope} is not known to reddit") @staticmethod def split_scopes(scopes: str) -> set[str]: - scopes = re.split(r'[,: ]+', scopes) + scopes = re.split(r"[,: ]+", scopes) return set(scopes) def retrieve_new_token(self) -> str: reddit = praw.Reddit( - redirect_uri='http://localhost:7634', - user_agent='obtain_refresh_token for BDFR', + redirect_uri="http://localhost:7634", + user_agent="obtain_refresh_token for BDFR", client_id=self.client_id, - client_secret=self.client_secret) + client_secret=self.client_secret, + ) state = str(random.randint(0, 65000)) - url = reddit.auth.url(self.scopes, state, 'permanent') - logger.warning('Authentication action required before the program can proceed') - logger.warning(f'Authenticate at {url}') + url = reddit.auth.url(self.scopes, state, "permanent") + logger.warning("Authentication action required before the program can proceed") + logger.warning(f"Authenticate at {url}") client = self.receive_connection() - data = client.recv(1024).decode('utf-8') - param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&') - params = {key: value for (key, value) in [token.split('=') for token in param_tokens]} + data = client.recv(1024).decode("utf-8") + param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&") + params = {key: value for (key, value) in [token.split("=") for token in param_tokens]} - if state != params['state']: + if state != params["state"]: self.send_message(client) raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}') - elif 'error' in params: + elif "error" in params: self.send_message(client) raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}') @@ -70,19 +71,19 @@ def retrieve_new_token(self) -> str: def receive_connection() -> socket.socket: server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server.bind(('0.0.0.0', 7634)) - logger.log(9, 'Server listening on 0.0.0.0:7634') + server.bind(("0.0.0.0", 7634)) + logger.log(9, "Server listening on 0.0.0.0:7634") server.listen(1) client = server.accept()[0] server.close() - logger.log(9, 'Server closed') + logger.log(9, "Server closed") return client @staticmethod - def send_message(client: socket.socket, message: str = ''): - client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8')) + def send_message(client: socket.socket, message: str = ""): + client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8")) client.close() @@ -94,14 +95,14 @@ def __init__(self, config: configparser.ConfigParser, config_location: Path): def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer): if authorizer.refresh_token is None: - if self.config.has_option('DEFAULT', 'user_token'): - authorizer.refresh_token = self.config.get('DEFAULT', 'user_token') - logger.log(9, 'Loaded OAuth2 token for authoriser') + if self.config.has_option("DEFAULT", "user_token"): + authorizer.refresh_token = self.config.get("DEFAULT", "user_token") + logger.log(9, "Loaded OAuth2 token for authoriser") else: - raise RedditAuthenticationError('No auth token loaded in configuration') + raise RedditAuthenticationError("No auth token loaded in configuration") def post_refresh_callback(self, authorizer: praw.reddit.Authorizer): - self.config.set('DEFAULT', 'user_token', authorizer.refresh_token) - with open(self.config_location, 'w') as file: + self.config.set("DEFAULT", "user_token", authorizer.refresh_token) + with open(self.config_location, "w") as file: self.config.write(file, True) - logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}') + logger.log(9, f"Written OAuth2 token from authoriser to {self.config_location}") diff --git a/bdfr/resource.py b/bdfr/resource.py index 68a42e1f..0f5404c5 100644 --- a/bdfr/resource.py +++ b/bdfr/resource.py @@ -39,7 +39,7 @@ def download(self, download_parameters: Optional[dict] = None): try: content = self.download_function(download_parameters) except requests.exceptions.ConnectionError as e: - raise BulkDownloaderException(f'Could not download resource: {e}') + raise BulkDownloaderException(f"Could not download resource: {e}") except BulkDownloaderException: raise if content: @@ -51,7 +51,7 @@ def create_hash(self): self.hash = hashlib.md5(self.content) def _determine_extension(self) -> Optional[str]: - extension_pattern = re.compile(r'.*(\..{3,5})$') + extension_pattern = re.compile(r".*(\..{3,5})$") stripped_url = urllib.parse.urlsplit(self.url).path match = re.search(extension_pattern, stripped_url) if match: @@ -59,27 +59,28 @@ def _determine_extension(self) -> Optional[str]: @staticmethod def http_download(url: str, download_parameters: dict) -> Optional[bytes]: - headers = download_parameters.get('headers') + headers = download_parameters.get("headers") current_wait_time = 60 - if 'max_wait_time' in download_parameters: - max_wait_time = download_parameters['max_wait_time'] + if "max_wait_time" in download_parameters: + max_wait_time = download_parameters["max_wait_time"] else: max_wait_time = 300 while True: try: response = requests.get(url, headers=headers) - if re.match(r'^2\d{2}', str(response.status_code)) and response.content: + if re.match(r"^2\d{2}", str(response.status_code)) and response.content: return response.content elif response.status_code in (408, 429): - raise requests.exceptions.ConnectionError(f'Response code {response.status_code}') + raise requests.exceptions.ConnectionError(f"Response code {response.status_code}") else: raise BulkDownloaderException( - f'Unrecoverable error requesting resource: HTTP Code {response.status_code}') + f"Unrecoverable error requesting resource: HTTP Code {response.status_code}" + ) except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e: - logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}') + logger.warning(f"Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}") time.sleep(current_wait_time) if current_wait_time < max_wait_time: current_wait_time += 60 else: - logger.error(f'Max wait time exceeded for resource at url {url}') + logger.error(f"Max wait time exceeded for resource at url {url}") raise diff --git a/bdfr/site_downloaders/base_downloader.py b/bdfr/site_downloaders/base_downloader.py index 10787b81..f3ecec5f 100644 --- a/bdfr/site_downloaders/base_downloader.py +++ b/bdfr/site_downloaders/base_downloader.py @@ -31,7 +31,7 @@ def retrieve_url(url: str, cookies: dict = None, headers: dict = None) -> reques res = requests.get(url, cookies=cookies, headers=headers) except requests.exceptions.RequestException as e: logger.exception(e) - raise SiteDownloaderError(f'Failed to get page {url}') + raise SiteDownloaderError(f"Failed to get page {url}") if res.status_code != 200: - raise ResourceNotFound(f'Server responded with {res.status_code} to {url}') + raise ResourceNotFound(f"Server responded with {res.status_code} to {url}") return res diff --git a/bdfr/site_downloaders/delay_for_reddit.py b/bdfr/site_downloaders/delay_for_reddit.py index 149e403e..33807316 100644 --- a/bdfr/site_downloaders/delay_for_reddit.py +++ b/bdfr/site_downloaders/delay_for_reddit.py @@ -5,8 +5,8 @@ from praw.models import Submission -from bdfr.site_authenticator import SiteAuthenticator from bdfr.resource import Resource +from bdfr.site_authenticator import SiteAuthenticator from bdfr.site_downloaders.base_downloader import BaseDownloader logger = logging.getLogger(__name__) diff --git a/bdfr/site_downloaders/direct.py b/bdfr/site_downloaders/direct.py index 833acae4..4a6ac92e 100644 --- a/bdfr/site_downloaders/direct.py +++ b/bdfr/site_downloaders/direct.py @@ -4,8 +4,8 @@ from praw.models import Submission -from bdfr.site_authenticator import SiteAuthenticator from bdfr.resource import Resource +from bdfr.site_authenticator import SiteAuthenticator from bdfr.site_downloaders.base_downloader import BaseDownloader diff --git a/bdfr/site_downloaders/download_factory.py b/bdfr/site_downloaders/download_factory.py index 75beeae3..638316fb 100644 --- a/bdfr/site_downloaders/download_factory.py +++ b/bdfr/site_downloaders/download_factory.py @@ -26,62 +26,63 @@ class DownloadFactory: @staticmethod def pull_lever(url: str) -> Type[BaseDownloader]: sanitised_url = DownloadFactory.sanitise_url(url) - if re.match(r'(i\.|m\.)?imgur', sanitised_url): + if re.match(r"(i\.|m\.)?imgur", sanitised_url): return Imgur - elif re.match(r'(i\.)?(redgifs|gifdeliverynetwork)', sanitised_url): + elif re.match(r"(i\.)?(redgifs|gifdeliverynetwork)", sanitised_url): return Redgifs - elif re.match(r'.*/.*\.\w{3,4}(\?[\w;&=]*)?$', sanitised_url) and \ - not DownloadFactory.is_web_resource(sanitised_url): + elif re.match(r".*/.*\.\w{3,4}(\?[\w;&=]*)?$", sanitised_url) and not DownloadFactory.is_web_resource( + sanitised_url + ): return Direct - elif re.match(r'erome\.com.*', sanitised_url): + elif re.match(r"erome\.com.*", sanitised_url): return Erome - elif re.match(r'delayforreddit\.com', sanitised_url): + elif re.match(r"delayforreddit\.com", sanitised_url): return DelayForReddit - elif re.match(r'reddit\.com/gallery/.*', sanitised_url): + elif re.match(r"reddit\.com/gallery/.*", sanitised_url): return Gallery - elif re.match(r'patreon\.com.*', sanitised_url): + elif re.match(r"patreon\.com.*", sanitised_url): return Gallery - elif re.match(r'gfycat\.', sanitised_url): + elif re.match(r"gfycat\.", sanitised_url): return Gfycat - elif re.match(r'reddit\.com/r/', sanitised_url): + elif re.match(r"reddit\.com/r/", sanitised_url): return SelfPost - elif re.match(r'(m\.)?youtu\.?be', sanitised_url): + elif re.match(r"(m\.)?youtu\.?be", sanitised_url): return Youtube - elif re.match(r'i\.redd\.it.*', sanitised_url): + elif re.match(r"i\.redd\.it.*", sanitised_url): return Direct - elif re.match(r'v\.redd\.it.*', sanitised_url): + elif re.match(r"v\.redd\.it.*", sanitised_url): return VReddit - elif re.match(r'pornhub\.com.*', sanitised_url): + elif re.match(r"pornhub\.com.*", sanitised_url): return PornHub - elif re.match(r'vidble\.com', sanitised_url): + elif re.match(r"vidble\.com", sanitised_url): return Vidble elif YtdlpFallback.can_handle_link(sanitised_url): return YtdlpFallback else: - raise NotADownloadableLinkError(f'No downloader module exists for url {url}') + raise NotADownloadableLinkError(f"No downloader module exists for url {url}") @staticmethod def sanitise_url(url: str) -> str: - beginning_regex = re.compile(r'\s*(www\.?)?') + beginning_regex = re.compile(r"\s*(www\.?)?") split_url = urllib.parse.urlsplit(url) split_url = split_url.netloc + split_url.path - split_url = re.sub(beginning_regex, '', split_url) + split_url = re.sub(beginning_regex, "", split_url) return split_url @staticmethod def is_web_resource(url: str) -> bool: web_extensions = ( - 'asp', - 'aspx', - 'cfm', - 'cfml', - 'css', - 'htm', - 'html', - 'js', - 'php', - 'php3', - 'xhtml', + "asp", + "aspx", + "cfm", + "cfml", + "css", + "htm", + "html", + "js", + "php", + "php3", + "xhtml", ) if re.match(rf'(?i).*/.*\.({"|".join(web_extensions)})$', url): return True diff --git a/bdfr/site_downloaders/erome.py b/bdfr/site_downloaders/erome.py index 6250415f..26469bc4 100644 --- a/bdfr/site_downloaders/erome.py +++ b/bdfr/site_downloaders/erome.py @@ -23,34 +23,34 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l links = self._get_links(self.post.url) if not links: - raise SiteDownloaderError('Erome parser could not find any links') + raise SiteDownloaderError("Erome parser could not find any links") out = [] for link in links: - if not re.match(r'https?://.*', link): - link = 'https://' + link + if not re.match(r"https?://.*", link): + link = "https://" + link out.append(Resource(self.post, link, self.erome_download(link))) return out @staticmethod def _get_links(url: str) -> set[str]: page = Erome.retrieve_url(url) - soup = bs4.BeautifulSoup(page.text, 'html.parser') - front_images = soup.find_all('img', attrs={'class': 'lasyload'}) - out = [im.get('data-src') for im in front_images] + soup = bs4.BeautifulSoup(page.text, "html.parser") + front_images = soup.find_all("img", attrs={"class": "lasyload"}) + out = [im.get("data-src") for im in front_images] - videos = soup.find_all('source') - out.extend([vid.get('src') for vid in videos]) + videos = soup.find_all("source") + out.extend([vid.get("src") for vid in videos]) return set(out) @staticmethod def erome_download(url: str) -> Callable: download_parameters = { - 'headers': { - 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)' - ' Chrome/88.0.4324.104 Safari/537.36', - 'Referer': 'https://www.erome.com/', + "headers": { + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/88.0.4324.104 Safari/537.36", + "Referer": "https://www.erome.com/", }, } return lambda global_params: Resource.http_download(url, global_params | download_parameters) diff --git a/bdfr/site_downloaders/fallback_downloaders/fallback_downloader.py b/bdfr/site_downloaders/fallback_downloaders/fallback_downloader.py index deeb213f..3bc615dd 100644 --- a/bdfr/site_downloaders/fallback_downloaders/fallback_downloader.py +++ b/bdfr/site_downloaders/fallback_downloaders/fallback_downloader.py @@ -7,7 +7,6 @@ class BaseFallbackDownloader(BaseDownloader, ABC): - @staticmethod @abstractmethod def can_handle_link(url: str) -> bool: diff --git a/bdfr/site_downloaders/fallback_downloaders/ytdlp_fallback.py b/bdfr/site_downloaders/fallback_downloaders/ytdlp_fallback.py index 1225624d..6109b7ac 100644 --- a/bdfr/site_downloaders/fallback_downloaders/ytdlp_fallback.py +++ b/bdfr/site_downloaders/fallback_downloaders/ytdlp_fallback.py @@ -9,7 +9,9 @@ from bdfr.exceptions import NotADownloadableLinkError from bdfr.resource import Resource from bdfr.site_authenticator import SiteAuthenticator -from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import BaseFallbackDownloader +from bdfr.site_downloaders.fallback_downloaders.fallback_downloader import ( + BaseFallbackDownloader, +) from bdfr.site_downloaders.youtube import Youtube logger = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l self.post, self.post.url, super()._download_video({}), - super().get_video_attributes(self.post.url)['ext'], + super().get_video_attributes(self.post.url)["ext"], ) return [out] diff --git a/bdfr/site_downloaders/gallery.py b/bdfr/site_downloaders/gallery.py index eeb9e0f8..278932f0 100644 --- a/bdfr/site_downloaders/gallery.py +++ b/bdfr/site_downloaders/gallery.py @@ -20,27 +20,27 @@ def __init__(self, post: Submission): def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: try: - image_urls = self._get_links(self.post.gallery_data['items']) + image_urls = self._get_links(self.post.gallery_data["items"]) except (AttributeError, TypeError): try: - image_urls = self._get_links(self.post.crosspost_parent_list[0]['gallery_data']['items']) + image_urls = self._get_links(self.post.crosspost_parent_list[0]["gallery_data"]["items"]) except (AttributeError, IndexError, TypeError, KeyError): - logger.error(f'Could not find gallery data in submission {self.post.id}') - logger.exception('Gallery image find failure') - raise SiteDownloaderError('No images found in Reddit gallery') + logger.error(f"Could not find gallery data in submission {self.post.id}") + logger.exception("Gallery image find failure") + raise SiteDownloaderError("No images found in Reddit gallery") if not image_urls: - raise SiteDownloaderError('No images found in Reddit gallery') + raise SiteDownloaderError("No images found in Reddit gallery") return [Resource(self.post, url, Resource.retry_download(url)) for url in image_urls] - @ staticmethod + @staticmethod def _get_links(id_dict: list[dict]) -> list[str]: out = [] for item in id_dict: - image_id = item['media_id'] - possible_extensions = ('.jpg', '.png', '.gif', '.gifv', '.jpeg') + image_id = item["media_id"] + possible_extensions = (".jpg", ".png", ".gif", ".gifv", ".jpeg") for extension in possible_extensions: - test_url = f'https://i.redd.it/{image_id}{extension}' + test_url = f"https://i.redd.it/{image_id}{extension}" response = requests.head(test_url) if response.status_code == 200: out.append(test_url) diff --git a/bdfr/site_downloaders/gfycat.py b/bdfr/site_downloaders/gfycat.py index c8da9df4..7862d338 100644 --- a/bdfr/site_downloaders/gfycat.py +++ b/bdfr/site_downloaders/gfycat.py @@ -22,21 +22,23 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l @staticmethod def _get_link(url: str) -> set[str]: - gfycat_id = re.match(r'.*/(.*?)/?$', url).group(1) - url = 'https://gfycat.com/' + gfycat_id + gfycat_id = re.match(r".*/(.*?)/?$", url).group(1) + url = "https://gfycat.com/" + gfycat_id response = Gfycat.retrieve_url(url) - if re.search(r'(redgifs|gifdeliverynetwork)', response.url): + if re.search(r"(redgifs|gifdeliverynetwork)", response.url): url = url.lower() # Fixes error with old gfycat/redgifs links return Redgifs._get_link(url) - soup = BeautifulSoup(response.text, 'html.parser') - content = soup.find('script', attrs={'data-react-helmet': 'true', 'type': 'application/ld+json'}) + soup = BeautifulSoup(response.text, "html.parser") + content = soup.find("script", attrs={"data-react-helmet": "true", "type": "application/ld+json"}) try: - out = json.loads(content.contents[0])['video']['contentUrl'] + out = json.loads(content.contents[0])["video"]["contentUrl"] except (IndexError, KeyError, AttributeError) as e: - raise SiteDownloaderError(f'Failed to download Gfycat link {url}: {e}') + raise SiteDownloaderError(f"Failed to download Gfycat link {url}: {e}") except json.JSONDecodeError as e: - raise SiteDownloaderError(f'Did not receive valid JSON data: {e}') - return {out,} + raise SiteDownloaderError(f"Did not receive valid JSON data: {e}") + return { + out, + } diff --git a/bdfr/site_downloaders/imgur.py b/bdfr/site_downloaders/imgur.py index 0688b105..f91e34f2 100644 --- a/bdfr/site_downloaders/imgur.py +++ b/bdfr/site_downloaders/imgur.py @@ -14,7 +14,6 @@ class Imgur(BaseDownloader): - def __init__(self, post: Submission): super().__init__(post) self.raw_data = {} @@ -23,63 +22,63 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l self.raw_data = self._get_data(self.post.url) out = [] - if 'album_images' in self.raw_data: - images = self.raw_data['album_images'] - for image in images['images']: + if "album_images" in self.raw_data: + images = self.raw_data["album_images"] + for image in images["images"]: out.append(self._compute_image_url(image)) else: out.append(self._compute_image_url(self.raw_data)) return out def _compute_image_url(self, image: dict) -> Resource: - ext = self._validate_extension(image['ext']) - if image.get('prefer_video', False): - ext = '.mp4' + ext = self._validate_extension(image["ext"]) + if image.get("prefer_video", False): + ext = ".mp4" - image_url = 'https://i.imgur.com/' + image['hash'] + ext + image_url = "https://i.imgur.com/" + image["hash"] + ext return Resource(self.post, image_url, Resource.retry_download(image_url)) @staticmethod def _get_data(link: str) -> dict: try: - imgur_id = re.match(r'.*/(.*?)(\..{0,})?$', link).group(1) - gallery = 'a/' if re.search(r'.*/(.*?)(gallery/|a/)', link) else '' - link = f'https://imgur.com/{gallery}{imgur_id}' + imgur_id = re.match(r".*/(.*?)(\..{0,})?$", link).group(1) + gallery = "a/" if re.search(r".*/(.*?)(gallery/|a/)", link) else "" + link = f"https://imgur.com/{gallery}{imgur_id}" except AttributeError: - raise SiteDownloaderError(f'Could not extract Imgur ID from {link}') + raise SiteDownloaderError(f"Could not extract Imgur ID from {link}") - res = Imgur.retrieve_url(link, cookies={'over18': '1', 'postpagebeta': '0'}) + res = Imgur.retrieve_url(link, cookies={"over18": "1", "postpagebeta": "0"}) - soup = bs4.BeautifulSoup(res.text, 'html.parser') - scripts = soup.find_all('script', attrs={'type': 'text/javascript'}) - scripts = [script.string.replace('\n', '') for script in scripts if script.string] + soup = bs4.BeautifulSoup(res.text, "html.parser") + scripts = soup.find_all("script", attrs={"type": "text/javascript"}) + scripts = [script.string.replace("\n", "") for script in scripts if script.string] - script_regex = re.compile(r'\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'') + script_regex = re.compile(r"\s*\(function\(widgetFactory\)\s*{\s*widgetFactory\.mergeConfig\(\'gallery\'") chosen_script = list(filter(lambda s: re.search(script_regex, s), scripts)) if len(chosen_script) != 1: - raise SiteDownloaderError(f'Could not read page source from {link}') + raise SiteDownloaderError(f"Could not read page source from {link}") chosen_script = chosen_script[0] - outer_regex = re.compile(r'widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);') - inner_regex = re.compile(r'image\s*:(.*),\s*group') + outer_regex = re.compile(r"widgetFactory\.mergeConfig\(\'gallery\', ({.*})\);") + inner_regex = re.compile(r"image\s*:(.*),\s*group") try: image_dict = re.search(outer_regex, chosen_script).group(1) image_dict = re.search(inner_regex, image_dict).group(1) except AttributeError: - raise SiteDownloaderError(f'Could not find image dictionary in page source') + raise SiteDownloaderError(f"Could not find image dictionary in page source") try: image_dict = json.loads(image_dict) except json.JSONDecodeError as e: - raise SiteDownloaderError(f'Could not parse received dict as JSON: {e}') + raise SiteDownloaderError(f"Could not parse received dict as JSON: {e}") return image_dict @staticmethod def _validate_extension(extension_suffix: str) -> str: - extension_suffix = re.sub(r'\?.*', '', extension_suffix) - possible_extensions = ('.jpg', '.png', '.mp4', '.gif') + extension_suffix = re.sub(r"\?.*", "", extension_suffix) + possible_extensions = (".jpg", ".png", ".mp4", ".gif") selection = [ext for ext in possible_extensions if ext == extension_suffix] if len(selection) == 1: return selection[0] diff --git a/bdfr/site_downloaders/pornhub.py b/bdfr/site_downloaders/pornhub.py index 748454ee..db377207 100644 --- a/bdfr/site_downloaders/pornhub.py +++ b/bdfr/site_downloaders/pornhub.py @@ -20,11 +20,11 @@ def __init__(self, post: Submission): def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: ytdl_options = { - 'format': 'best', - 'nooverwrites': True, + "format": "best", + "nooverwrites": True, } if video_attributes := super().get_video_attributes(self.post.url): - extension = video_attributes['ext'] + extension = video_attributes["ext"] else: raise SiteDownloaderError() diff --git a/bdfr/site_downloaders/redgifs.py b/bdfr/site_downloaders/redgifs.py index dd194139..625cf7d3 100644 --- a/bdfr/site_downloaders/redgifs.py +++ b/bdfr/site_downloaders/redgifs.py @@ -2,9 +2,9 @@ import json import re -import requests from typing import Optional +import requests from praw.models import Submission from bdfr.exceptions import SiteDownloaderError @@ -24,52 +24,53 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l @staticmethod def _get_link(url: str) -> set[str]: try: - redgif_id = re.match(r'.*/(.*?)(\..{0,})?$', url).group(1) + redgif_id = re.match(r".*/(.*?)(\..{0,})?$", url).group(1) except AttributeError: - raise SiteDownloaderError(f'Could not extract Redgifs ID from {url}') + raise SiteDownloaderError(f"Could not extract Redgifs ID from {url}") - auth_token = json.loads(Redgifs.retrieve_url('https://api.redgifs.com/v2/auth/temporary').text)['token'] + auth_token = json.loads(Redgifs.retrieve_url("https://api.redgifs.com/v2/auth/temporary").text)["token"] if not auth_token: - raise SiteDownloaderError('Unable to retrieve Redgifs API token') + raise SiteDownloaderError("Unable to retrieve Redgifs API token") headers = { - 'referer': 'https://www.redgifs.com/', - 'origin': 'https://www.redgifs.com', - 'content-type': 'application/json', - 'Authorization': f'Bearer {auth_token}', + "referer": "https://www.redgifs.com/", + "origin": "https://www.redgifs.com", + "content-type": "application/json", + "Authorization": f"Bearer {auth_token}", } - content = Redgifs.retrieve_url(f'https://api.redgifs.com/v2/gifs/{redgif_id}', headers=headers) + content = Redgifs.retrieve_url(f"https://api.redgifs.com/v2/gifs/{redgif_id}", headers=headers) if content is None: - raise SiteDownloaderError('Could not read the page source') + raise SiteDownloaderError("Could not read the page source") try: response_json = json.loads(content.text) except json.JSONDecodeError as e: - raise SiteDownloaderError(f'Received data was not valid JSON: {e}') + raise SiteDownloaderError(f"Received data was not valid JSON: {e}") out = set() try: - if response_json['gif']['type'] == 1: # type 1 is a video - if requests.get(response_json['gif']['urls']['hd'], headers=headers).ok: - out.add(response_json['gif']['urls']['hd']) + if response_json["gif"]["type"] == 1: # type 1 is a video + if requests.get(response_json["gif"]["urls"]["hd"], headers=headers).ok: + out.add(response_json["gif"]["urls"]["hd"]) else: - out.add(response_json['gif']['urls']['sd']) - elif response_json['gif']['type'] == 2: # type 2 is an image - if response_json['gif']['gallery']: + out.add(response_json["gif"]["urls"]["sd"]) + elif response_json["gif"]["type"] == 2: # type 2 is an image + if response_json["gif"]["gallery"]: content = Redgifs.retrieve_url( - f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}') + f'https://api.redgifs.com/v2/gallery/{response_json["gif"]["gallery"]}' + ) response_json = json.loads(content.text) - out = {p['urls']['hd'] for p in response_json['gifs']} + out = {p["urls"]["hd"] for p in response_json["gifs"]} else: - out.add(response_json['gif']['urls']['hd']) + out.add(response_json["gif"]["urls"]["hd"]) else: raise KeyError except (KeyError, AttributeError): - raise SiteDownloaderError('Failed to find JSON data in page') + raise SiteDownloaderError("Failed to find JSON data in page") # Update subdomain if old one is returned - out = {re.sub('thumbs2', 'thumbs3', link) for link in out} - out = {re.sub('thumbs3', 'thumbs4', link) for link in out} + out = {re.sub("thumbs2", "thumbs3", link) for link in out} + out = {re.sub("thumbs3", "thumbs4", link) for link in out} return out diff --git a/bdfr/site_downloaders/self_post.py b/bdfr/site_downloaders/self_post.py index 6e4ce0e3..1b76b922 100644 --- a/bdfr/site_downloaders/self_post.py +++ b/bdfr/site_downloaders/self_post.py @@ -17,27 +17,29 @@ def __init__(self, post: Submission): super().__init__(post) def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: - out = Resource(self.post, self.post.url, lambda: None, '.txt') - out.content = self.export_to_string().encode('utf-8') + out = Resource(self.post, self.post.url, lambda: None, ".txt") + out.content = self.export_to_string().encode("utf-8") out.create_hash() return [out] def export_to_string(self) -> str: """Self posts are formatted here""" - content = ("## [" - + self.post.fullname - + "](" - + self.post.url - + ")\n" - + self.post.selftext - + "\n\n---\n\n" - + "submitted to [r/" - + self.post.subreddit.title - + "](https://www.reddit.com/r/" - + self.post.subreddit.title - + ") by [u/" - + (self.post.author.name if self.post.author else "DELETED") - + "](https://www.reddit.com/user/" - + (self.post.author.name if self.post.author else "DELETED") - + ")") + content = ( + "## [" + + self.post.fullname + + "](" + + self.post.url + + ")\n" + + self.post.selftext + + "\n\n---\n\n" + + "submitted to [r/" + + self.post.subreddit.title + + "](https://www.reddit.com/r/" + + self.post.subreddit.title + + ") by [u/" + + (self.post.author.name if self.post.author else "DELETED") + + "](https://www.reddit.com/user/" + + (self.post.author.name if self.post.author else "DELETED") + + ")" + ) return content diff --git a/bdfr/site_downloaders/vidble.py b/bdfr/site_downloaders/vidble.py index 5cea0cbf..a79ee25d 100644 --- a/bdfr/site_downloaders/vidble.py +++ b/bdfr/site_downloaders/vidble.py @@ -25,30 +25,30 @@ def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> l try: res = self.get_links(self.post.url) except AttributeError: - raise SiteDownloaderError(f'Could not read page at {self.post.url}') + raise SiteDownloaderError(f"Could not read page at {self.post.url}") if not res: - raise SiteDownloaderError(rf'No resources found at {self.post.url}') + raise SiteDownloaderError(rf"No resources found at {self.post.url}") res = [Resource(self.post, r, Resource.retry_download(r)) for r in res] return res @staticmethod def get_links(url: str) -> set[str]: - if not re.search(r'vidble.com/(show/|album/|watch\?v)', url): - url = re.sub(r'/(\w*?)$', r'/show/\1', url) + if not re.search(r"vidble.com/(show/|album/|watch\?v)", url): + url = re.sub(r"/(\w*?)$", r"/show/\1", url) page = requests.get(url) - soup = bs4.BeautifulSoup(page.text, 'html.parser') - content_div = soup.find('div', attrs={'id': 'ContentPlaceHolder1_divContent'}) - images = content_div.find_all('img') - images = [i.get('src') for i in images] - videos = content_div.find_all('source', attrs={'type': 'video/mp4'}) - videos = [v.get('src') for v in videos] + soup = bs4.BeautifulSoup(page.text, "html.parser") + content_div = soup.find("div", attrs={"id": "ContentPlaceHolder1_divContent"}) + images = content_div.find_all("img") + images = [i.get("src") for i in images] + videos = content_div.find_all("source", attrs={"type": "video/mp4"}) + videos = [v.get("src") for v in videos] resources = filter(None, itertools.chain(images, videos)) - resources = ['https://www.vidble.com' + r for r in resources] + resources = ["https://www.vidble.com" + r for r in resources] resources = [Vidble.change_med_url(r) for r in resources] return set(resources) @staticmethod def change_med_url(url: str) -> str: - out = re.sub(r'_med(\..{3,4})$', r'\1', url) + out = re.sub(r"_med(\..{3,4})$", r"\1", url) return out diff --git a/bdfr/site_downloaders/vreddit.py b/bdfr/site_downloaders/vreddit.py index ad526b40..a71d3504 100644 --- a/bdfr/site_downloaders/vreddit.py +++ b/bdfr/site_downloaders/vreddit.py @@ -22,18 +22,18 @@ def __init__(self, post: Submission): def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: ytdl_options = { - 'playlistend': 1, - 'nooverwrites': True, + "playlistend": 1, + "nooverwrites": True, } download_function = self._download_video(ytdl_options) - extension = self.get_video_attributes(self.post.url)['ext'] + extension = self.get_video_attributes(self.post.url)["ext"] res = Resource(self.post, self.post.url, download_function, extension) return [res] @staticmethod def get_video_attributes(url: str) -> dict: result = VReddit.get_video_data(url) - if 'ext' in result: + if "ext" in result: return result else: try: @@ -41,4 +41,4 @@ def get_video_attributes(url: str) -> dict: return result except Exception as e: logger.exception(e) - raise NotADownloadableLinkError(f'Video info extraction failed for {url}') + raise NotADownloadableLinkError(f"Video info extraction failed for {url}") diff --git a/bdfr/site_downloaders/youtube.py b/bdfr/site_downloaders/youtube.py index 315fd0ad..f4f8622c 100644 --- a/bdfr/site_downloaders/youtube.py +++ b/bdfr/site_downloaders/youtube.py @@ -22,57 +22,62 @@ def __init__(self, post: Submission): def find_resources(self, authenticator: Optional[SiteAuthenticator] = None) -> list[Resource]: ytdl_options = { - 'format': 'best', - 'playlistend': 1, - 'nooverwrites': True, + "format": "best", + "playlistend": 1, + "nooverwrites": True, } download_function = self._download_video(ytdl_options) - extension = self.get_video_attributes(self.post.url)['ext'] + extension = self.get_video_attributes(self.post.url)["ext"] res = Resource(self.post, self.post.url, download_function, extension) return [res] def _download_video(self, ytdl_options: dict) -> Callable: - yt_logger = logging.getLogger('youtube-dl') + yt_logger = logging.getLogger("youtube-dl") yt_logger.setLevel(logging.CRITICAL) - ytdl_options['quiet'] = True - ytdl_options['logger'] = yt_logger + ytdl_options["quiet"] = True + ytdl_options["logger"] = yt_logger def download(_: dict) -> bytes: with tempfile.TemporaryDirectory() as temp_dir: download_path = Path(temp_dir).resolve() - ytdl_options['outtmpl'] = str(download_path) + '/' + 'test.%(ext)s' + ytdl_options["outtmpl"] = str(download_path) + "/" + "test.%(ext)s" try: with yt_dlp.YoutubeDL(ytdl_options) as ydl: ydl.download([self.post.url]) except yt_dlp.DownloadError as e: - raise SiteDownloaderError(f'Youtube download failed: {e}') + raise SiteDownloaderError(f"Youtube download failed: {e}") downloaded_files = list(download_path.iterdir()) if downloaded_files: downloaded_file = downloaded_files[0] else: raise NotADownloadableLinkError(f"No media exists in the URL {self.post.url}") - with downloaded_file.open('rb') as file: + with downloaded_file.open("rb") as file: content = file.read() return content + return download @staticmethod def get_video_data(url: str) -> dict: - yt_logger = logging.getLogger('youtube-dl') + yt_logger = logging.getLogger("youtube-dl") yt_logger.setLevel(logging.CRITICAL) - with yt_dlp.YoutubeDL({'logger': yt_logger, }) as ydl: + with yt_dlp.YoutubeDL( + { + "logger": yt_logger, + } + ) as ydl: try: result = ydl.extract_info(url, download=False) except Exception as e: logger.exception(e) - raise NotADownloadableLinkError(f'Video info extraction failed for {url}') + raise NotADownloadableLinkError(f"Video info extraction failed for {url}") return result @staticmethod def get_video_attributes(url: str) -> dict: result = Youtube.get_video_data(url) - if 'ext' in result: + if "ext" in result: return result else: - raise NotADownloadableLinkError(f'Video info extraction failed for {url}') + raise NotADownloadableLinkError(f"Video info extraction failed for {url}") diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000..af48d1b6 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,4 @@ +black +isort +pytest +tox diff --git a/dev_requirements.txt b/dev_requirements.txt deleted file mode 100644 index e079f8a6..00000000 --- a/dev_requirements.txt +++ /dev/null @@ -1 +0,0 @@ -pytest diff --git a/scripts/tests/bats b/scripts/tests/bats index e8c840b5..ce5ca280 160000 --- a/scripts/tests/bats +++ b/scripts/tests/bats @@ -1 +1 @@ -Subproject commit e8c840b58f0833e23461c682655fe540aa923f85 +Subproject commit ce5ca2802fabe5dc38393240cd40e20f8928d3b0 diff --git a/scripts/tests/test_helper/bats-assert b/scripts/tests/test_helper/bats-assert index 78fa631d..e0de84e9 160000 --- a/scripts/tests/test_helper/bats-assert +++ b/scripts/tests/test_helper/bats-assert @@ -1 +1 @@ -Subproject commit 78fa631d1370562d2cd4a1390989e706158e7bf0 +Subproject commit e0de84e9c011223e7f88b7ccf1c929f4327097ba diff --git a/tests/archive_entry/test_comment_archive_entry.py b/tests/archive_entry/test_comment_archive_entry.py index e453d276..8e6f2249 100644 --- a/tests/archive_entry/test_comment_archive_entry.py +++ b/tests/archive_entry/test_comment_archive_entry.py @@ -9,15 +9,21 @@ @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_comment_id', 'expected_dict'), ( - ('gstd4hk', { - 'author': 'james_pic', - 'subreddit': 'Python', - 'submission': 'mgi4op', - 'submission_title': '76% Faster CPython', - 'distinguished': None, - }), -)) +@pytest.mark.parametrize( + ("test_comment_id", "expected_dict"), + ( + ( + "gstd4hk", + { + "author": "james_pic", + "subreddit": "Python", + "submission": "mgi4op", + "submission_title": "76% Faster CPython", + "distinguished": None, + }, + ), + ), +) def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_instance: praw.Reddit): comment = reddit_instance.comment(id=test_comment_id) test_entry = CommentArchiveEntry(comment) @@ -27,13 +33,16 @@ def test_get_comment_details(test_comment_id: str, expected_dict: dict, reddit_i @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_comment_id', 'expected_min_comments'), ( - ('gstd4hk', 4), - ('gsvyste', 3), - ('gsxnvvb', 5), -)) +@pytest.mark.parametrize( + ("test_comment_id", "expected_min_comments"), + ( + ("gstd4hk", 4), + ("gsvyste", 3), + ("gsxnvvb", 5), + ), +) def test_get_comment_replies(test_comment_id: str, expected_min_comments: int, reddit_instance: praw.Reddit): comment = reddit_instance.comment(id=test_comment_id) test_entry = CommentArchiveEntry(comment) result = test_entry.compile() - assert len(result.get('replies')) >= expected_min_comments + assert len(result.get("replies")) >= expected_min_comments diff --git a/tests/archive_entry/test_submission_archive_entry.py b/tests/archive_entry/test_submission_archive_entry.py index 045eabd6..666eec35 100644 --- a/tests/archive_entry/test_submission_archive_entry.py +++ b/tests/archive_entry/test_submission_archive_entry.py @@ -9,9 +9,7 @@ @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'min_comments'), ( - ('m3reby', 27), -)) +@pytest.mark.parametrize(("test_submission_id", "min_comments"), (("m3reby", 27),)) def test_get_comments(test_submission_id: str, min_comments: int, reddit_instance: praw.Reddit): test_submission = reddit_instance.submission(id=test_submission_id) test_archive_entry = SubmissionArchiveEntry(test_submission) @@ -21,21 +19,27 @@ def test_get_comments(test_submission_id: str, min_comments: int, reddit_instanc @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_dict'), ( - ('m3reby', { - 'author': 'sinjen-tos', - 'id': 'm3reby', - 'link_flair_text': 'image', - 'pinned': False, - 'spoiler': False, - 'over_18': False, - 'locked': False, - 'distinguished': None, - 'created_utc': 1615583837, - 'permalink': '/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/' - }), - # TODO: add deleted user test case -)) +@pytest.mark.parametrize( + ("test_submission_id", "expected_dict"), + ( + ( + "m3reby", + { + "author": "sinjen-tos", + "id": "m3reby", + "link_flair_text": "image", + "pinned": False, + "spoiler": False, + "over_18": False, + "locked": False, + "distinguished": None, + "created_utc": 1615583837, + "permalink": "/r/australia/comments/m3reby/this_little_guy_fell_out_of_a_tree_and_in_front/", + }, + ), + # TODO: add deleted user test case + ), +) def test_get_post_details(test_submission_id: str, expected_dict: dict, reddit_instance: praw.Reddit): test_submission = reddit_instance.submission(id=test_submission_id) test_archive_entry = SubmissionArchiveEntry(test_submission) diff --git a/tests/conftest.py b/tests/conftest.py index a61d8d51..3f871a39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,29 +11,29 @@ from bdfr.oauth2 import OAuth2TokenManager -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def reddit_instance(): rd = praw.Reddit( - client_id='U-6gk4ZCh3IeNQ', - client_secret='7CZHY6AmKweZME5s50SfDGylaPg', - user_agent='test', + client_id="U-6gk4ZCh3IeNQ", + client_secret="7CZHY6AmKweZME5s50SfDGylaPg", + user_agent="test", ) return rd -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def authenticated_reddit_instance(): - test_config_path = Path('./tests/test_config.cfg') + test_config_path = Path("./tests/test_config.cfg") if not test_config_path.exists(): - pytest.skip('Refresh token must be provided to authenticate with OAuth2') + pytest.skip("Refresh token must be provided to authenticate with OAuth2") cfg_parser = configparser.ConfigParser() cfg_parser.read(test_config_path) - if not cfg_parser.has_option('DEFAULT', 'user_token'): - pytest.skip('Refresh token must be provided to authenticate with OAuth2') + if not cfg_parser.has_option("DEFAULT", "user_token"): + pytest.skip("Refresh token must be provided to authenticate with OAuth2") token_manager = OAuth2TokenManager(cfg_parser, test_config_path) reddit_instance = praw.Reddit( - client_id=cfg_parser.get('DEFAULT', 'client_id'), - client_secret=cfg_parser.get('DEFAULT', 'client_secret'), + client_id=cfg_parser.get("DEFAULT", "client_id"), + client_secret=cfg_parser.get("DEFAULT", "client_secret"), user_agent=socket.gethostname(), token_manager=token_manager, ) diff --git a/tests/integration_tests/test_archive_integration.py b/tests/integration_tests/test_archive_integration.py index caf6fcb6..f10f37ca 100644 --- a/tests/integration_tests/test_archive_integration.py +++ b/tests/integration_tests/test_archive_integration.py @@ -10,67 +10,78 @@ from bdfr.__main__ import cli -does_test_config_exist = Path('./tests/test_config.cfg').exists() +does_test_config_exist = Path("./tests/test_config.cfg").exists() def copy_test_config(run_path: Path): - shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, 'test_config.cfg')) + shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "test_config.cfg")) def create_basic_args_for_archive_runner(test_args: list[str], run_path: Path): copy_test_config(run_path) out = [ - 'archive', + "archive", str(run_path), - '-v', - '--config', str(Path(run_path, 'test_config.cfg')), - '--log', str(Path(run_path, 'test_log.txt')), + "-v", + "--config", + str(Path(run_path, "test_config.cfg")), + "--log", + str(Path(run_path, "test_log.txt")), ] + test_args return out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', 'gstd4hk'], - ['-l', 'm2601g', '-f', 'yaml'], - ['-l', 'n60t4c', '-f', 'xml'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "gstd4hk"], + ["-l", "m2601g", "-f", "yaml"], + ["-l", "n60t4c", "-f", "xml"], + ), +) def test_cli_archive_single(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert re.search(r'Writing entry .*? to file in .*? format', result.output) + assert re.search(r"Writing entry .*? to file in .*? format", result.output) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'Mindustry', '-L', 25], - ['--subreddit', 'Mindustry', '-L', 25, '--format', 'xml'], - ['--subreddit', 'Mindustry', '-L', 25, '--format', 'yaml'], - ['--subreddit', 'Mindustry', '-L', 25, '--sort', 'new'], - ['--subreddit', 'Mindustry', '-L', 25, '--time', 'day'], - ['--subreddit', 'Mindustry', '-L', 25, '--time', 'day', '--sort', 'new'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--subreddit", "Mindustry", "-L", 25], + ["--subreddit", "Mindustry", "-L", 25, "--format", "xml"], + ["--subreddit", "Mindustry", "-L", 25, "--format", "yaml"], + ["--subreddit", "Mindustry", "-L", 25, "--sort", "new"], + ["--subreddit", "Mindustry", "-L", 25, "--time", "day"], + ["--subreddit", "Mindustry", "-L", 25, "--time", "day", "--sort", "new"], + ), +) def test_cli_archive_subreddit(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert re.search(r'Writing entry .*? to file in .*? format', result.output) + assert re.search(r"Writing entry .*? to file in .*? format", result.output) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'me', '--authenticate', '--all-comments', '-L', '10'], - ['--user', 'me', '--user', 'djnish', '--authenticate', '--all-comments', '-L', '10'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--user", "me", "--authenticate", "--all-comments", "-L", "10"], + ["--user", "me", "--user", "djnish", "--authenticate", "--all-comments", "-L", "10"], + ), +) def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) @@ -80,89 +91,88 @@ def test_cli_archive_all_user_comments(test_args: list[str], tmp_path: Path): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--comment-context', '--link', 'gxqapql'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--comment-context", "--link", "gxqapql"],)) def test_cli_archive_full_context(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Converting comment' in result.output + assert "Converting comment" in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.slow -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'all', '-L', 100], - ['--subreddit', 'all', '-L', 100, '--sort', 'new'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--subreddit", "all", "-L", 100], + ["--subreddit", "all", "-L", 100, "--sort", "new"], + ), +) def test_cli_archive_long(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert re.search(r'Writing entry .*? to file in .*? format', result.output) + assert re.search(r"Writing entry .*? to file in .*? format", result.output) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--ignore-user', 'ArjanEgges', '-l', 'm3hxzd'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--ignore-user", "ArjanEgges", "-l", "m3hxzd"],)) def test_cli_archive_ignore_user(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'being an ignored user' in result.output - assert 'Attempting to archive submission' not in result.output + assert "being an ignored user" in result.output + assert "Attempting to archive submission" not in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--file-scheme', '{TITLE}', '-l', 'suy011'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--file-scheme", "{TITLE}", "-l", "suy011"],)) def test_cli_archive_file_format(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Attempting to archive submission' in result.output - assert re.search('format at /.+?/Judge says Trump and two adult', result.output) + assert "Attempting to archive submission" in result.output + assert re.search("format at /.+?/Judge says Trump and two adult", result.output) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', 'm2601g', '--exclude-id', 'm2601g'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["-l", "m2601g", "--exclude-id", "m2601g"],)) def test_cli_archive_links_exclusion(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'in exclusion list' in result.output - assert 'Attempting to archive' not in result.output + assert "in exclusion list" in result.output + assert "Attempting to archive" not in result.output + @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', 'ijy4ch'], # user deleted post - ['-l', 'kw4wjm'], # post from banned subreddit -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "ijy4ch"], # user deleted post + ["-l", "kw4wjm"], # post from banned subreddit + ), +) def test_cli_archive_soft_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_archive_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'failed to be archived due to a PRAW exception' in result.output - assert 'Attempting to archive' not in result.output + assert "failed to be archived due to a PRAW exception" in result.output + assert "Attempting to archive" not in result.output diff --git a/tests/integration_tests/test_clone_integration.py b/tests/integration_tests/test_clone_integration.py index 8046687b..e8dc0088 100644 --- a/tests/integration_tests/test_clone_integration.py +++ b/tests/integration_tests/test_clone_integration.py @@ -9,54 +9,62 @@ from bdfr.__main__ import cli -does_test_config_exist = Path('./tests/test_config.cfg').exists() +does_test_config_exist = Path("./tests/test_config.cfg").exists() def copy_test_config(run_path: Path): - shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, 'test_config.cfg')) + shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "test_config.cfg")) def create_basic_args_for_cloner_runner(test_args: list[str], tmp_path: Path): copy_test_config(tmp_path) out = [ - 'clone', + "clone", str(tmp_path), - '-v', - '--config', str(Path(tmp_path, 'test_config.cfg')), - '--log', str(Path(tmp_path, 'test_log.txt')), + "-v", + "--config", + str(Path(tmp_path, "test_config.cfg")), + "--log", + str(Path(tmp_path, "test_log.txt")), ] + test_args return out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', '6l7778'], - ['-s', 'TrollXChromosomes/', '-L', 1], - ['-l', 'eiajjw'], - ['-l', 'xl0lhi'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "6l7778"], + ["-s", "TrollXChromosomes/", "-L", 1], + ["-l", "eiajjw"], + ["-l", "xl0lhi"], + ), +) def test_cli_scrape_general(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_cloner_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded submission' in result.output - assert 'Record for entry item' in result.output + assert "Downloaded submission" in result.output + assert "Record for entry item" in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', 'ijy4ch'], # user deleted post - ['-l', 'kw4wjm'], # post from banned subreddit -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "ijy4ch"], # user deleted post + ["-l", "kw4wjm"], # post from banned subreddit + ), +) def test_cli_scrape_soft_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_cloner_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded submission' not in result.output - assert 'Record for entry item' not in result.output + assert "Downloaded submission" not in result.output + assert "Record for entry item" not in result.output diff --git a/tests/integration_tests/test_download_integration.py b/tests/integration_tests/test_download_integration.py index 83f972db..2ab38a06 100644 --- a/tests/integration_tests/test_download_integration.py +++ b/tests/integration_tests/test_download_integration.py @@ -9,97 +9,107 @@ from bdfr.__main__ import cli -does_test_config_exist = Path('./tests/test_config.cfg').exists() +does_test_config_exist = Path("./tests/test_config.cfg").exists() def copy_test_config(run_path: Path): - shutil.copy(Path('./tests/test_config.cfg'), Path(run_path, './test_config.cfg')) + shutil.copy(Path("./tests/test_config.cfg"), Path(run_path, "./test_config.cfg")) def create_basic_args_for_download_runner(test_args: list[str], run_path: Path): copy_test_config(run_path) out = [ - 'download', str(run_path), - '-v', - '--config', str(Path(run_path, './test_config.cfg')), - '--log', str(Path(run_path, 'test_log.txt')), + "download", + str(run_path), + "-v", + "--config", + str(Path(run_path, "./test_config.cfg")), + "--log", + str(Path(run_path, "test_log.txt")), ] + test_args return out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-s', 'Mindustry', '-L', 3], - ['-s', 'r/Mindustry', '-L', 3], - ['-s', 'r/mindustry', '-L', 3], - ['-s', 'mindustry', '-L', 3], - ['-s', 'https://www.reddit.com/r/TrollXChromosomes/', '-L', 3], - ['-s', 'r/TrollXChromosomes/', '-L', 3], - ['-s', 'TrollXChromosomes/', '-L', 3], - ['-s', 'trollxchromosomes', '-L', 3], - ['-s', 'trollxchromosomes,mindustry,python', '-L', 3], - ['-s', 'trollxchromosomes, mindustry, python', '-L', 3], - ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day'], - ['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new'], - ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new'], - ['-s', 'trollxchromosomes', '-L', 3, '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 3, '--sort', 'new', '--search', 'women'], - ['-s', 'trollxchromosomes', '-L', 3, '--time', 'day', '--sort', 'new', '--search', 'women'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-s", "Mindustry", "-L", 3], + ["-s", "r/Mindustry", "-L", 3], + ["-s", "r/mindustry", "-L", 3], + ["-s", "mindustry", "-L", 3], + ["-s", "https://www.reddit.com/r/TrollXChromosomes/", "-L", 3], + ["-s", "r/TrollXChromosomes/", "-L", 3], + ["-s", "TrollXChromosomes/", "-L", 3], + ["-s", "trollxchromosomes", "-L", 3], + ["-s", "trollxchromosomes,mindustry,python", "-L", 3], + ["-s", "trollxchromosomes, mindustry, python", "-L", 3], + ["-s", "trollxchromosomes", "-L", 3, "--time", "day"], + ["-s", "trollxchromosomes", "-L", 3, "--sort", "new"], + ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new"], + ["-s", "trollxchromosomes", "-L", 3, "--search", "women"], + ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--search", "women"], + ["-s", "trollxchromosomes", "-L", 3, "--sort", "new", "--search", "women"], + ["-s", "trollxchromosomes", "-L", 3, "--time", "day", "--sort", "new", "--search", "women"], + ), +) def test_cli_download_subreddits(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Added submissions from subreddit ' in result.output - assert 'Downloaded submission' in result.output + assert "Added submissions from subreddit " in result.output + assert "Downloaded submission" in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.slow @pytest.mark.authenticated -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-s', 'hentai', '-L', 10, '--search', 'red', '--authenticate'], - ['--authenticate', '--subscribed', '-L', 10], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-s", "hentai", "-L", 10, "--search", "red", "--authenticate"], + ["--authenticate", "--subscribed", "-L", 10], + ), +) def test_cli_download_search_subreddits_authenticated(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Added submissions from subreddit ' in result.output - assert 'Downloaded submission' in result.output + assert "Added submissions from subreddit " in result.output + assert "Downloaded submission" in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'friends', '-L', 10, '--authenticate'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--subreddit", "friends", "-L", 10, "--authenticate"],)) def test_cli_download_user_specific_subreddits(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Added submissions from subreddit ' in result.output + assert "Added submissions from subreddit " in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', '6l7778'], - ['-l', 'https://reddit.com/r/EmpireDidNothingWrong/comments/6l7778/technically_true/'], - ['-l', 'm3hxzd'], # Really long title used to overflow filename limit - ['-l', 'm5bqkf'], # Resource leading to a 404 -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "6l7778"], + ["-l", "https://reddit.com/r/EmpireDidNothingWrong/comments/6l7778/technically_true/"], + ["-l", "m3hxzd"], # Really long title used to overflow filename limit + ["-l", "m5bqkf"], # Resource leading to a 404 + ), +) def test_cli_download_links(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) @@ -109,64 +119,66 @@ def test_cli_download_links(test_args: list[str], tmp_path: Path): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], - ['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--sort', 'rising'], - ['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week'], - ['--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10, '--time', 'week', '--sort', 'rising'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10], + ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--sort", "rising"], + ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--time", "week"], + ["--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10, "--time", "week", "--sort", "rising"], + ), +) def test_cli_download_multireddit(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Added submissions from multireddit ' in result.output + assert "Added submissions from multireddit " in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'helen_darten', '-m', 'xxyyzzqwerty', '-L', 10], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--user", "helen_darten", "-m", "xxyyzzqwerty", "-L", 10],)) def test_cli_download_multireddit_nonexistent(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Failed to get submissions for multireddit' in result.output - assert 'received 404 HTTP response' in result.output + assert "Failed to get submissions for multireddit" in result.output + assert "received 404 HTTP response" in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'djnish', '--submitted', '--user', 'FriesWithThat', '-L', 10], - ['--user', 'me', '--upvoted', '--authenticate', '-L', 10], - ['--user', 'me', '--saved', '--authenticate', '-L', 10], - ['--user', 'me', '--submitted', '--authenticate', '-L', 10], - ['--user', 'djnish', '--submitted', '-L', 10], - ['--user', 'djnish', '--submitted', '-L', 10, '--time', 'month'], - ['--user', 'djnish', '--submitted', '-L', 10, '--sort', 'controversial'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--user", "djnish", "--submitted", "--user", "FriesWithThat", "-L", 10], + ["--user", "me", "--upvoted", "--authenticate", "-L", 10], + ["--user", "me", "--saved", "--authenticate", "-L", 10], + ["--user", "me", "--submitted", "--authenticate", "-L", 10], + ["--user", "djnish", "--submitted", "-L", 10], + ["--user", "djnish", "--submitted", "-L", 10, "--time", "month"], + ["--user", "djnish", "--submitted", "-L", 10, "--sort", "controversial"], + ), +) def test_cli_download_user_data_good(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded submission ' in result.output + assert "Downloaded submission " in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'me', '-L', 10, '--folder-scheme', ''], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--user", "me", "-L", 10, "--folder-scheme", ""],)) def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) @@ -177,42 +189,41 @@ def test_cli_download_user_data_bad_me_unauthenticated(test_args: list[str], tmp @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'python', '-L', 1, '--search-existing'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--subreddit", "python", "-L", 1, "--search-existing"],)) def test_cli_download_search_existing(test_args: list[str], tmp_path: Path): - Path(tmp_path, 'test.txt').touch() + Path(tmp_path, "test.txt").touch() runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Calculating hashes for' in result.output + assert "Calculating hashes for" in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'tumblr', '-L', '25', '--skip', 'png', '--skip', 'jpg'], - ['--subreddit', 'MaliciousCompliance', '-L', '25', '--skip', 'txt'], - ['--subreddit', 'tumblr', '-L', '10', '--skip-domain', 'i.redd.it'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--subreddit", "tumblr", "-L", "25", "--skip", "png", "--skip", "jpg"], + ["--subreddit", "MaliciousCompliance", "-L", "25", "--skip", "txt"], + ["--subreddit", "tumblr", "-L", "10", "--skip-domain", "i.redd.it"], + ), +) def test_cli_download_download_filters(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert any((string in result.output for string in ('Download filter removed ', 'filtered due to URL'))) + assert any((string in result.output for string in ("Download filter removed ", "filtered due to URL"))) @pytest.mark.online @pytest.mark.reddit @pytest.mark.slow -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--subreddit', 'all', '-L', '100', '--sort', 'new'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--subreddit", "all", "-L", "100", "--sort", "new"],)) def test_cli_download_long(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) @@ -223,34 +234,40 @@ def test_cli_download_long(test_args: list[str], tmp_path: Path): @pytest.mark.online @pytest.mark.reddit @pytest.mark.slow -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--user', 'sdclhgsolgjeroij', '--submitted', '-L', 10], - ['--user', 'me', '--upvoted', '-L', 10], - ['--user', 'sdclhgsolgjeroij', '--upvoted', '-L', 10], - ['--subreddit', 'submitters', '-L', 10], # Private subreddit - ['--subreddit', 'donaldtrump', '-L', 10], # Banned subreddit - ['--user', 'djnish', '--user', 'helen_darten', '-m', 'cuteanimalpics', '-L', 10], - ['--subreddit', 'friends', '-L', 10], - ['-l', 'ijy4ch'], # user deleted post - ['-l', 'kw4wjm'], # post from banned subreddit -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--user", "sdclhgsolgjeroij", "--submitted", "-L", 10], + ["--user", "me", "--upvoted", "-L", 10], + ["--user", "sdclhgsolgjeroij", "--upvoted", "-L", 10], + ["--subreddit", "submitters", "-L", 10], # Private subreddit + ["--subreddit", "donaldtrump", "-L", 10], # Banned subreddit + ["--user", "djnish", "--user", "helen_darten", "-m", "cuteanimalpics", "-L", 10], + ["--subreddit", "friends", "-L", 10], + ["-l", "ijy4ch"], # user deleted post + ["-l", "kw4wjm"], # post from banned subreddit + ), +) def test_cli_download_soft_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded' not in result.output + assert "Downloaded" not in result.output @pytest.mark.online @pytest.mark.reddit @pytest.mark.slow -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--time', 'random'], - ['--sort', 'random'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--time", "random"], + ["--sort", "random"], + ), +) def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) @@ -260,114 +277,122 @@ def test_cli_download_hard_fail(test_args: list[str], tmp_path: Path): def test_cli_download_use_default_config(tmp_path: Path): runner = CliRunner() - test_args = ['download', '-vv', str(tmp_path)] + test_args = ["download", "-vv", str(tmp_path)] result = runner.invoke(cli, test_args) assert result.exit_code == 0 @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', '6l7778', '--exclude-id', '6l7778'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["-l", "6l7778", "--exclude-id", "6l7778"],)) def test_cli_download_links_exclusion(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'in exclusion list' in result.output - assert 'Downloaded submission ' not in result.output + assert "in exclusion list" in result.output + assert "Downloaded submission " not in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', '6l7778', '--skip-subreddit', 'EmpireDidNothingWrong'], - ['-s', 'trollxchromosomes', '--skip-subreddit', 'trollxchromosomes', '-L', '3'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "6l7778", "--skip-subreddit", "EmpireDidNothingWrong"], + ["-s", "trollxchromosomes", "--skip-subreddit", "trollxchromosomes", "-L", "3"], + ), +) def test_cli_download_subreddit_exclusion(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'in skip list' in result.output - assert 'Downloaded submission ' not in result.output + assert "in skip list" in result.output + assert "Downloaded submission " not in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--file-scheme', '{TITLE}'], - ['--file-scheme', '{TITLE}_test_{SUBREDDIT}'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["--file-scheme", "{TITLE}"], + ["--file-scheme", "{TITLE}_test_{SUBREDDIT}"], + ), +) def test_cli_download_file_scheme_warning(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Some files might not be downloaded due to name conflicts' in result.output + assert "Some files might not be downloaded due to name conflicts" in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['-l', 'n9w9fo', '--disable-module', 'SelfPost'], - ['-l', 'nnb9vs', '--disable-module', 'VReddit'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + "test_args", + ( + ["-l", "n9w9fo", "--disable-module", "SelfPost"], + ["-l", "nnb9vs", "--disable-module", "VReddit"], + ), +) def test_cli_download_disable_modules(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'skipped due to disabled module' in result.output - assert 'Downloaded submission' not in result.output + assert "skipped due to disabled module" in result.output + assert "Downloaded submission" not in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") def test_cli_download_include_id_file(tmp_path: Path): - test_file = Path(tmp_path, 'include.txt') - test_args = ['--include-id-file', str(test_file)] - test_file.write_text('odr9wg\nody576') + test_file = Path(tmp_path, "include.txt") + test_args = ["--include-id-file", str(test_file)] + test_file.write_text("odr9wg\nody576") runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded submission' in result.output + assert "Downloaded submission" in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize('test_args', ( - ['--ignore-user', 'ArjanEgges', '-l', 'm3hxzd'], -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize("test_args", (["--ignore-user", "ArjanEgges", "-l", "m3hxzd"],)) def test_cli_download_ignore_user(test_args: list[str], tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert 'Downloaded submission' not in result.output - assert 'being an ignored user' in result.output + assert "Downloaded submission" not in result.output + assert "being an ignored user" in result.output @pytest.mark.online @pytest.mark.reddit -@pytest.mark.skipif(not does_test_config_exist, reason='A test config file is required for integration tests') -@pytest.mark.parametrize(('test_args', 'was_filtered'), ( - (['-l', 'ljyy27', '--min-score', '50'], True), - (['-l', 'ljyy27', '--min-score', '1'], False), - (['-l', 'ljyy27', '--max-score', '1'], True), - (['-l', 'ljyy27', '--max-score', '100'], False), -)) +@pytest.mark.skipif(not does_test_config_exist, reason="A test config file is required for integration tests") +@pytest.mark.parametrize( + ("test_args", "was_filtered"), + ( + (["-l", "ljyy27", "--min-score", "50"], True), + (["-l", "ljyy27", "--min-score", "1"], False), + (["-l", "ljyy27", "--max-score", "1"], True), + (["-l", "ljyy27", "--max-score", "100"], False), + ), +) def test_cli_download_score_filter(test_args: list[str], was_filtered: bool, tmp_path: Path): runner = CliRunner() test_args = create_basic_args_for_download_runner(test_args, tmp_path) result = runner.invoke(cli, test_args) assert result.exit_code == 0 - assert ('filtered due to score' in result.output) == was_filtered + assert ("filtered due to score" in result.output) == was_filtered diff --git a/tests/site_downloaders/fallback_downloaders/test_ytdlp_fallback.py b/tests/site_downloaders/fallback_downloaders/test_ytdlp_fallback.py index 29e72c54..9823d081 100644 --- a/tests/site_downloaders/fallback_downloaders/test_ytdlp_fallback.py +++ b/tests/site_downloaders/fallback_downloaders/test_ytdlp_fallback.py @@ -10,22 +10,23 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/', True), - ('https://www.youtube.com/watch?v=P19nvJOmqCc', True), - ('https://www.example.com/test', False), - ('https://milesmatrix.bandcamp.com/album/la-boum/', False), - ('https://v.redd.it/dlr54z8p182a1', True), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/", True), + ("https://www.youtube.com/watch?v=P19nvJOmqCc", True), + ("https://www.example.com/test", False), + ("https://milesmatrix.bandcamp.com/album/la-boum/", False), + ("https://v.redd.it/dlr54z8p182a1", True), + ), +) def test_can_handle_link(test_url: str, expected: bool): result = YtdlpFallback.can_handle_link(test_url) assert result == expected @pytest.mark.online -@pytest.mark.parametrize('test_url', ( - 'https://milesmatrix.bandcamp.com/album/la-boum/', -)) +@pytest.mark.parametrize("test_url", ("https://milesmatrix.bandcamp.com/album/la-boum/",)) def test_info_extraction_bad(test_url: str): with pytest.raises(NotADownloadableLinkError): YtdlpFallback.get_video_attributes(test_url) @@ -33,12 +34,18 @@ def test_info_extraction_bad(test_url: str): @pytest.mark.online @pytest.mark.slow -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://streamable.com/dt46y', 'b7e465adaade5f2b6d8c2b4b7d0a2878'), - ('https://streamable.com/t8sem', '49b2d1220c485455548f1edbc05d4ecf'), - ('https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/', '6c6ff46e04b4e33a755ae2a9b5a45ac5'), - ('https://v.redd.it/9z1dnk3xr5k61', '226cee353421c7aefb05c92424cc8cdd'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + ( + ("https://streamable.com/dt46y", "b7e465adaade5f2b6d8c2b4b7d0a2878"), + ("https://streamable.com/t8sem", "49b2d1220c485455548f1edbc05d4ecf"), + ( + "https://www.reddit.com/r/specializedtools/comments/n2nw5m/bamboo_splitter/", + "6c6ff46e04b4e33a755ae2a9b5a45ac5", + ), + ("https://v.redd.it/9z1dnk3xr5k61", "226cee353421c7aefb05c92424cc8cdd"), + ), +) def test_find_resources(test_url: str, expected_hash: str): test_submission = MagicMock() test_submission.url = test_url diff --git a/tests/site_downloaders/test_delay_for_reddit.py b/tests/site_downloaders/test_delay_for_reddit.py index 5e0e1c8c..65d080c4 100644 --- a/tests/site_downloaders/test_delay_for_reddit.py +++ b/tests/site_downloaders/test_delay_for_reddit.py @@ -10,10 +10,13 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://www.delayforreddit.com/dfr/calvin6123/MjU1Njc5NQ==', '3300f28c2f9358d05667985c9c04210d'), - ('https://www.delayforreddit.com/dfr/RoXs_26/NDAwMzAyOQ==', '09b7b01719dff45ab197bdc08b90f78a'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + ( + ("https://www.delayforreddit.com/dfr/calvin6123/MjU1Njc5NQ==", "3300f28c2f9358d05667985c9c04210d"), + ("https://www.delayforreddit.com/dfr/RoXs_26/NDAwMzAyOQ==", "09b7b01719dff45ab197bdc08b90f78a"), + ), +) def test_download_resource(test_url: str, expected_hash: str): mock_submission = Mock() mock_submission.url = test_url diff --git a/tests/site_downloaders/test_direct.py b/tests/site_downloaders/test_direct.py index 56f90fc5..b652d9ac 100644 --- a/tests/site_downloaders/test_direct.py +++ b/tests/site_downloaders/test_direct.py @@ -10,10 +10,13 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4', '48f9bd4dbec1556d7838885612b13b39'), - ('https://giant.gfycat.com/DazzlingSilkyIguana.mp4', '808941b48fc1e28713d36dd7ed9dc648'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + ( + ("https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4", "48f9bd4dbec1556d7838885612b13b39"), + ("https://giant.gfycat.com/DazzlingSilkyIguana.mp4", "808941b48fc1e28713d36dd7ed9dc648"), + ), +) def test_download_resource(test_url: str, expected_hash: str): mock_submission = Mock() mock_submission.url = test_url diff --git a/tests/site_downloaders/test_download_factory.py b/tests/site_downloaders/test_download_factory.py index bcfc7040..581656da 100644 --- a/tests/site_downloaders/test_download_factory.py +++ b/tests/site_downloaders/test_download_factory.py @@ -21,67 +21,82 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_submission_url', 'expected_class'), ( - ('https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life' - '_in_anything_but_comfort/', SelfPost), - ('https://i.redd.it/affyv0axd5k61.png', Direct), - ('https://i.imgur.com/bZx1SJQ.jpg', Imgur), - ('https://imgur.com/BuzvZwb.gifv', Imgur), - ('https://imgur.com/a/MkxAzeg', Imgur), - ('https://m.imgur.com/a/py3RW0j', Imgur), - ('https://www.reddit.com/gallery/lu93m7', Gallery), - ('https://gfycat.com/concretecheerfulfinwhale', Gfycat), - ('https://www.erome.com/a/NWGw0F09', Erome), - ('https://youtube.com/watch?v=Gv8Wz74FjVA', Youtube), - ('https://redgifs.com/watch/courageousimpeccablecanvasback', Redgifs), - ('https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse', Redgifs), - ('https://youtu.be/DevfjHOhuFc', Youtube), - ('https://m.youtube.com/watch?v=kr-FeojxzUM', Youtube), - ('https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781', Direct), - ('https://v.redd.it/9z1dnk3xr5k61', VReddit), - ('https://streamable.com/dt46y', YtdlpFallback), - ('https://vimeo.com/channels/31259/53576664', YtdlpFallback), - ('http://video.pbs.org/viralplayer/2365173446/', YtdlpFallback), - ('https://www.pornhub.com/view_video.php?viewkey=ph5a2ee0461a8d0', PornHub), - ('https://www.patreon.com/posts/minecart-track-59346560', Gallery), -)) +@pytest.mark.parametrize( + ("test_submission_url", "expected_class"), + ( + ( + "https://www.reddit.com/r/TwoXChromosomes/comments/lu29zn/i_refuse_to_live_my_life" + "_in_anything_but_comfort/", + SelfPost, + ), + ("https://i.redd.it/affyv0axd5k61.png", Direct), + ("https://i.imgur.com/bZx1SJQ.jpg", Imgur), + ("https://imgur.com/BuzvZwb.gifv", Imgur), + ("https://imgur.com/a/MkxAzeg", Imgur), + ("https://m.imgur.com/a/py3RW0j", Imgur), + ("https://www.reddit.com/gallery/lu93m7", Gallery), + ("https://gfycat.com/concretecheerfulfinwhale", Gfycat), + ("https://www.erome.com/a/NWGw0F09", Erome), + ("https://youtube.com/watch?v=Gv8Wz74FjVA", Youtube), + ("https://redgifs.com/watch/courageousimpeccablecanvasback", Redgifs), + ("https://www.gifdeliverynetwork.com/repulsivefinishedandalusianhorse", Redgifs), + ("https://youtu.be/DevfjHOhuFc", Youtube), + ("https://m.youtube.com/watch?v=kr-FeojxzUM", Youtube), + ("https://dynasty-scans.com/system/images_images/000/017/819/original/80215103_p0.png?1612232781", Direct), + ("https://v.redd.it/9z1dnk3xr5k61", VReddit), + ("https://streamable.com/dt46y", YtdlpFallback), + ("https://vimeo.com/channels/31259/53576664", YtdlpFallback), + ("http://video.pbs.org/viralplayer/2365173446/", YtdlpFallback), + ("https://www.pornhub.com/view_video.php?viewkey=ph5a2ee0461a8d0", PornHub), + ("https://www.patreon.com/posts/minecart-track-59346560", Gallery), + ), +) def test_factory_lever_good(test_submission_url: str, expected_class: BaseDownloader, reddit_instance: praw.Reddit): result = DownloadFactory.pull_lever(test_submission_url) assert result is expected_class -@pytest.mark.parametrize('test_url', ( - 'random.com', - 'bad', - 'https://www.google.com/', - 'https://www.google.com', - 'https://www.google.com/test', - 'https://www.google.com/test/', -)) +@pytest.mark.parametrize( + "test_url", + ( + "random.com", + "bad", + "https://www.google.com/", + "https://www.google.com", + "https://www.google.com/test", + "https://www.google.com/test/", + ), +) def test_factory_lever_bad(test_url: str): with pytest.raises(NotADownloadableLinkError): DownloadFactory.pull_lever(test_url) -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('www.test.com/test.png', 'test.com/test.png'), - ('www.test.com/test.png?test_value=random', 'test.com/test.png'), - ('https://youtube.com/watch?v=Gv8Wz74FjVA', 'youtube.com/watch'), - ('https://i.imgur.com/BuzvZwb.gifv', 'i.imgur.com/BuzvZwb.gifv'), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("www.test.com/test.png", "test.com/test.png"), + ("www.test.com/test.png?test_value=random", "test.com/test.png"), + ("https://youtube.com/watch?v=Gv8Wz74FjVA", "youtube.com/watch"), + ("https://i.imgur.com/BuzvZwb.gifv", "i.imgur.com/BuzvZwb.gifv"), + ), +) def test_sanitise_url(test_url: str, expected: str): result = DownloadFactory.sanitise_url(test_url) assert result == expected -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('www.example.com/test.asp', True), - ('www.example.com/test.html', True), - ('www.example.com/test.js', True), - ('www.example.com/test.xhtml', True), - ('www.example.com/test.mp4', False), - ('www.example.com/test.png', False), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("www.example.com/test.asp", True), + ("www.example.com/test.html", True), + ("www.example.com/test.js", True), + ("www.example.com/test.xhtml", True), + ("www.example.com/test.mp4", False), + ("www.example.com/test.png", False), + ), +) def test_is_web_resource(test_url: str, expected: bool): result = DownloadFactory.is_web_resource(test_url) assert result == expected diff --git a/tests/site_downloaders/test_erome.py b/tests/site_downloaders/test_erome.py index 2f3701d5..1baeb661 100644 --- a/tests/site_downloaders/test_erome.py +++ b/tests/site_downloaders/test_erome.py @@ -9,31 +9,38 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_urls'), ( - ('https://www.erome.com/a/vqtPuLXh', ( - r'https://[a-z]\d+.erome.com/\d{3}/vqtPuLXh/KH2qBT99_480p.mp4', - )), - ('https://www.erome.com/a/ORhX0FZz', ( - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9IYQocM9_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9eEDc8xm_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/EvApC7Rp_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/LruobtMs_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/TJNmSUU5_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/X11Skh6Z_480p.mp4', - r'https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/bjlTkpn7_480p.mp4' - )), -)) +@pytest.mark.parametrize( + ("test_url", "expected_urls"), + ( + ("https://www.erome.com/a/vqtPuLXh", (r"https://[a-z]\d+.erome.com/\d{3}/vqtPuLXh/KH2qBT99_480p.mp4",)), + ( + "https://www.erome.com/a/ORhX0FZz", + ( + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9IYQocM9_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/9eEDc8xm_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/EvApC7Rp_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/LruobtMs_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/TJNmSUU5_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/X11Skh6Z_480p.mp4", + r"https://[a-z]\d+.erome.com/\d{3}/ORhX0FZz/bjlTkpn7_480p.mp4", + ), + ), + ), +) def test_get_link(test_url: str, expected_urls: tuple[str]): - result = Erome. _get_links(test_url) + result = Erome._get_links(test_url) assert all([any([re.match(p, r) for r in result]) for p in expected_urls]) @pytest.mark.online @pytest.mark.slow -@pytest.mark.parametrize(('test_url', 'expected_hashes_len'), ( - ('https://www.erome.com/a/vqtPuLXh', 1), - ('https://www.erome.com/a/4tP3KI6F', 1), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hashes_len"), + ( + ("https://www.erome.com/a/vqtPuLXh", 1), + ("https://www.erome.com/a/4tP3KI6F", 1), + ), +) def test_download_resource(test_url: str, expected_hashes_len: int): # Can't compare hashes for this test, Erome doesn't return the exact same file from request to request so the hash # will change back and forth randomly diff --git a/tests/site_downloaders/test_gallery.py b/tests/site_downloaders/test_gallery.py index e9c401f0..57d055bb 100644 --- a/tests/site_downloaders/test_gallery.py +++ b/tests/site_downloaders/test_gallery.py @@ -9,30 +9,39 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_ids', 'expected'), ( - ([ - {'media_id': '18nzv9ch0hn61'}, - {'media_id': 'jqkizcch0hn61'}, - {'media_id': 'k0fnqzbh0hn61'}, - {'media_id': 'm3gamzbh0hn61'}, - ], { - 'https://i.redd.it/18nzv9ch0hn61.jpg', - 'https://i.redd.it/jqkizcch0hn61.jpg', - 'https://i.redd.it/k0fnqzbh0hn61.jpg', - 'https://i.redd.it/m3gamzbh0hn61.jpg' - }), - ([ - {'media_id': '04vxj25uqih61'}, - {'media_id': '0fnx83kpqih61'}, - {'media_id': '7zkmr1wqqih61'}, - {'media_id': 'u37k5gxrqih61'}, - ], { - 'https://i.redd.it/04vxj25uqih61.png', - 'https://i.redd.it/0fnx83kpqih61.png', - 'https://i.redd.it/7zkmr1wqqih61.png', - 'https://i.redd.it/u37k5gxrqih61.png' - }), -)) +@pytest.mark.parametrize( + ("test_ids", "expected"), + ( + ( + [ + {"media_id": "18nzv9ch0hn61"}, + {"media_id": "jqkizcch0hn61"}, + {"media_id": "k0fnqzbh0hn61"}, + {"media_id": "m3gamzbh0hn61"}, + ], + { + "https://i.redd.it/18nzv9ch0hn61.jpg", + "https://i.redd.it/jqkizcch0hn61.jpg", + "https://i.redd.it/k0fnqzbh0hn61.jpg", + "https://i.redd.it/m3gamzbh0hn61.jpg", + }, + ), + ( + [ + {"media_id": "04vxj25uqih61"}, + {"media_id": "0fnx83kpqih61"}, + {"media_id": "7zkmr1wqqih61"}, + {"media_id": "u37k5gxrqih61"}, + ], + { + "https://i.redd.it/04vxj25uqih61.png", + "https://i.redd.it/0fnx83kpqih61.png", + "https://i.redd.it/7zkmr1wqqih61.png", + "https://i.redd.it/u37k5gxrqih61.png", + }, + ), + ), +) def test_gallery_get_links(test_ids: list[dict], expected: set[str]): results = Gallery._get_links(test_ids) assert set(results) == expected @@ -40,32 +49,47 @@ def test_gallery_get_links(test_ids: list[dict], expected: set[str]): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_hashes'), ( - ('m6lvrh', { - '5c42b8341dd56eebef792e86f3981c6a', - '8f38d76da46f4057bf2773a778e725ca', - 'f5776f8f90491c8b770b8e0a6bfa49b3', - 'fa1a43c94da30026ad19a9813a0ed2c2', - }), - ('ljyy27', { - '359c203ec81d0bc00e675f1023673238', - '79262fd46bce5bfa550d878a3b898be4', - '808c35267f44acb523ce03bfa5687404', - 'ec8b65bdb7f1279c4b3af0ea2bbb30c3', - }), - ('obkflw', { - '65163f685fb28c5b776e0e77122718be', - '2a337eb5b13c34d3ca3f51b5db7c13e9', - }), - ('rb3ub6', { # patreon post - '748a976c6cedf7ea85b6f90e7cb685c7', - '839796d7745e88ced6355504e1f74508', - 'bcdb740367d0f19f97a77e614b48a42d', - '0f230b8c4e5d103d35a773fab9814ec3', - 'e5192d6cb4f84c4f4a658355310bf0f9', - '91cbe172cd8ccbcf049fcea4204eb979', - }) -)) +@pytest.mark.parametrize( + ("test_submission_id", "expected_hashes"), + ( + ( + "m6lvrh", + { + "5c42b8341dd56eebef792e86f3981c6a", + "8f38d76da46f4057bf2773a778e725ca", + "f5776f8f90491c8b770b8e0a6bfa49b3", + "fa1a43c94da30026ad19a9813a0ed2c2", + }, + ), + ( + "ljyy27", + { + "359c203ec81d0bc00e675f1023673238", + "79262fd46bce5bfa550d878a3b898be4", + "808c35267f44acb523ce03bfa5687404", + "ec8b65bdb7f1279c4b3af0ea2bbb30c3", + }, + ), + ( + "obkflw", + { + "65163f685fb28c5b776e0e77122718be", + "2a337eb5b13c34d3ca3f51b5db7c13e9", + }, + ), + ( + "rb3ub6", + { # patreon post + "748a976c6cedf7ea85b6f90e7cb685c7", + "839796d7745e88ced6355504e1f74508", + "bcdb740367d0f19f97a77e614b48a42d", + "0f230b8c4e5d103d35a773fab9814ec3", + "e5192d6cb4f84c4f4a658355310bf0f9", + "91cbe172cd8ccbcf049fcea4204eb979", + }, + ), + ), +) def test_gallery_download(test_submission_id: str, expected_hashes: set[str], reddit_instance: praw.Reddit): test_submission = reddit_instance.submission(id=test_submission_id) gallery = Gallery(test_submission) @@ -75,10 +99,13 @@ def test_gallery_download(test_submission_id: str, expected_hashes: set[str], re assert set(hashes) == expected_hashes -@pytest.mark.parametrize('test_id', ( - 'n0pyzp', - 'nxyahw', -)) +@pytest.mark.parametrize( + "test_id", + ( + "n0pyzp", + "nxyahw", + ), +) def test_gallery_download_raises_right_error(test_id: str, reddit_instance: praw.Reddit): test_submission = reddit_instance.submission(id=test_id) gallery = Gallery(test_submission) diff --git a/tests/site_downloaders/test_gfycat.py b/tests/site_downloaders/test_gfycat.py index 3b408402..d4366361 100644 --- a/tests/site_downloaders/test_gfycat.py +++ b/tests/site_downloaders/test_gfycat.py @@ -10,20 +10,26 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_url'), ( - ('https://gfycat.com/definitivecaninecrayfish', 'https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4'), - ('https://gfycat.com/dazzlingsilkyiguana', 'https://giant.gfycat.com/DazzlingSilkyIguana.mp4'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_url"), + ( + ("https://gfycat.com/definitivecaninecrayfish", "https://giant.gfycat.com/DefinitiveCanineCrayfish.mp4"), + ("https://gfycat.com/dazzlingsilkyiguana", "https://giant.gfycat.com/DazzlingSilkyIguana.mp4"), + ), +) def test_get_link(test_url: str, expected_url: str): result = Gfycat._get_link(test_url) assert result.pop() == expected_url @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://gfycat.com/definitivecaninecrayfish', '48f9bd4dbec1556d7838885612b13b39'), - ('https://gfycat.com/dazzlingsilkyiguana', '808941b48fc1e28713d36dd7ed9dc648'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + ( + ("https://gfycat.com/definitivecaninecrayfish", "48f9bd4dbec1556d7838885612b13b39"), + ("https://gfycat.com/dazzlingsilkyiguana", "808941b48fc1e28713d36dd7ed9dc648"), + ), +) def test_download_resource(test_url: str, expected_hash: str): mock_submission = Mock() mock_submission.url = test_url diff --git a/tests/site_downloaders/test_imgur.py b/tests/site_downloaders/test_imgur.py index 00419bad..38dbdc54 100644 --- a/tests/site_downloaders/test_imgur.py +++ b/tests/site_downloaders/test_imgur.py @@ -11,166 +11,167 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_gen_dict', 'expected_image_dict'), ( +@pytest.mark.parametrize( + ("test_url", "expected_gen_dict", "expected_image_dict"), ( - 'https://imgur.com/a/xWZsDDP', - {'num_images': '1', 'id': 'xWZsDDP', 'hash': 'xWZsDDP'}, - [ - {'hash': 'ypa8YfS', 'title': '', 'ext': '.png', 'animated': False} - ] - ), - ( - 'https://imgur.com/gallery/IjJJdlC', - {'num_images': 1, 'id': 384898055, 'hash': 'IjJJdlC'}, - [ - {'hash': 'CbbScDt', - 'description': 'watch when he gets it', - 'ext': '.gif', - 'animated': True, - 'has_sound': False - } - ], - ), - ( - 'https://imgur.com/a/dcc84Gt', - {'num_images': '4', 'id': 'dcc84Gt', 'hash': 'dcc84Gt'}, - [ - {'hash': 'ylx0Kle', 'ext': '.jpg', 'title': ''}, - {'hash': 'TdYfKbK', 'ext': '.jpg', 'title': ''}, - {'hash': 'pCxGbe8', 'ext': '.jpg', 'title': ''}, - {'hash': 'TSAkikk', 'ext': '.jpg', 'title': ''}, - ] - ), - ( - 'https://m.imgur.com/a/py3RW0j', - {'num_images': '1', 'id': 'py3RW0j', 'hash': 'py3RW0j', }, - [ - {'hash': 'K24eQmK', 'has_sound': False, 'ext': '.jpg'} - ], + ( + "https://imgur.com/a/xWZsDDP", + {"num_images": "1", "id": "xWZsDDP", "hash": "xWZsDDP"}, + [{"hash": "ypa8YfS", "title": "", "ext": ".png", "animated": False}], + ), + ( + "https://imgur.com/gallery/IjJJdlC", + {"num_images": 1, "id": 384898055, "hash": "IjJJdlC"}, + [ + { + "hash": "CbbScDt", + "description": "watch when he gets it", + "ext": ".gif", + "animated": True, + "has_sound": False, + } + ], + ), + ( + "https://imgur.com/a/dcc84Gt", + {"num_images": "4", "id": "dcc84Gt", "hash": "dcc84Gt"}, + [ + {"hash": "ylx0Kle", "ext": ".jpg", "title": ""}, + {"hash": "TdYfKbK", "ext": ".jpg", "title": ""}, + {"hash": "pCxGbe8", "ext": ".jpg", "title": ""}, + {"hash": "TSAkikk", "ext": ".jpg", "title": ""}, + ], + ), + ( + "https://m.imgur.com/a/py3RW0j", + { + "num_images": "1", + "id": "py3RW0j", + "hash": "py3RW0j", + }, + [{"hash": "K24eQmK", "has_sound": False, "ext": ".jpg"}], + ), ), -)) +) def test_get_data_album(test_url: str, expected_gen_dict: dict, expected_image_dict: list[dict]): result = Imgur._get_data(test_url) assert all([result.get(key) == expected_gen_dict[key] for key in expected_gen_dict.keys()]) # Check if all the keys from the test dict are correct in at least one of the album entries - assert any([all([image.get(key) == image_dict[key] for key in image_dict.keys()]) - for image_dict in expected_image_dict for image in result['album_images']['images']]) + assert any( + [ + all([image.get(key) == image_dict[key] for key in image_dict.keys()]) + for image_dict in expected_image_dict + for image in result["album_images"]["images"] + ] + ) @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_image_dict'), ( - ( - 'https://i.imgur.com/dLk3FGY.gifv', - {'hash': 'dLk3FGY', 'title': '', 'ext': '.mp4', 'animated': True} - ), +@pytest.mark.parametrize( + ("test_url", "expected_image_dict"), ( - 'https://imgur.com/65FqTpT.gifv', - { - 'hash': '65FqTpT', - 'title': '', - 'description': '', - 'animated': True, - 'mimetype': 'video/mp4' - }, + ("https://i.imgur.com/dLk3FGY.gifv", {"hash": "dLk3FGY", "title": "", "ext": ".mp4", "animated": True}), + ( + "https://imgur.com/65FqTpT.gifv", + {"hash": "65FqTpT", "title": "", "description": "", "animated": True, "mimetype": "video/mp4"}, + ), ), -)) +) def test_get_data_gif(test_url: str, expected_image_dict: dict): result = Imgur._get_data(test_url) assert all([result.get(key) == expected_image_dict[key] for key in expected_image_dict.keys()]) -@pytest.mark.parametrize('test_extension', ( - '.gif', - '.png', - '.jpg', - '.mp4' -)) +@pytest.mark.parametrize("test_extension", (".gif", ".png", ".jpg", ".mp4")) def test_imgur_extension_validation_good(test_extension: str): result = Imgur._validate_extension(test_extension) assert result == test_extension -@pytest.mark.parametrize('test_extension', ( - '.jpeg', - 'bad', - '.avi', - '.test', - '.flac', -)) +@pytest.mark.parametrize( + "test_extension", + ( + ".jpeg", + "bad", + ".avi", + ".test", + ".flac", + ), +) def test_imgur_extension_validation_bad(test_extension: str): with pytest.raises(SiteDownloaderError): Imgur._validate_extension(test_extension) @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( +@pytest.mark.parametrize( + ("test_url", "expected_hashes"), ( - 'https://imgur.com/a/xWZsDDP', - ('f551d6e6b0fef2ce909767338612e31b',) - ), - ( - 'https://imgur.com/gallery/IjJJdlC', - ('740b006cf9ec9d6f734b6e8f5130bdab',), - ), - ( - 'https://imgur.com/a/dcc84Gt', + ("https://imgur.com/a/xWZsDDP", ("f551d6e6b0fef2ce909767338612e31b",)), ( - 'cf1158e1de5c3c8993461383b96610cf', - '28d6b791a2daef8aa363bf5a3198535d', - '248ef8f2a6d03eeb2a80d0123dbaf9b6', - '029c475ce01b58fdf1269d8771d33913', + "https://imgur.com/gallery/IjJJdlC", + ("740b006cf9ec9d6f734b6e8f5130bdab",), ), - ), - ( - 'https://imgur.com/a/eemHCCK', ( - '9cb757fd8f055e7ef7aa88addc9d9fa5', - 'b6cb6c918e2544e96fb7c07d828774b5', - 'fb6c913d721c0bbb96aa65d7f560d385', + "https://imgur.com/a/dcc84Gt", + ( + "cf1158e1de5c3c8993461383b96610cf", + "28d6b791a2daef8aa363bf5a3198535d", + "248ef8f2a6d03eeb2a80d0123dbaf9b6", + "029c475ce01b58fdf1269d8771d33913", + ), + ), + ( + "https://imgur.com/a/eemHCCK", + ( + "9cb757fd8f055e7ef7aa88addc9d9fa5", + "b6cb6c918e2544e96fb7c07d828774b5", + "fb6c913d721c0bbb96aa65d7f560d385", + ), + ), + ( + "https://i.imgur.com/lFJai6i.gifv", + ("01a6e79a30bec0e644e5da12365d5071",), + ), + ( + "https://i.imgur.com/ywSyILa.gifv?", + ("56d4afc32d2966017c38d98568709b45",), + ), + ( + "https://imgur.com/ubYwpbk.GIFV", + ("d4a774aac1667783f9ed3a1bd02fac0c",), + ), + ( + "https://i.imgur.com/j1CNCZY.gifv", + ("58e7e6d972058c18b7ecde910ca147e3",), + ), + ( + "https://i.imgur.com/uTvtQsw.gifv", + ("46c86533aa60fc0e09f2a758513e3ac2",), + ), + ( + "https://i.imgur.com/OGeVuAe.giff", + ("77389679084d381336f168538793f218",), + ), + ( + "https://i.imgur.com/OGeVuAe.gift", + ("77389679084d381336f168538793f218",), + ), + ( + "https://i.imgur.com/3SKrQfK.jpg?1", + ("aa299e181b268578979cad176d1bd1d0",), + ), + ( + "https://i.imgur.com/cbivYRW.jpg?3", + ("7ec6ceef5380cb163a1d498c359c51fd",), + ), + ( + "http://i.imgur.com/s9uXxlq.jpg?5.jpg", + ("338de3c23ee21af056b3a7c154e2478f",), ), ), - ( - 'https://i.imgur.com/lFJai6i.gifv', - ('01a6e79a30bec0e644e5da12365d5071',), - ), - ( - 'https://i.imgur.com/ywSyILa.gifv?', - ('56d4afc32d2966017c38d98568709b45',), - ), - ( - 'https://imgur.com/ubYwpbk.GIFV', - ('d4a774aac1667783f9ed3a1bd02fac0c',), - ), - ( - 'https://i.imgur.com/j1CNCZY.gifv', - ('58e7e6d972058c18b7ecde910ca147e3',), - ), - ( - 'https://i.imgur.com/uTvtQsw.gifv', - ('46c86533aa60fc0e09f2a758513e3ac2',), - ), - ( - 'https://i.imgur.com/OGeVuAe.giff', - ('77389679084d381336f168538793f218',), - ), - ( - 'https://i.imgur.com/OGeVuAe.gift', - ('77389679084d381336f168538793f218',), - ), - ( - 'https://i.imgur.com/3SKrQfK.jpg?1', - ('aa299e181b268578979cad176d1bd1d0',), - ), - ( - 'https://i.imgur.com/cbivYRW.jpg?3', - ('7ec6ceef5380cb163a1d498c359c51fd',), - ), - ( - 'http://i.imgur.com/s9uXxlq.jpg?5.jpg', - ('338de3c23ee21af056b3a7c154e2478f',), - ), -)) +) def test_find_resources(test_url: str, expected_hashes: list[str]): mock_download = Mock() mock_download.url = test_url diff --git a/tests/site_downloaders/test_pornhub.py b/tests/site_downloaders/test_pornhub.py index e0933b0d..42ca5a0d 100644 --- a/tests/site_downloaders/test_pornhub.py +++ b/tests/site_downloaders/test_pornhub.py @@ -12,9 +12,10 @@ @pytest.mark.online @pytest.mark.slow -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://www.pornhub.com/view_video.php?viewkey=ph6074c59798497', 'ad52a0f4fce8f99df0abed17de1d04c7'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + (("https://www.pornhub.com/view_video.php?viewkey=ph6074c59798497", "ad52a0f4fce8f99df0abed17de1d04c7"),), +) def test_hash_resources_good(test_url: str, expected_hash: str): test_submission = MagicMock() test_submission.url = test_url @@ -27,9 +28,7 @@ def test_hash_resources_good(test_url: str, expected_hash: str): @pytest.mark.online -@pytest.mark.parametrize('test_url', ( - 'https://www.pornhub.com/view_video.php?viewkey=ph5ede121f0d3f8', -)) +@pytest.mark.parametrize("test_url", ("https://www.pornhub.com/view_video.php?viewkey=ph5ede121f0d3f8",)) def test_find_resources_good(test_url: str): test_submission = MagicMock() test_submission.url = test_url diff --git a/tests/site_downloaders/test_redgifs.py b/tests/site_downloaders/test_redgifs.py index 9a6d132a..0e1a497c 100644 --- a/tests/site_downloaders/test_redgifs.py +++ b/tests/site_downloaders/test_redgifs.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # coding=utf-8 -from unittest.mock import Mock import re +from unittest.mock import Mock import pytest @@ -11,45 +11,55 @@ @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('https://redgifs.com/watch/frighteningvictorioussalamander', - {'FrighteningVictoriousSalamander.mp4'}), - ('https://redgifs.com/watch/springgreendecisivetaruca', - {'SpringgreenDecisiveTaruca.mp4'}), - ('https://www.redgifs.com/watch/palegoldenrodrawhalibut', - {'PalegoldenrodRawHalibut.mp4'}), - ('https://redgifs.com/watch/hollowintentsnowyowl', - {'HollowIntentSnowyowl-large.jpg'}), - ('https://www.redgifs.com/watch/lustrousstickywaxwing', - {'EntireEnchantingHypsilophodon-large.jpg', - 'FancyMagnificentAdamsstaghornedbeetle-large.jpg', - 'LustrousStickyWaxwing-large.jpg', - 'ParchedWindyArmyworm-large.jpg', - 'ThunderousColorlessErmine-large.jpg', - 'UnripeUnkemptWoodpecker-large.jpg'}), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("https://redgifs.com/watch/frighteningvictorioussalamander", {"FrighteningVictoriousSalamander.mp4"}), + ("https://redgifs.com/watch/springgreendecisivetaruca", {"SpringgreenDecisiveTaruca.mp4"}), + ("https://www.redgifs.com/watch/palegoldenrodrawhalibut", {"PalegoldenrodRawHalibut.mp4"}), + ("https://redgifs.com/watch/hollowintentsnowyowl", {"HollowIntentSnowyowl-large.jpg"}), + ( + "https://www.redgifs.com/watch/lustrousstickywaxwing", + { + "EntireEnchantingHypsilophodon-large.jpg", + "FancyMagnificentAdamsstaghornedbeetle-large.jpg", + "LustrousStickyWaxwing-large.jpg", + "ParchedWindyArmyworm-large.jpg", + "ThunderousColorlessErmine-large.jpg", + "UnripeUnkemptWoodpecker-large.jpg", + }, + ), + ), +) def test_get_link(test_url: str, expected: set[str]): result = Redgifs._get_link(test_url) result = list(result) - patterns = [r'https://thumbs\d\.redgifs\.com/' + e + r'.*' for e in expected] + patterns = [r"https://thumbs\d\.redgifs\.com/" + e + r".*" for e in expected] assert all([re.match(p, r) for p in patterns] for r in result) @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( - ('https://redgifs.com/watch/frighteningvictorioussalamander', {'4007c35d9e1f4b67091b5f12cffda00a'}), - ('https://redgifs.com/watch/springgreendecisivetaruca', {'8dac487ac49a1f18cc1b4dabe23f0869'}), - ('https://redgifs.com/watch/leafysaltydungbeetle', {'076792c660b9c024c0471ef4759af8bd'}), - ('https://www.redgifs.com/watch/palegoldenrodrawhalibut', {'46d5aa77fe80c6407de1ecc92801c10e'}), - ('https://redgifs.com/watch/hollowintentsnowyowl', {'5ee51fa15e0a58e98f11dea6a6cca771'}), - ('https://www.redgifs.com/watch/lustrousstickywaxwing', - {'b461e55664f07bed8d2f41d8586728fa', - '30ba079a8ed7d7adf17929dc3064c10f', - '0d4f149d170d29fc2f015c1121bab18b', - '53987d99cfd77fd65b5fdade3718f9f1', - 'fb2e7d972846b83bf4016447d3060d60', - '44fb28f72ec9a5cca63fa4369ab4f672'}), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hashes"), + ( + ("https://redgifs.com/watch/frighteningvictorioussalamander", {"4007c35d9e1f4b67091b5f12cffda00a"}), + ("https://redgifs.com/watch/springgreendecisivetaruca", {"8dac487ac49a1f18cc1b4dabe23f0869"}), + ("https://redgifs.com/watch/leafysaltydungbeetle", {"076792c660b9c024c0471ef4759af8bd"}), + ("https://www.redgifs.com/watch/palegoldenrodrawhalibut", {"46d5aa77fe80c6407de1ecc92801c10e"}), + ("https://redgifs.com/watch/hollowintentsnowyowl", {"5ee51fa15e0a58e98f11dea6a6cca771"}), + ( + "https://www.redgifs.com/watch/lustrousstickywaxwing", + { + "b461e55664f07bed8d2f41d8586728fa", + "30ba079a8ed7d7adf17929dc3064c10f", + "0d4f149d170d29fc2f015c1121bab18b", + "53987d99cfd77fd65b5fdade3718f9f1", + "fb2e7d972846b83bf4016447d3060d60", + "44fb28f72ec9a5cca63fa4369ab4f672", + }, + ), + ), +) def test_download_resource(test_url: str, expected_hashes: set[str]): mock_submission = Mock() mock_submission.url = test_url @@ -62,18 +72,30 @@ def test_download_resource(test_url: str, expected_hashes: set[str]): @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_link', 'expected_hash'), ( - ('https://redgifs.com/watch/flippantmemorablebaiji', {'FlippantMemorableBaiji-mobile.mp4'}, - {'41a5fb4865367ede9f65fc78736f497a'}), - ('https://redgifs.com/watch/thirstyunfortunatewaterdragons', {'thirstyunfortunatewaterdragons-mobile.mp4'}, - {'1a51dad8fedb594bdd84f027b3cbe8af'}), - ('https://redgifs.com/watch/conventionalplainxenopterygii', {'conventionalplainxenopterygii-mobile.mp4'}, - {'2e1786b3337da85b80b050e2c289daa4'}) -)) +@pytest.mark.parametrize( + ("test_url", "expected_link", "expected_hash"), + ( + ( + "https://redgifs.com/watch/flippantmemorablebaiji", + {"FlippantMemorableBaiji-mobile.mp4"}, + {"41a5fb4865367ede9f65fc78736f497a"}, + ), + ( + "https://redgifs.com/watch/thirstyunfortunatewaterdragons", + {"thirstyunfortunatewaterdragons-mobile.mp4"}, + {"1a51dad8fedb594bdd84f027b3cbe8af"}, + ), + ( + "https://redgifs.com/watch/conventionalplainxenopterygii", + {"conventionalplainxenopterygii-mobile.mp4"}, + {"2e1786b3337da85b80b050e2c289daa4"}, + ), + ), +) def test_hd_soft_fail(test_url: str, expected_link: set[str], expected_hash: set[str]): link = Redgifs._get_link(test_url) link = list(link) - patterns = [r'https://thumbs\d\.redgifs\.com/' + e + r'.*' for e in expected_link] + patterns = [r"https://thumbs\d\.redgifs\.com/" + e + r".*" for e in expected_link] assert all([re.match(p, r) for p in patterns] for r in link) mock_submission = Mock() mock_submission.url = test_url diff --git a/tests/site_downloaders/test_self_post.py b/tests/site_downloaders/test_self_post.py index e3363bbe..104fb3bf 100644 --- a/tests/site_downloaders/test_self_post.py +++ b/tests/site_downloaders/test_self_post.py @@ -10,11 +10,14 @@ @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_hash'), ( - ('ltmivt', '7d2c9e4e989e5cf2dca2e55a06b1c4f6'), - ('ltoaan', '221606386b614d6780c2585a59bd333f'), - ('d3sc8o', 'c1ff2b6bd3f6b91381dcd18dfc4ca35f'), -)) +@pytest.mark.parametrize( + ("test_submission_id", "expected_hash"), + ( + ("ltmivt", "7d2c9e4e989e5cf2dca2e55a06b1c4f6"), + ("ltoaan", "221606386b614d6780c2585a59bd333f"), + ("d3sc8o", "c1ff2b6bd3f6b91381dcd18dfc4ca35f"), + ), +) def test_find_resource(test_submission_id: str, expected_hash: str, reddit_instance: praw.Reddit): submission = reddit_instance.submission(id=test_submission_id) downloader = SelfPost(submission) diff --git a/tests/site_downloaders/test_vidble.py b/tests/site_downloaders/test_vidble.py index f6ddd56b..16b5a3b9 100644 --- a/tests/site_downloaders/test_vidble.py +++ b/tests/site_downloaders/test_vidble.py @@ -8,55 +8,83 @@ from bdfr.site_downloaders.vidble import Vidble -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('/RDFbznUvcN_med.jpg', '/RDFbznUvcN.jpg'), -)) +@pytest.mark.parametrize(("test_url", "expected"), (("/RDFbznUvcN_med.jpg", "/RDFbznUvcN.jpg"),)) def test_change_med_url(test_url: str, expected: str): result = Vidble.change_med_url(test_url) assert result == expected @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('https://www.vidble.com/show/UxsvAssYe5', { - 'https://www.vidble.com/UxsvAssYe5.gif', - }), - ('https://vidble.com/show/RDFbznUvcN', { - 'https://www.vidble.com/RDFbznUvcN.jpg', - }), - ('https://vidble.com/album/h0jTLs6B', { - 'https://www.vidble.com/XG4eAoJ5JZ.jpg', - 'https://www.vidble.com/IqF5UdH6Uq.jpg', - 'https://www.vidble.com/VWuNsnLJMD.jpg', - 'https://www.vidble.com/sMmM8O650W.jpg', - }), - ('https://www.vidble.com/pHuwWkOcEb', { - 'https://www.vidble.com/pHuwWkOcEb.jpg', - }), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ( + "https://www.vidble.com/show/UxsvAssYe5", + { + "https://www.vidble.com/UxsvAssYe5.gif", + }, + ), + ( + "https://vidble.com/show/RDFbznUvcN", + { + "https://www.vidble.com/RDFbznUvcN.jpg", + }, + ), + ( + "https://vidble.com/album/h0jTLs6B", + { + "https://www.vidble.com/XG4eAoJ5JZ.jpg", + "https://www.vidble.com/IqF5UdH6Uq.jpg", + "https://www.vidble.com/VWuNsnLJMD.jpg", + "https://www.vidble.com/sMmM8O650W.jpg", + }, + ), + ( + "https://www.vidble.com/pHuwWkOcEb", + { + "https://www.vidble.com/pHuwWkOcEb.jpg", + }, + ), + ), +) def test_get_links(test_url: str, expected: set[str]): results = Vidble.get_links(test_url) assert results == expected @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hashes'), ( - ('https://www.vidble.com/show/UxsvAssYe5', { - '0ef2f8e0e0b45936d2fb3e6fbdf67e28', - }), - ('https://vidble.com/show/RDFbznUvcN', { - 'c2dd30a71e32369c50eed86f86efff58', - }), - ('https://vidble.com/album/h0jTLs6B', { - '3b3cba02e01c91f9858a95240b942c71', - 'dd6ecf5fc9e936f9fb614eb6a0537f99', - 'b31a942cd8cdda218ed547bbc04c3a27', - '6f77c570b451eef4222804bd52267481', - }), - ('https://www.vidble.com/pHuwWkOcEb', { - '585f486dd0b2f23a57bddbd5bf185bc7', - }), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hashes"), + ( + ( + "https://www.vidble.com/show/UxsvAssYe5", + { + "0ef2f8e0e0b45936d2fb3e6fbdf67e28", + }, + ), + ( + "https://vidble.com/show/RDFbznUvcN", + { + "c2dd30a71e32369c50eed86f86efff58", + }, + ), + ( + "https://vidble.com/album/h0jTLs6B", + { + "3b3cba02e01c91f9858a95240b942c71", + "dd6ecf5fc9e936f9fb614eb6a0537f99", + "b31a942cd8cdda218ed547bbc04c3a27", + "6f77c570b451eef4222804bd52267481", + }, + ), + ( + "https://www.vidble.com/pHuwWkOcEb", + { + "585f486dd0b2f23a57bddbd5bf185bc7", + }, + ), + ), +) def test_find_resources(test_url: str, expected_hashes: set[str]): mock_download = Mock() mock_download.url = test_url diff --git a/tests/site_downloaders/test_vreddit.py b/tests/site_downloaders/test_vreddit.py index 54ffcf8f..6e79ba02 100644 --- a/tests/site_downloaders/test_vreddit.py +++ b/tests/site_downloaders/test_vreddit.py @@ -12,9 +12,10 @@ @pytest.mark.online @pytest.mark.slow -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://reddit.com/r/Unexpected/comments/z4xsuj/omg_thats_so_cute/', '1ffab5e5c0cc96db18108e4f37e8ca7f'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + (("https://reddit.com/r/Unexpected/comments/z4xsuj/omg_thats_so_cute/", "1ffab5e5c0cc96db18108e4f37e8ca7f"),), +) def test_find_resources_good(test_url: str, expected_hash: str): test_submission = MagicMock() test_submission.url = test_url @@ -27,10 +28,13 @@ def test_find_resources_good(test_url: str, expected_hash: str): @pytest.mark.online -@pytest.mark.parametrize('test_url', ( - 'https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman' - '-interview-oj-simpson-goliath-chronicles', -)) +@pytest.mark.parametrize( + "test_url", + ( + "https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman" + "-interview-oj-simpson-goliath-chronicles", + ), +) def test_find_resources_bad(test_url: str): test_submission = MagicMock() test_submission.url = test_url diff --git a/tests/site_downloaders/test_youtube.py b/tests/site_downloaders/test_youtube.py index 14c6648a..7a45a3c9 100644 --- a/tests/site_downloaders/test_youtube.py +++ b/tests/site_downloaders/test_youtube.py @@ -12,10 +12,13 @@ @pytest.mark.online @pytest.mark.slow -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://www.youtube.com/watch?v=uSm2VDgRIUs', '2d60b54582df5b95ec72bb00b580d2ff'), - ('https://www.youtube.com/watch?v=GcI7nxQj7HA', '5db0fc92a0a7fb9ac91e63505eea9cf0'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + ( + ("https://www.youtube.com/watch?v=uSm2VDgRIUs", "2d60b54582df5b95ec72bb00b580d2ff"), + ("https://www.youtube.com/watch?v=GcI7nxQj7HA", "5db0fc92a0a7fb9ac91e63505eea9cf0"), + ), +) def test_find_resources_good(test_url: str, expected_hash: str): test_submission = MagicMock() test_submission.url = test_url @@ -28,10 +31,13 @@ def test_find_resources_good(test_url: str, expected_hash: str): @pytest.mark.online -@pytest.mark.parametrize('test_url', ( - 'https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman' - '-interview-oj-simpson-goliath-chronicles', -)) +@pytest.mark.parametrize( + "test_url", + ( + "https://www.polygon.com/disney-plus/2020/5/14/21249881/gargoyles-animated-series-disney-plus-greg-weisman" + "-interview-oj-simpson-goliath-chronicles", + ), +) def test_find_resources_bad(test_url: str): test_submission = MagicMock() test_submission.url = test_url diff --git a/tests/test_archiver.py b/tests/test_archiver.py index 627caeed..932a2ab1 100644 --- a/tests/test_archiver.py +++ b/tests/test_archiver.py @@ -12,15 +12,18 @@ @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'test_format'), ( - ('m3reby', 'xml'), - ('m3reby', 'json'), - ('m3reby', 'yaml'), -)) +@pytest.mark.parametrize( + ("test_submission_id", "test_format"), + ( + ("m3reby", "xml"), + ("m3reby", "json"), + ("m3reby", "yaml"), + ), +) def test_write_submission_json(test_submission_id: str, tmp_path: Path, test_format: str, reddit_instance: praw.Reddit): archiver_mock = MagicMock() archiver_mock.args.format = test_format - test_path = Path(tmp_path, 'test') + test_path = Path(tmp_path, "test") test_submission = reddit_instance.submission(id=test_submission_id) archiver_mock.file_name_formatter.format_path.return_value = test_path Archiver.write_entry(archiver_mock, test_submission) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 060f1456..652c4017 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -8,13 +8,16 @@ from bdfr.configuration import Configuration -@pytest.mark.parametrize('arg_dict', ( - {'directory': 'test_dir'}, - { - 'directory': 'test_dir', - 'no_dupes': True, - }, -)) +@pytest.mark.parametrize( + "arg_dict", + ( + {"directory": "test_dir"}, + { + "directory": "test_dir", + "no_dupes": True, + }, + ), +) def test_process_click_context(arg_dict: dict): test_config = Configuration() test_context = MagicMock() @@ -25,9 +28,9 @@ def test_process_click_context(arg_dict: dict): def test_yaml_file_read(): - file = './tests/yaml_test_configuration.yaml' + file = "./tests/yaml_test_configuration.yaml" test_config = Configuration() test_config.parse_yaml_options(file) - assert test_config.subreddit == ['EarthPorn', 'TwoXChromosomes', 'Mindustry'] - assert test_config.sort == 'new' + assert test_config.subreddit == ["EarthPorn", "TwoXChromosomes", "Mindustry"] + assert test_config.sort == "new" assert test_config.limit == 10 diff --git a/tests/test_connector.py b/tests/test_connector.py index 4c9e52d5..01b6a92b 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -20,7 +20,7 @@ @pytest.fixture() def args() -> Configuration: args = Configuration() - args.time_format = 'ISO' + args.time_format = "ISO" return args @@ -30,7 +30,8 @@ def downloader_mock(args: Configuration): downloader_mock.args = args downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator( - downloader_mock, x) + downloader_mock, x + ) downloader_mock.split_args_input = RedditConnector.split_args_input downloader_mock.master_hash_list = {} return downloader_mock @@ -55,16 +56,22 @@ def assert_all_results_are_submissions_or_comments(result_limit: int, results: l def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock): - downloader_mock.args.directory = tmp_path / 'test' + downloader_mock.args.directory = tmp_path / "test" downloader_mock.config_directories.user_config_dir = tmp_path RedditConnector.determine_directories(downloader_mock) - assert Path(tmp_path / 'test').exists() - - -@pytest.mark.parametrize(('skip_extensions', 'skip_domains'), ( - ([], []), - (['.test'], ['test.com'],), -)) + assert Path(tmp_path / "test").exists() + + +@pytest.mark.parametrize( + ("skip_extensions", "skip_domains"), + ( + ([], []), + ( + [".test"], + ["test.com"], + ), + ), +) def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock): downloader_mock.args.skip = skip_extensions downloader_mock.args.skip_domain = skip_domains @@ -75,14 +82,17 @@ def test_create_download_filter(skip_extensions: list[str], skip_domains: list[s assert result.excluded_extensions == skip_extensions -@pytest.mark.parametrize(('test_time', 'expected'), ( - ('all', 'all'), - ('hour', 'hour'), - ('day', 'day'), - ('week', 'week'), - ('random', 'all'), - ('', 'all'), -)) +@pytest.mark.parametrize( + ("test_time", "expected"), + ( + ("all", "all"), + ("hour", "hour"), + ("day", "day"), + ("week", "week"), + ("random", "all"), + ("", "all"), + ), +) def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.time = test_time result = RedditConnector.create_time_filter(downloader_mock) @@ -91,12 +101,15 @@ def test_create_time_filter(test_time: str, expected: str, downloader_mock: Magi assert result.name.lower() == expected -@pytest.mark.parametrize(('test_sort', 'expected'), ( - ('', 'hot'), - ('hot', 'hot'), - ('controversial', 'controversial'), - ('new', 'new'), -)) +@pytest.mark.parametrize( + ("test_sort", "expected"), + ( + ("", "hot"), + ("hot", "hot"), + ("controversial", "controversial"), + ("new", "new"), + ), +) def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock): downloader_mock.args.sort = test_sort result = RedditConnector.create_sort_filter(downloader_mock) @@ -105,13 +118,16 @@ def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: Magi assert result.name.lower() == expected -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('{POSTID}', '{SUBREDDIT}'), - ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'), - ('{POSTID}', 'test'), - ('{POSTID}', ''), - ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'), -)) +@pytest.mark.parametrize( + ("test_file_scheme", "test_folder_scheme"), + ( + ("{POSTID}", "{SUBREDDIT}"), + ("{REDDITOR}_{TITLE}_{POSTID}", "{SUBREDDIT}"), + ("{POSTID}", "test"), + ("{POSTID}", ""), + ("{POSTID}", "{SUBREDDIT}/{REDDITOR}"), + ), +) def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.folder_scheme = test_folder_scheme @@ -119,14 +135,17 @@ def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: s assert isinstance(result, FileNameFormatter) assert result.file_format_string == test_file_scheme - assert result.directory_format_string == test_folder_scheme.split('/') + assert result.directory_format_string == test_folder_scheme.split("/") -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), ( - ('', ''), - ('', '{SUBREDDIT}'), - ('test', '{SUBREDDIT}'), -)) +@pytest.mark.parametrize( + ("test_file_scheme", "test_folder_scheme"), + ( + ("", ""), + ("", "{SUBREDDIT}"), + ("test", "{SUBREDDIT}"), + ), +) def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock): downloader_mock.args.file_scheme = test_file_scheme downloader_mock.args.folder_scheme = test_folder_scheme @@ -141,15 +160,17 @@ def test_create_authenticator(downloader_mock: MagicMock): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_submission_ids', ( - ('lvpf4l',), - ('lvpf4l', 'lvqnsn'), - ('lvpf4l', 'lvqnsn', 'lvl9kd'), -)) +@pytest.mark.parametrize( + "test_submission_ids", + ( + ("lvpf4l",), + ("lvpf4l", "lvqnsn"), + ("lvpf4l", "lvqnsn", "lvl9kd"), + ), +) def test_get_submissions_from_link( - test_submission_ids: list[str], - reddit_instance: praw.Reddit, - downloader_mock: MagicMock): + test_submission_ids: list[str], reddit_instance: praw.Reddit, downloader_mock: MagicMock +): downloader_mock.args.link = test_submission_ids downloader_mock.reddit_instance = reddit_instance results = RedditConnector.get_submissions_from_link(downloader_mock) @@ -159,25 +180,28 @@ def test_get_submissions_from_link( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), ( - (('Futurology',), 10, 'hot', 'all', 10), - (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30), - (('Futurology',), 20, 'hot', 'all', 20), - (('Futurology', 'Python'), 10, 'hot', 'all', 20), - (('Futurology',), 100, 'hot', 'all', 100), - (('Futurology',), 0, 'hot', 'all', 0), - (('Futurology',), 10, 'top', 'all', 10), - (('Futurology',), 10, 'top', 'week', 10), - (('Futurology',), 10, 'hot', 'week', 10), -)) +@pytest.mark.parametrize( + ("test_subreddits", "limit", "sort_type", "time_filter", "max_expected_len"), + ( + (("Futurology",), 10, "hot", "all", 10), + (("Futurology", "Mindustry, Python"), 10, "hot", "all", 30), + (("Futurology",), 20, "hot", "all", 20), + (("Futurology", "Python"), 10, "hot", "all", 20), + (("Futurology",), 100, "hot", "all", 100), + (("Futurology",), 0, "hot", "all", 0), + (("Futurology",), 10, "top", "all", 10), + (("Futurology",), 10, "top", "week", 10), + (("Futurology",), 10, "hot", "week", 10), + ), +) def test_get_subreddit_normal( - test_subreddits: list[str], - limit: int, - sort_type: str, - time_filter: str, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, + test_subreddits: list[str], + limit: int, + sort_type: str, + time_filter: str, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, ): downloader_mock.args.limit = limit downloader_mock.args.sort = sort_type @@ -197,26 +221,29 @@ def test_get_subreddit_normal( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_time', 'test_delta'), ( - ('hour', timedelta(hours=1)), - ('day', timedelta(days=1)), - ('week', timedelta(days=7)), - ('month', timedelta(days=31)), - ('year', timedelta(days=365)), -)) +@pytest.mark.parametrize( + ("test_time", "test_delta"), + ( + ("hour", timedelta(hours=1)), + ("day", timedelta(days=1)), + ("week", timedelta(days=7)), + ("month", timedelta(days=31)), + ("year", timedelta(days=365)), + ), +) def test_get_subreddit_time_verification( - test_time: str, - test_delta: timedelta, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, + test_time: str, + test_delta: timedelta, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, ): downloader_mock.args.limit = 10 - downloader_mock.args.sort = 'top' + downloader_mock.args.sort = "top" downloader_mock.args.time = test_time downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock) downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock) downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock) - downloader_mock.args.subreddit = ['all'] + downloader_mock.args.subreddit = ["all"] downloader_mock.reddit_instance = reddit_instance results = RedditConnector.get_subreddits(downloader_mock) results = [sub for res1 in results for sub in res1] @@ -230,20 +257,23 @@ def test_get_subreddit_time_verification( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), ( - (('Python',), 'scraper', 10, 'all', 10), - (('Python',), '', 10, 'all', 0), - (('Python',), 'djsdsgewef', 10, 'all', 0), - (('Python',), 'scraper', 10, 'year', 10), -)) +@pytest.mark.parametrize( + ("test_subreddits", "search_term", "limit", "time_filter", "max_expected_len"), + ( + (("Python",), "scraper", 10, "all", 10), + (("Python",), "", 10, "all", 0), + (("Python",), "djsdsgewef", 10, "all", 0), + (("Python",), "scraper", 10, "year", 10), + ), +) def test_get_subreddit_search( - test_subreddits: list[str], - search_term: str, - time_filter: str, - limit: int, - max_expected_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, + test_subreddits: list[str], + search_term: str, + time_filter: str, + limit: int, + max_expected_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, ): downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.args.limit = limit @@ -265,17 +295,20 @@ def test_get_subreddit_search( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), ( - ('helen_darten', ('cuteanimalpics',), 10), - ('korfor', ('chess',), 100), -)) +@pytest.mark.parametrize( + ("test_user", "test_multireddits", "limit"), + ( + ("helen_darten", ("cuteanimalpics",), 10), + ("korfor", ("chess",), 100), + ), +) # Good sources at https://www.reddit.com/r/multihub/ def test_get_multireddits_public( - test_user: str, - test_multireddits: list[str], - limit: int, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, + test_user: str, + test_multireddits: list[str], + limit: int, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, ): downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT @@ -283,11 +316,10 @@ def test_get_multireddits_public( downloader_mock.args.multireddit = test_multireddits downloader_mock.args.user = [test_user] downloader_mock.reddit_instance = reddit_instance - downloader_mock.create_filtered_listing_generator.return_value = \ - RedditConnector.create_filtered_listing_generator( - downloader_mock, - reddit_instance.multireddit(redditor=test_user, name=test_multireddits[0]), - ) + downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.multireddit(redditor=test_user, name=test_multireddits[0]), + ) results = RedditConnector.get_multireddits(downloader_mock) results = [sub for res in results for sub in res] assert all([isinstance(res, praw.models.Submission) for res in results]) @@ -297,11 +329,14 @@ def test_get_multireddits_public( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_user', 'limit'), ( - ('danigirl3694', 10), - ('danigirl3694', 50), - ('CapitanHam', None), -)) +@pytest.mark.parametrize( + ("test_user", "limit"), + ( + ("danigirl3694", 10), + ("danigirl3694", 50), + ("CapitanHam", None), + ), +) def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit): downloader_mock.args.limit = limit downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot @@ -310,11 +345,10 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic downloader_mock.args.user = [test_user] downloader_mock.authenticated = False downloader_mock.reddit_instance = reddit_instance - downloader_mock.create_filtered_listing_generator.return_value = \ - RedditConnector.create_filtered_listing_generator( - downloader_mock, - reddit_instance.redditor(test_user).submissions, - ) + downloader_mock.create_filtered_listing_generator.return_value = RedditConnector.create_filtered_listing_generator( + downloader_mock, + reddit_instance.redditor(test_user).submissions, + ) results = RedditConnector.get_user_data(downloader_mock) results = assert_all_results_are_submissions(limit, results) assert all([res.author.name == test_user for res in results]) @@ -324,21 +358,24 @@ def test_get_user_submissions(test_user: str, limit: int, downloader_mock: Magic @pytest.mark.online @pytest.mark.reddit @pytest.mark.authenticated -@pytest.mark.parametrize('test_flag', ( - 'upvoted', - 'saved', -)) +@pytest.mark.parametrize( + "test_flag", + ( + "upvoted", + "saved", + ), +) def test_get_user_authenticated_lists( - test_flag: str, - downloader_mock: MagicMock, - authenticated_reddit_instance: praw.Reddit, + test_flag: str, + downloader_mock: MagicMock, + authenticated_reddit_instance: praw.Reddit, ): downloader_mock.args.__dict__[test_flag] = True downloader_mock.reddit_instance = authenticated_reddit_instance downloader_mock.args.limit = 10 downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot downloader_mock.sort_filter = RedditTypes.SortType.HOT - downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')] + downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, "me")] results = RedditConnector.get_user_data(downloader_mock) assert_all_results_are_submissions_or_comments(10, results) @@ -359,54 +396,63 @@ def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_red assert results -@pytest.mark.parametrize(('test_name', 'expected'), ( - ('Mindustry', 'Mindustry'), - ('Futurology', 'Futurology'), - ('r/Mindustry', 'Mindustry'), - ('TrollXChromosomes', 'TrollXChromosomes'), - ('r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'), - ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'), - ('https://www.reddit.com/r/Futurology/', 'Futurology'), - ('https://www.reddit.com/r/Futurology', 'Futurology'), -)) +@pytest.mark.parametrize( + ("test_name", "expected"), + ( + ("Mindustry", "Mindustry"), + ("Futurology", "Futurology"), + ("r/Mindustry", "Mindustry"), + ("TrollXChromosomes", "TrollXChromosomes"), + ("r/TrollXChromosomes", "TrollXChromosomes"), + ("https://www.reddit.com/r/TrollXChromosomes/", "TrollXChromosomes"), + ("https://www.reddit.com/r/TrollXChromosomes", "TrollXChromosomes"), + ("https://www.reddit.com/r/Futurology/", "Futurology"), + ("https://www.reddit.com/r/Futurology", "Futurology"), + ), +) def test_sanitise_subreddit_name(test_name: str, expected: str): result = RedditConnector.sanitise_subreddit_name(test_name) assert result == expected -@pytest.mark.parametrize(('test_subreddit_entries', 'expected'), ( - (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}), - (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}), - ([''], {''}), - (['test'], {'test'}), -)) +@pytest.mark.parametrize( + ("test_subreddit_entries", "expected"), + ( + (["test1", "test2", "test3"], {"test1", "test2", "test3"}), + (["test1,test2", "test3"], {"test1", "test2", "test3"}), + (["test1, test2", "test3"], {"test1", "test2", "test3"}), + (["test1; test2", "test3"], {"test1", "test2", "test3"}), + (["test1, test2", "test1,test2,test3", "test4"], {"test1", "test2", "test3", "test4"}), + ([""], {""}), + (["test"], {"test"}), + ), +) def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]): results = RedditConnector.split_args_input(test_subreddit_entries) assert results == expected def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path): - test_file = tmp_path / 'test.txt' - test_file.write_text('aaaaaa\nbbbbbb') + test_file = tmp_path / "test.txt" + test_file.write_text("aaaaaa\nbbbbbb") results = RedditConnector.read_id_files([str(test_file)]) - assert results == {'aaaaaa', 'bbbbbb'} + assert results == {"aaaaaa", "bbbbbb"} @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'nasa', - 'crowdstrike', - 'HannibalGoddamnit', -)) +@pytest.mark.parametrize( + "test_redditor_name", + ( + "nasa", + "crowdstrike", + "HannibalGoddamnit", + ), +) def test_check_user_existence_good( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @@ -414,42 +460,46 @@ def test_check_user_existence_good( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'lhnhfkuhwreolo', - 'adlkfmnhglojh', -)) +@pytest.mark.parametrize( + "test_redditor_name", + ( + "lhnhfkuhwreolo", + "adlkfmnhglojh", + ), +) def test_check_user_existence_nonexistent( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='Could not find'): + with pytest.raises(BulkDownloaderException, match="Could not find"): RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_redditor_name', ( - 'Bree-Boo', -)) +@pytest.mark.parametrize("test_redditor_name", ("Bree-Boo",)) def test_check_user_existence_banned( - test_redditor_name: str, - reddit_instance: praw.Reddit, - downloader_mock: MagicMock, + test_redditor_name: str, + reddit_instance: praw.Reddit, + downloader_mock: MagicMock, ): downloader_mock.reddit_instance = reddit_instance - with pytest.raises(BulkDownloaderException, match='is banned'): + with pytest.raises(BulkDownloaderException, match="is banned"): RedditConnector.check_user_existence(downloader_mock, test_redditor_name) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), ( - ('donaldtrump', 'cannot be found'), - ('submitters', 'private and cannot be scraped'), - ('lhnhfkuhwreolo', 'does not exist') -)) +@pytest.mark.parametrize( + ("test_subreddit_name", "expected_message"), + ( + ("donaldtrump", "cannot be found"), + ("submitters", "private and cannot be scraped"), + ("lhnhfkuhwreolo", "does not exist"), + ), +) def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit): test_subreddit = reddit_instance.subreddit(test_subreddit_name) with pytest.raises(BulkDownloaderException, match=expected_message): @@ -458,12 +508,15 @@ def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_subreddit_name', ( - 'Python', - 'Mindustry', - 'TrollXChromosomes', - 'all', -)) +@pytest.mark.parametrize( + "test_subreddit_name", + ( + "Python", + "Mindustry", + "TrollXChromosomes", + "all", + ), +) def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit): test_subreddit = reddit_instance.subreddit(test_subreddit_name) RedditConnector.check_subreddit_status(test_subreddit) diff --git a/tests/test_download_filter.py b/tests/test_download_filter.py index ce1b2602..07b7d670 100644 --- a/tests/test_download_filter.py +++ b/tests/test_download_filter.py @@ -11,55 +11,67 @@ @pytest.fixture() def download_filter() -> DownloadFilter: - return DownloadFilter(['mp4', 'mp3'], ['test.com', 'reddit.com', 'img.example.com']) + return DownloadFilter(["mp4", "mp3"], ["test.com", "reddit.com", "img.example.com"]) -@pytest.mark.parametrize(('test_extension', 'expected'), ( - ('.mp4', False), - ('.avi', True), - ('.random.mp3', False), - ('mp4', False), -)) +@pytest.mark.parametrize( + ("test_extension", "expected"), + ( + (".mp4", False), + (".avi", True), + (".random.mp3", False), + ("mp4", False), + ), +) def test_filter_extension(test_extension: str, expected: bool, download_filter: DownloadFilter): result = download_filter._check_extension(test_extension) assert result == expected -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('test.mp4', True), - ('http://reddit.com/test.mp4', False), - ('http://reddit.com/test.gif', False), - ('https://www.example.com/test.mp4', True), - ('https://www.example.com/test.png', True), - ('https://i.example.com/test.png', True), - ('https://img.example.com/test.png', False), - ('https://i.test.com/test.png', False), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("test.mp4", True), + ("http://reddit.com/test.mp4", False), + ("http://reddit.com/test.gif", False), + ("https://www.example.com/test.mp4", True), + ("https://www.example.com/test.png", True), + ("https://i.example.com/test.png", True), + ("https://img.example.com/test.png", False), + ("https://i.test.com/test.png", False), + ), +) def test_filter_domain(test_url: str, expected: bool, download_filter: DownloadFilter): result = download_filter._check_domain(test_url) assert result == expected -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('test.mp4', False), - ('test.gif', True), - ('https://www.example.com/test.mp4', False), - ('https://www.example.com/test.png', True), - ('http://reddit.com/test.mp4', False), - ('http://reddit.com/test.gif', False), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("test.mp4", False), + ("test.gif", True), + ("https://www.example.com/test.mp4", False), + ("https://www.example.com/test.png", True), + ("http://reddit.com/test.mp4", False), + ("http://reddit.com/test.gif", False), + ), +) def test_filter_all(test_url: str, expected: bool, download_filter: DownloadFilter): test_resource = Resource(MagicMock(), test_url, lambda: None) result = download_filter.check_resource(test_resource) assert result == expected -@pytest.mark.parametrize('test_url', ( - 'test.mp3', - 'test.mp4', - 'http://reddit.com/test.mp4', - 't', -)) +@pytest.mark.parametrize( + "test_url", + ( + "test.mp3", + "test.mp4", + "http://reddit.com/test.mp4", + "t", + ), +) def test_filter_empty_filter(test_url: str): download_filter = DownloadFilter() test_resource = Resource(MagicMock(), test_url, lambda: None) diff --git a/tests/test_downloader.py b/tests/test_downloader.py index e92d870a..7b81a853 100644 --- a/tests/test_downloader.py +++ b/tests/test_downloader.py @@ -18,7 +18,7 @@ @pytest.fixture() def args() -> Configuration: args = Configuration() - args.time_format = 'ISO' + args.time_format = "ISO" return args @@ -32,29 +32,32 @@ def downloader_mock(args: Configuration): return downloader_mock -@pytest.mark.parametrize(('test_ids', 'test_excluded', 'expected_len'), ( - (('aaaaaa',), (), 1), - (('aaaaaa',), ('aaaaaa',), 0), - ((), ('aaaaaa',), 0), - (('aaaaaa', 'bbbbbb'), ('aaaaaa',), 1), - (('aaaaaa', 'bbbbbb', 'cccccc'), ('aaaaaa',), 2), -)) -@patch('bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever') +@pytest.mark.parametrize( + ("test_ids", "test_excluded", "expected_len"), + ( + (("aaaaaa",), (), 1), + (("aaaaaa",), ("aaaaaa",), 0), + ((), ("aaaaaa",), 0), + (("aaaaaa", "bbbbbb"), ("aaaaaa",), 1), + (("aaaaaa", "bbbbbb", "cccccc"), ("aaaaaa",), 2), + ), +) +@patch("bdfr.site_downloaders.download_factory.DownloadFactory.pull_lever") def test_excluded_ids( - mock_function: MagicMock, - test_ids: tuple[str], - test_excluded: tuple[str], - expected_len: int, - downloader_mock: MagicMock, + mock_function: MagicMock, + test_ids: tuple[str], + test_excluded: tuple[str], + expected_len: int, + downloader_mock: MagicMock, ): downloader_mock.excluded_submission_ids = test_excluded mock_function.return_value = MagicMock() - mock_function.return_value.__name__ = 'test' + mock_function.return_value.__name__ = "test" test_submissions = [] for test_id in test_ids: m = MagicMock() m.id = test_id - m.subreddit.display_name.return_value = 'https://www.example.com/' + m.subreddit.display_name.return_value = "https://www.example.com/" m.__class__ = praw.models.Submission test_submissions.append(m) downloader_mock.reddit_lists = [test_submissions] @@ -65,32 +68,27 @@ def test_excluded_ids( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize('test_submission_id', ( - 'm1hqw6', -)) +@pytest.mark.parametrize("test_submission_id", ("m1hqw6",)) def test_mark_hard_link( - test_submission_id: str, - downloader_mock: MagicMock, - tmp_path: Path, - reddit_instance: praw.Reddit + test_submission_id: str, downloader_mock: MagicMock, tmp_path: Path, reddit_instance: praw.Reddit ): downloader_mock.reddit_instance = reddit_instance downloader_mock.args.make_hard_links = True downloader_mock.download_directory = tmp_path - downloader_mock.args.folder_scheme = '' - downloader_mock.args.file_scheme = '{POSTID}' + downloader_mock.args.folder_scheme = "" + downloader_mock.args.file_scheme = "{POSTID}" downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) submission = downloader_mock.reddit_instance.submission(id=test_submission_id) - original = Path(tmp_path, f'{test_submission_id}.png') + original = Path(tmp_path, f"{test_submission_id}.png") RedditDownloader._download_submission(downloader_mock, submission) assert original.exists() - downloader_mock.args.file_scheme = 'test2_{POSTID}' + downloader_mock.args.file_scheme = "test2_{POSTID}" downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) RedditDownloader._download_submission(downloader_mock, submission) test_file_1_stats = original.stat() - test_file_2_inode = Path(tmp_path, f'test2_{test_submission_id}.png').stat().st_ino + test_file_2_inode = Path(tmp_path, f"test2_{test_submission_id}.png").stat().st_ino assert test_file_1_stats.st_nlink == 2 assert test_file_1_stats.st_ino == test_file_2_inode @@ -98,20 +96,18 @@ def test_mark_hard_link( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'test_creation_date'), ( - ('ndzz50', 1621204841.0), -)) +@pytest.mark.parametrize(("test_submission_id", "test_creation_date"), (("ndzz50", 1621204841.0),)) def test_file_creation_date( - test_submission_id: str, - test_creation_date: float, - downloader_mock: MagicMock, - tmp_path: Path, - reddit_instance: praw.Reddit + test_submission_id: str, + test_creation_date: float, + downloader_mock: MagicMock, + tmp_path: Path, + reddit_instance: praw.Reddit, ): downloader_mock.reddit_instance = reddit_instance downloader_mock.download_directory = tmp_path - downloader_mock.args.folder_scheme = '' - downloader_mock.args.file_scheme = '{POSTID}' + downloader_mock.args.folder_scheme = "" + downloader_mock.args.file_scheme = "{POSTID}" downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) submission = downloader_mock.reddit_instance.submission(id=test_submission_id) @@ -123,27 +119,25 @@ def test_file_creation_date( def test_search_existing_files(): - results = RedditDownloader.scan_existing_files(Path('.')) + results = RedditDownloader.scan_existing_files(Path(".")) assert len(results.keys()) != 0 @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'test_hash'), ( - ('m1hqw6', 'a912af8905ae468e0121e9940f797ad7'), -)) +@pytest.mark.parametrize(("test_submission_id", "test_hash"), (("m1hqw6", "a912af8905ae468e0121e9940f797ad7"),)) def test_download_submission_hash_exists( - test_submission_id: str, - test_hash: str, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture + test_submission_id: str, + test_hash: str, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture, ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.args.no_dupes = True downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path @@ -153,47 +147,44 @@ def test_download_submission_hash_exists( folder_contents = list(tmp_path.iterdir()) output = capsys.readouterr() assert not folder_contents - assert re.search(r'Resource hash .*? downloaded elsewhere', output.out) + assert re.search(r"Resource hash .*? downloaded elsewhere", output.out) @pytest.mark.online @pytest.mark.reddit def test_download_submission_file_exists( - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture + downloader_mock: MagicMock, reddit_instance: praw.Reddit, tmp_path: Path, capsys: pytest.CaptureFixture ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path - submission = downloader_mock.reddit_instance.submission(id='m1hqw6') - Path(tmp_path, 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png').touch() + submission = downloader_mock.reddit_instance.submission(id="m1hqw6") + Path(tmp_path, "Arneeman_Metagaming isn't always a bad thing_m1hqw6.png").touch() RedditDownloader._download_submission(downloader_mock, submission) folder_contents = list(tmp_path.iterdir()) output = capsys.readouterr() assert len(folder_contents) == 1 - assert 'Arneeman_Metagaming isn\'t always a bad thing_m1hqw6.png'\ - ' from submission m1hqw6 already exists' in output.out + assert ( + "Arneeman_Metagaming isn't always a bad thing_m1hqw6.png" " from submission m1hqw6 already exists" in output.out + ) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected_files_len'), ( - ('ljyy27', 4), -)) +@pytest.mark.parametrize(("test_submission_id", "expected_files_len"), (("ljyy27", 4),)) def test_download_submission( - test_submission_id: str, - expected_files_len: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path): + test_submission_id: str, + expected_files_len: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, +): downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) @@ -204,103 +195,95 @@ def test_download_submission( @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'min_score'), ( - ('ljyy27', 1), -)) +@pytest.mark.parametrize(("test_submission_id", "min_score"), (("ljyy27", 1),)) def test_download_submission_min_score_above( - test_submission_id: str, - min_score: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture, + test_submission_id: str, + min_score: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture, ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.args.min_score = min_score downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) RedditDownloader._download_submission(downloader_mock, submission) output = capsys.readouterr() - assert 'filtered due to score' not in output.out + assert "filtered due to score" not in output.out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'min_score'), ( - ('ljyy27', 25), -)) +@pytest.mark.parametrize(("test_submission_id", "min_score"), (("ljyy27", 25),)) def test_download_submission_min_score_below( - test_submission_id: str, - min_score: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture, + test_submission_id: str, + min_score: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture, ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.args.min_score = min_score downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) RedditDownloader._download_submission(downloader_mock, submission) output = capsys.readouterr() - assert 'filtered due to score' in output.out + assert "filtered due to score" in output.out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'max_score'), ( - ('ljyy27', 25), -)) +@pytest.mark.parametrize(("test_submission_id", "max_score"), (("ljyy27", 25),)) def test_download_submission_max_score_below( - test_submission_id: str, - max_score: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture, + test_submission_id: str, + max_score: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture, ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.args.max_score = max_score downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) RedditDownloader._download_submission(downloader_mock, submission) output = capsys.readouterr() - assert 'filtered due to score' not in output.out + assert "filtered due to score" not in output.out @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'max_score'), ( - ('ljyy27', 1), -)) +@pytest.mark.parametrize(("test_submission_id", "max_score"), (("ljyy27", 1),)) def test_download_submission_max_score_above( - test_submission_id: str, - max_score: int, - downloader_mock: MagicMock, - reddit_instance: praw.Reddit, - tmp_path: Path, - capsys: pytest.CaptureFixture, + test_submission_id: str, + max_score: int, + downloader_mock: MagicMock, + reddit_instance: praw.Reddit, + tmp_path: Path, + capsys: pytest.CaptureFixture, ): setup_logging(3) downloader_mock.reddit_instance = reddit_instance downloader_mock.download_filter.check_url.return_value = True - downloader_mock.args.folder_scheme = '' + downloader_mock.args.folder_scheme = "" downloader_mock.args.max_score = max_score downloader_mock.file_name_formatter = RedditConnector.create_file_name_formatter(downloader_mock) downloader_mock.download_directory = tmp_path submission = downloader_mock.reddit_instance.submission(id=test_submission_id) RedditDownloader._download_submission(downloader_mock, submission) output = capsys.readouterr() - assert 'filtered due to score' in output.out + assert "filtered due to score" in output.out diff --git a/tests/test_file_name_formatter.py b/tests/test_file_name_formatter.py index 0492536d..c04e07de 100644 --- a/tests/test_file_name_formatter.py +++ b/tests/test_file_name_formatter.py @@ -22,26 +22,26 @@ @pytest.fixture() def submission() -> MagicMock: test = MagicMock() - test.title = 'name' - test.subreddit.display_name = 'randomreddit' - test.author.name = 'person' - test.id = '12345' + test.title = "name" + test.subreddit.display_name = "randomreddit" + test.author.name = "person" + test.id = "12345" test.score = 1000 - test.link_flair_text = 'test_flair' + test.link_flair_text = "test_flair" test.created_utc = datetime(2021, 4, 21, 9, 30, 0).timestamp() test.__class__ = praw.models.Submission return test def do_test_string_equality(result: Union[Path, str], expected: str) -> bool: - if platform.system() == 'Windows': + if platform.system() == "Windows": expected = FileNameFormatter._format_for_windows(expected) return str(result).endswith(expected) def do_test_path_equality(result: Path, expected: str) -> bool: - if platform.system() == 'Windows': - expected = expected.split('/') + if platform.system() == "Windows": + expected = expected.split("/") expected = [FileNameFormatter._format_for_windows(part) for part in expected] expected = Path(*expected) else: @@ -49,35 +49,41 @@ def do_test_path_equality(result: Path, expected: str) -> bool: return str(result).endswith(str(expected)) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def reddit_submission(reddit_instance: praw.Reddit) -> praw.models.Submission: - return reddit_instance.submission(id='w22m5l') - - -@pytest.mark.parametrize(('test_format_string', 'expected'), ( - ('{SUBREDDIT}', 'randomreddit'), - ('{REDDITOR}', 'person'), - ('{POSTID}', '12345'), - ('{UPVOTES}', '1000'), - ('{FLAIR}', 'test_flair'), - ('{DATE}', '2021-04-21T09:30:00'), - ('{REDDITOR}_{TITLE}_{POSTID}', 'person_name_12345'), -)) + return reddit_instance.submission(id="w22m5l") + + +@pytest.mark.parametrize( + ("test_format_string", "expected"), + ( + ("{SUBREDDIT}", "randomreddit"), + ("{REDDITOR}", "person"), + ("{POSTID}", "12345"), + ("{UPVOTES}", "1000"), + ("{FLAIR}", "test_flair"), + ("{DATE}", "2021-04-21T09:30:00"), + ("{REDDITOR}_{TITLE}_{POSTID}", "person_name_12345"), + ), +) def test_format_name_mock(test_format_string: str, expected: str, submission: MagicMock): - test_formatter = FileNameFormatter(test_format_string, '', 'ISO') + test_formatter = FileNameFormatter(test_format_string, "", "ISO") result = test_formatter._format_name(submission, test_format_string) assert do_test_string_equality(result, expected) -@pytest.mark.parametrize(('test_string', 'expected'), ( - ('', False), - ('test', False), - ('{POSTID}', True), - ('POSTID', False), - ('{POSTID}_test', True), - ('test_{TITLE}', True), - ('TITLE_POSTID', False), -)) +@pytest.mark.parametrize( + ("test_string", "expected"), + ( + ("", False), + ("test", False), + ("{POSTID}", True), + ("POSTID", False), + ("{POSTID}_test", True), + ("test_{TITLE}", True), + ("TITLE_POSTID", False), + ), +) def test_check_format_string_validity(test_string: str, expected: bool): result = FileNameFormatter.validate_string(test_string) assert result == expected @@ -85,84 +91,98 @@ def test_check_format_string_validity(test_string: str, expected: bool): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_format_string', 'expected'), ( - ('{SUBREDDIT}', 'formula1'), - ('{REDDITOR}', 'Kirsty-Blue'), - ('{POSTID}', 'w22m5l'), - ('{FLAIR}', 'Social Media rall'), - ('{SUBREDDIT}_{TITLE}', 'formula1_George Russel acknowledges the Twitter trend about him'), - ('{REDDITOR}_{TITLE}_{POSTID}', 'Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l') -)) +@pytest.mark.parametrize( + ("test_format_string", "expected"), + ( + ("{SUBREDDIT}", "formula1"), + ("{REDDITOR}", "Kirsty-Blue"), + ("{POSTID}", "w22m5l"), + ("{FLAIR}", "Social Media rall"), + ("{SUBREDDIT}_{TITLE}", "formula1_George Russel acknowledges the Twitter trend about him"), + ("{REDDITOR}_{TITLE}_{POSTID}", "Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l"), + ), +) def test_format_name_real(test_format_string: str, expected: str, reddit_submission: praw.models.Submission): - test_formatter = FileNameFormatter(test_format_string, '', '') + test_formatter = FileNameFormatter(test_format_string, "", "") result = test_formatter._format_name(reddit_submission, test_format_string) assert do_test_string_equality(result, expected) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'expected'), ( - ( - '{SUBREDDIT}', - '{POSTID}', - 'test/formula1/w22m5l.png', - ), - ( - '{SUBREDDIT}', - '{TITLE}_{POSTID}', - 'test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l.png', - ), +@pytest.mark.parametrize( + ("format_string_directory", "format_string_file", "expected"), ( - '{SUBREDDIT}', - '{REDDITOR}_{TITLE}_{POSTID}', - 'test/formula1/Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l.png', + ( + "{SUBREDDIT}", + "{POSTID}", + "test/formula1/w22m5l.png", + ), + ( + "{SUBREDDIT}", + "{TITLE}_{POSTID}", + "test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l.png", + ), + ( + "{SUBREDDIT}", + "{REDDITOR}_{TITLE}_{POSTID}", + "test/formula1/Kirsty-Blue_George Russel acknowledges the Twitter trend about him_w22m5l.png", + ), ), -)) +) def test_format_full( - format_string_directory: str, - format_string_file: str, - expected: str, - reddit_submission: praw.models.Submission): - test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) - test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO') - result = test_formatter.format_path(test_resource, Path('test')) + format_string_directory: str, format_string_file: str, expected: str, reddit_submission: praw.models.Submission +): + test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None) + test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO") + result = test_formatter.format_path(test_resource, Path("test")) assert do_test_path_equality(result, expected) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('format_string_directory', 'format_string_file'), ( - ('{SUBREDDIT}', '{POSTID}'), - ('{SUBREDDIT}', '{UPVOTES}'), - ('{SUBREDDIT}', '{UPVOTES}{POSTID}'), -)) +@pytest.mark.parametrize( + ("format_string_directory", "format_string_file"), + ( + ("{SUBREDDIT}", "{POSTID}"), + ("{SUBREDDIT}", "{UPVOTES}"), + ("{SUBREDDIT}", "{UPVOTES}{POSTID}"), + ), +) def test_format_full_conform( - format_string_directory: str, - format_string_file: str, - reddit_submission: praw.models.Submission): - test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) - test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO') - test_formatter.format_path(test_resource, Path('test')) + format_string_directory: str, format_string_file: str, reddit_submission: praw.models.Submission +): + test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None) + test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO") + test_formatter.format_path(test_resource, Path("test")) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('format_string_directory', 'format_string_file', 'index', 'expected'), ( - ('{SUBREDDIT}', '{POSTID}', None, 'test/formula1/w22m5l.png'), - ('{SUBREDDIT}', '{POSTID}', 1, 'test/formula1/w22m5l_1.png'), - ('{SUBREDDIT}', '{POSTID}', 2, 'test/formula1/w22m5l_2.png'), - ('{SUBREDDIT}', '{TITLE}_{POSTID}', 2, 'test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l_2.png'), -)) +@pytest.mark.parametrize( + ("format_string_directory", "format_string_file", "index", "expected"), + ( + ("{SUBREDDIT}", "{POSTID}", None, "test/formula1/w22m5l.png"), + ("{SUBREDDIT}", "{POSTID}", 1, "test/formula1/w22m5l_1.png"), + ("{SUBREDDIT}", "{POSTID}", 2, "test/formula1/w22m5l_2.png"), + ( + "{SUBREDDIT}", + "{TITLE}_{POSTID}", + 2, + "test/formula1/George Russel acknowledges the Twitter trend about him_w22m5l_2.png", + ), + ), +) def test_format_full_with_index_suffix( - format_string_directory: str, - format_string_file: str, - index: Optional[int], - expected: str, - reddit_submission: praw.models.Submission, + format_string_directory: str, + format_string_file: str, + index: Optional[int], + expected: str, + reddit_submission: praw.models.Submission, ): - test_resource = Resource(reddit_submission, 'i.reddit.com/blabla.png', lambda: None) - test_formatter = FileNameFormatter(format_string_file, format_string_directory, 'ISO') - result = test_formatter.format_path(test_resource, Path('test'), index) + test_resource = Resource(reddit_submission, "i.reddit.com/blabla.png", lambda: None) + test_formatter = FileNameFormatter(format_string_file, format_string_directory, "ISO") + result = test_formatter.format_path(test_resource, Path("test"), index) assert do_test_path_equality(result, expected) @@ -170,99 +190,114 @@ def test_format_multiple_resources(): mocks = [] for i in range(1, 5): new_mock = MagicMock() - new_mock.url = 'https://example.com/test.png' - new_mock.extension = '.png' - new_mock.source_submission.title = 'test' + new_mock.url = "https://example.com/test.png" + new_mock.extension = ".png" + new_mock.source_submission.title = "test" new_mock.source_submission.__class__ = praw.models.Submission mocks.append(new_mock) - test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') - results = test_formatter.format_resource_paths(mocks, Path('.')) + test_formatter = FileNameFormatter("{TITLE}", "", "ISO") + results = test_formatter.format_resource_paths(mocks, Path(".")) results = set([str(res[0].name) for res in results]) - expected = {'test_1.png', 'test_2.png', 'test_3.png', 'test_4.png'} + expected = {"test_1.png", "test_2.png", "test_3.png", "test_4.png"} assert results == expected -@pytest.mark.parametrize(('test_filename', 'test_ending'), ( - ('A' * 300, '.png'), - ('A' * 300, '_1.png'), - ('a' * 300, '_1000.jpeg'), - ('πŸ˜πŸ’•βœ¨' * 100, '_1.png'), -)) +@pytest.mark.parametrize( + ("test_filename", "test_ending"), + ( + ("A" * 300, ".png"), + ("A" * 300, "_1.png"), + ("a" * 300, "_1000.jpeg"), + ("πŸ˜πŸ’•βœ¨" * 100, "_1.png"), + ), +) def test_limit_filename_length(test_filename: str, test_ending: str): - result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path('.')) + result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path(".")) assert len(result.name) <= 255 - assert len(result.name.encode('utf-8')) <= 255 + assert len(result.name.encode("utf-8")) <= 255 assert len(str(result)) <= FileNameFormatter.find_max_path_length() assert isinstance(result, Path) -@pytest.mark.parametrize(('test_filename', 'test_ending', 'expected_end'), ( - ('test_aaaaaa', '_1.png', 'test_aaaaaa_1.png'), - ('test_aataaa', '_1.png', 'test_aataaa_1.png'), - ('test_abcdef', '_1.png', 'test_abcdef_1.png'), - ('test_aaaaaa', '.png', 'test_aaaaaa.png'), - ('test', '_1.png', 'test_1.png'), - ('test_m1hqw6', '_1.png', 'test_m1hqw6_1.png'), - ('A' * 300 + '_bbbccc', '.png', '_bbbccc.png'), - ('A' * 300 + '_bbbccc', '_1000.jpeg', '_bbbccc_1000.jpeg'), - ('πŸ˜πŸ’•βœ¨' * 100 + '_aaa1aa', '_1.png', '_aaa1aa_1.png'), -)) +@pytest.mark.parametrize( + ("test_filename", "test_ending", "expected_end"), + ( + ("test_aaaaaa", "_1.png", "test_aaaaaa_1.png"), + ("test_aataaa", "_1.png", "test_aataaa_1.png"), + ("test_abcdef", "_1.png", "test_abcdef_1.png"), + ("test_aaaaaa", ".png", "test_aaaaaa.png"), + ("test", "_1.png", "test_1.png"), + ("test_m1hqw6", "_1.png", "test_m1hqw6_1.png"), + ("A" * 300 + "_bbbccc", ".png", "_bbbccc.png"), + ("A" * 300 + "_bbbccc", "_1000.jpeg", "_bbbccc_1000.jpeg"), + ("πŸ˜πŸ’•βœ¨" * 100 + "_aaa1aa", "_1.png", "_aaa1aa_1.png"), + ), +) def test_preserve_id_append_when_shortening(test_filename: str, test_ending: str, expected_end: str): - result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path('.')) + result = FileNameFormatter.limit_file_name_length(test_filename, test_ending, Path(".")) assert len(result.name) <= 255 - assert len(result.name.encode('utf-8')) <= 255 + assert len(result.name.encode("utf-8")) <= 255 assert result.name.endswith(expected_end) assert len(str(result)) <= FileNameFormatter.find_max_path_length() -@pytest.mark.skipif(sys.platform == 'win32', reason='Test broken on windows github') +@pytest.mark.skipif(sys.platform == "win32", reason="Test broken on windows github") def test_shorten_filename_real(submission: MagicMock, tmp_path: Path): - submission.title = 'A' * 500 - submission.author.name = 'test' - submission.subreddit.display_name = 'test' - submission.id = 'BBBBBB' - test_resource = Resource(submission, 'www.example.com/empty', lambda: None, '.jpeg') - test_formatter = FileNameFormatter('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}', 'ISO') + submission.title = "A" * 500 + submission.author.name = "test" + submission.subreddit.display_name = "test" + submission.id = "BBBBBB" + test_resource = Resource(submission, "www.example.com/empty", lambda: None, ".jpeg") + test_formatter = FileNameFormatter("{REDDITOR}_{TITLE}_{POSTID}", "{SUBREDDIT}", "ISO") result = test_formatter.format_path(test_resource, tmp_path) result.parent.mkdir(parents=True) result.touch() -@pytest.mark.parametrize(('test_name', 'test_ending'), ( - ('a', 'b'), - ('a', '_bbbbbb.jpg'), - ('a' * 20, '_bbbbbb.jpg'), - ('a' * 50, '_bbbbbb.jpg'), - ('a' * 500, '_bbbbbb.jpg'), -)) +@pytest.mark.parametrize( + ("test_name", "test_ending"), + ( + ("a", "b"), + ("a", "_bbbbbb.jpg"), + ("a" * 20, "_bbbbbb.jpg"), + ("a" * 50, "_bbbbbb.jpg"), + ("a" * 500, "_bbbbbb.jpg"), + ), +) def test_shorten_path(test_name: str, test_ending: str, tmp_path: Path): result = FileNameFormatter.limit_file_name_length(test_name, test_ending, tmp_path) assert len(str(result.name)) <= 255 - assert len(str(result.name).encode('UTF-8')) <= 255 - assert len(str(result.name).encode('cp1252')) <= 255 + assert len(str(result.name).encode("UTF-8")) <= 255 + assert len(str(result.name).encode("cp1252")) <= 255 assert len(str(result)) <= FileNameFormatter.find_max_path_length() -@pytest.mark.parametrize(('test_string', 'expected'), ( - ('test', 'test'), - ('test😍', 'test'), - ('test.png', 'test.png'), - ('test*', 'test'), - ('test**', 'test'), - ('test?*', 'test'), - ('test_???.png', 'test_.png'), - ('test_???😍.png', 'test_.png'), -)) +@pytest.mark.parametrize( + ("test_string", "expected"), + ( + ("test", "test"), + ("test😍", "test"), + ("test.png", "test.png"), + ("test*", "test"), + ("test**", "test"), + ("test?*", "test"), + ("test_???.png", "test_.png"), + ("test_???😍.png", "test_.png"), + ), +) def test_format_file_name_for_windows(test_string: str, expected: str): result = FileNameFormatter._format_for_windows(test_string) assert result == expected -@pytest.mark.parametrize(('test_string', 'expected'), ( - ('test', 'test'), - ('test😍', 'test'), - ('😍', ''), -)) +@pytest.mark.parametrize( + ("test_string", "expected"), + ( + ("test", "test"), + ("test😍", "test"), + ("😍", ""), + ), +) def test_strip_emojies(test_string: str, expected: str): result = FileNameFormatter._strip_emojis(test_string) assert result == expected @@ -270,121 +305,151 @@ def test_strip_emojies(test_string: str, expected: str): @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_submission_id', 'expected'), ( - ('mfuteh', { - 'title': 'Why Do Interviewers Ask Linked List Questions?', - 'redditor': 'mjgardner', - }), -)) +@pytest.mark.parametrize( + ("test_submission_id", "expected"), + ( + ( + "mfuteh", + { + "title": "Why Do Interviewers Ask Linked List Questions?", + "redditor": "mjgardner", + }, + ), + ), +) def test_generate_dict_for_submission(test_submission_id: str, expected: dict, reddit_instance: praw.Reddit): test_submission = reddit_instance.submission(id=test_submission_id) - test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') + test_formatter = FileNameFormatter("{TITLE}", "", "ISO") result = test_formatter._generate_name_dict_from_submission(test_submission) assert all([result.get(key) == expected[key] for key in expected.keys()]) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_comment_id', 'expected'), ( - ('gsq0yuw', { - 'title': 'Why Do Interviewers Ask Linked List Questions?', - 'redditor': 'Doctor-Dapper', - 'postid': 'gsq0yuw', - 'flair': '', - }), -)) +@pytest.mark.parametrize( + ("test_comment_id", "expected"), + ( + ( + "gsq0yuw", + { + "title": "Why Do Interviewers Ask Linked List Questions?", + "redditor": "Doctor-Dapper", + "postid": "gsq0yuw", + "flair": "", + }, + ), + ), +) def test_generate_dict_for_comment(test_comment_id: str, expected: dict, reddit_instance: praw.Reddit): test_comment = reddit_instance.comment(id=test_comment_id) - test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') + test_formatter = FileNameFormatter("{TITLE}", "", "ISO") result = test_formatter._generate_name_dict_from_comment(test_comment) assert all([result.get(key) == expected[key] for key in expected.keys()]) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme', 'test_comment_id', 'expected_name'), ( - ('{POSTID}', '', 'gsoubde', 'gsoubde.json'), - ('{REDDITOR}_{POSTID}', '', 'gsoubde', 'DELETED_gsoubde.json'), -)) +@pytest.mark.parametrize( + ("test_file_scheme", "test_folder_scheme", "test_comment_id", "expected_name"), + ( + ("{POSTID}", "", "gsoubde", "gsoubde.json"), + ("{REDDITOR}_{POSTID}", "", "gsoubde", "DELETED_gsoubde.json"), + ), +) def test_format_archive_entry_comment( - test_file_scheme: str, - test_folder_scheme: str, - test_comment_id: str, - expected_name: str, - tmp_path: Path, - reddit_instance: praw.Reddit, + test_file_scheme: str, + test_folder_scheme: str, + test_comment_id: str, + expected_name: str, + tmp_path: Path, + reddit_instance: praw.Reddit, ): test_comment = reddit_instance.comment(id=test_comment_id) - test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, 'ISO') - test_entry = Resource(test_comment, '', lambda: None, '.json') + test_formatter = FileNameFormatter(test_file_scheme, test_folder_scheme, "ISO") + test_entry = Resource(test_comment, "", lambda: None, ".json") result = test_formatter.format_path(test_entry, tmp_path) assert do_test_string_equality(result, expected_name) -@pytest.mark.parametrize(('test_folder_scheme', 'expected'), ( - ('{REDDITOR}/{SUBREDDIT}', 'person/randomreddit'), - ('{POSTID}/{SUBREDDIT}/{REDDITOR}', '12345/randomreddit/person'), -)) +@pytest.mark.parametrize( + ("test_folder_scheme", "expected"), + ( + ("{REDDITOR}/{SUBREDDIT}", "person/randomreddit"), + ("{POSTID}/{SUBREDDIT}/{REDDITOR}", "12345/randomreddit/person"), + ), +) def test_multilevel_folder_scheme( - test_folder_scheme: str, - expected: str, - tmp_path: Path, - submission: MagicMock, + test_folder_scheme: str, + expected: str, + tmp_path: Path, + submission: MagicMock, ): - test_formatter = FileNameFormatter('{POSTID}', test_folder_scheme, 'ISO') + test_formatter = FileNameFormatter("{POSTID}", test_folder_scheme, "ISO") test_resource = MagicMock() test_resource.source_submission = submission - test_resource.extension = '.png' + test_resource.extension = ".png" result = test_formatter.format_path(test_resource, tmp_path) result = result.relative_to(tmp_path) assert do_test_path_equality(result.parent, expected) - assert len(result.parents) == (len(expected.split('/')) + 1) + assert len(result.parents) == (len(expected.split("/")) + 1) -@pytest.mark.parametrize(('test_name_string', 'expected'), ( - ('test', 'test'), - ('😍', '😍'), - ('test😍', 'test😍'), - ('test😍 ’', 'test😍 ’'), - ('test😍 \\u2019', 'test😍 ’'), - ('Using that real good [1\\4]', 'Using that real good [1\\4]'), -)) +@pytest.mark.parametrize( + ("test_name_string", "expected"), + ( + ("test", "test"), + ("😍", "😍"), + ("test😍", "test😍"), + ("test😍 ’", "test😍 ’"), + ("test😍 \\u2019", "test😍 ’"), + ("Using that real good [1\\4]", "Using that real good [1\\4]"), + ), +) def test_preserve_emojis(test_name_string: str, expected: str, submission: MagicMock): submission.title = test_name_string - test_formatter = FileNameFormatter('{TITLE}', '', 'ISO') - result = test_formatter._format_name(submission, '{TITLE}') + test_formatter = FileNameFormatter("{TITLE}", "", "ISO") + result = test_formatter._format_name(submission, "{TITLE}") assert do_test_string_equality(result, expected) -@pytest.mark.parametrize(('test_string', 'expected'), ( - ('test \\u2019', 'test ’'), - ('My cat\\u2019s paws are so cute', 'My cat’s paws are so cute'), -)) +@pytest.mark.parametrize( + ("test_string", "expected"), + ( + ("test \\u2019", "test ’"), + ("My cat\\u2019s paws are so cute", "My cat’s paws are so cute"), + ), +) def test_convert_unicode_escapes(test_string: str, expected: str): result = FileNameFormatter._convert_unicode_escapes(test_string) assert result == expected -@pytest.mark.parametrize(('test_datetime', 'expected'), ( - (datetime(2020, 1, 1, 8, 0, 0), '2020-01-01T08:00:00'), - (datetime(2020, 1, 1, 8, 0), '2020-01-01T08:00:00'), - (datetime(2021, 4, 21, 8, 30, 21), '2021-04-21T08:30:21'), -)) +@pytest.mark.parametrize( + ("test_datetime", "expected"), + ( + (datetime(2020, 1, 1, 8, 0, 0), "2020-01-01T08:00:00"), + (datetime(2020, 1, 1, 8, 0), "2020-01-01T08:00:00"), + (datetime(2021, 4, 21, 8, 30, 21), "2021-04-21T08:30:21"), + ), +) def test_convert_timestamp(test_datetime: datetime, expected: str): test_timestamp = test_datetime.timestamp() - test_formatter = FileNameFormatter('{POSTID}', '', 'ISO') + test_formatter = FileNameFormatter("{POSTID}", "", "ISO") result = test_formatter._convert_timestamp(test_timestamp) assert result == expected -@pytest.mark.parametrize(('test_time_format', 'expected'), ( - ('ISO', '2021-05-02T13:33:00'), - ('%Y_%m', '2021_05'), - ('%Y-%m-%d', '2021-05-02'), -)) +@pytest.mark.parametrize( + ("test_time_format", "expected"), + ( + ("ISO", "2021-05-02T13:33:00"), + ("%Y_%m", "2021_05"), + ("%Y-%m-%d", "2021-05-02"), + ), +) def test_time_string_formats(test_time_format: str, expected: str): test_time = datetime(2021, 5, 2, 13, 33) - test_formatter = FileNameFormatter('{TITLE}', '', test_time_format) + test_formatter = FileNameFormatter("{TITLE}", "", test_time_format) result = test_formatter._convert_timestamp(test_time.timestamp()) assert result == expected @@ -395,29 +460,32 @@ def test_get_max_path_length(): def test_windows_max_path(tmp_path: Path): - with unittest.mock.patch('platform.system', return_value='Windows'): - with unittest.mock.patch('bdfr.file_name_formatter.FileNameFormatter.find_max_path_length', return_value=260): - result = FileNameFormatter.limit_file_name_length('test' * 100, '_1.png', tmp_path) + with unittest.mock.patch("platform.system", return_value="Windows"): + with unittest.mock.patch("bdfr.file_name_formatter.FileNameFormatter.find_max_path_length", return_value=260): + result = FileNameFormatter.limit_file_name_length("test" * 100, "_1.png", tmp_path) assert len(str(result)) <= 260 assert len(result.name) <= (260 - len(str(tmp_path))) @pytest.mark.online @pytest.mark.reddit -@pytest.mark.parametrize(('test_reddit_id', 'test_downloader', 'expected_names'), ( - ('gphmnr', YtdlpFallback, {'He has a lot to say today.mp4'}), - ('d0oir2', YtdlpFallback, {"Crunk's finest moment. Welcome to the new subreddit!.mp4"}), - ('jiecu', SelfPost, {'[deleted by user].txt'}), -)) +@pytest.mark.parametrize( + ("test_reddit_id", "test_downloader", "expected_names"), + ( + ("gphmnr", YtdlpFallback, {"He has a lot to say today.mp4"}), + ("d0oir2", YtdlpFallback, {"Crunk's finest moment. Welcome to the new subreddit!.mp4"}), + ("jiecu", SelfPost, {"[deleted by user].txt"}), + ), +) def test_name_submission( - test_reddit_id: str, - test_downloader: Type[BaseDownloader], - expected_names: set[str], - reddit_instance: praw.reddit.Reddit, + test_reddit_id: str, + test_downloader: Type[BaseDownloader], + expected_names: set[str], + reddit_instance: praw.reddit.Reddit, ): test_submission = reddit_instance.submission(id=test_reddit_id) test_resources = test_downloader(test_submission).find_resources() - test_formatter = FileNameFormatter('{TITLE}', '', '') - results = test_formatter.format_resource_paths(test_resources, Path('.')) + test_formatter = FileNameFormatter("{TITLE}", "", "") + results = test_formatter.format_resource_paths(test_resources, Path(".")) results = set([r[0].name for r in results]) assert results == expected_names diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 71bdca13..3014c374 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -14,38 +14,58 @@ @pytest.fixture() def example_config() -> configparser.ConfigParser: out = configparser.ConfigParser() - config_dict = {'DEFAULT': {'user_token': 'example'}} + config_dict = {"DEFAULT": {"user_token": "example"}} out.read_dict(config_dict) return out @pytest.mark.online -@pytest.mark.parametrize('test_scopes', ( - {'history', }, - {'history', 'creddits'}, - {'account', 'flair'}, - {'*', }, -)) +@pytest.mark.parametrize( + "test_scopes", + ( + { + "history", + }, + {"history", "creddits"}, + {"account", "flair"}, + { + "*", + }, + ), +) def test_check_scopes(test_scopes: set[str]): OAuth2Authenticator._check_scopes(test_scopes) -@pytest.mark.parametrize(('test_scopes', 'expected'), ( - ('history', {'history', }), - ('history creddits', {'history', 'creddits'}), - ('history, creddits, account', {'history', 'creddits', 'account'}), - ('history,creddits,account,flair', {'history', 'creddits', 'account', 'flair'}), -)) +@pytest.mark.parametrize( + ("test_scopes", "expected"), + ( + ( + "history", + { + "history", + }, + ), + ("history creddits", {"history", "creddits"}), + ("history, creddits, account", {"history", "creddits", "account"}), + ("history,creddits,account,flair", {"history", "creddits", "account", "flair"}), + ), +) def test_split_scopes(test_scopes: str, expected: set[str]): result = OAuth2Authenticator.split_scopes(test_scopes) assert result == expected @pytest.mark.online -@pytest.mark.parametrize('test_scopes', ( - {'random', }, - {'scope', 'another_scope'}, -)) +@pytest.mark.parametrize( + "test_scopes", + ( + { + "random", + }, + {"scope", "another_scope"}, + ), +) def test_check_scopes_bad(test_scopes: set[str]): with pytest.raises(BulkDownloaderException): OAuth2Authenticator._check_scopes(test_scopes) @@ -56,16 +76,16 @@ def test_token_manager_read(example_config: configparser.ConfigParser): mock_authoriser.refresh_token = None test_manager = OAuth2TokenManager(example_config, MagicMock()) test_manager.pre_refresh_callback(mock_authoriser) - assert mock_authoriser.refresh_token == example_config.get('DEFAULT', 'user_token') + assert mock_authoriser.refresh_token == example_config.get("DEFAULT", "user_token") def test_token_manager_write(example_config: configparser.ConfigParser, tmp_path: Path): - test_path = tmp_path / 'test.cfg' + test_path = tmp_path / "test.cfg" mock_authoriser = MagicMock() - mock_authoriser.refresh_token = 'changed_token' + mock_authoriser.refresh_token = "changed_token" test_manager = OAuth2TokenManager(example_config, test_path) test_manager.post_refresh_callback(mock_authoriser) - assert example_config.get('DEFAULT', 'user_token') == 'changed_token' - with test_path.open('r') as file: + assert example_config.get("DEFAULT", "user_token") == "changed_token" + with test_path.open("r") as file: file_contents = file.read() - assert 'user_token = changed_token' in file_contents + assert "user_token = changed_token" in file_contents diff --git a/tests/test_resource.py b/tests/test_resource.py index f3bbc9ac..146d9a02 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -8,18 +8,21 @@ from bdfr.resource import Resource -@pytest.mark.parametrize(('test_url', 'expected'), ( - ('test.png', '.png'), - ('another.mp4', '.mp4'), - ('test.jpeg', '.jpeg'), - ('http://www.random.com/resource.png', '.png'), - ('https://www.resource.com/test/example.jpg', '.jpg'), - ('hard.png.mp4', '.mp4'), - ('https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99', '.png'), - ('test.jpg#test', '.jpg'), - ('test.jpg?width=247#test', '.jpg'), - ('https://www.test.com/test/test2/example.png?random=test#thing', '.png'), -)) +@pytest.mark.parametrize( + ("test_url", "expected"), + ( + ("test.png", ".png"), + ("another.mp4", ".mp4"), + ("test.jpeg", ".jpeg"), + ("http://www.random.com/resource.png", ".png"), + ("https://www.resource.com/test/example.jpg", ".jpg"), + ("hard.png.mp4", ".mp4"), + ("https://preview.redd.it/7zkmr1wqqih61.png?width=237&format=png&auto=webp&s=19de214e634cbcad99", ".png"), + ("test.jpg#test", ".jpg"), + ("test.jpg?width=247#test", ".jpg"), + ("https://www.test.com/test/test2/example.png?random=test#thing", ".png"), + ), +) def test_resource_get_extension(test_url: str, expected: str): test_resource = Resource(MagicMock(), test_url, lambda: None) result = test_resource._determine_extension() @@ -27,9 +30,10 @@ def test_resource_get_extension(test_url: str, expected: str): @pytest.mark.online -@pytest.mark.parametrize(('test_url', 'expected_hash'), ( - ('https://www.iana.org/_img/2013.1/iana-logo-header.svg', '426b3ac01d3584c820f3b7f5985d6623'), -)) +@pytest.mark.parametrize( + ("test_url", "expected_hash"), + (("https://www.iana.org/_img/2013.1/iana-logo-header.svg", "426b3ac01d3584c820f3b7f5985d6623"),), +) def test_download_online_resource(test_url: str, expected_hash: str): test_resource = Resource(MagicMock(), test_url, Resource.retry_download(test_url)) test_resource.download() diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..e5dce995 --- /dev/null +++ b/tox.ini @@ -0,0 +1,16 @@ +[tox] +envlist = + format + +[testenv:format] +deps = + isort + black +skip_install = True +commands = + isort bdfr tests + black bdfr tests --line-length 120 + +[isort] +profile = black +multi_line_output = 3 \ No newline at end of file