diff --git a/aws_embedded_metrics/metric_scope/__init__.py b/aws_embedded_metrics/metric_scope/__init__.py index 47044bc..a970460 100644 --- a/aws_embedded_metrics/metric_scope/__init__.py +++ b/aws_embedded_metrics/metric_scope/__init__.py @@ -11,13 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, TypeVar, cast from aws_embedded_metrics.logger.metrics_logger_factory import create_metrics_logger import inspect import asyncio from functools import wraps +F = TypeVar('F', bound=Callable[..., Any]) -def metric_scope(fn): # type: ignore + +def metric_scope(fn: F) -> F: if asyncio.iscoroutinefunction(fn): @@ -33,7 +36,7 @@ async def wrapper(*args, **kwargs): # type: ignore finally: await logger.flush() - return wrapper + return cast(F, wrapper) else: @wraps(fn) @@ -49,4 +52,4 @@ def wrapper(*args, **kwargs): # type: ignore loop = asyncio.get_event_loop() loop.run_until_complete(logger.flush()) - return wrapper + return cast(F, wrapper)