-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
71 lines (52 loc) · 1.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from fastai.text import Path, warnings, load_learner, torch, ClassificationInterpretation, TextClassificationInterpretation
from flask import Flask, jsonify, request
from flask_restful import Resource, Api, reqparse
from flask_cors import CORS
# classes = ['GCUP-EC-GC', 'GCs', 'GP', 'GS', 'GVOX']
#
path = Path(__file__).parent
models_dir = path / 'models'
models_dir.mkdir(parents=True, exist_ok=True)
warnings.filterwarnings('ignore') # "error", "ignore", "always", "default", "module" or "on'
app = Flask(__name__)
api = Api(app)
CORS(app)
# Define parser and request args
parser = reqparse.RequestParser()
parser.add_argument('input_text', type=str)
# Load model
learn_c = load_learner(models_dir)
preds = torch.load(models_dir / 'preds.pt')
y = torch.load(models_dir / 'y.pt')
losses = torch.load(models_dir / 'losses.pt')
ci = ClassificationInterpretation(learn_c, preds, y, losses)
txt_ci = TextClassificationInterpretation(learn_c, preds, y, losses)
class status (Resource):
def get(self):
try:
return {'data': 'Api is Running'}
except:
return {'data': 'An Error Occurred during fetching Api'}
class Predict(Resource):
def post(self):
json_data = request.get_json(force=True)
input_text = json_data['input_text']
pred = learn_c.predict(input_text)[2] * 100
pred_list = pred.tolist()
tokens, attention = txt_ci.intrinsic_attention(text=input_text)
return jsonify(
attention=np.array(attention).tolist(),
text=tokens.text,
preds={
'GCUP-EC-GC': pred_list[0],
'GS': pred_list[3],
'GCs': pred_list[1],
'GP': pred_list[2],
'GVOX': pred_list[4],
},
)
api.add_resource(status, '/')
api.add_resource(Predict, '/predict')
if __name__ == '__main__':
app.run()