rbiswasfc commited on
Commit
5fa685d
1 Parent(s): 22534d8
Files changed (4) hide show
  1. .gitignore +6 -0
  2. Dockerfile +13 -0
  3. app.py +277 -0
  4. requirements.txt +15 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ data
2
+ .sesskey
3
+ __pycache__
4
+ *.pyc
5
+ .env
6
+ log_data
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /code
4
+
5
+ COPY --link --chown=1000 . .
6
+
7
+ RUN mkdir -p /tmp/cache/
8
+ RUN chmod a+rwx -R /tmp/cache/
9
+ ENV HF_HUB_CACHE=HF_HOME
10
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
11
+
12
+ ENV PYTHONUNBUFFERED=1 PORT=7860
13
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import dotenv
6
+ import lancedb
7
+ from datasets import load_dataset
8
+ from fasthtml.common import * # noqa
9
+ from huggingface_hub import login, whoami
10
+ from rerankers import Reranker
11
+
12
+ dotenv.load_dotenv()
13
+ login(token=os.environ.get("HF_TOKEN"))
14
+
15
+ hf_user = whoami(os.environ.get("HF_TOKEN"))["name"]
16
+ HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts"
17
+
18
+ abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts")["train"]
19
+ article_ds = load_dataset(HF_REPO_ID_TXT, "articles")["train"]
20
+
21
+ ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert")
22
+
23
+
24
+ uri = "data/zotero-fts"
25
+ db = lancedb.connect(uri)
26
+
27
+ id2abstract = {example["arxiv_id"]: example["abstract"] for example in abstract_ds}
28
+ id2content = {example["arxiv_id"]: example["contents"] for example in article_ds}
29
+ id2title = {example["arxiv_id"]: example["title"] for example in article_ds}
30
+
31
+ arxiv_ids = set(list(id2abstract.keys()))
32
+
33
+ data = []
34
+ for arxiv_id in arxiv_ids:
35
+ abstract = id2abstract[arxiv_id]
36
+ title = id2title[arxiv_id]
37
+ full_text = title
38
+
39
+ for item in id2content[arxiv_id]:
40
+ full_text += f"{item['title']}\n\n{item['content']}"
41
+
42
+ data.append(
43
+ {
44
+ "arxiv_id": arxiv_id,
45
+ "title": title,
46
+ "abstract": abstract,
47
+ "full_text": full_text,
48
+ }
49
+ )
50
+
51
+
52
+ table = db.create_table("articles", data=data, mode="overwrite")
53
+
54
+ table.create_fts_index("full_text", replace=True)
55
+
56
+
57
+ # format results ----
58
+ def _format_results(results):
59
+ ret = []
60
+
61
+ for result in results:
62
+ arx_id = result["arxiv_id"]
63
+ title = result["title"]
64
+ abstract = result["abstract"]
65
+
66
+ if "Abstract\n\n" in abstract:
67
+ abstract = abstract.split("Abstract\n\n")[-1]
68
+
69
+ this_ex = {
70
+ "title": title,
71
+ "url": f"https://arxiv.org/abs/{arx_id}",
72
+ "abstract": abstract,
73
+ }
74
+
75
+ ret.append(this_ex)
76
+
77
+ return ret
78
+
79
+
80
+ def retrieve_and_rerank(query, k=5):
81
+ # retrieve ---
82
+ n_fetch = 25
83
+
84
+ retrieved = (
85
+ table.search(query, vector_column_name="", query_type="fts")
86
+ .limit(n_fetch)
87
+ .select(["arxiv_id", "title", "abstract"])
88
+ .to_list()
89
+ )
90
+
91
+ # re-rank
92
+ docs = [f"{item['title']} {item['abstract']}" for item in retrieved]
93
+ results = ranker.rank(query=query, docs=docs)
94
+
95
+ ranked_doc_ids = []
96
+ for result in results[:k]:
97
+ ranked_doc_ids.append(result.doc_id)
98
+
99
+ final_results = [retrieved[idx] for idx in ranked_doc_ids]
100
+ final_results = _format_results(final_results)
101
+ return final_results
102
+
103
+
104
+ ###########################################################################
105
+ # FastHTML app -----
106
+ ###########################################################################
107
+
108
+ style = Style("""
109
+ :root {
110
+ color-scheme: dark;
111
+ }
112
+ body {
113
+ max-width: 1200px;
114
+ margin: 0 auto;
115
+ padding: 20px;
116
+ line-height: 1.6;
117
+ }
118
+ #query {
119
+ width: 100%;
120
+ margin-bottom: 1rem;
121
+ }
122
+ #search-form button {
123
+ width: 100%;
124
+ }
125
+ #search-results, #log-entries {
126
+ margin-top: 2rem;
127
+ }
128
+ .log-entry {
129
+ border: 1px solid #ccc;
130
+ padding: 10px;
131
+ margin-bottom: 10px;
132
+ }
133
+ .log-entry pre {
134
+ white-space: pre-wrap;
135
+ word-wrap: break-word;
136
+ }
137
+ .htmx-indicator {
138
+ display: none;
139
+ }
140
+ .htmx-request .htmx-indicator {
141
+ display: inline-block;
142
+ }
143
+ .spinner {
144
+ display: inline-block;
145
+ width: 2.5em;
146
+ height: 2.5em;
147
+ border: 0.3em solid rgba(255,255,255,.3);
148
+ border-radius: 50%;
149
+ border-top-color: #fff;
150
+ animation: spin 1s ease-in-out infinite;
151
+ margin-left: 10px;
152
+ vertical-align: middle;
153
+ }
154
+ @keyframes spin {
155
+ to { transform: rotate(360deg); }
156
+ }
157
+ .searching-text {
158
+ font-size: 1.2em;
159
+ font-weight: bold;
160
+ color: #fff;
161
+ margin-right: 10px;
162
+ vertical-align: middle;
163
+ }
164
+ """)
165
+
166
+ # get the fast app and route
167
+ app, rt = fast_app(hdrs=(style,))
168
+
169
+ # Initialize a database to store search logs --
170
+ db = database("log_data/search_logs.db")
171
+ search_logs = db.t.search_logs
172
+
173
+ if search_logs not in db.t:
174
+ search_logs.create(
175
+ id=int,
176
+ timestamp=str,
177
+ query=str,
178
+ results=str,
179
+ pk="id",
180
+ )
181
+
182
+ SearchLog = search_logs.dataclass()
183
+
184
+
185
+ def insert_log_entry(log_entry):
186
+ "Insert a log entry into the database"
187
+ return search_logs.insert(
188
+ SearchLog(
189
+ timestamp=log_entry["timestamp"].isoformat(),
190
+ query=log_entry["query"],
191
+ results=json.dumps(log_entry["results"]),
192
+ )
193
+ )
194
+
195
+
196
+ @rt("/")
197
+ async def get():
198
+ query_form = Form(
199
+ Textarea(id="query", name="query", placeholder="Enter your query..."),
200
+ Button("Submit", type="submit"),
201
+ Div(
202
+ Span("Searching...", cls="searching-text htmx-indicator"),
203
+ Span(cls="spinner htmx-indicator"),
204
+ cls="indicator-container",
205
+ ),
206
+ id="search-form",
207
+ hx_post="/search",
208
+ hx_target="#search-results",
209
+ hx_indicator=".indicator-container",
210
+ )
211
+
212
+ results_div = Div(Div(id="search-results", cls="results-container"))
213
+
214
+ view_logs_link = A("View Logs", href="/logs", cls="view-logs-link")
215
+
216
+ return Titled(
217
+ "Zotero Search", Div(query_form, results_div, view_logs_link, cls="container")
218
+ )
219
+
220
+
221
+ def SearchResult(result):
222
+ "Custom component for displaying a search result"
223
+ return Card(
224
+ H4(A(result["title"], href=result["url"], target="_blank")),
225
+ P(result["abstract"]),
226
+ footer=A("Read more →", href=result["url"], target="_blank"),
227
+ )
228
+
229
+
230
+ def log_query_and_results(query, results):
231
+ log_entry = {
232
+ "timestamp": datetime.now(),
233
+ "query": query,
234
+ "results": [{"title": r["title"], "url": r["url"]} for r in results],
235
+ }
236
+ insert_log_entry(log_entry)
237
+
238
+
239
+ @rt("/search")
240
+ async def post(query: str):
241
+ results = retrieve_and_rerank(query)
242
+ log_query_and_results(query, results)
243
+
244
+ return Div(*[SearchResult(r) for r in results], id="search-results")
245
+
246
+
247
+ def LogEntry(entry):
248
+ return Div(
249
+ H4(f"Query: {entry.query}"),
250
+ P(f"Timestamp: {entry.timestamp}"),
251
+ H5("Results:"),
252
+ Pre(entry.results),
253
+ cls="log-entry",
254
+ )
255
+
256
+
257
+ @rt("/logs")
258
+ async def get():
259
+ logs = search_logs(order_by="-id", limit=50) # Get the latest 50 logs
260
+ log_entries = [LogEntry(log) for log in logs]
261
+ return Titled(
262
+ "Logs",
263
+ Div(
264
+ H2("Recent Search Logs"),
265
+ Div(*log_entries, id="log-entries"),
266
+ A("Back to Search", href="/", cls="back-link"),
267
+ cls="container",
268
+ ),
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ import uvicorn
274
+
275
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
276
+
277
+ # run_uv()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fasthtml-hf>=0.1.1
2
+ python-fasthtml>=0.5.2
3
+ huggingface-hub>=0.20.0
4
+ uvicorn>=0.29
5
+ requests
6
+ srsly
7
+ python-dotenv
8
+ retry
9
+ pandas
10
+ datasets
11
+ tqdm
12
+ tantivy==0.22.0
13
+ lancedb
14
+ rerankers
15
+ transformers