Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate ChatGPT function to openai v1.0 #1368

Merged
merged 4 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions evadb/functions/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,21 @@ def setup(
)
def forward(self, text_df):
try_to_import_openai()
import openai
from openai import OpenAI

@retry(tries=6, delay=20)
def completion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)

openai.api_key = self.openai_api_key
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
api_key = self.openai_api_key
if len(self.openai_api_key) == 0:
api_key = os.environ.get("OPENAI_API_KEY", "")
assert (
len(openai.api_key) != 0
len(api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"

client = OpenAI(api_key=api_key)

@retry(tries=6, delay=20)
def completion_with_backoff(**kwargs):
return client.chat.completions.create(**kwargs)

queries = text_df[text_df.columns[0]]
content = text_df[text_df.columns[0]]
if len(text_df.columns) > 1:
Expand Down
19 changes: 10 additions & 9 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,25 @@ def setup(self, openai_api_key="") -> None:
)
def forward(self, text_df):
try_to_import_openai()
import openai
from openai import OpenAI

openai.api_key = self.openai_api_key
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
api_key = self.openai_api_key
if len(self.openai_api_key) == 0:
api_key = os.environ.get("OPENAI_API_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"
len(api_key) != 0
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"

client = OpenAI(api_key=api_key)

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")
response = client.images.generate(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
image_response = requests.get(response.data[0].url)
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def read(path, encoding="utf-8"):
"sentence-transformers",
"protobuf",
"bs4",
"openai==0.28", # CHATGPT
"openai>=1.0", # CHATGPT
"gpt4all", # PRIVATE GPT
"sentencepiece", # TRANSFORMERS
]
Expand Down
13 changes: 7 additions & 6 deletions test/integration_tests/long/functions/test_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from test.markers import chatgpt_skip_marker
from test.util import get_evadb_for_testing
Expand All @@ -22,9 +23,8 @@
from evadb.server.command_handler import execute_query_fetch_all


def create_dummy_csv_file(config) -> str:
tmp_dir_from_config = config.get_value("storage", "tmp_dir")

def create_dummy_csv_file(catalog) -> str:
tmp_dir_from_config = catalog.get_configuration_catalog_value("tmp_dir")
df_dict = [
{
"prompt": "summarize",
Expand All @@ -49,17 +49,18 @@ def setUp(self) -> None:
);"""
execute_query_fetch_all(self.evadb, create_table_query)

self.csv_file_path = create_dummy_csv_file(self.evadb.config)
self.csv_file_path = create_dummy_csv_file(self.evadb.catalog())

csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;"""
execute_query_fetch_all(self.evadb, csv_query)
os.environ["OPENAI_API_KEY"] = "sk-..."

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;")

@chatgpt_skip_marker
def test_openai_chat_completion_function(self):
function_name = "OpenAIChatCompletion"
function_name = "ChatGPT"
execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name}
Expand All @@ -69,4 +70,4 @@ def test_openai_chat_completion_function(self):

gpt_query = f"SELECT {function_name}('summarize', content) FROM MyTextCSV;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)
self.assertEqual(output_batch.columns, ["openaichatcompletion.response"])
self.assertEqual(output_batch.columns, ["chatgpt.response"])
35 changes: 30 additions & 5 deletions test/unit_tests/test_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@
import unittest
from io import BytesIO
from test.util import get_evadb_for_testing
from typing import List, Optional
from unittest.mock import MagicMock, patch

from PIL import Image
from PIL import Image as PILImage
from pydantic import AnyUrl, BaseModel

from evadb.server.command_handler import execute_query_fetch_all


class Image(BaseModel):
b64_json: Optional[str] # Replace with the actual type if different
revised_prompt: Optional[str] # Replace with the actual type if different
url: AnyUrl


class ImagesResponse(BaseModel):
created: Optional[int] # Replace with the actual type if different
data: List[Image]


class DallEFunctionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
Expand All @@ -43,10 +56,10 @@ def tearDown(self) -> None:

@patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"})
@patch("requests.get")
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
@patch("openai.OpenAI")
def test_dalle_image_generation(self, mock_openai, mock_requests_get):
# Generate a 1x1 white pixel PNG image in memory
img = Image.new("RGB", (1, 1), color="white")
img = PILImage.new("RGB", (1, 1), color="white")
img_byte_array = BytesIO()
img.save(img_byte_array, format="PNG")
mock_image_content = img_byte_array.getvalue()
Expand All @@ -55,6 +68,18 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
mock_response.content = mock_image_content
mock_requests_get.return_value = mock_response

# Set up the mock for OpenAI instance
mock_openai_instance = mock_openai.return_value
mock_openai_instance.images.generate.return_value = ImagesResponse(
data=[
Image(
b64_json=None,
revised_prompt=None,
url="https://images.openai.com/1234.png",
)
]
)

function_name = "DallE"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")
Expand All @@ -67,6 +92,6 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
execute_query_fetch_all(self.evadb, gpt_query)

mock_openai_create.assert_called_once_with(
mock_openai_instance.images.generate.assert_called_once_with(
prompt="a surreal painting of a cat", n=1, size="1024x1024"
)