CatVTON / model /SCHP /simple_extractor_multi.py
ZhengChong
chore: Add SCHP model and detectron2 dependencies
6a6227f
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : simple_extractor.py
@Time : 8/30/19 8:59 PM
@Desc : Simple Extractor
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import threading
from queue import Queue
from tqdm import tqdm
import os
import torch
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import cv2
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import networks
from utils.transforms import transform_logits
from datasets.simple_extractor_dataset import SimpleFolderDataset
from utils.transforms import get_affine_transform
dataset_settings = {
'lip': {
'input_size': [473, 473],
'num_classes': 20,
'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
},
'atr': {
'input_size': [512, 512],
'num_classes': 18,
'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
},
'pascal': {
'input_size': [512, 512],
'num_classes': 7,
'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
}
}
def _box2cs(box,input_size):
x, y, w, h = box[:4]
return _xywh2cs(x, y, w, h,input_size)
def _xywh2cs(x, y, w, h,input_size):
aspect_ratio = input_size[1] * 1.0 / input_size[0]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
scale = np.array([w, h], dtype=np.float32)
return center, scale
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal'])
parser.add_argument("--model-restore", type=str,
default='/data1/chongzheng/zhangwq/Self-Correction-Human-Parsing-master/exp-schp-201908301523-atr.pth',
help="restore pretrained model parameters.")
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
parser.add_argument("--input-dir", type=str, default='/home/chongzheng_p23/data/Datasets/UniFashion/YOOX/YOOX-Images', help="path of input image folder.")
parser.add_argument("--output-dir", type=str, default='/home/chongzheng_p23/data/Datasets/UniFashion/YOOX/YOOX-SCHP', help="path of output image folder.")
parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.")
return parser.parse_args()
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
def schp_process(image_queue,model,progress_bar,input_size,transform):
while True:
img_path = image_queue.get()
image_queue.task_done()
if img_path is None: # 收到结束信号
break
save_path = img_path.replace("YOOX-Images","YOOX-SCHP").replace(".jpg",".png")
if os.path.exists(save_path):
progress_bar.update(1)
continue
root = os.path.dirname(img_path)
img_name = img_path.split("/")[-1].split(".")[0]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
progress_bar.update(1)
continue
if img is not None:
h, w, _ = img.shape
# Get person center and scale
person_center, s = _box2cs([0, 0, w - 1, h - 1],input_size)
r = 0
trans = get_affine_transform(person_center, s, r, input_size)
input = cv2.warpAffine(
img,
trans,
(int(input_size[1]), int(input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))
image = transform(input)
meta = {
'img_path': img_path,
'name': img_name,
'root': root,
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}
if not os.path.exists(save_path):
img_name = meta['name'][0]
c = meta['center'][0]
# s = meta['scale'][0]
# w = meta['width'][0]
# h = meta['height'][0]
root = meta['root'][0]
save_root = root.replace("YOOX-Images","YOOX-SCHP")
if not os.path.exists(save_root):
os.makedirs(save_root)
output = model(image.cuda())
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size)
parsing_result = np.argmax(logits_result, axis=2)
parsing_result_path = save_path
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
output_img.save(parsing_result_path)
progress_bar.update(1)
def main():
args = get_arguments()
gpus = [int(i) for i in args.gpu.split(',')]
assert len(gpus) == 1
if not args.gpu == 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
num_classes = dataset_settings[args.dataset]['num_classes']
input_size = dataset_settings[args.dataset]['input_size']
label = dataset_settings[args.dataset]['label']
print("Evaluating total class number {} with {}".format(num_classes, label))
model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
state_dict = torch.load(args.model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
# dataset = SimpleFolderDataset(root=args.input_dir, input_size=input_size, transform=transform)
# dataloader = DataLoader(dataset)
image_queue = Queue()
for root, dirs, files in os.walk("/home/chongzheng_p23/data/Datasets/UniFashion/YOOX/YOOX-Images"):
for file in files:
if file.endswith(".jpg"):
source_file_path = os.path.join(root, file)
image_queue.put(source_file_path)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
palette = get_palette(num_classes)
progress_bar = tqdm(total=image_queue.qsize(), desc="Processing SCHP")
with torch.no_grad():
devices = [1]*2
consumer_threads = []
for i in devices:
device = f'cuda:{i}'
consumer_threads.append(threading.Thread(target=schp_process,
args=(image_queue,model,progress_bar,input_size,transform)))
consumer_threads[-1].start()
# for idx, batch in enumerate(tqdm(dataloader)):
return
if __name__ == '__main__':
main()