rishh76 commited on
Commit
5383233
1 Parent(s): 9c0ae66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py CHANGED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict
2
+ import requests
3
+ import random
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ from diffusers import FluxInpaintPipeline
9
+
10
+ # Constants
11
+ MARKDOWN_TEXT = """
12
+ # FLUX.1 Inpainting 🔥
13
+ Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) for
14
+ creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
15
+ for taking it to the next level by enabling inpainting with the FLUX.
16
+ """
17
+
18
+ MAX_SEED_VALUE = np.iinfo(np.int32).max
19
+ DEFAULT_IMAGE_SIZE = 1024
20
+ DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Model initialization
23
+ pipeline = FluxInpaintPipeline.from_pretrained(
24
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE_TYPE)
25
+
26
+ def adjust_image_size(
27
+ original_size: Tuple[int, int], max_dimension: int = DEFAULT_IMAGE_SIZE
28
+ ) -> Tuple[int, int]:
29
+ width, height = original_size
30
+ scaling_factor = max_dimension / max(width, height)
31
+ new_width = int(width * scaling_factor) - (int(width * scaling_factor) % 32)
32
+ new_height = int(height * scaling_factor) - (int(height * scaling_factor) % 32)
33
+ return new_width, new_height
34
+
35
+ def process_images(
36
+ input_data: Dict,
37
+ prompt: str,
38
+ seed: int,
39
+ randomize_seed: bool,
40
+ strength: float,
41
+ num_steps: int,
42
+ progress=gr.Progress(track_tqdm=True)
43
+ ):
44
+ if not prompt:
45
+ gr.Info("Please enter a text prompt.")
46
+ return None, None
47
+
48
+ background_img = input_data['background']
49
+ mask_img = input_data['layers'][0]
50
+
51
+ if background_img is None:
52
+ gr.Info("Please upload an image.")
53
+ return None, None
54
+
55
+ if mask_img is None:
56
+ gr.Info("Please draw a mask on the image.")
57
+ return None, None
58
+
59
+ new_width, new_height = adjust_image_size(background_img.size)
60
+ resized_bg = background_img.resize((new_width, new_height), Image.LANCZOS)
61
+ resized_mask = mask_img.resize((new_width, new_height), Image.LANCZOS)
62
+
63
+ if randomize_seed:
64
+ seed = random.randint(0, MAX_SEED_VALUE)
65
+ generator = torch.Generator().manual_seed(seed)
66
+
67
+ result_image = pipeline(
68
+ prompt=prompt,
69
+ image=resized_bg,
70
+ mask_image=resized_mask,
71
+ width=new_width,
72
+ height=new_height,
73
+ strength=strength,
74
+ generator=generator,
75
+ num_inference_steps=num_steps
76
+ ).images[0]
77
+
78
+ return result_image, resized_mask
79
+
80
+ # Gradio interface
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown(MARKDOWN_TEXT)
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ img_editor = gr.ImageEditor(
87
+ label='Image',
88
+ type='pil',
89
+ sources=["upload", "webcam"],
90
+ image_mode='RGB',
91
+ layers=False,
92
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")
93
+ )
94
+
95
+ with gr.Row():
96
+ text_input = gr.Text(
97
+ label="Prompt",
98
+ show_label=False,
99
+ max_lines=1,
100
+ placeholder="Enter your prompt",
101
+ container=False
102
+ )
103
+ submit_btn = gr.Button(
104
+ value='Submit', variant='primary', scale=0
105
+ )
106
+
107
+ with gr.Accordion("Advanced Settings", open=False):
108
+ seed_slider = gr.Slider(
109
+ label="Seed",
110
+ minimum=0,
111
+ maximum=MAX_SEED_VALUE,
112
+ step=1,
113
+ value=42
114
+ )
115
+ random_seed_chkbox = gr.Checkbox(
116
+ label="Randomize seed", value=True
117
+ )
118
+
119
+ with gr.Row():
120
+ strength_slider = gr.Slider(
121
+ label="Strength",
122
+ info="Indicates extent to transform the reference `image`.",
123
+ minimum=0,
124
+ maximum=1,
125
+ step=0.01,
126
+ value=0.85
127
+ )
128
+ steps_slider = gr.Slider(
129
+ label="Number of inference steps",
130
+ info="The number of denoising steps.",
131
+ minimum=1,
132
+ maximum=50,
133
+ step=1,
134
+ value=20
135
+ )
136
+
137
+ with gr.Column():
138
+ output_img = gr.Image(
139
+ type='pil', image_mode='RGB', label='Generated Image', format="png"
140
+ )
141
+ with gr.Accordion("Debug", open=False):
142
+ output_mask = gr.Image(
143
+ type='pil', image_mode='RGB', label='Input Mask', format="png"
144
+ )
145
+
146
+ gr.Examples(
147
+ fn=process_images,
148
+ examples=[
149
+ [
150
+ {
151
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
152
+ "layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw).convert("RGBA")],
153
+ "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
154
+ },
155
+ "little lion",
156
+ 42,
157
+ False,
158
+ 0.85,
159
+ 30
160
+ ],
161
+ [
162
+ {
163
+ "background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
164
+ "layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw).convert("RGBA")],
165
+ "composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
166
+ },
167
+ "tribal tattoos",
168
+ 42,
169
+ False,
170
+ 0.85,
171
+ 30
172
+ ]
173
+ ],
174
+ inputs=[
175
+ img_editor,
176
+ text_input,
177
+ seed_slider,
178
+ random_seed_chkbox,
179
+ strength_slider,
180
+ steps_slider
181
+ ],
182
+ outputs=[
183
+ output_img,
184
+ output_mask
185
+ ],
186
+ run_on_click=True,
187
+ cache_examples=True
188
+ )
189
+
190
+ submit_btn.click(
191
+ fn=process_images,
192
+ inputs=[
193
+ img_editor,
194
+ text_input,
195
+ seed_slider,
196
+ random_seed_chkbox,
197
+ strength_slider,
198
+ steps_slider
199
+ ],
200
+ outputs=[
201
+ output_img,
202
+ output_mask
203
+ ]
204
+ )
205
+
206
+ demo.launch(debug=False, show_error=True)
207
+