"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. 1. Disable DataParallel. """ import torch import torch.backends.cudnn as cudnn from torch.autograd import Variable from PIL import Image from collections import OrderedDict import cv2 import numpy as np from .craft_utils import getDetBoxes, adjustResultCoordinates from .imgproc import resize_aspect_ratio, normalizeMeanVariance from .craft import CRAFT def copyStateDict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays image_arrs = image else: # image is single numpy array image_arrs = [image] img_resized_list = [] # resize for img in image_arrs: img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) img_resized_list.append(img_resized) ratio_h = ratio_w = 1 / target_ratio # preprocessing x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) for n_img in img_resized_list] x = torch.from_numpy(np.array(x)) x = x.to(device) # forward pass with torch.no_grad(): y, feature = net(x) boxes_list, polys_list = [], [] for out in y: # make score and link map score_text = out[:, :, 0].cpu().data.numpy() score_link = out[:, :, 1].cpu().data.numpy() # Post-processing boxes, polys, mapper = getDetBoxes( score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) # coordinate adjustment boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) polys = adjustResultCoordinates(polys, ratio_w, ratio_h) if estimate_num_chars: boxes = list(boxes) polys = list(polys) for k in range(len(polys)): if estimate_num_chars: boxes[k] = (boxes[k], mapper[k]) if polys[k] is None: polys[k] = boxes[k] boxes_list.append(boxes) polys_list.append(polys) return boxes_list, polys_list def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): net = CRAFT() if device == 'cpu': net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) if quantize: try: torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) except: pass else: net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) # net = torch.nn.DataParallel(net).to(device) net = net.to(device) cudnn.benchmark = cudnn_benchmark net.eval() return net def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): result = [] estimate_num_chars = optimal_num_chars is not None bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars) if estimate_num_chars: polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] for polys in polys_list] for polys in polys_list: single_img_result = [] for i, box in enumerate(polys): poly = np.array(box).astype(np.int32).reshape((-1)) single_img_result.append(poly) result.append(single_img_result) return result