meornabil's picture
Update app.py
164e0dd verified
raw
history blame contribute delete
No virus
6.77 kB
import streamlit as st
import cv2
import os
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
# Function to load the model
def load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id):
if model_name == 'RealESRGAN_x4plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
elif model_name == 'RealESRGAN_x4plus_anime_6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
elif model_name == 'RealESRGAN_x2plus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
netscale = 2
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
# Determine model paths
if model_path is not None:
model_path = model_path
else:
model_path = os.path.join('weights', model_name + '.pth')
if not os.path.isfile(model_path):
for url in file_url:
# Model_path will be updated
model_path = load_file_from_url(
url=url, model_dir=os.path.join(os.getcwd(), 'weights'), progress=True, file_name=model_name + '.pth')
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
model_path = [model_path, model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Use DNI to control the denoise strength
dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
# Restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=not fp32,
gpu_id=gpu_id)
return upsampler
# Function to download model weights if not present
def ensure_model_weights(model_name):
weights_dir = 'weights'
model_file = f"{model_name}.pth"
model_path = os.path.join(weights_dir, model_file)
if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
if not os.path.isfile(model_path):
if model_name == 'RealESRGAN_x4plus':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
elif model_name == 'RealESRGAN_x4plus_anime_6B':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
elif model_name == 'RealESRGAN_x2plus':
file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
model_path = load_file_from_url(
url=file_url, model_dir=weights_dir, progress=True, file_name=model_file)
return model_path
# Streamlit app
st.title("Real-ESRGAN Image Enhancement")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# User selects model name, denoise strength, and other parameters
model_name = st.selectbox("Model Name", ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus'])
denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
outscale = st.slider("Output Scale", 1, 4, 2) # Reduce output scale to 2
tile = st.slider("Tile Size", 0, 512, 256) # Add tile size slider
tile_pad = 10
pre_pad = 0
face_enhance = st.checkbox("Face Enhance")
fp32 = st.checkbox("Use FP32 Precision")
gpu_id = None # or set to 0, 1, etc. if you have multiple GPUs
if uploaded_file is not None:
col1, col2 = st.columns(2)
with col1:
st.write("### Original Image")
st.image(uploaded_file, use_column_width=True)
run_button = st.button("Run")
# Save uploaded image to disk
input_image_path = os.path.join("temp", "input_image.png")
os.makedirs("temp", exist_ok=True)
with open(input_image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
if not run_button:
st.warning("Click the 'Run' button to start the enhancement process.")
if run_button:
# Ensure model weights are downloaded
model_path = ensure_model_weights(model_name)
# Load the model
upsampler = load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id)
# Load the image
img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_UNCHANGED)
if img is None:
st.error("Error loading image. Please try again.")
else:
img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None
try:
if face_enhance:
face_enhancer = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=outscale)
except RuntimeError as error:
st.error(f"Error: {error}")
st.error('If you encounter CUDA out of memory, try to set a smaller tile size.')
else:
# Save and display the output image
output_image_path = os.path.join("temp", "output_image.png")
cv2.imwrite(output_image_path, output)
with col2:
st.write("### Enhanced Image")
st.image(output_image_path, use_column_width=True)
if 'output_image_path' in locals():
st.download_button("Download Enhanced Image", data=open(output_image_path, "rb").read(), file_name="output_image.png", mime="image/png")