diff --git a/procrastinate/sql/__init__.py b/procrastinate/sql/__init__.py index 73c002540..10539224d 100644 --- a/procrastinate/sql/__init__.py +++ b/procrastinate/sql/__init__.py @@ -1,6 +1,8 @@ import re import sys -from typing import Dict +from typing import Dict, cast + +from typing_extensions import LiteralString # https://github.com/pypa/twine/pull/551 if sys.version_info[:2] < (3, 9): # coverage: exclude @@ -11,7 +13,7 @@ QUERIES_REGEX = re.compile(r"(?:\n|^)-- ([a-z0-9_]+) --\n(?:-- .+\n)*", re.MULTILINE) -def parse_query_file(query_file: str) -> Dict["str", "str"]: +def parse_query_file(query_file: str) -> Dict["str", LiteralString]: split = iter(QUERIES_REGEX.split(query_file)) next(split) # Consume the header of the file result = {} @@ -19,13 +21,16 @@ def parse_query_file(query_file: str) -> Dict["str", "str"]: while True: key = next(split) value = next(split).strip() - result[key] = value + # procrastinate takes full responsibility for the queries, we + # can safely vouch for them being as safe as if they were + # defined in the code itself. + result[key] = cast(LiteralString, value) except StopIteration: pass return result -def get_queries() -> Dict["str", "str"]: +def get_queries() -> Dict["str", LiteralString]: return parse_query_file( (importlib_resources.files("procrastinate.sql") / "queries.sql").read_text( encoding="utf-8"