import streamlit as st import cv2 import os import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact from gfpgan import GFPGANer # Function to load the model def load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id): if model_name == 'RealESRGAN_x4plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] elif model_name == 'RealESRGAN_x4plus_anime_6B': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) netscale = 4 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] elif model_name == 'RealESRGAN_x2plus': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) netscale = 2 file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] # Determine model paths if model_path is not None: model_path = model_path else: model_path = os.path.join('weights', model_name + '.pth') if not os.path.isfile(model_path): for url in file_url: # Model_path will be updated model_path = load_file_from_url( url=url, model_dir=os.path.join(os.getcwd(), 'weights'), progress=True, file_name=model_name + '.pth') dni_weight = None if model_name == 'realesr-general-x4v3' and denoise_strength != 1: model_path = [model_path, model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')] dni_weight = [denoise_strength, 1 - denoise_strength] # Use DNI to control the denoise strength dni_weight = None if model_name == 'realesr-general-x4v3' and denoise_strength != 1: wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') model_path = [model_path, wdn_model_path] dni_weight = [denoise_strength, 1 - denoise_strength] # Restorer upsampler = RealESRGANer( scale=netscale, model_path=model_path, dni_weight=dni_weight, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32, gpu_id=gpu_id) return upsampler # Function to download model weights if not present def ensure_model_weights(model_name): weights_dir = 'weights' model_file = f"{model_name}.pth" model_path = os.path.join(weights_dir, model_file) if not os.path.exists(weights_dir): os.makedirs(weights_dir) if not os.path.isfile(model_path): if model_name == 'RealESRGAN_x4plus': file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' elif model_name == 'RealESRGAN_x4plus_anime_6B': file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth' elif model_name == 'RealESRGAN_x2plus': file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth' model_path = load_file_from_url( url=file_url, model_dir=weights_dir, progress=True, file_name=model_file) return model_path # Streamlit app st.title("Real-ESRGAN Image Enhancement") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) # User selects model name, denoise strength, and other parameters model_name = st.selectbox("Model Name", ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus']) denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5) outscale = st.slider("Output Scale", 1, 4, 2) # Reduce output scale to 2 tile = st.slider("Tile Size", 0, 512, 256) # Add tile size slider tile_pad = 10 pre_pad = 0 face_enhance = st.checkbox("Face Enhance") fp32 = st.checkbox("Use FP32 Precision") gpu_id = None # or set to 0, 1, etc. if you have multiple GPUs if uploaded_file is not None: col1, col2 = st.columns(2) with col1: st.write("### Original Image") st.image(uploaded_file, use_column_width=True) run_button = st.button("Run") # Save uploaded image to disk input_image_path = os.path.join("temp", "input_image.png") os.makedirs("temp", exist_ok=True) with open(input_image_path, "wb") as f: f.write(uploaded_file.getbuffer()) if not run_button: st.warning("Click the 'Run' button to start the enhancement process.") if run_button: # Ensure model weights are downloaded model_path = ensure_model_weights(model_name) # Load the model upsampler = load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id) # Load the image img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_UNCHANGED) if img is None: st.error("Error loading image. Please try again.") else: img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None try: if face_enhance: face_enhancer = GFPGANer( model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', upscale=outscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) else: output, _ = upsampler.enhance(img, outscale=outscale) except RuntimeError as error: st.error(f"Error: {error}") st.error('If you encounter CUDA out of memory, try to set a smaller tile size.') else: # Save and display the output image output_image_path = os.path.join("temp", "output_image.png") cv2.imwrite(output_image_path, output) with col2: st.write("### Enhanced Image") st.image(output_image_path, use_column_width=True) if 'output_image_path' in locals(): st.download_button("Download Enhanced Image", data=open(output_image_path, "rb").read(), file_name="output_image.png", mime="image/png")