kfahn commited on
Commit
ae56ea3
1 Parent(s): f2056b3

Update app.py

Browse files

Trouble shooting

Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -18,22 +18,22 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
- def infer(prompts, negative_prompts, image):
22
  params["controlnet"] = controlnet_params
23
 
24
  num_samples = 1 #jax.device_count()
25
  rng = create_key(0)
26
  rng = jax.random.split(rng, jax.device_count())
27
- #im = canny_filter(image)
28
- #canny_image = Image.fromarray(im)
29
 
30
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
31
- negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
32
- processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
33
 
34
  p_params = replicate(params)
35
- prompt_ids = shard(prompt_ids)
36
- negative_prompt_ids = shard(negative_prompt_ids)
37
  processed_image = shard(processed_image)
38
 
39
  output = pipe(
@@ -42,7 +42,7 @@ def infer(prompts, negative_prompts, image):
42
  params=p_params,
43
  prng_seed=rng,
44
  num_inference_steps=50,
45
- neg_prompt_ids=negative_prompt_ids,
46
  jit=True,
47
  ).images
48
 
@@ -52,7 +52,7 @@ def infer(prompts, negative_prompts, image):
52
  #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
53
 
54
  title = "Animal Pose Control Net"
55
- description = "This is a demo on ControlNet use animal keypoints."
56
 
57
  with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
58
  gr.Markdown(
@@ -68,7 +68,6 @@ gr.Examples(
68
  ["a tortoiseshell cat is sitting on a cushion", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_0.png"],
69
  ["a yellow dog standing on a lawn", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_1.png"],
70
  ]
71
- cache_examples=True,
72
  )
73
 
74
  gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image",
 
18
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
  )
20
 
21
+ def infer(prompt, image):
22
  params["controlnet"] = controlnet_params
23
 
24
  num_samples = 1 #jax.device_count()
25
  rng = create_key(0)
26
  rng = jax.random.split(rng, jax.device_count())
27
+ im = image
28
+ image = Image.fromarray(im)
29
 
30
+ #prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
31
+ #negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
32
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
33
 
34
  p_params = replicate(params)
35
+ #prompt_ids = shard(prompt_ids)
36
+ #negative_prompt_ids = shard(negative_prompt_ids)
37
  processed_image = shard(processed_image)
38
 
39
  output = pipe(
 
42
  params=p_params,
43
  prng_seed=rng,
44
  num_inference_steps=50,
45
+ #neg_prompt_ids=negative_prompt_ids,
46
  jit=True,
47
  ).images
48
 
 
52
  #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
53
 
54
  title = "Animal Pose Control Net"
55
+ description = "This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning."
56
 
57
  with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
58
  gr.Markdown(
 
68
  ["a tortoiseshell cat is sitting on a cushion", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_0.png"],
69
  ["a yellow dog standing on a lawn", "https://huggingface.co/JFoz/dog-cat-pose/blob/main/images_1.png"],
70
  ]
 
71
  )
72
 
73
  gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image",