retrAIced / pages /Question Answering.py
JavierGon12's picture
Remove unnecessary libraries and clean code a bit
cd03817
raw
history blame contribute delete
No virus
3.71 kB
import re
import streamlit as st
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import os
from PIL import Image
import PyPDF2
from pypdf.errors import PdfReadError
from pypdf import PdfReader
import pypdfium2 as pdfium
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device ="cpu"
model.to(device)
#create uploader
document = st.file_uploader(label="Upload the document you want to explore",type=["png",'jpg', "jpeg","pdf"])
question = st.text_input(str("Insert here you question?"))
if document == None:
st.write("Please upload the document in the box above")
else:
try:
PdfReader(document)
pdf = pdfium.PdfDocument(document)
page = pdf.get_page(0)
pil_image = page.render(scale = 300/72).to_pil()
#st.image(pil_image, caption="Document uploaded", use_column_width=True)
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
#question = "What's the total amount?"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(pil_image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
st.image(pil_image,"Document uploaded")
st.write(processor.token2json(sequence))
print(processor.token2json(sequence))
except PdfReadError:
#image = Image.open(document)
#st.image(document, caption="Document uploaded", use_column_width=False)
# prepare decoder inputs
document = Image.open(document)
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
#question = "What's the total amount?"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(document, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
st.image(document,"Document uploaded")
st.write(processor.token2json(sequence))