From 81a2084cbb4d9bdbfff592c6fe25f3e19e7111a8 Mon Sep 17 00:00:00 2001 From: Richard Blythman Date: Thu, 8 Feb 2024 09:45:31 +0000 Subject: [PATCH] fix bug in tools --- tools/prediction_request/prediction_request.py | 8 ++++---- .../prediction_request_sme.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tools/prediction_request/prediction_request.py b/tools/prediction_request/prediction_request.py index 0cb0c61d..60f53886 100644 --- a/tools/prediction_request/prediction_request.py +++ b/tools/prediction_request/prediction_request.py @@ -289,8 +289,8 @@ def fetch_additional_information( texts.append(extract_text(html=source_link, num_words=num_words)) if counter_callback: counter_callback( - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, model=engine, ) return "\n".join(["- " + text for text in texts]), counter_callback @@ -418,8 +418,8 @@ def run(**kwargs) -> Tuple[Optional[str], Optional[Dict[str, Any]], Any]: ) if counter_callback is not None: counter_callback( - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, model=engine, ) return response.choices[0].message.content, prediction_prompt, counter_callback diff --git a/tools/prediction_request_sme/prediction_request_sme.py b/tools/prediction_request_sme/prediction_request_sme.py index 20f4c67e..351423a6 100644 --- a/tools/prediction_request_sme/prediction_request_sme.py +++ b/tools/prediction_request_sme/prediction_request_sme.py @@ -307,8 +307,8 @@ def fetch_additional_information( texts.append(extract_text(html=source_link, num_words=num_words)) if counter_callback: counter_callback( - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, model=engine, ) return "\n".join(["- " + text for text in texts]), counter_callback @@ -339,9 +339,9 @@ def get_sme_role( sme = json.loads(generated_sme_roles)[0] if counter_callback is not None: counter_callback( - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], - total_tokens=response["usage"]["total_tokens"], + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, model=engine, ) return sme["sme"], sme["sme_introduction"], counter_callback @@ -421,8 +421,8 @@ def run(**kwargs) -> Tuple[str, Optional[str], Optional[Dict[str, Any]], Any]: ) if counter_callback is not None: counter_callback( - input_tokens=response["usage"]["prompt_tokens"], - output_tokens=response["usage"]["completion_tokens"], + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, model=engine, ) return response.choices[0].message.content, prediction_prompt, counter_callback