From effb3092339b483ac264a427385054c95a249c37 Mon Sep 17 00:00:00 2001 From: "m.kindritskiy" Date: Wed, 1 Nov 2023 11:04:02 +0200 Subject: [PATCH] export scalars in sdl --- hiku/federation/scalars.py | 15 +++++++++++++- hiku/federation/sdl.py | 29 ++++++++++++++++++++++------ tests/test_federation/test_sdl_v1.py | 2 ++ tests/test_federation/test_sdl_v2.py | 24 ++++++++++++++++++++++- 4 files changed, 62 insertions(+), 8 deletions(-) diff --git a/hiku/federation/scalars.py b/hiku/federation/scalars.py index 15d77279..b1f7c99b 100644 --- a/hiku/federation/scalars.py +++ b/hiku/federation/scalars.py @@ -1,11 +1,21 @@ -from typing import Any +from typing import Any, Callable, List, Type from hiku.scalar import Scalar, scalar +def federation_version(versions: List[int]) -> Callable[[Type], Type]: + def decorator(cls: Type[_Scalar]) -> Type[_Scalar]: + cls.__federation_versions__ = versions + return cls + + return decorator + + class _Scalar(Scalar): """Implements dummy `parse` and `serialize` methods for scalars.""" + __federation_versions__: List[int] + @classmethod def parse(cls, value: str) -> Any: return value @@ -15,15 +25,18 @@ def serialize(cls, value: str) -> Any: return value +@federation_version([2]) class _Any(_Scalar): ... +@federation_version([1, 2]) @scalar(name="_FieldSet") class FieldSet(_Scalar): ... +@federation_version([2]) @scalar(name="link__Import") class LinkImport(_Scalar): ... diff --git a/hiku/federation/sdl.py b/hiku/federation/sdl.py index 14abe641..e6bbec6a 100644 --- a/hiku/federation/sdl.py +++ b/hiku/federation/sdl.py @@ -53,13 +53,16 @@ GenericMeta, ) + +def _name(value: t.Optional[str]) -> t.Optional[ast.NameNode]: + return ast.NameNode(value=value) if value is not None else None + + _BUILTIN_DIRECTIVES_NAMES = { directive.__directive_info__.name for directive in _BUILTIN_DIRECTIVES } - -def _name(value: t.Optional[str]) -> t.Optional[ast.NameNode]: - return ast.NameNode(value=value) if value is not None else None +_BUILTIN_SCALARS = [ast.ScalarTypeDefinitionNode(name=_name("Any"))] @t.overload @@ -237,7 +240,7 @@ def visit_graph(self, graph: Graph) -> List[ast.DefinitionNode]: if self.mutation_graph else [] ), - self.get_any_scalar(), + *self.export_scalars(), *self.export_enums(), *self.export_unions(), self.get_service_type(), @@ -373,8 +376,22 @@ def get_custom_directives(self) -> t.List[ast.DirectiveDefinitionNode]: ) return directives - def get_any_scalar(self) -> ast.ScalarTypeDefinitionNode: - return ast.ScalarTypeDefinitionNode(name=_name("Any")) + def export_scalars(self) -> t.List[ast.ScalarTypeDefinitionNode]: + scalars = [] + for scalar in self.graph.scalars: + if hasattr(scalar, "__federation_versions__"): + if ( + self.federation_version + not in scalar.__federation_versions__ + ): # noqa: E501 + continue + + scalars.append( + ast.ScalarTypeDefinitionNode( + name=_name(scalar.__type_name__), + ) + ) + return _BUILTIN_SCALARS + scalars def export_enums(self) -> t.List[ast.EnumTypeDefinitionNode]: enums = [] diff --git a/tests/test_federation/test_sdl_v1.py b/tests/test_federation/test_sdl_v1.py index c826751a..528349b9 100644 --- a/tests/test_federation/test_sdl_v1.py +++ b/tests/test_federation/test_sdl_v1.py @@ -100,6 +100,8 @@ def execute(graph, query_string): } scalar Any + + scalar _FieldSet union Bucket = Cart diff --git a/tests/test_federation/test_sdl_v2.py b/tests/test_federation/test_sdl_v2.py index 351edb96..d1afe367 100644 --- a/tests/test_federation/test_sdl_v2.py +++ b/tests/test_federation/test_sdl_v2.py @@ -1,4 +1,5 @@ import textwrap +from typing import Any from hiku.directives import Location from hiku.enum import Enum @@ -16,6 +17,7 @@ TypeRef, Optional, ) +from hiku.scalar import Scalar from hiku.graph import apply from hiku.federation.graph import FederatedNode, Graph @@ -41,6 +43,16 @@ class Custom(FederationSchemaDirective): ... +class Long(Scalar): + @classmethod + def parse(cls, value: Any) -> int: + return int(value) + + @classmethod + def serialize(cls, value: Any) -> int: + return int(value) + + SaveOrderResultNode = Node( "SaveOrderResult", [ @@ -62,6 +74,7 @@ class Custom(FederationSchemaDirective): ], directives=[Key('id')]), FederatedNode('CartItem', [ Field('id', Integer, field_resolver), + Field('productId', Long, field_resolver), ], directives=[Key('id', resolvable=False)]), Root([ Link( @@ -83,7 +96,7 @@ class Custom(FederationSchemaDirective): Union('Bucket', ['Cart']) ], enums=[ Enum('Currency', ['UAH', 'USD']) -]) +], scalars=[Long]) MUTATION_GRAPH = Graph.from_graph( @@ -135,6 +148,7 @@ class Custom(FederationSchemaDirective): type CartItem @key(fields: "id", resolvable: false) { id: Int! + productId: Long! } extend type Query { @@ -143,6 +157,14 @@ class Custom(FederationSchemaDirective): %s scalar Any + scalar Long + + scalar _Any + + scalar _FieldSet + + scalar link__Import + enum Currency { UAH USD