olamidegoriola's picture
perform vector search using FASS (#7)
2373fa4
raw
history blame
No virus
2.28 kB
# search_content.py
import faiss
import pandas as pd
from sentence_transformers import SentenceTransformer
# Define paths for model, Faiss index, and data file
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
FAISS_INDEX_FILE_PATH = "index.faiss"
DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv"
def load_transformer_model(model_file):
"""Load a sentence transformer model from a file."""
return SentenceTransformer.load(model_file)
def load_faiss_index(filename):
"""Load a Faiss index from a file."""
return faiss.read_index(filename)
def load_data(file_path):
"""Load data from a CSV file and preprocess it."""
data_frame = pd.read_csv(file_path)
data_frame["id"] = data_frame.index
# Create a 'QNA' column that combines 'Questions' and 'Answers'
data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
return data_frame.set_index(["id"], drop=False)
def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5):
"""Search the content using a query and return the top k results."""
# Encode the query using the model
query_vector = transformer_model.encode([query])
# Normalize the query vector
faiss.normalize_L2(query_vector)
# Search the Faiss index using the query vector
top_k = faiss_index.search(query_vector, k)
# Extract the IDs and similarities of the top k results
ids = top_k[1][0].tolist()
similarities = top_k[0][0].tolist()
# Get the corresponding results from the data frame
results = data_frame_indexed.loc[ids]
# Add a column for the similarities
results["similarities"] = similarities
return results
def main_search(query):
"""Main function to execute the search."""
transformer_model = load_transformer_model(MODEL_SAVE_PATH)
faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
data_frame_indexed = load_data(DATA_FILE_PATH)
results = search_content(query, data_frame_indexed, transformer_model, faiss_index)
return results['QNA'] # return the results
if __name__ == "__main__":
query = "school courses"
print(main_search(query)) # print the results if this script is run directly