kfahn commited on
Commit
8118b09
1 Parent(s): cf1cacd

Update app.py

Browse files

Adding control image

Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -18,24 +18,24 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
- #def infer(prompt, image):
22
- def infer(prompt):
23
  params["controlnet"] = controlnet_params
24
 
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
- #im = image
29
- #image = Image.fromarray(im)
30
 
31
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
32
  #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
33
- #processed_image = pipe.prepare_image_inputs([image] * num_samples)
34
 
35
  p_params = replicate(params)
36
  prompt_ids = shard(prompt_ids)
37
  #negative_prompt_ids = shard(negative_prompt_ids)
38
- #processed_image = shard(processed_image)
39
 
40
  output = pipe(
41
  prompt_ids=prompt_ids,
@@ -53,13 +53,6 @@ def infer(prompt):
53
  #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
54
 
55
 
56
-
57
-
58
-
59
-
60
-
61
-
62
-
63
  title = "Animal Pose Control Net"
64
  description = "This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning."
65
 
@@ -75,14 +68,13 @@ description = "This is a demo of Animal Pose ControlNet, which is a model traine
75
  # button_primary_background_fill_hover="*primary_300",
76
  #)
77
 
78
- #gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image",
79
- gr.Interface(fn = infer, inputs = ["text"], outputs = "image",
80
- title = title, description = description, theme='gradio/soft').launch()
81
 
82
- #gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "gallery",
83
- # title = title, description = description, theme='gradio/soft',
84
- # examples=[["a Labrador crossing the road", "low quality", "pose_256.jpg"]]
85
- #).launch()
86
 
87
  gr.Markdown(
88
  """
 
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
+ def infer(prompt, image):
22
+ #def infer(prompt):
23
  params["controlnet"] = controlnet_params
24
 
25
  num_samples = 1 #jax.device_count()
26
  rng = create_key(0)
27
  rng = jax.random.split(rng, jax.device_count())
28
+ im = image
29
+ image = Image.fromarray(im)
30
 
31
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
32
  #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
33
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
34
 
35
  p_params = replicate(params)
36
  prompt_ids = shard(prompt_ids)
37
  #negative_prompt_ids = shard(negative_prompt_ids)
38
+ processed_image = shard(processed_image)
39
 
40
  output = pipe(
41
  prompt_ids=prompt_ids,
 
53
  #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
54
 
55
 
 
 
 
 
 
 
 
56
  title = "Animal Pose Control Net"
57
  description = "This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning."
58
 
 
68
  # button_primary_background_fill_hover="*primary_300",
69
  #)
70
 
71
+ #gr.Interface(fn = infer, inputs = ["text"], outputs = "image",
72
+ # title = title, description = description, theme='gradio/soft').launch()
 
73
 
74
+ gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "gallery",
75
+ title = title, description = description, theme='gradio/soft',
76
+ examples=[["a Labrador crossing the road", "low quality", "image_control.png"]]
77
+ ).launch()
78
 
79
  gr.Markdown(
80
  """