Skip to content

Instantly share code, notes, and snippets.

@sslotin
Created August 25, 2019 15:23
Show Gist options
  • Select an option

  • Save sslotin/3a431c367726e1f1c022fbee9eb6c212 to your computer and use it in GitHub Desktop.

Select an option

Save sslotin/3a431c367726e1f1c022fbee9eb6c212 to your computer and use it in GitHub Desktop.
from argparse import ArgumentParser
from flask import Flask, jsonify, request
import numpy as np
import embedlib
app = Flask(__name__)
# I don't know exactly what it does
from flask_cors import CORS
CORS(app)
@app.route("/classify", methods=["GET"])
def classify():
text = request.args.get("text", type=str, default="")
vec = embed(text)
k = np.dot(embeddings, vec).argmax()
label = labels[k]
confidence = np.dot(vec, embeddings[k])
return jsonify({
'label' : label,
'confidence' : str(confidence)
})
@app.route("/rank", methods=["GET"])
def rank():
text = request.args.get("text", type=str, default="")
vec = embed(text)
response = []
used_labels = []
for k in reversed(np.dot(embeddings, vec).argsort()):
label = labels[k]
if label not in used_labels:
used_labels.append(label)
confidence = np.dot(vec, embeddings[k])
question = questions[k].capitalize()
answer = answers[label]
response.append({
'question' : question,
'answer' : answer,
'confidence' : str(confidence)
})
return jsonify(response)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--port', default=2000)
parser.add_argument('--checkpoint', default='checkpoint')
parser.add_argument('--questions', default='data/cerebra.questions')
parser.add_argument('--answers', default='data/cerebra.answers')
args = parser.parse_args()
embed = embedlib.Embedder('bert-base-en', args.checkpoint)
with open(args.answers, 'r') as file:
answers = yaml.safe_load(file)
with open(args.questions, 'r') as file:
questions_dict = yaml.safe_load(file)
labels = []
embeddings = []
questions = []
for key, l in questions_dict.items():
for s in l:
labels.append(key)
questions.append(s)
embeddings.append(embed(s))
embeddings = np.vstack(embeddings)
app.run(host='0.0.0.0', port=args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment