-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: replaced openai client with chat completion api
- Loading branch information
1 parent
eb5db7c
commit 400303d
Showing
9 changed files
with
104 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,45 @@ | ||
"""openai client""" | ||
"""CHAT_COMPLETION_API client""" | ||
import json | ||
import logging | ||
|
||
import openai | ||
import requests | ||
from requests.exceptions import ConnectTimeout | ||
|
||
from django.conf import settings | ||
|
||
openai.api_key = settings.OPENAI_API_KEY | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
def chat_completion(prompt): | ||
""" | ||
Use chatGPT https://api.openai.com/v1/chat/completions endpoint to generate a response. | ||
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting. | ||
Arguments: | ||
prompt (str): chatGPT prompt | ||
""" | ||
response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo", | ||
messages=[ | ||
{"role": "user", "content": prompt}, | ||
] | ||
) | ||
completion_endpoint = getattr(settings, 'CHAT_COMPLETION_API', None) | ||
completion_endpoint_key = getattr(settings, 'CHAT_COMPLETION_API_KEY', None) | ||
if completion_endpoint and completion_endpoint_key: | ||
headers = {'Content-Type': 'application/json', 'x-api-key': completion_endpoint_key} | ||
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1) | ||
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15) | ||
body = {'message_list': [{'role': 'assistant', 'content': prompt},]} | ||
try: | ||
response = requests.post( | ||
completion_endpoint, | ||
headers=headers, | ||
data=json.dumps(body), | ||
timeout=(connect_timeout, read_timeout) | ||
) | ||
chat = response.json().get('content') | ||
except (ConnectTimeout, ConnectionError) as e: | ||
error_message = str(e) | ||
connection_message = 'Failed to connect to chat completion API.' | ||
log.error( | ||
'%(connection_message)s %(error)s', | ||
{'connection_message': connection_message, 'error': error_message} | ||
) | ||
chat = connection_message | ||
else: | ||
chat = 'Completion endpoint is not defined.' | ||
|
||
content = response['choices'][0]['message']['content'] | ||
return content | ||
return chat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
Tests for chat completion client. | ||
""" | ||
import responses | ||
from mock import patch | ||
|
||
from django.conf import settings | ||
|
||
from taxonomy.openai.client import chat_completion | ||
from test_utils.testcase import TaxonomyTestCase | ||
|
||
|
||
class TestChatCompletionClient(TaxonomyTestCase): | ||
""" | ||
Validate chat_completion client. | ||
""" | ||
@responses.activate | ||
def test_client(self): | ||
""" | ||
Test that the chat completion client works as expected. | ||
""" | ||
chat_prompt = 'how many courses are offered by edx in the data science area' | ||
expected_chat_response = { | ||
"role": "assistant", | ||
"content": "edx offers 500 courses in the data science area" | ||
} | ||
responses.add( | ||
method=responses.POST, | ||
url=settings.CHAT_COMPLETION_API, | ||
json=expected_chat_response, | ||
) | ||
chat_response = chat_completion(chat_prompt) | ||
self.assertEqual(chat_response, expected_chat_response['content']) | ||
|
||
@patch('taxonomy.openai.client.requests.post') | ||
def test_client_exceptions(self, post_mock): | ||
""" | ||
Test that the chat completion client handles exceptions as expected. | ||
""" | ||
chat_prompt = 'how many courses are offered by edx in the data science area' | ||
post_mock.side_effect = ConnectionError() | ||
chat_response = chat_completion(chat_prompt) | ||
self.assertEqual(chat_response, 'Failed to connect to chat completion API.') | ||
|
||
def test_client_missing_settings(self): | ||
""" | ||
Test that the chat completion client handles missing settings as expected. | ||
""" | ||
chat_prompt = 'how many courses are offered by edx in the data science area' | ||
settings.CHAT_COMPLETION_API_KEY = None | ||
chat_response = chat_completion(chat_prompt) | ||
self.assertEqual(chat_response, 'Completion endpoint is not defined.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters