Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add OpenAPI support, OpenAPITool component #8

Merged
merged 42 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
472b2de
Initial openapi impl
vblagoje May 28, 2024
3fd99ac
Refactoring step 1
vblagoje Jun 5, 2024
20be1e5
Refactoring step 2
vblagoje Jun 5, 2024
adb96c5
Refactor step 3
vblagoje Jun 5, 2024
221df1b
Refactoring step 4
vblagoje Jun 5, 2024
e2ecd8d
Add OpenAPITool initial impl
vblagoje Jun 5, 2024
992ef6e
Add headers
vblagoje Jun 5, 2024
910a7b2
Refactoring step 5 - move things around
vblagoje Jun 5, 2024
14d38e9
Fix linting
vblagoje Jun 5, 2024
406aa1b
Refactoring step 6 - simplify generator factory
vblagoje Jun 5, 2024
e5db2d0
Cosmetics
vblagoje Jun 5, 2024
4b35fdf
Add model_kwargs to OpenAPITool init
vblagoje Jun 5, 2024
7bca18f
Fix double ClientConfiguration creation
vblagoje Jun 5, 2024
72922a0
Update internal pydoc
vblagoje Jun 6, 2024
d996cc5
PR feedback
vblagoje Jun 6, 2024
2f01180
Remove lazy imports
vblagoje Jun 6, 2024
4d90cee
Typing fixes
vblagoje Jun 6, 2024
e57ac8a
Add lazy imports
vblagoje Jun 6, 2024
1b87f82
Expose LLMProvider
vblagoje Jun 6, 2024
7e09480
Avoid circular deps
vblagoje Jun 6, 2024
a7fcadc
Add header for types.py
vblagoje Jun 6, 2024
190f113
Improve pydoc
vblagoje Jun 7, 2024
d21bcad
Add back in http bearer auth
vblagoje Jun 17, 2024
e264ec2
Add firecrawl openapi conversion tests
vblagoje Jun 18, 2024
57bb344
PR feedback
vblagoje Jun 18, 2024
df71f3a
Fix header
vblagoje Jun 18, 2024
bfdcf94
PR feedback - details
vblagoje Jun 18, 2024
c6cca91
Lift up OpenAPISpecification
vblagoje Jun 18, 2024
05e38f2
Update OpenAPITool
vblagoje Jun 18, 2024
5de7003
Final touches
vblagoje Jun 18, 2024
32d8ca0
Final touches - pydoc
vblagoje Jun 18, 2024
4f55c96
Minor detail around ClientConfiguration LLMProvider setting
vblagoje Jun 18, 2024
dfcae5c
Make use of OpenAPISpecification explicit
vblagoje Jun 18, 2024
503e71f
First batch of OpenAPITool unit and integration tests
vblagoje Jun 18, 2024
d371dba
Add serde and unit tests
vblagoje Jun 18, 2024
a87d446
Merge branch 'main' into openapi
vblagoje Jun 18, 2024
e5c76c7
Skip github test
vblagoje Jun 18, 2024
ecdc37d
Skip github tests
vblagoje Jun 18, 2024
a5db73b
Increase default request timeout to 30 sec
vblagoje Jun 19, 2024
37e6e54
PR review
vblagoje Jun 19, 2024
7eb140f
Add to Experiments catalog
vblagoje Jun 24, 2024
d714a6f
Merge branch 'main' into openapi
vblagoje Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]:
params=request.get("params", {}),
json=request.get("json"),
auth=request.get("auth"),
timeout=10,
timeout=30,
)
response.raise_for_status()
return response.json()
Expand Down
43 changes: 40 additions & 3 deletions haystack_experimental/components/tools/openapi/openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from haystack import component, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.url_validation import is_valid_http_url

from haystack_experimental.components.tools.openapi._openapi import (
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
:param credentials: Credentials for the tool/service.
"""
self.generator_api = generator_api
self.generator_api_params = generator_api_params or {} # store the generator API parameters for serialization
self.chat_generator = self._init_generator(generator_api, generator_api_params or {})
self.config_openapi: Optional[ClientConfiguration] = None
self.open_api_service: Optional[OpenAPIServiceClient] = None
Expand All @@ -83,7 +84,8 @@ def __init__(
openapi_spec = OpenAPISpecification.from_url(str(spec))
else:
raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL")

self.spec = spec # store the spec for serialization
self.credentials = credentials # store the credentials for serialization
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
self.config_openapi = ClientConfiguration(
openapi_spec=openapi_spec,
credentials=credentials.resolve_value() if credentials else None,
Expand Down Expand Up @@ -167,6 +169,41 @@ def run(

return {"service_response": response_messages}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.

:returns:
The serialized component as a dictionary.
"""
if "api_key" in self.generator_api_params:
self.generator_api_params["api_key"] = self.generator_api_params["api_key"].to_dict()
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

return default_to_dict(
self,
generator_api=self.generator_api.value,
generator_api_params=self.generator_api_params,
spec=self.spec,
credentials=self.credentials.to_dict() if self.credentials else None,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAPITool":
"""
Deserialize this component from a dictionary.

:param data: The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["credentials"])
if "generator_api_params" in data["init_parameters"]:
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"])
init_params = data.get("init_parameters", {})
generator_api = init_params.get("generator_api")
data["init_parameters"]["generator_api"] = LLMProvider(generator_api)
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
return default_from_dict(cls, data)

def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]):
"""
Initialize the chat generator based on the specified API provider and parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_serperdev(self, test_files_path):
assert "invention" in str(response)

@pytest.mark.integration
@pytest.mark.skip("This test hits rate limit on Github API. Skip for now.")
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
def test_github(self, test_files_path):
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"))
api = OpenAPIServiceClient(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_serperdev(self, test_files_path):

@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
@pytest.mark.integration
@pytest.mark.skip("This test hits rate limit on Github API. Skip for now.")
def test_github(self, test_files_path):
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"))
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
Expand Down
181 changes: 181 additions & 0 deletions test/components/tools/openapi/test_openapi_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import json
import os

from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret

from haystack_experimental.components.tools.openapi import LLMProvider
from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool

import pytest


class TestOpenAPITool:

def test_to_dict(self, monkeypatch):
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key")

openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json"

tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
generator_api_params={
"model": "gpt-3.5-turbo",
"api_key": Secret.from_env_var("OPENAI_API_KEY"),
},
spec=openapi_spec_url,
credentials=Secret.from_env_var("SERPERDEV_API_KEY"),
)

data = tool.to_dict()
assert data == {
"type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool",
"init_parameters": {
"generator_api": "openai",
"generator_api_params": {
"model": "gpt-3.5-turbo",
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
},
"spec": openapi_spec_url,
"credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key")
openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json"
data = {
"type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool",
"init_parameters": {
"generator_api": "openai",
"generator_api_params": {
"model": "gpt-3.5-turbo",
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
},
"spec": openapi_spec_url,
"credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
},
}

tool = OpenAPITool.from_dict(data)

assert tool.generator_api == LLMProvider.OPENAI
assert tool.generator_api_params == {
"model": "gpt-3.5-turbo",
"api_key": Secret.from_env_var("OPENAI_API_KEY")
}
assert tool.spec == openapi_spec_url
assert tool.credentials == Secret.from_env_var("SERPERDEV_API_KEY")

def test_initialize_with_valid_openapi_spec_url_and_credentials(self):
openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json"
credentials = Secret.from_token("<your-tool-token>")
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
generator_api_params={
"model": "gpt-3.5-turbo",
"api_key": Secret.from_token("not_needed"),
},
spec=openapi_spec_url,
credentials=credentials,
)

assert tool.generator_api == LLMProvider.OPENAI
assert isinstance(tool.chat_generator, OpenAIChatGenerator)
assert tool.config_openapi is not None
assert tool.open_api_service is not None

@pytest.mark.skipif(
"SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set"
)
@pytest.mark.skipif(
"OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set"
)
@pytest.mark.integration
def test_run_live_openai(self):
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
spec="https://bit.ly/serper_dev_spec_yaml",
credentials=Secret.from_env_var("SERPERDEV_API_KEY"),
)

user_message = ChatMessage.from_user(
"Scrape URL: https://news.ycombinator.com/"
)

results = tool.run(messages=[user_message])

assert isinstance(results["service_response"], list)
assert len(results["service_response"]) == 1
assert isinstance(results["service_response"][0], ChatMessage)

try:
json_response = json.loads(results["service_response"][0].content)
assert isinstance(json_response, dict)
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")

@pytest.mark.skipif(
"SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set"
)
@pytest.mark.skipif(
"ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set"
)
@pytest.mark.integration
def test_run_live_anthropic(self):
tool = OpenAPITool(
generator_api=LLMProvider.ANTHROPIC,
generator_api_params={"model": "claude-3-opus-20240229"},
spec="https://bit.ly/serper_dev_spec_yaml",
credentials=Secret.from_env_var("SERPERDEV_API_KEY"),
)

user_message = ChatMessage.from_user(
"Scrape URL: https://news.ycombinator.com/"
)

results = tool.run(messages=[user_message])

assert isinstance(results["service_response"], list)
assert len(results["service_response"]) == 1
assert isinstance(results["service_response"][0], ChatMessage)

try:
json_response = json.loads(results["service_response"][0].content)
assert isinstance(json_response, dict)
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")

@pytest.mark.skipif(
"SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set"
)
@pytest.mark.skipif(
"COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set"
)
@pytest.mark.integration
def test_run_live_cohere(self):
tool = OpenAPITool(
generator_api=LLMProvider.COHERE,
generator_api_params={"model": "command-r"},
spec="https://bit.ly/serper_dev_spec_yaml",
credentials=Secret.from_env_var("SERPERDEV_API_KEY"),
)

user_message = ChatMessage.from_user(
"Scrape URL: https://news.ycombinator.com/"
)

results = tool.run(messages=[user_message])

assert isinstance(results["service_response"], list)
assert len(results["service_response"]) == 1
assert isinstance(results["service_response"][0], ChatMessage)

try:
json_response = json.loads(results["service_response"][0].content)
assert isinstance(json_response, dict)
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")
Loading