Skip to content

Commit

Permalink
Track usage on OpenAI stream requests, closes #591
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Oct 29, 2024
1 parent ba1ccb3 commit 389acdf
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions llm/default_plugins/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def execute(self, prompt, stream, response, conversation=None):
)
messages.append({"role": "user", "content": attachment_message})

kwargs = self.build_kwargs(prompt)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.chat.completions.create(
Expand All @@ -358,7 +358,10 @@ def execute(self, prompt, stream, response, conversation=None):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk.choices[0].delta.content
try:
content = chunk.choices[0].delta.content
except IndexError:
content = None
if content is not None:
yield content
response.response_json = remove_dict_none_values(combine_chunks(chunks))
Expand Down Expand Up @@ -395,13 +398,15 @@ def get_client(self):
kwargs["http_client"] = logging_client()
return openai.OpenAI(**kwargs)

def build_kwargs(self, prompt):
def build_kwargs(self, prompt, stream):
kwargs = dict(not_nulls(prompt.options))
json_object = kwargs.pop("json_object", None)
if "max_tokens" not in kwargs and self.default_max_tokens is not None:
kwargs["max_tokens"] = self.default_max_tokens
if json_object:
kwargs["response_format"] = {"type": "json_object"}
if stream:
kwargs["stream_options"] = {"include_usage": True}
return kwargs


Expand Down Expand Up @@ -431,7 +436,7 @@ def execute(self, prompt, stream, response, conversation=None):
messages.append(prev_response.prompt.prompt)
messages.append(prev_response.text())
messages.append(prompt.prompt)
kwargs = self.build_kwargs(prompt)
kwargs = self.build_kwargs(prompt, stream)
client = self.get_client()
if stream:
completion = client.completions.create(
Expand All @@ -443,7 +448,10 @@ def execute(self, prompt, stream, response, conversation=None):
chunks = []
for chunk in completion:
chunks.append(chunk)
content = chunk.choices[0].text
try:
content = chunk.choices[0].text
except IndexError:
content = None
if content is not None:
yield content
combined = combine_chunks(chunks)
Expand Down Expand Up @@ -472,8 +480,11 @@ def combine_chunks(chunks: List) -> dict:
# If any of them have log probability, we're going to persist
# those later on
logprobs = []
usage = {}

for item in chunks:
if item.usage:
usage = dict(item.usage)
for choice in item.choices:
if choice.logprobs and hasattr(choice.logprobs, "top_logprobs"):
logprobs.append(
Expand All @@ -497,6 +508,7 @@ def combine_chunks(chunks: List) -> dict:
"content": content,
"role": role,
"finish_reason": finish_reason,
"usage": usage,
}
if logprobs:
combined["logprobs"] = logprobs
Expand Down

0 comments on commit 389acdf

Please sign in to comment.