Skip to content

Commit

Permalink
simplify readers/graphql.py - collect fragments as is (no transformat…
Browse files Browse the repository at this point in the history
…ions)

- simplify graphql reader - we no longer transform framgments in any way, just map them 1:1 to hiku ast
- add fragments support to export/graphql.py
- implement proper fragments merging in QueryMerger - fragments, union fragments and interface fragments are merged properly
- add tests for query merger
  • Loading branch information
m.kindritskiy committed Jun 17, 2024
1 parent 60ef8da commit 89244a0
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 416 deletions.
4 changes: 2 additions & 2 deletions hiku/endpoint/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def _init_execution_context(
execution_context.query = execution_context.operation.query

# TODO: move this into read operation
collector = QueryMerger(execution_context.graph)
execution_context.query = collector.merge(execution_context.query)
merger = QueryMerger(execution_context.graph)
execution_context.query = merger.merge(execution_context.query)

op = execution_context.operation
if op.type not in (OperationType.QUERY, OperationType.MUTATION):
Expand Down
26 changes: 23 additions & 3 deletions hiku/export/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from graphql.language import ast

from hiku.query import Fragment

from ..query import (
QueryVisitor,
Field,
Expand Down Expand Up @@ -69,10 +71,28 @@ def visit_link(self, obj: Link) -> ast.FieldNode:
selection_set=self.visit(obj.node),
)

def visit_fragment(self, obj: Fragment) -> Any:
if obj.name is None:
return ast.InlineFragmentNode(
type_condition=ast.NamedTypeNode(
name=_name(obj.type_name),
)
if obj.type_name is not None
else None,
selection_set=self.visit(obj.node),
)

return ast.FragmentSpreadNode(name=_name(obj.name))

def visit_node(self, obj: Node) -> ast.SelectionSetNode:
return ast.SelectionSetNode(
selections=[self.visit(f) for f in obj.fields],
)
selections = []
for f in obj.fields:
selections.append(self.visit(f))

for fr in obj.fragments:
selections.append(self.visit(fr))

return ast.SelectionSetNode(selections=selections)


def export(query: Node) -> ast.DocumentNode:
Expand Down
206 changes: 149 additions & 57 deletions hiku/merge.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
from collections import deque
from contextlib import contextmanager
import typing as t

from collections import deque
from contextlib import contextmanager
from collections.abc import Sequence

from hiku.enum import BaseEnum
from hiku.graph import Graph, Interface, Link as GraphLink, LinkType, Union
from hiku.graph import Graph, Interface, Union
from hiku.query import Link, Field, Fragment, Node, QueryVisitor
from hiku.types import (
InterfaceRef,
InterfaceRefMeta,
OptionalMeta,
Record,
RecordMeta,
RefMeta,
RefMetaTypes,
SequenceMeta,
TypeRef,
TypeRefMeta,
UnionRef,
UnionRefMeta,
get_type,
)


# TODO simplify
def get_ref_type(types, type_, name):
if isinstance(type_, RecordMeta):
type_ = type_.__field_types__[name]
Expand All @@ -37,10 +42,8 @@ def get_ref_type(types, type_, name):
return type_.__type__
return get_ref_type(types, type_.__field_types__[name], name)
elif isinstance(type_, UnionRefMeta):
# TODO: finish
return type_
elif isinstance(type_, InterfaceRefMeta):
# TODO: finish
return type_
elif isinstance(type_, TypeRefMeta):
type_ = get_type(types, type_)
Expand All @@ -49,6 +52,53 @@ def get_ref_type(types, type_, name):
raise AssertionError(repr(type_))


def is_fragment_condition_match(
graph: Graph,
parent_type: RefMetaTypes,
fragment: Fragment,
) -> bool:
"""Check if a fragment is applicable to the given type."""
type_name = fragment.type_name
if not type_name:
return True

if isinstance(parent_type, TypeRefMeta):
if type_name == parent_type.__type_name__:
return True

if isinstance(parent_type, (UnionRefMeta, InterfaceRefMeta)):
return False

return False


def is_abstract_type(
type_,
) -> bool:
return isinstance(type_, (InterfaceRefMeta, UnionRefMeta))


def is_interface(
type_,
) -> bool:
return isinstance(type_, InterfaceRefMeta)


def is_union(
type_,
) -> bool:
return isinstance(type_, UnionRefMeta)


def is_match_type(
type_: t.Union[TypeRef, UnionRef, InterfaceRef], fragment: Fragment
) -> bool:
if fragment.type_name is None:
return True # TODO: test this case

return fragment.type_name == type_.__type_name__


class QueryMerger(QueryVisitor):
def __init__(self, graph: Graph):
self.graph = graph
Expand All @@ -57,8 +107,6 @@ def __init__(self, graph: Graph):
t.Union[t.Type[Record], Union, Interface, BaseEnum]
] = deque([self._types["__root__"]])

self._visited_fields = deque([set()])

def merge(self, query: Node) -> Node:
return self.visit(query)

Expand All @@ -84,35 +132,113 @@ def _collect_fields(self, node: Node, fields, links, fragments) -> None:
name = field.alias or field.name

if isinstance(field, Field):
if name not in self._visited_fields[-1]:
if name not in fields:
fields[name] = field
elif isinstance(field, Link):
links.setdefault(name, []).append(field)

fragments_by_type_name = {}

# TODO: determine fragments parsing rules
fragments_to_process = []
for fr in node.fragments:
if is_fragment_condition_match(self.graph, self.parent_type, fr):
self._collect_fields(fr.node, fields, links, fragments)
if is_match_type(self.parent_type, fr):
self._expand_fragment(fr, fields, links, fragments_to_process)
else:
fragments_by_type_name.setdefault(fr.type_name, []).append(fr)

for frs in fragments_by_type_name.values():
fragments.append(self._merge_fragments(frs))
fragments_to_process.append(fr)

if isinstance(self.parent_type, InterfaceRefMeta):
self._merge_interface_fragments(
self.graph.interfaces_map[self.parent_type.__type_name__],
fragments_to_process,
fields,
links,
fragments,
)
else:
fragments_by_type_name = {}

for fr in fragments_to_process:
if is_fragment_condition_match(
self.graph, self.parent_type, fr
):
self._collect_fields(fr.node, fields, links, fragments)
else:
fragments_by_type_name.setdefault(fr.type_name, []).append(
fr
)

for frs in fragments_by_type_name.values():
fragments.append(self._merge_fragments(frs))

def visit_link(self, obj: Link) -> t.Any:
with self.with_type_info(obj):
return super().visit_link(obj)

def _merge_interface_fragments(
self,
interface: Interface,
fragments_to_process: t.List[Fragment],
fields,
links,
fragments,
):
"""For each shared field inside fragment, move it to the interface node
leaving only the unique fields in the fragment.
"""

fragments_by_type_name = {}

for fragment in fragments_to_process:
fragment_fields = []
for field in fragment.node.fields:
name = field.alias or field.name
is_interface_field = name in interface.fields_map
if isinstance(field, Field):
if is_interface_field:
if name not in fields:
fields[name] = field
else:
fragment_fields.append(field)
elif isinstance(field, Link):
if is_interface_field:
links.setdefault(name, []).append(field)
else:
fragment_fields.append(field)

# if fragment still has own fields, create a new fragment with them
if fragment_fields:
new_fragment = fragment.copy(node=Node(fragment_fields))
fragments_by_type_name.setdefault(
fragment.type_name, []
).append(new_fragment)

for frs in fragments_by_type_name.values():
fragments.append(self._merge_fragments(frs))

def _merge_fragments(self, fragments: t.List[Fragment]) -> Fragment:
fr = fragments[0]
new = Fragment(
fr.name,
fr.type_name,
self._merge_nodes([fr.node for fr in fragments]).fields,
return fragments[0].copy(
node=self._merge_nodes([fr.node for fr in fragments]),
)
return new

def _expand_fragment(self, fragment: Fragment, fields, links, fragments):
"""Given a fragment, expand it and collect fields, links and fragments.
Fragment is disposed.
Example, expand QueryFragment:
Given:
query { ...QueryFragment }
fragment QueryFragment on Query { id ... on Query { name } }
Result:
query { id name }
"""
for field in fragment.node.fields:
name = field.alias or field.name

if isinstance(field, Field):
if name not in fields:
fields[name] = field
elif isinstance(field, Link):
links.setdefault(name, []).append(field)

for fr in fragment.node.fragments:
fragments.append(fr)

def _merge_nodes(self, nodes: t.List[Node]) -> Node:
"""Collect fields from multiple nodes and return new node with them.
Expand Down Expand Up @@ -151,37 +277,3 @@ def _merge_links(self, links: t.List[Link]) -> Link:
node=self._merge_nodes([link.node for link in links]),
directives=tuple(directives),
)


def is_abstract_link(graphql_link: GraphLink) -> bool:
return graphql_link.type_info.type_enum in (
LinkType.UNION,
LinkType.INTERFACE,
)


def is_fragment_condition_match(
graph: Graph,
runtime_type,
fragment: Fragment,
) -> bool:
"""Check if a fragment is applicable to the given type."""
type_name = fragment.type_name
if not type_name:
return True

if isinstance(runtime_type, TypeRefMeta):
if type_name == runtime_type.__type_name__:
return True

if isinstance(runtime_type, UnionRefMeta):
# We do not merge abstract type fragmetns, but we must merge same cocrete type fragments
return False
union = graph.unions_map[runtime_type.__type_name__]
return type_name in union.types
if isinstance(runtime_type, InterfaceRefMeta):
return False
interface_types = graph.interfaces_types[runtime_type.__type_name__]
return type_name in interface_types

return False
5 changes: 2 additions & 3 deletions hiku/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

from .directives import Directive
from .utils import cached_property
from hiku import directives

T = t.TypeVar("T", bound="Base")

Expand Down Expand Up @@ -274,11 +273,11 @@ def __init__(
self,
name: t.Optional[str],
type_name: t.Optional[str],
fields: t.List[FieldOrLink],
node: Node,
) -> None:
self.name = name # if None, it's an inline fragment
self.type_name = type_name
self.node = Node(fields)
self.node = node

def accept(self, visitor: "QueryVisitor") -> t.Any:
return visitor.visit_fragment(self)
Expand Down
Loading

0 comments on commit 89244a0

Please sign in to comment.