Skip to content

Commit

Permalink
Refactor (sdk/types): improve sdk types to be compatible with pydanti…
Browse files Browse the repository at this point in the history
…c v2
  • Loading branch information
aybruhm committed May 25, 2024
1 parent 5641fd2 commit c0be44e
Showing 1 changed file with 65 additions and 83 deletions.
148 changes: 65 additions & 83 deletions agenta-cli/agenta/sdk/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from pydantic import BaseModel, Extra, HttpUrl, Field
from pydantic import ConfigDict, BaseModel, HttpUrl


class InFile:
Expand All @@ -24,111 +24,97 @@ class FuncResponse(BaseModel):


class DictInput(dict):
def __new__(cls, default_keys=None):
def __new__(cls, default_keys: Optional[List[str]] = None):
instance = super().__new__(cls, default_keys)
if default_keys is None:
default_keys = []
instance.data = [key for key in default_keys]
instance.data = [key for key in default_keys] # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update({"x-parameter": "dict"})
def __schema__(cls) -> dict:
return {"x-parameter": "dict"}


class TextParam(str):
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update({"x-parameter": "text"})
def __schema__(cls) -> dict:
return {"x-parameter": "text", "type": "string"}


class BinaryParam(int):
def __new__(cls, value: bool = False):
instance = super().__new__(cls, int(value))
instance.default = value
instance.default = value # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(
{
"x-parameter": "bool",
"type": "boolean",
}
)
def __schema__(cls) -> dict:
return {
"x-parameter": "bool",
"type": "boolean",
}


class IntParam(int):
def __new__(cls, default: int = 6, minval: float = 1, maxval: float = 10):
instance = super().__new__(cls, default)
instance.minval = minval
instance.maxval = maxval
instance.minval = minval # type: ignore
instance.maxval = maxval # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(
{
"x-parameter": "int",
"type": "integer",
"minimum": 1,
"maximum": 10,
}
)
def __schema__(cls) -> dict:
return {"x-parameter": "int", "type": "integer"}


class FloatParam(float):
def __new__(cls, default: float = 0.5, minval: float = 0.0, maxval: float = 1.0):
instance = super().__new__(cls, default)
instance.minval = minval
instance.maxval = maxval
instance.default = default # type: ignore
instance.minval = minval # type: ignore
instance.maxval = maxval # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(
{
"x-parameter": "float",
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
}
)
def __schema__(cls) -> dict:
return {"x-parameter": "float", "type": "number"}


class MultipleChoiceParam(str):
def __new__(cls, default: str = None, choices: List[str] = None):
if type(default) is list:
def __new__(
cls, default: Optional[str] = None, choices: Optional[List[str]] = None
):
if default is not None and type(default) is list:
raise ValueError(
"The order of the parameters for MultipleChoiceParam is wrong! It's MultipleChoiceParam(default, choices) and not the opposite"
)
if default is None and choices:

if not default and choices is not None:
# if a default value is not provided,
# uset the first value in the choices list
# set the first value in the choices list
default = choices[0]

if default is None and not choices:
# raise error if no default value or choices is provided
raise ValueError("You must provide either a default value or choices")

instance = super().__new__(cls, default)
instance.choices = choices
instance.default = default
instance.choices = choices # type: ignore
instance.default = default # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema: dict[str, Any]):
field_schema.update(
{
"x-parameter": "choice",
"type": "string",
"enum": [],
}
)
def __schema__(cls) -> dict:
return {"x-parameter": "choice", "type": "string", "enum": []}


class GroupedMultipleChoiceParam(str):
def __new__(cls, default: str = None, choices: Dict[str, List[str]] = None):
def __new__(
cls,
default: Optional[str] = None,
choices: Optional[Dict[str, List[str]]] = None,
):
if choices is None:
choices = {}

Expand All @@ -143,31 +129,23 @@ def __new__(cls, default: str = None, choices: Dict[str, List[str]] = None):
)

if not default:
for choices in choices.values():
if choices:
default = choices[0]
break
default_selected_choice = next(
(choices for choices in choices.values()), None
)
if default_selected_choice:
default = default_selected_choice[0]

instance = super().__new__(cls, default)
instance.choices = choices
instance.default = default
instance.choices = choices # type: ignore
instance.default = default # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema: dict[str, Any], **kwargs):
choices = kwargs.get("choices", {})
field_schema.update(
{
"x-parameter": "grouped_choice",
"type": "string",
"choices": choices,
}
)


class Message(BaseModel):
role: str
content: str
def __schema__(cls) -> dict:
return {
"x-parameter": "grouped_choice",
"type": "string",
}


class MessagesInput(list):
Expand All @@ -182,28 +160,32 @@ class MessagesInput(list):
"""

def __new__(cls, messages: List[Dict[str, str]] = None):
instance = super().__new__(cls, messages)
instance.default = messages
def __new__(cls, messages: List[Dict[str, str]] = []):
instance = super().__new__(cls)
instance.default = messages # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema: dict[str, Any]):
field_schema.update({"x-parameter": "messages", "type": "array"})
def __schema__(cls) -> dict:
return {"x-parameter": "messages", "type": "array"}


class FileInputURL(HttpUrl):
def __new__(cls, url: str):
instance = super().__new__(cls, url)
instance.default = url # type: ignore
return instance

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update({"x-parameter": "file_url", "type": "string"})
def __schema__(cls) -> dict:
return {"x-parameter": "file_url", "type": "string"}


class Context(BaseModel):
class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")

def to_json(self):
return self.json()
return self.model_dump()

@classmethod
def from_json(cls, json_str: str):
Expand Down

0 comments on commit c0be44e

Please sign in to comment.