research / lrt /lrt.py
haoqi7's picture
Upload 47 files
16188ba
raw
history blame contribute delete
No virus
5.37 kB
from .clustering import *
from typing import List
import textdistance as td
from .utils import UnionFind, ArticleList
from .academic_query import AcademicQuery
import streamlit as st
from tokenizers import Tokenizer
from .clustering.clusters import KeyphraseCount
class LiteratureResearchTool:
def __init__(self, cluster_config: Configuration = None):
self.literature_search = AcademicQuery
self.cluster_pipeline = ClusterPipeline(cluster_config)
def __postprocess_clusters__(self, clusters: ClusterList,query: str) ->ClusterList:
'''
add top-5 keyphrases to each cluster
:param clusters:
:return: clusters
'''
def condition(x: KeyphraseCount, y: KeyphraseCount):
return td.ratcliff_obershelp(x.keyphrase, y.keyphrase) > 0.8
def valid_keyphrase(x:KeyphraseCount):
tmp = x.keyphrase
return tmp is not None and tmp != '' and not tmp.isspace() and len(tmp)!=1\
and tmp != query
for cluster in clusters:
keyphrases = cluster.get_keyphrases() # [kc]
keyphrases = list(filter(valid_keyphrase,keyphrases))
unionfind = UnionFind(keyphrases, condition)
unionfind.union_step()
tmp = unionfind.get_unions() # dict(root_id = [kc])
tmp = tmp.values() # [[kc]]
# [[kc]] -> [ new kc] -> sorted
tmp = [KeyphraseCount.reduce(x) for x in tmp]
keyphrases = sorted(tmp,key= lambda x: x.count,reverse=True)[:5]
keyphrases = [x.keyphrase for x in keyphrases]
# keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
# for i in keyphrases:
# tmp = '/'.join(i)
# cluster.top_5_keyphrases.append(tmp)
cluster.top_5_keyphrases = keyphrases
return clusters
def __call__(self,
query: str,
num_papers: int,
start_year: int,
end_year: int,
max_k: int,
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
loading_ctx_manager = None,
standardization = False
):
for platform in platforms:
if loading_ctx_manager:
with loading_ctx_manager():
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k,standardization)
else:
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k,standardization)
clusters.sort()
yield clusters,articles
def __platformPipeline__(self,platforn_name:str,
query: str,
num_papers: int,
start_year: int,
end_year: int,
max_k: int,
standardization
) -> (ClusterList,ArticleList):
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def ieee_process(
query: str,
num_papers: int,
start_year: int,
end_year: int,
):
articles = ArticleList.parse_ieee_articles(
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def arxiv_process(
query: str,
num_papers: int,
):
articles = ArticleList.parse_arxiv_articles(
self.literature_search.arxiv(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
def pwc_process(
query: str,
num_papers: int,
):
articles = ArticleList.parse_pwc_articles(
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
clusters = self.__postprocess_clusters__(clusters,query)
return clusters, articles
if platforn_name == 'IEEE':
return ieee_process(query,num_papers,start_year,end_year)
elif platforn_name == 'Arxiv':
return arxiv_process(query,num_papers)
elif platforn_name == 'Paper with Code':
return pwc_process(query,num_papers)
else:
raise RuntimeError('This platform is not supported. Please open an issue on the GitHub.')