diff --git a/hiku/cache.py b/hiku/cache.py index 4c0c8057..c01d916b 100644 --- a/hiku/cache.py +++ b/hiku/cache.py @@ -51,7 +51,7 @@ labelnames=["graph", "query_name", "node", "field"], ) -CACHE_VERSION = "1" +CACHE_VERSION = "2" class Hasher(Protocol): diff --git a/hiku/denormalize/base.py b/hiku/denormalize/base.py index 9317295f..70910b7e 100644 --- a/hiku/denormalize/base.py +++ b/hiku/denormalize/base.py @@ -10,6 +10,7 @@ Field as GraphField, ) from ..query import ( + Fragment, QueryVisitor, Link, Field, @@ -56,6 +57,7 @@ def __init__(self, graph: Graph, result: Proxy) -> None: self._unions = graph.unions_map self._enums = graph.enums_map self._result = result + self._index = result.__idx__ self._type: t.Deque[ t.Union[t.Type[Record], Union, Interface, BaseEnum] ] = deque([self._types["__root__"]]) @@ -72,16 +74,22 @@ def visit_node(self, obj: Node) -> t.Any: for item in obj.fields: self.visit(item) + for fr in obj.fragments: + self.visit_fragment(fr) + + def visit_fragment(self, obj: Fragment) -> None: type_name = None if isinstance(self._data[-1], Proxy): type_name = self._data[-1].__ref__.node - for fr in obj.fragments: - if type_name is not None and type_name != fr.type_name: - # do not visit fragment if type specified and not match - continue + if type_name is not None and type_name != obj.type_name: + # for unions we must visit only fragments with same type as node + return - self.visit(fr) + self._data.append(Proxy(self._index, self._data[-1].__ref__, obj.node)) + for item in obj.node.fields: + self.visit(item) + self._data.pop() def visit_field(self, obj: Field) -> None: if isinstance(self._data[-1], Proxy): @@ -92,9 +100,10 @@ def visit_field(self, obj: Field) -> None: node = self._graph.nodes_map[type_name] graph_field = node.fields_map[obj.name] - self._res[-1][obj.result_key] = serialize_value( - self._graph, graph_field, self._data[-1][obj.result_key] - ) + if obj.result_key not in self._res[-1]: + self._res[-1][obj.result_key] = serialize_value( + self._graph, graph_field, self._data[-1][obj.result_key] + ) else: # Record type itself does not have custom serialization # TODO: support Scalar/Enum types in Record @@ -112,7 +121,13 @@ def visit_link(self, obj: Link) -> None: if isinstance(type_, RefMeta): self._type.append(get_type(self._types, type_)) - self._res.append({}) + # if we already visited this link, just reuse the result + if obj.result_key not in self._res[-1]: + self._res.append({}) + else: + res = self._res[-1][obj.result_key] + self._res.append(res) + self._data.append(self._data[-1][obj.result_key]) super().visit_link(obj) self._data.pop() diff --git a/hiku/engine.py b/hiku/engine.py index 8ac358f5..ccd882fd 100644 --- a/hiku/engine.py +++ b/hiku/engine.py @@ -34,11 +34,13 @@ from .executors.base import SyncAsyncExecutor from .operation import Operation, OperationType from .query import ( + Fragment, Node as QueryNode, Field as QueryField, Link as QueryLink, QueryTransformer, QueryVisitor, + merge_links, ) from .graph import ( FieldType, @@ -171,10 +173,16 @@ def visit_link(self, obj: QueryLink) -> QueryLink: return obj.copy(node=node, options=options) -# query.Link is considered a complex Field if present in tuple -FieldGroup = Tuple[Field, Union[QueryField, QueryLink]] -CallableFieldGroup = Tuple[Callable, Field, Union[QueryField, QueryLink]] -LinkGroup = Tuple[Link, QueryLink] +@dataclasses.dataclass +class FieldInfo: + graph_field: Field + query_field: Union[QueryField, QueryLink] + + +@dataclasses.dataclass +class LinkInfo: + graph_link: Link + query_link: QueryLink class SplitQuery(QueryVisitor): @@ -184,21 +192,38 @@ class SplitQuery(QueryVisitor): def __init__(self, graph_node: Node) -> None: self._node = graph_node - self._fields: List[CallableFieldGroup] = [] - self._links: List[LinkGroup] = [] + self.links_map: Dict[str, List[LinkInfo]] = {} + self.fields_map: Dict[str, List[Tuple[Callable, FieldInfo]]] = {} - def split( - self, query_node: QueryNode - ) -> Tuple[List[CallableFieldGroup], List[LinkGroup]]: + def split(self, query_node: QueryNode) -> "SplitQuery": for item in query_node.fields: self.visit(item) for fr in query_node.fragments: - if fr.type_name != self._node.name: - continue - self.visit(fr) + # node fragments can have different type_names + # if node is union or inteface + if fr.type_name == self._node.name: + self.visit(fr) + + for field, fields in self.fields_map.items(): + if len(set([f.query_field.index_key for _, f in fields])) > 1: + raise ValueError( + f"Can not use same field '{field}' with " + "different arguments." + " Use different field names (aliases) or arguments." + ) + + for link, links in self.links_map.items(): + if len(set([ln.query_link.index_key for ln in links])) > 1: + raise ValueError( + f"Can not use same field '{field}' with " + "different arguments." + " Use different field names (aliases) or arguments." + ) + return self - return self._fields, self._links + def visit_fragment(self, obj: Fragment) -> None: + self.visit(obj.node) def visit_node(self, obj: QueryNode) -> None: for item in obj.fields: @@ -210,7 +235,9 @@ def visit_field(self, obj: QueryField) -> None: graph_obj = self._node.fields_map[obj.name] func = getattr(graph_obj.func, "__subquery__", graph_obj.func) - self._fields.append((func, graph_obj, obj)) + self.fields_map.setdefault(obj.name, []).append( + (func, FieldInfo(graph_obj, obj)) + ) def visit_link(self, obj: QueryLink) -> None: graph_obj = self._node.fields_map[obj.name] @@ -221,24 +248,28 @@ def visit_link(self, obj: QueryLink) -> None: self.visit(QueryField(r)) else: self.visit(QueryField(graph_obj.requires)) - self._links.append((graph_obj, obj)) + self.links_map.setdefault(obj.name, []).append( + LinkInfo(graph_link=graph_obj, query_link=obj) + ) else: assert isinstance(graph_obj, Field), type(graph_obj) # `obj` here is a link, but this link is treated as a complex field func = getattr(graph_obj.func, "__subquery__", graph_obj.func) - self._fields.append((func, graph_obj, obj)) + self.fields_map.setdefault(obj.name, []).append( + (func, FieldInfo(graph_obj, obj)) + ) class GroupQuery(QueryVisitor): def __init__(self, node: Node) -> None: self._node = node self._funcs: List[Callable] = [] - self._groups: List[Union[List[FieldGroup], LinkGroup]] = [] + self._groups: List[Union[List[FieldInfo], LinkInfo]] = [] self._current_func = None def group( self, node: QueryNode - ) -> List[Tuple[Callable, Union[List[FieldGroup], LinkGroup]]]: + ) -> List[Tuple[Callable, Union[List[FieldInfo], LinkInfo]]]: for item in node.fields: self.visit(item) return list(zip(self._funcs, self._groups)) @@ -251,9 +282,9 @@ def visit_field(self, obj: QueryField) -> None: func = getattr(graph_obj.func, "__subquery__", graph_obj.func) if func == self._current_func: assert isinstance(self._groups[-1], list) - self._groups[-1].append((graph_obj, obj)) + self._groups[-1].append(FieldInfo(graph_obj, obj)) else: - self._groups.append([(graph_obj, obj)]) + self._groups.append([FieldInfo(graph_obj, obj)]) self._funcs.append(func) self._current_func = func @@ -265,7 +296,7 @@ def visit_link(self, obj: QueryLink) -> None: self.visit(QueryField(r)) else: self.visit(QueryField(graph_obj.requires)) - self._groups.append((graph_obj, obj)) + self._groups.append(LinkInfo(graph_obj, obj)) self._funcs.append(graph_obj.func) self._current_func = None @@ -645,7 +676,9 @@ def _process_node_ordered( proc_steps = GroupQuery(node).group(query) # recursively and sequentially schedule fields and links - def proc(steps: List) -> None: + def proc( + steps: List[Tuple[Callable, Union[List[FieldInfo], LinkInfo]]] + ) -> None: step_func, step_item = steps.pop(0) if isinstance(step_item, list): self._track(path) @@ -653,10 +686,9 @@ def proc(steps: List) -> None: path, node, step_func, step_item, ids ) else: - graph_link, query_link = step_item self._track(path) dep = self._schedule_link( - path, node, graph_link, query_link, ids + path, node, step_item.graph_link, step_item.query_link, ids ) if steps: @@ -679,27 +711,38 @@ def process_node( self._process_node_ordered(path, node, query, ids) return - fields, links = SplitQuery(node).split(query) + fields = SplitQuery(node).split(query) to_func: Dict[str, Callable] = {} - from_func: DefaultDict[Callable, List[FieldGroup]] = defaultdict(list) - for func, graph_field, query_field in fields: - to_func[graph_field.name] = func - from_func[func].append((graph_field, query_field)) + from_func: DefaultDict[Callable, List[FieldInfo]] = defaultdict(list) + for field_name, fields_info in fields.fields_map.items(): + func, field_info = fields_info[0] + to_func[field_info.graph_field.name] = func + from_func[func].append(field_info) - # schedule fields resolve to_dep: Dict[Callable, Dep] = {} - for func, func_fields in from_func.items(): + for func, func_fields_info in from_func.items(): self._track(path) to_dep[func] = self._schedule_fields( - path, node, func, func_fields, ids + path, node, func, func_fields_info, ids ) # schedule link resolve - for graph_link, query_link in links: + for link_name, links_info in fields.links_map.items(): + query_links = [info.query_link for info in links_info] + graph_link = links_info[0].graph_link + + # recursively we collect and resolve leaf fields of all links fields + link = merge_links(query_links) + self._track(path) schedule = partial( - self._schedule_link, path, node, graph_link, query_link, ids + self._schedule_link, + path, + node, + graph_link, + link, + ids, ) if graph_link.requires: if isinstance(graph_link.requires, list): @@ -787,15 +830,16 @@ def _schedule_fields( path: NodePath, node: Node, func: Callable, - fields: List[FieldGroup], + fields_info: List[FieldInfo], ids: Optional[Any], ) -> Union[SubmitRes, TaskSet]: - query_fields = [qf for _, qf in fields] + query_fields = [f.query_field for f in fields_info] dep: Union[TaskSet, SubmitRes] if hasattr(func, "__subquery__"): assert ids is not None dep = self._queue.fork(self._task_set) + fields = [(f.graph_field, f.query_field) for f in fields_info] proc = func(fields, ids, self._queue, self._ctx, dep) else: if ids is None: @@ -854,7 +898,12 @@ def callback() -> None: if ids: self._schedule_link( - path, node, graph_link, query_link, ids, skip_cache=True + path, + node, + graph_link, + query_link, + ids, + skip_cache=True, ) self._queue.add_callback(dep, callback) @@ -882,6 +931,7 @@ def _schedule_link( """ args = [] if graph_link.requires: + # collect_link_requires from store reqs: Any = link_reqs(self._index, node, graph_link, ids) if ( @@ -902,7 +952,12 @@ def _schedule_link( def callback() -> None: return self.process_link( - path, node, graph_link, query_link, ids, dep.result() + path, + node, + graph_link, + query_link, + ids, + dep.result(), ) self._queue.add_callback(dep, callback) diff --git a/hiku/query.py b/hiku/query.py index a4033e6c..c32f965f 100644 --- a/hiku/query.py +++ b/hiku/query.py @@ -133,7 +133,6 @@ def index_key(self) -> str: else: return self.name - # TODO: test this hash def __hash__(self) -> int: return hash(self.index_key) @@ -247,7 +246,10 @@ def fields_map( @cached_property def fragments_map(self) -> FragmentMap: - return OrderedDict((f.type_name, f) for f in self.fragments) + """Only named fragments""" + return OrderedDict( + (f.name, f) for f in self.fragments if f.name is not None + ) @cached_property def result_map(self) -> OrderedDict: @@ -261,9 +263,12 @@ def __hash__(self) -> int: class Fragment(Base): - __attrs__ = ("type_name", "node") + __attrs__ = ("name", "type_name", "node") - def __init__(self, type_name: str, fields: t.List[FieldOrLink]) -> None: + def __init__( + self, name: t.Optional[str], type_name: str, fields: t.List[FieldOrLink] + ) -> None: + self.name = name # if None, it's an inline fragment self.type_name = type_name self.node = Node(fields) @@ -285,11 +290,11 @@ def _merge( nodes: t.Iterable[Node], ) -> t.Iterator[t.Union[FieldOrLink, Fragment]]: visited_fields = set() - visited_fragments = set() links = {} link_directives: t.DefaultDict[t.Tuple, t.List] = defaultdict(list) to_merge = OrderedDict() fields_iter = chain.from_iterable(e.fields for e in nodes) + fragments_iter = chain.from_iterable(e.fragments for e in nodes) for field in fields_iter: key = field_key(field) @@ -307,42 +312,8 @@ def _merge( visited_fields.add(key) yield field - if not visited_fields and not to_merge: - for fr in chain.from_iterable(e.fragments for e in nodes): - yield fr - else: - for node in nodes: - for fr in node.fragments: - fr_fields: t.List[FieldOrLink] = [] - for field in fr.node.fields: - key = (field.name, field.options_hash, field.alias) - - if field.__class__ is Link: - field = t.cast(Link, field) - - # If fragment field not exists in node fields, we - # can skip merging it with node fields and just - # leave it in a fragment. - # Field's own node will be merged as usuall - if field.name not in node.fields_map: - fr_fields.append(_merge_link(field)) - continue - - if key not in to_merge: - to_merge[key] = [field.node] - links[key] = field - else: - to_merge[key].append(field.node) - link_directives[key].extend(field.directives) - else: - if key not in visited_fields: - fr_fields.append(field) - - fr_key = (fr.type_name, tuple(field_key(f) for f in fr_fields)) - if fr_key not in visited_fragments: - visited_fragments.add(fr_key) - if fr_fields: - yield Fragment(fr.type_name, fr_fields) + for fr in fragments_iter: + yield fr for key, values in to_merge.items(): link = links[key] @@ -350,11 +321,6 @@ def _merge( yield link.copy(node=merge(values), directives=tuple(directives)) -def _merge_link(link: Link) -> Link: - """Recursively merge link node fields and return new link""" - return link.copy(node=merge([link.node])) - - def merge(nodes: t.Iterable[Node]) -> Node: """Merges multiple queries into one query @@ -374,6 +340,71 @@ def merge(nodes: t.Iterable[Node]) -> Node: return Node(fields=fields, fragments=fragments, ordered=ordered) +def merge_links(links: t.List[Link]) -> Link: + """Recursively merge link node fields and return new link""" + if len(links) == 1: + return links[0] + + # directives will be the same as in the first link + return links[0].copy(node=collect_fields([link.node for link in links])) + + +def collect_fields(nodes: t.List[Node]) -> Node: + """Collect fields from multiple nodes and return new node with them. + The main difference from `merge` is that it collects fields + from fragments and drops fragments. + """ + assert isinstance(nodes, Sequence), type(nodes) + ordered = any(n.ordered for n in nodes) + fields = [] + visited_fields: t.Set[KeyT] = set() + fragments = [] + for item in _collect_fields(nodes, visited_fields): + if isinstance(item, Fragment): + fragments.append(item) + else: + fields.append(item) + + return Node(fields=fields, fragments=fragments, ordered=ordered) + + +def _collect_fields( + nodes: t.Iterable[Node], + visited_fields: t.Set[KeyT], +) -> t.Iterator[t.Union[FieldOrLink, Fragment]]: + links = {} + link_directives: t.DefaultDict[t.Tuple, t.List] = defaultdict(list) + to_merge = OrderedDict() + fields_iter = chain.from_iterable(e.fields for e in nodes) + fragments_iter = chain.from_iterable(e.fragments for e in nodes) + + for field in fields_iter: + key = field_key(field) + + if field.__class__ is Link: + field = t.cast(Link, field) + if key not in to_merge: + to_merge[key] = [field.node] + links[key] = field + else: + to_merge[key].append(field.node) + link_directives[key].extend(field.directives) + else: + if key not in visited_fields: + visited_fields.add(key) + yield field + + for fr in fragments_iter: + yield from _collect_fields([fr.node], visited_fields) + + for key, values in to_merge.items(): + link = links[key] + directives = link_directives[key] + yield link.copy( + node=collect_fields(values), directives=tuple(directives) + ) + + class QueryVisitor: def visit(self, obj: t.Any) -> t.Any: return obj.accept(self) diff --git a/hiku/readers/graphql.py b/hiku/readers/graphql.py index 1c659891..77b06f4a 100644 --- a/hiku/readers/graphql.py +++ b/hiku/readers/graphql.py @@ -1,4 +1,3 @@ -from collections import defaultdict from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Union from graphql.language import ast @@ -233,36 +232,31 @@ def _collect_fields( else: fragments_map = self.fragments_transformer.fragments_map # type: ignore[attr-defined] # noqa: E501 - shared_fields = [] - type_fields = defaultdict(list) + fields = [] + fragments = [] for item in obj.selection_set.selections: type_name = None - selection_set = None if isinstance(item, ast.InlineFragmentNode): type_name = item.type_condition.name.value - selection_set = item.selection_set + fr_fields = list(self.visit(item)) # type: ignore[attr-defined] + fragments.append(Fragment(None, type_name, fr_fields)) elif isinstance(item, ast.FragmentSpreadNode): - if item.name.value not in fragments_map: - raise TypeError(f'Undefined fragment: "{item.name.value}"') + fragment_name = item.name.value + if fragment_name not in fragments_map: + raise TypeError(f'Undefined fragment: "{fragment_name}"') - fragment = fragments_map[item.name.value] + fragment = fragments_map[fragment_name] type_name = fragment.type_condition.name.value - selection_set = fragment.selection_set - else: - shared_fields.extend(list(self.visit(item))) # type: ignore[attr-defined] # noqa: E501 - - if type_name and selection_set: - type_fields[type_name].extend( - list(self.visit(selection_set)) # type: ignore[attr-defined] # noqa: E501 - ) - node_fragments = [] - for type_name, fields in type_fields.items(): - node_fragments.append(Fragment(type_name, fields)) + fr_fields = list(self.visit(item)) # type: ignore[attr-defined] + fragments.append(Fragment(fragment_name, type_name, fr_fields)) + else: + res = list(self.visit(item)) # type: ignore[attr-defined] + fields.extend(res) - return shared_fields, node_fragments + return fields, fragments def visit_field(self, obj: ast.FieldNode) -> Iterator[Union[Field, Link]]: if self._should_skip(obj): diff --git a/hiku/result.py b/hiku/result.py index db2d71a0..76dda051 100644 --- a/hiku/result.py +++ b/hiku/result.py @@ -96,15 +96,6 @@ def __getitem__(self, item: str) -> t.Any: "Field {!r} wasn't requested in the query".format(item) ) - try: - field = self.__node__.fragments_map[ - self.__ref__.node - ].node.result_map[item] - except KeyError: - raise KeyError( - "Field {!r} wasn't requested in the query".format(item) - ) - try: obj: t.Dict = self.__idx__[self.__ref__.node][self.__ref__.ident] except KeyError: diff --git a/hiku/sources/graph.py b/hiku/sources/graph.py index f4149304..20542a69 100644 --- a/hiku/sources/graph.py +++ b/hiku/sources/graph.py @@ -2,6 +2,7 @@ from typing import ( NoReturn, List, + Tuple, Union, Callable, Iterator, @@ -21,11 +22,15 @@ Field, ) from ..types import TypeRef, Sequence -from ..query import merge, Node as QueryNode, Field as QueryField +from ..query import ( + merge, + Node as QueryNode, + Field as QueryField, + Link as QueryLink, +) from ..types import Any from ..engine import ( Query, - FieldGroup, Context, ) from ..expr.refs import RequirementsExtractor @@ -40,6 +45,7 @@ from ..expr.compiler import ExpressionCompiler +FieldGroup = Tuple[Field, Union[QueryField, QueryLink]] Expr: TypeAlias = Union[_Func, DotHandler] @@ -176,7 +182,12 @@ def __call__( q = Query(queue, task_set, self.graph, reqs, ctx) q.process_link( - path, self.graph.root, this_graph_link, this_query_link, None, ids + path, + self.graph.root, + this_graph_link, + this_query_link, + None, + ids, ) q.process_node(path, self.graph.root, other_reqs, None) return _create_result_proc(q, procs, option_values) diff --git a/tests/benchmarks/test_read_graphql.py b/tests/benchmarks/test_read_graphql.py index 33ce5821..25d58903 100644 --- a/tests/benchmarks/test_read_graphql.py +++ b/tests/benchmarks/test_read_graphql.py @@ -51,7 +51,7 @@ def test_link_fragment(benchmark): Field("id"), ], [ - Fragment("User", [ + Fragment(None, "User", [ Field("name"), ]) ] diff --git a/tests/test_cache.py b/tests/test_cache.py index cb49194b..bd76e3c0 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -24,7 +24,7 @@ define, S, ) -from hiku.query import FieldOrLink, _compute_hash, Link as QueryLink, Node as QueryNode +from hiku.query import FieldOrLink, Link as QueryLink, Node as QueryNode, merge_links from hiku.result import Reference from hiku.sources.graph import SubGraph from hiku.sources.sqlalchemy import ( @@ -554,9 +554,9 @@ def get_product_query(product_id: int) -> str: } } } - + fragment ProductInfo on Product { - company { + company @cached(ttl: 20) { id name address { city } @@ -570,7 +570,7 @@ def get_product_query(product_id: int) -> str: ...CompanyInfo } } - + fragment CompanyInfo on Company { logoImage(size: 100) } @@ -601,7 +601,7 @@ def get_products_query() -> str: } } fragment ProductInfo on Product { - company { + company @cached(ttl: 20) { id name address { city } @@ -619,31 +619,64 @@ def get_products_query() -> str: """ -def assert_dict_equal(got, exp): - assert _compute_hash(got) == _compute_hash(exp) + +def assert_deep_equal(got, exp): + if isinstance(got, dict): + for k, v in got.items(): + assert k in exp + assert_deep_equal(v, exp[k]) + elif isinstance(got, list) or isinstance(got, tuple): + for i, item in enumerate(got): + assert_deep_equal(item, exp[i]) + elif isinstance(got, Reference): + assert hash(got) == hash(exp) + else: + assert got == exp def get_field(query: QueryNode, path: t.List[str]) -> FieldOrLink: - node = query + cur = query path_size = len(path) def last(idx: int): return idx + 1 == path_size + field = None + for idx, name in enumerate(path): - if name in node.fields_map: - node = node.fields_map[name] + fields = [] + if name in cur.fields_map: + field = cur.fields_map[name] if last(idx): - return node + return field - if isinstance(node, QueryLink): - node = node.node - else: - for fr in node.fragments: + fields.append(field) + + for fr in cur.fragments: if name in fr.node.fields_map: - node = fr.node.fields_map[name] + fields.append(fr.node.fields_map[name]) + + if len(fields) == 1: + if isinstance(fields[0], QueryLink): + cur = fields[0].node + else: + peak_path = path[idx + 1] + for f in fields: + if isinstance(f, QueryLink) and peak_path in f.node.fields_map: + cur = f.node + break - return node + for f in fields: + if isinstance(f, QueryLink): + for fr in f.node.fragments: + if peak_path in fr.node.fields_map: + cur = fr.node + break + + if not field or field.name != path[-1]: + raise KeyError(f"Field {path[-1]} not found in query") + + return field def test_cached_link_one__sqlalchemy(sync_graph_sqlalchemy): @@ -667,7 +700,14 @@ def execute(q): return DenormalizeGraphQL(graph, proxy, "query").process(q) query = read(get_product_query(1)) - company_link = get_field(query, ['product', 'company']) + + product_link = get_field(query, ['product']) + + company_link = merge_links([ + product_link.node.fields_map['company'], + product_link.node.fragments_map['ProductInfo'].node.fields_map['company'], + ]) + attributes_link = get_field(query, ['product', 'attributes']) photo_field = get_field(query, ["product", "company", "owner", "photo"]) @@ -753,12 +793,27 @@ def execute(q): assert cache.get_many.call_count == 2 - calls = { - **cache.set_many.mock_calls[0][1][0], - **cache.set_many.mock_calls[1][1][0], - } - calls_expected = {attributes_key: attributes_cache, company_key: company_cache} - assert_dict_equal(calls, calls_expected) + call1 = cache.set_many.call_args_list[0][0] + call2 = cache.set_many.call_args_list[1][0] + + company_call = None + attributes_call = None + + if company_key in call1[0]: + company_call = call1 + attributes_call = call2 + else: + company_call = call2 + attributes_call = call1 + + if not company_call or not attributes_call: + pytest.fail("Expected cache.set_many call") + + assert_deep_equal(company_call[0], {company_key: company_cache}) + assert company_call[1] == 10 + + assert_deep_equal(attributes_call[0], {attributes_key: attributes_cache}) + assert attributes_call[1] == 15 cache.reset_mock() @@ -799,7 +854,13 @@ def execute(q): query = read(get_products_query()) - company_link = get_field(query, ['products', 'company']) + products_link = get_field(query, ['products']) + + company_link = merge_links([ + products_link.node.fields_map['company'], + products_link.node.fragments_map['ProductInfo'].node.fields_map['company'], + ]) + attributes_link = get_field(query, ['products', 'attributes']) photo_field = get_field(query, ["products", "company", "owner", "photo"]) @@ -935,21 +996,34 @@ def execute(q): check_result(execute(query), expected_result) assert cache.get_many.call_count == 2 - calls = { - **cache.set_many.mock_calls[0][1][0], - **cache.set_many.mock_calls[1][1][0], - } - calls_expected = { - attributes11_12_key: attributes11_12_cache, - attributes_none_key: attributes_none_cache, - company10_key: company10_cache, - company20_key: company20_cache, - } - assert_dict_equal(calls, calls_expected) + + call1 = cache.set_many.call_args_list[0][0] + call2 = cache.set_many.call_args_list[1][0] + + company_call = None + attributes_call = None + + # calls can be in different order, so we first determine which call is which + if company10_key in call1[0] or company20_key in call1[0]: + company_call = call1 + attributes_call = call2 + else: + company_call = call2 + attributes_call = call1 + + if not company_call or not attributes_call: + pytest.fail("Expected cache.set_many call") + + assert_deep_equal(company_call[0], {company10_key: company10_cache, company20_key: company20_cache}) + assert company_call[1] == 10 + + assert_deep_equal(attributes_call[0], {attributes11_12_key: attributes11_12_cache, attributes_none_key: attributes_none_cache}) + assert attributes_call[1] == 15 cache.reset_mock() check_result(execute(query), expected_result) + assert set(*cache.get_many.mock_calls[0][1]) == { attributes11_12_key, attributes_none_key, diff --git a/tests/test_engine.py b/tests/test_engine.py index f8a79816..782b1100 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1604,3 +1604,117 @@ def link_user(ids): }, ] } + + +def test_merge_fields__should_execute_each_field_once() -> None: + num_link_user = 0 + num_link_info = 0 + num_resolve_id = 0 + num_resolve_name = 0 + + def resolve_user(fields, ids) -> List[Any]: + def get_field(f, id_) -> Any: + if f.name == "name": + nonlocal num_resolve_name + num_resolve_name += 1 + return "John" + elif f.name == "id": + nonlocal num_resolve_id + num_resolve_id += 1 + return id_ + + return [[get_field(f, id_) for f in fields] for id_ in ids] + + def resolve_info(fields, ids) -> List[Any]: + def get_field(f, id_) -> Any: + if f.name == "email": + return "john@example.com" + elif f.name == "phone": + return "+1234567890" + + return [[get_field(f, id_) for f in fields] for id_ in ids] + + def link_user() -> int: + nonlocal num_link_user + num_link_user += 1 + return 1 + + def link_info() -> int: + nonlocal num_link_info + num_link_info += 1 + return 100 + + graph = Graph( + [ + Node( + "User", + [ + Field("id", String, resolve_user), + Field("name", String, resolve_user), + Link("info", TypeRef["Info"], link_info, requires=None) + ], + ), + Node( + "Info", + [ + Field("email", String, resolve_info), + Field("phone", String, resolve_info), + ], + ), + Node( + "Context", + [ + Link("user", TypeRef["User"], link_user, requires=None) + ], + ), + Root( + [Link("context", TypeRef["Context"], lambda: 100, requires=None)] + ), + ] + ) + + query = """ + query GetUser { + context { + user { + id + ...UserFragmentA + ...UserFragmentB + ... on User { + id + } + } + ...ContextFragment + } + } + + fragment ContextFragment on Context { + user { + id + name + } + } + + fragment UserFragmentA on User { + id + info { + email + } + } + + fragment UserFragmentB on User { + id + name + info { + phone + } + } + """ + + data = execute_endpoint(graph, query)["data"] + + assert num_link_user == 1 + assert num_link_info == 1 + assert num_resolve_id == 1 + assert num_resolve_name == 1 + assert data == {"context": {"user": {"id": 1, "name": "John", "info": {"email": "john@example.com", "phone": "+1234567890"}}}} diff --git a/tests/test_federation/test_engine.py b/tests/test_federation/test_engine.py index 2300a5d9..1d30f239 100644 --- a/tests/test_federation/test_engine.py +++ b/tests/test_federation/test_engine.py @@ -1,7 +1,6 @@ import pytest from hiku.graph import Graph -from hiku.context import create_execution_context from hiku.query import Node, Field, Link from hiku.executors.asyncio import AsyncIOExecutor from hiku.federation.endpoint import denormalize_entities @@ -43,7 +42,6 @@ async def execute_async(query: Node, graph: Graph, ctx=None): ] } } -QUERY = read(ENTITIES_QUERY['query'], ENTITIES_QUERY['variables']) SDL_QUERY = Node(fields=[ @@ -52,15 +50,17 @@ async def execute_async(query: Node, graph: Graph, ctx=None): def test_validate_entities_query(): - errors = validate(GRAPH, QUERY) + query = read(ENTITIES_QUERY['query'], ENTITIES_QUERY['variables']) + errors = validate(GRAPH, query) assert errors == [] def test_execute_sync_executor(): - result = execute(QUERY, GRAPH) + query = read(ENTITIES_QUERY['query'], ENTITIES_QUERY['variables']) + result = execute(query, GRAPH) data = denormalize_entities( GRAPH, - QUERY, + query, result, ) @@ -73,10 +73,11 @@ def test_execute_sync_executor(): @pytest.mark.asyncio async def test_execute_async_executor(): - result = await execute_async(QUERY, ASYNC_GRAPH) + query = read(ENTITIES_QUERY['query'], ENTITIES_QUERY['variables']) + result = await execute_async(query, ASYNC_GRAPH) data = denormalize_entities( GRAPH, - QUERY, + query, result, ) diff --git a/tests/test_read_graphql.py b/tests/test_read_graphql.py index ba0a9b32..aefc1d8f 100644 --- a/tests/test_read_graphql.py +++ b/tests/test_read_graphql.py @@ -190,7 +190,14 @@ def test_named_fragments() -> None: Node([Field("rusk")]), ), ], - [], + [ + Fragment("Meer", 'Torsion', [ + Link( + "kilned", + Node([Field("rusk")]), + ), + ]), + ], ), ) @@ -202,7 +209,7 @@ def test_named_fragments() -> None: Field("apres"), ], [ - Fragment('Makai', [ + Fragment("Goaded", 'Makai', [ Field("doozie"), PinsLink ]), @@ -218,7 +225,7 @@ def test_named_fragments() -> None: SneezerLink ], [ - Fragment("Valium", [ + Fragment(None, "Valium", [ Link( "movies", Node([Field("boree")]), @@ -599,12 +606,14 @@ def test_parse_union_with_two_fragments(): Node([ Field("__typename"), ], [ - Fragment('Audio', [ + Fragment(None, 'Audio', [ Field("id"), Field("duration"), ]), - Fragment('Video', [ + Fragment('VideoId', 'Video', [ Field("id"), + ]), + Fragment(None, 'Video', [ Field("thumbnailUrl"), ]), ]), @@ -634,7 +643,7 @@ def test_parse_union_with_one_fragment(): Node([ Field("__typename"), ], [ - Fragment('Audio', [ + Fragment(None, 'Audio', [ Field("id"), Field("duration"), ]), @@ -672,10 +681,10 @@ def test_parse_interface_with_two_fragments(): Field("id"), Field("duration"), ], [ - Fragment('Audio', [ + Fragment(None, 'Audio', [ Field("album"), ]), - Fragment('Video', [ + Fragment(None, 'Video', [ Field("thumbnailUrl"), ]), ]) @@ -708,7 +717,7 @@ def test_parse_interface_with_one_fragment(): Field("id"), Field("duration"), ], [ - Fragment('Audio', [ + Fragment(None, 'Audio', [ Field("album"), ]), ]), @@ -752,11 +761,21 @@ def test_merge_node_with_fragment_on_node() -> None: Field("id"), Field("name"), ], [ - Fragment('User', [ + Fragment(None, 'User', [ + Field("id"), Field("email"), ]), ])), - ], []), + ], [ + Fragment(None, 'Context', [ + Link("user", Node([], [ + Fragment(None, 'User', [ + Field("id"), + Field("email"), + ]), + ])), + ]), + ]), ) ] ), @@ -799,11 +818,21 @@ def test_merge_fragment_for_union() -> None: Field("id"), Field("name"), ], [ - Fragment('User', [ + Fragment(None, 'User', [ + Field("id"), Field("email"), ]), ])), - ], []), + ], [ + Fragment(None, 'Context', [ + Link("user", Node([], [ + Fragment(None, 'User', [ + Field("id"), + Field("email"), + ]), + ])), + ]), + ]), ) ] ),