diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 332ee788b..12c2f416f 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -334,9 +334,9 @@ def get_current_temperature(self, location): @pytest.fixture def generator(self, model_path, capsys): gguf_model_path = ( - "https://huggingface.co/bartowski/functionary-small-v3.1-GGUF/blob/main/functionary-small-v3.1-Q4_K_M.gguf" + "https://huggingface.co/meetkai/functionary-small-v2.4-GGUF/resolve/main/functionary-small-v2.4.Q4_0.gguf" ) - filename = "functionary-small-v3.1-Q4_K_M.gguf" + filename = "functionary-small-v2.4.Q4_0.gguf" download_file(gguf_model_path, str(model_path / filename), capsys) model_path = str(model_path / filename) hf_tokenizer_path = "meetkai/functionary-small-v2.4-GGUF" @@ -344,6 +344,7 @@ def generator(self, model_path, capsys): model=model_path, n_ctx=8192, n_batch=512, + generation_kwargs={"max_tokens": 256}, model_kwargs={ "chat_format": "functionary-v2", "hf_tokenizer_path": hf_tokenizer_path, @@ -399,7 +400,6 @@ def test_function_call_and_execute(self, generator): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, @@ -407,7 +407,8 @@ def test_function_call_and_execute(self, generator): } ] - response = generator.run(messages=messages, generation_kwargs={"tools": tools}) + tool_choice = {"type": "function", "function": {"name": "get_current_temperature"}} + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) available_functions = { "get_current_temperature": self.get_current_temperature, @@ -417,6 +418,7 @@ def test_function_call_and_execute(self, generator): assert len(response["replies"]) > 0 first_reply = response["replies"][0] + print(first_reply) assert "tool_calls" in first_reply.meta tool_calls = first_reply.meta["tool_calls"]