aarishshahmohsin commited on
Commit
e2da896
1 Parent(s): 8716202

added needed gpu support

Browse files
Files changed (3) hide show
  1. app copy.py +97 -0
  2. app.py +61 -42
  3. requirements.txt +1 -0
app copy.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from surya.ocr import run_ocr
4
+ from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
5
+ from surya.model.recognition.model import load_model as load_rec_model
6
+ from surya.model.recognition.processor import load_processor as load_rec_processor
7
+ import re
8
+ from transformers import AutoModel, AutoTokenizer
9
+ import torch
10
+ import tempfile
11
+ import os
12
+
13
+ st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ # device = "cpu"
17
+
18
+ @st.cache_resource
19
+ def load_surya_models():
20
+ det_processor, det_model = load_det_processor(), load_det_model()
21
+ det_model.to(device)
22
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
23
+ rec_model.to(device)
24
+ return det_processor, det_model, rec_model, rec_processor
25
+
26
+ @st.cache_resource
27
+ def load_got_ocr_model():
28
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
29
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
30
+ model.eval().to(device)
31
+ return tokenizer, model
32
+
33
+ det_processor, det_model, rec_model, rec_processor = load_surya_models()
34
+ tokenizer, got_model = load_got_ocr_model()
35
+
36
+ st.title("OCR Application (Aarish Shah Mohsin)")
37
+ st.write("Upload an image for OCR processing. Using GOT-OCR for English translations, Picked Surya OCR Model for English+Hindi Translations")
38
+
39
+ st.sidebar.header("Configuration")
40
+ model_choice = st.sidebar.selectbox("Select OCR Model:", ("For English + Hindi", "For English (GOT-OCR)"))
41
+
42
+ # Store the uploaded image in session state
43
+ if 'uploaded_image' not in st.session_state:
44
+ st.session_state.uploaded_image = None
45
+
46
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
47
+
48
+ # Update the session state if a new file is uploaded
49
+ if uploaded_file is not None:
50
+ st.session_state.uploaded_image = uploaded_file
51
+
52
+ predict_button = st.sidebar.button("Predict", key="predict")
53
+
54
+ col1, col2 = st.columns([2, 1])
55
+
56
+ # Display the image preview if it's already uploaded
57
+ if st.session_state.uploaded_image:
58
+ image = Image.open(st.session_state.uploaded_image)
59
+
60
+ with col1:
61
+ # Display a smaller preview of the uploaded image (set width to 300px)
62
+ col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)
63
+
64
+ if predict_button and st.session_state.uploaded_image:
65
+ with col2:
66
+ with st.spinner("Processing..."):
67
+ # Save the uploaded file temporarily
68
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
69
+ temp_file.write(st.session_state.uploaded_image.getvalue())
70
+ temp_file_path = temp_file.name
71
+
72
+ image = Image.open(temp_file_path)
73
+ image = image.convert("RGB")
74
+
75
+ if model_choice == "For English + Hindi":
76
+ langs = ["en", "hi"]
77
+ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
78
+ text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
79
+ extracted_text = ' '.join(text_list)
80
+
81
+ with col2:
82
+ st.subheader("Extracted Text (Surya):")
83
+ st.write(extracted_text)
84
+
85
+ elif model_choice == "For English (GOT-OCR)":
86
+ image_file = temp_file_path
87
+ res = got_model.chat(tokenizer, image_file, ocr_type='ocr')
88
+
89
+ with col2:
90
+ st.subheader("Extracted Text (GOT-OCR):")
91
+ st.write(res)
92
+
93
+ # Delete the temporary file after processing
94
+ if os.path.exists(temp_file_path):
95
+ os.remove(temp_file_path)
96
+ # else:
97
+ # st.sidebar.warning("Please upload an image before predicting.")
app.py CHANGED
@@ -12,8 +12,7 @@ import os
12
 
13
  st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
14
 
15
- # device = "cuda" if torch.cuda.is_available() else "cpu"
16
- device = 'cpu'
17
 
18
  @st.cache_resource
19
  def load_surya_models():
@@ -25,13 +24,9 @@ def load_surya_models():
25
 
26
  @st.cache_resource
27
  def load_got_ocr_model():
28
- # tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
29
- # model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
30
- # model.eval().to(device)
31
- # return tokenizer, model
32
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
33
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
34
- model.eval()
35
  return tokenizer, model
36
 
37
  det_processor, det_model, rec_model, rec_processor = load_surya_models()
@@ -43,9 +38,11 @@ st.write("Upload an image for OCR processing. Using GOT-OCR for English translat
43
  st.sidebar.header("Configuration")
44
  model_choice = st.sidebar.selectbox("Select OCR Model:", ("For English + Hindi", "For English (GOT-OCR)"))
45
 
46
- # Store the uploaded image in session state
47
  if 'uploaded_image' not in st.session_state:
48
  st.session_state.uploaded_image = None
 
 
49
 
50
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
51
 
@@ -65,37 +62,59 @@ if st.session_state.uploaded_image:
65
  # Display a smaller preview of the uploaded image (set width to 300px)
66
  col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)
67
 
 
68
  if predict_button and st.session_state.uploaded_image:
69
- with col2:
70
- with st.spinner("Processing..."):
71
- # Save the uploaded file temporarily
72
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
73
- temp_file.write(st.session_state.uploaded_image.getvalue())
74
- temp_file_path = temp_file.name
75
-
76
- image = Image.open(temp_file_path)
77
- image = image.convert("RGB")
78
-
79
- if model_choice == "For English + Hindi":
80
- langs = ["en", "hi"]
81
- predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
82
- text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
83
- extracted_text = ' '.join(text_list)
84
-
85
- with col2:
86
- st.subheader("Extracted Text (Surya):")
87
- st.write(extracted_text)
88
-
89
- elif model_choice == "For English (GOT-OCR)":
90
- image_file = temp_file_path
91
- res = got_model.chat(tokenizer, image_file, ocr_type='ocr')
92
-
93
- with col2:
94
- st.subheader("Extracted Text (GOT-OCR):")
95
- st.write(res)
96
-
97
- # Delete the temporary file after processing
98
- if os.path.exists(temp_file_path):
99
- os.remove(temp_file_path)
100
- # else:
101
- # st.sidebar.warning("Please upload an image before predicting.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
 
17
  @st.cache_resource
18
  def load_surya_models():
 
24
 
25
  @st.cache_resource
26
  def load_got_ocr_model():
 
 
 
 
27
  tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
28
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map=device, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
29
+ model.eval().to(device)
30
  return tokenizer, model
31
 
32
  det_processor, det_model, rec_model, rec_processor = load_surya_models()
 
38
  st.sidebar.header("Configuration")
39
  model_choice = st.sidebar.selectbox("Select OCR Model:", ("For English + Hindi", "For English (GOT-OCR)"))
40
 
41
+ # Store the uploaded image and extracted text in session state
42
  if 'uploaded_image' not in st.session_state:
43
  st.session_state.uploaded_image = None
44
+ if 'extracted_text' not in st.session_state:
45
+ st.session_state.extracted_text = ""
46
 
47
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
48
 
 
62
  # Display a smaller preview of the uploaded image (set width to 300px)
63
  col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)
64
 
65
+ # Handle predictions
66
  if predict_button and st.session_state.uploaded_image:
67
+ # with col2:
68
+ with st.spinner("Processing..."):
69
+ # Save the uploaded file temporarily
70
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
71
+ temp_file.write(st.session_state.uploaded_image.getvalue())
72
+ temp_file_path = temp_file.name
73
+
74
+ image = Image.open(temp_file_path)
75
+ image = image.convert("RGB")
76
+
77
+ if model_choice == "For English + Hindi":
78
+ langs = ["en", "hi"]
79
+ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
80
+ text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
81
+ extracted_text = ' '.join(text_list)
82
+
83
+ st.session_state.extracted_text = extracted_text # Save extracted text in session state
84
+
85
+ # with col2:
86
+ # st.subheader("Extracted Text (Surya):")
87
+ # st.write(extracted_text)
88
+
89
+ elif model_choice == "For English (GOT-OCR)":
90
+ image_file = temp_file_path
91
+ res = got_model.chat(tokenizer, image_file, ocr_type='ocr')
92
+
93
+ st.session_state.extracted_text = res # Save extracted text in session state
94
+
95
+ # with col2:
96
+ # st.subheader("Extracted Text (GOT-OCR):")
97
+ # st.write(res)
98
+
99
+ # Delete the temporary file after processing
100
+ if os.path.exists(temp_file_path):
101
+ os.remove(temp_file_path)
102
+
103
+ # Search functionality
104
+ if st.session_state.extracted_text:
105
+ search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...")
106
+
107
+ # Create a pattern to find the search query in a case-insensitive way
108
+ if search_query:
109
+ pattern = re.compile(re.escape(search_query), re.IGNORECASE)
110
+ highlighted_text = st.session_state.extracted_text
111
+
112
+ # Replace matching text with highlighted version (bright green)
113
+ highlighted_text = pattern.sub(lambda m: f"<span style='background-color: limegreen;'>{m.group(0)}</span>", highlighted_text)
114
+
115
+ st.markdown("### Highlighted Search Results:")
116
+ st.markdown(highlighted_text, unsafe_allow_html=True)
117
+ else:
118
+ # If no search query, show the original extracted text
119
+ st.markdown("### Extracted Text:")
120
+ st.markdown(st.session_state.extracted_text, unsafe_allow_html=True)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ tiktoken
7
  torchvision
8
  verovio
9
  accelerate
 
 
7
  torchvision
8
  verovio
9
  accelerate
10
+ rapidfuzz