Skip to content

Commit

Permalink
rewrite fragments
Browse files Browse the repository at this point in the history
In this PR we addressed the issue with broken fragments merging. Historically, hiku has not supported unions/interfaces so fragments merging were ok. But after we added unions/intefaces support + fragments support, things went in the wrong direction.

Fragments merging (fields merging to be more accurate) has its own rules in graphql spec and we did not complied. So why thought ?

Hiku arhitecture was build without fragments in mind, that is - engine and denormalization modules can not work with fragments. So to avoid complex rewriting for engine and denormalization we decided to merge fragments to avoid fields duplication - the most hard thing to support in engine.

In this PR, we try to mimic graphql-py behavior or merging fields:
- We remove complext merging in query parsing step, now all fields and fragments are stored in hiku ast as is
- In engine we adapted SplitQuery to group fields and link:
  - Fields are leaf nodes so it is safe to take first field from fields_info (list of field instances) as long as field args are the same
  - Same link is collected and will be merged before `schedule_link` call, so that no duplicated links will be processed
  - Cache works as previous but hashes are changed sihce link node fields are merged
  - Denormalize adapted to work with links/fields resolved multiple times and checks that field already presents in result skipping its serialization

So basically we moved handling of fields merging to engine/denormalization stage and simplified parsing.

Some other changes:
 - Refactor SplitQuery types, introduce FieldInfo and LinkInfo instead of tuples
 - Refactor GroupQuery to use FieldInfo/LinkInfo
 - Fix cache tests
 - Drop fragments hack from result.py:Proxy since we now provide proper Proxy to index for each fragment
 - Add name for Fragment, if name is None - this is an InlineFragment
 - Node.fragments_map only returns named fragments map
  • Loading branch information
m.kindritskiy committed Jun 12, 2024
1 parent 593d052 commit 9dcf896
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 182 deletions.
33 changes: 24 additions & 9 deletions hiku/denormalize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Field as GraphField,
)
from ..query import (
Fragment,
QueryVisitor,
Link,
Field,
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, graph: Graph, result: Proxy) -> None:
self._unions = graph.unions_map
self._enums = graph.enums_map
self._result = result
self._index = result.__idx__
self._type: t.Deque[
t.Union[t.Type[Record], Union, Interface, BaseEnum]
] = deque([self._types["__root__"]])
Expand All @@ -72,16 +74,22 @@ def visit_node(self, obj: Node) -> t.Any:
for item in obj.fields:
self.visit(item)

for fr in obj.fragments:
self.visit_fragment(fr)

def visit_fragment(self, obj: Fragment) -> None:
type_name = None
if isinstance(self._data[-1], Proxy):
type_name = self._data[-1].__ref__.node

for fr in obj.fragments:
if type_name is not None and type_name != fr.type_name:
# do not visit fragment if type specified and not match
continue
if type_name is not None and type_name != obj.type_name:
# for unions we must visit only fragments with same type as node
return

self.visit(fr)
self._data.append(Proxy(self._index, self._data[-1].__ref__, obj.node))
for item in obj.node.fields:
self.visit(item)
self._data.pop()

def visit_field(self, obj: Field) -> None:
if isinstance(self._data[-1], Proxy):
Expand All @@ -92,9 +100,10 @@ def visit_field(self, obj: Field) -> None:
node = self._graph.nodes_map[type_name]
graph_field = node.fields_map[obj.name]

self._res[-1][obj.result_key] = serialize_value(
self._graph, graph_field, self._data[-1][obj.result_key]
)
if obj.result_key not in self._res[-1]:
self._res[-1][obj.result_key] = serialize_value(
self._graph, graph_field, self._data[-1][obj.result_key]
)
else:
# Record type itself does not have custom serialization
# TODO: support Scalar/Enum types in Record
Expand All @@ -112,7 +121,13 @@ def visit_link(self, obj: Link) -> None:

if isinstance(type_, RefMeta):
self._type.append(get_type(self._types, type_))
self._res.append({})
# if we already visited this link, just reuse the result
if obj.result_key not in self._res[-1]:
self._res.append({})
else:
res = self._res[-1][obj.result_key]
self._res.append(res)

self._data.append(self._data[-1][obj.result_key])
super().visit_link(obj)
self._data.pop()
Expand Down
131 changes: 93 additions & 38 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
from .executors.base import SyncAsyncExecutor
from .operation import Operation, OperationType
from .query import (
Fragment,
Node as QueryNode,
Field as QueryField,
Link as QueryLink,
QueryTransformer,
QueryVisitor,
merge_links,
)
from .graph import (
FieldType,
Expand Down Expand Up @@ -171,10 +173,16 @@ def visit_link(self, obj: QueryLink) -> QueryLink:
return obj.copy(node=node, options=options)


# query.Link is considered a complex Field if present in tuple
FieldGroup = Tuple[Field, Union[QueryField, QueryLink]]
CallableFieldGroup = Tuple[Callable, Field, Union[QueryField, QueryLink]]
LinkGroup = Tuple[Link, QueryLink]
@dataclasses.dataclass
class FieldInfo:
graph_field: Field
query_field: Union[QueryField, QueryLink]


@dataclasses.dataclass
class LinkInfo:
graph_link: Link
query_link: QueryLink


class SplitQuery(QueryVisitor):
Expand All @@ -184,21 +192,38 @@ class SplitQuery(QueryVisitor):

def __init__(self, graph_node: Node) -> None:
self._node = graph_node
self._fields: List[CallableFieldGroup] = []
self._links: List[LinkGroup] = []
self.links_map: Dict[str, List[LinkInfo]] = {}
self.fields_map: Dict[str, List[Tuple[Callable, FieldInfo]]] = {}

def split(
self, query_node: QueryNode
) -> Tuple[List[CallableFieldGroup], List[LinkGroup]]:
def split(self, query_node: QueryNode) -> "SplitQuery":
for item in query_node.fields:
self.visit(item)

for fr in query_node.fragments:
if fr.type_name != self._node.name:
continue
self.visit(fr)
# node fragments can have different type_names
# if node is union or inteface
if fr.type_name == self._node.name:
self.visit(fr)

for field, fields in self.fields_map.items():
if len(set([f.query_field.index_key for _, f in fields])) > 1:
raise ValueError(
f"Can not use same field '{field}' with "
"different arguments."
" Use different field names (aliases) or arguments."
)

for link, links in self.links_map.items():
if len(set([ln.query_link.index_key for ln in links])) > 1:
raise ValueError(
f"Can not use same field '{field}' with "
"different arguments."
" Use different field names (aliases) or arguments."
)
return self

return self._fields, self._links
def visit_fragment(self, obj: Fragment) -> None:
self.visit(obj.node)

def visit_node(self, obj: QueryNode) -> None:
for item in obj.fields:
Expand All @@ -210,7 +235,9 @@ def visit_field(self, obj: QueryField) -> None:

graph_obj = self._node.fields_map[obj.name]
func = getattr(graph_obj.func, "__subquery__", graph_obj.func)
self._fields.append((func, graph_obj, obj))
self.fields_map.setdefault(obj.name, []).append(
(func, FieldInfo(graph_obj, obj))
)

def visit_link(self, obj: QueryLink) -> None:
graph_obj = self._node.fields_map[obj.name]
Expand All @@ -221,24 +248,28 @@ def visit_link(self, obj: QueryLink) -> None:
self.visit(QueryField(r))
else:
self.visit(QueryField(graph_obj.requires))
self._links.append((graph_obj, obj))
self.links_map.setdefault(obj.name, []).append(
LinkInfo(graph_link=graph_obj, query_link=obj)
)
else:
assert isinstance(graph_obj, Field), type(graph_obj)
# `obj` here is a link, but this link is treated as a complex field
func = getattr(graph_obj.func, "__subquery__", graph_obj.func)
self._fields.append((func, graph_obj, obj))
self.fields_map.setdefault(obj.name, []).append(
(func, FieldInfo(graph_obj, obj))
)


class GroupQuery(QueryVisitor):
def __init__(self, node: Node) -> None:
self._node = node
self._funcs: List[Callable] = []
self._groups: List[Union[List[FieldGroup], LinkGroup]] = []
self._groups: List[Union[List[FieldInfo], LinkInfo]] = []
self._current_func = None

def group(
self, node: QueryNode
) -> List[Tuple[Callable, Union[List[FieldGroup], LinkGroup]]]:
) -> List[Tuple[Callable, Union[List[FieldInfo], LinkInfo]]]:
for item in node.fields:
self.visit(item)
return list(zip(self._funcs, self._groups))
Expand All @@ -251,9 +282,9 @@ def visit_field(self, obj: QueryField) -> None:
func = getattr(graph_obj.func, "__subquery__", graph_obj.func)
if func == self._current_func:
assert isinstance(self._groups[-1], list)
self._groups[-1].append((graph_obj, obj))
self._groups[-1].append(FieldInfo(graph_obj, obj))
else:
self._groups.append([(graph_obj, obj)])
self._groups.append([FieldInfo(graph_obj, obj)])
self._funcs.append(func)
self._current_func = func

Expand All @@ -265,7 +296,7 @@ def visit_link(self, obj: QueryLink) -> None:
self.visit(QueryField(r))
else:
self.visit(QueryField(graph_obj.requires))
self._groups.append((graph_obj, obj))
self._groups.append(LinkInfo(graph_obj, obj))
self._funcs.append(graph_obj.func)
self._current_func = None

Expand Down Expand Up @@ -645,18 +676,19 @@ def _process_node_ordered(
proc_steps = GroupQuery(node).group(query)

# recursively and sequentially schedule fields and links
def proc(steps: List) -> None:
def proc(
steps: List[Tuple[Callable, Union[List[FieldInfo], LinkInfo]]]
) -> None:
step_func, step_item = steps.pop(0)
if isinstance(step_item, list):
self._track(path)
dep = self._schedule_fields(
path, node, step_func, step_item, ids
)
else:
graph_link, query_link = step_item
self._track(path)
dep = self._schedule_link(
path, node, graph_link, query_link, ids
path, node, step_item.graph_link, step_item.query_link, ids
)

if steps:
Expand All @@ -679,27 +711,38 @@ def process_node(
self._process_node_ordered(path, node, query, ids)
return

fields, links = SplitQuery(node).split(query)
fields = SplitQuery(node).split(query)

to_func: Dict[str, Callable] = {}
from_func: DefaultDict[Callable, List[FieldGroup]] = defaultdict(list)
for func, graph_field, query_field in fields:
to_func[graph_field.name] = func
from_func[func].append((graph_field, query_field))
from_func: DefaultDict[Callable, List[FieldInfo]] = defaultdict(list)
for field_name, fields_info in fields.fields_map.items():
func, field_info = fields_info[0]
to_func[field_info.graph_field.name] = func
from_func[func].append(field_info)

# schedule fields resolve
to_dep: Dict[Callable, Dep] = {}
for func, func_fields in from_func.items():
for func, func_fields_info in from_func.items():
self._track(path)
to_dep[func] = self._schedule_fields(
path, node, func, func_fields, ids
path, node, func, func_fields_info, ids
)

# schedule link resolve
for graph_link, query_link in links:
for link_name, links_info in fields.links_map.items():
query_links = [info.query_link for info in links_info]
graph_link = links_info[0].graph_link

# recursively we collect and resolve leaf fields of all links fields
link = merge_links(query_links)

self._track(path)
schedule = partial(
self._schedule_link, path, node, graph_link, query_link, ids
self._schedule_link,
path,
node,
graph_link,
link,
ids,
)
if graph_link.requires:
if isinstance(graph_link.requires, list):
Expand Down Expand Up @@ -787,15 +830,16 @@ def _schedule_fields(
path: NodePath,
node: Node,
func: Callable,
fields: List[FieldGroup],
fields_info: List[FieldInfo],
ids: Optional[Any],
) -> Union[SubmitRes, TaskSet]:
query_fields = [qf for _, qf in fields]
query_fields = [f.query_field for f in fields_info]

dep: Union[TaskSet, SubmitRes]
if hasattr(func, "__subquery__"):
assert ids is not None
dep = self._queue.fork(self._task_set)
fields = [(f.graph_field, f.query_field) for f in fields_info]
proc = func(fields, ids, self._queue, self._ctx, dep)
else:
if ids is None:
Expand Down Expand Up @@ -854,7 +898,12 @@ def callback() -> None:

if ids:
self._schedule_link(
path, node, graph_link, query_link, ids, skip_cache=True
path,
node,
graph_link,
query_link,
ids,
skip_cache=True,
)

self._queue.add_callback(dep, callback)
Expand Down Expand Up @@ -882,6 +931,7 @@ def _schedule_link(
"""
args = []
if graph_link.requires:
# collect_link_requires from store
reqs: Any = link_reqs(self._index, node, graph_link, ids)

if (
Expand All @@ -902,7 +952,12 @@ def _schedule_link(

def callback() -> None:
return self.process_link(
path, node, graph_link, query_link, ids, dep.result()
path,
node,
graph_link,
query_link,
ids,
dep.result(),
)

self._queue.add_callback(dep, callback)
Expand Down
Loading

0 comments on commit 9dcf896

Please sign in to comment.