From b4059255045cdb7054a35bc338207e23c42ce068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Gr=C3=BCbel?= Date: Thu, 21 Mar 2024 04:45:04 +0100 Subject: [PATCH] refactor: add mypy and fix typing issues (#72) * add mypy and fix typing issues Signed-off-by: gruebel * imrpove callable typing Signed-off-by: gruebel --------- Signed-off-by: gruebel --- .pre-commit-config.yaml | 10 +++ .../contrib/hook/opentelemetry/__init__.py | 6 +- mypy.ini | 17 ++++++ .../contrib/provider/flagd/config.py | 8 ++- .../contrib/provider/flagd/provider.py | 61 +++++++++---------- 5 files changed, 66 insertions(+), 36 deletions(-) create mode 100644 mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e59fa038..508d68c5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,13 @@ repos: - id: check-yaml - id: trailing-whitespace - id: check-merge-conflict + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + additional_dependencies: + - openfeature-sdk>=0.4.0 + - opentelemetry-api + - types-protobuf + exclude: proto|tests diff --git a/hooks/openfeature-hooks-opentelemetry/src/openfeature/contrib/hook/opentelemetry/__init__.py b/hooks/openfeature-hooks-opentelemetry/src/openfeature/contrib/hook/opentelemetry/__init__.py index bfe454ad..7a813bdb 100644 --- a/hooks/openfeature-hooks-opentelemetry/src/openfeature/contrib/hook/opentelemetry/__init__.py +++ b/hooks/openfeature-hooks-opentelemetry/src/openfeature/contrib/hook/opentelemetry/__init__.py @@ -16,7 +16,7 @@ class EventAttributes: class TracingHook(Hook): def after( self, hook_context: HookContext, details: FlagEvaluationDetails, hints: dict - ): + ) -> None: current_span = trace.get_current_span() variant = details.variant @@ -38,6 +38,8 @@ def after( current_span.add_event(OTEL_EVENT_NAME, event_attributes) - def error(self, hook_context: HookContext, exception: Exception, hints: dict): + def error( + self, hook_context: HookContext, exception: Exception, hints: dict + ) -> None: current_span = trace.get_current_span() current_span.record_exception(exception) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..4c24b416 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,17 @@ +[mypy] +files = hooks,providers +exclude = proto|tests +untyped_calls_exclude = flagd.proto + +namespace_packages = True +explicit_package_bases = True +local_partial_types = True +pretty = True +strict = True +disallow_any_generics = False + +[mypy-flagd.proto.*] +follow_imports = silent + +[mypy-grpc] +ignore_missing_imports = True diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py index dd841bcb..e2db98a5 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py @@ -1,12 +1,16 @@ import os import typing +T = typing.TypeVar("T") + def str_to_bool(val: str) -> bool: return val.lower() == "true" -def env_or_default(env_var, default, cast=None): +def env_or_default( + env_var: str, default: T, cast: typing.Optional[typing.Callable[[str], T]] = None +) -> typing.Union[str, T]: val = os.environ.get(env_var) if val is None: return default @@ -17,7 +21,7 @@ class Config: def __init__( self, host: typing.Optional[str] = None, - port: typing.Optional[str] = None, + port: typing.Optional[int] = None, tls: typing.Optional[bool] = None, timeout: typing.Optional[int] = None, ): diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py index 570e819c..ea91a1d1 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py @@ -22,8 +22,6 @@ """ import typing -from dataclasses import dataclass -from numbers import Number import grpc from google.protobuf.struct_pb2 import Struct @@ -36,17 +34,15 @@ ParseError, TypeMismatchError, ) -from openfeature.flag_evaluation import FlagEvaluationDetails +from openfeature.flag_evaluation import FlagResolutionDetails +from openfeature.provider.metadata import Metadata from openfeature.provider.provider import AbstractProvider from .config import Config from .flag_type import FlagType from .proto.schema.v1 import schema_pb2, schema_pb2_grpc - -@dataclass -class Metadata: - name: str +T = typing.TypeVar("T") class FlagdProvider(AbstractProvider): @@ -78,10 +74,10 @@ def __init__( self.channel = channel_factory(f"{self.config.host}:{self.config.port}") self.stub = schema_pb2_grpc.ServiceStub(self.channel) - def shutdown(self): + def shutdown(self) -> None: self.channel.close() - def get_metadata(self): + def get_metadata(self) -> Metadata: """Returns provider metadata""" return Metadata(name="FlagdProvider") @@ -89,74 +85,74 @@ def resolve_boolean_details( self, key: str, default_value: bool, - evaluation_context: EvaluationContext = None, - ): + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: return self._resolve(key, FlagType.BOOLEAN, default_value, evaluation_context) def resolve_string_details( self, key: str, default_value: str, - evaluation_context: EvaluationContext = None, - ): + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: return self._resolve(key, FlagType.STRING, default_value, evaluation_context) def resolve_float_details( self, key: str, - default_value: Number, - evaluation_context: EvaluationContext = None, - ): + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: return self._resolve(key, FlagType.FLOAT, default_value, evaluation_context) def resolve_integer_details( self, key: str, - default_value: Number, - evaluation_context: EvaluationContext = None, - ): + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: return self._resolve(key, FlagType.INTEGER, default_value, evaluation_context) def resolve_object_details( self, key: str, default_value: typing.Union[dict, list], - evaluation_context: EvaluationContext = None, - ): + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context) def _resolve( self, flag_key: str, flag_type: FlagType, - default_value: typing.Any, - evaluation_context: EvaluationContext, - ): + default_value: T, + evaluation_context: typing.Optional[EvaluationContext], + ) -> FlagResolutionDetails[T]: context = self._convert_context(evaluation_context) call_args = {"timeout": self.config.timeout} try: if flag_type == FlagType.BOOLEAN: - request = schema_pb2.ResolveBooleanRequest( + request = schema_pb2.ResolveBooleanRequest( # type:ignore[attr-defined] flag_key=flag_key, context=context ) response = self.stub.ResolveBoolean(request, **call_args) elif flag_type == FlagType.STRING: - request = schema_pb2.ResolveStringRequest( + request = schema_pb2.ResolveStringRequest( # type:ignore[attr-defined] flag_key=flag_key, context=context ) response = self.stub.ResolveString(request, **call_args) elif flag_type == FlagType.OBJECT: - request = schema_pb2.ResolveObjectRequest( + request = schema_pb2.ResolveObjectRequest( # type:ignore[attr-defined] flag_key=flag_key, context=context ) response = self.stub.ResolveObject(request, **call_args) elif flag_type == FlagType.FLOAT: - request = schema_pb2.ResolveFloatRequest( + request = schema_pb2.ResolveFloatRequest( # type:ignore[attr-defined] flag_key=flag_key, context=context ) response = self.stub.ResolveFloat(request, **call_args) elif flag_type == FlagType.INTEGER: - request = schema_pb2.ResolveIntRequest( + request = schema_pb2.ResolveIntRequest( # type:ignore[attr-defined] flag_key=flag_key, context=context ) response = self.stub.ResolveInt(request, **call_args) @@ -176,14 +172,15 @@ def _resolve( raise GeneralError(message) from e # Got a valid flag and valid type. Return it. - return FlagEvaluationDetails( - flag_key=flag_key, + return FlagResolutionDetails( value=response.value, reason=response.reason, variant=response.variant, ) - def _convert_context(self, evaluation_context: EvaluationContext): + def _convert_context( + self, evaluation_context: typing.Optional[EvaluationContext] + ) -> Struct: s = Struct() if evaluation_context: try: