File size: 5,271 Bytes
aafffbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st

from PIL import Image

from surya.ocr import run_ocr

from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor

from surya.model.recognition.model import load_model as load_rec_model

from surya.model.recognition.processor import load_processor as load_rec_processor

import re

from transformers import AutoModel, AutoTokenizer

import torch

import tempfile

import os


os.environ["CUDA_VISIBLE_DEVICES"] = ""


st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")



# Force CPU if CUDA is not available

device = "cuda" if torch.cuda.is_available() else "cpu"



@st.cache_resource

def load_surya_models():

    det_processor, det_model = load_det_processor(), load_det_model()

    det_model.to(device)

    rec_model, rec_processor = load_rec_model(), load_rec_processor()

    rec_model.to(device)

    return det_processor, det_model, rec_model, rec_processor



@st.cache_resource

def load_got_ocr_model():

    tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)

    model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)

    model.eval().to(device)

    

    # Override .half() and .cuda() to ensure everything runs in float32 and on CPU

    torch.Tensor.half = lambda x: x.float()

    torch.Tensor.cuda = lambda x, **kwargs: x.cpu()



    return tokenizer, model



det_processor, det_model, rec_model, rec_processor = load_surya_models()

tokenizer, got_model = load_got_ocr_model()



st.title("OCR Application  (Aarish Shah Mohsin)")

st.write("Upload an image for OCR processing. Using GOT-OCR for English translations, Picked Surya OCR Model for English+Hindi Translations")



st.sidebar.header("Configuration")

model_choice = st.sidebar.selectbox("Select OCR Model:", ("For English + Hindi", "For English (GOT-OCR)"))



# Store the uploaded image and extracted text in session state

if 'uploaded_image' not in st.session_state:

    st.session_state.uploaded_image = None

if 'extracted_text' not in st.session_state:

    st.session_state.extracted_text = ""



uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])



# Update the session state if a new file is uploaded

if uploaded_file is not None:

    st.session_state.uploaded_image = uploaded_file



predict_button = st.sidebar.button("Predict", key="predict")



col1, col2 = st.columns([2, 1])



# Display the image preview if it's already uploaded

if st.session_state.uploaded_image:

    image = Image.open(st.session_state.uploaded_image)



    with col1:

        # Display a smaller preview of the uploaded image (set width to 300px)

        col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)



# Handle predictions

if predict_button and st.session_state.uploaded_image:

    with st.spinner("Processing..."):

        # Save the uploaded file temporarily

        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:

            temp_file.write(st.session_state.uploaded_image.getvalue())

            temp_file_path = temp_file.name



        image = Image.open(temp_file_path)

        image = image.convert("RGB")  



        if model_choice == "For English + Hindi":

            langs = ["en", "hi"]

            predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)

            text_list = re.findall(r"text='(.*?)'", str(predictions[0]))

            extracted_text = ' '.join(text_list)



            st.session_state.extracted_text = extracted_text  # Save extracted text in session state



        elif model_choice == "For English (GOT-OCR)":

            image_file = temp_file_path

            res = got_model.chat(tokenizer, image_file, ocr_type='ocr')



            st.session_state.extracted_text = res  # Save extracted text in session state



        # Delete the temporary file after processing

        if os.path.exists(temp_file_path):

            os.remove(temp_file_path)



# Search functionality

if st.session_state.extracted_text:

    search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...")



    # Create a pattern to find the search query in a case-insensitive way

    if search_query:

        pattern = re.compile(re.escape(search_query), re.IGNORECASE)

        highlighted_text = st.session_state.extracted_text

        

        # Replace matching text with highlighted version (bright green)

        highlighted_text = pattern.sub(lambda m: f"<span style='background-color: limegreen;'>{m.group(0)}</span>", highlighted_text)



        st.markdown("### Highlighted Search Results:")

        st.markdown(highlighted_text, unsafe_allow_html=True)

    else:

        # If no search query, show the original extracted text

        st.markdown("### Extracted Text:")

        st.markdown(st.session_state.extracted_text, unsafe_allow_html=True)