import gradio as gr from pathlib import Path import torch import torchvision.transforms as T from PIL import Image from utils import SSLModule from io import BytesIO import matplotlib.pyplot as plt import os # Load the model checkpoints_dir = Path("saved_checkpoints") checkpoint = "SSLhuge_satellite.pth" device = "cpu" ckpt_path = checkpoints_dir / checkpoint model = SSLModule(ssl_path=str(ckpt_path)) model.to(device) model = model.eval() # Define the normalization transform norm = T.Normalize((0.420, 0.411, 0.296), (0.213, 0.156, 0.143)) norm = norm.to(device) # Define a function to make predictions def predict(image): # Convert PIL Image to tensor image_t = torch.tensor(image).permute(2, 0, 1)[:3].float().to(device) / 255 # Normalize the image with torch.no_grad(): pred = model(norm(image_t.unsqueeze(0))) pred = pred.cpu().detach().relu() # Convert tensor to numpy array pred_np = pred[0, 0].numpy() # Save the image to an in-memory buffer buffer = BytesIO() plt.imsave(buffer, pred_np, cmap="Greens") buffer.seek(0) # Rewind the buffer to the beginning # Read the image back from the buffer image_from_buffer = Image.open(buffer) return image_from_buffer # create a Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(label="Upload a Satellite Image"), outputs=gr.Image(label="Estimated Canopy Height"), title="Estimate 🌳 Canopy Height from Satellite Images 🛰️", description="""

This application uses a pre-trained model to estimate canopy height from satellite images. Upload an image and see the result! (You can upload a screenshot from Google Maps, for example).

""", examples=[ ["examples/image.png"], ["examples/image2.png"], ["examples/image3.png"], ], article="

Find more information here.

", allow_flagging=False, ) # launch the interface demo.launch()