File size: 4,460 Bytes
2a37fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path

import cv2
from diffusers.utils import logging
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn

try:
    from basicsr.archs.rrdbnet_arch import RRDBNet
    from realesrgan import RealESRGANer
except ImportError as e:
    raise ImportError(
        "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
        "pip install realesrgan"
    )

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class RealESRGANModel(nn.Module):
    def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
        super().__init__()
        try:
            from basicsr.archs.rrdbnet_arch import RRDBNet
            from realesrgan import RealESRGANer
        except ImportError as e:
            raise ImportError(
                "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
                "pip install realesrgan"
            )

        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
        self.upsampler = RealESRGANer(
            scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
        )

    def forward(self, image, outscale=4, convert_to_pil=True):
        """Upsample an image array or path.
        Args:
            image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
                and we convert it to BGR.
            outscale (int, optional): Amount to upscale the image. Defaults to 4.
            convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
        Returns:
            Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
        """
        if isinstance(image, (str, Path)):
            img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
        else:
            img = image
            img = (img * 255).round().astype("uint8")
            img = img[:, :, ::-1]

        image, _ = self.upsampler.enhance(img, outscale=outscale)

        if convert_to_pil:
            image = Image.fromarray(image[:, :, ::-1])

        return image

    @classmethod
    def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
        """Initialize a pretrained Real-ESRGAN upsampler.
        Example:
            ```python
            >>> from stable_diffusion_videos import PipelineRealESRGAN
            >>> pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
            >>> im_out = pipe('input_img.jpg')
            ```
        Args:
            model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
        Returns:
            stable_diffusion_videos.PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
        """
        # reuploaded form official ones mentioned here:
        # https://github.com/xinntao/Real-ESRGAN
        if Path(model_name_or_path).exists():
            file = model_name_or_path
        else:
            file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
        return cls(file)

    def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png", recursive=False, force=False):
        in_dir, out_dir = Path(in_dir), Path(out_dir)
        if not in_dir.exists():
            raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")

        out_dir.mkdir(exist_ok=True, parents=True)

        generator = in_dir.rglob("*") if recursive else in_dir.glob("*")
        image_paths = [x for x in generator if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
        n_img = len(image_paths)
        for i, image in enumerate(image_paths):
            out_filepath = out_dir / (str(image.relative_to(in_dir).with_suffix("")) + suffix + outfile_ext)
            if not force and out_filepath.exists():
                logger.info(
                    f"[{i}/{n_img}] {out_filepath} already exists, skipping. To avoid skipping, pass force=True."
                )
                continue
            logger.info(f"[{i}/{n_img}] upscaling {image}")
            im = self(str(image))
            out_filepath.parent.mkdir(parents=True, exist_ok=True)
            im.save(out_filepath)