From 389acdf52c3aaea9470b1307a017bdec21a524e2 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 28 Oct 2024 17:40:40 -0700 Subject: [PATCH] Track usage on OpenAI stream requests, closes #591 --- llm/default_plugins/openai_models.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index bce1fb04..5cbb02bb 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -346,7 +346,7 @@ def execute(self, prompt, stream, response, conversation=None): ) messages.append({"role": "user", "content": attachment_message}) - kwargs = self.build_kwargs(prompt) + kwargs = self.build_kwargs(prompt, stream) client = self.get_client() if stream: completion = client.chat.completions.create( @@ -358,7 +358,10 @@ def execute(self, prompt, stream, response, conversation=None): chunks = [] for chunk in completion: chunks.append(chunk) - content = chunk.choices[0].delta.content + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None if content is not None: yield content response.response_json = remove_dict_none_values(combine_chunks(chunks)) @@ -395,13 +398,15 @@ def get_client(self): kwargs["http_client"] = logging_client() return openai.OpenAI(**kwargs) - def build_kwargs(self, prompt): + def build_kwargs(self, prompt, stream): kwargs = dict(not_nulls(prompt.options)) json_object = kwargs.pop("json_object", None) if "max_tokens" not in kwargs and self.default_max_tokens is not None: kwargs["max_tokens"] = self.default_max_tokens if json_object: kwargs["response_format"] = {"type": "json_object"} + if stream: + kwargs["stream_options"] = {"include_usage": True} return kwargs @@ -431,7 +436,7 @@ def execute(self, prompt, stream, response, conversation=None): messages.append(prev_response.prompt.prompt) messages.append(prev_response.text()) messages.append(prompt.prompt) - kwargs = self.build_kwargs(prompt) + kwargs = self.build_kwargs(prompt, stream) client = self.get_client() if stream: completion = client.completions.create( @@ -443,7 +448,10 @@ def execute(self, prompt, stream, response, conversation=None): chunks = [] for chunk in completion: chunks.append(chunk) - content = chunk.choices[0].text + try: + content = chunk.choices[0].text + except IndexError: + content = None if content is not None: yield content combined = combine_chunks(chunks) @@ -472,8 +480,11 @@ def combine_chunks(chunks: List) -> dict: # If any of them have log probability, we're going to persist # those later on logprobs = [] + usage = {} for item in chunks: + if item.usage: + usage = dict(item.usage) for choice in item.choices: if choice.logprobs and hasattr(choice.logprobs, "top_logprobs"): logprobs.append( @@ -497,6 +508,7 @@ def combine_chunks(chunks: List) -> dict: "content": content, "role": role, "finish_reason": finish_reason, + "usage": usage, } if logprobs: combined["logprobs"] = logprobs