Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add allowed_operations init parameter to OpenAPITool #39

Merged
merged 10 commits into from
Sep 10, 2024
9 changes: 6 additions & 3 deletions haystack_experimental/components/tools/openapi/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
credentials: Optional[str] = None,
request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
llm_provider: LLMProvider = LLMProvider.OPENAI,
operations_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
): # noqa: PLR0913
"""
Initialize a ClientConfiguration instance.
Expand All @@ -143,12 +144,14 @@ def __init__(
:param credentials: The credentials to use for authentication.
:param request_sender: The function to use for sending requests.
:param llm_provider: The LLM provider to use for generating tools definitions.
:param operations_filter: A function to filter the functions to register with LLMs.
:raises ValueError: If the OpenAPI specification format is invalid.
"""
self.openapi_spec = openapi_spec
self.credentials = credentials
self.request_sender = request_sender or send_request
self.llm_provider: LLMProvider = llm_provider
self.operation_filter = operations_filter

def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]:
"""
Expand Down Expand Up @@ -180,10 +183,10 @@ def get_tools_definitions(self) -> List[Dict[str, Any]]:
{
LLMProvider.ANTHROPIC: anthropic_converter,
LLMProvider.COHERE: cohere_converter,
}
},
)
converter = provider_to_converter[self.llm_provider]
return converter(self.openapi_spec)
return converter(self.openapi_spec, self.operation_filter)

def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
"""
Expand All @@ -197,7 +200,7 @@ def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
{
LLMProvider.ANTHROPIC: "input",
LLMProvider.COHERE: "parameters",
}
},
)
arguments_field_name = provider_to_arguments_field_name[self.llm_provider]
return create_function_payload_extractor(arguments_field_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,65 +16,79 @@
logger = logging.getLogger(__name__)


def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
def openai_converter(
schema: OpenAPISpecification,
operation_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
) -> List[Dict[str, Any]]:
"""
Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling.

See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema.
:param schema: The OpenAPI specification to convert.
:param operation_filter: A function to filter operations to register with LLMs.
:returns: A list of dictionaries, each dictionary representing an OpenAI function definition.
"""
fn_definitions = _openapi_to_functions(
schema.spec_dict, "parameters", _parse_endpoint_spec_openai
schema.spec_dict, "parameters", _parse_endpoint_spec_openai, operation_filter
)
return [{"type": "function", "function": fn} for fn in fn_definitions]


def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
def anthropic_converter(
schema: OpenAPISpecification,
operation_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
) -> List[Dict[str, Any]]:
"""
Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling.

See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema.

:param schema: The OpenAPI specification to convert.
:param operation_filter: A function to filter operations to register with LLMs.
:returns: A list of dictionaries, each dictionary representing Anthropic function definition.
"""

return _openapi_to_functions(
schema.spec_dict, "input_schema", _parse_endpoint_spec_openai
schema.spec_dict, "input_schema", _parse_endpoint_spec_openai, operation_filter
)


def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
def cohere_converter(
schema: OpenAPISpecification,
operation_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
) -> List[Dict[str, Any]]:
"""
Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling.

See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema.

:param schema: The OpenAPI specification to convert.
:param operation_filter: A function to filter operations to register with LLMs.
:returns: A list of dictionaries, each representing a Cohere style function definition.
"""
return _openapi_to_functions(
schema.spec_dict,"not important for cohere",_parse_endpoint_spec_cohere
schema.spec_dict,"not important for cohere",_parse_endpoint_spec_cohere, operation_filter
)


def _openapi_to_functions(
service_openapi_spec: Dict[str, Any],
parameters_name: str,
parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]],
operation_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
) -> List[Dict[str, Any]]:
"""
Extracts functions from the OpenAPI specification, converts them into a function schema.
Extracts operations from the OpenAPI specification, converts them into a function schema.

:param service_openapi_spec: The OpenAPI specification to extract functions from.
:param service_openapi_spec: The OpenAPI specification to extract operations from.
:param parameters_name: The name of the parameters field in the function schema.
:param parse_endpoint_fn: The function to parse the endpoint specification.
:param operation_filter: A function to filter operations to register with LLMs.
:returns: A list of dictionaries, each dictionary representing a function schema.
"""

# Doesn't enforce rigid spec validation because that would require a lot of dependencies
# We check the version and require minimal fields to be present, so we can extract functions
# We check the version and require minimal fields to be present, so we can extract operations
spec_version = service_openapi_spec.get("openapi")
if not spec_version:
raise ValueError(
Expand All @@ -87,16 +101,22 @@ def _openapi_to_functions(
f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be "
f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
)
functions: List[Dict[str, Any]] = []
operations: List[Dict[str, Any]] = []
for path, path_value in service_openapi_spec["paths"].items():
for path_key, operation_spec in path_value.items():
if path_key.lower() in VALID_HTTP_METHODS:
if "operationId" not in operation_spec:
operation_spec["operationId"] = path_to_operation_id(path, path_key)
function_dict = parse_endpoint_fn(operation_spec, parameters_name)
if function_dict:
functions.append(function_dict)
return functions

# Apply the filter based on operationId before parsing the endpoint (operation)
if operation_filter and not operation_filter(operation_spec):
continue

# parse (and register) this operation as it passed the filter
ops_dict = parse_endpoint_fn(operation_spec, parameters_name)
if ops_dict:
operations.append(ops_dict)
return operations


def _parse_endpoint_spec_openai(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
generator_api_params: Optional[Dict[str, Any]] = None,
spec: Optional[Union[str, Path]] = None,
credentials: Optional[Secret] = None,
allowed_operations: Optional[List[str]] = None,
):
"""
Initialize the OpenAPITool component.
Expand All @@ -72,6 +73,9 @@ def __init__(
:param spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or
an OpenAPI service specification provided as a string.
:param credentials: Credentials for the tool/service.
:param allowed_operations: A list of operations to register with LLMs via the LLM tools parameter. Use
operationId field in the OpenAPI spec path/operation to specify the operation names to use. If not specified,
all operations found in the OpenAPI spec will be registered with LLMs.
"""
self.generator_api = generator_api
self.generator_api_params = generator_api_params or {} # store the generator API parameters for serialization
Expand All @@ -80,6 +84,7 @@ def __init__(
self.open_api_service: Optional[OpenAPIServiceClient] = None
self.spec = spec # store the spec for serialization
self.credentials = credentials # store the credentials for serialization
self.allowed_operations = allowed_operations
if spec:
if os.path.isfile(spec):
openapi_spec = OpenAPISpecification.from_file(spec)
Expand All @@ -91,6 +96,7 @@ def __init__(
openapi_spec=openapi_spec,
credentials=credentials.resolve_value() if credentials else None,
llm_provider=generator_api,
operations_filter=(lambda f: f["operationId"] in allowed_operations) if allowed_operations else None,
)
self.open_api_service = OpenAPIServiceClient(self.config_openapi)

Expand Down Expand Up @@ -184,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]:
generator_api_params=self.generator_api_params,
spec=self.spec,
credentials=self.credentials.to_dict() if self.credentials else None,
allowed_operations=self.allowed_operations,
)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ disable = [
"cyclic-import",
"import-outside-toplevel",
"deprecated-method",
"too-many-arguments", # sometimes we need to pass more than 5 arguments
"too-many-instance-attributes" # sometimes we need to have a class with more than 7 attributes
]

[tool.pytest.ini_options]
Expand Down Expand Up @@ -217,6 +219,7 @@ ignore = [
"SIM108", # if-else-block-instead-of-if-exp
"SIM115", # open-file-with-context-handler
"SIM118", # in-dict-keys
"PLR0913", # too-many-arguments
]

[tool.ruff.lint.mccabe]
Expand Down
48 changes: 48 additions & 0 deletions test/components/tools/openapi/test_openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_to_dict(self, monkeypatch):
},
spec=openapi_spec_url,
credentials=Secret.from_env_var("SERPERDEV_API_KEY"),
allowed_operations=["someOperationId", "someOtherOperationId"],
)

data = tool.to_dict()
Expand All @@ -40,6 +41,7 @@ def test_to_dict(self, monkeypatch):
},
"spec": openapi_spec_url,
"credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
"allowed_operations": ["someOperationId", "someOtherOperationId"],
},
}

Expand All @@ -57,6 +59,7 @@ def test_from_dict(self, monkeypatch):
},
"spec": openapi_spec_url,
"credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
"allowed_operations": None,
},
}

Expand Down Expand Up @@ -211,3 +214,48 @@ def test_run_live_meteo_forecast(self, provider: str):
assert "hourly" in json_response
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")

@pytest.mark.integration
def test_allowed_operations(self):
"""
Although the tool definition is generated from the OpenAPI spec and firecrawl's API has multiple operations,
only the ones we specify in the allowed_operations list are registered with LLMs via the tool definition.
"""
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json",
allowed_operations=["scrape"],
)
tools = tool.config_openapi.get_tools_definitions()
assert len(tools) == 1
assert tools[0]["function"]["name"] == "scrape"

# test two operations
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json",
allowed_operations=["scrape", "crawlUrls"],
)
tools = tool.config_openapi.get_tools_definitions()
assert len(tools) == 2
assert tools[0]["function"]["name"] == "scrape"
assert tools[1]["function"]["name"] == "crawlUrls"

# test non-existent operation
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json",
allowed_operations=["scrape", "non-existent-operation"],
)
tools = tool.config_openapi.get_tools_definitions()
assert len(tools) == 1
assert tools[0]["function"]["name"] == "scrape"

# test all non-existent operations
tool = OpenAPITool(
generator_api=LLMProvider.OPENAI,
spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json",
allowed_operations=["non-existent-operation", "non-existent-operation-2"],
)
tools = tool.config_openapi.get_tools_definitions()
assert len(tools) == 0
Loading