embedding-haraj / app.py
ilhamsyahids's picture
normalized cosine sim to be between 0 and 1
be74b55
raw
history blame contribute delete
No virus
2.91 kB
import requests
import gradio as gr
from numpy import dot
from numpy.linalg import norm
from os import environ
num_of_sentences = 2
endpoint = environ.get("ENDPOINT", "")
api_key = environ.get("API_KEY", "")
def cos_sim(a, b):
return dot(a, b) / (norm(a) * norm(b))
def get_embeddings(model, text):
# make POST request to the endpoint
response = requests.post(
f"{endpoint}/{model}",
json={"text": text},
headers={"Authorization": "Bearer " + api_key},
)
return response.json()["vector"]
def calculate_similarities(model, text, *sentences):
# get embeddings for the input text
text_embedding = get_embeddings(model, text)
# get embeddings for the input sentences
sentences_embeddings = [get_embeddings(model, sentence) for sentence in sentences]
# calculate cosine similarity between the input text and the input sentences
similarities = {}
# to normalize cosine similarity to be between 0 and 1
minx = -1
maxx = 1
for sentence, sentence_embedding in zip(sentences, sentences_embeddings):
sim = cos_sim(text_embedding, sentence_embedding)
normalized_sim = (sim - minx) / (maxx - minx)
similarities[sentence] = normalized_sim
return similarities
demo = gr.Blocks()
with demo:
with gr.Row():
with gr.Column():
model = gr.inputs.Radio(
["roberta", "ada"], default="roberta", label="Model"
)
text = gr.Textbox(lines=3, label="Input Text")
inp_sentences = [
gr.Textbox(lines=3, label="Sentence " + str(i + 1))
for i in range(num_of_sentences)
]
btn = gr.Button(text="Submit")
with gr.Column():
output = gr.Label(label="Output", show_label=False)
# submit btn
btn.click(
calculate_similarities,
inputs=[model, text, *inp_sentences],
outputs=[output],
)
gr.Examples(
examples=[
["roberta", "This is happy person", "هذا شخص سعيد", "هذه قطة سعيدة"],
["ada", "This is happy person", "هذا شخص سعيد", "هذه قطة سعيدة"],
["roberta", "هذا شخص سعيد", "هذه قطة سعيدة", "This is happy person"],
["ada", "هذا شخص سعيد", "هذه قطة سعيدة", "This is happy person"],
["roberta", "car", "camry", "toyota"],
["ada", "camry", "toy", "toyota"],
["roberta", "ihpone for sale", "iphone for sale", "camry for sale"],
["ada", "ihpone for sale", "iphone for sale", "camry for sale"],
],
inputs=[model, text, *inp_sentences],
outputs=output,
fn=calculate_similarities,
# cache_examples=True,
)
if __name__ == "__main__":
demo.launch()