Skip to content

Commit

Permalink
Fix OpenAI assistant API related bugs and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
notsyncing committed Jan 5, 2025
1 parent 3300fce commit a1b3490
Show file tree
Hide file tree
Showing 13 changed files with 517 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ See [CHANGELOG](./CHANGELOG.md) for more details.
|Assistants (v2)|Assistants|☑️|`response_format` not implemented yet|
|Assistants (v2)|Threads|||
|Assistants (v2)|Messages|||
|Assistants (v2)|Runs & Run steps|🚧|`include[]`, `response_format` not implemented yet, tools not implemented yet, `stream` not implemented yet|
|Assistants (v2)|Runs & Run steps|🚧|`include[]`, `response_format` not implemented yet, tools not implemented yet, `stream` not implemented yet, some other small parts may also not implemented|
|Assistants (v2)|Vector stores||Vector store bytes used is estimated|
|Assistants (v2)|Vector store files||Vector store bytes used is estimated|
|Assistants (v2)|Vector store file batches||Vector store bytes used is estimated|
Expand Down
32 changes: 29 additions & 3 deletions src/azarrot/agents/chat_task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from typing_extensions import override

from azarrot.agents.common_data import (
AgentChatTaskAutoThreadHistoryStrategyParams,
AgentChatTaskDetailToolCallItem,
AgentChatTaskInfo,
AgentChatTaskLastMessageThreadHistoryStrategyParams,
AgentChatTaskMessageDetailsData,
AgentChatTaskToolCallDetailsData,
AgentToolRequest,
Expand Down Expand Up @@ -49,6 +51,7 @@
AgentChatTaskRequiredAction,
AgentChatTaskStatus,
)
from azarrot.config import ServerConfig
from azarrot.database_schemas import AgentChatTask, AgentChatTaskDetail, AgentChatTaskTool
from azarrot.file_store import FileStore
from azarrot.frontends.backend_pipe import BackendPipe
Expand All @@ -65,6 +68,7 @@ class AgentChatTaskExecutor(ChatThreadMessageListener):
_log: Logger = logging.getLogger(__name__)

_database: Engine
_server_config: ServerConfig
_chat_template_manager: ChatTemplateManager
_agent_manager: AgentManager
_model_manager: ModelManager
Expand All @@ -78,6 +82,7 @@ class AgentChatTaskExecutor(ChatThreadMessageListener):
def __init__(
self,
database: Engine,
server_config: ServerConfig,
chat_template_manager: ChatTemplateManager,
agent_manager: AgentManager,
model_manager: ModelManager,
Expand All @@ -86,6 +91,7 @@ def __init__(
backend_pipe: BackendPipe,
) -> None:
self._database = database
self._server_config = server_config
self._chat_template_manager = chat_template_manager
self._agent_manager = agent_manager
self._model_manager = model_manager
Expand Down Expand Up @@ -261,8 +267,23 @@ def __to_backend_generation_messages(self, messages: list[ChatMessageItem]) -> l

return results

def __determine_fetch_msg_count(self, task: AgentChatTaskInfo) -> int:
fetch_msg_count = -1

if task.thread_history_strategy == "auto":
# TODO: Implement truncating with head_preserve_count
assert isinstance(task.thread_history_strategy_params, AgentChatTaskAutoThreadHistoryStrategyParams)
fetch_msg_count = task.thread_history_strategy_params.tail_preserve_count
elif task.thread_history_strategy == "last_messages":
assert isinstance(task.thread_history_strategy_params, AgentChatTaskLastMessageThreadHistoryStrategyParams)
fetch_msg_count = task.thread_history_strategy_params.count

return fetch_msg_count

def __execute_task(self, task: AgentChatTaskInfo) -> None:
latest_messages = self._chat_thread_manager.get_latest_messages(task.thread_id)
latest_messages = self._chat_thread_manager.get_latest_messages(
task.thread_id, count=self.__determine_fetch_msg_count(task)
)

if len(latest_messages) <= 0:
raise ValueError(f"No message on thread {task.thread_id}")
Expand Down Expand Up @@ -382,12 +403,15 @@ def __run_generation(
if result is None:
result = content

if self._server_config.log_generation_details:
self._log.info("Generation response: %s", result)

gen_stats.end_time = datetime.now()
self._log.info(gen_stats.to_stats_text())

self.__handle_generation_result(task, result)
self.__handle_generation_result(task, result, gen_stats)

def __handle_generation_result(self, task: AgentChatTaskInfo, result: Any) -> None:
def __handle_generation_result(self, task: AgentChatTaskInfo, result: Any, gen_stats: GenerationStatistics) -> None:
task_id = UUID(task.id)

if isinstance(result, str):
Expand All @@ -403,6 +427,7 @@ def __handle_generation_result(self, task: AgentChatTaskInfo, result: Any) -> No
detail_type="message",
detail_status="completed",
data=AgentChatTaskMessageDetailsData(message_id=msg.id),
generation_statistics=gen_stats,
)

self.__update_task_status(task_id, "in_progress", "completed")
Expand Down Expand Up @@ -438,6 +463,7 @@ def __handle_generation_result(self, task: AgentChatTaskInfo, result: Any) -> No
for c in result.tool_requests
]
),
generation_statistics=gen_stats,
)

self.__update_task_required_action(task_id, "tool_call_request", result)
Expand Down
30 changes: 24 additions & 6 deletions src/azarrot/agents/chat_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from azarrot.common_types import (
AgentChatTaskDetailStatus,
AgentChatTaskDetailType,
AgentChatTaskStatus,
AgentChatTaskThreadHistoryStrategy,
)
from azarrot.database_schemas import (
Expand All @@ -53,6 +54,7 @@ class AgentChatTaskCreationRequest:
thread_id: str | uuid.UUID
model_id: str | None = None
model_instruction: str | None = None
status: AgentChatTaskStatus | None = None
generation_parameters: AgentGenerationParameters | None = None
thread_history_strategy: AgentChatTaskThreadHistoryStrategy | None = None
thread_history_strategy_params: AgentChatTaskThreadHistoryStrategyParams | None = None
Expand Down Expand Up @@ -163,25 +165,36 @@ def create_task(self, request: AgentChatTaskCreationRequest) -> AgentChatTaskInf
gen_params = None

if request.generation_parameters is not None:
gen_params = json.dumps(request.generation_parameters)
gen_params = json.dumps(dataclass_wizard.asdict(request.generation_parameters))

thread_history_strategy = request.thread_history_strategy

if thread_history_strategy is None:
thread_history_strategy = "auto"

if request.thread_history_strategy_params is not None:
thread_history_strategy_params = json.dumps(request.thread_history_strategy_params)
thread_history_strategy_params = json.dumps(
dataclass_wizard.asdict(request.thread_history_strategy_params)
)
else:
thread_history_strategy_params = json.dumps(
dataclass_wizard.asdict(AgentChatTaskAutoThreadHistoryStrategyParams())
)

if request.tools_info is not None:
tools_info = json.dumps(dataclass_wizard.asdict(request.tools_info))
else:
thread_history_strategy_params = json.dumps(AgentChatTaskAutoThreadHistoryStrategyParams())
tools_info = None

status = request.status if request.status is not None else "pending"

db_task = AgentChatTask(
id=task_id,
agent_id=agent_id,
thread_id=thread_id,
model_id=model_id,
model_instruction=model_instruction,
status="pending",
status=status,
current_required_action=None,
current_required_action_data=None,
start_time=None,
Expand All @@ -191,7 +204,7 @@ def create_task(self, request: AgentChatTaskCreationRequest) -> AgentChatTaskInf
thread_history_strategy=thread_history_strategy,
thread_history_strategy_params=thread_history_strategy_params,
max_tokens=request.max_tokens if request.max_tokens is not None else -1,
tools_info=json.dumps(request.tools_info) if request.tools_info is not None else None,
tools_info=tools_info,
parallel_tool_calling=request.parallel_tool_calling,
additional_data=json.dumps(request.additional_data) if request.additional_data is not None else None,
create_time=now,
Expand All @@ -209,7 +222,10 @@ def create_task(self, request: AgentChatTaskCreationRequest) -> AgentChatTaskInf
db.commit()

info = AgentChatTaskInfo.from_db(db_task, agent_tools)
self._executor.add_task(info)

if status == "pending":
self._executor.add_task(info)

return info

def get_current_tasks_by_messages(self, message_id_list: Sequence[str | uuid.UUID]) -> dict[str, AgentChatTaskInfo]:
Expand Down Expand Up @@ -371,6 +387,8 @@ def update(
return True

def cancel(self, thread_id: str | uuid.UUID, agent_chat_task_id: str | uuid.UUID) -> bool:
# TODO: Implement cancelling in_progress run

thread_id = sanitize_uuid(thread_id)
agent_chat_task_id = sanitize_uuid(agent_chat_task_id)

Expand Down
7 changes: 3 additions & 4 deletions src/azarrot/agents/common_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from azarrot.common_data import (
CallableToolsInfo,
GenerationStatistics,
ToolCallRequestMessageContent,
ToolCallRequestMessageContents,
)
from azarrot.common_types import (
Expand Down Expand Up @@ -109,11 +108,11 @@ def from_db(dbo: AgentChatTask, db_tools: Sequence[AgentChatTaskTool] | None = N
current_req_action = dbo.current_required_action

if current_req_action == "tool_call_request" and dbo.current_required_action_data is not None:
req_list = dataclass_wizard.fromlist(
ToolCallRequestMessageContent, json.loads(dbo.current_required_action_data)
reqs = dataclass_wizard.fromdict(
ToolCallRequestMessageContents, json.loads(dbo.current_required_action_data)
)

current_req_action_data = ToolCallRequestMessageContents(req_list)
current_req_action_data = ToolCallRequestMessageContents(reqs.tool_requests)
else:
current_req_action_data = None

Expand Down
5 changes: 5 additions & 0 deletions src/azarrot/backends/backend_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def stop(self) -> None:


class BaseBackend(ABC):
_log: logging.Logger = logging.getLogger(__name__)

_device_queues: ClassVar[dict[str, DeviceWorker]] = {}
_device_queue_lock: ClassVar[threading.Lock] = threading.Lock()

Expand Down Expand Up @@ -318,6 +320,9 @@ def __submit_task_to_device(self, task: BackendGenerationTask) -> TaskReference:
def generate(
self, request: TextGenerationRequest, generation_handlers: GenerationHandlers
) -> tuple[CustomTextIteratorStreamer, GenerationStatistics]:
if self._server_config.log_generation_details:
self._log.info("Generation request: %s", request)

task, streamer, gen_stats = self._generate(request, generation_handlers)
self.__submit_task_to_device(task)

Expand Down
5 changes: 5 additions & 0 deletions src/azarrot/backends/transformers_based_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def __generate_normal(
}
)

seed = request.seed

if seed is None:
seed = self._server_config.default_seed

return TransformersGenerationMethods(
model=loaded_model.model, streamer=streamer, seed=request.seed, generation_kwargs=generation_kwargs
)
Expand Down
20 changes: 10 additions & 10 deletions src/azarrot/chats/thread_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,17 +475,17 @@ def get_message(
def get_latest_messages(self, thread_id: str | uuid.UUID, count: int = 1) -> list[ChatMessageItem]:
thread_id = sanitize_uuid(thread_id)

query = (
select(ChatMessage.id)
.where(and_(ChatMessage.thread_id == thread_id, ChatMessage.deleted == False))
.order_by(ChatMessage.create_time.desc(), ChatMessage.order.desc())
)

if count > 0:
query = query.limit(count)

with Session(self._database) as db:
latest_msg_id_list = (
db.execute(
select(ChatMessage.id)
.where(and_(ChatMessage.thread_id == thread_id, ChatMessage.deleted == False))
.order_by(ChatMessage.create_time.desc(), ChatMessage.order.desc())
.limit(count)
)
.scalars()
.all()
)
latest_msg_id_list = db.execute(query).scalars().all()

if latest_msg_id_list is None:
return []
Expand Down
2 changes: 2 additions & 0 deletions src/azarrot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ServerConfig:
single_token_generation_timeout: int = 60000
auto_batch_threshold: int = 100
auto_batch_max_size: int = 8
default_seed: int | None = None
log_generation_details: bool = False

partial_file_expire_time: int = 3600 * 1000

Expand Down
21 changes: 12 additions & 9 deletions src/azarrot/frontends/openai_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ToolCallResponseMessageContent,
WorkingDirectories,
)
from azarrot.config import DEFAULT_MAX_TOKENS, OpenAIFrontendConfig
from azarrot.config import DEFAULT_MAX_TOKENS, ServerConfig
from azarrot.file_store import FileStore
from azarrot.frontends.backend_pipe import BackendPipe
from azarrot.frontends.base import Frontend
Expand Down Expand Up @@ -58,7 +58,7 @@

class OpenAIFrontend(Frontend):
_log = logging.getLogger(__name__)
_openai_config: OpenAIFrontendConfig
_server_config: ServerConfig
_model_manager: ModelManager
_backend_pipe: BackendPipe
_working_dirs: WorkingDirectories
Expand All @@ -72,7 +72,7 @@ class OpenAIFrontend(Frontend):

def __init__(
self,
openai_config: OpenAIFrontendConfig,
server_config: ServerConfig,
model_manager: ModelManager,
backend_pipe: BackendPipe,
file_store: FileStore,
Expand All @@ -83,19 +83,19 @@ def __init__(
api: FastAPI,
working_dirs: WorkingDirectories,
) -> None:
self._openai_config = openai_config
self._server_config = server_config
self._model_manager = model_manager
self._working_dirs = working_dirs
self._backend_pipe = backend_pipe
self._openai_files = OpenAIFiles(file_store)
self._assistants = OpenAIAssistants(openai_config, agent_manager, vector_store, model_manager)
self._assistants = OpenAIAssistants(server_config.openai_configs, agent_manager, vector_store, model_manager)

self._threads = OpenAIAssistantThreads(
openai_config, chat_thread_manager, vector_store, model_manager, file_store
server_config.openai_configs, chat_thread_manager, vector_store, model_manager, file_store
)

self._messages = OpenAIAssistantMessages(file_store, chat_thread_manager, agent_chat_task_manager)
self._vstores = OpenAIVectorStores(openai_config, model_manager, vector_store)
self._vstores = OpenAIVectorStores(server_config.openai_configs, model_manager, vector_store)
self._runs = OpenAIAssistantRuns(agent_chat_task_manager, chat_thread_manager, self._threads)
self._api = api

Expand Down Expand Up @@ -140,6 +140,7 @@ def __init_routes(self) -> None: # noqa: PLR0915

# Assistants - Threads API
router.add_api_route("/v1/threads", self._threads.create_thread, methods=["POST"])
router.add_api_route("/v1/threads/runs", self._runs.run_assistant_with_thread, methods=["POST"])
router.add_api_route("/v1/threads/{thread_id}", self._threads.get_thread, methods=["GET"])
router.add_api_route("/v1/threads/{thread_id}", self._threads.update_thread, methods=["POST"])
router.add_api_route("/v1/threads/{thread_id}", self._threads.delete_thread, methods=["DELETE"])
Expand All @@ -153,9 +154,8 @@ def __init_routes(self) -> None: # noqa: PLR0915

r_url = "/v1/threads/{tid}/runs"

# Assistants - Run API
# Assistants - Runs API
router.add_api_route(r_url, self._runs.run_assistant, methods=["POST"])
router.add_api_route("/v1/threads/runs", self._runs.run_assistant_with_thread, methods=["POST"])
router.add_api_route(r_url, self._runs.get_run_list, methods=["GET"])
router.add_api_route(r_url + "/{rid}", self._runs.get_run, methods=["GET"])
router.add_api_route(r_url + "/{rid}", self._runs.update_run, methods=["POST"])
Expand Down Expand Up @@ -411,6 +411,9 @@ def chat_completions(self, request: ChatCompletionRequest) -> dict | StreamingRe
if result is None:
result = content

if self._server_config.log_generation_details:
self._log.info("Generation response: %s", result)

gen_stats.end_time = datetime.now()
self.__log_generation_statistics(gen_stats)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class OpenAIAssistantToolCallsRunStep:


@dataclass
class OpenAIAssistantRunStep(BaseModel):
class OpenAIAssistantRunStep:
id: str
created_at: int
assistant_id: str
Expand Down
Loading

0 comments on commit a1b3490

Please sign in to comment.