Skip to content

Commit

Permalink
Fix huggingface inference endpoint name (NVIDIA#1011)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartin-tech committed Nov 19, 2024
2 parents 0bfff87 + ca2e050 commit c7a9fa6
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 6 deletions.
4 changes: 2 additions & 2 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(self, name="", config_root=_config):
self.name = name
super().__init__(self.name, config_root=config_root)

self.uri = self.URI + name
self.uri = self.URI + self.name

# special case for api token requirement this also reserves `headers` as not configurable
if self.api_key:
Expand Down Expand Up @@ -376,7 +376,7 @@ class InferenceEndpoint(InferenceAPI):

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root=config_root)
self.uri = name
self.uri = self.name

@backoff.on_exception(
backoff.fibo,
Expand Down
7 changes: 7 additions & 0 deletions tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ def openai_compat_mocks():
"""Mock responses for OpenAI compatible endpoints"""
with open(pathlib.Path(__file__).parents[0] / "openai.json") as mock_openai:
return json.load(mock_openai)


@pytest.fixture
def hf_endpoint_mocks():
"""Mock responses for Huggingface InferenceAPI based endpoints"""
with open(pathlib.Path(__file__).parents[0] / "hf_inference.json") as mock_openai:
return json.load(mock_openai)
10 changes: 10 additions & 0 deletions tests/generators/hf_inference.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"hf_inference": {
"code": 200,
"json": [
{
"generated_text":"restricted by their policy,"
}
]
}
}
60 changes: 56 additions & 4 deletions tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import requests
import transformers
import garak.generators.huggingface
from garak._config import GarakSubConfig
Expand All @@ -8,6 +9,7 @@
def hf_generator_config():
gen_config = {
"huggingface": {
"api_key": "fake",
"hf_args": {
"device": "cpu",
"torch_dtype": "float32",
Expand All @@ -19,6 +21,17 @@ def hf_generator_config():
return config_root


@pytest.fixture
def hf_mock_response(hf_endpoint_mocks):
import json

mock_resp_data = hf_endpoint_mocks["hf_inference"]
mock_resp = requests.Response()
mock_resp.status_code = mock_resp_data["code"]
mock_resp._content = json.dumps(mock_resp_data["json"]).encode("UTF-8")
return mock_resp


def test_pipeline(hf_generator_config):
generations = 10
g = garak.generators.huggingface.Pipeline("gpt2", config_root=hf_generator_config)
Expand All @@ -37,16 +50,55 @@ def test_pipeline(hf_generator_config):
assert isinstance(item, str)


def test_inference():
return # slow w/o key
g = garak.generators.huggingface.InferenceAPI("gpt2")
assert g.name == "gpt2"
def test_inference(mocker, hf_mock_response, hf_generator_config):
model_name = "gpt2"
mock_request = mocker.patch.object(
requests, "request", return_value=hf_mock_response
)

g = garak.generators.huggingface.InferenceAPI(
model_name, config_root=hf_generator_config
)
assert g.name == model_name
assert model_name in g.uri

hf_generator_config.generators["huggingface"]["name"] = model_name
g = garak.generators.huggingface.InferenceAPI(config_root=hf_generator_config)
assert g.name == model_name
assert model_name in g.uri
assert isinstance(g.max_tokens, int)
g.max_tokens = 99
assert g.max_tokens == 99
g.temperature = 0.1
assert g.temperature == 0.1
output = g.generate("")
mock_request.assert_called_once()
assert len(output) == 1 # 1 generation by default
for item in output:
assert isinstance(item, str)


def test_endpoint(mocker, hf_mock_response, hf_generator_config):
model_name = "https://localhost:8000/gpt2"
mock_request = mocker.patch.object(requests, "post", return_value=hf_mock_response)

g = garak.generators.huggingface.InferenceEndpoint(
model_name, config_root=hf_generator_config
)
assert g.name == model_name
assert g.uri == model_name

hf_generator_config.generators["huggingface"]["name"] = model_name
g = garak.generators.huggingface.InferenceEndpoint(config_root=hf_generator_config)
assert g.name == model_name
assert g.uri == model_name
assert isinstance(g.max_tokens, int)
g.max_tokens = 99
assert g.max_tokens == 99
g.temperature = 0.1
assert g.temperature == 0.1
output = g.generate("")
mock_request.assert_called_once()
assert len(output) == 1 # 1 generation by default
for item in output:
assert isinstance(item, str)
Expand Down

0 comments on commit c7a9fa6

Please sign in to comment.