test01 / app.py
EnriqueVega1995's picture
test
4e46657 unverified
raw
history blame
990 Bytes
import gradio as gr
from diffusers import DDPMPipeline
import torch
# Cargar el modelo DDPM preentrenado
ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256", use_safetensors=True).to("cuda")
def generate_cat_image(num_inference_steps):
# Generar una imagen de gato
with torch.no_grad():
image = ddpm(num_inference_steps=num_inference_steps)["sample"][0]
# Convertir la imagen de tensor a PIL para mostrarla en Gradio
image = image.permute(1, 2, 0).cpu().numpy()
return image
# Interfaz de Gradio
gr_interface = gr.Interface(fn=generate_cat_image,
inputs=gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Número de Pasos de Inferencia"),
outputs="image",
title="Generador de Imágenes de Gatos",
description="Modelo DDPM para generar imágenes de gatos.")
if __name__ == "__main__":
gr_interface.launch()