-
Notifications
You must be signed in to change notification settings - Fork 230
/
densenet_service.py
123 lines (95 loc) · 3.55 KB
/
densenet_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import io
import json
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
class PyTorchImageClassifier():
"""
PyTorchImageClassifier service class. This service takes a flower
image and returns the name of that flower.
"""
def __init__(self):
self.checkpoint_file_path = None
self.model = None
self.mapping = None
self.device = "cpu"
self.initialized = False
def initialize(self, context):
"""
Load the model and mapping file to perform infernece.
"""
properties = context.system_properties
model_dir = properties.get("model_dir")
# Read checkpoint file
checkpoint_file_path = os.path.join(model_dir, "model.pth")
if not os.path.isfile(checkpoint_file_path):
raise RuntimeError("Missing model.pth file.")
# Prepare the model
checkpoint = torch.load(checkpoint_file_path, map_location='cpu')
model = checkpoint['model']
model.classifier = checkpoint['classifier']
model.load_state_dict(checkpoint['state_dict'])
model.class_to_idx = checkpoint['class_to_idx']
for param in model.parameters():
param.requires_grad = False
self.model = model
# Read the mapping file, index to flower
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
if not os.path.isfile(mapping_file_path):
raise RuntimeError("Missing the mapping file")
with open(mapping_file_path) as f:
self.mapping = json.load(f)
self.initialized = True
def preprocess(self, data):
"""
Scales, crops, and normalizes a PIL image for a PyTorch model,
returns an Numpy array
"""
image = data[0].get("data")
if image is None:
image = data[0].get("body")
my_preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open(io.BytesIO(image))
image = my_preprocess(image)
return image
def inference(self, img, topk=5):
''' Predict the class (or classes) of an image using a trained deep learning model.
'''
# Convert 2D image to 1D vector
img = np.expand_dims(img, 0)
img = torch.from_numpy(img)
self.model.eval()
inputs = Variable(img).to(self.device)
logits = self.model.forward(inputs)
ps = F.softmax(logits,dim=1)
topk = ps.cpu().topk(topk)
probs, classes = (e.data.numpy().squeeze().tolist() for e in topk)
results = []
for i in range(len(probs)):
tmp = dict()
tmp[self.mapping[str(classes[i])]] = probs[i]
results.append(tmp)
return [results]
def postprocess(self, inference_output):
return inference_output
# Following code is not necessary if your service class contains `handle(self, data, context)` function
_service = PyTorchImageClassifier()
def handle(data, context):
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data