Skip to content

Commit

Permalink
Merge pull request #139 from evo-company/refactor-hiku-extensions
Browse files Browse the repository at this point in the history
pass execution_context as an argment, drop ExtensionFactory
  • Loading branch information
kindermax authored Nov 2, 2023
2 parents c8d7512 + 00bb89e commit 3aedc7d
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 196 deletions.
2 changes: 0 additions & 2 deletions hiku/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .base_extension import Extension
from .base_extension import ExtensionFactory

__all__ = [
"Extension",
"ExtensionFactory",
]
114 changes: 56 additions & 58 deletions hiku/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import contextlib
import inspect
from asyncio import iscoroutinefunction

from types import TracebackType
from typing import (
Any,
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
Expand All @@ -15,7 +14,6 @@
NamedTuple,
Optional,
Sequence,
TYPE_CHECKING,
Type,
TypeVar,
Union,
Expand All @@ -29,18 +27,15 @@

AwaitableOrValue = Union[Awaitable[T], T]
AsyncIteratorOrIterator = Union[AsyncIterator[T], Iterator[T]]
Hook = Callable[["Extension"], AsyncIteratorOrIterator[None]]
Hook = Callable[
["Extension", "ExecutionContext"], AsyncIteratorOrIterator[None]
]


class Extension:
execution_context: ExecutionContext

def __init__(self, *, execution_context: Optional[ExecutionContext] = None):
# execution_context will be set during ExtensionsManager initialization
# it is safe to assume that it will be not None
self.execution_context = execution_context # type: ignore[assignment]

def on_graph(self) -> AsyncIteratorOrIterator[None]:
def on_graph(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the graph (transformation) step.
Graph transformation step is a step where we applying transformations
Expand All @@ -51,7 +46,9 @@ def on_graph(self) -> AsyncIteratorOrIterator[None]:
"""
yield None

def on_dispatch(self) -> AsyncIteratorOrIterator[None]:
def on_dispatch(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the dispatch step.
Dispatch step is a step where the query is dispatched by to the endpoint
Expand All @@ -70,7 +67,9 @@ def on_dispatch(self) -> AsyncIteratorOrIterator[None]:
"""
yield None

def on_operation(self) -> AsyncIteratorOrIterator[None]:
def on_operation(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the operation step.
Operation step is a step where the graphql ast is transformed into
Expand All @@ -79,15 +78,19 @@ def on_operation(self) -> AsyncIteratorOrIterator[None]:
"""
yield None

def on_parse(self) -> AsyncIteratorOrIterator[None]:
def on_parse(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the parsing step.
Parse step is when query string is parsed into graphql ast
and will be assigned to the execution_context.graphql_document.
"""
yield None

def on_validate(self) -> AsyncIteratorOrIterator[None]:
def on_validate(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the validation step.
Validation step is when hiku query is validated.
Expand All @@ -96,7 +99,9 @@ def on_validate(self) -> AsyncIteratorOrIterator[None]:
"""
yield None

def on_execute(self) -> AsyncIteratorOrIterator[None]:
def on_execute(
self, execution_context: ExecutionContext
) -> AsyncIteratorOrIterator[None]:
"""Called before and after the execution step.
Execution step is when hiku query is executed by hiku engine.
Expand All @@ -106,25 +111,12 @@ def on_execute(self) -> AsyncIteratorOrIterator[None]:
yield None


class ExtensionFactory:
"""Lazy extension factory.
Remembers arguments and keyword arguments and creates an extension
instance when ExtensionsManager is created.
"""

ext_class: Type[Extension]

def __init__(self, *args: Any, **kwargs: Any):
self._args = args
self._kwargs = kwargs

def create(self, execution_context: ExecutionContext) -> Extension:
extension = self.ext_class(*self._args, **self._kwargs)
extension.execution_context = execution_context
return extension
class ExtensionsManager:
"""ExtensionManager is a per/dispatch extensions manager.
It is used to call excensions hooks in the right order.
"""

class ExtensionsManager:
def __init__(
self,
execution_context: ExecutionContext,
Expand All @@ -138,55 +130,49 @@ def __init__(
init_extensions: List[Extension] = []

for extension in extensions:
if isinstance(extension, ExtensionFactory):
init_extensions.append(extension.create(execution_context))
elif isinstance(extension, Extension):
raise ValueError(
f"Extension {extension} must be a class, "
"not an instance. Use ExtensionFactory if your extension "
"has custom arguments."
)
if isinstance(extension, Extension):
init_extensions.append(extension)
else:
init_extensions.append(
extension(execution_context=execution_context)
)
init_extensions.append(extension())

self.extensions = init_extensions

def graph(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_graph.__name__,
self.extensions,
Extension.on_graph.__name__, self.extensions, self.execution_context
)

def dispatch(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_dispatch.__name__,
self.extensions,
self.execution_context,
)

def parsing(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_parse.__name__,
self.extensions,
Extension.on_parse.__name__, self.extensions, self.execution_context
)

def operation(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_operation.__name__,
self.extensions,
self.execution_context,
)

def validation(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_validate.__name__,
self.extensions,
self.execution_context,
)

def execution(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_execute.__name__,
self.extensions,
self.execution_context,
)


Expand All @@ -199,28 +185,39 @@ class WrappedHook(NamedTuple):
class ExtensionContextManagerBase:
__slots__ = ("hook_name", "hooks", "deprecation_message", "default_hook")

def __init__(self, hook_name: str, extensions: List[Extension]):
def __init__(
self,
hook_name: str,
extensions: List[Extension],
execution_context: ExecutionContext,
):
self.hook_name = hook_name
self.hooks: List[WrappedHook] = []
self.default_hook: Hook = getattr(Extension, self.hook_name)
for extension in extensions:
hook = self.get_hook(extension)
hook = self.get_hook(extension, execution_context)
if hook:
self.hooks.append(hook)

def get_hook(self, extension: Extension) -> Optional[WrappedHook]:
def get_hook(
self, extension: Extension, execution_context: ExecutionContext
) -> Optional[WrappedHook]:
hook_fn: Optional[Hook] = getattr(type(extension), self.hook_name)
hook_fn = hook_fn if hook_fn is not self.default_hook else None

if hook_fn:
if inspect.isgeneratorfunction(hook_fn):
return WrappedHook(extension, hook_fn(extension), False)
return WrappedHook(
extension, hook_fn(extension, execution_context), False
)

if inspect.isasyncgenfunction(hook_fn):
return WrappedHook(extension, hook_fn(extension), True)
return WrappedHook(
extension, hook_fn(extension, execution_context), True
)

if callable(hook_fn):
return self.from_callable(extension, hook_fn)
return self.from_callable(extension, hook_fn, execution_context)

raise ValueError(
f"Hook {self.hook_name} on {extension} "
Expand All @@ -232,20 +229,21 @@ def get_hook(self, extension: Extension) -> Optional[WrappedHook]:
@staticmethod
def from_callable(
extension: Extension,
func: Callable[[Extension], AwaitableOrValue[Any]],
func: Hook,
execution_context: ExecutionContext,
) -> WrappedHook:
if iscoroutinefunction(func):

async def async_iterator(): # type: ignore[no-untyped-def]
await func(extension)
await func(extension, execution_context)
yield

hook = async_iterator()
return WrappedHook(extension, hook, True)
else:

def iterator(): # type: ignore[no-untyped-def]
func(extension)
func(extension, execution_context)
yield

hook = iterator()
Expand Down
17 changes: 4 additions & 13 deletions hiku/extensions/context.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
from typing import Callable, Dict, Iterator

from hiku.context import ExecutionContext
from hiku.extensions.base_extension import Extension, ExtensionFactory
from hiku.extensions.base_extension import Extension


class _CustomContextImpl(Extension):
class CustomContext(Extension):
def __init__(
self,
get_context: Callable[[ExecutionContext], Dict],
):
self.get_context = get_context

def on_execute(self) -> Iterator[None]:
self.execution_context.context = self.get_context(
self.execution_context
)
def on_execute(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.context = self.get_context(execution_context)
yield


class CustomContext(ExtensionFactory):
ext_class = _CustomContextImpl

def __init__(self, get_context: Callable[[ExecutionContext], Dict]):
super().__init__(get_context)
25 changes: 9 additions & 16 deletions hiku/extensions/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from prometheus_client.metrics import MetricWrapperBase

from hiku.context import ExecutionContext
from hiku.extensions.base_extension import Extension
from hiku.telemetry.prometheus import (
AsyncGraphMetrics,
GraphMetrics,
GraphMetricsBase,
)
from hiku.extensions.base_extension import Extension, ExtensionFactory


class _PrometheusMetricsImpl(Extension):
class PrometheusMetrics(Extension):
def __init__(
self,
name: str,
Expand All @@ -27,26 +28,22 @@ def __init__(
self._name, metric=self._metric, ctx_var=ctx_var
)

def on_graph(self) -> Iterator[None]:
self.execution_context.transformers = (
self.execution_context.transformers + (self._transformer,)
def on_graph(self, execution_context: ExecutionContext) -> Iterator[None]:
execution_context.transformers = execution_context.transformers + (
self._transformer,
)
yield

def on_execute(self) -> Iterator[None]:
def on_execute(self, execution_context: ExecutionContext) -> Iterator[None]:
if self._ctx_var is None:
yield
else:
token = self._ctx_var.set(self.execution_context.context)
token = self._ctx_var.set(execution_context.context)
yield
self._ctx_var.reset(token)


class PrometheusMetrics(ExtensionFactory):
ext_class = _PrometheusMetricsImpl


class _PrometheusMetricsAsyncImpl(_PrometheusMetricsImpl):
class PrometheusMetricsAsync(PrometheusMetrics):
def __init__(
self,
name: str,
Expand All @@ -60,7 +57,3 @@ def __init__(
ctx_var=ctx_var,
transformer_cls=AsyncGraphMetrics,
)


class PrometheusMetricsAsync(ExtensionFactory):
ext_class = _PrometheusMetricsAsyncImpl
Loading

0 comments on commit 3aedc7d

Please sign in to comment.