From b07e0833ca4dbf51090da7d4d17681d4e01d6f64 Mon Sep 17 00:00:00 2001 From: Ashesh <3626859+Ashesh3@users.noreply.github.com> Date: Thu, 19 Oct 2023 01:12:24 +0530 Subject: [PATCH] Add GPT Visual support in testing suite (#372) * Add GPT Visual support in testing suite * Support for local storage backend --- ayushma/api_router.py | 7 +- .../0046_testquestion_attachments.py | 17 ++++ ...chments_testquestion_documents_and_more.py | 37 ++++++++ ...48_alter_chat_model_alter_project_model.py | 41 +++++++++ ayushma/models/document.py | 5 +- ayushma/models/enums.py | 1 + ayushma/models/testsuite.py | 1 + ayushma/serializers/testsuite.py | 17 +++- ayushma/tasks/testrun.py | 1 + ayushma/utils/langchain.py | 86 +++++++++++++++---- ayushma/utils/openaiapi.py | 9 +- ayushma/views/document.py | 45 ++++++++-- utils/helpers.py | 21 +++++ 13 files changed, 260 insertions(+), 28 deletions(-) create mode 100644 ayushma/migrations/0046_testquestion_attachments.py create mode 100644 ayushma/migrations/0047_rename_attachments_testquestion_documents_and_more.py create mode 100644 ayushma/migrations/0048_alter_chat_model_alter_project_model.py diff --git a/ayushma/api_router.py b/ayushma/api_router.py index a0139121..7f401ce8 100644 --- a/ayushma/api_router.py +++ b/ayushma/api_router.py @@ -4,7 +4,7 @@ from ayushma.views.auth import AuthViewSet from ayushma.views.chat import ChatFeedbackViewSet, ChatViewSet -from ayushma.views.document import DocumentViewSet +from ayushma.views.document import ProjectDocumentViewSet, TestQuestionDocumentViewSet from ayushma.views.orphan import OrphanChatViewSet from ayushma.views.project import ProjectViewSet from ayushma.views.service import TempTokenViewSet @@ -35,12 +35,14 @@ router.register(r"projects", ProjectViewSet) projects_router = NestedRouter(router, r"projects", lookup="project") -projects_router.register(r"documents", DocumentViewSet) +projects_router.register(r"documents", ProjectDocumentViewSet) projects_router.register(r"chats", ChatViewSet) router.register(r"tests/suites", TestSuiteViewSet) tests_router = NestedRouter(router, r"tests/suites", lookup="test_suite") tests_router.register(r"questions", TestQuestionViewSet) +test_question_router = NestedRouter(tests_router, r"questions", lookup="test_question") +test_question_router.register(r"documents", TestQuestionDocumentViewSet) tests_router.register(r"runs", TestRunViewSet) runs_router = NestedRouter(tests_router, r"runs", lookup="run") runs_router.register(r"feedback", FeedbackViewSet) @@ -49,5 +51,6 @@ path(r"", include(router.urls)), path(r"", include(projects_router.urls)), path(r"", include(tests_router.urls)), + path(r"", include(test_question_router.urls)), path(r"", include(runs_router.urls)), ] diff --git a/ayushma/migrations/0046_testquestion_attachments.py b/ayushma/migrations/0046_testquestion_attachments.py new file mode 100644 index 00000000..378ded22 --- /dev/null +++ b/ayushma/migrations/0046_testquestion_attachments.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.5 on 2023-10-15 10:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ayushma", "0045_testrun_references"), + ] + + operations = [ + migrations.AddField( + model_name="testquestion", + name="attachments", + field=models.ManyToManyField(blank=True, to="ayushma.document"), + ), + ] diff --git a/ayushma/migrations/0047_rename_attachments_testquestion_documents_and_more.py b/ayushma/migrations/0047_rename_attachments_testquestion_documents_and_more.py new file mode 100644 index 00000000..9fadd9aa --- /dev/null +++ b/ayushma/migrations/0047_rename_attachments_testquestion_documents_and_more.py @@ -0,0 +1,37 @@ +# Generated by Django 4.2.5 on 2023-10-15 10:47 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ayushma", "0046_testquestion_attachments"), + ] + + operations = [ + migrations.RenameField( + model_name="testquestion", + old_name="attachments", + new_name="documents", + ), + migrations.AddField( + model_name="document", + name="test_question", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.PROTECT, + to="ayushma.testquestion", + ), + ), + migrations.AlterField( + model_name="document", + name="project", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.PROTECT, + to="ayushma.project", + ), + ), + ] diff --git a/ayushma/migrations/0048_alter_chat_model_alter_project_model.py b/ayushma/migrations/0048_alter_chat_model_alter_project_model.py new file mode 100644 index 00000000..6a87bcdd --- /dev/null +++ b/ayushma/migrations/0048_alter_chat_model_alter_project_model.py @@ -0,0 +1,41 @@ +# Generated by Django 4.2.5 on 2023-10-18 05:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ayushma", "0047_rename_attachments_testquestion_documents_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="chat", + name="model", + field=models.IntegerField( + blank=True, + choices=[ + (1, "Gpt 3 5"), + (2, "Gpt 3 5 16K"), + (3, "Gpt 4"), + (4, "Gpt 4 32K"), + (5, "Gpt 4 Visual"), + ], + null=True, + ), + ), + migrations.AlterField( + model_name="project", + name="model", + field=models.IntegerField( + choices=[ + (1, "Gpt 3 5"), + (2, "Gpt 3 5 16K"), + (3, "Gpt 4"), + (4, "Gpt 4 32K"), + (5, "Gpt 4 Visual"), + ], + default=1, + ), + ), + ] diff --git a/ayushma/models/document.py b/ayushma/models/document.py index 54bc4228..55c2d1be 100644 --- a/ayushma/models/document.py +++ b/ayushma/models/document.py @@ -14,7 +14,10 @@ class Document(BaseModel): ) file = models.FileField(null=True, blank=True) text_content = models.TextField(null=True, blank=True) - project = models.ForeignKey(Project, on_delete=models.PROTECT) + project = models.ForeignKey(Project, on_delete=models.PROTECT, null=True) + test_question = models.ForeignKey( + "TestQuestion", on_delete=models.PROTECT, null=True, blank=True + ) uploading = models.BooleanField(default=True) def __str__(self) -> str: diff --git a/ayushma/models/enums.py b/ayushma/models/enums.py index 24a73400..6b2301b3 100644 --- a/ayushma/models/enums.py +++ b/ayushma/models/enums.py @@ -33,6 +33,7 @@ class ModelType(IntegerChoices): GPT_3_5_16K = 2 GPT_4 = 3 GPT_4_32K = 4 + GPT_4_VISUAL = 5 class StatusChoices(IntegerChoices): diff --git a/ayushma/models/testsuite.py b/ayushma/models/testsuite.py index 8d63c7cc..d93dee4e 100644 --- a/ayushma/models/testsuite.py +++ b/ayushma/models/testsuite.py @@ -25,6 +25,7 @@ class TestQuestion(BaseModel): test_suite = ForeignKey(TestSuite, on_delete=CASCADE) question = TextField() human_answer = TextField() + documents = models.ManyToManyField(Document, blank=True) language = models.CharField(max_length=10, blank=False, default="en") diff --git a/ayushma/serializers/testsuite.py b/ayushma/serializers/testsuite.py index 0110c8a8..48921f7a 100644 --- a/ayushma/serializers/testsuite.py +++ b/ayushma/serializers/testsuite.py @@ -1,6 +1,13 @@ from rest_framework import serializers -from ayushma.models import Feedback, TestQuestion, TestResult, TestRun, TestSuite +from ayushma.models import ( + Document, + Feedback, + TestQuestion, + TestResult, + TestRun, + TestSuite, +) from ayushma.serializers.document import DocumentSerializer from ayushma.serializers.project import ProjectSerializer from ayushma.serializers.users import UserSerializer @@ -22,6 +29,12 @@ class Meta: class TestQuestionSerializer(serializers.ModelSerializer): + documents = DocumentSerializer(many=True, read_only=True) + + def get_documents(self, obj): + documents = Document.objects.filter(test_question__external_id=obj.external_id) + return DocumentSerializer(documents, many=True).data + class Meta: model = TestQuestion fields = ( @@ -31,6 +44,7 @@ class Meta: "modified_at", "human_answer", "external_id", + "documents", ) read_only_fields = ("external_id", "created_at", "modified_at") @@ -59,6 +73,7 @@ class Meta: class TestResultSerializer(serializers.ModelSerializer): feedback = FeedbackSerializer(source="feedback_set", many=True, read_only=True) references = DocumentSerializer(many=True, read_only=True) + test_question = TestQuestionSerializer(read_only=True) class Meta: model = TestResult diff --git a/ayushma/tasks/testrun.py b/ayushma/tasks/testrun.py index 6baa2378..2988f832 100644 --- a/ayushma/tasks/testrun.py +++ b/ayushma/tasks/testrun.py @@ -72,6 +72,7 @@ def mark_test_run_as_completed(self, test_run_id): stream=False, generate_audio=False, fetch_references=test_run.references, + documents=test_question.documents.all(), ) ) diff --git a/ayushma/utils/langchain.py b/ayushma/utils/langchain.py index 5d43c859..2e1ef26c 100644 --- a/ayushma/utils/langchain.py +++ b/ayushma/utils/langchain.py @@ -1,3 +1,5 @@ +from typing import Any, Literal + import openai from django.conf import settings from langchain import LLMChain, PromptTemplate @@ -6,15 +8,17 @@ from langchain.llms import AzureOpenAI from langchain.prompts import ( ChatPromptTemplate, - HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) +from langchain.prompts.chat import BaseStringMessagePromptTemplate from langchain.schema import SystemMessage +from langchain.schema.messages import HumanMessage from ayushma.models.enums import ModelType from ayushma.utils.stream_callback import StreamingQueueCallbackHandler from core.settings.base import AI_NAME +from utils.helpers import get_base64_document def get_model_name(model_type: ModelType): @@ -30,12 +34,44 @@ def get_model_name(model_type: ModelType): return "gpt-4" elif model_type == ModelType.GPT_4_32K: return "gpt-4-32k" + elif model_type == ModelType.GPT_4_VISUAL: + return "gpt-4-visual" else: if settings.OPENAI_API_TYPE == "azure": return settings.AZURE_CHAT_MODEL return "gpt-3.5-turbo" +class GenericHumanMessage(HumanMessage): + """A Generic Message from a human.""" + + content: Any + """The contents of the message.""" + + example: bool = False + """Whether this Message is being passed in to the model as part of an example conversation.""" + + type: Literal["human"] = "human" + + +class GenericHumanMessagePromptTemplate(BaseStringMessagePromptTemplate): + """Generic Human message prompt template. This is a message sent from the user.""" + + def format(self, **kwargs: Any): + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + text = kwargs[self.input_variables[0]] + return GenericHumanMessage( + content=text, additional_kwargs=self.additional_kwargs + ) + + class LangChainHelper: def __init__( self, @@ -53,6 +89,7 @@ def __init__( "openai_api_key": openai_api_key, "model_name": get_model_name(model), "request_timeout": 180, + "max_tokens": "4096", } if stream: llm_args["streaming"] = True @@ -107,8 +144,9 @@ def __init__( human_prompt = PromptTemplate( template="{user_msg}", input_variables=["user_msg"] ) - human_message_prompt = HumanMessagePromptTemplate( + human_message_prompt = GenericHumanMessagePromptTemplate( prompt=human_prompt, + template_format=None, ) message_prompt = MessagesPlaceholder(variable_name="chat_history") @@ -124,16 +162,24 @@ def __init__( self.chain = LLMChain(llm=llm, prompt=chat_prompt, verbose=True) async def get_aresponse( - self, job_done, error, token_queue, user_msg, reference, chat_history + self, job_done, error, token_queue, user_msg, reference, chat_history, documents ): - chat_history.append( - SystemMessage( - content="Remember to only answer the question if it can be answered with the given references" - ) - ) + system_message = "Remember to only answer the question if it can be answered with the given references" + user_message = user_msg + + if documents: + user_message = [user_msg] + system_message = f"Image Capabilities: Enabled\n${system_message}" + for document in documents: + encoded_document = get_base64_document(document) + if encoded_document: + user_message.append({"image": encoded_document, "resize": None}) + + chat_history.append(SystemMessage(content=system_message)) + try: async_response = await self.chain.apredict( - user_msg=user_msg, + user_msg=user_message, reference=reference, chat_history=chat_history, ) @@ -143,14 +189,22 @@ async def get_aresponse( print(e) token_queue.put((error, e)) - def get_response(self, user_msg, reference, chat_history): - chat_history.append( - SystemMessage( - content="Remember to only answer the question if it can be answered with the given references" - ) - ) + def get_response(self, user_msg, reference, chat_history, documents): + system_message = "Remember to only answer the question if it can be answered with the given references" + user_message = user_msg + + if documents: + user_message = [user_msg] + system_message = f"Image Capabilities: Enabled\n{system_message}" + for document in documents: + encoded_document = get_base64_document(document) + if encoded_document: + user_message.append({"image": encoded_document, "resize": None}) + + chat_history.append(SystemMessage(content=system_message)) + return self.chain.predict( - user_msg=user_msg, + user_msg=user_message, reference=reference, chat_history=chat_history, ) diff --git a/ayushma/utils/openaiapi.py b/ayushma/utils/openaiapi.py index f8807dda..573145cf 100644 --- a/ayushma/utils/openaiapi.py +++ b/ayushma/utils/openaiapi.py @@ -266,6 +266,7 @@ def converse( generate_audio=True, noonce=None, fetch_references=True, + documents=None, ): if not openai_key: raise Exception("OpenAI-Key header is required to create a chat or converse") @@ -306,6 +307,9 @@ def converse( prompt = chat.prompt or (chat.project and chat.project.prompt) + if documents or chat.project.model == ModelType.GPT_4_VISUAL: + prompt = "Image Capabilities: Enabled\n" + prompt + # excluding the latest query since it is not a history previous_messages = ( ChatMessage.objects.filter(chat=chat) @@ -329,7 +333,9 @@ def converse( or ModelType.GPT_3_5, temperature=temperature, ) - response = lang_chain_helper.get_response(english_text, reference, chat_history) + response = lang_chain_helper.get_response( + english_text, reference, chat_history, documents + ) chat_response = response.replace("Ayushma: ", "") stats["response_end_time"] = time.time() translated_chat_response, url, chat_message = handle_post_response( @@ -372,6 +378,7 @@ def converse( english_text, reference, chat_history, + documents, ) chat_response = "" skip_token = len(f"{AI_NAME}: ") diff --git a/ayushma/views/document.py b/ayushma/views/document.py index 62a33afc..c9a5901d 100644 --- a/ayushma/views/document.py +++ b/ayushma/views/document.py @@ -1,8 +1,4 @@ -import json -from ast import Delete - from django.conf import settings -from drf_spectacular.utils import extend_schema, extend_schema_view from rest_framework.exceptions import ValidationError from rest_framework.mixins import ( CreateModelMixin, @@ -14,15 +10,15 @@ from rest_framework.permissions import IsAdminUser from rest_framework.response import Response -from ayushma.models import Document, DocumentType, Project +from ayushma.models import Document, Project +from ayushma.models.testsuite import TestQuestion from ayushma.serializers.document import DocumentSerializer, DocumentUpdateSerializer from ayushma.tasks.upsertdoc import upsert_doc -from ayushma.utils.upsert import upsert from utils.views.base import BaseModelViewSet from utils.views.mixins import PartialUpdateModelMixin -class DocumentViewSet( +class ProjectDocumentViewSet( BaseModelViewSet, PartialUpdateModelMixin, CreateModelMixin, @@ -63,6 +59,7 @@ def perform_create(self, serializer): try: doc_url = self.request.build_absolute_uri(document.file.url) except Exception as e: + print(e) pass upsert_doc.delay(document.external_id, doc_url) @@ -81,3 +78,37 @@ def perform_destroy(self, instance): status=400, ) return super().perform_destroy(instance) + + +class TestQuestionDocumentViewSet( + BaseModelViewSet, + CreateModelMixin, + RetrieveModelMixin, + DestroyModelMixin, + ListModelMixin, +): + queryset = Document.objects.all() + serializer_action_classes = { + "list": DocumentSerializer, + "retrieve": DocumentSerializer, + "create": DocumentSerializer, + } + permission_classes = (IsAdminUser,) + parser_classes = (MultiPartParser,) + lookup_field = "external_id" + + def get_queryset(self): + queryset = self.queryset.filter( + test_question__external_id=self.kwargs["test_question_external_id"] + ) + return queryset + + def perform_create(self, serializer): + external_id = self.kwargs["test_question_external_id"] + test_question = TestQuestion.objects.get(external_id=external_id) + + document = serializer.save(test_question=test_question, uploading=False) + test_question.documents.add(document) + test_question.save() + + return document diff --git a/utils/helpers.py b/utils/helpers.py index e5f90871..6709640d 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -1,9 +1,12 @@ +import base64 import random import string import requests from django.conf import settings +from ayushma.models.document import Document + def get_random_string(length: int) -> str: return "".join(random.choices(string.hexdigits, k=length)) @@ -42,3 +45,21 @@ def validatecaptcha(recaptcha_response): result = captcha_response.json() return result.get("success", False) + + +def get_base64_document(document: Document): + try: + if document.file.path: + with open(document.file.path, "rb") as f: + image_data = f.read() + except NotImplementedError: + try: + response = requests.get(document.file.url) + response.raise_for_status() + image_data = response.content + except Exception: + return None + except Exception: + return None + + return base64.b64encode(image_data).decode("utf-8")