From 36296e913bbeaf4b271863045d6361f0f26c5f88 Mon Sep 17 00:00:00 2001 From: Serhii Date: Thu, 12 Sep 2024 19:49:01 +0300 Subject: [PATCH] feature_ai_translate --- frontend/src/api/api.ts | 8 ++ frontend/src/components/image.tsx | 22 ++-- frontend/src/pages/Collection.tsx | 161 +++++++++++++++++++++++------- frontend/src/types/model.ts | 9 +- linguaphoto/ai/cli.py | 10 +- linguaphoto/ai/transcribe.py | 56 +++++------ linguaphoto/ai/tts.py | 3 +- linguaphoto/api/image.py | 15 +++ linguaphoto/crud/image.py | 62 ++++++++++++ linguaphoto/models.py | 16 ++- linguaphoto/requirements.txt | 3 + 11 files changed, 281 insertions(+), 84 deletions(-) diff --git a/frontend/src/api/api.ts b/frontend/src/api/api.ts index 5b46e48..2b1e631 100644 --- a/frontend/src/api/api.ts +++ b/frontend/src/api/api.ts @@ -59,4 +59,12 @@ export class Api { ); return response.data; } + public async translateImages(images: Array): Promise> { + const response = await this.api.post( + "/translate", + { images }, + { timeout: 300000 }, + ); + return response.data; + } } diff --git a/frontend/src/components/image.tsx b/frontend/src/components/image.tsx index 10ac9d4..5e6c8d6 100644 --- a/frontend/src/components/image.tsx +++ b/frontend/src/components/image.tsx @@ -1,13 +1,16 @@ import React from "react"; import { CheckCircleFill, LockFill, PencilFill } from "react-bootstrap-icons"; import { Image } from "types/model"; - -const ImageComponent: React.FC = ({ - // id, +// Extend the existing Image interface to include the new function +interface ImageWithFunction extends Image { + handleTranslateOneImage: (image_id: string) => void; +} +const ImageComponent: React.FC = ({ + id, is_translated, image_url, - // audio_url, - transcript, + transcriptions, + handleTranslateOneImage, }) => { return (
= ({ The image has been translated
- {transcript} + {transcriptions.map((transcription, index) => ( + {transcription.text}   + ))}
) : ( @@ -41,7 +46,10 @@ const ImageComponent: React.FC = ({ Edit ) : ( - )} diff --git a/frontend/src/pages/Collection.tsx b/frontend/src/pages/Collection.tsx index 9f06707..155a339 100644 --- a/frontend/src/pages/Collection.tsx +++ b/frontend/src/pages/Collection.tsx @@ -5,9 +5,15 @@ import Modal from "components/modal"; import UploadContent from "components/UploadContent"; import { useAuth } from "contexts/AuthContext"; import { useLoading } from "contexts/LoadingContext"; -import React, { useEffect, useState } from "react"; +import React, { useEffect, useMemo, useState } from "react"; import { Col, Row } from "react-bootstrap"; -import { ArrowLeft } from "react-bootstrap-icons"; +import { + ArrowLeft, + CaretLeft, + CaretRight, + SkipBackward, + SkipForward, +} from "react-bootstrap-icons"; import { useLocation, useNavigate, useParams } from "react-router-dom"; import { Collection, Image } from "types/model"; @@ -18,6 +24,7 @@ const CollectionPage: React.FC = () => { const [title, setTitle] = useState(""); const [description, setDescription] = useState(""); const [currentImageIndex, setCurrentImageIndex] = useState(0); + const [currentTranscriptionIndex, setCurrentTranscriptionIndex] = useState(0); const [currentImage, setCurrentImage] = useState(null); const [collection, setCollection] = useState(null); const { auth, is_auth } = useAuth(); @@ -25,29 +32,44 @@ const CollectionPage: React.FC = () => { const [showModal, setShowModal] = useState(false); const [images, setImages] = useState | null>([]); - const apiClient: AxiosInstance = axios.create({ - baseURL: process.env.REACT_APP_BACKEND_URL, // Base URL for all requests - timeout: 10000, // Request timeout (in milliseconds) - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${auth?.token}`, // Add any default headers you need - }, - }); - const apiClient1: AxiosInstance = axios.create({ - baseURL: process.env.REACT_APP_BACKEND_URL, // Base URL for all requests - timeout: 1000000, // Request timeout (in milliseconds) - headers: { - "Content-Type": "multipart/form-data", - Authorization: `Bearer ${auth?.token}`, // Add any default headers you need - }, - }); - const API = new Api(apiClient); - const API_Uploader = new Api(apiClient1); + const apiClient: AxiosInstance = useMemo( + () => + axios.create({ + baseURL: process.env.REACT_APP_BACKEND_URL, // Base URL for all requests + timeout: 10000, // Request timeout (in milliseconds) + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${auth?.token}`, // Add any default headers you need + }, + }), + [auth?.token], + ); + const apiClient1: AxiosInstance = useMemo( + () => + axios.create({ + baseURL: process.env.REACT_APP_BACKEND_URL, + timeout: 1000000, + headers: { + "Content-Type": "multipart/form-data", + Authorization: `Bearer ${auth?.token}`, + }, + }), + [auth?.token], + ); + const API = useMemo(() => new Api(apiClient), [apiClient]); + const API_Uploader = useMemo(() => new Api(apiClient1), [apiClient1]); // Helper to check if it's an edit action - const isEditAction = location.search.includes("Action=edit"); + const isEditAction = useMemo( + () => location.search.includes("Action=edit"), + [location.search], + ); // Get translated images - let translatedImages: Array = []; + const translatedImages = useMemo(() => { + // Get translated images + if (images) return images.filter((img) => img.is_translated); + return []; + }, [images]); // Simulate fetching data for the edit page (mocking API call) useEffect(() => { @@ -62,11 +84,6 @@ const CollectionPage: React.FC = () => { } }, [id, is_auth]); - useEffect(() => { - // Get translated images - if (images) translatedImages = images.filter((img) => img.is_translated); - }, [images]); - useEffect(() => { if (translatedImages.length > 0) { setCurrentImage(translatedImages[currentImageIndex]); @@ -96,15 +113,31 @@ const CollectionPage: React.FC = () => { const handleNext = () => { if (currentImageIndex < translatedImages.length - 1) { setCurrentImageIndex(currentImageIndex + 1); + setCurrentTranscriptionIndex(0); } }; const handlePrev = () => { if (currentImageIndex > 0) { setCurrentImageIndex(currentImageIndex - 1); + setCurrentTranscriptionIndex(0); + } + }; + // Navigate transcriptions + const handleTranscriptionNext = () => { + if ( + currentImage?.transcriptions && + currentTranscriptionIndex < currentImage?.transcriptions.length - 1 + ) { + setCurrentTranscriptionIndex(currentTranscriptionIndex + 1); } }; + const handleTranscriptionPrev = () => { + if (currentTranscriptionIndex > 0) { + setCurrentTranscriptionIndex(currentTranscriptionIndex - 1); + } + }; // Return button handler const handleReturn = () => { navigate("/collections"); @@ -133,6 +166,16 @@ const CollectionPage: React.FC = () => { } } }; + const handleTranslateOneImage = async (image_id: string) => { + if (images) { + startLoading(); + const image_response = await API.translateImages([image_id]); + const i = images?.findIndex((image) => image.id == image_id); + images[i] = image_response[0]; + setImages([...images]); + stopLoading(); + } + }; // Custom Return Button (fixed top-left with border) const ReturnButton = () => ( + + diff --git a/frontend/src/types/model.ts b/frontend/src/types/model.ts index 4c99a37..a648b88 100644 --- a/frontend/src/types/model.ts +++ b/frontend/src/types/model.ts @@ -5,11 +5,16 @@ export interface Collection { images: Array; } +interface Transcription { + text: string; + pinyin: string; + translation: string; + audio_url: string; +} export interface Image { id: string; is_translated: boolean; collection: string; image_url: string; - audio_url: string; - transcript: string; + transcriptions: Array; } diff --git a/linguaphoto/ai/cli.py b/linguaphoto/ai/cli.py index 21255aa..c5a4b95 100644 --- a/linguaphoto/ai/cli.py +++ b/linguaphoto/ai/cli.py @@ -7,9 +7,8 @@ from openai import AsyncOpenAI from PIL import Image - -from linguaphoto.ai.transcribe import transcribe_image -from linguaphoto.ai.tts import synthesize_text +from transcribe import transcribe_image +from tts import synthesize_text logger = logging.getLogger(__name__) @@ -26,8 +25,11 @@ async def main() -> None: # Transcribes the image. image = Image.open(args.image) - client = AsyncOpenAI() + client = AsyncOpenAI( + api_key="sk-svcacct-PFETCFHtqmHOmIpP_IAyQfBGz5LOpvC6Zudj7d5Wcdp9WjJT4ImAxuotGcpyT3BlbkFJRbtswQqIxYHam9TN13mCM04_OTZE-v8z-Rw1WEcwzyZqW_GcK0PNNyFp6BcA" + ) transcription_response = await transcribe_image(image, client) + print(transcription_response.model_dump_json(indent=2)) with open(root_dir / "transcription.json", "w") as file: file.write(transcription_response.model_dump_json(indent=2)) logger.info("Transcription saved to %s", args.output) diff --git a/linguaphoto/ai/transcribe.py b/linguaphoto/ai/transcribe.py index 3bae671..9e18f81 100644 --- a/linguaphoto/ai/transcribe.py +++ b/linguaphoto/ai/transcribe.py @@ -1,15 +1,16 @@ """Uses the OpenAI API to transcribe an image to text.""" -import argparse -import asyncio +# import argparse + +# import asyncio import base64 import logging from io import BytesIO import aiohttp +from models import TranscriptionResponse from openai import AsyncOpenAI from PIL import Image -from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -23,12 +24,14 @@ { "text": "你好,我朋友!", "pinyin": "nǐhǎo, wǒ péngyǒu!", - "translation": "Hello, my friend!" + "translation": "Hello, my friend!", + "audio_url":"" }, { "text": "我找到了工作。", "pinyin": "wǒ zhǎodàole gōngzuò.", - "translation": "I found a job." + "translation": "I found a job.", + "audio_url":"" } ] } @@ -48,17 +51,8 @@ def encode_image(image: Image.Image) -> str: return base64.b64encode(buffered.getvalue()).decode("utf-8") -class Transcription(BaseModel): - text: str - pinyin: str - translation: str - - -class TranscriptionResponse(BaseModel): - transcriptions: list[Transcription] - - -async def transcribe_image(image: Image.Image, client: AsyncOpenAI) -> TranscriptionResponse: +async def transcribe_image(image_source: BytesIO, client: AsyncOpenAI) -> TranscriptionResponse: + image = Image.open(image_source) """Transcribes the image to text. Args: @@ -101,23 +95,23 @@ async def transcribe_image(image: Image.Image, client: AsyncOpenAI) -> Transcrip return transcription_response -async def run_adhoc_test() -> None: - logging.basicConfig(level=logging.INFO) +# async def run_adhoc_test() -> None: +# logging.basicConfig(level=logging.INFO) - parser = argparse.ArgumentParser(description="Transcribe an image to text.") - parser.add_argument("image", type=str, help="The path to the image to transcribe.") - parser.add_argument("output", type=str, help="The path to save the transcription.") - args = parser.parse_args() +# parser = argparse.ArgumentParser(description="Transcribe an image to text.") +# parser.add_argument("image", type=str, help="The path to the image to transcribe.") +# parser.add_argument("output", type=str, help="The path to save the transcription.") +# args = parser.parse_args() - image = Image.open(args.image) - client = AsyncOpenAI() - transcription_response = await transcribe_image(image, client) +# image = Image.open(args.image) +# client = AsyncOpenAI() +# transcription_response = await transcribe_image(image, client) - with open(args.output, "w") as file: - file.write(transcription_response.model_dump_json(indent=2)) - logger.info("Transcription saved to %s", args.output) +# with open(args.output, "w") as file: +# file.write(transcription_response.model_dump_json(indent=2)) +# logger.info("Transcription saved to %s", args.output) -if __name__ == "__main__": - # python -m linguaphoto.ai.transcribe - asyncio.run(run_adhoc_test()) +# if __name__ == "__main__": +# # python -m linguaphoto.ai.transcribe +# asyncio.run(run_adhoc_test()) diff --git a/linguaphoto/ai/tts.py b/linguaphoto/ai/tts.py index 6ad7e90..33700ec 100644 --- a/linguaphoto/ai/tts.py +++ b/linguaphoto/ai/tts.py @@ -6,10 +6,9 @@ from pathlib import Path from typing import AsyncIterator +from ai.transcribe import TranscriptionResponse from openai import AsyncOpenAI -from linguaphoto.ai.transcribe import TranscriptionResponse - logger = logging.getLogger(__name__) diff --git a/linguaphoto/api/image.py b/linguaphoto/api/image.py index 3d92536..ca9619f 100644 --- a/linguaphoto/api/image.py +++ b/linguaphoto/api/image.py @@ -5,8 +5,14 @@ from crud.image import ImageCrud from fastapi import APIRouter, Depends, File, Form, UploadFile from models import Image +from pydantic import BaseModel from utils.auth import get_current_user_id + +class TranslateFramgement(BaseModel): + images: List[str] + + router = APIRouter() @@ -30,3 +36,12 @@ async def get_images( async with image_crud: images = await image_crud.get_images(collection_id=collection_id, user_id=user_id) return images + + +@router.post("/translate", response_model=List[Image]) +async def translate( + data: TranslateFramgement, user_id: str = Depends(get_current_user_id), image_crud: ImageCrud = Depends() +) -> List[Image]: + async with image_crud: + images = await image_crud.translate(data.images, user_id=user_id) + return images diff --git a/linguaphoto/crud/image.py b/linguaphoto/crud/image.py index 9f77ce2..ee47e00 100644 --- a/linguaphoto/crud/image.py +++ b/linguaphoto/crud/image.py @@ -2,13 +2,18 @@ import os import uuid +from io import BytesIO from typing import List +import requests +from ai.transcribe import transcribe_image +from ai.tts import synthesize_text from boto3.dynamodb.conditions import Key from crud.base import BaseCrud from errors import ItemNotFoundError from fastapi import HTTPException, UploadFile from models import Collection, Image +from openai import AsyncOpenAI from settings import settings from utils.cloudfront_url_signer import CloudFrontUrlSigner @@ -43,6 +48,63 @@ async def create_image(self, file: UploadFile, user_id: str, collection_id: str) return new_image raise ItemNotFoundError + # Handles audio file creation by synthesizing and uploading to S3 + async def create_audio(self, audio_source: BytesIO) -> str: + # Generate a unique file name with a '.mp3' extension (or other extension based on the format) + unique_filename = f"{uuid.uuid4()}.mp3" # You can change the extension based on the actual audio format + + # Create an instance of CloudFrontUrlSigner + private_key_path = os.path.abspath("private_key.pem") + cfs = CloudFrontUrlSigner(str(key_pair_id), private_key_path) + + # Generate a signed URL + url = f"{media_hosting_server}/{unique_filename}" + custom_policy = cfs.create_custom_policy(url, expire_days=100) + s3_url = cfs.generate_presigned_url(url, custom_policy) + + # Upload the audio source to S3 + await self._upload_to_s3(audio_source, unique_filename) + + # Return the signed S3 URL + return s3_url + async def get_images(self, collection_id: str, user_id: str) -> List[Image]: images = await self._get_items_from_secondary_index("user", user_id, Image, Key("collection").eq(collection_id)) return images + + # Translates the images to text and synthesizes audio for the transcriptions + async def translate(self, images: List[str], user_id: str) -> List[Image]: + image_instances = [] + for id in images: + # Retrieve image metadata and download the image content + image_instance = await self._get_item(id, Image, True) + response = requests.get(image_instance.image_url) + if response.status_code == 200: + img_source = BytesIO(response.content) + # Initialize OpenAI client for transcription and speech synthesis + client = AsyncOpenAI( + api_key="sk-svcacct-PFETCFHtqmHOmIpP_IAyQfBGz5LOpvC6Zudj7d5Wcdp9WjJT4ImAxuotGcpyT3BlbkFJRbtswQqIxYHam9TN13mCM04_OTZE-v8z-Rw1WEcwzyZqW_GcK0PNNyFp6BcA" + ) + transcription_response = await transcribe_image(img_source, client) + # Process each transcription and generate corresponding audio + for i, transcription in enumerate(transcription_response.transcriptions): + audio_buffer = BytesIO() + text = transcription.text + # Synthesize text and write the chunks directly into the in-memory buffer + async for chunk in await synthesize_text(text, client): + audio_buffer.write(chunk) + # Set buffer position to the start + audio_buffer.seek(0) + audio_url = await self.create_audio(audio_buffer) + # Attach the audio URL to the transcription + transcription.audio_url = audio_url + image_instance.transcriptions = transcription_response.transcriptions + image_instance.is_translated = True + await self._update_item( + id, + Image, + {"transcriptions": transcription_response.model_dump()["transcriptions"], "is_translated": True}, + ) + if image_instance: + image_instances.append(image_instance) + return image_instances diff --git a/linguaphoto/models.py b/linguaphoto/models.py index ed78312..3e94648 100644 --- a/linguaphoto/models.py +++ b/linguaphoto/models.py @@ -62,12 +62,22 @@ def create(cls, title: str, description: str, user_id: str) -> Self: return cls(id=str(uuid4()), title=title, description=description, user=user_id) +class Transcription(BaseModel): + text: str + pinyin: str + translation: str + audio_url: str + + +class TranscriptionResponse(BaseModel): + transcriptions: list[Transcription] + + class Image(LinguaBaseModel): - is_traslated: bool = False - transcript: str | None = None + is_translated: bool = False + transcriptions: list[Transcription] = [] collection: str | None = None image_url: str - audio_url: str | None = None user: str @classmethod diff --git a/linguaphoto/requirements.txt b/linguaphoto/requirements.txt index f49bbc9..3914a90 100644 --- a/linguaphoto/requirements.txt +++ b/linguaphoto/requirements.txt @@ -34,3 +34,6 @@ numpy-stl types-aioboto3[dynamodb, s3] python-dotenv + +#AI +openai \ No newline at end of file