-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
281 lines (221 loc) · 9.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from fastapi import FastAPI, File, Form, UploadFile, Header
from PIL import Image
from io import BytesIO
import numpy as np
import torch
import torchvision
import torchvision.utils as vutils
import shutil
import httpx
import jwt
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime
import boto3
from store_NFTStorage import store_NFTStorage
import io
from mask import mask
from predictor import predictor
from PIL import ImageOps
# start : uvicorn main:app --host 0.0.0.0 --port 8000 --reload &
app = FastAPI()
origins = [
"http://localhost:3000",
"http://3.133.233.81:3000"
"https://localhost:3000",
"https://3.133.233.81:3000"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#####################################################################
# ---------------------------- Configs ---------------------------- #
#####################################################################
categories = ['TOP', 'BTM', 'RIN']
#categories = ['DR', 'TOP', 'BTM', 'HEA', 'BRA', 'NEC', 'BAG', 'MAS', 'RIN']
samples_per_categories = {'DR': 'sphere','TOP': 'sphere', 'BTM': 'sphere', 'HEA': 'sphere', 'NEC': 'torus', 'BAG': 'torus', 'MAS':'sphere', 'RIN':'torus'}
args_per_categories = {
'TOP': {
'azi_scope' : 360,
'elev_range' : '0~30',
'dist_range' : '5~7',
'flip_dim' : 2
},
'BTM': {
'azi_scope': 360,
'elev_range': '0~30',
'dist_range': '5~7',
'flip_dim' : 2
},
'RIN': {
'azi_scope': 360,
'elev_range': '0~30',
'dist_range': '4~7',
'flip_dim' : 3
}
}
image_size = 256
predictor_models = {}
diffRenderers = {}
origin = "/home/ec2-user/Matilda_Learning"
# 카테고리별 예측 모델 불러오기
for category in categories:
# predictor
predictor_model_path = f'{origin}/predictor/network/models/{category}.pth'
init_mesh_path = f"{origin}/predictor/samples/{samples_per_categories[category]}.obj"
predictor_model, diffRenderer = predictor.get_predictor_model(init_mesh_path,predictor_model_path,image_size,args_per_categories[category])
predictor_models[category] = predictor_model
if samples_per_categories[category] not in diffRenderers:
diffRenderers[samples_per_categories[category]] = diffRenderer
# mask 모델 불러오기
model_path = f"{origin}/mask/DIS/saved_models/isnet.pth" ## load trained weights from this path
mask_model = mask.get_mask_model(model_path)
#####################################################################
# ----------------------------- Utils ----------------------------- #
#####################################################################
# WAS를 통해 Repository에 파일 저장
URL = "http://3.133.233.81:8080"
def save_fileinfo_into_repository(title: str, catCode: str, imgUrl: str, objectUrl: str, token: str) -> str:
memberNum = jwt.decode(token, options={"verify_signature": False})['num']
body = {
"title": title,
"catCode": catCode,
"imgUrl": imgUrl,
"memberNum": memberNum,
"objectUrl": objectUrl
}
response = httpx.post(URL + '/items/new', json=body)
return response.json()
def get_fileinfo_from_repository(num: int, token: str) -> str:
response = httpx.get(URL + '/objects/auth/objUrl/' +
str(num), headers={'X-AUTH-TOKEN': token})
return response.text
bucketMatilda = boto3.resource('s3').Bucket('matilda.image-storage')
def save_file_into_S3(localfilePath: str, targetfilePath: str):
bucketMatilda.upload_file(localfilePath, targetfilePath)
return True
def get_file_from_S3(filePath: str) -> io.BytesIO:
fileData = io.BytesIO()
bucketMatilda.download_fileobj(filePath, fileData)
fileData.seek(0)
return fileData
def load_into_tensor_and_resize(data, resolution, mask_model):
img = Image.open(BytesIO(data)).convert('RGB')
target_height, target_width = resolution, resolution
W, H = img.size
desired_size = max(W, H)
delta_w = desired_size - W
delta_h = desired_size - H
padding = (delta_w//2, delta_h//2, delta_w-(delta_w//2), delta_h-(delta_h//2))
img = ImageOps.expand(img, padding)
img = img.resize((target_height, target_width))
img = torchvision.transforms.functional.to_tensor(img).cuda()
img_mask = mask.get_mask_from_image(mask_model, img)
img = img * img_mask + torch.ones_like(img) * (1 - img_mask)
return img
#####################################################################
# ----------------------------- APIs ------------------------------ #
#####################################################################
@app.get("/")
async def root():
return {"Welcome"}
@app.post("/convert")
async def convert(file: UploadFile = File(...), category: str = Form(...), X_AUTH_TOKEN: str = Header()):
# 파일이 주어졌는지 확인
if file is None:
return {"message": "file is not found"}
# 카테고리가 유효한지 확인
if category not in categories:
return {"message": "category not found"}
# 파일 이름 추출
title = '.'.join(file.filename.split('.')[:-1])
if len(title) > 45:
title = title[0:45]
image = load_into_tensor_and_resize(await file.read(),image_size, mask_model) # image 사이즈 조절 및 tensor로 변환
predictor = predictor_models[category] # category에 해당하는 3D 속성 예측 모델 불러오기
dib_r = diffRenderers[samples_per_categories[category]] # category에 해당하는 3D Renderer 불러오기
attributes = predictor(image.unsqueeze(0), args_per_categories[category]['flip_dim'])
# 파일 이름에 사용 할 시간 정보
now = str(int(datetime.now().timestamp()))
# 파일이 로컬에 임시로 저장될 위치
save_path = './temp/' + now + '/'
# 3D Object 생성 - 생성된 mesh, texture, lights를 통해 3D 파일(.glb) 추출하기
obj_save_path, img_save_path = dib_r.save_object(attributes, category, save_path)
# 파일이 S3에 저장될 위치
objPath = 'items/obj/' + category + '/' + now + '_' + title + '.glb'
imgPath = 'items/img/' + category + '/' + now + '_' + title + '.jpg'
# S3에 파일 저장
save_file_into_S3(obj_save_path, objPath)
save_file_into_S3(img_save_path, imgPath)
# 로컬 파일 삭제
shutil.rmtree(save_path)
# WAS로 saveUrl 전달
response = save_fileinfo_into_repository(
title, category, imgPath, objPath, X_AUTH_TOKEN)
return response
@app.post("/convert/twoimg")
async def convert_by_two_imgs(file1: UploadFile = File(...), file2: UploadFile = File(...), category: str = Form(...), X_AUTH_TOKEN: str = Header()):
# 파일이 주어졌는지 확인
if file1 is None or file2 is None:
return {"message": "file is not found"}
# 카테고리가 유효한지 확인
if category not in categories:
return {"message": "category not found"}
# 파일 이름 추출
title = '.'.join(file1.filename.split('.')[:-1])
if len(title) > 45:
title = title[0:45]
image1 = load_into_tensor_and_resize(await file1.read(), image_size, mask_model) # image 사이즈 조절 및 tensor로 변환
image2 = load_into_tensor_and_resize(await file2.read(), image_size, mask_model) # image 사이즈 조절 및 tensor로 변환
images = torch.cat([image1.unsqueeze(0),image2.unsqueeze(0)], dim=0)
predictor = predictor_models[category] # category에 해당하는 3D 속성 예측 모델 불러오기
dib_r = diffRenderers[samples_per_categories[category]] # category에 해당하는 3D Renderer 불러오기
attributes = predictor(images, args_per_categories[category]['flip_dim'])
vert_alpha = 0.7
attributes['vertices'] = (attributes['vertices'][0] * vert_alpha + attributes['vertices'][1] * (1-vert_alpha)).unsqueeze(0)
attributes['lights'] = attributes['lights'][0].unsqueeze(0)
tex_front = attributes['textures'][0]
# 앞,뒤 자연스럽게 이어지는 텍스처 생성
tex_back = attributes['textures'][1].flip([2])
textures = torch.cat([tex_front[:,:image_size], tex_back[:,image_size:]], dim=1)
smmoth_len = 32
for i in range(smmoth_len):
idx = image_size+i
alpha = i/smmoth_len
textures[:,idx] = tex_front[:,idx]*(1-alpha) + tex_back[:,idx] * alpha
attributes['textures'] = textures.unsqueeze(0)
attributes['distances'] = attributes['distances'][0].unsqueeze(0)
attributes['elevations'] = attributes['elevations'][0].unsqueeze(0)
attributes['azimuths'] = attributes['azimuths'][0].unsqueeze(0)
# 파일 이름에 사용 할 시간 정보
now = str(int(datetime.now().timestamp()))
# 파일이 로컬에 임시로 저장될 위치
save_path = './temp/' + now + '/'
# 3D Object 생성 - 생성된 mesh, texture, lights를 통해 3D 파일(.glb) 추출하기
obj_save_path, img_save_path = dib_r.save_object(attributes, category, save_path)
# 파일이 S3에 저장될 위치
objPath = 'items/obj/' + category + '/' + now + '_' + title + '.glb'
imgPath = 'items/img/' + category + '/' + now + '_' + title + '.jpg'
# S3에 파일 저장
save_file_into_S3(obj_save_path, objPath)
save_file_into_S3(img_save_path, imgPath)
# 로컬 파일 삭제
shutil.rmtree(save_path)
# WAS로 saveUrl 전달
response = save_fileinfo_into_repository(
title, category, imgPath, objPath, X_AUTH_TOKEN)
return response
@app.post("/getCID")
async def getCID(num: int = Form(...), X_AUTH_TOKEN: str = Header()):
# 유효 확인
# WAS를 통해 파일 정보 호출
filePath = get_fileinfo_from_repository(num, X_AUTH_TOKEN)
# S3로부터 파일 다운로드
fileData = get_file_from_S3(filePath)
# NFT.Storage에 파일 저장, CID 획득
cid = store_NFTStorage(fileData)
# FE로 cid 정보 반환
return cid