File size: 3,966 Bytes
258d8c9
c74095e
975dc6e
c74095e
 
 
 
 
b2b2e3c
258d8c9
c74095e
 
c4964ee
430340e
b2f817c
430340e
c4964ee
3eed896
 
 
 
9dedaff
 
b2f817c
3eed896
24db23e
 
48c7266
258d8c9
c74095e
d186b30
c74095e
 
d186b30
258d8c9
 
090c9fa
 
c74095e
 
 
 
 
090c9fa
c74095e
febb26d
090c9fa
8118b09
c74095e
 
febb26d
090c9fa
8118b09
c74095e
 
 
 
 
 
 
090c9fa
c74095e
3bcd02d
c74095e
65e93e6
b2b2e3c
 
 
3bcd02d
6f00ba2
258d8c9
c131c56
70c9759
6d16f98
 
 
 
 
466cd5e
 
b2f817c
1252ec5
bf379ef
a950a05
466cd5e
430340e
65e93e6
 
 
641777b
b2f817c
 
 
 
 
 
0771b33
 
 
bd9a3cd
65e93e6
b547a14
466cd5e
a600f9f
 
466cd5e
6303c40
280a8d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import gc

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def addp5sketch(url):
   iframe = f'<iframe src ={url} style="border:none;height:300px;width:100%"/frame>'
   return gr.HTML(iframe)

def wandb_report(url):
    iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
    return gr.HTML(iframe)

mytitle = 'Animal Pose Control Net'
mydescription = 'This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.'

report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'
sketch_url = 'https://editor.p5js.org/kfahn/full/Ntzq9HWhx'

control_img = 'myimage.jpg'

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def infer(prompts, negative_prompts, image):

    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    image = Image.fromarray(image)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=50,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images[0,0]
    
    #output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    del image
    gc.collect()
    
    output=np.array(output, dtype=np.float32)
    return output

with gr.Blocks(theme='kfahn/AnimalPose') as demo:  
  gr.Markdown(
      """
      # Animal Pose Control Net
      
      ### This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
      """) 
  with gr.Row():
    with gr.Column():
      prompts  = gr.Textbox(label="Prompt", placeholder="yellow dog standing on a lawn, best quality, highres")
      negative_prompts  = gr.Textbox(label="Negative Prompt", value="lowres, bad muzzle, bad anatomy, missing ears, missing paws")
      conditioning_image = gr.Image(label="Conditioning Image")
      run_btn = gr.Button("Run")
    with gr.Column():
      keypoint_tool = addp5sketch(sketch_url)
      output = gr.Image(
                label="Result",
            )
  gr.Markdown(
      """
      [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)  
      [Diffusers model](https://huggingface.co/JFoz/dog-pose)  
      [Github](https://github.com/fi4cr/animalpose)   
      [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
      """)     

  #run_btn.click(fn=infer, title = mytitle, description = mydescription, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
  run_btn.click(fn=infer, inputs = [prompts, negative_prompts, conditioning_image], outputs = output)
    
#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = output,
            #examples=[["a Labrador crossing the road", "low quality", "myimage.jpg"]])   
    
#with gr.Row():
 #   report = wandb_report(report_url)
    

demo.launch(debug=True)