diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index 681c05b3c..0fab10842 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from .recognition import get_recognizer, get_text +from .recognition import get_recognizer, get_text, get_recognizer_attn from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\ download_and_unzip, printProgressBar, diff, reformat_input,\ make_rotated_img_list, set_result_with_confidence,\ @@ -228,9 +228,15 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None, } else: network_params = recog_config['network_params'] - self.recognizer, self.converter = get_recognizer(recog_network, network_params,\ + if "Prediction" in recog_config: + if recog_config["Prediction"] == "Attn": + self.recognizer, self.converter = get_recognizer_attn(recog_network, network_params,\ self.character, separator_list,\ dict_list, model_path, device = self.device, quantize=quantize) + else: + self.recognizer, self.converter = get_recognizer(recog_network, network_params,\ + self.character, separator_list,\ + dict_list, model_path, device = self.device, quantize=quantize) def getDetectorPath(self, detect_network): if detect_network in self.support_detection_network: diff --git a/easyocr/recognition.py b/easyocr/recognition.py index 147370f7d..f10850c28 100644 --- a/easyocr/recognition.py +++ b/easyocr/recognition.py @@ -8,6 +8,7 @@ from collections import OrderedDict import importlib from .utils import CTCLabelConverter +from ..trainer.utils import AttnLabelConverter import math def custom_mean(x): @@ -121,21 +122,26 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1) preds_prob = torch.from_numpy(preds_prob).float().to(device) - if decoder == 'greedy': - # Select max probabilty (greedy decoding) then decode index to character + if isinstance(converter, AttnLabelConverter): _, preds_index = preds_prob.max(2) - preds_index = preds_index.view(-1) - preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data) - elif decoder == 'beamsearch': - k = preds_prob.cpu().detach().numpy() - preds_str = converter.decode_beamsearch(k, beamWidth=beamWidth) - elif decoder == 'wordbeamsearch': - k = preds_prob.cpu().detach().numpy() - preds_str = converter.decode_wordbeamsearch(k, beamWidth=beamWidth) + preds_str = converter.decode(preds_index, preds_size.data) + else: + if decoder == 'greedy': + # Select max probabilty (greedy decoding) then decode index to character + _, preds_index = preds_prob.max(2) + preds_index = preds_index.view(-1) + preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data) + elif decoder == 'beamsearch': + k = preds_prob.cpu().detach().numpy() + preds_str = converter.decode_beamsearch(k, beamWidth=beamWidth) + elif decoder == 'wordbeamsearch': + k = preds_prob.cpu().detach().numpy() + preds_str = converter.decode_wordbeamsearch(k, beamWidth=beamWidth) preds_prob = preds_prob.cpu().detach().numpy() values = preds_prob.max(axis=2) indices = preds_prob.argmax(axis=2) + preds_max_prob = [] for v,i in zip(values, indices): max_probs = v[i!=0] @@ -145,7 +151,15 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ preds_max_prob.append(np.array([0])) for pred, pred_max_prob in zip(preds_str, preds_max_prob): - confidence_score = custom_mean(pred_max_prob) + + + if isinstance(converter, AttnLabelConverter): + pred_EOS = pred.find('[s]') + pred = pred[:pred_EOS] + confidence_score = custom_mean(pred_max_prob[:pred_EOS]) + else: + confidence_score = custom_mean(pred_max_prob) + result.append([pred, confidence_score]) return result @@ -153,7 +167,6 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ def get_recognizer(recog_network, network_params, character,\ separator_list, dict_list, model_path,\ device = 'cpu', quantize = True): - converter = CTCLabelConverter(character, separator_list, dict_list) num_class = len(converter.character) @@ -183,6 +196,38 @@ def get_recognizer(recog_network, network_params, character,\ return model, converter +def get_recognizer_attn(recog_network, network_params, character,\ + separator_list, dict_list, model_path,\ + device = 'cpu', quantize = True): + converter = AttnLabelConverter(character, device) + num_class = len(converter.character) + + if recog_network == 'generation1': + model_pkg = importlib.import_module("easyocr.model.model") + elif recog_network == 'generation2': + model_pkg = importlib.import_module("easyocr.model.vgg_model") + else: + model_pkg = importlib.import_module(recog_network) + model = model_pkg.Model(num_class=num_class, **network_params) + + if device == 'cpu': + state_dict = torch.load(model_path, map_location=device) + new_state_dict = OrderedDict() + for key, value in state_dict.items(): + new_key = key[7:] + new_state_dict[new_key] = value + model.load_state_dict(new_state_dict) + if quantize: + try: + torch.quantization.quantize_dynamic(model, dtype=torch.qint8, inplace=True) + except: + pass + else: + model = torch.nn.DataParallel(model).to(device) + model.load_state_dict(torch.load(model_path, map_location=device)) + + return model, converter + def get_text(character, imgH, imgW, recognizer, converter, image_list,\ ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\ adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'): diff --git a/recognize_function.py b/recognize_function.py new file mode 100644 index 000000000..819bb1a6f --- /dev/null +++ b/recognize_function.py @@ -0,0 +1,41 @@ +import easyocr +import os +import cv2 +import numpy as np +from easyocr.easyocr import Reader +from easyocr.trainer.utils import AttnLabelConverter + +def recognize_text_from_images(image_pieces, models_directory, recog_network='best_accuracy', gpu=False): + """ + Recognizes text from a list of image pieces using EasyOCR. + + Parameters: + - image_pieces (list): List of image pieces as PIL Image objects. + - models_directory (str): Path to the models directory. + - recog_network (str): Recognition network to use (default is 'best_accuracy'). + - gpu (bool): Whether to use GPU for OCR (default is False). + + Returns: + - List of recognized texts. + """ + model_storage_directory = os.path.join(models_directory, "model") + user_network_directory = os.path.join(models_directory, "user_network") + + # Initialize EasyOCR reader + reader = Reader(['ru'], recog_network=recog_network, gpu=gpu, + model_storage_directory=model_storage_directory, + user_network_directory=user_network_directory) + + + recognized_texts = [] + for image_piece in image_pieces: + # Convert PIL Image to OpenCV format + image_cv = cv2.cvtColor(np.array(image_piece), cv2.COLOR_RGB2BGR) + # Perform text recognition + if isinstance(reader.converter, AttnLabelConverter): + result = reader.readtext(image_cv, detail=0, decoder="beamsearch") + else: + result = reader.readtext(image_cv, detail=0) + recognized_texts.append(" ".join(result)) + + return recognized_texts \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 000000000..5a00586f0 --- /dev/null +++ b/test.py @@ -0,0 +1,59 @@ +import easyocr +import matplotlib.pyplot as plt +import cv2 +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import os + + +# Load the image +image_path = "test2 (1).jpg" +image = cv2.imread(image_path) +image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + +# Convert to PIL Image for better font rendering +image_pil = Image.fromarray(image_rgb) +draw = ImageDraw.Draw(image_pil) + +# Load a font (make sure the path to the font file is correct) +font_path = "arial.ttf" # Replace with the path to your .ttf font file +font = ImageFont.truetype(font_path, 60) # Increased font size to 40 + +models_directory = "C:/Users/user/Desktop/modelsocr" +model_storage_directory = os.path.join(models_directory, "model") +user_network_directory = os.path.join(models_directory, "user_network") +# Initialize EasyOCR reader +reader = easyocr.Reader(['ru'], recog_network='best_accuracy', gpu = False, + model_storage_directory=model_storage_directory, + user_network_directory=user_network_directory) + +# Perform text detection and recognition +results = reader.readtext(image_path) + +# Print the results +print(results) + +# Draw bounding boxes and placeholder text on the image +for (bbox, text, prob) in results: + # Replace the recognized text with "вопросы" + + # Unpack the bounding box coordinates + (top_left, top_right, bottom_right, bottom_left) = bbox + top_left = tuple([int(val) for val in top_left]) + bottom_right = tuple([int(val) for val in bottom_right]) + + # Draw the bounding box on the image + draw.rectangle([top_left, bottom_right], outline="green", width=2) + + # Put the placeholder text + draw.text((top_left[0], top_left[1] - 40), text, font=font, fill="green") # Adjusted y-coordinate for larger font + +# Convert back to OpenCV format +image_rgb = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) + +# Display the image with bounding boxes and placeholder text +plt.figure(figsize=(60, 60)) # Adjusted the figure size for better visibility +plt.imshow(cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)) +plt.axis('off') +plt.savefig("text.png") +plt.show() \ No newline at end of file diff --git a/trainer/modules/prediction.py b/trainer/modules/prediction.py index c8a40af0e..2b11f24e6 100644 --- a/trainer/modules/prediction.py +++ b/trainer/modules/prediction.py @@ -1,22 +1,23 @@ import torch import torch.nn as nn import torch.nn.functional as F -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class Attention(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): + def __init__(self, input_size, hidden_size, num_classes, device): super(Attention, self).__init__() self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) self.hidden_size = hidden_size self.num_classes = num_classes self.generator = nn.Linear(hidden_size, num_classes) + self.device = device def _char_to_onehot(self, input_char, onehot_dim=38): input_char = input_char.unsqueeze(1) batch_size = input_char.size(0) - one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) + one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device) one_hot = one_hot.scatter_(1, input_char, 1) return one_hot @@ -30,9 +31,9 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25): batch_size = batch_H.size(0) num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. - output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) - hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), - torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) + output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(self.device) + hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device), + torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device)) if is_train: for i in range(num_steps): @@ -44,8 +45,8 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25): probs = self.generator(output_hiddens) else: - targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token - probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) + targets = torch.LongTensor(batch_size).fill_(0).to(self.device) # [GO] token + probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(self.device) for i in range(num_steps): char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) diff --git a/trainer/utils.py b/trainer/utils.py index 4cf7dac8c..75b021329 100644 --- a/trainer/utils.py +++ b/trainer/utils.py @@ -1,7 +1,7 @@ import torch import pickle import numpy as np -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + class AttrDict(dict): def __init__(self, *args, **kwargs): @@ -298,12 +298,13 @@ def decode_wordbeamsearch(self, mat, beamWidth=5): class AttnLabelConverter(object): """ Convert between text-label and text-index """ - def __init__(self, character): + def __init__(self, character, device): # character (str): set of the possible characters. # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] list_character = list(character) self.character = list_token + list_character + self.device = device self.dict = {} for i, char in enumerate(self.character): @@ -331,7 +332,7 @@ def encode(self, text, batch_max_length=25): text.append('[s]') text = [self.dict[char] for char in text] batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token - return (batch_text.to(device), torch.IntTensor(length).to(device)) + return (batch_text.to(self.device), torch.IntTensor(length).to(self.device)) def decode(self, text_index, length): """ convert text-index into text-label. """