Skip to content

Commit

Permalink
updates to the new prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
mmabrouk committed Dec 16, 2024
1 parent f3e04b1 commit f69b901
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 68 deletions.
1 change: 1 addition & 0 deletions services/completion-new-sdk-prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

250 changes: 182 additions & 68 deletions services/completion-new-sdk-prompt/_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, List, Union, Optional, Dict, Literal
from typing import Annotated, List, Union, Optional, Dict, Literal, Any
from pydantic import BaseModel, Field, root_validator

import agenta as ag
Expand Down Expand Up @@ -42,7 +42,50 @@ class ResponseFormat(BaseModel):
type: Literal["text", "json_object"] = "text"
schema: Optional[Dict] = None

class Prompts(BaseModel):
class PromptTemplateError(Exception):
"""Base exception for all PromptTemplate errors"""
pass

class InputValidationError(PromptTemplateError):
"""Raised when input validation fails"""
def __init__(self, message: str, missing: Optional[set] = None, extra: Optional[set] = None):
self.missing = missing
self.extra = extra
super().__init__(message)

class TemplateFormatError(PromptTemplateError):
"""Raised when template formatting fails"""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)

class ModelConfig(BaseModel):
"""Configuration for model parameters"""
model: Annotated[str, ag.MultipleChoice(choices=supported_llm_models)] = Field(
default="gpt-3.5-turbo",
description="The model to use for completion"
)
temperature: float = Field(default=1.0, ge=0.0, le=2.0)
max_tokens: int = Field(default=-1, ge=-1, description="Maximum tokens to generate. -1 means no limit")
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
response_format: Optional[ResponseFormat] = Field(
default=None,
description="Specify the format of the response (text or JSON)"
)
stream: bool = Field(default=False)
tools: Optional[List[Dict]] = Field(
default=None,
description="List of tools/functions the model can use"
)
tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field(
default="auto",
description="Control which tool the model should use"
)

class PromptTemplate(BaseModel):
"""A template for generating prompts with formatting capabilities"""
messages: List[Message] = Field(
default=[
Message(role="system", content=prompts["system_prompt"]),
Expand All @@ -55,17 +98,13 @@ class Prompts(BaseModel):
default="fstring",
description="Format type for template variables: fstring {var}, jinja2 {{ var }}, or curly {{var}}"
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description="Specify the format of the response (text or JSON)"
)
tools: Optional[List[Dict]] = Field(
input_keys: Optional[List[str]] = Field(
default=None,
description="List of tools/functions the model can use"
description="Optional list of input keys for validation. If not provided, any inputs will be accepted"
)
tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field(
default="auto",
description="Control which tool the model should use"
llm_config: ModelConfig = Field(
default_factory=ModelConfig,
description="Configuration for the model parameters"
)

class Config:
Expand All @@ -86,74 +125,149 @@ def init_messages(cls, values):
values["messages"] = messages
return values

class MyConfig(BaseModel):
prompt: Prompts = Field(default=Prompts())

def _format_with_template(self, content: str, kwargs: Dict[str, Any]) -> str:
"""Internal method to format content based on template_format"""
try:
if self.template_format == "fstring":
return content.format(**kwargs)
elif self.template_format == "jinja2":
from jinja2 import Template, TemplateError
try:
return Template(content).render(**kwargs)
except TemplateError as e:
raise TemplateFormatError(
f"Jinja2 template error in content: '{content}'. Error: {str(e)}",
original_error=e
)
elif self.template_format == "curly":
import re
result = content
for key, value in kwargs.items():
result = re.sub(r'\{\{' + key + r'\}\}', str(value), result)
if re.search(r'\{\{.*?\}\}', result):
unreplaced = re.findall(r'\{\{(.*?)\}\}', result)
raise TemplateFormatError(
f"Unreplaced variables in curly template: {unreplaced}"
)
return result
else:
raise TemplateFormatError(f"Unknown template format: {self.template_format}")
except KeyError as e:
key = str(e).strip("'")
raise TemplateFormatError(
f"Missing required variable '{key}' in template: '{content}'"
)
except Exception as e:
raise TemplateFormatError(
f"Error formatting template '{content}': {str(e)}",
original_error=e
)

def format(self, **kwargs) -> 'PromptTemplate':
"""
Format the template with provided inputs.
Only validates against input_keys if they are specified.
@ag.instrument(spankind="llm")
async def llm_call(prompt_system: str, prompt_user: str):
config = ag.ConfigManager.get_from_route(schema=MyConfig)
response_format = (
{"type": "json_object"}
if config.force_json and config.model in GPT_FORMAT_RESPONSE
else {"type": "text"}
)
Raises:
InputValidationError: If input validation fails
TemplateFormatError: If template formatting fails
"""
# Validate inputs if input_keys is set
if self.input_keys is not None:
missing = set(self.input_keys) - set(kwargs.keys())
extra = set(kwargs.keys()) - set(self.input_keys)

error_parts = []
if missing:
error_parts.append(f"Missing required inputs: {', '.join(sorted(missing))}")
if extra:
error_parts.append(f"Unexpected inputs: {', '.join(sorted(extra))}")

if error_parts:
raise InputValidationError(
" | ".join(error_parts),
missing=missing if missing else None,
extra=extra if extra else None
)

new_messages = []
for i, msg in enumerate(self.messages):
if msg.content:
try:
new_content = self._format_with_template(msg.content, kwargs)
except TemplateFormatError as e:
raise TemplateFormatError(
f"Error in message {i} ({msg.role}): {str(e)}",
original_error=e.original_error
)
else:
new_content = None

new_messages.append(Message(
role=msg.role,
content=new_content,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id
))

return PromptTemplate(
messages=new_messages,
template_format=self.template_format,
llm_config=self.llm_config,
input_keys=self.input_keys
)

max_tokens = config.max_tokens if config.max_tokens != -1 else None

# Include frequency_penalty and presence_penalty only if supported
completion_params = {}
if config.model in GPT_FORMAT_RESPONSE:
completion_params["frequency_penalty"] = config.frequence_penalty
completion_params["presence_penalty"] = config.presence_penalty

response = await litellm.acompletion(
**{
"model": config.model,
"messages": config.prompt.messages,
"temperature": config.temperature,
"max_tokens": max_tokens,
"top_p": config.top_p,
"response_format": response_format,
**completion_params,
def to_openai_kwargs(self) -> dict:
"""Convert the prompt template to kwargs compatible with litellm/openai"""
kwargs = {
"model": self.llm_config.model,
"messages": [msg.dict(exclude_none=True) for msg in self.messages],
"temperature": self.llm_config.temperature,
"top_p": self.llm_config.top_p,
"stream": self.llm_config.stream,
}
)
token_usage = response.usage.dict()
return {
"message": response.choices[0].message.content,
"usage": token_usage,
"cost": litellm.cost_calculator.completion_cost(
completion_response=response, model=config.model
),
}

# Add optional parameters only if they have non-default values
if self.llm_config.max_tokens != -1:
kwargs["max_tokens"] = self.llm_config.max_tokens

if self.llm_config.frequency_penalty != 0:
kwargs["frequency_penalty"] = self.llm_config.frequency_penalty

if self.llm_config.presence_penalty != 0:
kwargs["presence_penalty"] = self.llm_config.presence_penalty

if self.llm_config.response_format:
kwargs["response_format"] = self.llm_config.response_format.dict()

if self.llm_config.tools:
kwargs["tools"] = self.llm_config.tools

if self.llm_config.tool_choice and self.llm_config.tool_choice != "auto":
kwargs["tool_choice"] = self.llm_config.tool_choice

return kwargs

class MyConfig(BaseModel):
prompt: PromptTemplate = Field(default=PromptTemplate())

@ag.route("/", config_schema=MyConfig)
@ag.instrument()
async def generate(
inputs: ag.DictInput = ag.DictInput(default_keys=["country"]),
inputs: Dict[str, str],
):
config = ag.ConfigManager.get_from_route(schema=MyConfig)
print("popo", config)
try:
prompt_user = config.prompt_user.format(**inputs)
except Exception as e:
prompt_user = config.prompt_user

try:
prompt_system = config.prompt_system.format(**inputs)
# Format the prompt template with the inputs
formatted_prompt = config.prompt.format(**inputs)
config.prompt = formatted_prompt
except (InputValidationError, TemplateFormatError) as e:
raise ValueError(f"Error formatting prompt template: {str(e)}")
except Exception as e:
prompt_system = config.prompt_system
raise ValueError(f"Unexpected error formatting prompt: {str(e)}")

# SET MAX TOKENS - via completion()
if config.force_json and config.model not in GPT_FORMAT_RESPONSE:
raise ValueError(
"Model {} does not support JSON response format".format(config.model)
)

response = await llm_call(prompt_system=prompt_system, prompt_user=prompt_user)
return {
"message": response["message"],
"usage": response.get("usage", None),
"cost": response.get("cost", None),
}
response = await litellm.acompletion(**config.prompt.to_openai_kwargs())

return response.choices[0].message.content
1 change: 1 addition & 0 deletions services/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

24 changes: 24 additions & 0 deletions services/test/mock_agenta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Mock agenta module for testing"""
from typing import Any, Dict, Type, TypeVar, Optional
from dataclasses import dataclass

T = TypeVar('T')

@dataclass
class ConfigManager:
"""Mock ConfigManager"""
@staticmethod
def get_from_route(schema: Type[T]) -> T:
return schema()

def route(path: str = "", config_schema: Optional[Type[Any]] = None):
"""Mock route decorator"""
def decorator(func):
return func
return decorator

def instrument():
"""Mock instrument decorator"""
def decorator(func):
return func
return decorator
70 changes: 70 additions & 0 deletions services/test/mock_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from typing import List, Optional, Dict, Any
from dataclasses import dataclass

@dataclass
class Message:
role: str
content: str
tool_calls: Optional[List[Dict[str, Any]]] = None

@dataclass
class Choice:
message: Message
index: int = 0
finish_reason: str = "stop"

@dataclass
class Response:
choices: List[Choice]
model: str = "gpt-4"
id: str = "mock-response-id"

class MockLiteLLM:
"""Mock LiteLLM for testing"""

async def acompletion(self, **kwargs):
"""Mock async completion"""
model = kwargs.get("model", "gpt-4")
messages = kwargs.get("messages", [])
tools = kwargs.get("tools", [])
response_format = kwargs.get("response_format", None)

# Simulate different response types based on input
if tools:
# Function calling response
tool_calls = [{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "London", "unit": "celsius"}'
}
}]
message = Message(
role="assistant",
content=None,
tool_calls=tool_calls
)
elif response_format and response_format["type"] == "json_object":
# JSON response
message = Message(
role="assistant",
content='{"colors": ["red", "blue", "green"]}'
)
else:
# Regular text response
message = Message(
role="assistant",
content="This is a mock response"
)

return Response(
choices=[Choice(message=message)],
model=model
)

@pytest.fixture
def mock_litellm():
"""Fixture to provide mock LiteLLM instance"""
return MockLiteLLM()
Loading

0 comments on commit f69b901

Please sign in to comment.