-
Notifications
You must be signed in to change notification settings - Fork 0
/
rec_server.py
53 lines (43 loc) · 1.54 KB
/
rec_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import re
import numpy as np
from scipy.sparse import coo_matrix
from recoder.model import Recoder
from recoder.nn import DynamicAutoencoder
from recoder.data import UsersInteractions
from flask import request, Flask, jsonify, abort
from flask_cors import CORS
from typing import Dict
MODELS_DIR = os.getenv('MODELS_DIR')
app = Flask(__name__)
CORS(app)
def load_models() -> Dict[str, Recoder]:
model_paths = {}
model_re = re.compile(r'^(?P<ds>.*)\.model$')
for f in os.listdir(MODELS_DIR):
match = model_re.match(f)
if match:
model_paths[match.group('ds')] = os.path.join(MODELS_DIR, f)
recorders = {}
for ds, path in model_paths.items():
model = DynamicAutoencoder()
recoder = Recoder(model)
recoder.init_from_model_file(path)
recorders[ds] = recoder
return recorders
models = load_models()
def percentile(values: np.ndarray):
order = values.argsort()
ranks = order.argsort()
return ranks / len(values)
@app.route('/', methods=['POST'])
def create_task():
if not request.json or not 'ds' in request.json or not 'idxs' in request.json:
abort(400)
model = models[request.json['ds']]
idxs = [int(x) for x in request.json['idxs']]
csr_matrix = coo_matrix((np.ones(len(idxs)), (np.zeros(
len(idxs)), np.array(idxs))), shape=(1, model.num_items)).tocsr()
ui = UsersInteractions(np.array([0]), csr_matrix)
recs = percentile(model.predict(ui)[0].detach().squeeze(0).numpy())
return jsonify({'recs': recs.tolist()}), 201