Skip to content

Commit

Permalink
Refactor (agenta_decorator.py): clean up code that had FuncResponse type
Browse files Browse the repository at this point in the history
  • Loading branch information
aybruhm committed May 20, 2024
1 parent 0b783fb commit abaa09c
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions agenta-cli/agenta/sdk/agenta_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
TextParam,
MessagesInput,
FileInputURL,
FuncResponse,
BinaryParam,
)

Expand Down Expand Up @@ -97,7 +96,7 @@ async def wrapper(*args, **kwargs) -> Any:

# End trace recording
tracing.end_recording(
outputs=llm_result.dict(),
outputs={"message": llm_result},
span=tracing.active_trace,
)
return llm_result
Expand Down Expand Up @@ -130,22 +129,22 @@ async def wrapper_deployed(*args, **kwargs) -> Any:

# End trace recording
tracing.end_recording(
outputs=llm_result.dict(),
outputs={"message": llm_result},
span=tracing.active_trace,
)
return llm_result

update_function_signature(wrapper, func_signature, config_params, ingestible_files)
route = f"/{endpoint_name}"
app.post(route, response_model=FuncResponse)(wrapper)
app.post(route)(wrapper)

update_deployed_function_signature(
wrapper_deployed,
func_signature,
ingestible_files,
)
route_deployed = f"/{endpoint_name}_deployed"
app.post(route_deployed, response_model=FuncResponse)(wrapper_deployed)
app.post(route_deployed)(wrapper_deployed)
override_schema(
openapi_schema=app.openapi(),
func_name=func.__name__,
Expand Down Expand Up @@ -205,24 +204,20 @@ async def execute_function(func: Callable[..., Any], *args, **func_params):
it awaits their execution.
"""
is_coroutine_function = inspect.iscoroutinefunction(func)
start_time = time.perf_counter()
if is_coroutine_function:
result = await func(*args, **func_params["params"])
else:
result = func(*args, **func_params["params"])

end_time = time.perf_counter()
latency = end_time - start_time

if isinstance(result, Context):
save_context(result)
if isinstance(result, Dict):
return FuncResponse(**result, latency=round(latency, 4))
if isinstance(result, Dict) and "message" in result:
return result["message"]
if isinstance(result, str):
return FuncResponse(message=result, latency=round(latency, 4)) # type: ignore
return result
except Exception as e:
handle_exception(e)
return FuncResponse(message="Unexpected error occurred", latency=0) # type: ignore
return "Unexpected error occurred"


def handle_exception(e: Exception):
Expand Down

0 comments on commit abaa09c

Please sign in to comment.