diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 53600a4a1e..71097cd836 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -8,7 +8,6 @@ DictInput, MultipleChoice, FloatParam, - InFile, IntParam, MultipleChoiceParam, GroupedMultipleChoiceParam, diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index 4fc475ef45..6d3c4da842 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -8,7 +8,6 @@ DictInput, MultipleChoice, FloatParam, - InFile, IntParam, MultipleChoiceParam, GroupedMultipleChoiceParam, diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 45229ea4bc..92d9a3572f 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -31,7 +31,6 @@ from agenta.sdk.types import ( DictInput, FloatParam, - InFile, IntParam, MultipleChoiceParam, MultipleChoice, @@ -139,7 +138,6 @@ def __init__( self.config_schema = config_schema signature_parameters = signature(func).parameters - ingestible_files = self.extract_ingestible_files() config, default_parameters = self.parse_config() ### --- Middleware --- # @@ -168,15 +166,9 @@ async def run_wrapper(request: Request, *args, **kwargs) -> Any: kwargs, _ = self.split_kwargs(kwargs, default_parameters) - # TODO: Why is this not used in the run_wrapper? - # self.ingest_files(kwargs, ingestible_files) - return await self.execute_wrapper(request, False, *args, **kwargs) - self.update_run_wrapper_signature( - wrapper=run_wrapper, - ingestible_files=ingestible_files, - ) + self.update_run_wrapper_signature(wrapper=run_wrapper) run_route = f"{entrypoint._run_path}{route_path}" app.post(run_route, response_model=BaseResponse)(run_wrapper) @@ -197,14 +189,10 @@ async def test_wrapper(request: Request, *args, **kwargs) -> Any: request.state.config["parameters"] = parameters - # TODO: Why is this only used in the test_wrapper? - self.ingest_files(kwargs, ingestible_files) - return await self.execute_wrapper(request, True, *args, **kwargs) self.update_test_wrapper_signature( wrapper=test_wrapper, - ingestible_files=ingestible_files, config_class=config, config_dict=default_parameters, ) @@ -234,11 +222,7 @@ async def test_wrapper(request: Request, *args, **kwargs) -> Any: { "func": func.__name__, "endpoint": test_route, - "params": ( - {**default_parameters, **signature_parameters} - if not config - else signature_parameters - ), + "params": signature_parameters, "config": config, } ) @@ -264,14 +248,7 @@ async def test_wrapper(request: Request, *args, **kwargs) -> Any: openapi_schema = app.openapi() for _route in entrypoint.routes: - self.override_schema( - openapi_schema=openapi_schema, - func_name=_route["func"], - endpoint=_route["endpoint"], - params=_route["params"], - ) - - if _route["config"] is not None: # new SDK version + if _route["config"] is not None: self.override_config_in_schema( openapi_schema=openapi_schema, func_name=_route["func"], @@ -280,15 +257,6 @@ async def test_wrapper(request: Request, *args, **kwargs) -> Any: ) ### --------------- # - def extract_ingestible_files(self) -> Dict[str, Parameter]: - """Extract parameters annotated as InFile from function signature.""" - - return { - name: param - for name, param in signature(self.func).parameters.items() - if param.annotation is InFile - } - def parse_config(self) -> Dict[str, Any]: config = None default_parameters = ag.config.all() @@ -316,27 +284,6 @@ def split_kwargs( return arguments, parameters - def ingest_file( - self, - upfile: UploadFile, - ): - temp_file = NamedTemporaryFile(delete=False) - temp_file.write(upfile.file.read()) - temp_file.close() - - return InFile(file_name=upfile.filename, file_path=temp_file.name) - - def ingest_files( - self, - func_params: Dict[str, Any], - ingestible_files: Dict[str, Parameter], - ) -> None: - """Ingest files specified in function parameters.""" - - for name in ingestible_files: - if name in func_params and func_params[name] is not None: - func_params[name] = self.ingest_file(func_params[name]) - async def execute_wrapper( self, request: Request, @@ -541,7 +488,6 @@ def update_test_wrapper_signature( wrapper: Callable[..., Any], config_class: Type[BaseModel], # TODO: change to our type config_dict: Dict[str, Any], - ingestible_files: Dict[str, Parameter], ) -> None: """Update the function signature to include new parameters.""" @@ -550,31 +496,18 @@ def update_test_wrapper_signature( self.add_config_params_to_parser(updated_params, config_class) else: self.deprecated_add_config_params_to_parser(updated_params, config_dict) - self.add_func_params_to_parser(updated_params, ingestible_files) + self.add_func_params_to_parser(updated_params) self.update_wrapper_signature(wrapper, updated_params) self.add_request_to_signature(wrapper) def update_run_wrapper_signature( self, wrapper: Callable[..., Any], - ingestible_files: Dict[str, Parameter], ) -> None: """Update the function signature to include new parameters.""" updated_params: List[Parameter] = [] - self.add_func_params_to_parser(updated_params, ingestible_files) - for param in [ - "config", - "environment", - ]: # we add the config and environment parameters - updated_params.append( - Parameter( - name=param, - kind=Parameter.KEYWORD_ONLY, - default=Body(None), - annotation=str, - ) - ) + self.add_func_params_to_parser(updated_params) self.update_wrapper_signature(wrapper, updated_params) self.add_request_to_signature(wrapper) @@ -614,33 +547,24 @@ def deprecated_add_config_params_to_parser( ) ) - def add_func_params_to_parser( - self, - updated_params: list, - ingestible_files: Dict[str, Parameter], - ) -> None: + def add_func_params_to_parser(self, updated_params: list) -> None: """Add function parameters to function signature.""" for name, param in signature(self.func).parameters.items(): - if name in ingestible_files: - updated_params.append( - Parameter(name, param.kind, annotation=UploadFile) - ) - else: - assert ( - len(param.default.__class__.__bases__) == 1 - ), f"Inherited standard type of {param.default.__class__} needs to be one." - updated_params.append( - Parameter( - name, - Parameter.KEYWORD_ONLY, - default=Body(..., embed=True), - annotation=param.default.__class__.__bases__[ - 0 - ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \ - # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (,), \ - # thus, why we are accessing the first item. - ) + assert ( + len(param.default.__class__.__bases__) == 1 + ), f"Inherited standard type of {param.default.__class__} needs to be one." + updated_params.append( + Parameter( + name, + Parameter.KEYWORD_ONLY, + default=Body(..., embed=True), + annotation=param.default.__class__.__bases__[ + 0 + ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \ + # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (,), \ + # thus, why we are accessing the first item. ) + ) def override_config_in_schema( self, @@ -715,193 +639,3 @@ def override_config_in_schema( if isinstance(constraint, Lt) ) schema_to_override[param_name]["exclusiveMaximum"] = max_value - - def override_schema( - self, openapi_schema: dict, func_name: str, endpoint: str, params: dict - ): - """ - Overrides the default openai schema generated by fastapi with additional information about: - - The choices available for each MultipleChoiceParam instance - - The min and max values for each FloatParam instance - - The min and max values for each IntParam instance - - The default value for DictInput instance - - The default value for MessagesParam instance - - The default value for FileInputURL instance - - The default value for BinaryParam instance - - ... [PLEASE ADD AT EACH CHANGE] - - Args: - openapi_schema (dict): The openapi schema generated by fastapi - func (str): The name of the function to override - endpoint (str): The name of the endpoint to override - params (dict(param_name, param_val)): The dictionary of the parameters for the function - """ - - def find_in_schema( - schema_type_properties: dict, schema: dict, param_name: str, xparam: str - ): - """Finds a parameter in the schema based on its name and x-parameter value""" - for _, value in schema.items(): - value_title_lower = str(value.get("title")).lower() - value_title = ( - "_".join(value_title_lower.split()) - if len(value_title_lower.split()) >= 2 - else value_title_lower - ) - - if ( - isinstance(value, dict) - and schema_type_properties.get("x-parameter") == xparam - and value_title == param_name - ): - # this will update the default type schema with the properties gotten - # from the schema type (param_val) __schema_properties__ classmethod - for type_key, type_value in schema_type_properties.items(): - # BEFORE: - # value = {'temperature': {'title': 'Temperature'}} - value[type_key] = type_value - # AFTER: - # value = {'temperature': { "type": "number", "title": "Temperature", "x-parameter": "float" }} - return value - - def get_type_from_param(param_val): - param_type = "string" - annotation = param_val.annotation - - if annotation == int: - param_type = "integer" - elif annotation == float: - param_type = "number" - elif annotation == dict: - param_type = "object" - elif annotation == bool: - param_type = "boolean" - elif annotation == list: - param_type = "list" - elif annotation == str: - param_type = "string" - else: - print("ERROR, unhandled annotation:", annotation) - - return param_type - - # Goes from '/some/path' to 'some_path' - endpoint = endpoint[1:].replace("/", "_") - - schema_to_override = openapi_schema["components"]["schemas"][ - f"Body_{func_name}_{endpoint}_post" - ]["properties"] - - for param_name, param_val in params.items(): - if isinstance(param_val, GroupedMultipleChoiceParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "grouped_choice", - ) - assert ( - subschema - ), f"GroupedMultipleChoiceParam '{param_name}' is in the parameters but could not be found in the openapi.json" - subschema["choices"] = param_val.choices # type: ignore - subschema["default"] = param_val.default # type: ignore - - elif isinstance(param_val, MultipleChoiceParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "choice", - ) - default = str(param_val) - param_choices = param_val.choices # type: ignore - choices = ( - [default] + param_choices - if param_val not in param_choices - else param_choices - ) - subschema["enum"] = choices - subschema["default"] = ( - default if default in param_choices else choices[0] - ) - - elif isinstance(param_val, FloatParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "float", - ) - subschema["minimum"] = param_val.minval # type: ignore - subschema["maximum"] = param_val.maxval # type: ignore - subschema["default"] = param_val - - elif isinstance(param_val, IntParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "int", - ) - subschema["minimum"] = param_val.minval # type: ignore - subschema["maximum"] = param_val.maxval # type: ignore - subschema["default"] = param_val - - elif isinstance(param_val, Parameter) and param_val.annotation is DictInput: - subschema = find_in_schema( - param_val.annotation.__schema_type_properties__(), - schema_to_override, - param_name, - "dict", - ) - subschema["default"] = param_val.default["default_keys"] - - elif isinstance(param_val, TextParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "text", - ) - subschema["default"] = param_val - - elif ( - isinstance(param_val, Parameter) - and param_val.annotation is MessagesInput - ): - subschema = find_in_schema( - param_val.annotation.__schema_type_properties__(), - schema_to_override, - param_name, - "messages", - ) - subschema["default"] = param_val.default - - elif ( - isinstance(param_val, Parameter) - and param_val.annotation is FileInputURL - ): - subschema = find_in_schema( - param_val.annotation.__schema_type_properties__(), - schema_to_override, - param_name, - "file_url", - ) - subschema["default"] = "https://example.com" - - elif isinstance(param_val, BinaryParam): - subschema = find_in_schema( - param_val.__schema_type_properties__(), - schema_to_override, - param_name, - "bool", - ) - subschema["default"] = param_val.default # type: ignore - else: - subschema = { - "title": str(param_name).capitalize(), - "type": get_type_from_param(param_val), - } - if param_val.default != _empty: - subschema["default"] = param_val.default # type: ignore - schema_to_override[param_name] = subschema diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 6d31243963..f368509fc6 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -246,7 +246,9 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() if ignore is True else [] + else io.keys() + if ignore is True + else [] ) } diff --git a/agenta-cli/agenta/sdk/types.py b/agenta-cli/agenta/sdk/types.py index cefe92825a..8cfa027afc 100644 --- a/agenta-cli/agenta/sdk/types.py +++ b/agenta-cli/agenta/sdk/types.py @@ -13,12 +13,6 @@ class MultipleChoice: choices: Union[List[str], Dict[str, List[str]]] -class InFile: - def __init__(self, file_name: str, file_path: str): - self.file_name = file_name - self.file_path = file_path - - class LLMTokenUsage(BaseModel): completion_tokens: int prompt_tokens: int diff --git a/agenta-cli/debugging/simple-app/_app.py b/agenta-cli/debugging/simple-app/_app.py index 622e817009..9e7337c9ac 100644 --- a/agenta-cli/debugging/simple-app/_app.py +++ b/agenta-cli/debugging/simple-app/_app.py @@ -13,6 +13,7 @@ ag.init() + class Prompt(BaseModel): prompt_template: str = Field(default=default_prompt) model_config = { @@ -20,11 +21,12 @@ class Prompt(BaseModel): "x-component-type": "prompt-playground", "x-component-props": { "supportedModels": ["gpt-3", "gpt-4"], - "allowTemplating": True - } + "allowTemplating": True, + }, } } + class Message(BaseModel): role: str = Field(default="user") content: str = Field(default="") @@ -33,11 +35,12 @@ class Message(BaseModel): "x-component-type": "message", "x-component-props": { "supportedModels": ["gpt-3", "gpt-4"], - "allowTemplating": True - } + "allowTemplating": True, + }, } } + class BabyConfig(BaseModel): temperature: float = Field(default=0.2) prompt_template: str = Field(default=default_prompt) @@ -45,11 +48,10 @@ class BabyConfig(BaseModel): default="asd" ) prompt: Prompt = Field(default=Prompt()) - @ag.route("/", config_schema=BabyConfig) -def generate(country: str, gender: str, messages:Message) -> str: +def generate(country: str, gender: str, messages: Message) -> str: """ Generate a baby name based on the given country and gender. diff --git a/agenta-cli/debugging/simple-app/agenta/sdk/agenta_init.py b/agenta-cli/debugging/simple-app/agenta/sdk/agenta_init.py index 9751706b44..db0a27580a 100644 --- a/agenta-cli/debugging/simple-app/agenta/sdk/agenta_init.py +++ b/agenta-cli/debugging/simple-app/agenta/sdk/agenta_init.py @@ -10,6 +10,8 @@ from agenta.client.exceptions import APIRequestError print(".DS_Store") + + class AgentaSingleton: """Singleton class to save all the "global variables" for the sdk.""" diff --git a/agenta-cli/debugging/simple-app/agenta/sdk/prompt.py b/agenta-cli/debugging/simple-app/agenta/sdk/prompt.py index 479a85c948..ea447f034f 100644 --- a/agenta-cli/debugging/simple-app/agenta/sdk/prompt.py +++ b/agenta-cli/debugging/simple-app/agenta/sdk/prompt.py @@ -1,11 +1,22 @@ from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any + class Prompt(BaseModel): """A pre-built BaseModel for prompt configuration""" + system_message: str = Field(default="", description="System message for the prompt") user_message: str = Field(default="", description="User message template") - temperature: float = Field(default=0.7, ge=0.0, le=1.0, description="Temperature for text generation") - max_tokens: Optional[int] = Field(default=None, ge=1, description="Maximum number of tokens to generate") - stop_sequences: Optional[List[str]] = Field(default=None, description="List of sequences where the model should stop generating") - model_parameters: Optional[Dict[str, Any]] = Field(default=None, description="Additional model-specific parameters") + temperature: float = Field( + default=0.7, ge=0.0, le=1.0, description="Temperature for text generation" + ) + max_tokens: Optional[int] = Field( + default=None, ge=1, description="Maximum number of tokens to generate" + ) + stop_sequences: Optional[List[str]] = Field( + default=None, + description="List of sequences where the model should stop generating", + ) + model_parameters: Optional[Dict[str, Any]] = Field( + default=None, description="Additional model-specific parameters" + ) diff --git a/services/chat-live-sdk/_app.py b/services/chat-live-sdk/_app.py index af082c8b38..f9bd72131e 100644 --- a/services/chat-live-sdk/_app.py +++ b/services/chat-live-sdk/_app.py @@ -2,12 +2,15 @@ import agenta as ag from supported_llm_models import get_all_supported_llm_models import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM + litellm = MockLiteLLM() else: import litellm + litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] diff --git a/services/chat-live-sdk/main.py b/services/chat-live-sdk/main.py index fec3a17083..c17d626bc3 100644 --- a/services/chat-live-sdk/main.py +++ b/services/chat-live-sdk/main.py @@ -5,4 +5,10 @@ if __name__ == "__main__": - run("agenta:app", host="0.0.0.0", port=80, reload=True, reload_dirs=[".", "/agenta-cli"]) + run( + "agenta:app", + host="0.0.0.0", + port=80, + reload=True, + reload_dirs=[".", "/agenta-cli"], + ) diff --git a/services/chat-live-sdk/mock_litellm.py b/services/chat-live-sdk/mock_litellm.py index 1f4cd973f3..a5b57a68cc 100644 --- a/services/chat-live-sdk/mock_litellm.py +++ b/services/chat-live-sdk/mock_litellm.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List from dataclasses import dataclass + @dataclass class MockUsage: prompt_tokens: int = 10 @@ -11,28 +12,39 @@ def dict(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, } + @dataclass class MockMessage: content: str = "This is a mock response from the LLM." + @dataclass class MockChoice: message: MockMessage = MockMessage() + @dataclass class MockCompletion: choices: List[MockChoice] = None usage: MockUsage = None - + def __init__(self): self.choices = [MockChoice()] self.usage = MockUsage() + class MockLiteLLM: - async def acompletion(self, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int = None, **kwargs) -> MockCompletion: + async def acompletion( + self, + model: str, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int = None, + **kwargs + ) -> MockCompletion: return MockCompletion() class cost_calculator: diff --git a/services/chat-old-sdk/_app.py b/services/chat-old-sdk/_app.py index 76698ccce0..0a0ff4f0b1 100644 --- a/services/chat-old-sdk/_app.py +++ b/services/chat-old-sdk/_app.py @@ -6,9 +6,11 @@ # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM + litellm = MockLiteLLM() else: import litellm + litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] diff --git a/services/chat-old-sdk/mock_litellm.py b/services/chat-old-sdk/mock_litellm.py index 1f4cd973f3..a5b57a68cc 100644 --- a/services/chat-old-sdk/mock_litellm.py +++ b/services/chat-old-sdk/mock_litellm.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List from dataclasses import dataclass + @dataclass class MockUsage: prompt_tokens: int = 10 @@ -11,28 +12,39 @@ def dict(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, } + @dataclass class MockMessage: content: str = "This is a mock response from the LLM." + @dataclass class MockChoice: message: MockMessage = MockMessage() + @dataclass class MockCompletion: choices: List[MockChoice] = None usage: MockUsage = None - + def __init__(self): self.choices = [MockChoice()] self.usage = MockUsage() + class MockLiteLLM: - async def acompletion(self, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int = None, **kwargs) -> MockCompletion: + async def acompletion( + self, + model: str, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int = None, + **kwargs + ) -> MockCompletion: return MockCompletion() class cost_calculator: diff --git a/services/completion-live-sdk/_app.py b/services/completion-live-sdk/_app.py index 05980221aa..03dda84207 100644 --- a/services/completion-live-sdk/_app.py +++ b/services/completion-live-sdk/_app.py @@ -1,12 +1,15 @@ import agenta as ag from supported_llm_models import get_all_supported_llm_models import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM + litellm = MockLiteLLM() else: import litellm + litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] diff --git a/services/completion-live-sdk/main.py b/services/completion-live-sdk/main.py index fec3a17083..c17d626bc3 100644 --- a/services/completion-live-sdk/main.py +++ b/services/completion-live-sdk/main.py @@ -5,4 +5,10 @@ if __name__ == "__main__": - run("agenta:app", host="0.0.0.0", port=80, reload=True, reload_dirs=[".", "/agenta-cli"]) + run( + "agenta:app", + host="0.0.0.0", + port=80, + reload=True, + reload_dirs=[".", "/agenta-cli"], + ) diff --git a/services/completion-live-sdk/mock_litellm.py b/services/completion-live-sdk/mock_litellm.py index 1f4cd973f3..a5b57a68cc 100644 --- a/services/completion-live-sdk/mock_litellm.py +++ b/services/completion-live-sdk/mock_litellm.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List from dataclasses import dataclass + @dataclass class MockUsage: prompt_tokens: int = 10 @@ -11,28 +12,39 @@ def dict(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, } + @dataclass class MockMessage: content: str = "This is a mock response from the LLM." + @dataclass class MockChoice: message: MockMessage = MockMessage() + @dataclass class MockCompletion: choices: List[MockChoice] = None usage: MockUsage = None - + def __init__(self): self.choices = [MockChoice()] self.usage = MockUsage() + class MockLiteLLM: - async def acompletion(self, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int = None, **kwargs) -> MockCompletion: + async def acompletion( + self, + model: str, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int = None, + **kwargs + ) -> MockCompletion: return MockCompletion() class cost_calculator: diff --git a/services/completion-old-sdk/_app.py b/services/completion-old-sdk/_app.py index eabf2491f0..430aed8434 100644 --- a/services/completion-old-sdk/_app.py +++ b/services/completion-old-sdk/_app.py @@ -2,12 +2,15 @@ from supported_llm_models import get_all_supported_llm_models import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM + litellm = MockLiteLLM() else: import litellm + litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] diff --git a/services/completion-old-sdk/mock_litellm.py b/services/completion-old-sdk/mock_litellm.py index 1f4cd973f3..a5b57a68cc 100644 --- a/services/completion-old-sdk/mock_litellm.py +++ b/services/completion-old-sdk/mock_litellm.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List from dataclasses import dataclass + @dataclass class MockUsage: prompt_tokens: int = 10 @@ -11,28 +12,39 @@ def dict(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, } + @dataclass class MockMessage: content: str = "This is a mock response from the LLM." + @dataclass class MockChoice: message: MockMessage = MockMessage() + @dataclass class MockCompletion: choices: List[MockChoice] = None usage: MockUsage = None - + def __init__(self): self.choices = [MockChoice()] self.usage = MockUsage() + class MockLiteLLM: - async def acompletion(self, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int = None, **kwargs) -> MockCompletion: + async def acompletion( + self, + model: str, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int = None, + **kwargs + ) -> MockCompletion: return MockCompletion() class cost_calculator: diff --git a/services/completion-stateless-sdk/_app.py b/services/completion-stateless-sdk/_app.py index eabf2491f0..430aed8434 100644 --- a/services/completion-stateless-sdk/_app.py +++ b/services/completion-stateless-sdk/_app.py @@ -2,12 +2,15 @@ from supported_llm_models import get_all_supported_llm_models import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM + litellm = MockLiteLLM() else: import litellm + litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] diff --git a/services/completion-stateless-sdk/agenta/sdk/decorators/tracing.py b/services/completion-stateless-sdk/agenta/sdk/decorators/tracing.py index 6d31243963..f368509fc6 100644 --- a/services/completion-stateless-sdk/agenta/sdk/decorators/tracing.py +++ b/services/completion-stateless-sdk/agenta/sdk/decorators/tracing.py @@ -246,7 +246,9 @@ def _redact( not in ( ignore if isinstance(ignore, list) - else io.keys() if ignore is True else [] + else io.keys() + if ignore is True + else [] ) } diff --git a/services/completion-stateless-sdk/mock_litellm.py b/services/completion-stateless-sdk/mock_litellm.py index 1f4cd973f3..a5b57a68cc 100644 --- a/services/completion-stateless-sdk/mock_litellm.py +++ b/services/completion-stateless-sdk/mock_litellm.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List from dataclasses import dataclass + @dataclass class MockUsage: prompt_tokens: int = 10 @@ -11,28 +12,39 @@ def dict(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens + "total_tokens": self.total_tokens, } + @dataclass class MockMessage: content: str = "This is a mock response from the LLM." + @dataclass class MockChoice: message: MockMessage = MockMessage() + @dataclass class MockCompletion: choices: List[MockChoice] = None usage: MockUsage = None - + def __init__(self): self.choices = [MockChoice()] self.usage = MockUsage() + class MockLiteLLM: - async def acompletion(self, model: str, messages: List[Dict[str, Any]], temperature: float, max_tokens: int = None, **kwargs) -> MockCompletion: + async def acompletion( + self, + model: str, + messages: List[Dict[str, Any]], + temperature: float, + max_tokens: int = None, + **kwargs + ) -> MockCompletion: return MockCompletion() class cost_calculator: diff --git a/services/test/conftest.py b/services/test/conftest.py new file mode 100644 index 0000000000..f89e1c3c09 --- /dev/null +++ b/services/test/conftest.py @@ -0,0 +1,24 @@ +import pytest +import httpx +import pytest_asyncio + + +# Configure pytest-asyncio to use strict mode +def pytest_configure(config): + config.option.asyncio_mode = "strict" + + +@pytest.fixture +def chat_url(): + return "http://localhost/chat-live-sdk" # Adjust this if your services run on different ports + + +@pytest.fixture +def completion_url(): + return "http://localhost/completion-live-sdk" + + +@pytest_asyncio.fixture +async def async_client(): + async with httpx.AsyncClient() as client: + yield client diff --git a/services/test/requirements.txt b/services/test/requirements.txt new file mode 100644 index 0000000000..9bd3895bcd --- /dev/null +++ b/services/test/requirements.txt @@ -0,0 +1,3 @@ +pytest==7.4.3 +pytest-asyncio==0.21.1 +httpx==0.25.2 diff --git a/services/test/test_chat_service.py b/services/test/test_chat_service.py new file mode 100644 index 0000000000..555f80f0e6 --- /dev/null +++ b/services/test/test_chat_service.py @@ -0,0 +1,76 @@ +import pytest +import pytest_asyncio +from typing import Dict, Any + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.asyncio +async def test_generate(async_client, chat_url): + payload = { + "inputs": [ + { + "role": "user", + "content": "What are some innovative tech solutions for a startup?", + } + ] + } + response = await async_client.post(f"{chat_url}/generate", json=payload) + assert response.status_code == 200 + data = response.json() + + # Check response structure + assert "version" in data + assert "data" in data + assert "tree" in data + + # Check tree structure + tree = data["tree"] + assert "nodes" in tree + assert len(tree["nodes"]) > 0 + + # Check first node + node = tree["nodes"][0] + assert "lifecycle" in node + assert "data" in node + assert "metrics" in node + assert "meta" in node + + # Check configuration + config = node["meta"]["configuration"] + assert config["model"] == "gpt-3.5-turbo" + assert "temperature" in config + assert "prompt_system" in config + + +@pytest.mark.asyncio +async def test_run(async_client, chat_url): + payload = { + "inputs": [ + { + "role": "user", + "content": "What are the best practices for startup growth?", + } + ] + } + response = await async_client.post(f"{chat_url}/run", json=payload) + assert response.status_code == 200 + data = response.json() + + assert "version" in data + assert "data" in data + assert isinstance(data["data"], str) + + +@pytest.mark.asyncio +async def test_generate_deployed(async_client, chat_url): + payload = { + "inputs": [{"role": "user", "content": "How to build a successful tech team?"}] + } + response = await async_client.post(f"{chat_url}/generate_deployed", json=payload) + assert response.status_code == 200 + data = response.json() + + assert "version" in data + assert "data" in data + assert isinstance(data["data"], str) diff --git a/services/test/test_completion_service.py b/services/test/test_completion_service.py new file mode 100644 index 0000000000..33c1d46dfd --- /dev/null +++ b/services/test/test_completion_service.py @@ -0,0 +1,78 @@ +import pytest +import pytest_asyncio +from typing import Dict, Any + +pytestmark = pytest.mark.asyncio + + +async def test_health(async_client, completion_url): + response = await async_client.get(f"{completion_url}/health") + assert response.status_code == 200 + data = response.json() + assert data == {"status": "ok"} + + +async def test_generate(async_client, completion_url): + payload = {"inputs": {"country": "France"}} + response = await async_client.post(f"{completion_url}/generate", json=payload) + assert response.status_code == 200 + data = response.json() + + # Check response structure + assert "version" in data + assert "data" in data + assert "tree" in data + + # Check tree structure + tree = data["tree"] + assert "nodes" in tree + assert len(tree["nodes"]) > 0 + + # Check first node + node = tree["nodes"][0] + assert "lifecycle" in node + assert "data" in node + assert "metrics" in node + assert "meta" in node + + # Check configuration + config = node["meta"]["configuration"] + assert config["model"] == "gpt-3.5-turbo" + assert "temperature" in config + assert "prompt_system" in config + assert "prompt_user" in config + + +async def test_playground_run(async_client, completion_url): + payload = {"inputs": {"country": "Spain"}} + response = await async_client.post(f"{completion_url}/playground/run", json=payload) + assert response.status_code == 200 + data = response.json() + + assert "version" in data + assert "data" in data + assert isinstance(data["data"], str) + + +async def test_generate_deployed(async_client, completion_url): + payload = {"inputs": {"country": "Germany"}} + response = await async_client.post( + f"{completion_url}/generate_deployed", json=payload + ) + assert response.status_code == 200 + data = response.json() + + assert "version" in data + assert "data" in data + assert isinstance(data["data"], str) + + +async def test_run(async_client, completion_url): + payload = {"inputs": {"country": "Italy"}} + response = await async_client.post(f"{completion_url}/run", json=payload) + assert response.status_code == 200 + data = response.json() + + assert "version" in data + assert "data" in data + assert isinstance(data["data"], str)