Created
July 10, 2024 07:13
-
-
Save MathiasToftas/dae1165a22da14a34f0813163b333be5 to your computer and use it in GitHub Desktop.
cuda support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Module containing methods of detecting tables from whole pdf pages. | |
Whenever possible, classes (like :class:`TableDetector`) should be imported from the top-level module, not from this module, | |
as the exact paths may change in future versions. | |
Example: | |
>>> from gmft import TableDetector | |
""" | |
import copy | |
from typing import Generator, Union | |
import PIL.Image | |
from PIL.Image import Image as PILImage | |
import torch | |
import transformers | |
from transformers import AutoImageProcessor, TableTransformerForObjectDetection | |
from gmft.common import Rect | |
from gmft.pdf_bindings.common import BasePage, ImageOnlyPage | |
from gmft.table_visualization import plot_results_unwr | |
def position_words(words: Generator[tuple[int, int, int, int, str], None, None], y_gap=3): | |
""" | |
Helper function to convert a list of words with positions to a string. | |
""" | |
# assume reading order is left to right, then top to bottom | |
first = next(words, None) | |
if first is None: | |
return "" | |
prev_left, prev_top, prev_right, prev_bottom, lines = first | |
# y_gap = 2 # consider the y jumping by y_gap to be a new line | |
for word in words: | |
x0, y0, x1, y1, text = word[:5] | |
if abs(y1 - prev_bottom) >= y_gap: | |
lines += f"\n{text}" | |
else: | |
lines += f" {text}" | |
prev_bottom = y1 | |
return lines | |
class CroppedTable: | |
""" | |
A pdf selection, cropped to include just a table. | |
Created by :class:`~gmft.TableDetector`. | |
""" | |
_img: PILImage | |
_img_dpi: int | |
_img_padding: tuple[int, int, int, int] | |
_img_margin: tuple[int, int, int, int] | |
def __init__(self, page: BasePage, bbox: tuple[int, int, int, int] | Rect, confidence_score: float, label=0): | |
self.page = page | |
if isinstance(bbox, Rect): | |
self.rect = bbox | |
else: | |
self.rect = Rect(bbox) | |
self.confidence_score = confidence_score | |
self._img = None | |
self._img_dpi = None | |
self._img_padding = None | |
self._img_margin = None | |
self.label = label | |
def image(self, dpi: int = None, padding: str | tuple[int, int, int, int]=None, margin: str | tuple[int, int, int, int]=None) -> PILImage: | |
""" | |
Return the image of the cropped table. | |
Following pypdfium2, scaling_factor = (dpi / 72). | |
Therefore, dpi=72 is the default, and dpi=144 is x2 zoom. | |
:param dpi: dots per inch. If not None, the scaling_factor parameter is ignored. | |
:param padding: padding (**blank pixels**) to add to the image. Tuple of (left, top, right, bottom) | |
Padding (blank pixels) is added after the crop and rotation. | |
Padding is important for subsequent row/column detection; see https://github.com/microsoft/table-transformer/issues/68 for discussion. | |
If padding = 'auto', the padding is automatically set to 10% of the larger of {width, height}. | |
Default is no padding. | |
:param margin: add content (in **pdf units**) from the original pdf beyond the detected table bbox boundary. | |
:return: image of the cropped table | |
""" | |
dpi = 72 if dpi is None else dpi | |
if padding == 'auto': | |
width= self.rect.width * dpi / 72 | |
height = self.rect.height * dpi / 72 | |
pad = int(max(width, height) * 0.1) | |
padding = (pad, pad, pad, pad) | |
elif padding == None: | |
padding = (0, 0, 0, 0) | |
# if effective_dpi == self._img_dpi and effective_padding == self._img_padding: | |
# return self._img # cache results | |
rect = self.rect | |
if margin == 'auto': | |
margin = (30, 30, 30, 30) # from the paper | |
if margin is not None: | |
rect = Rect((rect.xmin - margin[0], rect.ymin - margin[1], | |
rect.xmax + margin[2], rect.ymax + margin[3])) | |
img = self.page.get_image(dpi=dpi, rect=rect) | |
if padding is not None: | |
img = PIL.ImageOps.expand(img, padding, fill="white") | |
self._img = img | |
self._img_dpi = dpi | |
self._img_padding = padding | |
self._img_margin = margin | |
return self._img | |
def text_positions(self, remove_table_offset: bool = False, outside: bool = False) -> Generator[tuple[int, int, int, int, str], None, None]: | |
""" | |
Return the text positions of the cropped table. | |
Any words that intersect the table are captured, even if they are not fully contained. | |
:param remove_table_offset: if True, the positions are adjusted to be relative to the top-left corner of the table. | |
:param outside: if True, returns the **complement** of the table: all the text positions outside the table. | |
By default, it returns the text positions inside the table. | |
:return: list of text positions, which is a tuple | |
``(x0, y0, x1, y1, "string")`` | |
""" | |
for w in self.page.get_positions_and_text(): | |
if Rect(w[:4]).is_intersecting(self.rect) != outside: | |
if remove_table_offset: | |
yield (w[0] - self.rect.xmin, w[1] - self.rect.ymin, w[2] - self.rect.xmin, w[3] - self.rect.ymin, w[4]) | |
else: | |
yield w | |
# words = [w for w in self.page.get_positions_and_text()] | |
# if outside: | |
# # get the table's complement | |
# subset = [w for w in words if not Rect(w[:4]).is_intersecting(self.rect)] | |
# else: | |
# # get the table | |
# subset = [w for w in words if Rect(w[:4]).is_intersecting(self.rect)] | |
# if remove_table_offset: | |
# subset = [(w[0] - self.rect.xmin, w[1] - self.rect.ymin, w[2] - self.rect.xmin, w[3] - self.rect.ymin, w[4]) for w in subset] | |
# return subset | |
def text(self): | |
""" | |
Return the text of the cropped table. | |
Any words that intersect the table are captured, even if they are not fully contained. | |
:return: text of the cropped table | |
""" | |
return position_words(self.text_positions()) | |
def visualize(self, show_text=False, **kwargs): | |
""" | |
Visualize the cropped table. | |
""" | |
img = self.page.get_image() | |
confidences = [self.confidence_score] | |
labels = [self.label] | |
bboxes = [self.rect.bbox] | |
if show_text: | |
# text_positions = [(x0, y0, x1, y1) for x0, y0, x1, y1, _ in self.text_positions()] | |
text_positions = [w[:4] for w in self.page.get_positions_and_text()] | |
confidences += [0.9] * len(text_positions) | |
labels += [-1] * len(text_positions) | |
bboxes += text_positions | |
plot_results_unwr(img, confidence=confidences, labels=labels, boxes=bboxes, id2label=None, **kwargs) | |
def to_dict(self): | |
return { | |
"filename": self.page.get_filename(), | |
"page_no": self.page.page_number, | |
"bbox": self.rect.bbox, | |
"confidence_score": self.confidence_score, | |
"label": self.label | |
} | |
@staticmethod | |
def from_dict(d: dict, page: BasePage): | |
""" | |
Deserialize a CroppedTable object from dict. | |
Because file locations may change, require the user to provide the original page - | |
but as a helper method see PyPDFium2Utils.load_page_from_dict and PyPDFium2Utils.reload | |
:param d: dict | |
:param page: BasePage | |
:return: CroppedTable object | |
""" | |
if 'angle' in d: | |
return RotatedCroppedTable.from_dict(d, page) | |
return CroppedTable(page, d['bbox'], d['confidence_score'], d['label']) | |
@staticmethod | |
def from_image_only(img: PILImage) -> 'CroppedTable': | |
""" | |
Create a :class:`~gmft.CroppedTable` object from an image only. | |
:param img: PIL image | |
:return: CroppedTable object | |
""" | |
page = ImageOnlyPage(img) | |
# bbox is the entire image | |
bbox = (0, 0, img.width, img.height) | |
table = CroppedTable(page, bbox, confidence_score=1.0, label=0) | |
table._img = img | |
table._img_dpi = 72 | |
return table | |
@property | |
def bbox(self): | |
return self.rect.bbox | |
class TableDetectorConfig: | |
""" | |
Configuration for the :class:`~gmft.TableDetector` class. | |
""" | |
image_processor_path: str = "microsoft/table-transformer-detection" | |
detector_path: str = "microsoft/table-transformer-detection" | |
no_timm: bool = True # huggingface revision | |
warn_uninitialized_weights: bool = False | |
torch_device: torch.device = torch.device("cpu") | |
detector_base_threshold: float = 0.9 | |
"""Minimum confidence score required for a table""" | |
@property | |
def confidence_score_threshold(self): | |
raise DeprecationWarning("Use detector_base_threshold instead.") | |
@confidence_score_threshold.setter | |
def confidence_score_threshold(self, value): | |
raise DeprecationWarning("Use detector_base_threshold instead.") | |
def __init__(self, image_processor_path: str = None, detector_path: str = None, torch_device: torch.device = None): | |
if image_processor_path is not None: | |
self.image_processor_path = image_processor_path | |
if detector_path is not None: | |
self.detector_path = detector_path | |
if torch_device is not None: | |
self.torch_device = torch_device | |
class TableDetector: | |
""" | |
Detects tables in a pdf page. Default implementation uses TableTransformerForObjectDetection. | |
""" | |
def __init__(self, config: TableDetectorConfig=None, default_implementation=True): | |
""" | |
Initialize the TableDetector. | |
:param config: TableDetectorConfig | |
:param default_implementation: Should be True, unless you are writing a custom subclass for TableDetector. | |
""" | |
# future-proofing: allow subclasses for TableDetector to have different architectures | |
if not default_implementation: | |
return | |
if config is None: | |
config = TableDetectorConfig() | |
elif isinstance(config, dict): | |
config = TableDetectorConfig(**config) | |
if not config.warn_uninitialized_weights: | |
previous_verbosity = transformers.logging.get_verbosity() | |
transformers.logging.set_verbosity(transformers.logging.ERROR) | |
self.image_processor = AutoImageProcessor.from_pretrained(config.image_processor_path) | |
revision = "no_timm" if config.no_timm else None | |
self.detector = TableTransformerForObjectDetection.from_pretrained(config.detector_path, revision=revision).to(config.torch_device) | |
if not config.warn_uninitialized_weights: | |
transformers.logging.set_verbosity(previous_verbosity) | |
self.config = config | |
def extract(self, page: BasePage, config_overrides: TableDetectorConfig=None) -> list[CroppedTable]: | |
""" | |
Detect tables in a page. | |
:param page: BasePage | |
:param config_overrides: override the config for this call only | |
:return: list of CroppedTable objects | |
""" | |
if config_overrides is not None: | |
config = copy.deepcopy(self.config) | |
config.__dict__.update(config_overrides.__dict__) | |
else: | |
config = self.config | |
img = page.get_image(72) # use standard dpi = 72, which means we don't need any scaling | |
encoding = self.image_processor(img, return_tensors="pt").to(self.config.torch_device) | |
with torch.no_grad(): | |
outputs = self.detector(**encoding) | |
# keep only predictions of queries with 0.9+ confidence (excluding no-object class) | |
target_sizes = torch.tensor([img.size[::-1]]) | |
threshold = config.detector_base_threshold | |
results = self.image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[ | |
0 | |
] | |
tables = [] | |
for i in range(len(results["boxes"])): | |
bbox = results["boxes"][i].tolist() | |
confidence_score = results["scores"][i].item() | |
label = results["labels"][i].item() | |
if label == 1: | |
tables.append(RotatedCroppedTable(page, bbox, confidence_score, 90, label)) | |
else: | |
tables.append(CroppedTable(page, bbox, confidence_score, label)) | |
return tables | |
class TATRTableDetector(TableDetector): | |
""" | |
Uses TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. | |
Using :meth:`extract` produces a :class:`~gmft.FormattedTable`, which can be exported to csv, df, etc. | |
""" | |
pass | |
class RotatedCroppedTable(CroppedTable): | |
""" | |
Table that has been rotated. | |
Note: ``self.bbox`` and ``self.rect`` are in coordinates of the original pdf. | |
But text_positions() can possibly give transformed coordinates. | |
Currently, only 0, 90, 180, and 270 degree rotations are supported. | |
An angle of 90 would mean that a 90 degree cc rotation has been applied to a level image. | |
In practice, the majority of rotated tables are rotated by 90 degrees. | |
""" | |
def __init__(self, page: BasePage, bbox: tuple[int, int, int, int], confidence_score: float, angle: float, label=0): | |
""" | |
Currently, only 0, 90, 180, and 270 degree rotations are supported. | |
:param page: BasePage | |
:param angle: angle in degrees, counterclockwise. | |
That is, 90 would mean that a 90 degree cc rotation has been applied to a level image. | |
In practice, the majority of rotated tables are rotated by 90 degrees. | |
""" | |
super().__init__(page, bbox, confidence_score, label) | |
if angle not in [0, 90, 180, 270]: | |
raise ValueError("Only 0, 90, 180, 270 are supported.") | |
self.angle = angle | |
def image(self, dpi: int = None, padding: str | tuple[int, int, int, int]=None, | |
margin: str | tuple[int, int, int, int]=None, **kwargs) -> PILImage: | |
""" | |
Return the image of the cropped table. | |
""" | |
img = super().image(dpi=dpi, padding=padding, margin=margin, **kwargs) | |
# if self.angle == 90: | |
if self.angle != 0: | |
# rotate by negative angle to get back to original orientation | |
img = img.rotate(-self.angle, expand=True) | |
return img | |
def text_positions(self, remove_table_offset: bool = False, outside: bool = False) -> Generator[tuple[int, int, int, int, str], None, None]: | |
""" | |
Return the text positions of the cropped table. | |
If remove_table_offset is False, positions are relative to the top-left corner of the pdf (no adjustment for rotation). | |
If remove_table_offset is True, positions are relative to a hypothetical pdf where the text in the table is perfectly level, and | |
pdf's top-left corner is also the table's top-left corner (both at 0, 0). | |
:param remove_table_offset: if True, the positions are adjusted to be relative to the top-left corner of the table. | |
:param outside: if True, returns the **complement** of the table: all the text positions outside the table. | |
:return: list of text positions, which are tuples of (xmin, ymin, xmax, ymax, "string") | |
""" | |
if self.angle == 0 or remove_table_offset == False: | |
yield from super().text_positions(remove_table_offset=remove_table_offset, outside=outside) | |
elif self.angle == 90: | |
for w in super().text_positions(remove_table_offset=True, outside=outside): | |
x0, y0, x1, y1, text = w | |
x0, y0, x1, y1 = self.rect.height - y1, x0, self.rect.height - y0, x1 | |
yield (x0, y0, x1, y1, text) | |
elif self.angle == 180: | |
for w in super().text_positions(remove_table_offset=True, outside=outside): | |
x0, y0, x1, y1, text = w | |
x0, y0, x1, y1 = self.rect.width - x1, self.rect.height - y1, self.rect.width - x0, self.rect.height - y0 | |
yield (x0, y0, x1, y1, text) | |
elif self.angle == 270: | |
for w in super().text_positions(remove_table_offset=True, outside=outside): | |
x0, y0, x1, y1, text = w | |
x0, y0, x1, y1 = y0, self.rect.width - x1, y1, self.rect.width - x0 | |
yield (x0, y0, x1, y1, text) | |
def to_dict(self): | |
d = super().to_dict() | |
d['angle'] = self.angle | |
return d | |
@staticmethod | |
def from_dict(d: dict, page: BasePage) -> Union[CroppedTable, 'RotatedCroppedTable']: | |
""" | |
Create a :class:`CroppedRotatedTable` object from dict. | |
""" | |
if 'angle' not in d: | |
return CroppedTable.from_dict(d, page) | |
return RotatedCroppedTable(page, d['bbox'], d['confidence_score'], d['angle'], d['label']) | |
# def visualize(self, **kwargs): | |
# """ | |
# Visualize the cropped table. | |
# """ | |
# img = self.page.get_image() | |
# plot_results_unwr(img, [self.confidence_score], [self.label], [self.bbox], self.angle, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Module containing methods of formatting tables: structural analysis, data extraction, and converting them into pandas dataframes. | |
Whenever possible, classes (like :class:`AutoTableFormatter`) should be imported from the top-level module, not from this module, | |
as the exact paths may change in future versions. | |
Example: | |
>>> from gmft import AutoTableFormatter | |
""" | |
from abc import ABC, abstractmethod | |
import copy | |
from gmft.pdf_bindings.common import BasePage | |
import torch | |
from gmft.table_detection import CroppedTable, RotatedCroppedTable | |
import transformers | |
from transformers import AutoImageProcessor, TableTransformerForObjectDetection | |
import pandas as pd | |
from gmft.table_function_algorithm import extract_to_df | |
from gmft.table_visualization import plot_results_unwr | |
class FormattedTable(RotatedCroppedTable): | |
""" | |
This is a table that is "formatted", which is to say it is functionalized with header and data information through structural analysis. | |
Therefore, it can be converted into df, csv, etc. | |
""" | |
def __init__(self, cropped_table: CroppedTable, df: pd.DataFrame=None): | |
self._df = df | |
# create shallow copy | |
if 'angle' in cropped_table.__dict__: | |
angle = cropped_table.angle | |
else: | |
angle = 0 | |
super().__init__(page=cropped_table.page, | |
bbox=cropped_table.rect.bbox, | |
confidence_score=cropped_table.confidence_score, | |
angle=angle, | |
label=cropped_table.label) | |
# self.page = cropped_table.page | |
# self.rect = cropped_table.rect | |
# self.bbox = cropped_table.bbox | |
# self.confidence_score = cropped_table.confidence_score | |
# self.label = cropped_table.label | |
self._img = cropped_table._img.copy() if cropped_table._img is not None else None | |
self._img_dpi = cropped_table._img_dpi | |
self._img_padding = cropped_table._img_padding | |
self._img_margin = cropped_table._img_margin | |
def df(self): | |
""" | |
Return the table as a pandas dataframe. | |
""" | |
return self._df | |
@abstractmethod | |
def visualize(self): | |
""" | |
Visualize the table. | |
""" | |
raise NotImplementedError | |
@abstractmethod | |
def to_dict(self): | |
""" | |
Serialize self into dict | |
""" | |
raise NotImplementedError | |
@staticmethod | |
@abstractmethod | |
def from_dict(d: dict, page: BasePage): | |
""" | |
Deserialize from dict | |
""" | |
raise NotImplementedError | |
class TableFormatter(ABC): | |
""" | |
Abstract class for converting a :class:`~gmft.CroppedTable` to a :class:`~gmft.FormattedTable`. | |
Allows export to csv, df, etc. | |
""" | |
@abstractmethod | |
def extract(self, table: CroppedTable) -> FormattedTable: | |
""" | |
Extract the data from the table. | |
Produces a :class:`~gmft.FormattedTable` instance, from which data can be exported in csv, html, etc. | |
""" | |
raise NotImplementedError | |
class TATRFormatConfig: | |
""" | |
Configuration for :class:`~gmft.TATRTableFormatter`. | |
""" | |
# ---- model settings ---- | |
warn_uninitialized_weights: bool = False | |
image_processor_path: str = "microsoft/table-transformer-detection" | |
formatter_path: str = "microsoft/table-transformer-structure-recognition" | |
no_timm: bool = True # use a model which uses AutoBackbone. | |
torch_device: torch.device = torch.device("cpu") | |
# https://huggingface.co/microsoft/table-transformer-structure-recognition/discussions/5 | |
# "microsoft/table-transformer-structure-recognition-v1.1-all" | |
verbosity: int = 1 | |
""" | |
0: errors only | |
1: print warnings | |
2: print warnings and info | |
3: print warnings, info, and debug | |
""" | |
formatter_base_threshold: float = 0.3 | |
"""Base threshold for the confidence demanded of a table feature (row/column). | |
Note that a low threshold is actually better, because overzealous rows means that | |
generally, numbers are still aligned and there are just many empty rows | |
(having fewer rows than expected merges cells, which is bad). | |
""" | |
cell_required_confidence = { | |
0: 0.3, # table | |
1: 0.3, # column | |
2: 0.3, # row | |
3: 0.3, # column header | |
4: 0.5, # projected row header | |
5: 0.5, # spanning cell | |
6: 99 # no object | |
} | |
"""Confidences required (>=) for a row/column feature to be considered good. See TATRFormattedTable.id2label | |
But low confidences may be better than too high confidence (see formatter_base_threshold) | |
""" | |
# ---- df() settings ---- | |
# ---- options ---- | |
remove_null_rows = True | |
"""remove rows with no text""" | |
enable_multi_header = False | |
"""Enable multi-indices in the dataframe. | |
If false, then multiple headers will be merged column-wise.""" | |
semantic_spanning_cells = False | |
""" | |
[Experimental] Enable semantic spanning cells, which often encode hierarchical multi-level indices. | |
""" | |
semantic_hierarchical_left_fill = 'algorithm' | |
""" | |
[Experimental] When semantic spanning cells is enabled, when a left header is detected which might | |
represent a group of rows, that same value is reduplicated for each row. | |
Possible values: 'algorithm', 'deep', None | |
""" | |
# ---- large table ---- | |
large_table_if_n_rows_removed = 8 | |
""" | |
If >= n rows are removed due to non-maxima suppression (NMS), then this table is classified as a large table. | |
""" | |
large_table_threshold = 10 | |
"""with large tables, table transformer struggles with placing too many overlapping rows | |
luckily, with more rows, we have more info on the usual size of text, which we can use to make | |
a guess on the height such that no rows are merged or overlapping | |
large table assumption is only applied when (# of rows > large_table_threshold) AND (total overlap > large_table_row_overlap_threshold) | |
set 9999 to disable, set 0 to force large table assumption to run every time""" | |
large_table_row_overlap_threshold = 0.2 | |
large_table_maximum_rows = 1000 | |
"""If the table predicts a large number of rows, refuse to proceed. Therefore prevent memory issues for super small text.""" | |
force_large_table_assumption=None | |
"""True: force large table assumption to be applied to all tables | |
False: force large table assumption to not be applied to any tables | |
None: heuristically apply large table assumption according to threshold and overlap""" | |
# ---- rejection and warnings ---- | |
total_overlap_reject_threshold = 0.2 | |
"""reject if total overlap is > 20% of table area""" | |
total_overlap_warn_threshold = 0.05 | |
"""warn if total overlap is > 5% of table area""" | |
nms_warn_threshold = 5 | |
"""warn if non maxima suppression removes > 5 rows""" | |
iob_reject_threshold = 0.05 | |
"""reject if iob between textbox and cell is < 5%""" | |
iob_warn_threshold = 0.5 | |
"""warn if iob between textbox and cell is < 50%""" | |
# ---- technical ---- | |
_nms_overlap_threshold = 0.1 | |
"""Non-maxima suppression: if two rows overlap by > threshold (default: 10%), then the one with the lower confidence is removed. | |
A subsequent technique is able to fill in gaps created by NMS.""" | |
_large_table_merge_distance = 0.2 | |
"""In the large_table method, if two means are within (20% * text_height) of each other, then they are merged. | |
This may be useful to adjust if text is being split due to subscripts/superscripts.""" | |
_smallest_supported_text_height = 0.1 | |
"""The smallest supported text height. Text smaller than this height will be ignored. | |
Helps prevent very small text from creating huge arrays under large table assumption.""" | |
# ---- deprecated ---- | |
# aggregate_spanning_cells = False | |
@property | |
def aggregate_spanning_cells(self): | |
raise DeprecationWarning("aggregate_spanning_cells has been removed.") | |
@aggregate_spanning_cells.setter | |
def aggregate_spanning_cells(self, value): | |
raise DeprecationWarning("aggregate_spanning_cells has been removed.") | |
# corner_clip_outlier_threshold = 0.1 | |
# """"corner clip" is when the text is clipped by a corner, and not an edge""" | |
@property | |
def corner_clip_outlier_threshold(self): | |
raise DeprecationWarning("corner_clip_outlier_threshold has been removed.") | |
@corner_clip_outlier_threshold.setter | |
def corner_clip_outlier_threshold(self, value): | |
raise DeprecationWarning("corner_clip_outlier_threshold has been removed.") | |
# spanning_cell_minimum_width = 0.6 | |
@property | |
def spanning_cell_minimum_width(self): | |
raise DeprecationWarning("spanning_cell_minimum_width has been removed.") | |
@spanning_cell_minimum_width.setter | |
def spanning_cell_minimum_width(self, value): | |
raise DeprecationWarning("spanning_cell_minimum_width has been removed.") | |
@property | |
def deduplication_iob_threshold(self): | |
raise DeprecationWarning("deduplication_iob_threshold is deprecated. See nms_overlap_threshold instead.") | |
@deduplication_iob_threshold.setter | |
def deduplication_iob_threshold(self, value): | |
raise DeprecationWarning("deduplication_iob_threshold is deprecated. See nms_overlap_threshold instead.") | |
def __init__(self, torch_device: torch.device = None): | |
if torch_device is not None: | |
self.torch_device = torch_device | |
class TATRFormattedTable(FormattedTable): | |
""" | |
FormattedTable, as seen by a Table Transformer (TATR). | |
See :class:`~gmft.TATRTableFormatter`. | |
""" | |
_POSSIBLE_ROWS = ['table row', 'table spanning cell', 'table projected row header'] # , 'table column header'] | |
_POSSIBLE_PROJECTING_ROWS = ['table projected row header'] # , 'table spanning cell'] | |
_POSSIBLE_COLUMN_HEADERS = ['table column header'] | |
_POSSIBLE_COLUMNS = ['table column'] | |
id2label = { | |
0: 'table', | |
1: 'table column', | |
2: 'table row', | |
3: 'table column header', | |
4: 'table projected row header', | |
5: 'table spanning cell', | |
6: 'no object', | |
} | |
label2id = {v: k for k, v in id2label.items()} | |
config: TATRFormatConfig | |
outliers: dict[str, bool] | |
effective_rows: list[tuple] | |
"Rows as seen by the image --> df algorithm, which may differ from what the table transformer sees." | |
effective_columns: list[tuple] | |
"Columns as seen by the image --> df algorithm, which may differ from what the table transformer sees." | |
effective_headers: list[tuple] | |
"Headers as seen by the image --> df algorithm." | |
effective_projecting: list[tuple] | |
"Projected rows as seen by the image --> df algorithm." | |
effective_spanning: list[tuple] | |
"Spanning cells as seen by the image --> df algorithm." | |
def __init__(self, cropped_table: CroppedTable, fctn_results: dict, | |
# fctn_scale_factor: float, fctn_padding: tuple[int, int, int, int], | |
config: TATRFormatConfig=None): | |
super(TATRFormattedTable, self).__init__(cropped_table) | |
self.fctn_results = fctn_results | |
# self.fctn_scale_factor = fctn_scale_factor | |
# self.fctn_padding = tuple(fctn_padding) | |
if config is None: | |
config = TATRFormatConfig() | |
self.config = config | |
self.outliers = None | |
def df(self, config_overrides: TATRFormatConfig=None): | |
""" | |
Return the table as a pandas dataframe. | |
:param config_overrides: override the config settings for this call only | |
""" | |
if self._df is not None: # cache | |
return self._df | |
if config_overrides is not None: | |
config = copy.deepcopy(self.config) | |
config.__dict__.update(config_overrides.__dict__) | |
else: | |
config = self.config | |
self._df = extract_to_df(self, config=config) | |
return self._df | |
def visualize(self, filter=None, dpi=None, padding=None, margin=(10,10,10,10), effective=False, **kwargs): | |
""" | |
Visualize the table. | |
:param filter: filter the labels to visualize. See TATRFormattedTable.id2label | |
:param dpi: Sets the dpi. If none, then the dpi of the cached image is used. | |
:param effective: if True, visualize the effective rows and columns, which may differ from the table transformer's output. | |
""" | |
if dpi is None: # dpi = needed_dpi | |
dpi = self._img_dpi | |
if dpi is None: | |
dpi = 72 | |
if self._df is None: | |
self._df = self.df() | |
scale_by = (dpi / 72) | |
if effective: | |
vis = self.effective_rows + self.effective_columns + self.effective_headers + self.effective_projecting + self.effective_spanning | |
boxes = [x['bbox'] for x in vis] | |
boxes = [(x * scale_by for x in bbox) for bbox in boxes] | |
_to_visualize = { | |
"scores": [x['confidence'] for x in vis], | |
"labels": [self.label2id[x['label']] for x in vis], | |
"boxes": boxes | |
} | |
else: | |
# transform functionalized coordinates into image coordinates | |
# sf = self.fctn_scale_factor | |
# pdg = self.fctn_padding | |
# boxes = [_normalize_bbox(bbox, used_scale_factor=sf / scale_by, used_padding=pdg) for bbox in self.fctn_results["boxes"]] | |
boxes = [(x * scale_by for x in bbox) for bbox in self.fctn_results["boxes"]] | |
_to_visualize = { | |
"scores": self.fctn_results["scores"], | |
"labels": self.fctn_results["labels"], | |
"boxes": boxes | |
} | |
# get needed scale factor and dpi | |
img = self.image(dpi=dpi, padding=padding, margin=margin) | |
# if self._img is not None: | |
true_margin = [x * (dpi / 72) for x in self._img_margin] | |
return plot_results_unwr(img, _to_visualize['scores'], _to_visualize['labels'], _to_visualize['boxes'], TATRFormattedTable.id2label, | |
filter=filter, padding=padding, margin=true_margin, **kwargs) | |
def to_dict(self): | |
""" | |
Serialize self into dict | |
""" | |
if self.angle != 0: | |
parent = RotatedCroppedTable.to_dict(self) | |
else: | |
parent = CroppedTable.to_dict(self) | |
return {**parent, **{ | |
# 'fctn_scale_factor': self.fctn_scale_factor, | |
# 'fctn_padding': list(self.fctn_padding), | |
'config': self.config.__dict__, | |
'outliers': self.outliers, | |
'fctn_results': self.fctn_results, | |
}} | |
@staticmethod | |
def from_dict(d: dict, page: BasePage): | |
""" | |
Deserialize from dict. | |
A page is required partly because of memory management, since having this open a page may cause memory issues. | |
""" | |
cropped_table = CroppedTable.from_dict(d, page) | |
if 'fctn_results' not in d: | |
raise ValueError("fctn_results not found in dict -- dict may be a CroppedTable but not a TATRFormattedTable.") | |
config = TATRFormatConfig() | |
for k, v in d['config'].items(): | |
if v is not None and config.__dict__.get(k) != v: | |
setattr(config, k, v) | |
results = d['fctn_results'] | |
if 'fctn_scale_factor' in d or 'scale_factor' in d or 'fctn_padding' in d or 'padding' in d: | |
# deprecated: this is for backwards compatibility | |
scale_factor = d.get('fctn_scale_factor', d.get('scale_factor', 1)) | |
padding = d.get('fctn_padding', d.get('padding', (0, 0))) | |
padding = tuple(padding) | |
# normalize results here | |
for i, bbox in enumerate(results["boxes"]): | |
results["boxes"][i] = _normalize_bbox(bbox, used_scale_factor=scale_factor, used_padding=padding) | |
table = TATRFormattedTable(cropped_table, results, # scale_factor, tuple(padding), | |
config=config) | |
table.outliers = d.get('outliers', None) | |
return table | |
class TATRTableFormatter(TableFormatter): | |
""" | |
Uses a TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables. | |
Using :meth:`extract`, a :class:`~gmft.FormattedTable` is produced, which can be exported to csv, df, etc. | |
""" | |
def __init__(self, config: TATRFormatConfig=None): | |
if config is None: | |
config = TATRFormatConfig() | |
if not config.warn_uninitialized_weights: | |
previous_verbosity = transformers.logging.get_verbosity() | |
transformers.logging.set_verbosity(transformers.logging.ERROR) | |
if not config.warn_uninitialized_weights: | |
previous_verbosity = transformers.logging.get_verbosity() | |
transformers.logging.set_verbosity(transformers.logging.ERROR) | |
self.image_processor = AutoImageProcessor.from_pretrained(config.image_processor_path) | |
revision = "no_timm" if config.no_timm else None | |
self.structor = TableTransformerForObjectDetection.from_pretrained(config.formatter_path, revision=revision).to(config.torch_device) | |
self.config = config | |
if not config.warn_uninitialized_weights: | |
transformers.logging.set_verbosity(previous_verbosity) | |
def extract(self, table: CroppedTable, dpi=144, padding='auto', margin=None, config_overrides=None) -> FormattedTable: | |
""" | |
Extract the data from the table. | |
""" | |
if config_overrides is not None: | |
config = copy.deepcopy(self.config) | |
config.__dict__.update(config_overrides.__dict__) | |
else: | |
config = self.config | |
image = table.image(dpi=dpi, padding=padding, margin=margin) # (20, 20, 20, 20) | |
padding = table._img_padding | |
margin = table._img_margin | |
scale_factor = dpi / 72 | |
encoding = self.image_processor(image, return_tensors="pt").to(self.config.torch_device) | |
with torch.no_grad(): | |
outputs = self.structor(**encoding) | |
target_sizes = [image.size[::-1]] | |
# threshold = 0.3 | |
# note that a LOW threshold is good because the model is overzealous in | |
# but since we find the highest-intersecting row, same-row elements still tend to stay together | |
# this is better than having a high threshold, because if we have fewer rows than expected, we merge cells | |
# losing information | |
results = self.image_processor.post_process_object_detection(outputs, threshold=config.formatter_base_threshold, target_sizes=target_sizes)[0] | |
# create a new FormattedTable instance with the cropped table and the dataframe | |
# formatted_table = FormattedTable(table, df) | |
# return formatted_table | |
results = {k: v.tolist() for k, v in results.items()} | |
# normalize results w.r.t. padding and scale factor | |
for i, bbox in enumerate(results["boxes"]): | |
results["boxes"][i] = _normalize_bbox(bbox, used_scale_factor=scale_factor, used_padding=padding, used_margin=margin) | |
formatted_table = TATRFormattedTable(table, results, # scale_factor, padding, | |
config=config) | |
return formatted_table | |
def _normalize_bbox(bbox: tuple[float, float, float, float], used_scale_factor: float, | |
used_padding: tuple[float, float], used_margin: tuple[float, float] =None): | |
""" | |
Normalize bbox such that: | |
1. padding is removed (so (0, 0) is the top-left of the cropped table) | |
2. scale factor is normalized (dpi=72) | |
3. margin is removed (so (0, 0) is the start of the original detected bbox) | |
""" | |
# print("Margin: ", used_margin) | |
if used_margin is None: | |
used_margin = (0, 0) | |
bbox = [bbox[0] - used_padding[0], bbox[1] - used_padding[1], bbox[2] - used_padding[0], bbox[3] - used_padding[1]] | |
bbox = [bbox[0] / used_scale_factor, bbox[1] / used_scale_factor, bbox[2] / used_scale_factor, bbox[3] / used_scale_factor] | |
bbox = [bbox[0] - used_margin[0], bbox[1] - used_margin[1], bbox[2] - used_margin[0], bbox[3] - used_margin[1]] | |
return bbox | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import torch | |
from gmft.pdf_bindings import PyPDFium2Document | |
from gmft.table_detection import TableDetector, TableDetectorConfig | |
from gmft.table_function import TATRFormatConfig, TATRTableFormatter | |
@pytest.fixture | |
def doc_tiny(): | |
doc = PyPDFium2Document("test/samples/tiny.pdf") | |
yield doc | |
doc.close() | |
def test_cuda(doc_tiny): | |
if not torch.cuda.is_available(): | |
raise RuntimeError("cannot test device settings without cuda") | |
page = doc_tiny[0] | |
detector = TableDetector(TableDetectorConfig(torch_device=torch.device("cuda"))) | |
formatter = TATRTableFormatter(TATRFormatConfig(torch_device=torch.device("cuda"))) | |
table = detector.extract(page)[0] | |
ft = formatter.extract(table) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment