SWHL commited on
Commit
c2ce78d
1 Parent(s): e29d8d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -80
app.py CHANGED
@@ -1,115 +1,155 @@
1
  # -*- encoding: utf-8 -*-
2
  # @Author: SWHL
3
  # @Contact: liekkaskono@163.com
4
- from pathlib import Path
 
5
 
6
  import numpy as np
 
 
7
  import streamlit as st
8
  from PIL import Image
9
  from rapid_latex_ocr import LatexOCR
10
- from streamlit_cropper import st_cropper
11
- from streamlit_image_select import image_select
12
 
13
- st.set_option("deprecation.showfileUploaderEncoding", False)
 
14
 
 
15
 
16
- class RecEquation:
17
- def __init__(self, model_dir: str):
18
- model_dir = Path(model_dir)
19
 
20
- image_resizer_path = model_dir / "image_resizer.onnx"
21
- encoder_path = model_dir / "encoder.onnx"
22
- decoder_path = model_dir / "decoder.onnx"
23
- tokenizer_json = model_dir / "tokenizer.json"
24
- self.model = LatexOCR(
25
- image_resizer_path=str(image_resizer_path),
26
- encoder_path=str(encoder_path),
27
- decoder_path=str(decoder_path),
28
- tokenizer_json=str(tokenizer_json),
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def __call__(self, img: np.ndarray):
32
- result, elapse = self.model(img)
33
- return result, elapse
 
 
 
 
34
 
35
 
36
  if __name__ == "__main__":
37
  st.markdown(
38
- "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidLatexOCR' style='text-decoration: none'>Rapid Latex OCR</a></h1>",
39
  unsafe_allow_html=True,
40
  )
41
  st.markdown(
42
  """
43
- <p align="left">
44
  <a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
45
  <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
46
  <a href="https://pepy.tech/project/rapid_latex_ocr"><img src="https://static.pepy.tech/personalized-badge/rapid_latex_ocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads"></a>
47
  <a href="https://pypi.org/project/rapid_latex_ocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapid_latex_ocr"></a>
48
  <a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
49
  <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
 
50
  </p>
51
  """,
52
  unsafe_allow_html=True,
53
  )
54
 
55
- # Upload an image and set some options for demo purposes
56
- img_file = st.sidebar.file_uploader(label="Upload a file", type=["png", "jpg"])
57
- realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=False)
58
- box_color = st.sidebar.color_picker(label="Box Color", value="#0000FF")
59
- aspect_choice = st.sidebar.radio(
60
- label="Aspect Ratio", options=["Free", "1:1", "16:9", "4:3", "2:3"]
61
- )
62
- aspect_dict = {
63
- "Free": None,
64
- "1:1": (1, 1),
65
- "16:9": (16, 9),
66
- "4:3": (4, 3),
67
- "2:3": (2, 3),
68
- }
69
- aspect_ratio = aspect_dict[aspect_choice]
70
- with st.sidebar.container():
71
- img = image_select(
72
- label="Examples(click to select):",
73
- images=[
74
- "images/equation.png",
75
- "images/eq_2.png",
76
- "images/eq_3.png",
77
- "images/eq_4.png",
78
- ],
79
- key="equation_default",
80
- use_container_width=False,
81
- )
82
-
83
- rec_eq_sys = RecEquation(model_dir="models")
84
-
85
- select_img_container = st.container()
86
-
87
- st.markdown("#### Select image:")
88
- img_empty = st.empty()
89
-
90
- img_empty.image(img, use_column_width=False)
91
- rec_res, elapse = rec_eq_sys(img)
92
 
93
- if img_file:
94
- img = Image.open(img_file)
95
-
96
- # Get a cropped image from the frontend
97
- with select_img_container:
98
- if not realtime_update:
99
- select_img_container.markdown("#### Double click to save crop")
100
-
101
- img = st_cropper(
102
- img,
103
- realtime_update=realtime_update,
104
- box_color=box_color,
105
- aspect_ratio=aspect_ratio,
106
- )
107
-
108
- img_empty.image(img, use_column_width=False)
109
- rec_res, elapse = rec_eq_sys(np.array(img))
110
 
111
- st.markdown(f"#### Rec result (cost: {elapse:.4f}s):")
112
- st.latex(rec_res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- st.markdown("#### Latex source code:")
115
- st.code(rec_res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- encoding: utf-8 -*-
2
  # @Author: SWHL
3
  # @Contact: liekkaskono@163.com
4
+ import hashlib
5
+ import io
6
 
7
  import numpy as np
8
+ import pandas as pd
9
+ import pypdfium2
10
  import streamlit as st
11
  from PIL import Image
12
  from rapid_latex_ocr import LatexOCR
13
+ from streamlit_drawable_canvas import st_canvas
 
14
 
15
+ MAX_WIDTH = 800
16
+ MAX_HEIGHT = 1000
17
 
18
+ st.set_page_config(layout="wide")
19
 
 
 
 
20
 
21
+ @st.cache_resource()
22
+ def load_model_cached():
23
+ return LatexOCR()
24
+
25
+
26
+ def get_canvas_hash(pil_image):
27
+ return hashlib.md5(pil_image.tobytes()).hexdigest()
28
+
29
+
30
+ def open_pdf(pdf_file):
31
+ stream = io.BytesIO(pdf_file.getvalue())
32
+ return pypdfium2.PdfDocument(stream)
33
+
34
+
35
+ @st.cache_data()
36
+ def page_count(pdf_file):
37
+ doc = open_pdf(pdf_file)
38
+ return len(doc)
39
+
40
+
41
+ @st.cache_data()
42
+ def get_page_image(pdf_file, page_num, dpi=96):
43
+ doc = open_pdf(pdf_file)
44
+ renderer = doc.render(
45
+ pypdfium2.PdfBitmap.to_pil,
46
+ page_indices=[page_num - 1],
47
+ scale=dpi / 72,
48
+ )
49
+ png = list(renderer)[0]
50
+ png_image = png.convert("RGB")
51
+ return png_image
52
+
53
+
54
+ @st.cache_data()
55
+ def get_uploaded_image(in_file):
56
+ if isinstance(in_file, Image.Image):
57
+ return in_file.convert("RGB")
58
+ return Image.open(in_file).convert("RGB")
59
+
60
+
61
+ def resize_image(pil_image):
62
+ if pil_image is None:
63
+ return
64
+ pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
65
 
66
+
67
+ @st.cache_data()
68
+ def get_image_size(pil_image):
69
+ if pil_image is None:
70
+ return MAX_HEIGHT, MAX_WIDTH
71
+ height, width = pil_image.height, pil_image.width
72
+ return height, width
73
 
74
 
75
  if __name__ == "__main__":
76
  st.markdown(
77
+ "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidLatexOCR' style='text-decoration: none'>Rapid ⚡︎ LaTeX OCR</a></h1>",
78
  unsafe_allow_html=True,
79
  )
80
  st.markdown(
81
  """
82
+ <p align="center">
83
  <a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a>
84
  <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
85
  <a href="https://pepy.tech/project/rapid_latex_ocr"><img src="https://static.pepy.tech/personalized-badge/rapid_latex_ocr?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads"></a>
86
  <a href="https://pypi.org/project/rapid_latex_ocr/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapid_latex_ocr"></a>
87
  <a href="https://semver.org/"><img alt="SemVer2.0" src="https://img.shields.io/badge/SemVer-2.0-brightgreen"></a>
88
  <a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"></a>
89
+ <a href="https://github.com/RapidAI/RapidLatexOCR"><img src="https://img.shields.io/badge/Github-link-brightgreen.svg"></a>
90
  </p>
91
  """,
92
  unsafe_allow_html=True,
93
  )
94
 
95
+ col1, col2 = st.columns([0.5, 0.5])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ in_file = st.sidebar.file_uploader(
98
+ "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]
99
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ if in_file is None:
102
+ st.stop()
103
+
104
+ filetype = in_file.type
105
+ if "pdf" in filetype:
106
+ page_count = page_count(in_file)
107
+ page_number = st.sidebar.number_input(
108
+ f"Page number out of {page_count}:",
109
+ min_value=1,
110
+ value=1,
111
+ max_value=page_count,
112
+ )
113
+ pil_image = get_page_image(in_file, page_number)
114
+ else:
115
+ pil_image = get_uploaded_image(in_file)
116
+
117
+ resize_image(pil_image)
118
+ canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
119
+
120
+ model = load_model_cached()
121
+ with col1:
122
+ canvas_result = st_canvas(
123
+ fill_color="rgba(255, 165, 0, 0.1)",
124
+ stroke_width=1,
125
+ stroke_color="#FFAA00",
126
+ background_color="#FFF",
127
+ background_image=pil_image,
128
+ update_streamlit=True,
129
+ height=get_image_size(pil_image)[0],
130
+ width=get_image_size(pil_image)[1],
131
+ drawing_mode="rect",
132
+ point_display_radius=0,
133
+ key=canvas_hash,
134
+ )
135
 
136
+ if canvas_result.json_data is not None:
137
+ objects = pd.json_normalize(canvas_result.json_data["objects"])
138
+ bbox_list = None
139
+ if objects.shape[0] > 0:
140
+ boxes = objects[objects["type"] == "rect"][
141
+ ["left", "top", "width", "height"]
142
+ ]
143
+ boxes["right"] = boxes["left"] + boxes["width"]
144
+ boxes["bottom"] = boxes["top"] + boxes["height"]
145
+ bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
146
+
147
+ if bbox_list:
148
+ with col2:
149
+ bbox_nums = len(bbox_list)
150
+ for i, bbox in enumerate(bbox_list):
151
+ input_img = pil_image.crop(bbox)
152
+ rec_res, elapse = model(np.array(input_img))
153
+ st.markdown(f"#### {i + 1}")
154
+ st.latex(rec_res)
155
+ st.code(rec_res)