diff --git a/backend/petkinApp/__pycache__/main.cpython-311.pyc b/backend/petkinApp/__pycache__/main.cpython-311.pyc index d59565d..f11bc04 100644 Binary files a/backend/petkinApp/__pycache__/main.cpython-311.pyc and b/backend/petkinApp/__pycache__/main.cpython-311.pyc differ diff --git a/backend/petkinApp/routers/prediction.py b/backend/petkinApp/routers/prediction.py index bf27131..ef94143 100644 --- a/backend/petkinApp/routers/prediction.py +++ b/backend/petkinApp/routers/prediction.py @@ -1,7 +1,6 @@ import json import os from datetime import datetime - import pandas as pd import gdown from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter, Form, Depends @@ -9,12 +8,12 @@ from sqlalchemy.orm import Session, session from torch import nn import torch +from torch.nn import Module from torchvision import transforms, models from PIL import Image from sklearn.preprocessing import MinMaxScaler from httpx import AsyncClient import torch.nn.functional as F - from petkinApp.database import get_db from petkinApp.models import AIResult, DiseasePredictionRecord from petkinApp.routers.pets import get_pet @@ -84,7 +83,7 @@ def load_model(model_path="model.pt"): # 모델 파일이 없으면 Google Drive에서 다운로드 if not os.path.exists(model_path): print(f"{model_path} not found. Downloading from Google Drive...") - download_model_from_google_drive(file_id="1vF5oXignqWuLBE7uQUthx4sTc8w4RydA", output_path=model_path) + download_model_from_google_drive(file_id="1YKt3OayICzlpXLYFVYD5TOZ2Lg8bnSJ0", output_path=model_path) # 모델 로드 print("Loading model...") @@ -172,15 +171,8 @@ async def predict_api( print("Model is None!") raise HTTPException(status_code=500, detail="Model is not loaded properly.") - # 펫 정보 추출 및 전처리 - df = pd.DataFrame([{ - "breed": pet_info.breed, - "gender": pet_info.gender, - "lesions": None, # 필요한 경우 기본값 설정 - "age": pet_info.age - }]) # PetDetailResponse를 DataFrame으로 변환 - processed_features = preprocess_dataframe(df) - additional_features_tensor = torch.tensor(processed_features.values, dtype=torch.float32) + # 펫 정보 전처리 + additional_features_tensor = preprocess_pet_info(pet_info) # 이미지 처리 image_tensor = image_transform(image.file) @@ -188,9 +180,9 @@ async def predict_api( # 모델 예측 with torch.no_grad(): - logits = model(image_tensor, additional_features_tensor) - probabilities = F.softmax(logits, dim=1).squeeze(0).tolist() # 1D 리스트로 변환 - print("Logits with features:", logits) + logit = model(image_tensor, additional_features_tensor) + probabilities = F.softmax(logit, dim=1).squeeze(0).tolist() # 1D 리스트로 변환 + print("Logits with features:", logit) print("Probabilities with features:", probabilities) # 이미지 저장 및 URL 생성 @@ -211,29 +203,25 @@ async def predict_api( image_path = os.path.join(upload_dir, image_filename) # 저장 경로 try: - from PIL import Image + # 이미지를 RGB 형식으로 변환 + pil_image = Image.open(image.file).convert("RGB") + + # 확장자에 따른 이미지 저장 포맷 결정 + save_format = "JPEG" if file_extension in [".jpg", ".jpeg"] else file_extension[1:].upper() - # `Pillow`를 이용한 이미지 변환 및 저장 - pil_image = Image.open(image.file).convert("RGB") # `webp` 포함 모든 파일을 RGB로 변환 - pil_image.save(image_path, format=file_extension[1:].upper()) # 확장자에 맞는 형식으로 저장 + # 이미지 저장 + pil_image.save(image_path, format=save_format) + print(f"Image saved successfully at {image_path}") # 저장 성공 로그 추가 - # 저장된 이미지 URL 생성 - image_url = f"/static/uploads/{image_filename}" except Exception as e: + print(f"Failed to save image: {str(e)}") # 에러 로그 출력 raise HTTPException(status_code=500, detail=f"Failed to process the image: {str(e)}") - # 가장 높은 클래스 인덱스 및 이름 가져오기 - class_mapping = { - 0: "A1 구진/플라크", - 1: "A2 비듬/각질/상피성잔고리", - 2: "A3 태선화 과다색소침착", - 3: "A4 농포/여드름", - 4: "A5 미란/궤양", - 5: "A6 결정/종괴", - 6: "A7 무증상" - } - predicted_class_index = torch.argmax(torch.tensor(probabilities)).item() # 가장 높은 인덱스 - predicted_class_label = class_mapping[predicted_class_index] + # 저장된 이미지 URL 생성 + image_url = f"/static/uploads/{image_filename}" + + # 클래스 라벨 반환 + predicted_class_label, predicted_class_index = get_predicted_class_label(probabilities) # 모델 이름 (예: EfficientNet) model_name = "MultimodalModel-EfficientNetB0" @@ -436,3 +424,35 @@ async def get_result_detail_by_analysis_id( status_code=500, detail=f"Failed to retrieve AIResult: {str(e)}" ) + + +def get_predicted_class_label(probabilities): + """ + 확률 값을 바탕으로 예측된 클래스 라벨을 반환합니다. + """ + class_mapping = { + 0: "A1 구진/플라크", + 1: "A2 비듬/각질/상피성잔고리", + 2: "A3 태선화 과다색소침착", + 3: "A4 농포/여드름", + 4: "A5 미란/궤양", + 5: "A6 결정/종괴", + 6: "A7 무증상" + } + predicted_class_index = torch.argmax(torch.tensor(probabilities)).item() + return class_mapping[predicted_class_index], predicted_class_index + + +def preprocess_pet_info(pet_info: PetDetailResponse): + """ + 펫 정보를 받아 DataFrame으로 변환하고 전처리를 수행합니다. + """ + df = pd.DataFrame([{ + "breed": pet_info.breed, + "gender": pet_info.gender, + "lesions": None, # 기본값 설정 + "age": pet_info.age + }]) + processed_features = preprocess_dataframe(df) + additional_features_tensor = torch.tensor(processed_features.values, dtype=torch.float32) + return additional_features_tensor