import json import os from datetime import datetime import dotenv import lancedb import requests from datasets import load_dataset from fasthtml.common import * # noqa from huggingface_hub import login, whoami # def get_images(query: str): # url = "http://147.189.194.113:80/get_pages" # response = requests.get(url, params={"query": query}) # return response.json() def get_images(query: str): url = "http://47.47.180.31:80/get_pages" response = requests.get(url, params={"query": query}) return response.json() # def rerank_api(query, docs): # url = "http://47.47.180.31:80/rerank" # params = {"query": query, "docs": docs} # response = requests.get(url, params=params) # return response.json() def rerank_api(query, docs): url = "http://47.47.180.31:80/rerank" data = {"query": query, "docs": docs} response = requests.post(url, json=data) # Use POST and send data as JSON return response.json() dotenv.load_dotenv() login(token=os.environ.get("HF_TOKEN")) hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"] article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"] # ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert") uri = "data/zotero-fts" db = lancedb.connect(uri) id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds} id2content = {example["arxiv_id"]: example["contents"] for example in article_ds} id2title = {example["arxiv_id"]: example["title"] for example in article_ds} arxiv_ids = set(list(id2abstract.keys())) data = [] for arxiv_id in arxiv_ids: abstract = id2abstract[arxiv_id] title = id2title[arxiv_id] full_text = title for item in id2content[arxiv_id]: full_text += f"{item['title']}\n\n{item['content']}" data.append( { "arxiv_id": arxiv_id, "title": title, "abstract": abstract, "full_text": full_text, } ) table = db.create_table("articles", data=data, mode="overwrite") table.create_fts_index("full_text", replace=True) # format results ---- def _format_results(results): ret = [] for result in results: arx_id = result["arxiv_id"] title = result["title"] abstract = result["abstract"] if "Abstract\n\n" in abstract: abstract = abstract.split("Abstract\n\n")[-1] this_ex = { "title": title, "url": f"https://arxiv.org/abs/{arx_id}", "abstract": abstract, } ret.append(this_ex) return ret def retrieve_and_rerank(query, k=3): # retrieve --- n_fetch = 25 retrieved = ( table.search(query, vector_column_name="", query_type="fts") .limit(n_fetch) .select(["arxiv_id", "title", "abstract"]) .to_list() ) print(f"Retrieved {len(retrieved)} documents") # re-rank docs = [f"{item['title']} {item['abstract']}" for item in retrieved] # results = ranker.rank(query=query, docs=docs) ranked_doc_ids = rerank_api(query, docs)["ranked_doc_ids"][:k] # ranked_doc_ids = [] # for result in results[:k]: # ranked_doc_ids.append(result.doc_id) final_results = [retrieved[idx] for idx in ranked_doc_ids] final_results = _format_results(final_results) return final_results ########################################################################### # FastHTML app ----- ########################################################################### style = Style(""" :root { color-scheme: dark; } body { max-width: 1200px; margin: 0 auto; padding: 20px; line-height: 1.6; } #query { width: 100%; margin-bottom: 1rem; } #search-form button { width: 100%; } #search-results, #log-entries { margin-top: 2rem; } .log-entry { border: 1px solid #ccc; padding: 10px; margin-bottom: 10px; } .log-entry pre { white-space: pre-wrap; word-wrap: break-word; } .htmx-indicator { display: none; } .htmx-request .htmx-indicator { display: inline-block; } .spinner { display: inline-block; width: 2.5em; height: 2.5em; border: 0.3em solid rgba(255,255,255,.3); border-radius: 50%; border-top-color: #fff; animation: spin 1s ease-in-out infinite; margin-left: 10px; vertical-align: middle; } @keyframes spin { to { transform: rotate(360deg); } } .searching-text { font-size: 1.2em; font-weight: bold; color: #fff; margin-right: 10px; vertical-align: middle; } .image-results { display: flex; flex-wrap: wrap; gap: 10px; margin-top: 20px; } .image-result { width: calc(33% - 10px); text-align: center; } .image-result img { max-width: 100%; height: auto; border-radius: 5px; } """) # get the fast app and route app, rt = fast_app(hdrs=(style,)) # Initialize a database to store search logs -- db = database("log_data/search_logs.db") search_logs = db.t.search_logs if search_logs not in db.t: search_logs.create( id=int, timestamp=str, query=str, results=str, pk="id", ) SearchLog = search_logs.dataclass() def insert_log_entry(log_entry): "Insert a log entry into the database" return search_logs.insert( SearchLog( timestamp=log_entry["timestamp"].isoformat(), query=log_entry["query"], results=json.dumps(log_entry["results"]), ) ) @rt("/") async def get(): query_form = Form( Textarea(id="query", name="query", placeholder="Enter your query..."), Button("Submit", type="submit"), Div( Span("Searching...", cls="searching-text htmx-indicator"), Span(cls="spinner htmx-indicator"), cls="indicator-container", ), id="search-form", hx_post="/search", hx_target="#search-results", hx_indicator=".indicator-container", ) results_div = Div(Div(id="search-results", cls="results-container")) view_logs_link = A("View Logs", href="/logs", cls="view-logs-link") return Titled( "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container") ) def SearchResult(result): "Custom component for displaying a search result" return Card( H4(A(result["title"], href=result["url"], target="_blank")), P(result["abstract"]), footer=A("Read more →", href=result["url"], target="_blank"), ) # def base64_to_pil(base64_string): # # Remove the "data:image/png;base64," part if it exists # if "base64," in base64_string: # base64_string = base64_string.split("base64,")[1] # # Decode the base64 string # img_data = base64.b64decode(base64_string) # # Open the image using PIL # img = Image.open(BytesIO(img_data)) # return img # def process_image(image, max_size=(500, 500), quality=85): # pil_image = base64_to_pil(image) # img_byte_arr = io.BytesIO() # pil_image.thumbnail(max_size) # pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True) # return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}" def ImageResult(image): return Div( Img(src=f"data:image/jpeg;base64,{image}", alt="arxiv image"), cls="image-result", ) # def ImageResult(image): # return Div( # Img(src=process_image(image), alt="arxiv image"), # cls="image-result", # ) def log_query_and_results(query, results): log_entry = { "timestamp": datetime.now(), "query": query, "results": [{"title": r["title"], "url": r["url"]} for r in results], } insert_log_entry(log_entry) @rt("/search") async def post(query: str): image_results = get_images(query) # print(image_results) results = retrieve_and_rerank(query) log_query_and_results(query, results) return Div( Br(), H3("Byaldi Results"), Div(*[ImageResult(img) for img in image_results], cls="image-results"), Br(), H3("Text Results"), Div(*[SearchResult(r) for r in results], id="text-results"), id="search-results", ) # return Div(*[SearchResult(r) for r in results], id="search-results") def LogEntry(entry): return Div( H4(f"Query: {entry.query}"), P(f"Timestamp: {entry.timestamp}"), H5("Results:"), Pre(entry.results), cls="log-entry", ) @rt("/logs") async def get(): logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs log_entries = [LogEntry(log) for log in logs] return Titled( "Logs", Div( H2("Recent Search Logs"), Div(*log_entries, id="log-entries"), A("Back to Search", href="/", cls="back-link"), cls="container", ), ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) # run_uv()