Skip to content

Commit

Permalink
wip: fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Aug 13, 2024
1 parent 8d138fa commit 5221171
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
# b) we can use apply_chat_template with the template above to delineate ChatMessages
# Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model.
tokenizer: PreTrainedTokenizer
if os.environ.get("HF_TOKEN"):
if os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN"):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
else:
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,18 @@ def run(
# rename the meta key to be inline with OpenAI meta output keys
for response in replies:
if response.meta is not None:
if "prompt_token_count" in response.meta:
response.meta["prompt_tokens"] = response.meta.pop("prompt_token_count")
if "generation_token_count" in response.meta:
response.meta["completion_tokens"] = response.meta.pop("generation_token_count")
if 'usage' not in response.meta:
if "prompt_token_count" in response.meta:
response.meta["prompt_tokens"] = response.meta.pop("prompt_token_count")
if 'generation_token_count' in response.meta:
response.meta["completion_tokens"] = response.meta.pop("generation_token_count")
elif "usage" in response.meta:
if "input_tokens" in response.meta['usage']:
response.meta["usage"]['prompt_tokens'] = response.meta['usage'].pop("input_tokens")
if "output_token" in response.meta['usage']:
response.meta["usage"]['completion_tokens'] = response.meta['usage'].pop("output_token")
else:
print("DEBUG", response.meta)

return {"replies": replies}

Expand Down
8 changes: 8 additions & 0 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,14 @@ def test_default_inference_params(self, model_name, chat_messages):
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"

if first_reply.meta and 'usage' in first_reply.meta:
assert 'prompt_tokens' in first_reply.meta['usage']
assert 'completion_tokens' in first_reply.meta['usage']

if first_reply.meta and 'usage' not in first_reply.meta:
assert 'prompt_tokens' in first_reply.meta
assert 'completion_tokens' in first_reply.meta

@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@pytest.mark.integration
def test_default_inference_with_streaming(self, model_name, chat_messages):
Expand Down

0 comments on commit 5221171

Please sign in to comment.