Skip to content

Commit

Permalink
Add GPT Visual support in testing suite (#372)
Browse files Browse the repository at this point in the history
* Add GPT Visual support in testing suite

* Support for local storage backend
  • Loading branch information
Ashesh3 authored Oct 18, 2023
1 parent 62a413b commit b07e083
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 28 deletions.
7 changes: 5 additions & 2 deletions ayushma/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ayushma.views.auth import AuthViewSet
from ayushma.views.chat import ChatFeedbackViewSet, ChatViewSet
from ayushma.views.document import DocumentViewSet
from ayushma.views.document import ProjectDocumentViewSet, TestQuestionDocumentViewSet
from ayushma.views.orphan import OrphanChatViewSet
from ayushma.views.project import ProjectViewSet
from ayushma.views.service import TempTokenViewSet
Expand Down Expand Up @@ -35,12 +35,14 @@

router.register(r"projects", ProjectViewSet)
projects_router = NestedRouter(router, r"projects", lookup="project")
projects_router.register(r"documents", DocumentViewSet)
projects_router.register(r"documents", ProjectDocumentViewSet)
projects_router.register(r"chats", ChatViewSet)

router.register(r"tests/suites", TestSuiteViewSet)
tests_router = NestedRouter(router, r"tests/suites", lookup="test_suite")
tests_router.register(r"questions", TestQuestionViewSet)
test_question_router = NestedRouter(tests_router, r"questions", lookup="test_question")
test_question_router.register(r"documents", TestQuestionDocumentViewSet)
tests_router.register(r"runs", TestRunViewSet)
runs_router = NestedRouter(tests_router, r"runs", lookup="run")
runs_router.register(r"feedback", FeedbackViewSet)
Expand All @@ -49,5 +51,6 @@
path(r"", include(router.urls)),
path(r"", include(projects_router.urls)),
path(r"", include(tests_router.urls)),
path(r"", include(test_question_router.urls)),
path(r"", include(runs_router.urls)),
]
17 changes: 17 additions & 0 deletions ayushma/migrations/0046_testquestion_attachments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-10-15 10:23

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("ayushma", "0045_testrun_references"),
]

operations = [
migrations.AddField(
model_name="testquestion",
name="attachments",
field=models.ManyToManyField(blank=True, to="ayushma.document"),
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 4.2.5 on 2023-10-15 10:47

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("ayushma", "0046_testquestion_attachments"),
]

operations = [
migrations.RenameField(
model_name="testquestion",
old_name="attachments",
new_name="documents",
),
migrations.AddField(
model_name="document",
name="test_question",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.PROTECT,
to="ayushma.testquestion",
),
),
migrations.AlterField(
model_name="document",
name="project",
field=models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.PROTECT,
to="ayushma.project",
),
),
]
41 changes: 41 additions & 0 deletions ayushma/migrations/0048_alter_chat_model_alter_project_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Generated by Django 4.2.5 on 2023-10-18 05:50

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("ayushma", "0047_rename_attachments_testquestion_documents_and_more"),
]

operations = [
migrations.AlterField(
model_name="chat",
name="model",
field=models.IntegerField(
blank=True,
choices=[
(1, "Gpt 3 5"),
(2, "Gpt 3 5 16K"),
(3, "Gpt 4"),
(4, "Gpt 4 32K"),
(5, "Gpt 4 Visual"),
],
null=True,
),
),
migrations.AlterField(
model_name="project",
name="model",
field=models.IntegerField(
choices=[
(1, "Gpt 3 5"),
(2, "Gpt 3 5 16K"),
(3, "Gpt 4"),
(4, "Gpt 4 32K"),
(5, "Gpt 4 Visual"),
],
default=1,
),
),
]
5 changes: 4 additions & 1 deletion ayushma/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ class Document(BaseModel):
)
file = models.FileField(null=True, blank=True)
text_content = models.TextField(null=True, blank=True)
project = models.ForeignKey(Project, on_delete=models.PROTECT)
project = models.ForeignKey(Project, on_delete=models.PROTECT, null=True)
test_question = models.ForeignKey(
"TestQuestion", on_delete=models.PROTECT, null=True, blank=True
)
uploading = models.BooleanField(default=True)

def __str__(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions ayushma/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelType(IntegerChoices):
GPT_3_5_16K = 2
GPT_4 = 3
GPT_4_32K = 4
GPT_4_VISUAL = 5


class StatusChoices(IntegerChoices):
Expand Down
1 change: 1 addition & 0 deletions ayushma/models/testsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TestQuestion(BaseModel):
test_suite = ForeignKey(TestSuite, on_delete=CASCADE)
question = TextField()
human_answer = TextField()
documents = models.ManyToManyField(Document, blank=True)
language = models.CharField(max_length=10, blank=False, default="en")


Expand Down
17 changes: 16 additions & 1 deletion ayushma/serializers/testsuite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from rest_framework import serializers

from ayushma.models import Feedback, TestQuestion, TestResult, TestRun, TestSuite
from ayushma.models import (
Document,
Feedback,
TestQuestion,
TestResult,
TestRun,
TestSuite,
)
from ayushma.serializers.document import DocumentSerializer
from ayushma.serializers.project import ProjectSerializer
from ayushma.serializers.users import UserSerializer
Expand All @@ -22,6 +29,12 @@ class Meta:


class TestQuestionSerializer(serializers.ModelSerializer):
documents = DocumentSerializer(many=True, read_only=True)

def get_documents(self, obj):
documents = Document.objects.filter(test_question__external_id=obj.external_id)
return DocumentSerializer(documents, many=True).data

class Meta:
model = TestQuestion
fields = (
Expand All @@ -31,6 +44,7 @@ class Meta:
"modified_at",
"human_answer",
"external_id",
"documents",
)
read_only_fields = ("external_id", "created_at", "modified_at")

Expand Down Expand Up @@ -59,6 +73,7 @@ class Meta:
class TestResultSerializer(serializers.ModelSerializer):
feedback = FeedbackSerializer(source="feedback_set", many=True, read_only=True)
references = DocumentSerializer(many=True, read_only=True)
test_question = TestQuestionSerializer(read_only=True)

class Meta:
model = TestResult
Expand Down
1 change: 1 addition & 0 deletions ayushma/tasks/testrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def mark_test_run_as_completed(self, test_run_id):
stream=False,
generate_audio=False,
fetch_references=test_run.references,
documents=test_question.documents.all(),
)
)

Expand Down
86 changes: 70 additions & 16 deletions ayushma/utils/langchain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Literal

import openai
from django.conf import settings
from langchain import LLMChain, PromptTemplate
Expand All @@ -6,15 +8,17 @@
from langchain.llms import AzureOpenAI
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.prompts.chat import BaseStringMessagePromptTemplate
from langchain.schema import SystemMessage
from langchain.schema.messages import HumanMessage

from ayushma.models.enums import ModelType
from ayushma.utils.stream_callback import StreamingQueueCallbackHandler
from core.settings.base import AI_NAME
from utils.helpers import get_base64_document


def get_model_name(model_type: ModelType):
Expand All @@ -30,12 +34,44 @@ def get_model_name(model_type: ModelType):
return "gpt-4"
elif model_type == ModelType.GPT_4_32K:
return "gpt-4-32k"
elif model_type == ModelType.GPT_4_VISUAL:
return "gpt-4-visual"
else:
if settings.OPENAI_API_TYPE == "azure":
return settings.AZURE_CHAT_MODEL
return "gpt-3.5-turbo"


class GenericHumanMessage(HumanMessage):
"""A Generic Message from a human."""

content: Any
"""The contents of the message."""

example: bool = False
"""Whether this Message is being passed in to the model as part of an example conversation."""

type: Literal["human"] = "human"


class GenericHumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Generic Human message prompt template. This is a message sent from the user."""

def format(self, **kwargs: Any):
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = kwargs[self.input_variables[0]]
return GenericHumanMessage(
content=text, additional_kwargs=self.additional_kwargs
)


class LangChainHelper:
def __init__(
self,
Expand All @@ -53,6 +89,7 @@ def __init__(
"openai_api_key": openai_api_key,
"model_name": get_model_name(model),
"request_timeout": 180,
"max_tokens": "4096",
}
if stream:
llm_args["streaming"] = True
Expand Down Expand Up @@ -107,8 +144,9 @@ def __init__(
human_prompt = PromptTemplate(
template="{user_msg}", input_variables=["user_msg"]
)
human_message_prompt = HumanMessagePromptTemplate(
human_message_prompt = GenericHumanMessagePromptTemplate(
prompt=human_prompt,
template_format=None,
)

message_prompt = MessagesPlaceholder(variable_name="chat_history")
Expand All @@ -124,16 +162,24 @@ def __init__(
self.chain = LLMChain(llm=llm, prompt=chat_prompt, verbose=True)

async def get_aresponse(
self, job_done, error, token_queue, user_msg, reference, chat_history
self, job_done, error, token_queue, user_msg, reference, chat_history, documents
):
chat_history.append(
SystemMessage(
content="Remember to only answer the question if it can be answered with the given references"
)
)
system_message = "Remember to only answer the question if it can be answered with the given references"
user_message = user_msg

if documents:
user_message = [user_msg]
system_message = f"Image Capabilities: Enabled\n${system_message}"
for document in documents:
encoded_document = get_base64_document(document)
if encoded_document:
user_message.append({"image": encoded_document, "resize": None})

chat_history.append(SystemMessage(content=system_message))

try:
async_response = await self.chain.apredict(
user_msg=user_msg,
user_msg=user_message,
reference=reference,
chat_history=chat_history,
)
Expand All @@ -143,14 +189,22 @@ async def get_aresponse(
print(e)
token_queue.put((error, e))

def get_response(self, user_msg, reference, chat_history):
chat_history.append(
SystemMessage(
content="Remember to only answer the question if it can be answered with the given references"
)
)
def get_response(self, user_msg, reference, chat_history, documents):
system_message = "Remember to only answer the question if it can be answered with the given references"
user_message = user_msg

if documents:
user_message = [user_msg]
system_message = f"Image Capabilities: Enabled\n{system_message}"
for document in documents:
encoded_document = get_base64_document(document)
if encoded_document:
user_message.append({"image": encoded_document, "resize": None})

chat_history.append(SystemMessage(content=system_message))

return self.chain.predict(
user_msg=user_msg,
user_msg=user_message,
reference=reference,
chat_history=chat_history,
)
9 changes: 8 additions & 1 deletion ayushma/utils/openaiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def converse(
generate_audio=True,
noonce=None,
fetch_references=True,
documents=None,
):
if not openai_key:
raise Exception("OpenAI-Key header is required to create a chat or converse")
Expand Down Expand Up @@ -306,6 +307,9 @@ def converse(

prompt = chat.prompt or (chat.project and chat.project.prompt)

if documents or chat.project.model == ModelType.GPT_4_VISUAL:
prompt = "Image Capabilities: Enabled\n" + prompt

# excluding the latest query since it is not a history
previous_messages = (
ChatMessage.objects.filter(chat=chat)
Expand All @@ -329,7 +333,9 @@ def converse(
or ModelType.GPT_3_5,
temperature=temperature,
)
response = lang_chain_helper.get_response(english_text, reference, chat_history)
response = lang_chain_helper.get_response(
english_text, reference, chat_history, documents
)
chat_response = response.replace("Ayushma: ", "")
stats["response_end_time"] = time.time()
translated_chat_response, url, chat_message = handle_post_response(
Expand Down Expand Up @@ -372,6 +378,7 @@ def converse(
english_text,
reference,
chat_history,
documents,
)
chat_response = ""
skip_token = len(f"{AI_NAME}: ")
Expand Down
Loading

0 comments on commit b07e083

Please sign in to comment.