diff --git a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py index 2788ffec4..bf4d855de 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/_decorators.py @@ -160,6 +160,8 @@ def _convert_ast_to_py(cls, ast_node: ast.expr | None) -> Any: return ast_node.value if isinstance(ast_node, ast.List): return [cls._convert_ast_to_py(node) for node in ast_node.elts] + if isinstance(ast_node, ast.Tuple): + return tuple(cls._convert_ast_to_py(node) for node in ast_node.elts) if isinstance(ast_node, ast.Dict): return { cls._convert_ast_to_py(k): cls._convert_ast_to_py(v) @@ -175,6 +177,7 @@ class ExportSync(AstDecorator): Args: path: path to output the generated sync class replace_symbols: a dict of symbols and replacements to use when generating sync class + docstring_format_vars: a dict of variables to replace in the docstring mypy_ignore: set of mypy errors to ignore in the generated file include_file_imports: if True, include top-level imports from the file in the generated sync class add_mapping_for_name: when given, will add a new attribute to CrossSync, so the original class and its sync version can be accessed from CrossSync. @@ -185,12 +188,16 @@ def __init__( path: str, *, replace_symbols: dict[str, str] | None = None, + docstring_format_vars: dict[str, tuple[str, str]] | None = None, mypy_ignore: Sequence[str] = (), include_file_imports: bool = True, add_mapping_for_name: str | None = None, ): self.path = path self.replace_symbols = replace_symbols + docstring_format_vars = docstring_format_vars or {} + self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} + self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} self.mypy_ignore = mypy_ignore self.include_file_imports = include_file_imports self.add_mapping_for_name = add_mapping_for_name @@ -206,6 +213,8 @@ def async_decorator(self): def decorator(cls): if new_mapping: CrossSync.add_mapping(new_mapping, cls) + if self.async_docstring_format_vars: + cls.__doc__ = cls.__doc__.format(**self.async_docstring_format_vars) return cls return decorator @@ -257,6 +266,9 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): wrapped_node = transformers_globals["CrossSyncMethodDecoratorHandler"]().visit( wrapped_node ) + if self.sync_docstring_format_vars: + docstring = ast.get_docstring(wrapped_node) + wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) return wrapped_node @@ -267,6 +279,7 @@ class Convert(AstDecorator): Args: sync_name: use a new name for the sync method replace_symbols: a dict of symbols and replacements to use when generating sync method + docstring_format_vars: a dict of variables to replace in the docstring rm_aio: if True, automatically strip all asyncio keywords from method. If False, only the signature `async def` is stripped. Other keywords must be wrapped in CrossSync.rm_aio() calls to be removed. @@ -277,10 +290,14 @@ def __init__( *, sync_name: str | None = None, replace_symbols: dict[str, str] | None = None, + docstring_format_vars: dict[str, tuple[str, str]] | None = None, rm_aio: bool = False, ): self.sync_name = sync_name self.replace_symbols = replace_symbols + docstring_format_vars = docstring_format_vars or {} + self.async_docstring_format_vars = {k: v[0] for k, v in docstring_format_vars.items()} + self.sync_docstring_format_vars = {k: v[1] for k, v in docstring_format_vars.items()} self.rm_aio = rm_aio def sync_ast_transform(self, wrapped_node, transformers_globals): @@ -310,8 +327,25 @@ def sync_ast_transform(self, wrapped_node, transformers_globals): if self.replace_symbols: replacer = transformers_globals["SymbolReplacer"] wrapped_node = replacer(self.replace_symbols).visit(wrapped_node) + # update docstring if specified + if self.sync_docstring_format_vars: + docstring = ast.get_docstring(wrapped_node) + wrapped_node.body[0].value.s = docstring.format(**self.sync_docstring_format_vars) return wrapped_node + def async_decorator(self): + """ + If docstring_format_vars are provided, update the docstring of the async method + """ + + if self.async_docstring_format_vars: + def decorator(f): + f.__doc__ = f.__doc__.format(**self.async_docstring_format_vars) + return f + return decorator + else: + return None + class DropMethod(AstDecorator): """