-
Notifications
You must be signed in to change notification settings - Fork 8
/
classify_ingredients.py
47 lines (40 loc) · 1.09 KB
/
classify_ingredients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import cv_model
import utils
import argparse
from imports import*
keys = 'input_path','load_model_path','device'
args = {k:v for k,v in zip(keys,sys.argv[1:])}
print()
print('+------------------------------------+')
print('| Dream AI |')
print('+------------------------------------+')
print()
device = torch.device(args['device'])
input_path = args['input_path']
img = plt.imread(input_path)
if len(args['load_model_path']) > 0:
load_model_path = args['load_model_path']
else:
load_model_path = 'mlflow_pretrained_models/ingredients101_model'
net = mlflow.pytorch.load_model(load_model_path,map_location=device)
net.device = device
net = net.to(device)
batch = utils.get_test_input(paths=[input_path])
batch = batch.to(device)
food,ing = net.classify(batch, thresh=0.4)
ing = ing[0]
if type(ing) == str:
ing = [ing]
food = ','.join(food).title()
if len(ing) == 0:
ing = 'Unknown'
else:
ing = ', '.join(ing).title()
pred = 'Food Name: {}\nFood Ingredients: {}'.format(food,ing)
plt.imshow(img)
plt.title(pred)
plt.show()
print(pred)
print()
del net
del batch