Skip to content

Commit

Permalink
refactor: add mypy and fix typing issues (#72)
Browse files Browse the repository at this point in the history
* add mypy and fix typing issues

Signed-off-by: gruebel <[email protected]>

* imrpove callable typing

Signed-off-by: gruebel <[email protected]>

---------

Signed-off-by: gruebel <[email protected]>
  • Loading branch information
gruebel authored Mar 21, 2024
1 parent 9411d0f commit b405925
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 36 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@ repos:
- id: check-yaml
- id: trailing-whitespace
- id: check-merge-conflict

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
additional_dependencies:
- openfeature-sdk>=0.4.0
- opentelemetry-api
- types-protobuf
exclude: proto|tests
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class EventAttributes:
class TracingHook(Hook):
def after(
self, hook_context: HookContext, details: FlagEvaluationDetails, hints: dict
):
) -> None:
current_span = trace.get_current_span()

variant = details.variant
Expand All @@ -38,6 +38,8 @@ def after(

current_span.add_event(OTEL_EVENT_NAME, event_attributes)

def error(self, hook_context: HookContext, exception: Exception, hints: dict):
def error(
self, hook_context: HookContext, exception: Exception, hints: dict
) -> None:
current_span = trace.get_current_span()
current_span.record_exception(exception)
17 changes: 17 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[mypy]
files = hooks,providers
exclude = proto|tests
untyped_calls_exclude = flagd.proto

namespace_packages = True
explicit_package_bases = True
local_partial_types = True
pretty = True
strict = True
disallow_any_generics = False

[mypy-flagd.proto.*]
follow_imports = silent

[mypy-grpc]
ignore_missing_imports = True
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import typing

T = typing.TypeVar("T")


def str_to_bool(val: str) -> bool:
return val.lower() == "true"


def env_or_default(env_var, default, cast=None):
def env_or_default(
env_var: str, default: T, cast: typing.Optional[typing.Callable[[str], T]] = None
) -> typing.Union[str, T]:
val = os.environ.get(env_var)
if val is None:
return default
Expand All @@ -17,7 +21,7 @@ class Config:
def __init__(
self,
host: typing.Optional[str] = None,
port: typing.Optional[str] = None,
port: typing.Optional[int] = None,
tls: typing.Optional[bool] = None,
timeout: typing.Optional[int] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
"""

import typing
from dataclasses import dataclass
from numbers import Number

import grpc
from google.protobuf.struct_pb2 import Struct
Expand All @@ -36,17 +34,15 @@
ParseError,
TypeMismatchError,
)
from openfeature.flag_evaluation import FlagEvaluationDetails
from openfeature.flag_evaluation import FlagResolutionDetails
from openfeature.provider.metadata import Metadata
from openfeature.provider.provider import AbstractProvider

from .config import Config
from .flag_type import FlagType
from .proto.schema.v1 import schema_pb2, schema_pb2_grpc


@dataclass
class Metadata:
name: str
T = typing.TypeVar("T")


class FlagdProvider(AbstractProvider):
Expand Down Expand Up @@ -78,85 +74,85 @@ def __init__(
self.channel = channel_factory(f"{self.config.host}:{self.config.port}")
self.stub = schema_pb2_grpc.ServiceStub(self.channel)

def shutdown(self):
def shutdown(self) -> None:
self.channel.close()

def get_metadata(self):
def get_metadata(self) -> Metadata:
"""Returns provider metadata"""
return Metadata(name="FlagdProvider")

def resolve_boolean_details(
self,
key: str,
default_value: bool,
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[bool]:
return self._resolve(key, FlagType.BOOLEAN, default_value, evaluation_context)

def resolve_string_details(
self,
key: str,
default_value: str,
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[str]:
return self._resolve(key, FlagType.STRING, default_value, evaluation_context)

def resolve_float_details(
self,
key: str,
default_value: Number,
evaluation_context: EvaluationContext = None,
):
default_value: float,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[float]:
return self._resolve(key, FlagType.FLOAT, default_value, evaluation_context)

def resolve_integer_details(
self,
key: str,
default_value: Number,
evaluation_context: EvaluationContext = None,
):
default_value: int,
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[int]:
return self._resolve(key, FlagType.INTEGER, default_value, evaluation_context)

def resolve_object_details(
self,
key: str,
default_value: typing.Union[dict, list],
evaluation_context: EvaluationContext = None,
):
evaluation_context: typing.Optional[EvaluationContext] = None,
) -> FlagResolutionDetails[typing.Union[dict, list]]:
return self._resolve(key, FlagType.OBJECT, default_value, evaluation_context)

def _resolve(
self,
flag_key: str,
flag_type: FlagType,
default_value: typing.Any,
evaluation_context: EvaluationContext,
):
default_value: T,
evaluation_context: typing.Optional[EvaluationContext],
) -> FlagResolutionDetails[T]:
context = self._convert_context(evaluation_context)
call_args = {"timeout": self.config.timeout}
try:
if flag_type == FlagType.BOOLEAN:
request = schema_pb2.ResolveBooleanRequest(
request = schema_pb2.ResolveBooleanRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveBoolean(request, **call_args)
elif flag_type == FlagType.STRING:
request = schema_pb2.ResolveStringRequest(
request = schema_pb2.ResolveStringRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveString(request, **call_args)
elif flag_type == FlagType.OBJECT:
request = schema_pb2.ResolveObjectRequest(
request = schema_pb2.ResolveObjectRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveObject(request, **call_args)
elif flag_type == FlagType.FLOAT:
request = schema_pb2.ResolveFloatRequest(
request = schema_pb2.ResolveFloatRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveFloat(request, **call_args)
elif flag_type == FlagType.INTEGER:
request = schema_pb2.ResolveIntRequest(
request = schema_pb2.ResolveIntRequest( # type:ignore[attr-defined]
flag_key=flag_key, context=context
)
response = self.stub.ResolveInt(request, **call_args)
Expand All @@ -176,14 +172,15 @@ def _resolve(
raise GeneralError(message) from e

# Got a valid flag and valid type. Return it.
return FlagEvaluationDetails(
flag_key=flag_key,
return FlagResolutionDetails(
value=response.value,
reason=response.reason,
variant=response.variant,
)

def _convert_context(self, evaluation_context: EvaluationContext):
def _convert_context(
self, evaluation_context: typing.Optional[EvaluationContext]
) -> Struct:
s = Struct()
if evaluation_context:
try:
Expand Down

0 comments on commit b405925

Please sign in to comment.