File size: 3,277 Bytes
6aa994f
 
 
 
 
 
 
 
 
 
 
0768472
6aa994f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0768472
6aa994f
 
 
 
 
 
 
 
 
 
 
 
 
7e9fae4
 
e27236f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from flask import Flask, render_template, request, jsonify
from qdrant_client import QdrantClient
from qdrant_client import models
import torch.nn.functional as F
import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from qdrant_client.models import Batch, PointStruct
from pickle import load, dump
import numpy as np
import os, time, sys
from datetime import datetime as dt
from datetime import timedelta
from datetime import timezone

app = Flask(__name__)

# Initialize Qdrant Client and other required settings
qdrant_api_key = os.environ.get("qdrant_api_key")
qdrant_url = os.environ.get("qdrant_url")

client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)

def e5embed(query):
  batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
  batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
  outputs = model(**batch_dict)
  embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
  embeddings = F.normalize(embeddings, p=2, dim=1)
  embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
  return embeddings

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/search", methods=["POST"])
def search():
    query = request.form["query"]
    topN = 200  # Define your topN value


    print('QUERY: ',query)
    if query.strip().startswith('tilc:'):
        collection_name = 'tils'
        qvector = "context"
        query = query.replace('tilc:', '')
    elif query.strip().startswith('til:'):
        collection_name = 'tils'
        qvector = "title"
        query = query.replace('til:', '')
    else: collection_name = 'jks'

    timh = time.time()
    sq = e5embed(query)    
    print('EMBEDDING TIME: ', time.time() - timh)

    timh = time.time()
    if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
    else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
    print('SEARCH TIME: ', time.time() - timh)
    
    print(results[0].payload['text'].split('\n'))
    try: 
        results = [{"text": x.payload['text'], "date": str(int(x.payload['date'])), "id": x.id} for x in results]  # Implement your Qdrant search here     
        return jsonify(results)
    except:
        return jsonify([])

@app.route("/delete_joke", methods=["POST"])
def delete_joke():
    joke_id = request.form["id"]
    print('Deleting joke no', joke_id)
    client.delete(collection_name="jks", points_selector=models.PointIdsList(points=[int(joke_id)],),)
    return jsonify({"deleted": True})

if __name__ == "__main__":
    app.run(host="0.0.0.0", debug=True, port=7860)