from fastapi.middleware.cors import CORSMiddleware from fastapi import FastAPI import uvicorn from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from transformer_qa_decode import TransformerQADecode from transformers import AutoTokenizer, AutoModelForQuestionAnswering from pydantic import BaseModel tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") qahl = TransformerQADecode(model=model, tokenizer=tokenizer) app = FastAPI() app.mount("/static", StaticFiles(directory="react-qa/build/static"), name="static") origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class QAItem(BaseModel): question:str context:str @app.get("/") def read_root(): html_content = open('react-qa/build/index.html','r').read() return HTMLResponse(content=html_content,status_code=200) @app.post("/question-answer") def read_item(item:QAItem): result = qahl(item.question, item.context) # convert to dict for r in result: for i,x in enumerate(r): x_dict = x._asdict() r[i] = x_dict return result if __name__ == "__main__": uvicorn.run("app:app",port=7680,reload=True)