File size: 6,766 Bytes
cae212d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164e0dd
 
cae212d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164e0dd
cae212d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import streamlit as st
import cv2
import os
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer

# Function to load the model
def load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id):
    if model_name == 'RealESRGAN_x4plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
    elif model_name == 'RealESRGAN_x4plus_anime_6B':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        netscale = 4
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
    elif model_name == 'RealESRGAN_x2plus':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
        netscale = 2
        file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']

    # Determine model paths
    if model_path is not None:
        model_path = model_path
    else:
        model_path = os.path.join('weights', model_name + '.pth')
        if not os.path.isfile(model_path):
            for url in file_url:
                # Model_path will be updated
                model_path = load_file_from_url(
                    url=url, model_dir=os.path.join(os.getcwd(), 'weights'), progress=True, file_name=model_name + '.pth')

    dni_weight = None
    if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
        model_path = [model_path, model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')]
        dni_weight = [denoise_strength, 1 - denoise_strength]

    # Use DNI to control the denoise strength
    dni_weight = None
    if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
        wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
        model_path = [model_path, wdn_model_path]
        dni_weight = [denoise_strength, 1 - denoise_strength]

    # Restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=tile,
        tile_pad=tile_pad,
        pre_pad=pre_pad,
        half=not fp32,
        gpu_id=gpu_id)
    
    return upsampler

# Function to download model weights if not present
def ensure_model_weights(model_name):
    weights_dir = 'weights'
    model_file = f"{model_name}.pth"
    model_path = os.path.join(weights_dir, model_file)

    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)

    if not os.path.isfile(model_path):
        if model_name == 'RealESRGAN_x4plus':
            file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
        elif model_name == 'RealESRGAN_x4plus_anime_6B':
            file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'
        elif model_name == 'RealESRGAN_x2plus':
            file_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'

        model_path = load_file_from_url(
            url=file_url, model_dir=weights_dir, progress=True, file_name=model_file)

    return model_path

# Streamlit app
st.title("Real-ESRGAN Image Enhancement")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])

# User selects model name, denoise strength, and other parameters
model_name = st.selectbox("Model Name", ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B', 'RealESRGAN_x2plus'])
denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
outscale = st.slider("Output Scale", 1, 4, 2)  # Reduce output scale to 2
tile = st.slider("Tile Size", 0, 512, 256)  # Add tile size slider
tile_pad = 10
pre_pad = 0
face_enhance = st.checkbox("Face Enhance")
fp32 = st.checkbox("Use FP32 Precision")
gpu_id = None  # or set to 0, 1, etc. if you have multiple GPUs

if uploaded_file is not None:
    col1, col2 = st.columns(2)
    with col1:
        st.write("### Original Image")
        st.image(uploaded_file, use_column_width=True)
        run_button = st.button("Run")
    
    # Save uploaded image to disk
    input_image_path = os.path.join("temp", "input_image.png")
    os.makedirs("temp", exist_ok=True)
    with open(input_image_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    if not run_button:
        st.warning("Click the 'Run' button to start the enhancement process.")

    if run_button:
        # Ensure model weights are downloaded
        model_path = ensure_model_weights(model_name)

        # Load the model
        upsampler = load_model(model_name, model_path, denoise_strength, tile, tile_pad, pre_pad, fp32, gpu_id)

        # Load the image
        img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_UNCHANGED)
        if img is None:
            st.error("Error loading image. Please try again.")
        else:
            img_mode = 'RGBA' if len(img.shape) == 3 and img.shape[2] == 4 else None

            try:
                if face_enhance:
                    face_enhancer = GFPGANer(
                        model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
                        upscale=outscale,
                        arch='clean',
                        channel_multiplier=2,
                        bg_upsampler=upsampler)
                    _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
                else:
                    output, _ = upsampler.enhance(img, outscale=outscale)
            except RuntimeError as error:
                st.error(f"Error: {error}")
                st.error('If you encounter CUDA out of memory, try to set a smaller tile size.')
            else:
                # Save and display the output image
                output_image_path = os.path.join("temp", "output_image.png")
                cv2.imwrite(output_image_path, output)
                with col2:
                    st.write("### Enhanced Image")
                    st.image(output_image_path, use_column_width=True)
                    if 'output_image_path' in locals():
                        st.download_button("Download Enhanced Image", data=open(output_image_path, "rb").read(), file_name="output_image.png", mime="image/png")