EnriqueVega1995 commited on
Commit
8fb43c2
1 Parent(s): 98d9696
Files changed (2) hide show
  1. app.py +42 -21
  2. requirements.txt +4 -3
app.py CHANGED
@@ -1,29 +1,50 @@
1
- import gradio as gr
2
- from diffusers import DDPMPipeline
3
  import torch
 
 
 
 
 
 
 
 
4
 
5
- # Define la función para generar y mostrar la imagen de un gato
6
- def greet_and_generate(name):
7
- # Saludo
8
- greeting = "Hello " + name + "!!"
 
 
9
 
10
- # Carga el modelo DDPM
11
- ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256", use_safetensors=True).to("cpu")
12
 
13
- # Genera la imagen
14
- with torch.no_grad(): # Desactiva el cálculo de gradientes para ahorrar memoria
15
- image = ddpm(num_inference_steps=25).images[0]
 
 
 
 
16
 
17
- # Convierte la imagen a un formato que Gradio puede mostrar (PIL.Image)
18
- image = image.cpu().detach().convert("RGBA")
 
 
 
 
 
 
 
19
 
20
- return greeting, image
21
 
22
- # Configura la interfaz de Gradio
23
- iface = gr.Interface(fn=greet_and_generate,
24
- inputs="text",
25
- outputs=["text", "image"],
26
- examples=[["John"], ["Jane"], ["Alex"]])
 
 
27
 
28
- # Lanza la interfaz de Gradio
29
- iface.launch()
 
 
 
1
  import torch
2
+ import torchvision
3
+ from torchvision import models, transforms
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ # Cargar el modelo preentrenado
8
+ model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
9
+ model.eval()
10
 
11
+ # Función para realizar la detección de objetos
12
+ def object_detection(image):
13
+ # Transformaciones necesarias para la imagen
14
+ transform = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ ])
17
 
18
+ image = transform(image).unsqueeze(0)
19
+ preds = model(image)
20
 
21
+ # Procesar las predicciones
22
+ pred_classes = [torchvision.models.detection._utils.BoxCoder.get_class(i) for i in list(preds[0]['labels'].numpy())] # Nombres de las clases detectadas
23
+ pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(preds[0]['boxes'].detach().numpy())] # Coordenadas de los cuadros de detección
24
+ pred_scores = list(preds[0]['scores'].detach().numpy())
25
+
26
+ # Filtrar las predicciones con una puntuación baja
27
+ pred_t = [pred_scores.index(x) for x in pred_scores if x > 0.5] # Umbral de puntuación
28
 
29
+ if len(pred_t) != 0:
30
+ pred_t = pred_t[-1]
31
+ pred_boxes = pred_boxes[:pred_t+1]
32
+ pred_classes = pred_classes[:pred_t+1]
33
+ pred_scores = pred_scores[:pred_t+1]
34
+ else:
35
+ pred_boxes = []
36
+ pred_classes = []
37
+ pred_scores = []
38
 
39
+ return image, pred_boxes, pred_classes, pred_scores
40
 
41
+ # Interfaz de Gradio
42
+ gr_interface = gr.Interface(fn=object_detection,
43
+ inputs=gr.inputs.Image(shape=(512, 512)),
44
+ outputs=[gr.outputs.Image(type="pil"),
45
+ gr.outputs.Label(num_top_classes=3)],
46
+ title="Detección de Objetos",
47
+ description="Modelo de detección de objetos utilizando un Faster R-CNN ResNet50 preentrenado.")
48
 
49
+ if __name__ == "__main__":
50
+ gr_interface.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- diffusers==0.26.3
2
- gradio==4.20.1
3
- torchvision==0.17.1
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow