kmnis commited on
Commit
c9bfa1f
1 Parent(s): ddfa01f

Added app.py and data loader

Browse files
Files changed (2) hide show
  1. app.py +46 -0
  2. data_loader.py +127 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import os
4
+
5
+ from PIL import Image
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.keras.models import load_model
9
+ from tensorflow.keras.utils import array_to_img
10
+
11
+ import sys
12
+ sys.path.append(".")
13
+ from data_loader import preprocess_test_image
14
+
15
+ import warnings
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+ @st.cache_data
20
+ def get_model():
21
+ model_path = "models/pix2pix.keras"
22
+ if not os.path.exists(model_path):
23
+ model_path = "../saved_models/pix2pix/pix2pix.keras"
24
+
25
+ with st.spinner('Loading the model...'):
26
+ pix2pix = load_model(model_path)
27
+ return pix2pix
28
+
29
+ st.markdown("<center><h1>ComicBooks.AI</h1></center>", unsafe_allow_html=True)
30
+ st.caption("<center>Upload your photo to see how a comic book version of yourself would look!</center>", unsafe_allow_html=True)
31
+
32
+ uploaded_file = st.file_uploader("Upload an image")
33
+
34
+ if uploaded_file is not None:
35
+ img = Image.open(uploaded_file)
36
+ img.save("uploaded_image.png")
37
+ st.image(uploaded_file)
38
+
39
+ img = preprocess_test_image("uploaded_image.png")
40
+ img = tf.expand_dims(img, axis=0)
41
+
42
+ pix2pix = get_model()
43
+ st.write("Model Loaded!!! Processing the image...")
44
+ pred = array_to_img(pix2pix.predict(img)[0] * 0.5 + 0.5)
45
+ st.image(pred)
46
+ _ = os.system("rm uploaded_image.png")
data_loader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+
4
+ # Define Training variable
5
+ BUFFER_SIZE = 400
6
+ BATCH_SIZE = 32
7
+ IMG_WIDTH = 256
8
+ IMG_HEIGHT = 256
9
+ AUTOTUNE = tf.data.AUTOTUNE
10
+
11
+
12
+ def load_images(image_file):
13
+ image = tf.io.read_file(image_file)
14
+ image = tf.image.decode_jpeg(image)
15
+
16
+ image = tf.cast(image, tf.float32)
17
+ return image
18
+
19
+
20
+ def resize(content_image, style_image, height, width):
21
+ content_image = tf.image.resize(content_image, [height, width],
22
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
23
+ if style_image is not None:
24
+ style_image = tf.image.resize(style_image, [height, width],
25
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
26
+
27
+ return content_image, style_image
28
+
29
+
30
+ def random_crop(content_image, style_image):
31
+ stacked_image = tf.stack([content_image, style_image], axis=0)
32
+ cropped_image = tf.image.random_crop(
33
+ stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
34
+
35
+ return cropped_image[0], cropped_image[1]
36
+
37
+
38
+ def normalize(content_image, style_image):
39
+ content_image = (content_image / 127.5) - 1
40
+
41
+ if style_image is not None:
42
+ style_image = (style_image / 127.5) - 1
43
+
44
+ return content_image, style_image
45
+
46
+
47
+ @tf.function()
48
+ def random_jitter(content_image, style_image):
49
+ # resizing to 286 x 286 x 3
50
+ content_image, style_image = resize(content_image, style_image, 286, 286)
51
+
52
+ # randomly cropping to 256 x 256 x 3
53
+ content_image, style_image = random_crop(content_image, style_image)
54
+
55
+ if tf.random.uniform(()) > 0.5:
56
+ # random mirroring
57
+ content_image = tf.image.flip_left_right(content_image)
58
+ style_image = tf.image.flip_left_right(style_image)
59
+
60
+ return content_image, style_image
61
+
62
+
63
+ def preprocess_train_image(content_path, style_path):
64
+ content_image = load_images(content_path)
65
+ style_image = load_images(style_path)
66
+
67
+ content_image, style_image = random_jitter(content_image, style_image)
68
+ content_image, style_image = normalize(content_image, style_image)
69
+
70
+ return content_image, style_image
71
+
72
+
73
+ def preprocess_test_image(content_path, style_path=None):
74
+ content_image = load_images(content_path)
75
+
76
+ if style_path is None:
77
+ style_image = None
78
+ else:
79
+ style_image = load_images(style_path)
80
+
81
+ content_image, style_image = resize(content_image, style_image,
82
+ IMG_HEIGHT, IMG_WIDTH)
83
+ content_image, style_image = normalize(content_image, style_image)
84
+
85
+ if style_image is None:
86
+ return content_image
87
+ else:
88
+ return content_image, style_image
89
+
90
+
91
+ def create_image_loader(path):
92
+ images = os.listdir(path)
93
+ images = [os.path.join(path, p) for p in images]
94
+ images.sort()
95
+
96
+ # split the images in train and test
97
+ total_images = len(images)
98
+ train = images[: int(0.9 * total_images)]
99
+ test = images[int(0.9 * total_images):]
100
+
101
+ # Build the tf.data datasets.
102
+ train_ds = tf.data.Dataset.from_tensor_slices(train)
103
+ test_ds = tf.data.Dataset.from_tensor_slices(test)
104
+
105
+ return train_ds, test_ds
106
+
107
+
108
+ def data_loader(content_path="../data/face", style_path="../data/comics"):
109
+ train_content_ds, test_content_ds = create_image_loader(content_path)
110
+ train_style_ds, test_style_ds = create_image_loader(style_path)
111
+
112
+ # Zipping the style and content datasets.
113
+ train_ds = (
114
+ tf.data.Dataset.zip((train_content_ds, train_style_ds))
115
+ .map(preprocess_train_image)
116
+ .shuffle(BUFFER_SIZE)
117
+ .batch(BATCH_SIZE)
118
+ .prefetch(AUTOTUNE)
119
+ )
120
+
121
+ test_ds = (
122
+ tf.data.Dataset.zip((test_content_ds, test_style_ds))
123
+ .map(preprocess_test_image)
124
+ .batch(BATCH_SIZE)
125
+ .prefetch(AUTOTUNE)
126
+ )
127
+ return train_ds, test_ds