RanM's picture
Update app.py
3b0063b verified
raw
history blame contribute delete
No virus
4.05 kB
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import spacy
import re
from typing import List
# Set environment variables for writable directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust the origins as needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the spaCy models once
nlp = spacy.load("en_core_web_sm")
nlp_coref = spacy.load("en_coreference_web_trf")
REPLACE_PRONOUNS = {"he","his", "she", "her", "they", "He", "His", "She", "Her", "They"}
class CorefRequest(BaseModel):
text: str
main_characters: List[str]
def extract_core_name(mention_text, main_characters):
words = mention_text.split()
for character in main_characters:
if character.lower() in mention_text.lower():
return character
return words[-1]
def calculate_pronoun_density(text):
doc = nlp(text)
pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
return pronoun_count / max(named_entity_count, 1), named_entity_count
def resolve_coreferences_across_text(text, main_characters):
doc = nlp_coref(text)
coref_mapping = {}
for key, cluster in doc.spans.items():
if re.match(r"coref_clusters_*", key):
main_mention = cluster[0]
core_name = extract_core_name(main_mention.text, main_characters)
if core_name in main_characters:
for mention in cluster:
for token in mention:
if token.text in REPLACE_PRONOUNS:
core_name_final = core_name if token.text.istitle() else core_name.lower()
coref_mapping[token.i] = core_name_final
resolved_tokens = []
current_sentence_characters = set()
current_sentence = []
for i, token in enumerate(doc):
if token.is_sent_start and current_sentence:
resolved_tokens.extend(current_sentence)
current_sentence_characters.clear()
current_sentence = []
if i in coref_mapping:
core_name = coref_mapping[i]
if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
current_sentence.append(core_name)
current_sentence_characters.add(core_name)
else:
current_sentence.append(token.text)
else:
current_sentence.append(token.text)
resolved_tokens.extend(current_sentence)
resolved_text = " ".join(resolved_tokens)
return remove_consecutive_duplicate_phrases(resolved_text)
def remove_consecutive_duplicate_phrases(text):
words = text.split()
i = 0
while i < len(words) - 1:
j = i + 1
while j < len(words):
if words[i:j] == words[j:j + (j - i)]:
del words[j:j + (j - i)]
else:
j += 1
i += 1
return " ".join(words)
def process_text(text, main_characters):
pronoun_density, named_entity_count = calculate_pronoun_density(text)
min_named_entities = len(main_characters)
if pronoun_density > 0:
return resolve_coreferences_across_text(text, main_characters)
else:
return text
@app.post("/predict")
async def predict(coref_request: CorefRequest):
resolved_text = process_text(coref_request.text, coref_request.main_characters)
if resolved_text:
return {"resolved_text": resolved_text}
raise HTTPException(status_code=400, detail="Coreference resolution failed")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))