yijiu commited on
Commit
383ca9f
1 Parent(s): 44207e0

feat:recover image resolution

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. tool_utils.py +1 -1
app.py CHANGED
@@ -31,16 +31,16 @@ def predict(numpy_img):
31
  # #convert output to numpy
32
  heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
33
  # #heatmaps to joints location
34
- coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[256,256],prob_threshold=0.1)
35
  inference_time=time.time()-start_time
36
  inference_time_text="model inference time:{:.4f}s".format(inference_time)
37
  # #draw coords on image_np
38
- img_rgb=draw_joints(img_np,coord_joints)
39
  return img_rgb,inference_time_text
40
 
41
 
42
 
43
- demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy',width=256,height=256),"text"],examples=example_list)
44
 
45
  if __name__=="__main__":
46
  demo.launch(show_api=False)
 
31
  # #convert output to numpy
32
  heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
33
  # #heatmaps to joints location
34
+ coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[numpy_img.shape[0],numpy_img.shape[1]],prob_threshold=0.1)
35
  inference_time=time.time()-start_time
36
  inference_time_text="model inference time:{:.4f}s".format(inference_time)
37
  # #draw coords on image_np
38
+ img_rgb=draw_joints(numpy_img,coord_joints)
39
  return img_rgb,inference_time_text
40
 
41
 
42
 
43
+ demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy'),"text"],examples=example_list)
44
 
45
  if __name__=="__main__":
46
  demo.launch(show_api=False)
tool_utils.py CHANGED
@@ -344,7 +344,7 @@ def heatmaps2rgb(heatmaps):
344
  # return img
345
  def draw_joints(img, pts):
346
  # Convert the image to the range [0, 255] for visualization
347
- img_visualization = (img * 255).astype(np.uint8)
348
 
349
  # Draw lines for the body parts
350
  for i in range(10, 13 - 1):
 
344
  # return img
345
  def draw_joints(img, pts):
346
  # Convert the image to the range [0, 255] for visualization
347
+ img_visualization = (img).astype(np.uint8)
348
 
349
  # Draw lines for the body parts
350
  for i in range(10, 13 - 1):