owaiskaifi commited on
Commit
50c3b64
1 Parent(s): 95fe902

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -157
app.py CHANGED
@@ -1,53 +1,26 @@
1
- import torch
2
- import gradio as gr
3
- from PIL import Image
4
- import qrcode
5
- from pathlib import Path
6
- from multiprocessing import cpu_count
7
- import requests
8
- import io
9
- import os
10
-
11
-
12
  from PIL import Image
 
13
  from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
14
  from diffusers.utils import load_image
15
 
 
16
 
17
- from diffusers import (
18
- StableDiffusionPipeline,
19
- StableDiffusionControlNetImg2ImgPipeline,
20
- ControlNetModel,
21
- DDIMScheduler,
22
- DPMSolverMultistepScheduler,
23
- DEISMultistepScheduler,
24
- HeunDiscreteScheduler,
25
- EulerDiscreteScheduler,
26
- )
27
-
28
- qrcode_generator = qrcode.QRCode(
29
- version=1,
30
- error_correction=qrcode.ERROR_CORRECT_H,
31
- box_size=10,
32
- border=4,
33
- )
34
-
35
- controlnet = ControlNetModel.from_pretrained(
36
- "DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16
37
- )
38
 
39
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
40
  "runwayml/stable-diffusion-v1-5",
41
  controlnet=controlnet,
42
  safety_checker=None,
43
- torch_dtype=torch.float16,
44
- ) #.to("cuda")
 
45
  pipe.enable_xformers_memory_efficient_attention()
46
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
47
  pipe.enable_model_cpu_offload()
48
 
49
-
50
- def resize_for_condition_image(input_image: Image.Image, resolution: int):
51
  input_image = input_image.convert("RGB")
52
  W, H = input_image.size
53
  k = float(resolution) / min(H, W)
@@ -58,126 +31,38 @@ def resize_for_condition_image(input_image: Image.Image, resolution: int):
58
  img = input_image.resize((W, H), resample=Image.LANCZOS)
59
  return img
60
 
61
-
62
- SAMPLER_MAP = {
63
- "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
64
- "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
65
- "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
66
- "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
67
- "DDIM": lambda config: DDIMScheduler.from_config(config),
68
- "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
69
- }
70
-
71
-
72
- def inference(
73
- qr_code_content: str,
74
- prompt: str,
75
- negative_prompt: str,
76
- guidance_scale: float = 10.0,
77
- controlnet_conditioning_scale: float = 2.0,
78
- strength: float = 0.8,
79
- seed: int = -1,
80
- init_image: Image.Image | None = None,
81
- qrcode_image: Image.Image | None = None,
82
- use_qr_code_as_init_image = True,
83
- sampler = "DPM++ Karras SDE",
84
- ):
85
- if prompt is None or prompt == "":
86
- raise gr.Error("Prompt is required")
87
-
88
- if qrcode_image is None and qr_code_content == "":
89
- raise gr.Error("QR Code Image or QR Code Content is required")
90
-
91
- pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)
92
-
93
- generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()
94
-
95
- if qr_code_content != "" or qrcode_image.size == (1, 1):
96
- qr = qrcode.QRCode(
97
- version=1,
98
- error_correction=qrcode.constants.ERROR_CORRECT_H,
99
- box_size=10,
100
- border=4,
101
- )
102
- qr.add_data(qr_code_content)
103
- qr.make(fit=True)
104
- qrcode_image = qr.make_image(fill_color="black", back_color="white")
105
-
106
- if init_image is None:
107
- if use_qr_code_as_init_image:
108
- init_image = qrcode_image.convert("RGB")
109
-
110
- resolution = controlnet.config.resolution
111
- qrcode_image = resize_for_condition_image(qrcode_image, resolution)
112
- if init_image is not None:
113
- init_image = init_image.convert("RGB")
114
- init_image = resize_for_condition_image(init_image, resolution)
115
- init_image = torch.nn.functional.interpolate(
116
- torch.nn.functional.to_tensor(init_image).unsqueeze(0),
117
- size=(resolution, resolution),
118
- mode="bilinear",
119
- align_corners=False,
120
- )[0].unsqueeze(0)
121
- else:
122
- init_image = torch.zeros(
123
- (1, 3, resolution, resolution), device=pipe.device
124
- ).to(dtype=torch.float32)
125
-
126
- with torch.no_grad():
127
- result_image = pipe(
128
- qr_code_condition=qrcode_image,
129
- prompt=prompt,
130
- negative_prompt=negative_prompt,
131
- init_image=init_image,
132
- strength=strength,
133
- guidance_scale=guidance_scale,
134
- controlnet_conditioning_scale=controlnet_conditioning_scale,
135
- disable_progress_bar=True,
136
- seed=generator,
137
- ).cpu()
138
-
139
- result_image = (
140
- result_image.clamp(-1, 1).squeeze().permute(1, 2, 0).numpy() * 255
141
- )
142
- result_image = Image.fromarray(result_image.astype("uint8"))
143
-
144
- return result_image
145
-
146
-
147
- app = Flask(__name__)
148
-
149
- @app.route('/generate_qr_code', methods=['POST'])
150
- def generate_qr_code():
151
- qr_code_content = request.json['qr_code_content']
152
- prompt = request.json['prompt']
153
- negative_prompt = request.json['negative_prompt']
154
- guidance_scale = float(request.json.get('guidance_scale', 10.0))
155
- controlnet_conditioning_scale = float(request.json.get('controlnet_conditioning_scale', 2.0))
156
- strength = float(request.json.get('strength', 0.8))
157
- seed = int(request.json.get('seed', -1))
158
- init_image = None
159
- qrcode_image = None
160
- use_qr_code_as_init_image = request.json.get('use_qr_code_as_init_image', True)
161
- sampler = request.json.get('sampler', 'DPM++ Karras SDE')
162
-
163
- try:
164
- result_image = inference(qr_code_content, prompt, negative_prompt, guidance_scale,
165
- controlnet_conditioning_scale, strength, seed, init_image,
166
- qrcode_image, use_qr_code_as_init_image, sampler)
167
-
168
- image_bytes = io.BytesIO()
169
- result_image.save(image_bytes, format='PNG')
170
- image_base64 = base64.b64encode(image_bytes.getvalue()).decode('utf-8')
171
-
172
- return jsonify({'image_base64': image_base64})
173
- except Exception as e:
174
- return jsonify({'error': str(e)}), 500
175
-
176
-
177
- @app.route('/health', methods=['GET'])
178
- def health_check():
179
- return 'OK'
180
-
181
 
182
  if __name__ == '__main__':
183
- app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, request, jsonify
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
+ import torch
4
  from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
5
  from diffusers.utils import load_image
6
 
7
+ app = Flask(__name__)
8
 
9
+ controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v1p_sd15",
10
+ torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
13
  "runwayml/stable-diffusion-v1-5",
14
  controlnet=controlnet,
15
  safety_checker=None,
16
+ torch_dtype=torch.float16
17
+ )
18
+
19
  pipe.enable_xformers_memory_efficient_attention()
20
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
21
  pipe.enable_model_cpu_offload()
22
 
23
+ def resize_for_condition_image(input_image: Image, resolution: int):
 
24
  input_image = input_image.convert("RGB")
25
  W, H = input_image.size
26
  k = float(resolution) / min(H, W)
 
31
  img = input_image.resize((W, H), resample=Image.LANCZOS)
32
  return img
33
 
34
+ @app.route('/generate_image', methods=['POST'])
35
+ def generate_image():
36
+ # Get input parameters from the request
37
+ prompt = request.json.get('prompt')
38
+ negative_prompt = request.json.get('negative_prompt')
39
+ image_url = request.json.get('image_url')
40
+ control_image_url = request.json.get('control_image_url')
41
+
42
+ # Load the images from URLs
43
+ source_image = load_image(image_url)
44
+ init_image = load_image(control_image_url)
45
+
46
+ # Resize images for conditioning
47
+ condition_image = resize_for_condition_image(source_image, 768)
48
+ init_image = resize_for_condition_image(init_image, 768)
49
+
50
+ # Generate the image using the pipeline
51
+ generator = torch.manual_seed(123121231)
52
+ image = pipe(prompt=prompt,
53
+ negative_prompt=negative_prompt,
54
+ image=init_image,
55
+ control_image=condition_image,
56
+ width=768,
57
+ height=768,
58
+ guidance_scale=20,
59
+ controlnet_conditioning_scale=1.5,
60
+ generator=generator,
61
+ strength=0.9,
62
+ num_inference_steps=150)
63
+
64
+ # Return the generated image
65
+ return jsonify({'image': image.images[0]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if __name__ == '__main__':
68
+ app.run()