diff --git a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py index 8d81ecfbc..30fa6dc17 100644 --- a/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py +++ b/berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/claude.py @@ -3,6 +3,8 @@ from anthropic import Anthropic from anthropic.types import TextBlock, ToolUseBlock +from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming +from anthropic.types.beta.messages.batch_create_params import Request from bfcl.model_handler.base_handler import BaseHandler from bfcl.model_handler.constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI from bfcl.model_handler.model_style import ModelStyle @@ -66,23 +68,43 @@ def decode_execute(self, result): else: function_call = convert_to_function_call(result) return function_call + + # Helper function to process the batch response and return results + def _process_batch_response(self, batch_response): + results = [] + for result in batch_response: + if result.result.type == "succeeded": + results.append(result.result.message.content) + elif result.result.type == "errored": + print(f"Error: {result.result.error}") + + return results #### FC methods #### def _query_FC(self, inference_data: dict): - inference_data["inference_input_log"] = { - "message": repr(inference_data["message"]), - "tools": inference_data["tools"], - } + # Initialize batch request list + batch_requests = [] + + # For each message in the inference data, add to the batch request + for message in inference_data["message"]: + batch_requests.append( + Request( + custom_id=f"fc-{message['content'][:20]}", # Custom ID for each request + params=MessageCreateParamsNonStreaming( + model=self.model_name.strip("-FC"), + max_tokens=8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096, + tools=inference_data["tools"], + messages=[message], + ) + ) + ) + + # Send the batch request + batch_response = self.client.beta.messages.batches.create(requests=batch_requests) - return self.client.messages.create( - model=self.model_name.strip("-FC"), - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit - tools=inference_data["tools"], - messages=inference_data["message"], - ) + # Process the batch response + return self._process_batch_response(batch_response) def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: for round_idx in range(len(test_entry["question"])): @@ -183,22 +205,29 @@ def _add_execution_results_FC( #### Prompting methods #### def _query_prompting(self, inference_data: dict): - inference_data["inference_input_log"] = { - "message": repr(inference_data["message"]), - "system_prompt": inference_data["system_prompt"], - } - - api_response = self.client.messages.create( - model=self.model_name, - max_tokens=( - 8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096 - ), # 3.5 Sonnet has a higher max token limit - temperature=self.temperature, - system=inference_data["system_prompt"], - messages=inference_data["message"], - ) - - return api_response + # Initialize batch request list + batch_requests = [] + + # Add all the messages to the batch + for message in inference_data["message"]: + batch_requests.append( + Request( + custom_id=f"prompt-{message['content'][:20]}", + params=MessageCreateParamsNonStreaming( + model=self.model_name, + max_tokens=8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096, + temperature=self.temperature, + messages=[message], + system=inference_data["system_prompt"], + ) + ) + ) + + # Send the batch request + batch_response = self.client.beta.messages.batches.create(requests=batch_requests) + + # Process the batch response + return self._process_batch_response(batch_response) def _pre_query_processing_prompting(self, test_entry: dict) -> dict: functions: list = test_entry["function"] @@ -261,4 +290,4 @@ def _add_execution_results_prompting( {"role": "user", "content": formatted_results_message} ) - return inference_data + return inference_data \ No newline at end of file