Skip to content

Commit

Permalink
remove shell logic from sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
mmabrouk committed Dec 16, 2024
1 parent bebc810 commit 45a089d
Showing 1 changed file with 31 additions and 102 deletions.
133 changes: 31 additions & 102 deletions agenta-cli/agenta/sdk/decorators/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def __init__(self, path, config_schema: BaseModel):
self.route_path = path

def __call__(self, f):
# If not running in Agenta, return the original function unchanged
if environ.get("AGENTA_RUNTIME") != "true":
return f

self.e = entrypoint(
f, route_path=self.route_path, config_schema=self.config_schema
)

return f


Expand Down Expand Up @@ -120,6 +123,23 @@ def __init__(
route_path="",
config_schema: Optional[BaseModel] = None,
):
self.func = func
self.route_path = route_path
self.config_schema = config_schema

def __call__(self, func=None):
if func is None:
func = self.func
route_path = self.route_path
config_schema = self.config_schema
else:
route_path = ""
config_schema = None

# If not running in Agenta, return the original function unchanged
if environ.get("AGENTA_RUNTIME") != "true":
return func

### --- Update Middleware --- #
try:
global _MIDDLEWARES # pylint: disable=global-statement
Expand Down Expand Up @@ -172,6 +192,11 @@ def __init__(
### --- Playground --- #
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
if environ.get("AGENTA_RUNTIME") != "true":
raise HTTPException(
status_code=403,
detail="This endpoint is only available when running in Agenta environment",
)
func_params, api_config_params = self.split_kwargs(kwargs, config_params)
self.ingest_files(func_params, ingestible_files)
if not config_schema:
Expand Down Expand Up @@ -234,6 +259,11 @@ async def wrapper(*args, **kwargs) -> Any:
### --- Deployed --- #
@wraps(func)
async def wrapper_deployed(*args, **kwargs) -> Any:
if environ.get("AGENTA_RUNTIME") != "true":
raise HTTPException(
status_code=403,
detail="This endpoint is only available when running in Agenta environment",
)
func_params = {
k: v
for k, v in kwargs.items()
Expand Down Expand Up @@ -306,14 +336,6 @@ async def wrapper_deployed(*args, **kwargs) -> Any:
config=route["config"],
)

if self.is_main_script(func) and route_path == "":
self.handle_terminal_run(
func,
func_signature.parameters, # type: ignore
config_params,
ingestible_files,
)

def extract_ingestible_files(
self,
func_signature: Signature,
Expand Down Expand Up @@ -615,99 +637,6 @@ def is_main_script(self, func: Callable) -> bool:
"""
return func.__module__ == "__main__"

def handle_terminal_run(
self,
func: Callable,
func_params: Dict[str, Parameter],
config_params: Dict[str, Any],
ingestible_files: Dict,
):
"""
Parses command line arguments and sets configuration when script is run from the terminal.
Args:
func_params (dict): A dictionary containing the function parameters and their annotations.
config_params (dict): A dictionary containing the configuration parameters.
ingestible_files (dict): A dictionary containing the files that should be ingested.
"""

# For required parameters, we add them as arguments
parser = ArgumentParser()
for name, param in func_params.items():
if name in ingestible_files:
parser.add_argument(name, type=str)
else:
parser.add_argument(name, type=param.annotation)

for name, param in config_params.items():
if type(param) is MultipleChoiceParam:
parser.add_argument(
f"--{name}",
type=str,
default=param.default,
choices=param.choices, # type: ignore
)
else:
parser.add_argument(
f"--{name}",
type=type(param),
default=param,
)

args = parser.parse_args()

# split the arg list into the arg in the app_param and
# the args from the sig.parameter
args_config_params = {k: v for k, v in vars(args).items() if k in config_params}
args_func_params = {
k: v for k, v in vars(args).items() if k not in config_params
}
for name in ingestible_files:
args_func_params[name] = InFile(
file_name=Path(args_func_params[name]).stem,
file_path=args_func_params[name],
)

# Update args_config_params with default values from config_params if not provided in command line arguments
args_config_params.update(
{
key: value
for key, value in config_params.items()
if key not in args_config_params
}
)

loop = get_event_loop()

with routing_context_manager(config=args_config_params):
result = loop.run_until_complete(
self.execute_function(
func,
True, # inline trace: True
**{"params": args_func_params, "config_params": args_config_params},
)
)

if result.trace:
log.info("\n========= Result =========\n")

log.info(f"trace_id: {result.trace['trace_id']}")
log.info(f"latency: {result.trace.get('latency')}")
log.info(f"cost: {result.trace.get('cost')}")
log.info(f"usage: {list(result.trace.get('usage', {}).values())}")

log.info(" ")
log.info("data:")
log.info(dumps(result.data, indent=2))

log.info(" ")
log.info("trace:")
log.info("----------------")
log.info(dumps(result.trace.get("spans", []), indent=2))
log.info("----------------")

log.info("\n==========================\n")

def override_config_in_schema(
self,
openapi_schema: dict,
Expand Down

0 comments on commit 45a089d

Please sign in to comment.