import gradio as gr import jax import jax.numpy as jnp import numpy as np from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel import gc def create_key(seed=0): return jax.random.PRNGKey(seed) def addp5sketch(url): iframe = f'