owaiskaifi commited on
Commit
1474dee
1 Parent(s): 07c3544

add main files

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +176 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image
2
+ FROM python:3.9-slim-buster
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements.txt file
8
+ COPY requirements.txt .
9
+
10
+ # Install the dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the code into the container
14
+ COPY . .
15
+
16
+ # Expose the port
17
+ EXPOSE 7860
18
+
19
+ # Run the application
20
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import qrcode
5
+ from pathlib import Path
6
+ from multiprocessing import cpu_count
7
+ import requests
8
+ import io
9
+ import os
10
+ from PIL import Image
11
+
12
+ from diffusers import (
13
+ StableDiffusionPipeline,
14
+ StableDiffusionControlNetImg2ImgPipeline,
15
+ ControlNetModel,
16
+ DDIMScheduler,
17
+ DPMSolverMultistepScheduler,
18
+ DEISMultistepScheduler,
19
+ HeunDiscreteScheduler,
20
+ EulerDiscreteScheduler,
21
+ )
22
+
23
+ qrcode_generator = qrcode.QRCode(
24
+ version=1,
25
+ error_correction=qrcode.ERROR_CORRECT_H,
26
+ box_size=10,
27
+ border=4,
28
+ )
29
+
30
+ controlnet = ControlNetModel.from_pretrained(
31
+ "DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch.float16
32
+ )
33
+
34
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
35
+ "runwayml/stable-diffusion-v1-5",
36
+ controlnet=controlnet,
37
+ safety_checker=None,
38
+ torch_dtype=torch.float16,
39
+ ).to("cuda")
40
+ pipe.enable_xformers_memory_efficient_attention()
41
+
42
+
43
+ def resize_for_condition_image(input_image: Image.Image, resolution: int):
44
+ input_image = input_image.convert("RGB")
45
+ W, H = input_image.size
46
+ k = float(resolution) / min(H, W)
47
+ H *= k
48
+ W *= k
49
+ H = int(round(H / 64.0)) * 64
50
+ W = int(round(W / 64.0)) * 64
51
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
52
+ return img
53
+
54
+
55
+ SAMPLER_MAP = {
56
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
57
+ "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
58
+ "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
59
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
60
+ "DDIM": lambda config: DDIMScheduler.from_config(config),
61
+ "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
62
+ }
63
+
64
+
65
+ def inference(
66
+ qr_code_content: str,
67
+ prompt: str,
68
+ negative_prompt: str,
69
+ guidance_scale: float = 10.0,
70
+ controlnet_conditioning_scale: float = 2.0,
71
+ strength: float = 0.8,
72
+ seed: int = -1,
73
+ init_image: Image.Image | None = None,
74
+ qrcode_image: Image.Image | None = None,
75
+ use_qr_code_as_init_image = True,
76
+ sampler = "DPM++ Karras SDE",
77
+ ):
78
+ if prompt is None or prompt == "":
79
+ raise gr.Error("Prompt is required")
80
+
81
+ if qrcode_image is None and qr_code_content == "":
82
+ raise gr.Error("QR Code Image or QR Code Content is required")
83
+
84
+ pipe.scheduler = SAMPLER_MAP[sampler](pipe.scheduler.config)
85
+
86
+ generator = torch.manual_seed(seed) if seed != -1 else torch.Generator()
87
+
88
+ if qr_code_content != "" or qrcode_image.size == (1, 1):
89
+ qr = qrcode.QRCode(
90
+ version=1,
91
+ error_correction=qrcode.constants.ERROR_CORRECT_H,
92
+ box_size=10,
93
+ border=4,
94
+ )
95
+ qr.add_data(qr_code_content)
96
+ qr.make(fit=True)
97
+ qrcode_image = qr.make_image(fill_color="black", back_color="white")
98
+
99
+ if init_image is None:
100
+ if use_qr_code_as_init_image:
101
+ init_image = qrcode_image.convert("RGB")
102
+
103
+ resolution = controlnet.config.resolution
104
+ qrcode_image = resize_for_condition_image(qrcode_image, resolution)
105
+ if init_image is not None:
106
+ init_image = init_image.convert("RGB")
107
+ init_image = resize_for_condition_image(init_image, resolution)
108
+ init_image = torch.nn.functional.interpolate(
109
+ torch.nn.functional.to_tensor(init_image).unsqueeze(0),
110
+ size=(resolution, resolution),
111
+ mode="bilinear",
112
+ align_corners=False,
113
+ )[0].unsqueeze(0)
114
+ else:
115
+ init_image = torch.zeros(
116
+ (1, 3, resolution, resolution), device=pipe.device
117
+ ).to(dtype=torch.float32)
118
+
119
+ with torch.no_grad():
120
+ result_image = pipe(
121
+ qr_code_condition=qrcode_image,
122
+ prompt=prompt,
123
+ negative_prompt=negative_prompt,
124
+ init_image=init_image,
125
+ strength=strength,
126
+ guidance_scale=guidance_scale,
127
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
128
+ disable_progress_bar=True,
129
+ seed=generator,
130
+ ).cpu()
131
+
132
+ result_image = (
133
+ result_image.clamp(-1, 1).squeeze().permute(1, 2, 0).numpy() * 255
134
+ )
135
+ result_image = Image.fromarray(result_image.astype("uint8"))
136
+
137
+ return result_image
138
+
139
+
140
+ app = Flask(__name__)
141
+
142
+ @app.route('/generate_qr_code', methods=['POST'])
143
+ def generate_qr_code():
144
+ qr_code_content = request.json['qr_code_content']
145
+ prompt = request.json['prompt']
146
+ negative_prompt = request.json['negative_prompt']
147
+ guidance_scale = float(request.json.get('guidance_scale', 10.0))
148
+ controlnet_conditioning_scale = float(request.json.get('controlnet_conditioning_scale', 2.0))
149
+ strength = float(request.json.get('strength', 0.8))
150
+ seed = int(request.json.get('seed', -1))
151
+ init_image = None
152
+ qrcode_image = None
153
+ use_qr_code_as_init_image = request.json.get('use_qr_code_as_init_image', True)
154
+ sampler = request.json.get('sampler', 'DPM++ Karras SDE')
155
+
156
+ try:
157
+ result_image = inference(qr_code_content, prompt, negative_prompt, guidance_scale,
158
+ controlnet_conditioning_scale, strength, seed, init_image,
159
+ qrcode_image, use_qr_code_as_init_image, sampler)
160
+
161
+ image_bytes = io.BytesIO()
162
+ result_image.save(image_bytes, format='PNG')
163
+ image_base64 = base64.b64encode(image_bytes.getvalue()).decode('utf-8')
164
+
165
+ return jsonify({'image_base64': image_base64})
166
+ except Exception as e:
167
+ return jsonify({'error': str(e)}), 500
168
+
169
+
170
+ @app.route('/health', methods=['GET'])
171
+ def health_check():
172
+ return 'OK'
173
+
174
+
175
+ if __name__ == '__main__':
176
+ app.run(host='0.0.0.0', port=7860)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.9.0
2
+ gradio==2.0.8
3
+ Pillow==8.3.2
4
+ qrcode==7.3
5
+ pathlib==1.0.1
6
+ requests==2.26.0
7
+ diffusers==0.0.1
8
+ Flask==2.0.2
9
+ uvicorn