.gitignore → CodeFormer/.gitignore RENAMED
@@ -5,9 +5,9 @@ version.py
5
 
6
  # ignored files with suffix
7
  *.html
8
- *.png
9
- *.jpeg
10
- *.jpg
11
  *.pt
12
  *.gif
13
  *.pth
@@ -122,8 +122,7 @@ venv.bak/
122
  .mypy_cache/
123
 
124
  # project
125
- CodeFormer/results/
126
- output/
127
  dlib/
128
  *.pth
129
  *_old*
 
5
 
6
  # ignored files with suffix
7
  *.html
8
+ # *.png
9
+ # *.jpeg
10
+ # *.jpg
11
  *.pt
12
  *.gif
13
  *.pth
 
122
  .mypy_cache/
123
 
124
  # project
125
+ results/
 
126
  dlib/
127
  *.pth
128
  *_old*
CodeFormer/basicsr/utils/misc.py CHANGED
@@ -1,36 +1,13 @@
 
1
  import os
2
- import re
3
  import random
4
  import time
5
  import torch
6
- import numpy as np
7
  from os import path as osp
8
 
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
- IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
13
- torch.__version__)[0][:3])] >= [1, 12, 0]
14
-
15
- def gpu_is_available():
16
- if IS_HIGH_VERSION:
17
- if torch.backends.mps.is_available():
18
- return True
19
- return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
20
-
21
- def get_device(gpu_id=None):
22
- if gpu_id is None:
23
- gpu_str = ''
24
- elif isinstance(gpu_id, int):
25
- gpu_str = f':{gpu_id}'
26
- else:
27
- raise TypeError('Input should be int value.')
28
-
29
- if IS_HIGH_VERSION:
30
- if torch.backends.mps.is_available():
31
- return torch.device('mps'+gpu_str)
32
- return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
33
-
34
 
35
  def set_random_seed(seed):
36
  """Set random seeds."""
@@ -154,4 +131,4 @@ def sizeof_fmt(size, suffix='B'):
154
  if abs(size) < 1024.0:
155
  return f'{size:3.1f} {unit}{suffix}'
156
  size /= 1024.0
157
- return f'{size:3.1f} Y{suffix}'
 
1
+ import numpy as np
2
  import os
 
3
  import random
4
  import time
5
  import torch
 
6
  from os import path as osp
7
 
8
  from .dist_util import master_only
9
  from .logger import get_root_logger
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def set_random_seed(seed):
13
  """Set random seeds."""
 
131
  if abs(size) < 1024.0:
132
  return f'{size:3.1f} {unit}{suffix}'
133
  size /= 1024.0
134
+ return f'{size:3.1f} Y{suffix}'
CodeFormer/basicsr/version.py CHANGED
@@ -1,5 +1,5 @@
1
  # GENERATED VERSION FILE
2
- # TIME: Sat Sep 21 15:31:46 2024
3
  __version__ = '1.3.2'
4
- __gitsha__ = '1.3.2'
5
  version_info = (1, 3, 2)
 
1
  # GENERATED VERSION FILE
2
+ # TIME: Sun Aug 7 15:14:26 2022
3
  __version__ = '1.3.2'
4
+ __gitsha__ = '6f94023'
5
  version_info = (1, 3, 2)
CodeFormer/facelib/utils/face_restoration_helper.py CHANGED
@@ -6,14 +6,8 @@ from torchvision.transforms.functional import normalize
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
- from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
10
- from basicsr.utils.download_util import load_file_from_url
11
- from basicsr.utils.misc import get_device
12
 
13
- dlib_model_url = {
14
- 'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
15
- 'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
16
- }
17
 
18
  def get_largest_face(det_faces, h, w):
19
 
@@ -70,15 +64,8 @@ class FaceRestoreHelper(object):
70
  self.crop_ratio = crop_ratio # (h, w)
71
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
72
  self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
73
- self.det_model = det_model
74
-
75
- if self.det_model == 'dlib':
76
- # standard 5 landmarks for FFHQ faces with 1024 x 1024
77
- self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
78
- [337.91089109, 488.38613861], [437.95049505, 493.51485149],
79
- [513.58415842, 678.5049505]])
80
- self.face_template = self.face_template / (1024 // face_size)
81
- elif self.template_3points:
82
  self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
83
  else:
84
  # standard 5 landmarks for FFHQ faces with 512 x 512
@@ -90,6 +77,7 @@ class FaceRestoreHelper(object):
90
  # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
91
  # [198.22603, 372.82502], [313.91018, 372.75659]])
92
 
 
93
  self.face_template = self.face_template * (face_size / 512.0)
94
  if self.crop_ratio[0] > 1:
95
  self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
@@ -109,16 +97,12 @@ class FaceRestoreHelper(object):
109
  self.pad_input_imgs = []
110
 
111
  if device is None:
112
- # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
113
- self.device = get_device()
114
  else:
115
  self.device = device
116
 
117
  # init face detection model
118
- if self.det_model == 'dlib':
119
- self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
120
- else:
121
- self.face_detector = init_detection_model(det_model, half=False, device=self.device)
122
 
123
  # init face parsing model
124
  self.use_parse = use_parse
@@ -141,7 +125,7 @@ class FaceRestoreHelper(object):
141
  img = img[:, :, 0:3]
142
 
143
  self.input_img = img
144
- self.is_gray = is_gray(img, threshold=10)
145
  if self.is_gray:
146
  print('Grayscale input: True')
147
 
@@ -149,72 +133,25 @@ class FaceRestoreHelper(object):
149
  f = 512.0/min(self.input_img.shape[:2])
150
  self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
151
 
152
- def init_dlib(self, detection_path, landmark5_path):
153
- """Initialize the dlib detectors and predictors."""
154
- try:
155
- import dlib
156
- except ImportError:
157
- print('Please install dlib by running:' 'conda install -c conda-forge dlib')
158
- detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
159
- landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
160
- face_detector = dlib.cnn_face_detection_model_v1(detection_path)
161
- shape_predictor_5 = dlib.shape_predictor(landmark5_path)
162
- return face_detector, shape_predictor_5
163
-
164
- def get_face_landmarks_5_dlib(self,
165
- only_keep_largest=False,
166
- scale=1):
167
- det_faces = self.face_detector(self.input_img, scale)
168
-
169
- if len(det_faces) == 0:
170
- print('No face detected. Try to increase upsample_num_times.')
171
- return 0
172
- else:
173
- if only_keep_largest:
174
- print('Detect several faces and only keep the largest.')
175
- face_areas = []
176
- for i in range(len(det_faces)):
177
- face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
178
- det_faces[i].rect.bottom() - det_faces[i].rect.top())
179
- face_areas.append(face_area)
180
- largest_idx = face_areas.index(max(face_areas))
181
- self.det_faces = [det_faces[largest_idx]]
182
- else:
183
- self.det_faces = det_faces
184
-
185
- if len(self.det_faces) == 0:
186
- return 0
187
-
188
- for face in self.det_faces:
189
- shape = self.shape_predictor_5(self.input_img, face.rect)
190
- landmark = np.array([[part.x, part.y] for part in shape.parts()])
191
- self.all_landmarks_5.append(landmark)
192
-
193
- return len(self.all_landmarks_5)
194
-
195
-
196
  def get_face_landmarks_5(self,
197
  only_keep_largest=False,
198
  only_center_face=False,
199
  resize=None,
200
  blur_ratio=0.01,
201
  eye_dist_threshold=None):
202
- if self.det_model == 'dlib':
203
- return self.get_face_landmarks_5_dlib(only_keep_largest)
204
-
205
  if resize is None:
206
  scale = 1
207
  input_img = self.input_img
208
  else:
209
  h, w = self.input_img.shape[0:2]
210
  scale = resize / min(h, w)
211
- # scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
212
  h, w = int(h * scale), int(w * scale)
213
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
214
  input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
215
 
216
  with torch.no_grad():
217
- bboxes = self.face_detector.detect_faces(input_img)
218
 
219
  if bboxes is None or bboxes.shape[0] == 0:
220
  return 0
@@ -361,12 +298,10 @@ class FaceRestoreHelper(object):
361
  torch.save(inverse_affine, save_path)
362
 
363
 
364
- def add_restored_face(self, restored_face, input_face=None):
365
  if self.is_gray:
366
- restored_face = bgr2gray(restored_face) # convert img into grayscale
367
- if input_face is not None:
368
- restored_face = adain_npy(restored_face, input_face) # transfer the color
369
- self.restored_faces.append(restored_face)
370
 
371
 
372
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
 
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
+ from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
 
 
10
 
 
 
 
 
11
 
12
  def get_largest_face(det_faces, h, w):
13
 
 
64
  self.crop_ratio = crop_ratio # (h, w)
65
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
66
  self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
67
+
68
+ if self.template_3points:
 
 
 
 
 
 
 
69
  self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
70
  else:
71
  # standard 5 landmarks for FFHQ faces with 512 x 512
 
77
  # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
78
  # [198.22603, 372.82502], [313.91018, 372.75659]])
79
 
80
+
81
  self.face_template = self.face_template * (face_size / 512.0)
82
  if self.crop_ratio[0] > 1:
83
  self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
 
97
  self.pad_input_imgs = []
98
 
99
  if device is None:
100
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
101
  else:
102
  self.device = device
103
 
104
  # init face detection model
105
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
 
 
 
106
 
107
  # init face parsing model
108
  self.use_parse = use_parse
 
125
  img = img[:, :, 0:3]
126
 
127
  self.input_img = img
128
+ self.is_gray = is_gray(img, threshold=5)
129
  if self.is_gray:
130
  print('Grayscale input: True')
131
 
 
133
  f = 512.0/min(self.input_img.shape[:2])
134
  self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def get_face_landmarks_5(self,
137
  only_keep_largest=False,
138
  only_center_face=False,
139
  resize=None,
140
  blur_ratio=0.01,
141
  eye_dist_threshold=None):
 
 
 
142
  if resize is None:
143
  scale = 1
144
  input_img = self.input_img
145
  else:
146
  h, w = self.input_img.shape[0:2]
147
  scale = resize / min(h, w)
148
+ scale = max(1, scale) # always scale up
149
  h, w = int(h * scale), int(w * scale)
150
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
151
  input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
152
 
153
  with torch.no_grad():
154
+ bboxes = self.face_det.detect_faces(input_img)
155
 
156
  if bboxes is None or bboxes.shape[0] == 0:
157
  return 0
 
298
  torch.save(inverse_affine, save_path)
299
 
300
 
301
+ def add_restored_face(self, face):
302
  if self.is_gray:
303
+ face = bgr2gray(face) # convert img into grayscale
304
+ self.restored_faces.append(face)
 
 
305
 
306
 
307
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
CodeFormer/facelib/utils/misc.py CHANGED
@@ -7,13 +7,13 @@ import torch
7
  from torch.hub import download_url_to_file, get_dir
8
  from urllib.parse import urlparse
9
  # from basicsr.utils.download_util import download_file_from_google_drive
 
 
10
 
11
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
 
13
 
14
  def download_pretrained_models(file_ids, save_path_root):
15
- import gdown
16
-
17
  os.makedirs(save_path_root, exist_ok=True)
18
 
19
  for file_name, file_id in file_ids.items():
@@ -23,7 +23,7 @@ def download_pretrained_models(file_ids, save_path_root):
23
  user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
24
  if user_response.lower() == 'y':
25
  print(f'Covering {file_name} to {save_path}')
26
- gdown.download(file_url, save_path, quiet=False)
27
  # download_file_from_google_drive(file_id, save_path)
28
  elif user_response.lower() == 'n':
29
  print(f'Skipping {file_name}')
@@ -31,7 +31,7 @@ def download_pretrained_models(file_ids, save_path_root):
31
  raise ValueError('Wrong input. Only accepts Y/N.')
32
  else:
33
  print(f'Downloading {file_name} to {save_path}')
34
- gdown.download(file_url, save_path, quiet=False)
35
  # download_file_from_google_drive(file_id, save_path)
36
 
37
 
@@ -172,31 +172,3 @@ def bgr2gray(img, out_channel=3):
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
175
-
176
-
177
- def calc_mean_std(feat, eps=1e-5):
178
- """
179
- Args:
180
- feat (numpy): 3D [w h c]s
181
- """
182
- size = feat.shape
183
- assert len(size) == 3, 'The input feature should be 3D tensor.'
184
- c = size[2]
185
- feat_var = feat.reshape(-1, c).var(axis=0) + eps
186
- feat_std = np.sqrt(feat_var).reshape(1, 1, c)
187
- feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
188
- return feat_mean, feat_std
189
-
190
-
191
- def adain_npy(content_feat, style_feat):
192
- """Adaptive instance normalization for numpy.
193
-
194
- Args:
195
- content_feat (numpy): The input feature.
196
- style_feat (numpy): The reference feature.
197
- """
198
- size = content_feat.shape
199
- style_mean, style_std = calc_mean_std(style_feat)
200
- content_mean, content_std = calc_mean_std(content_feat)
201
- normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
202
- return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
 
7
  from torch.hub import download_url_to_file, get_dir
8
  from urllib.parse import urlparse
9
  # from basicsr.utils.download_util import download_file_from_google_drive
10
+ # import gdown
11
+
12
 
13
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
 
15
 
16
  def download_pretrained_models(file_ids, save_path_root):
 
 
17
  os.makedirs(save_path_root, exist_ok=True)
18
 
19
  for file_name, file_id in file_ids.items():
 
23
  user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
24
  if user_response.lower() == 'y':
25
  print(f'Covering {file_name} to {save_path}')
26
+ # gdown.download(file_url, save_path, quiet=False)
27
  # download_file_from_google_drive(file_id, save_path)
28
  elif user_response.lower() == 'n':
29
  print(f'Skipping {file_name}')
 
31
  raise ValueError('Wrong input. Only accepts Y/N.')
32
  else:
33
  print(f'Downloading {file_name} to {save_path}')
34
+ # gdown.download(file_url, save_path, quiet=False)
35
  # download_file_from_google_drive(file_id, save_path)
36
 
37
 
 
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
CodeFormer/inference_codeformer.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import cv2
3
  import argparse
@@ -6,9 +7,8 @@ import torch
6
  from torchvision.transforms.functional import normalize
7
  from basicsr.utils import imwrite, img2tensor, tensor2img
8
  from basicsr.utils.download_util import load_file_from_url
9
- from basicsr.utils.misc import gpu_is_available, get_device
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
- from facelib.utils.misc import is_gray
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
@@ -17,104 +17,51 @@ pretrain_model_url = {
17
  }
18
 
19
  def set_realesrgan():
20
- from basicsr.archs.rrdbnet_arch import RRDBNet
21
- from basicsr.utils.realesrgan_utils import RealESRGANer
22
-
23
- use_half = False
24
- if torch.cuda.is_available(): # set False in CPU/MPS mode
25
- no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
26
- if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
27
- use_half = True
28
-
29
- model = RRDBNet(
30
- num_in_ch=3,
31
- num_out_ch=3,
32
- num_feat=64,
33
- num_block=23,
34
- num_grow_ch=32,
35
- scale=2,
36
- )
37
- upsampler = RealESRGANer(
38
- scale=2,
39
- model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
40
- model=model,
41
- tile=args.bg_tile,
42
- tile_pad=40,
43
- pre_pad=0,
44
- half=use_half
45
- )
46
-
47
- if not gpu_is_available(): # CPU
48
  import warnings
49
- warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
50
- 'The unoptimized RealESRGAN is slow on CPU. '
51
- 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
52
  category=RuntimeWarning)
53
- return upsampler
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if __name__ == '__main__':
56
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
- device = get_device()
58
  parser = argparse.ArgumentParser()
59
 
60
- parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
61
- help='Input image, video or folder. Default: inputs/whole_imgs')
62
- parser.add_argument('-o', '--output_path', type=str, default=None,
63
- help='Output folder. Default: results/<input_name>_<w>')
64
- parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
65
- help='Balance the quality and fidelity. Default: 0.5')
66
- parser.add_argument('-s', '--upscale', type=int, default=2,
67
- help='The final upsampling scale of the image. Default: 2')
68
- parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
69
- parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
70
- parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
71
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
72
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
73
- parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
74
- help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
75
- Default: retinaface_resnet50')
76
- parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
77
- parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
78
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
79
- parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
80
- parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
81
 
82
  args = parser.parse_args()
83
 
84
  # ------------------------ input & output ------------------------
85
- w = args.fidelity_weight
86
- input_video = False
87
- if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
88
- input_img_list = [args.input_path]
89
- result_root = f'results/test_img_{w}'
90
- elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
91
- from basicsr.utils.video_util import VideoReader, VideoWriter
92
- input_img_list = []
93
- vidreader = VideoReader(args.input_path)
94
- image = vidreader.get_frame()
95
- while image is not None:
96
- input_img_list.append(image)
97
- image = vidreader.get_frame()
98
- audio = vidreader.get_audio()
99
- fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
100
- video_name = os.path.basename(args.input_path)[:-4]
101
- result_root = f'results/{video_name}_{w}'
102
- input_video = True
103
- vidreader.close()
104
- else: # input img folder
105
- if args.input_path.endswith('/'): # solve when path ends with /
106
- args.input_path = args.input_path[:-1]
107
- # scan all the jpg and png images
108
- input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
109
- result_root = f'results/{os.path.basename(args.input_path)}_{w}'
110
-
111
- if not args.output_path is None: # set output path
112
- result_root = args.output_path
113
 
114
- test_img_num = len(input_img_list)
115
- if test_img_num == 0:
116
- raise FileNotFoundError('No input image/video is found...\n'
117
- '\tNote that --input_path for video should end with .mp4|.mov|.avi')
118
 
119
  # ------------------ set up background upsampler ------------------
120
  if args.bg_upsampler == 'realesrgan':
@@ -162,27 +109,19 @@ if __name__ == '__main__':
162
  device=device)
163
 
164
  # -------------------- start to processing ---------------------
165
- for i, img_path in enumerate(input_img_list):
 
166
  # clean all the intermediate results to process the next image
167
  face_helper.clean_all()
168
 
169
- if isinstance(img_path, str):
170
- img_name = os.path.basename(img_path)
171
- basename, ext = os.path.splitext(img_name)
172
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
173
- img = cv2.imread(img_path, cv2.IMREAD_COLOR)
174
- else: # for video processing
175
- basename = str(i).zfill(6)
176
- img_name = f'{video_name}_{basename}' if input_video else basename
177
- print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
178
- img = img_path
179
 
180
  if args.has_aligned:
181
  # the input faces are already cropped and aligned
182
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
183
- face_helper.is_gray = is_gray(img, threshold=10)
184
- if face_helper.is_gray:
185
- print('Grayscale input: True')
186
  face_helper.cropped_faces = [img]
187
  else:
188
  face_helper.read_image(img)
@@ -211,7 +150,7 @@ if __name__ == '__main__':
211
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
212
 
213
  restored_face = restored_face.astype('uint8')
214
- face_helper.add_restored_face(restored_face, cropped_face)
215
 
216
  # paste_back
217
  if not args.has_aligned:
@@ -239,36 +178,12 @@ if __name__ == '__main__':
239
  save_face_name = f'{basename}.png'
240
  else:
241
  save_face_name = f'{basename}_{idx:02d}.png'
242
- if args.suffix is not None:
243
- save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
244
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
245
  imwrite(restored_face, save_restore_path)
246
 
247
  # save restored img
248
  if not args.has_aligned and restored_img is not None:
249
- if args.suffix is not None:
250
- basename = f'{basename}_{args.suffix}'
251
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
252
  imwrite(restored_img, save_restore_path)
253
 
254
- # save enhanced video
255
- if input_video:
256
- print('Video Saving...')
257
- # load images
258
- video_frames = []
259
- img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
260
- for img_path in img_list:
261
- img = cv2.imread(img_path)
262
- video_frames.append(img)
263
- # write images to video
264
- height, width = video_frames[0].shape[:2]
265
- if args.suffix is not None:
266
- video_name = f'{video_name}_{args.suffix}.png'
267
- save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
268
- vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
269
-
270
- for f in video_frames:
271
- vidwriter.write_frame(f)
272
- vidwriter.close()
273
-
274
- print(f'\nAll results are saved in {result_root}')
 
1
+ # Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
2
  import os
3
  import cv2
4
  import argparse
 
7
  from torchvision.transforms.functional import normalize
8
  from basicsr.utils import imwrite, img2tensor, tensor2img
9
  from basicsr.utils.download_util import load_file_from_url
 
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
+ import torch.nn.functional as F
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
 
17
  }
18
 
19
  def set_realesrgan():
20
+ if not torch.cuda.is_available(): # CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import warnings
22
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
+ 'If you really want to use it, please modify the corresponding codes.',
 
24
  category=RuntimeWarning)
25
+ bg_upsampler = None
26
+ else:
27
+ from basicsr.archs.rrdbnet_arch import RRDBNet
28
+ from basicsr.utils.realesrgan_utils import RealESRGANer
29
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
30
+ bg_upsampler = RealESRGANer(
31
+ scale=2,
32
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
33
+ model=model,
34
+ tile=args.bg_tile,
35
+ tile_pad=40,
36
+ pre_pad=0,
37
+ half=True) # need to set False in CPU mode
38
+ return bg_upsampler
39
 
40
  if __name__ == '__main__':
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
42
  parser = argparse.ArgumentParser()
43
 
44
+ parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
45
+ parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
46
+ parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
47
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
48
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
 
 
 
 
 
 
49
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
50
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
51
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
52
+ parser.add_argument('--draw_box', action='store_true')
53
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
54
+ parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
 
55
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
 
 
56
 
57
  args = parser.parse_args()
58
 
59
  # ------------------------ input & output ------------------------
60
+ if args.test_path.endswith('/'): # solve when path ends with /
61
+ args.test_path = args.test_path[:-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ w = args.w
64
+ result_root = f'results/{os.path.basename(args.test_path)}_{w}'
 
 
65
 
66
  # ------------------ set up background upsampler ------------------
67
  if args.bg_upsampler == 'realesrgan':
 
109
  device=device)
110
 
111
  # -------------------- start to processing ---------------------
112
+ # scan all the jpg and png images
113
+ for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
114
  # clean all the intermediate results to process the next image
115
  face_helper.clean_all()
116
 
117
+ img_name = os.path.basename(img_path)
118
+ print(f'Processing: {img_name}')
119
+ basename, ext = os.path.splitext(img_name)
120
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
 
 
 
 
 
 
121
 
122
  if args.has_aligned:
123
  # the input faces are already cropped and aligned
124
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
 
 
 
125
  face_helper.cropped_faces = [img]
126
  else:
127
  face_helper.read_image(img)
 
150
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
151
 
152
  restored_face = restored_face.astype('uint8')
153
+ face_helper.add_restored_face(restored_face)
154
 
155
  # paste_back
156
  if not args.has_aligned:
 
178
  save_face_name = f'{basename}.png'
179
  else:
180
  save_face_name = f'{basename}_{idx:02d}.png'
 
 
181
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
182
  imwrite(restored_face, save_restore_path)
183
 
184
  # save restored img
185
  if not args.has_aligned and restored_img is not None:
 
 
186
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
187
  imwrite(restored_img, save_restore_path)
188
 
189
+ print(f'\nAll results are saved in {result_root}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,9 +4,9 @@ emoji: 🐼
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -16,9 +16,9 @@ from torchvision.transforms.functional import normalize
16
  from basicsr.utils import imwrite, img2tensor, tensor2img
17
  from basicsr.utils.download_util import load_file_from_url
18
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
 
19
  from basicsr.archs.rrdbnet_arch import RRDBNet
20
  from basicsr.utils.realesrgan_utils import RealESRGANer
21
- from facelib.utils.misc import is_gray
22
 
23
  from basicsr.utils.registry import ARCH_REGISTRY
24
 
@@ -57,9 +57,6 @@ torch.hub.download_url_to_file(
57
  torch.hub.download_url_to_file(
58
  'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
59
  '05.jpg')
60
- torch.hub.download_url_to_file(
61
- 'https://raw.githubusercontent.com/sczhou/CodeFormer/master/inputs/cropped_faces/0729.png',
62
- '06.png')
63
 
64
  def imread(img_path):
65
  img = cv2.imread(img_path)
@@ -104,23 +101,20 @@ codeformer_net.eval()
104
 
105
  os.makedirs('output', exist_ok=True)
106
 
107
- def inference(image, face_align, background_enhance, face_upsample, upscale, codeformer_fidelity):
108
  """Run a single prediction on the model"""
109
  try: # global try
110
  # take the default setting for the demo
 
111
  only_center_face = False
112
  draw_box = False
113
  detection_model = "retinaface_resnet50"
114
-
115
  print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
116
- face_align = face_align if face_align is not None else True
117
  background_enhance = background_enhance if background_enhance is not None else True
118
  face_upsample = face_upsample if face_upsample is not None else True
119
  upscale = upscale if (upscale is not None and upscale > 0) else 2
120
 
121
- has_aligned = not face_align
122
- upscale = 1 if has_aligned else upscale
123
-
124
  img = cv2.imread(str(image), cv2.IMREAD_COLOR)
125
  print('\timage size:', img.shape)
126
 
@@ -166,7 +160,9 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
166
  # face restoration for each cropped face
167
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
168
  # prepare data
169
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
 
 
170
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
171
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
172
 
@@ -180,10 +176,12 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
180
  torch.cuda.empty_cache()
181
  except RuntimeError as error:
182
  print(f"Failed inference for CodeFormer: {error}")
183
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
 
 
184
 
185
  restored_face = restored_face.astype("uint8")
186
- face_helper.add_restored_face(restored_face, cropped_face)
187
 
188
  # paste_back
189
  if not has_aligned:
@@ -205,8 +203,6 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
205
  restored_img = face_helper.paste_faces_to_input_image(
206
  upsample_img=bg_img, draw_box=draw_box
207
  )
208
- else:
209
- restored_img = restored_face
210
 
211
  # save restored img
212
  save_path = f'output/out.png'
@@ -260,12 +256,6 @@ If you have any questions, please feel free to reach me out at <b>shangchenzhou@
260
  td {
261
  padding-right: 0px !important;
262
  }
263
-
264
- .gradio-container-4-37-2 .prose table, .gradio-container-4-37-2 .prose tr, .gradio-container-4-37-2 .prose td, .gradio-container-4-37-2 .prose th {
265
- border: 0px solid #ffffff;
266
- border-bottom: 0px solid #ffffff;
267
- }
268
-
269
  </style>
270
 
271
  <table>
@@ -281,28 +271,24 @@ td {
281
  demo = gr.Interface(
282
  inference, [
283
  gr.Image(type="filepath", label="Input"),
284
- gr.Checkbox(value=True, label="Pre_Face_Align"),
285
  gr.Checkbox(value=True, label="Background_Enhance"),
286
  gr.Checkbox(value=True, label="Face_Upsample"),
287
  gr.Number(value=2, label="Rescaling_Factor (up to 4)"),
288
  gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
289
  ], [
290
- gr.Image(type="numpy", label="Output")
291
  ],
292
  title=title,
293
  description=description,
294
  article=article,
295
  examples=[
296
- ['01.png', True, True, True, 2, 0.7],
297
- ['02.jpg', True, True, True, 2, 0.7],
298
- ['03.jpg', True, True, True, 2, 0.7],
299
- ['04.jpg', True, True, True, 2, 0.1],
300
- ['05.jpg', True, True, True, 2, 0.1],
301
- ['06.png', False, True, True, 1, 0.5]
302
- ],
303
- concurrency_limit=2
304
- )
305
 
306
  DEBUG = os.getenv('DEBUG') == '1'
307
- # demo.launch(debug=DEBUG)
308
- demo.launch(debug=DEBUG, share=True)
 
16
  from basicsr.utils import imwrite, img2tensor, tensor2img
17
  from basicsr.utils.download_util import load_file_from_url
18
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
+ from facelib.utils.misc import is_gray
20
  from basicsr.archs.rrdbnet_arch import RRDBNet
21
  from basicsr.utils.realesrgan_utils import RealESRGANer
 
22
 
23
  from basicsr.utils.registry import ARCH_REGISTRY
24
 
 
57
  torch.hub.download_url_to_file(
58
  'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
59
  '05.jpg')
 
 
 
60
 
61
  def imread(img_path):
62
  img = cv2.imread(img_path)
 
101
 
102
  os.makedirs('output', exist_ok=True)
103
 
104
+ def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
105
  """Run a single prediction on the model"""
106
  try: # global try
107
  # take the default setting for the demo
108
+ has_aligned = False
109
  only_center_face = False
110
  draw_box = False
111
  detection_model = "retinaface_resnet50"
 
112
  print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
113
+
114
  background_enhance = background_enhance if background_enhance is not None else True
115
  face_upsample = face_upsample if face_upsample is not None else True
116
  upscale = upscale if (upscale is not None and upscale > 0) else 2
117
 
 
 
 
118
  img = cv2.imread(str(image), cv2.IMREAD_COLOR)
119
  print('\timage size:', img.shape)
120
 
 
160
  # face restoration for each cropped face
161
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
162
  # prepare data
163
+ cropped_face_t = img2tensor(
164
+ cropped_face / 255.0, bgr2rgb=True, float32=True
165
+ )
166
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
167
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
168
 
 
176
  torch.cuda.empty_cache()
177
  except RuntimeError as error:
178
  print(f"Failed inference for CodeFormer: {error}")
179
+ restored_face = tensor2img(
180
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
181
+ )
182
 
183
  restored_face = restored_face.astype("uint8")
184
+ face_helper.add_restored_face(restored_face)
185
 
186
  # paste_back
187
  if not has_aligned:
 
203
  restored_img = face_helper.paste_faces_to_input_image(
204
  upsample_img=bg_img, draw_box=draw_box
205
  )
 
 
206
 
207
  # save restored img
208
  save_path = f'output/out.png'
 
256
  td {
257
  padding-right: 0px !important;
258
  }
 
 
 
 
 
 
259
  </style>
260
 
261
  <table>
 
271
  demo = gr.Interface(
272
  inference, [
273
  gr.Image(type="filepath", label="Input"),
 
274
  gr.Checkbox(value=True, label="Background_Enhance"),
275
  gr.Checkbox(value=True, label="Face_Upsample"),
276
  gr.Number(value=2, label="Rescaling_Factor (up to 4)"),
277
  gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
278
  ], [
279
+ gr.Image(type="numpy", label="Output").style(height='auto')
280
  ],
281
  title=title,
282
  description=description,
283
  article=article,
284
  examples=[
285
+ ['01.png', True, True, 2, 0.7],
286
+ ['02.jpg', True, True, 2, 0.7],
287
+ ['03.jpg', True, True, 2, 0.7],
288
+ ['04.jpg', True, True, 2, 0.1],
289
+ ['05.jpg', True, True, 2, 0.1]
290
+ ])
 
 
 
291
 
292
  DEBUG = os.getenv('DEBUG') == '1'
293
+ demo.queue(api_open=False, concurrency_count=2, max_size=10)
294
+ demo.launch(debug=DEBUG)