diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index b73eb24d60..8a683f19e1 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -10,6 +10,7 @@ MessagesInput, TextParam, FileInputURL, + BinaryParam, ) from .sdk.utils.preinit import PreInitObject from .sdk.agenta_init import Config, init diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index b10b8c1e17..ebd87f40ba 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -12,6 +12,7 @@ TextParam, MessagesInput, FileInputURL, + BinaryParam, ) from .agenta_init import Config, init diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index d47a45f719..677b9c0ae2 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -26,6 +26,7 @@ TextParam, MessagesInput, FileInputURL, + BinaryParam, ) app = FastAPI() @@ -352,6 +353,7 @@ def override_schema(openapi_schema: dict, func_name: str, endpoint: str, params: - The default value for DictInput instance - The default value for MessagesParam instance - The default value for FileInputURL instance + - The default value for BinaryParam instance - ... [PLEASE ADD AT EACH CHANGE] Args: @@ -424,3 +426,6 @@ def find_in_schema(schema: dict, param_name: str, xparam: str): ): subschema = find_in_schema(schema_to_override, param_name, "file_url") subschema["default"] = "https://example.com" + if isinstance(param_val, BinaryParam): + subschema = find_in_schema(schema_to_override, param_name, "bool") + subschema["default"] = param_val.default diff --git a/agenta-cli/agenta/sdk/types.py b/agenta-cli/agenta/sdk/types.py index 8c22032bf8..3dc07cb6ef 100644 --- a/agenta-cli/agenta/sdk/types.py +++ b/agenta-cli/agenta/sdk/types.py @@ -1,7 +1,7 @@ import json from typing import Any, Dict, List -from pydantic import BaseModel, Extra, HttpUrl +from pydantic import BaseModel, Extra, HttpUrl, Field class InFile: @@ -29,6 +29,33 @@ def __modify_schema__(cls, field_schema): field_schema.update({"x-parameter": "text"}) +class BoolMeta(type): + """ + This meta class handles the behavior of a boolean without + directly inheriting from it (avoiding the conflict + that comes from inheriting bool). + """ + + def __new__(cls, name: str, bases: tuple, namespace: dict): + if "default" in namespace and namespace["default"] not in [0, 1]: + raise ValueError("Must provide either 0 or 1") + namespace["default"] = bool(namespace.get("default", 0)) + instance = super().__new__(cls, name, bases, namespace) + instance.default = 0 + return instance + + +class BinaryParam(int, metaclass=BoolMeta): + @classmethod + def __modify_schema__(cls, field_schema): + field_schema.update( + { + "x-parameter": "bool", + "type": "boolean", + } + ) + + class IntParam(int): def __new__(cls, default: int = 6, minval: float = 1, maxval: float = 10): instance = super().__new__(cls, default) diff --git a/agenta-web/src/components/Playground/Views/ParametersCards.tsx b/agenta-web/src/components/Playground/Views/ParametersCards.tsx index e18a6a570d..8a298e0d70 100644 --- a/agenta-web/src/components/Playground/Views/ParametersCards.tsx +++ b/agenta-web/src/components/Playground/Views/ParametersCards.tsx @@ -1,8 +1,8 @@ -import {Row, Card, Slider, Select, InputNumber, Col, Input, Button} from "antd" import React from "react" -import {Parameter, InputParameter} from "@/lib/Types" -import {renameVariables} from "@/lib/helpers/utils" import {createUseStyles} from "react-jss" +import {renameVariables} from "@/lib/helpers/utils" +import {Parameter, InputParameter} from "@/lib/Types" +import {Row, Card, Slider, Select, InputNumber, Col, Input, Button, Switch} from "antd" const useStyles = createUseStyles({ row1: { @@ -72,6 +72,10 @@ export const ModelParameters: React.FC = ({ handleParamChange, }) => { const classes = useStyles() + const handleCheckboxChange = (paramName: string, checked: boolean) => { + const value = checked ? 1 : 0 + handleParamChange(paramName, value) + } return ( <> {optParams?.some((param) => !param.input && param.type === "number") && ( @@ -80,10 +84,11 @@ export const ModelParameters: React.FC = ({ {optParams ?.filter( (param) => - !param.input && - (param.type === "number" || - param.type === "integer" || - param.type === "array"), + (!param.input && + (param.type === "number" || + param.type === "integer" || + param.type === "array")) || + param.type === "boolean", ) .map((param, index) => ( @@ -136,6 +141,14 @@ export const ModelParameters: React.FC = ({ ))} )} + {param.type === "boolean" && ( + + handleCheckboxChange(param.name, checked) + } + /> + )} {param.type === "number" && ( diff --git a/agenta-web/src/lib/helpers/openapi_parser.ts b/agenta-web/src/lib/helpers/openapi_parser.ts index 54d7dd34b7..02a3100d68 100644 --- a/agenta-web/src/lib/helpers/openapi_parser.ts +++ b/agenta-web/src/lib/helpers/openapi_parser.ts @@ -63,6 +63,8 @@ const determineType = (xParam: any): string => { return "number" case "dict": return "object" + case "bool": + return "boolean" case "int": return "integer" case "file_url": diff --git a/examples/chat_json_format/app.py b/examples/chat_json_format/app.py new file mode 100644 index 0000000000..8f9d234480 --- /dev/null +++ b/examples/chat_json_format/app.py @@ -0,0 +1,43 @@ +import agenta as ag +from agenta.sdk.types import BinaryParam +from openai import OpenAI + +client = OpenAI() + +SYSTEM_PROMPT = "You have expertise in offering technical ideas to startups. Responses should be in json." +GPT_FORMAT_RESPONSE = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"] +CHAT_LLM_GPT = [ + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4", +] + GPT_FORMAT_RESPONSE + +ag.init() +ag.config.default( + temperature=ag.FloatParam(0.2), + model=ag.MultipleChoiceParam("gpt-3.5-turbo", CHAT_LLM_GPT), + max_tokens=ag.IntParam(-1, -1, 4000), + prompt_system=ag.TextParam(SYSTEM_PROMPT), + force_json_response=BinaryParam(), +) + + +@ag.entrypoint +def chat(inputs: ag.MessagesInput = ag.MessagesInput()): + messages = [{"role": "system", "content": ag.config.prompt_system}] + inputs + max_tokens = ag.config.max_tokens if ag.config.max_tokens != -1 else None + response_format = ( + {"type": "json_object"} + if ag.config.force_json_response and ag.config.model in GPT_FORMAT_RESPONSE + else {"type": "text"} + ) + chat_completion = client.chat.completions.create( + model=ag.config.model, + messages=messages, + temperature=ag.config.temperature, + max_tokens=max_tokens, + response_format=response_format, + ) + return chat_completion.choices[0].message.content diff --git a/examples/chat_json_format/requirements.txt b/examples/chat_json_format/requirements.txt new file mode 100644 index 0000000000..310f162cec --- /dev/null +++ b/examples/chat_json_format/requirements.txt @@ -0,0 +1,2 @@ +agenta +openai \ No newline at end of file