import os import random import base64 import gradio as gr import numpy as np import PIL.Image from PIL import ImageOps import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import BitsAndBytesConfig import torchvision.transforms.functional as TF from diffusers import ( AutoencoderKL, EulerAncestralDiscreteScheduler, StableDiffusionXLAdapterPipeline, T2IAdapter, ) import urllib.parse import requests from io import BytesIO import json from pathlib import Path import uuid import os, uuid from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient from datetime import datetime class DEFAULTS: NEGATIVE_PROMPT = " extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured" REWRITING_PROMPT = ( "Rewrite the image caption by making it shorter (but retain all information about relative position), " "remove information about style of objects or colors of background and foreground, and, most importantly, remove all details " "that suggests it is a sketch. Write it as a Google image search query:" ) MOONDREAM_PROMPT = "Describe this image." NUM_STEPS = 25 GUIDANCE_SCALE = 5 ADAPTER_CONDITIONING_SCALE = 0.8 ADAPTER_CONDITIONING_FACTOR = 0.8 SEED = 1231245 RANDOMIZE_SEED = True DESCRIPTION = '''# Sketch to Image/Caption to Bing Search :) This is a test space for the Sketch to Image/Caption to Bing Search model. You can draw a sketch on the left, provide a prompt, and select a style. The model will generate an image based on your sketch and prompt, and provide a Bing search query based on the generated image. ''' if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on CPU.

" style_list = [ { "name": "(No style)", "prompt": "{prompt}", "negative_prompt": "", }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", "negative_prompt": "photo, photorealistic, realism, ugly", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style", }, ] styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Photographic" # "(No style)" def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]: p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) return p.replace("{prompt}", positive), n + negative if os.path.exists("azure_connection_string.txt"): with open("azure_connection_string.txt", "r") as f: CONNECTION_STRING = f.read().strip() else: CONNECTION_STRING = os.getenv("AZURE_CONNECTION_STRING") def upload_pil_image_to_azure(image, connection_string=CONNECTION_STRING): image_name = f"{uuid.uuid4()}.png" image_bytes = BytesIO() image.save(image_bytes, format="PNG") image_bytes.seek(0) try: # Create the BlobServiceClient object blob_service_client = BlobServiceClient.from_connection_string(connection_string) # Create a blob client using the local file name as the name for the blob blob_client = blob_service_client.get_blob_client(container="blob-image-hosting", blob=image_name) # Upload the created file and retrieve the URL blob_client.upload_blob(image_bytes) file_url = blob_client.url except Exception as ex: print('Exception:') print(ex) file_url = None # If this function did not fail, upload was successful return file_url if torch.cuda.is_available(): if torch.cuda.device_count() > 1: device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:1") else: device_0, device_1 = torch.device("cuda:0"), torch.device("cuda:0") else: device_0, device_1 = torch.device("cpu"), torch.device("cpu") # device_1 = 'cuda:0' if torch.cuda.is_available(): model_id = "stabilityai/stable-diffusion-xl-base-1.0" adapter = T2IAdapter.from_pretrained( "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16" ) scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionXLAdapterPipeline.from_pretrained( model_id, vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16), adapter=adapter, scheduler=scheduler, torch_dtype=torch.float16, variant="fp16", ) pipe.to(device_0) else: pipe = None MAX_SEED = np.iinfo(np.int32).max def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) vlmodel_id = "vikhyatk/moondream2" vlmodel_revision = "2024-07-23" vlmodel = AutoModelForCausalLM.from_pretrained( vlmodel_id, trust_remote_code=True, revision=vlmodel_revision, device_map={"": device_1}, torch_dtype=torch.float16, attn_implementation="flash_attention_2", ) vltokenizer = AutoTokenizer.from_pretrained(vlmodel_id, revision=vlmodel_revision) rewrite_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct" rewrite_model = AutoModelForCausalLM.from_pretrained( rewrite_model_name, device_map={"": device_1}, quantization_config=nf4_config, # load_in_8bit=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) rewrite_tokenizer = AutoTokenizer.from_pretrained(rewrite_model_name) def caption_image_with_recaption(pil_image, moondream_prompt, rewriting_prompt, user_prompt=""): enc_image = vlmodel.encode_image(pil_image) img_caption = vlmodel.answer_question(enc_image, moondream_prompt, vltokenizer) rewritten_caption = rewrite_prompt(img_caption, rewriting_prompt, user_prompt=user_prompt) rewritten_caption = rewritten_caption.strip('"').replace("\n", " ") return img_caption, rewritten_caption def rewrite_prompt(image_cap: str, guide: str, user_prompt: str = "") -> str: prompt = f"{guide}\n{image_cap}" messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] text = rewrite_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = rewrite_tokenizer([text], return_tensors="pt").to(device_1) generated_ids = rewrite_model.generate(model_inputs.input_ids, max_new_tokens=128) generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] response = rewrite_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response def run_full( image, user_prompt: str, negative_prompt: str, rewriting_prompt: str, moondream_prompt: str, style_name: str = DEFAULT_STYLE_NAME, num_steps: int = 25, guidance_scale: float = 5, adapter_conditioning_scale: float = 0.8, adapter_conditioning_factor: float = 0.8, seed: int = 0, progress=None, ) -> PIL.Image.Image: # image is a white background with black sketch image = ImageOps.invert(image) # resize to 1024x1024 image = image.resize((1024, 1024)) # Threshold the image to get a binary sketch image = TF.to_tensor(image) > 0.5 image = TF.to_pil_image(image.to(torch.float32)) full_log = [] if user_prompt == "": pre_caption = True start_time = datetime.now() img_caption, rewritten_caption = caption_image_with_recaption( pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt) full_log.append(f"Combined captioning time: {datetime.now() - start_time}") full_log.append(f"img_caption (pre): {img_caption}") full_log.append(f"rewritten_caption (pre): {rewritten_caption}") drawing_prompt = rewritten_caption else: pre_caption = False drawing_prompt = user_prompt full_log.append(f"Pre-caption: {pre_caption}") # Generate image start_time = datetime.now() drawing_prompt, negative_prompt = apply_style(style_name, drawing_prompt, negative_prompt) generator = torch.Generator(device=device_0).manual_seed(seed) out_img = pipe( prompt=drawing_prompt, negative_prompt=negative_prompt, image=image, num_inference_steps=num_steps, generator=generator, guidance_scale=guidance_scale, adapter_conditioning_scale=adapter_conditioning_scale, adapter_conditioning_factor=adapter_conditioning_factor, ).images[0] full_log.append(f"Image generation time: {datetime.now() - start_time}") if not pre_caption: start_time = datetime.now() img_caption, rewritten_caption = caption_image_with_recaption( pil_image=out_img, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt, user_prompt=user_prompt) full_log.append(f"Combined captioning time: {datetime.now() - start_time}") full_log.append(f"img_caption (post): {img_caption}") full_log.append(f"rewritten_caption (post): {rewritten_caption}") # SERP query bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}" md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n" # Visual Search query out_img_imgur_url = upload_pil_image_to_azure(out_img) if out_img_imgur_url is None: md_text += "### Bing Visual Search\n**Error:** Failed to upload image to Azure Blob Storage\n" bing_image_search_url = "https://www.bing.com/images" else: imgur_url_quote = urllib.parse.quote(out_img_imgur_url) bing_image_search_url = f"https://www.bing.com/images/search?view=detailv2&iss=SBI&form=SBIIRP&q=imgurl:{imgur_url_quote}" md_text += f"### Bing Visual Search\n[{bing_image_search_url}]({bing_image_search_url})\n" # Debug info md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n" # Full log dump md_text += f"### Debug: full log\n{'
'.join(full_log)}" # return dict return { "image": out_img, "text_search_url": bing_serp_query, "visual_search_url": bing_image_search_url, "logs": md_text, } def run_full_gradio( image, user_prompt: str, negative_prompt: str, rewriting_prompt: str, moondream_prompt: str, style_name: str = DEFAULT_STYLE_NAME, num_steps: int = 25, guidance_scale: float = 5, adapter_conditioning_scale: float = 0.8, adapter_conditioning_factor: float = 0.8, seed: int = 0, progress=gr.Progress(track_tqdm=True), ) -> PIL.Image.Image: image = image['composite'] background = PIL.Image.new('RGBA', image.size, (255, 255, 255)) alpha_composite = PIL.Image.alpha_composite(background, image) image = alpha_composite.convert("RGB") results = run_full( image=image, user_prompt=user_prompt, negative_prompt=negative_prompt, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt, style_name=style_name, num_steps=num_steps, guidance_scale=guidance_scale, adapter_conditioning_scale=adapter_conditioning_scale, adapter_conditioning_factor=adapter_conditioning_factor, seed=seed, progress=progress, ) # construct markdown output return results["image"], results["logs"] def run_full_api( image_url: str, image_bytes: str, user_prompt: str, progress=gr.Progress(track_tqdm=True), ) -> str: seed = randomize_seed_fn(0, True) if image_url: image = PIL.Image.open(BytesIO(requests.get(image_url).content)) elif image_bytes: decoded_image = base64.b64decode(image_bytes) image = PIL.Image.open(BytesIO(decoded_image)) # if image is RGBA, convert to RGB if image.mode == "RGBA": background = PIL.Image.new('RGBA', image.size, (255, 255, 255)) alpha_composite = PIL.Image.alpha_composite(background, image) image = alpha_composite.convert("RGB") results = run_full( image=image, user_prompt=user_prompt, negative_prompt=DEFAULTS.NEGATIVE_PROMPT, rewriting_prompt=DEFAULTS.REWRITING_PROMPT, moondream_prompt=DEFAULTS.MOONDREAM_PROMPT, style_name=DEFAULT_STYLE_NAME, num_steps=DEFAULTS.NUM_STEPS, guidance_scale=DEFAULTS.GUIDANCE_SCALE, adapter_conditioning_scale=DEFAULTS.ADAPTER_CONDITIONING_SCALE, adapter_conditioning_factor=DEFAULTS.ADAPTER_CONDITIONING_FACTOR, seed=seed) return results["text_search_url"], results["visual_search_url"], results["logs"] def run_caponly( image, rewriting_prompt: str, moondream_prompt: str, seed: int = 0, progress=None, ) -> PIL.Image.Image: # image is a white background with black sketch image = ImageOps.invert(image) # resize to 1024x1024 image = image.resize((1024, 1024)) # Threshold the image to get a binary sketch image = TF.to_tensor(image) > 0.5 image = TF.to_pil_image(image.to(torch.float32)) full_log = [] start_time = datetime.now() img_caption, rewritten_caption = caption_image_with_recaption( pil_image=image, rewriting_prompt=rewriting_prompt, moondream_prompt=moondream_prompt) full_log.append(f"Combined captioning time: {datetime.now() - start_time}") full_log.append(f"img_caption (pre): {img_caption}") full_log.append(f"rewritten_caption (pre): {rewritten_caption}") final_prompt = rewritten_caption # SERP query bing_serp_query = f"https://www.bing.com/images/search?q={urllib.parse.quote(rewritten_caption)}" md_text = f"### Bing search query\n[{bing_serp_query}]({bing_serp_query})\n" # Debug info md_text += f"### Debug: sketch caption\n{img_caption}\n\n### Debug: rewritten caption\n{rewritten_caption}\n" # Full log dump md_text += f"### Debug: full log\n{'
'.join(full_log)}" # return dict return { "text_search_url": bing_serp_query, "logs": md_text, } def run_caponly_api( image_url: str, image_bytes: str, progress=gr.Progress(track_tqdm=True), ) -> str: seed = randomize_seed_fn(0, True) if image_url: image = PIL.Image.open(BytesIO(requests.get(image_url).content)) elif image_bytes: decoded_image = base64.b64decode(image_bytes) image = PIL.Image.open(BytesIO(decoded_image)) # if image is RGBA, convert to RGB if image.mode == "RGBA": background = PIL.Image.new('RGBA', image.size, (255, 255, 255)) alpha_composite = PIL.Image.alpha_composite(background, image) image = alpha_composite.convert("RGB") results = run_caponly( image=image, rewriting_prompt=DEFAULTS.REWRITING_PROMPT, moondream_prompt=DEFAULTS.MOONDREAM_PROMPT, seed=seed) return results["text_search_url"], results["logs"] with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION, elem_id="description") gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Row(): with gr.Column(): with gr.Group(): image = gr.Sketchpad( # sources=["canvas"], # tool="sketch", type="pil", image_mode="RGBA", # invert_colors=True, layers=False, canvas_size=(1024, 1024), brush=gr.Brush( default_color="black", colors=None, default_size=4, color_mode="fixed", ), eraser=gr.Eraser(), height=440, ) prompt = gr.Textbox(label="Prompt") style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME) run_button = gr.Button("Run") with gr.Accordion("Advanced options", open=False): negative_prompt = gr.Textbox( label="Negative prompt", value=DEFAULTS.NEGATIVE_PROMPT, ) rewriting_prompt = gr.Textbox( label="Rewriting prompt", value=DEFAULTS.REWRITING_PROMPT, ) moondream_prompt = gr.Textbox( label="Moondream prompt", value=DEFAULTS.MOONDREAM_PROMPT, ) num_steps = gr.Slider( label="Number of steps", minimum=1, maximum=50, step=1, value=DEFAULTS.NUM_STEPS, ) guidance_scale = gr.Slider( label="Guidance scale", minimum=0.1, maximum=10.0, step=0.1, value=DEFAULTS.GUIDANCE_SCALE, ) adapter_conditioning_scale = gr.Slider( label="Adapter conditioning scale", minimum=0.5, maximum=1, step=0.1, value=DEFAULTS.ADAPTER_CONDITIONING_SCALE, ) adapter_conditioning_factor = gr.Slider( label="Adapter conditioning factor", info="Fraction of timesteps for which adapter should be applied", minimum=0.5, maximum=1, step=0.1, value=DEFAULTS.ADAPTER_CONDITIONING_FACTOR, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Column(): result_img = gr.Image(label="Result", height=400, interactive=False) result_caption = gr.Markdown(label="Image caption") result = [result_img, result_caption] with gr.Row(): gr.Markdown("# API endpoints\nThe fields below are only used to test the served API endpoints of this space.", elem_id="description") with gr.Row(): with gr.Column(): with gr.Accordion("Full Experience API", open=False): api_fullexp_image_url = gr.Textbox(label="Image URL") api_fullexp_image_bytes = gr.Textbox(label="Image Base64 bytes") api_fullexp_user_prompt = gr.Textbox(label="User prompt") api_fullexp_run_button = gr.Button("Run API") api_fullexp_text_search_url = gr.Textbox(label="Text search URL") api_fullexp_visual_search_url = gr.Textbox(label="Visual search URL") api_fullexp_logs = gr.Markdown(label="Logs") with gr.Column(): with gr.Accordion("Caption Only API", open=False): api_caponly_image_url = gr.Textbox(label="Image URL") api_caponly_image_bytes = gr.Textbox(label="Image Base64 bytes") api_caponly_run_button = gr.Button("Run API") api_caponly_text_search_url = gr.Textbox(label="Text search URL") api_caponly_logs = gr.Markdown(label="Logs") # Gradio components interconnections inputs = [ image, prompt, negative_prompt, rewriting_prompt, moondream_prompt, style, num_steps, guidance_scale, adapter_conditioning_scale, adapter_conditioning_factor, seed, ] prompt.submit( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=run_full_gradio, inputs=inputs, outputs=result, api_name=False, ) negative_prompt.submit( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=run_full_gradio, inputs=inputs, outputs=result, api_name=False, ) run_button.click( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=run_full_gradio, inputs=inputs, outputs=result, api_name=False, ) # API interconnections api_fullexp_run_button.click( fn=run_full_api, inputs=[api_fullexp_image_url, api_fullexp_image_bytes, api_fullexp_user_prompt], outputs=[api_fullexp_text_search_url, api_fullexp_visual_search_url, api_fullexp_logs], api_name="full_experience", ) api_caponly_run_button.click( fn=run_caponly_api, inputs=[api_caponly_image_url, api_caponly_image_bytes], outputs=[api_caponly_text_search_url, api_caponly_logs], api_name="caption_only", ) if __name__ == "__main__": demo.queue(max_size=20).launch()