Skip to content

Commit

Permalink
added docstring templating
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 3, 2024
1 parent 792abd9 commit 18854f0
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions google/cloud/bigtable/data/_sync/cross_sync/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any:
return ast_node.value
if isinstance(ast_node, ast.List):
return [cls._convert_ast_to_py(node) for node in ast_node.elts]
if isinstance(ast_node, ast.Tuple):
return tuple(cls._convert_ast_to_py(node) for node in ast_node.elts)
if isinstance(ast_node, ast.Dict):
return {
cls._convert_ast_to_py(k): cls._convert_ast_to_py(v)
Expand All @@ -175,6 +177,7 @@ class ExportSync(AstDecorator):
Args:
path: path to output the generated sync class
replace_symbols: a dict of symbols and replacements to use when generating sync class
docstring_format_vars: a dict of variables to replace in the docstring
mypy_ignore: set of mypy errors to ignore in the generated file
include_file_imports: if True, include top-level imports from the file in the generated sync class
add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync.<name>
Expand All @@ -185,12 +188,16 @@ def __init__(
path: str,
*,
replace_symbols: dict[str, str] | None = None,
docstring_format_vars: dict[str, tuple[str, str]] | None = None,
mypy_ignore: Sequence[str] = (),
include_file_imports: bool = True,
add_mapping_for_name: str | None = None,
):
self.path = path
self.replace_symbols = replace_symbols
docstring_format_vars = docstring_format_vars or {}
self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()}
self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()}
self.mypy_ignore = mypy_ignore
self.include_file_imports = include_file_imports
self.add_mapping_for_name = add_mapping_for_name
Expand All @@ -206,6 +213,8 @@ def async_decorator(self):
def decorator(cls):
if new_mapping:
CrossSync.add_mapping(new_mapping, cls)
if self.async_docstring_format_vars:
cls.__doc__ = cls.__doc__.format(**self.async_docstring_format_vars)
return cls

return decorator
Expand Down Expand Up @@ -257,6 +266,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit(
wrapped_node
)
if self.sync_docstring_format_vars:
docstring = ast.get_docstring(wrapped_node)
wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars)
return wrapped_node


Expand All @@ -267,6 +279,7 @@ class Convert(AstDecorator):
Args:
sync_name: use a new name for the sync method
replace_symbols: a dict of symbols and replacements to use when generating sync method
docstring_format_vars: a dict of variables to replace in the docstring
rm_aio: if True, automatically strip all asyncio keywords from method. If False,
only the signature `async def` is stripped. Other keywords must be wrapped in
CrossSync.rm_aio() calls to be removed.
Expand All @@ -277,10 +290,14 @@ def __init__(
*,
sync_name: str | None = None,
replace_symbols: dict[str, str] | None = None,
docstring_format_vars: dict[str, tuple[str, str]] | None = None,
rm_aio: bool = False,
):
self.sync_name = sync_name
self.replace_symbols = replace_symbols
docstring_format_vars = docstring_format_vars or {}
self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()}
self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()}
self.rm_aio = rm_aio

def sync_ast_transform(self, wrapped_node, transformers_globals):
Expand Down Expand Up @@ -310,8 +327,25 @@ def sync_ast_transform(self, wrapped_node, transformers_globals):
if self.replace_symbols:
replacer = transformers_globals["SymbolReplacer"]
wrapped_node = replacer(self.replace_symbols).visit(wrapped_node)
# update docstring if specified
if self.sync_docstring_format_vars:
docstring = ast.get_docstring(wrapped_node)
wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars)
return wrapped_node

def async_decorator(self):
"""
If docstring_format_vars are provided, update the docstring of the async method
"""

if self.async_docstring_format_vars:
def decorator(f):
f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars)
return f
return decorator
else:
return None


class DropMethod(AstDecorator):
"""
Expand Down

0 comments on commit 18854f0

Please sign in to comment.