Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-org the ocr agent files #64

Merged
merged 1 commit into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/layoutparser/ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .gcv_agent import GCVAgent, GCVFeatureType
from .tesseract_agent import TesseractAgent, TesseractFeatureType
67 changes: 67 additions & 0 deletions src/layoutparser/ocr/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
from enum import IntEnum
import importlib


class BaseOCRElementType(IntEnum):
@property
@abstractmethod
def attr_name(self):
pass


class BaseOCRAgent(ABC):
@property
@abstractmethod
def DEPENDENCIES(self):
"""DEPENDENCIES lists all necessary dependencies for the class."""
pass

@property
@abstractmethod
def MODULES(self):
"""MODULES instructs how to import these necessary libraries.

Note:
Sometimes a python module have different installation name and module name (e.g.,
`pip install tensorflow-gpu` when installing and `import tensorflow` when using
). And sometimes we only need to import a submodule but not whole module. MODULES
is designed for this purpose.

Returns:
:obj: list(dict): A list of dict indicate how the model is imported.

Example::

[{
"import_name": "_vision",
"module_path": "google.cloud.vision"
}]

is equivalent to self._vision = importlib.import_module("google.cloud.vision")
"""
pass

@classmethod
def _import_module(cls):
for m in cls.MODULES:
if importlib.util.find_spec(m["module_path"]):
setattr(
cls, m["import_name"], importlib.import_module(m["module_path"])
)
else:
raise ModuleNotFoundError(
f"\n "
f"\nPlease install the following libraries to support the class {cls.__name__}:"
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
f"\n "
)

def __new__(cls, *args, **kwargs):

cls._import_module()
return super().__new__(cls)

@abstractmethod
def detect(self, image):
pass
240 changes: 3 additions & 237 deletions src/layoutparser/ocr.py → src/layoutparser/ocr/gcv_agent.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,19 @@
from abc import ABC, abstractmethod
from enum import IntEnum
import importlib
import io
import os
import json
import csv
import warnings
import pickle

import numpy as np
import pandas as pd
from cv2 import imencode

from .elements import *
from .io import load_dataframe

__all__ = ["GCVFeatureType", "GCVAgent", "TesseractFeatureType", "TesseractAgent"]
from .base import BaseOCRAgent, BaseOCRElementType
from ..elements import Layout, TextBlock, Quadrilateral, TextBlock


def _cvt_GCV_vertices_to_points(vertices):
return np.array([[vertex.x, vertex.y] for vertex in vertices])


class BaseOCRElementType(IntEnum):
@property
@abstractmethod
def attr_name(self):
pass


class BaseOCRAgent(ABC):
@property
@abstractmethod
def DEPENDENCIES(self):
"""DEPENDENCIES lists all necessary dependencies for the class."""
pass

@property
@abstractmethod
def MODULES(self):
"""MODULES instructs how to import these necessary libraries.

Note:
Sometimes a python module have different installation name and module name (e.g.,
`pip install tensorflow-gpu` when installing and `import tensorflow` when using
). And sometimes we only need to import a submodule but not whole module. MODULES
is designed for this purpose.

Returns:
:obj: list(dict): A list of dict indicate how the model is imported.

Example::

[{
"import_name": "_vision",
"module_path": "google.cloud.vision"
}]

is equivalent to self._vision = importlib.import_module("google.cloud.vision")
"""
pass

@classmethod
def _import_module(cls):
for m in cls.MODULES:
if importlib.util.find_spec(m["module_path"]):
setattr(
cls, m["import_name"], importlib.import_module(m["module_path"])
)
else:
raise ModuleNotFoundError(
f"\n "
f"\nPlease install the following libraries to support the class {cls.__name__}:"
f"\n pip install {' '.join(cls.DEPENDENCIES)}"
f"\n "
)

def __new__(cls, *args, **kwargs):

cls._import_module()
return super().__new__(cls)

@abstractmethod
def detect(self, image):
pass


class GCVFeatureType(BaseOCRElementType):
"""
The element types from Google Cloud Vision API
Expand Down Expand Up @@ -341,166 +269,4 @@ def save_response(self, res, file_name):

with open(file_name, "w") as f:
json_file = json.loads(res)
json.dump(json_file, f)


class TesseractFeatureType(BaseOCRElementType):
"""
The element types for Tesseract Detection API
"""

PAGE = 0
BLOCK = 1
PARA = 2
LINE = 3
WORD = 4

@property
def attr_name(self):
name_cvt = {
TesseractFeatureType.PAGE: "page_num",
TesseractFeatureType.BLOCK: "block_num",
TesseractFeatureType.PARA: "par_num",
TesseractFeatureType.LINE: "line_num",
TesseractFeatureType.WORD: "word_num",
}
return name_cvt[self]

@property
def group_levels(self):
levels = ["page_num", "block_num", "par_num", "line_num", "word_num"]
return levels[: self + 1]


class TesseractAgent(BaseOCRAgent):
"""
A wrapper for `Tesseract <https://github.com/tesseract-ocr/tesseract>`_ Text
Detection APIs based on `PyTesseract <https://github.com/tesseract-ocr/tesseract>`_.
"""

DEPENDENCIES = ["pytesseract"]
MODULES = [{"import_name": "_pytesseract", "module_path": "pytesseract"}]

def __init__(self, languages="eng", **kwargs):
"""Create a Tesseract OCR Agent.

Args:
languages (:obj:`list` or :obj:`str`, optional):
You can specify the language code(s) of the documents to detect to improve
accuracy. The supported language and their code can be found on
`its github repo <https://github.com/tesseract-ocr/langdata>`_.
It supports two formats: 1) you can pass in the languages code as a string
of format like `"eng+fra"`, or 2) you can pack them as a list of strings
`["eng", "fra"]`.
Defaults to 'eng'.
"""
self.lang = languages if isinstance(languages, str) else "+".join(languages)
self.configs = kwargs

@classmethod
def with_tesseract_executable(cls, tesseract_cmd_path, **kwargs):

cls._pytesseract.pytesseract.tesseract_cmd = tesseract_cmd_path
return cls(**kwargs)

def _detect(self, img_content):
res = {}
res["text"] = self._pytesseract.image_to_string(
img_content, lang=self.lang, **self.configs
)
_data = self._pytesseract.image_to_data(
img_content, lang=self.lang, **self.configs
)
res["data"] = pd.read_csv(
io.StringIO(_data), quoting=csv.QUOTE_NONE, encoding="utf-8", sep="\t"
)
return res

def detect(
self, image, return_response=False, return_only_text=True, agg_output_level=None
):
"""Send the input image for OCR.

Args:
image (:obj:`np.ndarray` or :obj:`str`):
The input image array or the name of the image file
return_response (:obj:`bool`, optional):
Whether directly return all output (string and boxes
info) from Tesseract.
Defaults to `False`.
return_only_text (:obj:`bool`, optional):
Whether return only the texts in the OCR results.
Defaults to `False`.
agg_output_level (:obj:`~TesseractFeatureType`, optional):
When set, aggregate the GCV output with respect to the
specified aggregation level. Defaults to `None`.
"""

res = self._detect(image)

if return_response:
return res

if return_only_text:
return res["text"]

if agg_output_level is not None:
return self.gather_data(res, agg_output_level)

return res["text"]

@staticmethod
def gather_data(response, agg_level):
"""
Gather the OCR'ed text, bounding boxes, and confidence
in a given aggeragation level.
"""
assert isinstance(
agg_level, TesseractFeatureType
), f"Invalid agg_level {agg_level}"
res = response["data"]
df = (
res[~res.text.isna()]
.groupby(agg_level.group_levels)
.apply(
lambda gp: pd.Series(
[
gp["left"].min(),
gp["top"].min(),
gp["width"].max(),
gp["height"].max(),
gp["conf"].mean(),
gp["text"].str.cat(sep=" "),
]
)
)
.reset_index(drop=True)
.reset_index()
.rename(
columns={
0: "x_1",
1: "y_1",
2: "w",
3: "h",
4: "score",
5: "text",
"index": "id",
}
)
.assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, block_type="rectangle")
.drop(columns=["w", "h"])
)

return load_dataframe(df)

@staticmethod
def load_response(filename):
with open(filename, "rb") as fp:
res = pickle.load(fp)
return res

@staticmethod
def save_response(res, file_name):

with open(file_name, "wb") as fp:
pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL)
json.dump(json_file, f)
Loading