diff --git a/services/completion-new-sdk-prompt/__init__.py b/services/completion-new-sdk-prompt/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/services/completion-new-sdk-prompt/__init__.py @@ -0,0 +1 @@ + diff --git a/services/completion-new-sdk-prompt/_app.py b/services/completion-new-sdk-prompt/_app.py index a316f9442b..ae8b4a0c15 100644 --- a/services/completion-new-sdk-prompt/_app.py +++ b/services/completion-new-sdk-prompt/_app.py @@ -1,4 +1,4 @@ -from typing import Annotated, List, Union, Optional, Dict, Literal +from typing import Annotated, List, Union, Optional, Dict, Literal, Any from pydantic import BaseModel, Field, root_validator import agenta as ag @@ -42,7 +42,50 @@ class ResponseFormat(BaseModel): type: Literal["text", "json_object"] = "text" schema: Optional[Dict] = None -class Prompts(BaseModel): +class PromptTemplateError(Exception): + """Base exception for all PromptTemplate errors""" + pass + +class InputValidationError(PromptTemplateError): + """Raised when input validation fails""" + def __init__(self, message: str, missing: Optional[set] = None, extra: Optional[set] = None): + self.missing = missing + self.extra = extra + super().__init__(message) + +class TemplateFormatError(PromptTemplateError): + """Raised when template formatting fails""" + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + +class ModelConfig(BaseModel): + """Configuration for model parameters""" + model: Annotated[str, ag.MultipleChoice(choices=supported_llm_models)] = Field( + default="gpt-3.5-turbo", + description="The model to use for completion" + ) + temperature: float = Field(default=1.0, ge=0.0, le=2.0) + max_tokens: int = Field(default=-1, ge=-1, description="Maximum tokens to generate. -1 means no limit") + top_p: float = Field(default=1.0, ge=0.0, le=1.0) + frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + response_format: Optional[ResponseFormat] = Field( + default=None, + description="Specify the format of the response (text or JSON)" + ) + stream: bool = Field(default=False) + tools: Optional[List[Dict]] = Field( + default=None, + description="List of tools/functions the model can use" + ) + tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field( + default="auto", + description="Control which tool the model should use" + ) + +class PromptTemplate(BaseModel): + """A template for generating prompts with formatting capabilities""" messages: List[Message] = Field( default=[ Message(role="system", content=prompts["system_prompt"]), @@ -55,17 +98,13 @@ class Prompts(BaseModel): default="fstring", description="Format type for template variables: fstring {var}, jinja2 {{ var }}, or curly {{var}}" ) - response_format: Optional[ResponseFormat] = Field( - default=None, - description="Specify the format of the response (text or JSON)" - ) - tools: Optional[List[Dict]] = Field( + input_keys: Optional[List[str]] = Field( default=None, - description="List of tools/functions the model can use" + description="Optional list of input keys for validation. If not provided, any inputs will be accepted" ) - tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field( - default="auto", - description="Control which tool the model should use" + llm_config: ModelConfig = Field( + default_factory=ModelConfig, + description="Configuration for the model parameters" ) class Config: @@ -86,74 +125,149 @@ def init_messages(cls, values): values["messages"] = messages return values -class MyConfig(BaseModel): - prompt: Prompts = Field(default=Prompts()) - + def _format_with_template(self, content: str, kwargs: Dict[str, Any]) -> str: + """Internal method to format content based on template_format""" + try: + if self.template_format == "fstring": + return content.format(**kwargs) + elif self.template_format == "jinja2": + from jinja2 import Template, TemplateError + try: + return Template(content).render(**kwargs) + except TemplateError as e: + raise TemplateFormatError( + f"Jinja2 template error in content: '{content}'. Error: {str(e)}", + original_error=e + ) + elif self.template_format == "curly": + import re + result = content + for key, value in kwargs.items(): + result = re.sub(r'\{\{' + key + r'\}\}', str(value), result) + if re.search(r'\{\{.*?\}\}', result): + unreplaced = re.findall(r'\{\{(.*?)\}\}', result) + raise TemplateFormatError( + f"Unreplaced variables in curly template: {unreplaced}" + ) + return result + else: + raise TemplateFormatError(f"Unknown template format: {self.template_format}") + except KeyError as e: + key = str(e).strip("'") + raise TemplateFormatError( + f"Missing required variable '{key}' in template: '{content}'" + ) + except Exception as e: + raise TemplateFormatError( + f"Error formatting template '{content}': {str(e)}", + original_error=e + ) + def format(self, **kwargs) -> 'PromptTemplate': + """ + Format the template with provided inputs. + Only validates against input_keys if they are specified. -@ag.instrument(spankind="llm") -async def llm_call(prompt_system: str, prompt_user: str): - config = ag.ConfigManager.get_from_route(schema=MyConfig) - response_format = ( - {"type": "json_object"} - if config.force_json and config.model in GPT_FORMAT_RESPONSE - else {"type": "text"} - ) + Raises: + InputValidationError: If input validation fails + TemplateFormatError: If template formatting fails + """ + # Validate inputs if input_keys is set + if self.input_keys is not None: + missing = set(self.input_keys) - set(kwargs.keys()) + extra = set(kwargs.keys()) - set(self.input_keys) + + error_parts = [] + if missing: + error_parts.append(f"Missing required inputs: {', '.join(sorted(missing))}") + if extra: + error_parts.append(f"Unexpected inputs: {', '.join(sorted(extra))}") + + if error_parts: + raise InputValidationError( + " | ".join(error_parts), + missing=missing if missing else None, + extra=extra if extra else None + ) + + new_messages = [] + for i, msg in enumerate(self.messages): + if msg.content: + try: + new_content = self._format_with_template(msg.content, kwargs) + except TemplateFormatError as e: + raise TemplateFormatError( + f"Error in message {i} ({msg.role}): {str(e)}", + original_error=e.original_error + ) + else: + new_content = None + + new_messages.append(Message( + role=msg.role, + content=new_content, + name=msg.name, + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id + )) + + return PromptTemplate( + messages=new_messages, + template_format=self.template_format, + llm_config=self.llm_config, + input_keys=self.input_keys + ) - max_tokens = config.max_tokens if config.max_tokens != -1 else None - - # Include frequency_penalty and presence_penalty only if supported - completion_params = {} - if config.model in GPT_FORMAT_RESPONSE: - completion_params["frequency_penalty"] = config.frequence_penalty - completion_params["presence_penalty"] = config.presence_penalty - - response = await litellm.acompletion( - **{ - "model": config.model, - "messages": config.prompt.messages, - "temperature": config.temperature, - "max_tokens": max_tokens, - "top_p": config.top_p, - "response_format": response_format, - **completion_params, + def to_openai_kwargs(self) -> dict: + """Convert the prompt template to kwargs compatible with litellm/openai""" + kwargs = { + "model": self.llm_config.model, + "messages": [msg.dict(exclude_none=True) for msg in self.messages], + "temperature": self.llm_config.temperature, + "top_p": self.llm_config.top_p, + "stream": self.llm_config.stream, } - ) - token_usage = response.usage.dict() - return { - "message": response.choices[0].message.content, - "usage": token_usage, - "cost": litellm.cost_calculator.completion_cost( - completion_response=response, model=config.model - ), - } + # Add optional parameters only if they have non-default values + if self.llm_config.max_tokens != -1: + kwargs["max_tokens"] = self.llm_config.max_tokens + + if self.llm_config.frequency_penalty != 0: + kwargs["frequency_penalty"] = self.llm_config.frequency_penalty + + if self.llm_config.presence_penalty != 0: + kwargs["presence_penalty"] = self.llm_config.presence_penalty + + if self.llm_config.response_format: + kwargs["response_format"] = self.llm_config.response_format.dict() + + if self.llm_config.tools: + kwargs["tools"] = self.llm_config.tools + + if self.llm_config.tool_choice and self.llm_config.tool_choice != "auto": + kwargs["tool_choice"] = self.llm_config.tool_choice + + return kwargs + +class MyConfig(BaseModel): + prompt: PromptTemplate = Field(default=PromptTemplate()) @ag.route("/", config_schema=MyConfig) @ag.instrument() async def generate( - inputs: ag.DictInput = ag.DictInput(default_keys=["country"]), + inputs: Dict[str, str], ): config = ag.ConfigManager.get_from_route(schema=MyConfig) - print("popo", config) - try: - prompt_user = config.prompt_user.format(**inputs) - except Exception as e: - prompt_user = config.prompt_user + try: - prompt_system = config.prompt_system.format(**inputs) + # Format the prompt template with the inputs + formatted_prompt = config.prompt.format(**inputs) + config.prompt = formatted_prompt + except (InputValidationError, TemplateFormatError) as e: + raise ValueError(f"Error formatting prompt template: {str(e)}") except Exception as e: - prompt_system = config.prompt_system + raise ValueError(f"Unexpected error formatting prompt: {str(e)}") - # SET MAX TOKENS - via completion() - if config.force_json and config.model not in GPT_FORMAT_RESPONSE: - raise ValueError( - "Model {} does not support JSON response format".format(config.model) - ) - - response = await llm_call(prompt_system=prompt_system, prompt_user=prompt_user) - return { - "message": response["message"], - "usage": response.get("usage", None), - "cost": response.get("cost", None), - } + response = await litellm.acompletion(**config.prompt.to_openai_kwargs()) + + return response.choices[0].message.content diff --git a/services/test/__init__.py b/services/test/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/services/test/__init__.py @@ -0,0 +1 @@ + diff --git a/services/test/mock_agenta.py b/services/test/mock_agenta.py new file mode 100644 index 0000000000..429135a296 --- /dev/null +++ b/services/test/mock_agenta.py @@ -0,0 +1,24 @@ +"""Mock agenta module for testing""" +from typing import Any, Dict, Type, TypeVar, Optional +from dataclasses import dataclass + +T = TypeVar('T') + +@dataclass +class ConfigManager: + """Mock ConfigManager""" + @staticmethod + def get_from_route(schema: Type[T]) -> T: + return schema() + +def route(path: str = "", config_schema: Optional[Type[Any]] = None): + """Mock route decorator""" + def decorator(func): + return func + return decorator + +def instrument(): + """Mock instrument decorator""" + def decorator(func): + return func + return decorator diff --git a/services/test/mock_litellm.py b/services/test/mock_litellm.py new file mode 100644 index 0000000000..7e8916d10f --- /dev/null +++ b/services/test/mock_litellm.py @@ -0,0 +1,70 @@ +import pytest +from typing import List, Optional, Dict, Any +from dataclasses import dataclass + +@dataclass +class Message: + role: str + content: str + tool_calls: Optional[List[Dict[str, Any]]] = None + +@dataclass +class Choice: + message: Message + index: int = 0 + finish_reason: str = "stop" + +@dataclass +class Response: + choices: List[Choice] + model: str = "gpt-4" + id: str = "mock-response-id" + +class MockLiteLLM: + """Mock LiteLLM for testing""" + + async def acompletion(self, **kwargs): + """Mock async completion""" + model = kwargs.get("model", "gpt-4") + messages = kwargs.get("messages", []) + tools = kwargs.get("tools", []) + response_format = kwargs.get("response_format", None) + + # Simulate different response types based on input + if tools: + # Function calling response + tool_calls = [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London", "unit": "celsius"}' + } + }] + message = Message( + role="assistant", + content=None, + tool_calls=tool_calls + ) + elif response_format and response_format["type"] == "json_object": + # JSON response + message = Message( + role="assistant", + content='{"colors": ["red", "blue", "green"]}' + ) + else: + # Regular text response + message = Message( + role="assistant", + content="This is a mock response" + ) + + return Response( + choices=[Choice(message=message)], + model=model + ) + +@pytest.fixture +def mock_litellm(): + """Fixture to provide mock LiteLLM instance""" + return MockLiteLLM() diff --git a/services/test/test_prompt_template.py b/services/test/test_prompt_template.py new file mode 100644 index 0000000000..08f80b729b --- /dev/null +++ b/services/test/test_prompt_template.py @@ -0,0 +1,251 @@ +import pytest +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from typing import Dict, List +from pydantic import ValidationError + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "completion-new-sdk-prompt")) +from _app import ( + PromptTemplate, + ModelConfig, + Message, + InputValidationError, + TemplateFormatError, + ResponseFormat +) +from .mock_litellm import MockLiteLLM + +# Test Data +BASIC_MESSAGES = [ + Message(role="system", content="You are a {type} assistant"), + Message(role="user", content="Help me with {task}") +] + +TOOL_MESSAGES = [ + Message(role="system", content="You are a function calling assistant"), + Message(role="user", content="Get the weather for {location}") +] + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location"] + } + } +} + +class TestPromptTemplateBasics: + """Test basic functionality of PromptTemplate""" + + def test_create_template(self): + """Test creating a basic template""" + template = PromptTemplate(messages=BASIC_MESSAGES) + assert len(template.messages) == 2 + assert template.messages[0].role == "system" + assert template.messages[1].role == "user" + + def test_create_template_with_model_config(self): + """Test creating template with custom model config""" + model_config = ModelConfig( + model="gpt-4", + temperature=0.7, + max_tokens=100 + ) + template = PromptTemplate( + messages=BASIC_MESSAGES, + model_config=model_config + ) + assert template.model_config.model == "gpt-4" + assert template.model_config.temperature == 0.7 + assert template.model_config.max_tokens == 100 + + def test_invalid_model_config(self): + """Test validation errors for invalid model config""" + with pytest.raises(ValidationError): + ModelConfig(temperature=3.0) # temperature > 2.0 + + with pytest.raises(ValidationError): + ModelConfig(max_tokens=-2) # max_tokens < -1 + +class TestPromptFormatting: + """Test template formatting functionality""" + + def test_basic_format(self): + """Test basic formatting with valid inputs""" + template = PromptTemplate(messages=BASIC_MESSAGES) + formatted = template.format(type="coding", task="Python") + assert formatted.messages[0].content == "You are a coding assistant" + assert formatted.messages[1].content == "Help me with Python" + + def test_format_with_validation(self): + """Test formatting with input validation""" + template = PromptTemplate( + messages=BASIC_MESSAGES, + input_keys=["type", "task"] + ) + # Valid inputs + formatted = template.format(type="coding", task="Python") + assert formatted.messages[0].content == "You are a coding assistant" + + # Missing input + with pytest.raises(InputValidationError) as exc: + template.format(type="coding") + assert "Missing required inputs: task" in str(exc.value) + + # Extra input + with pytest.raises(InputValidationError) as exc: + template.format(type="coding", task="Python", extra="value") + assert "Unexpected inputs: extra" in str(exc.value) + + @pytest.mark.parametrize("template_format,template_string,inputs,expected", [ + ("fstring", "Hello {name}", {"name": "World"}, "Hello World"), + ("jinja2", "Hello {{ name }}", {"name": "World"}, "Hello World"), + ("curly", "Hello {{name}}", {"name": "World"}, "Hello World"), + ]) + def test_format_types(self, template_format, template_string, inputs, expected): + """Test different format types""" + template = PromptTemplate( + messages=[Message(role="user", content=template_string)], + template_format=template_format + ) + formatted = template.format(**inputs) + assert formatted.messages[0].content == expected + + def test_format_errors(self): + """Test formatting error cases""" + template = PromptTemplate(messages=BASIC_MESSAGES) + + # Missing variable + with pytest.raises(TemplateFormatError) as exc: + template.format(type="coding") # missing 'task' + assert "Missing required variable" in str(exc.value) + + # Invalid template + bad_template = PromptTemplate( + messages=[Message(role="user", content="Hello {")] + ) + with pytest.raises(TemplateFormatError): + bad_template.format(name="World") + +class TestOpenAIIntegration: + """Test OpenAI/LiteLLM integration features""" + + def test_basic_openai_kwargs(self): + """Test basic OpenAI kwargs generation""" + template = PromptTemplate( + messages=BASIC_MESSAGES, + model_config=ModelConfig( + model="gpt-4", + temperature=0.7, + max_tokens=100 + ) + ) + kwargs = template.to_openai_kwargs() + assert kwargs["model"] == "gpt-4" + assert kwargs["temperature"] == 0.7 + assert kwargs["max_tokens"] == 100 + assert len(kwargs["messages"]) == 2 + + def test_tools_openai_kwargs(self): + """Test OpenAI kwargs with tools""" + template = PromptTemplate( + messages=TOOL_MESSAGES, + model_config=ModelConfig( + model="gpt-4", + tools=[WEATHER_TOOL], + tool_choice="auto" + ) + ) + kwargs = template.to_openai_kwargs() + assert len(kwargs["tools"]) == 1 + assert kwargs["tools"][0]["type"] == "function" + assert kwargs["tool_choice"] == "auto" + + def test_json_mode_openai_kwargs(self): + """Test OpenAI kwargs with JSON mode""" + template = PromptTemplate( + messages=BASIC_MESSAGES, + model_config=ModelConfig( + model="gpt-4", + response_format=ResponseFormat(type="json_object") + ) + ) + kwargs = template.to_openai_kwargs() + assert kwargs["response_format"]["type"] == "json_object" + + def test_optional_params_openai_kwargs(self): + """Test that optional params are only included when non-default""" + template = PromptTemplate( + messages=BASIC_MESSAGES, + model_config=ModelConfig( + model="gpt-4", + frequency_penalty=0.0, # default value + presence_penalty=0.5 # non-default value + ) + ) + kwargs = template.to_openai_kwargs() + assert "frequency_penalty" not in kwargs + assert kwargs["presence_penalty"] == 0.5 + +class TestEndToEndScenarios: + """Test end-to-end scenarios""" + + @pytest.mark.asyncio + async def test_chat_completion(self, mock_litellm): + """Test chat completion with basic prompt""" + template = PromptTemplate( + messages=[ + Message(role="user", content="Say hello to {name}") + ], + model_config=ModelConfig(model="gpt-3.5-turbo") + ) + formatted = template.format(name="World") + kwargs = formatted.to_openai_kwargs() + + response = await mock_litellm.acompletion(**kwargs) + assert response.choices[0].message.content is not None + + @pytest.mark.asyncio + async def test_function_calling(self, mock_litellm): + """Test function calling scenario""" + template = PromptTemplate( + messages=TOOL_MESSAGES, + model_config=ModelConfig( + model="gpt-4", + tools=[WEATHER_TOOL], + tool_choice="auto" + ) + ) + formatted = template.format(location="London") + kwargs = formatted.to_openai_kwargs() + + response = await mock_litellm.acompletion(**kwargs) + assert response.choices[0].message.tool_calls is not None + + @pytest.mark.asyncio + async def test_json_mode(self, mock_litellm): + """Test JSON mode response""" + template = PromptTemplate( + messages=[ + Message(role="user", content="List 3 colors in JSON") + ], + model_config=ModelConfig( + model="gpt-4", + response_format=ResponseFormat(type="json_object") + ) + ) + kwargs = template.to_openai_kwargs() + + response = await mock_litellm.acompletion(**kwargs) + assert response.choices[0].message.content.startswith("{") + assert response.choices[0].message.content.endswith("}")