Skip to content

Commit

Permalink
Merge pull request #118 from 2024-SummerBootcamp-Team/develop
Browse files Browse the repository at this point in the history
[main] main Merge
  • Loading branch information
dlwhsk0 authored Jul 25, 2024
2 parents 016d77a + 2a043da commit 51c9ff2
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 22 deletions.
14 changes: 14 additions & 0 deletions app/config/langChain/langChainSetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ def get_session_history(session_id):
]
)
topic_chain = RunnablePassthrough() | prompt_topic | llm

# 채팅방 매운맛 분석
prompt_spicy = ChatPromptTemplate.from_messages(
[
("system", """
대화 내용을 주면 그 내용을 분석해서 대화 내용 중 독한말의 정도를 정수 1부터 10까지 중 한가지 선택해서 말해줘 숫자가 클수록 대화 내용의 독한말 정도가 큰거야 .
이유는 붙이지 말고 정수 결과만 딱 반환해줘.
"""
),
("human", "{input}") # 사용자 메시지
]
)
spicy_chain = RunnablePassthrough() | prompt_spicy | llm
3 changes: 2 additions & 1 deletion app/models/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Float
from sqlalchemy.orm import relationship

from ..database.session import Base
Expand All @@ -15,6 +15,7 @@ class Chat(Base):
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now)
name = Column(String(45), nullable=False)
topic = Column(String(10), nullable=True)
spicy = Column(Float, nullable=True)

character = relationship("Character", back_populates="chats")
bubbles = relationship("Bubble", back_populates="chat")
4 changes: 2 additions & 2 deletions app/routers/characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def read_characters(db: Session = Depends(get_db), skip: int = 0, limit: int = 1


# 대시보드
@router.get("/dashboard/total", response_model=ResultResponseModel, summary="전체 캐릭터 통계", description="전체 캐릭터에 대한 통계를 조회합니다.")
@router.get("/dashboards/total", response_model=ResultResponseModel, summary="전체 캐릭터 통계", description="전체 캐릭터에 대한 통계를 조회합니다.")
def read_dashboard_total(db: Session = Depends(get_db)):
result = character_service.get_dashboard_total(db)
return ResultResponseModel(code=200, message="전체 캐릭터에 대한 통계를 조회했습니다.", data=result)


@router.get("/dashboard/{character_name}", response_model=ResultResponseModel, summary="캐릭터 별 통계", description="캐릭터 별 통계를 조회합니다.")
@router.get("/dashboards/{character_name}", response_model=ResultResponseModel, summary="캐릭터 별 통계", description="캐릭터 별 통계를 조회합니다.")
def read_dashboard_character(character_name: str, db: Session = Depends(get_db)):
result = character_service.get_dashboard_character(db, character_name)
return ResultResponseModel(code=200, message="캐릭터에 대한 통계를 조회했습니다.", data=result)
2 changes: 2 additions & 0 deletions app/routers/chats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session

from ..models import Bubble
from ..schemas.response import ResultResponseModel
from ..services import chat_service
from ..database.session import get_db
Expand Down
2 changes: 2 additions & 0 deletions app/routers/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from io import BytesIO
from datetime import timedelta

Expand All @@ -8,6 +9,7 @@
from app.config.elevenlabs.text_to_speech_stream import text_to_speech_stream
from app.config.redis.config import Config
from app.database.session import get_db
from app.models import Chat
from app.schemas.response import ResultResponseModel
from app.schemas.voice import VoiceCreateRequest
from app.services import bubble_service, voice_service, chat_service
Expand Down
6 changes: 5 additions & 1 deletion app/schemas/bubble.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ class BubbleRequest(BaseModel):
content: str

class Config:
from_attributes = True
from_attributes = True

# class BubbleRequest(BaseModel)
# content: str
# spicy_score: int # AI로부터 반환되는 spicy_score 필드 추가
1 change: 1 addition & 0 deletions app/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ChatRoomBase(BaseModel):
character_id: int
character_name: str
topic: Optional[str] = None
spicy: Optional[int] = None
created_at: datetime
name: str

Expand Down
9 changes: 5 additions & 4 deletions app/schemas/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ class TopicFrequency(BaseModel):
연애: int = Field(alias="연애")

class SpicyFrequency(BaseModel):
level_1_2: int = Field(alias="1-2")
level_0_2: int = Field(alias="0-2")
level_2_4: int = Field(alias="2-4")
level_5_6: int = Field(alias="5-6")
level_7_8: int = Field(alias="7-8")
level_9_10: int = Field(alias="9-10")
level_4_6: int = Field(alias="4-6")
level_6_8: int = Field(alias="6-8")
level_8_10: int = Field(alias="8-10")

class Config:
populate_by_name = True # 필드 이름을 기준으로 값을 채움

class CharacterStats(BaseModel):
name: str
chat_count: int
topic_frequency: TopicFrequency
spicy_frequency: SpicyFrequency

Expand Down
7 changes: 7 additions & 0 deletions app/services/bubble_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,12 @@ async def create_bubble(chat_id: int, content: str, db: Session):
bubble_count = db.query(func.count(Bubble.id)).filter(Bubble.chat_id == chat_id).scalar()
if bubble_count == 2 or bubble_count % 10 == 0:
topic = chat_service.get_chat_topic(db, chat_id)
spicy = chat_service.get_chat_spicy(db, chat_id)
print("topic", topic)
print("spicy", spicy)
# yield f"data: {json.dumps({'topic': topic})}\n\n"


# 대화 내용 최신순으로 가져오기
def get_recent_bubbles(db: Session, chat_id: int, limit: int):
return db.query(Bubble).filter(Bubble.chat_id == chat_id).order_by(Bubble.created_at.desc()).limit(limit).all()
63 changes: 53 additions & 10 deletions app/services/character_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List

from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy.orm import Session

from . import image_service, voice_service
from . import image_service, voice_service, chat_service
from ..models import Chat
from ..models.character import Character
from ..schemas.character import CharacterDetail
Expand All @@ -20,42 +22,83 @@ def get_character_by_name(db: Session, character_name: str):

def get_characters(db: Session, skip: int = 0, limit: int = 100):
characters = db.query(Character).filter(Character.is_deleted == False).offset(skip).limit(limit).all()
character_list = [CharacterDetail(id=character.id, name=character.name, image_url=character.image_url) for character in characters]
character_list = [CharacterDetail(id=character.id, name=character.name, image_url=character.image_url) for character
in characters]
return character_list


# 대시보드 관련
def get_topic_count(db: Session, character_id: int, topic: str) -> int:
return db.query(func.count(Chat.id)).filter(Chat.character_id == character_id, Chat.topic == topic).scalar()


# 매운맛 점수 구간 개수 구하기
def get_spicy_frequency(db: Session, character_id: int) -> SpicyFrequency:
chats = db.query(Chat).filter(Chat.character_id == character_id).all()
spicy_count = get_spicy_count(chats)
return SpicyFrequency(
level_0_2=spicy_count["0-2"],
level_2_4=spicy_count["2-4"],
level_4_6=spicy_count["4-6"],
level_6_8=spicy_count["6-8"],
level_8_10=spicy_count["8-10"]
)


def get_spicy_count(chats) -> dict:
spicy_count = {
"0-2": 0,
"2-4": 0,
"4-6": 0,
"6-8": 0,
"8-10": 0
}
for chat in chats:
spicy = chat.spicy
if 0 <= spicy < 2:
spicy_count["0-2"] += 1
elif 2 <= spicy < 4:
spicy_count["2-4"] += 1
elif 4 <= spicy < 6:
spicy_count["4-6"] += 1
elif 6 <= spicy < 8:
spicy_count["6-8"] += 1
elif 8 <= spicy <= 10:
spicy_count["8-10"] += 1
return spicy_count


# 매운맛 점수 평균 구하기
def get_spicy_average(db: Session, character_id: int):
return -1
chats = chat_service.get_chat_by_character_id(db, character_id)
if not chats: # chats가 비어있으면 0.0을 반환합니다.
return 0.0
total_spicy = sum(chat.spicy if chat.spicy is not None else 0 for chat in chats)
return total_spicy / len(chats)


# 전체 통계
def get_dashboard_total(db: Session):
characters = get_characters(db)
return DashboardTotal(
characters=[
CharacterStats(
name=character.name,
chat_count=chat_service.get_chat_count(db, character.id),
topic_frequency=TopicFrequency(
취업=get_topic_count(db, character.id, "취업"),
학업=get_topic_count(db, character.id, "학업"),
인간관계=get_topic_count(db, character.id, "인간관계"),
연애=get_topic_count(db, character.id, "연애")
),
spicy_frequency=SpicyFrequency(
level_1_2=7,
level_2_4=15,
level_5_6=10,
level_7_8=10,
level_9_10=8
)
spicy_frequency=get_spicy_frequency(db, character.id)
)
for character in characters
]
)


# 캐릭터 별 통계
def get_dashboard_character(db: Session, character_name: str):
character = get_character_by_name(db, character_name)
if not character:
Expand Down
38 changes: 34 additions & 4 deletions app/services/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from fastapi import HTTPException
from sqlalchemy import func

from app.config.langChain.langChainSetting import topic_chain
from app.config.langChain.langChainSetting import topic_chain, spicy_chain
from app.models import Character
from app.models.chat import Chat
from sqlalchemy.orm import Session
from app.models.voice import Voice
from app.models.image import Image
from app.schemas.bubble import ChatBubble, ChatBubbleList
from app.schemas.chat import ChatRoomBase
from app.models.bubble import Bubble
from app.services import bubble_service


# 채팅방 조회
def get_chat(db: Session, chat_id: int) -> Chat:
Expand All @@ -26,7 +29,8 @@ def get_chat_room(db: Session, chat_id: int):
character_name=chat.character.name,
topic=chat.topic,
created_at=chat.created_at,
name=chat.name
name=chat.name,
spicy=chat.spicy
)


Expand Down Expand Up @@ -58,12 +62,20 @@ def create_chat_room(db: Session, chat_name: str, character_id: int):
return chat


# 캐릭터 별 채팅방 조회
def get_chat_by_character_id(db: Session, character_id: int):
return db.query(Chat).filter(Chat.character_id == character_id, Chat.is_deleted == False).all()


# 캐릭터 별 채팅방 개수 조회
def get_chat_count(db: Session, character_id: int):
return db.query(func.count(Chat.id)).filter(Chat.character_id == character_id, Chat.is_deleted == False).scalar()

# 채팅방 토픽 분석 및 업데이트
def get_chat_topic(db: Session, chat_id: int):
# 최신순으로 10개 가져오기 (5쌍)
bubbles = db.query(Bubble).filter(Bubble.chat_id == chat_id).order_by(Bubble.created_at.desc()).limit(10).all()
bubbles = bubble_service.get_recent_bubbles(db, chat_id, 10)
content = "\n\n".join(bubble.content for bubble in bubbles)
print("content: ", content)
topic = topic_chain.invoke({"input": content})
# 갱신된 topic 저장
chat = get_chat(db, chat_id)
Expand All @@ -74,3 +86,21 @@ def get_chat_topic(db: Session, chat_id: int):
db.commit()
db.refresh(chat)
return chat.topic


# 채팅방 매운맛 점수 분석 및 업데이트
def get_chat_spicy(db: Session, chat_id: int):
# 최신순으로 10개 가져오기 (5쌍)
bubbles = bubble_service.get_recent_bubbles(db, chat_id, 10)
content = "\n\n".join(bubble.content for bubble in bubbles) # 가져온 대화 10개 하나의 긴 문장으로 합치기
# 합쳐진 문장에서 매운맛 찾기
spicy = spicy_chain.invoke({"input": content})
# 채팅방 정보 가져오기
chat = get_chat(db, chat_id)
if not chat:
raise HTTPException(status_code=404, detail="채팅방 정보를 불러오는데 실패했습니다.")
elif chat.spicy != spicy.content: # 갱신된 매운맛 점수가 기존과 다를 경우 업데이트
chat.spicy = spicy.content
db.commit()
db.refresh(chat)
return chat.spicy # 업데이트 된 spicy 반환

0 comments on commit 51c9ff2

Please sign in to comment.