MuhammadHanif commited on
Commit
cbb6863
1 Parent(s): ff05647

sd high res

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionPipeline
9
+
10
+ def create_key(seed=0):
11
+ return jax.random.PRNGKey(seed)
12
+
13
+
14
+ pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
15
+ "MuhammadHanif/stable-diffusion-v1-5-high-res",
16
+ controlnet=controlnet, revision="flax",
17
+ dtype=jnp.bfloat16,
18
+ use_memory_efficient_attention=True
19
+ )
20
+
21
+ def infer(prompts, negative_prompts):
22
+
23
+ num_samples = 1 #jax.device_count()
24
+ rng = create_key(0)
25
+ rng = jax.random.split(rng, jax.device_count())
26
+
27
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
28
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
29
+
30
+ p_params = replicate(params)
31
+ prompt_ids = shard(prompt_ids)
32
+ negative_prompt_ids = shard(negative_prompt_ids)
33
+
34
+ output = pipe(
35
+ prompt_ids=prompt_ids,
36
+ params=p_params,
37
+ height=1088,
38
+ width=1088,
39
+ prng_seed=rng,
40
+ num_inference_steps=50,
41
+ neg_prompt_ids=negative_prompt_ids,
42
+ jit=True,
43
+ ).images
44
+
45
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
46
+ return output_images
47
+
48
+ gr.Interface(infer, inputs=["text", "text"], outputs="gallery").launch()