kfahn commited on
Commit
c74095e
1 Parent(s): 0ae97cf

Update app.py

Browse files

Using https://huggingface.co/spaces/jax-diffusers-event/canny_coyo1m/blob/main/app.py as a guide

Files changed (1) hide show
  1. app.py +55 -118
app.py CHANGED
@@ -1,60 +1,58 @@
1
- from PIL import Image
2
  import gradio as gr
3
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
4
- import torch
5
- torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
6
 
7
- controlnet = ControlNetModel.from_pretrained("JFoz/dog-cat-pose", torch_dtype=torch.float16)
 
8
 
9
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
10
- "runwayml/stable-diffusion-v1-5",
11
- controlnet=controlnet,
12
- torch_dtype=torch.float16,
13
- safety_checker=None,
14
  )
15
 
16
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
17
-
18
- pipe.enable_xformers_memory_efficient_attention()
19
- pipe.enable_model_cpu_offload()
20
- pipe.enable_attention_slicing()
21
-
22
- def infer(
23
- prompt,
24
- negative_prompt,
25
- conditioning_image,
26
- num_inference_steps=30,
27
- size=768,
28
- guidance_scale=7.0,
29
- seed=1234,
30
- ):
31
-
32
- conditioning_image_raw = Image.fromarray(conditioning_image)
33
- #conditioning_image = conditioning_image_raw.convert('L')
34
-
35
- g_cpu = torch.Generator()
36
-
37
- if seed == -1:
38
- generator = g_cpu.manual_seed(g_cpu.seed())
39
- else:
40
- generator = g_cpu.manual_seed(seed)
41
-
42
- output_image = pipe(
43
- prompt,
44
- conditioning_image,
45
- height=size,
46
- width=size,
47
- num_inference_steps=num_inference_steps,
48
- generator=generator,
49
- negative_prompt=negative_prompt,
50
- guidance_scale=guidance_scale,
51
- controlnet_conditioning_scale=1.0,
52
- ).images[0]
53
 
54
- #del conditioning_image, conditioning_image_raw
55
- #gc.collect()
56
 
57
- return output_image
 
58
 
59
  with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
60
  gr.Markdown(
@@ -63,83 +61,22 @@ with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata")
63
  # This is a demo of Animal Pose Control Net, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
64
  """)
65
 
66
- with gr.Row():
67
- with gr.Column():
68
- prompt = gr.Textbox(
69
- label="Prompt",
70
- )
71
- negative_prompt = gr.Textbox(
72
- label="Negative Prompt",
73
- )
74
- conditioning_image = gr.Image(
75
- label="Conditioning Image",
76
- )
77
- with gr.Accordion('Advanced options', open=False):
78
- with gr.Row():
79
- num_inference_steps = gr.Slider(
80
- 10, 40, 20,
81
- step=1,
82
- label="Steps",
83
- )
84
- size = gr.Slider(
85
- 256, 768, 512,
86
- step=128,
87
- label="Size",
88
- )
89
- with gr.Row():
90
- guidance_scale = gr.Slider(
91
- label='Guidance Scale',
92
- minimum=0.1,
93
- maximum=30.0,
94
- value=7.0,
95
- step=0.1
96
- )
97
- seed = gr.Slider(
98
- label='Seed',
99
- value=-1,
100
- minimum=-1,
101
- maximum=2147483647,
102
- step=1,
103
- # randomize=True
104
- )
105
- submit_btn = gr.Button(
106
- value="Submit",
107
- variant="primary"
108
- )
109
- with gr.Column(min_width=300):
110
- output = gr.Image(
111
- label="Result",
112
- )
113
-
114
- submit_btn.click(
115
- fn=infer,
116
- inputs=[
117
- prompt, negative_prompt, conditioning_image, num_inference_steps, size, guidance_scale, seed
118
- #prompt, size, seed
119
- ],
120
- outputs=output
121
- )
122
- gr.Examples(
123
  examples=[
124
  #["a tortoiseshell cat is sitting on a cushion"],
125
  #["a yellow dog standing on a lawn"],
126
  ["a tortoiseshell cat is sitting on a cushion", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_0.png"],
127
  ["a yellow dog standing on a lawn", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_1.png"],
128
- ],
129
- inputs=[
130
- #prompt, negative_prompt, conditioning_image
131
- prompt
132
- ],
133
- outputs=output,
134
- fn=infer,
135
  cache_examples=True,
136
  )
137
- gr.Markdown(
 
 
 
 
138
  """
139
  * [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
140
  * [Diffusers model](), [Web UI model](https://huggingface.co/JFoz/dog-pose)
141
  * [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5))
142
  """)
143
-
144
- #gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch()
145
- demo.launch()
 
 
1
  import gradio as gr
2
+ import jax.numpy as jnp
3
+ import jax
4
+ import numpy as np
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
10
 
11
+ def create_key(seed=0):
12
+ return jax.random.PRNGKey(seed)
13
 
14
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
15
+ "JFoz/dog-cat-pose", dtype=jnp.bfloat16
16
+ )
17
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
18
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
+ def infer(prompts, negative_prompts, image):
22
+ params["controlnet"] = controlnet_params
23
+
24
+ num_samples = 1 #jax.device_count()
25
+ rng = create_key(0)
26
+ rng = jax.random.split(rng, jax.device_count())
27
+ #im = canny_filter(image)
28
+ #canny_image = Image.fromarray(im)
29
+
30
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
31
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
32
+ processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
33
+
34
+ p_params = replicate(params)
35
+ prompt_ids = shard(prompt_ids)
36
+ negative_prompt_ids = shard(negative_prompt_ids)
37
+ processed_image = shard(processed_image)
38
+
39
+ output = pipe(
40
+ prompt_ids=prompt_ids,
41
+ image=processed_image,
42
+ params=p_params,
43
+ prng_seed=rng,
44
+ num_inference_steps=50,
45
+ neg_prompt_ids=negative_prompt_ids,
46
+ jit=True,
47
+ ).images
48
+
49
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
50
+ return output_images
 
 
 
 
 
 
 
51
 
52
+ #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
 
53
 
54
+ title = "Animal Pose Control Net"
55
+ description = "This is a demo on ControlNet based on canny filter."
56
 
57
  with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
58
  gr.Markdown(
 
61
  # This is a demo of Animal Pose Control Net, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
62
  """)
63
 
64
+ gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  examples=[
66
  #["a tortoiseshell cat is sitting on a cushion"],
67
  #["a yellow dog standing on a lawn"],
68
  ["a tortoiseshell cat is sitting on a cushion", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_0.png"],
69
  ["a yellow dog standing on a lawn", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_1.png"],
70
+ ]
 
 
 
 
 
 
71
  cache_examples=True,
72
  )
73
+
74
+ gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image",
75
+ title = title, description = description, examples = gr.examples, theme='gradio/soft').launch()
76
+
77
+ gr.Markdown(
78
  """
79
  * [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
80
  * [Diffusers model](), [Web UI model](https://huggingface.co/JFoz/dog-pose)
81
  * [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5))
82
  """)