temp-space / app.py
hav4ik
remove log dump
1cc5f08 unverified
raw
history blame
No virus
24.3 kB
import os
import random
import base64
import gradio as gr
import numpy as np
import PIL.Image
from PIL import ImageOps
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import torchvision.transforms.functional as TF
from diffusers import (
AutoencoderKL,
EulerAncestralDiscreteScheduler,
StableDiffusionXLAdapterPipeline,
T2IAdapter,
)
import urllib.parse
import requests
from io import BytesIO
import json
from pathlib import Path
import uuid
import os, uuid
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
from datetime import datetime
class DEFAULTS:
NEGATIVE_PROMPT = " extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
REWRITING_PROMPT = (
"Rewrite the image caption by making it shorter (but retain all information about relative position), "
"remove information about style of objects or colors of background and foreground, and, most importantly, remove all details "
"that suggests it is a sketch. Write it as a Google image search query:"
)
MOONDREAM_PROMPT = "Describe this image."
NUM_STEPS = 25
GUIDANCE_SCALE = 5
ADAPTER_CONDITIONING_SCALE = 0.8
ADAPTER_CONDITIONING_FACTOR = 0.8
SEED = 1231245
RANDOMIZE_SEED = True
DESCRIPTION = '''# Sketch to Image/Caption to Bing Search :)
This is a test space for the Sketch to Image/Caption to Bing Search model. You can draw a sketch on the left, provide a prompt, and select a style. The model will generate an image based on your sketch and prompt, and provide a Bing search query based on the generated image.
'''
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
style_list = [
{
"name": "(No style)",
"prompt": "{prompt}",
"negative_prompt": "",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
"negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
},
{
"name": "Anime",
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
"negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
},
{
"name": "Digital Art",
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
"negative_prompt": "photo, photorealistic, realism, ugly",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
},
{
"name": "Pixel art",
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
"negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
},
{
"name": "Fantasy art",
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
"negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
},
{
"name": "Neonpunk",
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
"negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
},
{
"name": "Manga",
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
},
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Photographic" # "(No style)"
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + negative
if os.path.exists("azure_connection_string.txt"):
with open("azure_connection_string.txt", "r") as f:
CONNECTION_STRING = f.read().strip()
else:
CONNECTION_STRING = os.getenv("AZURE_CONNECTION_STRING")
def upload_pil_image_to_azure(image, connection_string=CONNECTION_STRING):
image_name = f"{uuid.uuid4()}.png"
image_bytes = BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
try:
# Create the BlobServiceClient object
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
# Create a blob client using the local file name as the name for the blob
blob_client = blob_service_client.get_blob_client(container="blob-image-hosting", blob=image_name)
# Upload the created file and retrieve the URL
blob_client.upload_blob(image_bytes)
file_url = blob_client.url
except Exception as ex:
print('Exception:')
print(ex)
file_url = None
# If this function did not fail, upload was successful
return file_url
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:1")
else:
device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:0")
else:
device_0, device_1 = torch.device("cpu"), torch.device("cpu")
# device_1 = 'cuda:0'
if torch.cuda.is_available():
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
model_id,
vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
adapter=adapter,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.to(device_0)
else:
pipe = None
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
vlmodel_id = "vikhyatk/moondream2"
vlmodel_revision = "2024-07-23"
vlmodel = AutoModelForCausalLM.from_pretrained(
vlmodel_id, trust_remote_code=True, revision=vlmodel_revision, device_map={"": device_1},
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
vltokenizer = AutoTokenizer.from_pretrained(vlmodel_id, revision=vlmodel_revision)
rewrite_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
rewrite_model = AutoModelForCausalLM.from_pretrained(
rewrite_model_name,
device_map={"": device_1},
quantization_config=nf4_config,
# load_in_8bit=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
rewrite_tokenizer = AutoTokenizer.from_pretrained(rewrite_model_name)
def caption_image_with_recaption(pil_image, moondream_prompt, rewriting_prompt, user_prompt=""):
enc_image = vlmodel.encode_image(pil_image)
img_caption = vlmodel.answer_question(enc_image, moondream_prompt, vltokenizer)
rewritten_caption = rewrite_prompt(img_caption, rewriting_prompt, user_prompt=user_prompt)
rewritten_caption = rewritten_caption.strip('"').replace("\n", " ")
return img_caption, rewritten_caption
def rewrite_prompt(image_cap: str, guide: str, user_prompt: str = "") -> str:
prompt = f"{guide}\n{image_cap}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = rewrite_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = rewrite_tokenizer([text], return_tensors="pt").to(device_1)
generated_ids = rewrite_model.generate(model_inputs.input_ids, max_new_tokens=128)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
response = rewrite_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def run_full(
image,
user_prompt: str,
negative_prompt: str,
rewriting_prompt: str,
moondream_prompt: str,
style_name: str = DEFAULT_STYLE_NAME,
num_steps: int = 25,
guidance_scale: float = 5,
adapter_conditioning_scale: float = 0.8,
adapter_conditioning_factor: float = 0.8,
seed: int = 0,
progress=None,
) -> PIL.Image.Image:
# image is a white background with black sketch
image = ImageOps.invert(image)
# resize to 1024x1024
image = image.resize((1024, 1024))
# Threshold the image to get a binary sketch
image = TF.to_tensor(image) > 0.5
image = TF.to_pil_image(image.to(torch.float32))
full_log = []
if user_prompt == "":
pre_caption = True
start_time = datetime.now()
img_caption, rewritten_caption = caption_image_with_recaption(
pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt)
full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
full_log.append(f"img_caption (pre): {img_caption}")
full_log.append(f"rewritten_caption (pre): {rewritten_caption}")
drawing_prompt = rewritten_caption
else:
pre_caption = False
drawing_prompt = user_prompt
full_log.append(f"Pre-caption: {pre_caption}")
# Generate image
start_time = datetime.now()
drawing_prompt, negative_prompt = apply_style(style_name, drawing_prompt, negative_prompt)
generator = torch.Generator(device=device_0).manual_seed(seed)
out_img = pipe(
prompt=drawing_prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=num_steps,
generator=generator,
guidance_scale=guidance_scale,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
).images[0]
full_log.append(f"Image generation time: {datetime.now() - start_time}")
if not pre_caption:
start_time = datetime.now()
img_caption, rewritten_caption = caption_image_with_recaption(
pil_image=out_img,
rewriting_prompt=rewriting_prompt,
moondream_prompt=moondream_prompt,
user_prompt=user_prompt)
full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
full_log.append(f"img_caption (post): {img_caption}")
full_log.append(f"rewritten_caption (post): {rewritten_caption}")
# SERP query
bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}"
md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n"
# Visual Search query
out_img_imgur_url = upload_pil_image_to_azure(out_img)
if out_img_imgur_url is None:
md_text += "### Bing Visual Search\n**Error:** Failed to upload image to Azure Blob Storage\n"
bing_image_search_url = "https://www.bing.com/images"
else:
imgur_url_quote = urllib.parse.quote(out_img_imgur_url)
bing_image_search_url = f"https://www.bing.com/images/search?view=detailv2&iss=SBI&form=SBIIRP&q=imgurl:{imgur_url_quote}"
md_text += f"### Bing Visual Search\n[{bing_image_search_url}]({bing_image_search_url})\n"
# Debug info
md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n"
# Full log dump
md_text += f"### Debug: full log\n{'<br>'.join(full_log)}"
# return dict
return {
"image": out_img,
"text_search_url": bing_serp_query,
"visual_search_url": bing_image_search_url,
"logs": md_text,
}
def run_full_gradio(
image,
user_prompt: str,
negative_prompt: str,
rewriting_prompt: str,
moondream_prompt: str,
style_name: str = DEFAULT_STYLE_NAME,
num_steps: int = 25,
guidance_scale: float = 5,
adapter_conditioning_scale: float = 0.8,
adapter_conditioning_factor: float = 0.8,
seed: int = 0,
progress=gr.Progress(track_tqdm=True),
) -> PIL.Image.Image:
image = image['composite']
background = PIL.Image.new('RGBA', image.size, (255, 255, 255))
alpha_composite = PIL.Image.alpha_composite(background, image)
image = alpha_composite.convert("RGB")
results = run_full(
image=image,
user_prompt=user_prompt,
negative_prompt=negative_prompt,
rewriting_prompt=rewriting_prompt,
moondream_prompt=moondream_prompt,
style_name=style_name,
num_steps=num_steps,
guidance_scale=guidance_scale,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=adapter_conditioning_factor,
seed=seed,
progress=progress,
)
# construct markdown output
return results["image"], results["logs"]
def run_full_api(
image_url: str,
image_bytes: str,
user_prompt: str,
progress=gr.Progress(track_tqdm=True),
) -> str:
seed = randomize_seed_fn(0, True)
if image_url:
image = PIL.Image.open(BytesIO(requests.get(image_url).content))
elif image_bytes:
decoded_image = base64.b64decode(image_bytes)
image = PIL.Image.open(BytesIO(decoded_image))
# if image is RGBA, convert to RGB
if image.mode == "RGBA":
background = PIL.Image.new('RGBA', image.size, (255, 255, 255))
alpha_composite = PIL.Image.alpha_composite(background, image)
image = alpha_composite.convert("RGB")
results = run_full(
image=image, user_prompt=user_prompt,
negative_prompt=DEFAULTS.NEGATIVE_PROMPT,
rewriting_prompt=DEFAULTS.REWRITING_PROMPT,
moondream_prompt=DEFAULTS.MOONDREAM_PROMPT,
style_name=DEFAULT_STYLE_NAME,
num_steps=DEFAULTS.NUM_STEPS,
guidance_scale=DEFAULTS.GUIDANCE_SCALE,
adapter_conditioning_scale=DEFAULTS.ADAPTER_CONDITIONING_SCALE,
adapter_conditioning_factor=DEFAULTS.ADAPTER_CONDITIONING_FACTOR,
seed=seed)
return results["text_search_url"], results["visual_search_url"], results["logs"]
def run_caponly(
image,
rewriting_prompt: str,
moondream_prompt: str,
seed: int = 0,
progress=None,
) -> PIL.Image.Image:
# image is a white background with black sketch
image = ImageOps.invert(image)
# resize to 1024x1024
image = image.resize((1024, 1024))
# Threshold the image to get a binary sketch
image = TF.to_tensor(image) > 0.5
image = TF.to_pil_image(image.to(torch.float32))
full_log = []
start_time = datetime.now()
img_caption, rewritten_caption = caption_image_with_recaption(
pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt)
full_log.append(f"Combined captioning time: {datetime.now() - start_time}")
full_log.append(f"img_caption (pre): {img_caption}")
full_log.append(f"rewritten_caption (pre): {rewritten_caption}")
final_prompt = rewritten_caption
# SERP query
bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}"
md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n"
# Debug info
md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n"
# Full log dump
md_text += f"### Debug: full log\n{'<br>'.join(full_log)}"
# return dict
return {
"text_search_url": bing_serp_query,
"logs": md_text,
}
def run_caponly_api(
image_url: str,
image_bytes: str,
progress=gr.Progress(track_tqdm=True),
) -> str:
seed = randomize_seed_fn(0, True)
if image_url:
image = PIL.Image.open(BytesIO(requests.get(image_url).content))
elif image_bytes:
decoded_image = base64.b64decode(image_bytes)
image = PIL.Image.open(BytesIO(decoded_image))
# if image is RGBA, convert to RGB
if image.mode == "RGBA":
background = PIL.Image.new('RGBA', image.size, (255, 255, 255))
alpha_composite = PIL.Image.alpha_composite(background, image)
image = alpha_composite.convert("RGB")
results = run_caponly(
image=image,
rewriting_prompt=DEFAULTS.REWRITING_PROMPT,
moondream_prompt=DEFAULTS.MOONDREAM_PROMPT,
seed=seed)
return results["text_search_url"], results["logs"]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION, elem_id="description")
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
with gr.Group():
image = gr.Sketchpad(
# sources=["canvas"],
# tool="sketch",
type="pil",
image_mode="RGBA",
# invert_colors=True,
layers=False,
canvas_size=(1024, 1024),
brush=gr.Brush(
default_color="black",
colors=None,
default_size=4,
color_mode="fixed",
),
eraser=gr.Eraser(),
height=440,
)
prompt = gr.Textbox(label="Prompt")
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
value=DEFAULTS.NEGATIVE_PROMPT,
)
rewriting_prompt = gr.Textbox(
label="Rewriting prompt",
value=DEFAULTS.REWRITING_PROMPT,
)
moondream_prompt = gr.Textbox(
label="Moondream prompt",
value=DEFAULTS.MOONDREAM_PROMPT,
)
num_steps = gr.Slider(
label="Number of steps",
minimum=1,
maximum=50,
step=1,
value=DEFAULTS.NUM_STEPS,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.1,
maximum=10.0,
step=0.1,
value=DEFAULTS.GUIDANCE_SCALE,
)
adapter_conditioning_scale = gr.Slider(
label="Adapter conditioning scale",
minimum=0.5,
maximum=1,
step=0.1,
value=DEFAULTS.ADAPTER_CONDITIONING_SCALE,
)
adapter_conditioning_factor = gr.Slider(
label="Adapter conditioning factor",
info="Fraction of timesteps for which adapter should be applied",
minimum=0.5,
maximum=1,
step=0.1,
value=DEFAULTS.ADAPTER_CONDITIONING_FACTOR,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column():
result_img = gr.Image(label="Result", height=400, interactive=False)
result_caption = gr.Markdown(label="Image caption")
result = [result_img, result_caption]
with gr.Row():
gr.Markdown("# API endpoints\nThe fields below are only used to test the served API endpoints of this space.", elem_id="description")
with gr.Row():
with gr.Column():
with gr.Accordion("Full Experience API", open=False):
api_fullexp_image_url = gr.Textbox(label="Image URL")
api_fullexp_image_bytes = gr.Textbox(label="Image Base64 bytes")
api_fullexp_user_prompt = gr.Textbox(label="User prompt")
api_fullexp_run_button = gr.Button("Run API")
api_fullexp_text_search_url = gr.Textbox(label="Text search URL")
api_fullexp_visual_search_url = gr.Textbox(label="Visual search URL")
api_fullexp_logs = gr.Markdown(label="Logs")
with gr.Column():
with gr.Accordion("Caption Only API", open=False):
api_caponly_image_url = gr.Textbox(label="Image URL")
api_caponly_image_bytes = gr.Textbox(label="Image Base64 bytes")
api_caponly_run_button = gr.Button("Run API")
api_caponly_text_search_url = gr.Textbox(label="Text search URL")
api_caponly_logs = gr.Markdown(label="Logs")
# Gradio components interconnections
inputs = [
image,
prompt,
negative_prompt,
rewriting_prompt,
moondream_prompt,
style,
num_steps,
guidance_scale,
adapter_conditioning_scale,
adapter_conditioning_factor,
seed,
]
prompt.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=run_full_gradio,
inputs=inputs,
outputs=result,
api_name=False,
)
negative_prompt.submit(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=run_full_gradio,
inputs=inputs,
outputs=result,
api_name=False,
)
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=run_full_gradio,
inputs=inputs,
outputs=result,
api_name=False,
)
# API interconnections
api_fullexp_run_button.click(
fn=run_full_api,
inputs=[api_fullexp_image_url, api_fullexp_image_bytes, api_fullexp_user_prompt],
outputs=[api_fullexp_text_search_url, api_fullexp_visual_search_url, api_fullexp_logs],
api_name="full_experience",
)
api_caponly_run_button.click(
fn=run_caponly_api,
inputs=[api_caponly_image_url, api_caponly_image_bytes],
outputs=[api_caponly_text_search_url, api_caponly_logs],
api_name="caption_only",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()