Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix submodule #1261

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions easyocr/easyocr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from .recognition import get_recognizer, get_text
from .recognition import get_recognizer, get_text, get_recognizer_attn
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
download_and_unzip, printProgressBar, diff, reformat_input,\
make_rotated_img_list, set_result_with_confidence,\
Expand Down Expand Up @@ -228,9 +228,15 @@ def __init__(self, lang_list, gpu=True, model_storage_directory=None,
}
else:
network_params = recog_config['network_params']
self.recognizer, self.converter = get_recognizer(recog_network, network_params,\
if "Prediction" in recog_config:
if recog_config["Prediction"] == "Attn":
self.recognizer, self.converter = get_recognizer_attn(recog_network, network_params,\
self.character, separator_list,\
dict_list, model_path, device = self.device, quantize=quantize)
else:
self.recognizer, self.converter = get_recognizer(recog_network, network_params,\
self.character, separator_list,\
dict_list, model_path, device = self.device, quantize=quantize)

def getDetectorPath(self, detect_network):
if detect_network in self.support_detection_network:
Expand Down
69 changes: 57 additions & 12 deletions easyocr/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import OrderedDict
import importlib
from .utils import CTCLabelConverter
from ..trainer.utils import AttnLabelConverter
import math

def custom_mean(x):
Expand Down Expand Up @@ -121,21 +122,26 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\
preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
preds_prob = torch.from_numpy(preds_prob).float().to(device)

if decoder == 'greedy':
# Select max probabilty (greedy decoding) then decode index to character
if isinstance(converter, AttnLabelConverter):
_, preds_index = preds_prob.max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data)
elif decoder == 'beamsearch':
k = preds_prob.cpu().detach().numpy()
preds_str = converter.decode_beamsearch(k, beamWidth=beamWidth)
elif decoder == 'wordbeamsearch':
k = preds_prob.cpu().detach().numpy()
preds_str = converter.decode_wordbeamsearch(k, beamWidth=beamWidth)
preds_str = converter.decode(preds_index, preds_size.data)
else:
if decoder == 'greedy':
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds_prob.max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode_greedy(preds_index.data.cpu().detach().numpy(), preds_size.data)
elif decoder == 'beamsearch':
k = preds_prob.cpu().detach().numpy()
preds_str = converter.decode_beamsearch(k, beamWidth=beamWidth)
elif decoder == 'wordbeamsearch':
k = preds_prob.cpu().detach().numpy()
preds_str = converter.decode_wordbeamsearch(k, beamWidth=beamWidth)

preds_prob = preds_prob.cpu().detach().numpy()
values = preds_prob.max(axis=2)
indices = preds_prob.argmax(axis=2)

preds_max_prob = []
for v,i in zip(values, indices):
max_probs = v[i!=0]
Expand All @@ -145,15 +151,22 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\
preds_max_prob.append(np.array([0]))

for pred, pred_max_prob in zip(preds_str, preds_max_prob):
confidence_score = custom_mean(pred_max_prob)


if isinstance(converter, AttnLabelConverter):
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS]
confidence_score = custom_mean(pred_max_prob[:pred_EOS])
else:
confidence_score = custom_mean(pred_max_prob)

result.append([pred, confidence_score])

return result

def get_recognizer(recog_network, network_params, character,\
separator_list, dict_list, model_path,\
device = 'cpu', quantize = True):

converter = CTCLabelConverter(character, separator_list, dict_list)
num_class = len(converter.character)

Expand Down Expand Up @@ -183,6 +196,38 @@ def get_recognizer(recog_network, network_params, character,\

return model, converter

def get_recognizer_attn(recog_network, network_params, character,\
separator_list, dict_list, model_path,\
device = 'cpu', quantize = True):
converter = AttnLabelConverter(character, device)
num_class = len(converter.character)

if recog_network == 'generation1':
model_pkg = importlib.import_module("easyocr.model.model")
elif recog_network == 'generation2':
model_pkg = importlib.import_module("easyocr.model.vgg_model")
else:
model_pkg = importlib.import_module(recog_network)
model = model_pkg.Model(num_class=num_class, **network_params)

if device == 'cpu':
state_dict = torch.load(model_path, map_location=device)
new_state_dict = OrderedDict()
for key, value in state_dict.items():
new_key = key[7:]
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict)
if quantize:
try:
torch.quantization.quantize_dynamic(model, dtype=torch.qint8, inplace=True)
except:
pass
else:
model = torch.nn.DataParallel(model).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

return model, converter

def get_text(character, imgH, imgW, recognizer, converter, image_list,\
ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\
adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'):
Expand Down
41 changes: 41 additions & 0 deletions recognize_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import easyocr
import os
import cv2
import numpy as np
from easyocr.easyocr import Reader
from easyocr.trainer.utils import AttnLabelConverter

def recognize_text_from_images(image_pieces, models_directory, recog_network='best_accuracy', gpu=False):
"""
Recognizes text from a list of image pieces using EasyOCR.

Parameters:
- image_pieces (list): List of image pieces as PIL Image objects.
- models_directory (str): Path to the models directory.
- recog_network (str): Recognition network to use (default is 'best_accuracy').
- gpu (bool): Whether to use GPU for OCR (default is False).

Returns:
- List of recognized texts.
"""
model_storage_directory = os.path.join(models_directory, "model")
user_network_directory = os.path.join(models_directory, "user_network")

# Initialize EasyOCR reader
reader = Reader(['ru'], recog_network=recog_network, gpu=gpu,
model_storage_directory=model_storage_directory,
user_network_directory=user_network_directory)


recognized_texts = []
for image_piece in image_pieces:
# Convert PIL Image to OpenCV format
image_cv = cv2.cvtColor(np.array(image_piece), cv2.COLOR_RGB2BGR)
# Perform text recognition
if isinstance(reader.converter, AttnLabelConverter):
result = reader.readtext(image_cv, detail=0, decoder="beamsearch")
else:
result = reader.readtext(image_cv, detail=0)
recognized_texts.append(" ".join(result))

return recognized_texts
59 changes: 59 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import easyocr
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os


# Load the image
image_path = "test2 (1).jpg"
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Convert to PIL Image for better font rendering
image_pil = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(image_pil)

# Load a font (make sure the path to the font file is correct)
font_path = "arial.ttf" # Replace with the path to your .ttf font file
font = ImageFont.truetype(font_path, 60) # Increased font size to 40

models_directory = "C:/Users/user/Desktop/modelsocr"
model_storage_directory = os.path.join(models_directory, "model")
user_network_directory = os.path.join(models_directory, "user_network")
# Initialize EasyOCR reader
reader = easyocr.Reader(['ru'], recog_network='best_accuracy', gpu = False,
model_storage_directory=model_storage_directory,
user_network_directory=user_network_directory)

# Perform text detection and recognition
results = reader.readtext(image_path)

# Print the results
print(results)

# Draw bounding boxes and placeholder text on the image
for (bbox, text, prob) in results:
# Replace the recognized text with "вопросы"

# Unpack the bounding box coordinates
(top_left, top_right, bottom_right, bottom_left) = bbox
top_left = tuple([int(val) for val in top_left])
bottom_right = tuple([int(val) for val in bottom_right])

# Draw the bounding box on the image
draw.rectangle([top_left, bottom_right], outline="green", width=2)

# Put the placeholder text
draw.text((top_left[0], top_left[1] - 40), text, font=font, fill="green") # Adjusted y-coordinate for larger font

# Convert back to OpenCV format
image_rgb = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)

# Display the image with bounding boxes and placeholder text
plt.figure(figsize=(60, 60)) # Adjusted the figure size for better visibility
plt.imshow(cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.savefig("text.png")
plt.show()
17 changes: 9 additions & 8 deletions trainer/modules/prediction.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Attention(nn.Module):

def __init__(self, input_size, hidden_size, num_classes):
def __init__(self, input_size, hidden_size, num_classes, device):
super(Attention, self).__init__()
self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
self.hidden_size = hidden_size
self.num_classes = num_classes
self.generator = nn.Linear(hidden_size, num_classes)
self.device = device

def _char_to_onehot(self, input_char, onehot_dim=38):
input_char = input_char.unsqueeze(1)
batch_size = input_char.size(0)
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(self.device)
one_hot = one_hot.scatter_(1, input_char, 1)
return one_hot

Expand All @@ -30,9 +31,9 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25):
batch_size = batch_H.size(0)
num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.

output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(self.device)
hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device),
torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(self.device))

if is_train:
for i in range(num_steps):
Expand All @@ -44,8 +45,8 @@ def forward(self, batch_H, text, is_train=True, batch_max_length=25):
probs = self.generator(output_hiddens)

else:
targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
targets = torch.LongTensor(batch_size).fill_(0).to(self.device) # [GO] token
probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(self.device)

for i in range(num_steps):
char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
Expand Down
7 changes: 4 additions & 3 deletions trainer/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import pickle
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class AttrDict(dict):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -298,12 +298,13 @@ def decode_wordbeamsearch(self, mat, beamWidth=5):
class AttnLabelConverter(object):
""" Convert between text-label and text-index """

def __init__(self, character):
def __init__(self, character, device):
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
list_character = list(character)
self.character = list_token + list_character
self.device = device

self.dict = {}
for i, char in enumerate(self.character):
Expand Down Expand Up @@ -331,7 +332,7 @@ def encode(self, text, batch_max_length=25):
text.append('[s]')
text = [self.dict[char] for char in text]
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text.to(device), torch.IntTensor(length).to(device))
return (batch_text.to(self.device), torch.IntTensor(length).to(self.device))

def decode(self, text_index, length):
""" convert text-index into text-label. """
Expand Down