From 7cd21f3d9f400f405c61371f8ab9bd12c74744a6 Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Thu, 21 Dec 2023 12:58:40 -0800 Subject: [PATCH 1/2] fix to_snake_case for all caps --- ariadne_codegen/utils.py | 4 +++- tests/test_utils.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/utils.py b/ariadne_codegen/utils.py index 19433ed8..063fb50a 100644 --- a/ariadne_codegen/utils.py +++ b/ariadne_codegen/utils.py @@ -39,8 +39,10 @@ def str_to_snake_case(name: str) -> str: # upper-case letters, excluding last letter if it is followed by a lower-case letter uppercase_words = r"[A-Z]+(?=[A-Z][a-z]|\d|\W|$)" numbers = r"\d+" + # match upper-case only segments, this is placed after uppercase_words so that match gets priority + uppercase_only = r"[A-Z]+" - words = re.findall(rf"{lowercase_words}|{uppercase_words}|{numbers}", name) + words = re.findall(rf"{lowercase_words}|{uppercase_words}|{numbers}|{uppercase_only}", name) return "_".join(map(str.lower, words)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7190e721..6681f5ad 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -115,6 +115,9 @@ class TestClass(Efg): ("testWORD123", "test_word_123"), ("TestWORD123", "test_word_123"), ("TESTWord123", "test_word_123"), + ("ASC_NULLS_FIRST", "asc_nulls_first"), + ("DESC", "desc"), + ("DESC_NULLS_LAST", "desc_nulls_last"), ], ) def test_str_to_snake_case_returns_correct_string(name, expected_result): From 7eae2168fc6f85c79a2218f1a9d458bd0642b4ba Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Thu, 21 Dec 2023 12:59:18 -0800 Subject: [PATCH 2/2] added convert_enums_names_to_upper_snake_case option and modified enum generator to also call process_name --- ariadne_codegen/client_generators/enums.py | 21 ++++++++++++++++---- ariadne_codegen/client_generators/package.py | 7 ++++++- ariadne_codegen/settings.py | 7 +++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/ariadne_codegen/client_generators/enums.py b/ariadne_codegen/client_generators/enums.py index 8ae34d73..39ce6da4 100644 --- a/ariadne_codegen/client_generators/enums.py +++ b/ariadne_codegen/client_generators/enums.py @@ -12,15 +12,21 @@ generate_module, ) from ..plugins.manager import PluginManager +from ..utils import process_name from .constants import ENUM_CLASS, ENUM_MODULE class EnumsGenerator: def __init__( - self, schema: GraphQLSchema, plugin_manager: Optional[PluginManager] = None + self, schema: GraphQLSchema, + convert_to_snake_case: bool = False, + convert_to_upper_snake_case: bool = False, + plugin_manager: Optional[PluginManager] = None ) -> None: self.schema = schema self.plugin_manager = plugin_manager + self.convert_to_snake_case = convert_to_snake_case + self.convert_to_upper_snake_case = convert_to_upper_snake_case self._generated_public_names: List[str] = [] self._imports: List[ast.ImportFrom] = [ @@ -32,14 +38,14 @@ def __init__( def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module: class_defs = self._filter_class_defs(types_to_include) - self._generated_public_names = [class_def.name for class_def in class_defs] + self._generated_public_names = [class_def.name for class_def in class_defs] module = generate_module( body=cast(List[ast.stmt], self._imports) + cast(List[ast.stmt], class_defs) ) if self.plugin_manager: - module = self.plugin_manager.generate_enums_module(module) + module = self.plugin_manager.generate_enums_module(module) return module @@ -58,7 +64,14 @@ def _parse_enum_definition(self, definition: GraphQLEnumType) -> ast.ClassDef: for lineno, (val_name, val_def) in enumerate( definition.values.items(), start=1 ): - name = val_name if not iskeyword(val_name) else val_name + "_" + name = process_name( + val_name, + convert_to_snake_case=self.convert_to_snake_case or self.convert_to_upper_snake_case, + plugin_manager=self.plugin_manager, + node=val_def.ast_node, + ) + if self.convert_to_upper_snake_case: + name = name.upper() fields.append( generate_assign([name], generate_constant(val_def.value), lineno) ) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index e17ec5e6..0d613342 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -364,7 +364,12 @@ def get_package_generator( custom_scalars=settings.scalars, plugin_manager=plugin_manager, ) - enums_generator = EnumsGenerator(schema=schema, plugin_manager=plugin_manager) + enums_generator = EnumsGenerator( + schema=schema, + plugin_manager=plugin_manager, + convert_to_snake_case=settings.convert_to_snake_case, + convert_to_upper_snake_case=settings.convert_enums_names_to_upper_snake_case + ) input_types_generator = InputTypesGenerator( schema=schema, enums_module=settings.enums_module_name, diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 96c967a0..a1d2d74d 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -65,6 +65,7 @@ class ClientSettings(BaseSettings): fragments_module_name: str = "fragments" include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE) convert_to_snake_case: bool = True + convert_enums_names_to_upper_snake_case: bool = False include_all_inputs: bool = True include_all_enums: bool = True async_client: bool = True @@ -148,6 +149,11 @@ def used_settings_message(self) -> str: if self.convert_to_snake_case else "Not converting fields and arguments name to snake case." ) + enum_upper_snake_case_msg = ( + "Converting enum names to upper snake case." + if self.convert_enums_names_to_upper_snake_case + else "Not converting enum names to upper snake case." + ) async_client_msg = ( "Generating async client." if self.async_client @@ -180,6 +186,7 @@ def used_settings_message(self) -> str: Generating fragments into '{self.fragments_module_name}.py'. Comments type: {self.include_comments.value} {snake_case_msg} + {enum_upper_snake_case_msg} {async_client_msg} {files_to_include_msg} {plugins_msg}