Skip to content

Commit

Permalink
Fix: Backend 사진 저장 코드 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
aengzu committed Dec 17, 2024
1 parent be3cb3d commit a809802
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
Binary file modified backend/petkinApp/__pycache__/main.cpython-311.pyc
Binary file not shown.
86 changes: 53 additions & 33 deletions backend/petkinApp/routers/prediction.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import json
import os
from datetime import datetime

import pandas as pd
import gdown
from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter, Form, Depends
from pytz import timezone
from sqlalchemy.orm import Session, session
from torch import nn
import torch
from torch.nn import Module
from torchvision import transforms, models
from PIL import Image
from sklearn.preprocessing import MinMaxScaler
from httpx import AsyncClient
import torch.nn.functional as F

from petkinApp.database import get_db
from petkinApp.models import AIResult, DiseasePredictionRecord
from petkinApp.routers.pets import get_pet
Expand Down Expand Up @@ -84,7 +83,7 @@ def load_model(model_path="model.pt"):
# 모델 파일이 없으면 Google Drive에서 다운로드
if not os.path.exists(model_path):
print(f"{model_path} not found. Downloading from Google Drive...")
download_model_from_google_drive(file_id="1vF5oXignqWuLBE7uQUthx4sTc8w4RydA", output_path=model_path)
download_model_from_google_drive(file_id="1YKt3OayICzlpXLYFVYD5TOZ2Lg8bnSJ0", output_path=model_path)

# 모델 로드
print("Loading model...")
Expand Down Expand Up @@ -172,25 +171,18 @@ async def predict_api(
print("Model is None!")
raise HTTPException(status_code=500, detail="Model is not loaded properly.")

# 펫 정보 추출 및 전처리
df = pd.DataFrame([{
"breed": pet_info.breed,
"gender": pet_info.gender,
"lesions": None, # 필요한 경우 기본값 설정
"age": pet_info.age
}]) # PetDetailResponse를 DataFrame으로 변환
processed_features = preprocess_dataframe(df)
additional_features_tensor = torch.tensor(processed_features.values, dtype=torch.float32)
# 펫 정보 전처리
additional_features_tensor = preprocess_pet_info(pet_info)

# 이미지 처리
image_tensor = image_transform(image.file)
print(f"Input tensor shape: {image_tensor.shape}")

# 모델 예측
with torch.no_grad():
logits = model(image_tensor, additional_features_tensor)
probabilities = F.softmax(logits, dim=1).squeeze(0).tolist() # 1D 리스트로 변환
print("Logits with features:", logits)
logit = model(image_tensor, additional_features_tensor)
probabilities = F.softmax(logit, dim=1).squeeze(0).tolist() # 1D 리스트로 변환
print("Logits with features:", logit)
print("Probabilities with features:", probabilities)

# 이미지 저장 및 URL 생성
Expand All @@ -211,29 +203,25 @@ async def predict_api(
image_path = os.path.join(upload_dir, image_filename) # 저장 경로

try:
from PIL import Image
# 이미지를 RGB 형식으로 변환
pil_image = Image.open(image.file).convert("RGB")

# 확장자에 따른 이미지 저장 포맷 결정
save_format = "JPEG" if file_extension in [".jpg", ".jpeg"] else file_extension[1:].upper()

# `Pillow`를 이용한 이미지 변환 및 저장
pil_image = Image.open(image.file).convert("RGB") # `webp` 포함 모든 파일을 RGB로 변환
pil_image.save(image_path, format=file_extension[1:].upper()) # 확장자에 맞는 형식으로 저장
# 이미지 저장
pil_image.save(image_path, format=save_format)
print(f"Image saved successfully at {image_path}") # 저장 성공 로그 추가

# 저장된 이미지 URL 생성
image_url = f"/static/uploads/{image_filename}"
except Exception as e:
print(f"Failed to save image: {str(e)}") # 에러 로그 출력
raise HTTPException(status_code=500, detail=f"Failed to process the image: {str(e)}")

# 가장 높은 클래스 인덱스 및 이름 가져오기
class_mapping = {
0: "A1 구진/플라크",
1: "A2 비듬/각질/상피성잔고리",
2: "A3 태선화 과다색소침착",
3: "A4 농포/여드름",
4: "A5 미란/궤양",
5: "A6 결정/종괴",
6: "A7 무증상"
}
predicted_class_index = torch.argmax(torch.tensor(probabilities)).item() # 가장 높은 인덱스
predicted_class_label = class_mapping[predicted_class_index]
# 저장된 이미지 URL 생성
image_url = f"/static/uploads/{image_filename}"

# 클래스 라벨 반환
predicted_class_label, predicted_class_index = get_predicted_class_label(probabilities)

# 모델 이름 (예: EfficientNet)
model_name = "MultimodalModel-EfficientNetB0"
Expand Down Expand Up @@ -436,3 +424,35 @@ async def get_result_detail_by_analysis_id(
status_code=500,
detail=f"Failed to retrieve AIResult: {str(e)}"
)


def get_predicted_class_label(probabilities):
"""
확률 값을 바탕으로 예측된 클래스 라벨을 반환합니다.
"""
class_mapping = {
0: "A1 구진/플라크",
1: "A2 비듬/각질/상피성잔고리",
2: "A3 태선화 과다색소침착",
3: "A4 농포/여드름",
4: "A5 미란/궤양",
5: "A6 결정/종괴",
6: "A7 무증상"
}
predicted_class_index = torch.argmax(torch.tensor(probabilities)).item()
return class_mapping[predicted_class_index], predicted_class_index


def preprocess_pet_info(pet_info: PetDetailResponse):
"""
펫 정보를 받아 DataFrame으로 변환하고 전처리를 수행합니다.
"""
df = pd.DataFrame([{
"breed": pet_info.breed,
"gender": pet_info.gender,
"lesions": None, # 기본값 설정
"age": pet_info.age
}])
processed_features = preprocess_dataframe(df)
additional_features_tensor = torch.tensor(processed_features.values, dtype=torch.float32)
return additional_features_tensor

0 comments on commit a809802

Please sign in to comment.