Skip to content

Instantly share code, notes, and snippets.

@MathiasToftas
Created July 10, 2024 07:13
Show Gist options
  • Save MathiasToftas/dae1165a22da14a34f0813163b333be5 to your computer and use it in GitHub Desktop.
Save MathiasToftas/dae1165a22da14a34f0813163b333be5 to your computer and use it in GitHub Desktop.
cuda support
"""
Module containing methods of detecting tables from whole pdf pages.
Whenever possible, classes (like :class:`TableDetector`) should be imported from the top-level module, not from this module,
as the exact paths may change in future versions.
Example:
>>> from gmft import TableDetector
"""
import copy
from typing import Generator, Union
import PIL.Image
from PIL.Image import Image as PILImage
import torch
import transformers
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
from gmft.common import Rect
from gmft.pdf_bindings.common import BasePage, ImageOnlyPage
from gmft.table_visualization import plot_results_unwr
def position_words(words: Generator[tuple[int, int, int, int, str], None, None], y_gap=3):
"""
Helper function to convert a list of words with positions to a string.
"""
# assume reading order is left to right, then top to bottom
first = next(words, None)
if first is None:
return ""
prev_left, prev_top, prev_right, prev_bottom, lines = first
# y_gap = 2 # consider the y jumping by y_gap to be a new line
for word in words:
x0, y0, x1, y1, text = word[:5]
if abs(y1 - prev_bottom) >= y_gap:
lines += f"\n{text}"
else:
lines += f" {text}"
prev_bottom = y1
return lines
class CroppedTable:
"""
A pdf selection, cropped to include just a table.
Created by :class:`~gmft.TableDetector`.
"""
_img: PILImage
_img_dpi: int
_img_padding: tuple[int, int, int, int]
_img_margin: tuple[int, int, int, int]
def __init__(self, page: BasePage, bbox: tuple[int, int, int, int] | Rect, confidence_score: float, label=0):
self.page = page
if isinstance(bbox, Rect):
self.rect = bbox
else:
self.rect = Rect(bbox)
self.confidence_score = confidence_score
self._img = None
self._img_dpi = None
self._img_padding = None
self._img_margin = None
self.label = label
def image(self, dpi: int = None, padding: str | tuple[int, int, int, int]=None, margin: str | tuple[int, int, int, int]=None) -> PILImage:
"""
Return the image of the cropped table.
Following pypdfium2, scaling_factor = (dpi / 72).
Therefore, dpi=72 is the default, and dpi=144 is x2 zoom.
:param dpi: dots per inch. If not None, the scaling_factor parameter is ignored.
:param padding: padding (**blank pixels**) to add to the image. Tuple of (left, top, right, bottom)
Padding (blank pixels) is added after the crop and rotation.
Padding is important for subsequent row/column detection; see https://github.com/microsoft/table-transformer/issues/68 for discussion.
If padding = 'auto', the padding is automatically set to 10% of the larger of {width, height}.
Default is no padding.
:param margin: add content (in **pdf units**) from the original pdf beyond the detected table bbox boundary.
:return: image of the cropped table
"""
dpi = 72 if dpi is None else dpi
if padding == 'auto':
width= self.rect.width * dpi / 72
height = self.rect.height * dpi / 72
pad = int(max(width, height) * 0.1)
padding = (pad, pad, pad, pad)
elif padding == None:
padding = (0, 0, 0, 0)
# if effective_dpi == self._img_dpi and effective_padding == self._img_padding:
# return self._img # cache results
rect = self.rect
if margin == 'auto':
margin = (30, 30, 30, 30) # from the paper
if margin is not None:
rect = Rect((rect.xmin - margin[0], rect.ymin - margin[1],
rect.xmax + margin[2], rect.ymax + margin[3]))
img = self.page.get_image(dpi=dpi, rect=rect)
if padding is not None:
img = PIL.ImageOps.expand(img, padding, fill="white")
self._img = img
self._img_dpi = dpi
self._img_padding = padding
self._img_margin = margin
return self._img
def text_positions(self, remove_table_offset: bool = False, outside: bool = False) -> Generator[tuple[int, int, int, int, str], None, None]:
"""
Return the text positions of the cropped table.
Any words that intersect the table are captured, even if they are not fully contained.
:param remove_table_offset: if True, the positions are adjusted to be relative to the top-left corner of the table.
:param outside: if True, returns the **complement** of the table: all the text positions outside the table.
By default, it returns the text positions inside the table.
:return: list of text positions, which is a tuple
``(x0, y0, x1, y1, "string")``
"""
for w in self.page.get_positions_and_text():
if Rect(w[:4]).is_intersecting(self.rect) != outside:
if remove_table_offset:
yield (w[0] - self.rect.xmin, w[1] - self.rect.ymin, w[2] - self.rect.xmin, w[3] - self.rect.ymin, w[4])
else:
yield w
# words = [w for w in self.page.get_positions_and_text()]
# if outside:
# # get the table's complement
# subset = [w for w in words if not Rect(w[:4]).is_intersecting(self.rect)]
# else:
# # get the table
# subset = [w for w in words if Rect(w[:4]).is_intersecting(self.rect)]
# if remove_table_offset:
# subset = [(w[0] - self.rect.xmin, w[1] - self.rect.ymin, w[2] - self.rect.xmin, w[3] - self.rect.ymin, w[4]) for w in subset]
# return subset
def text(self):
"""
Return the text of the cropped table.
Any words that intersect the table are captured, even if they are not fully contained.
:return: text of the cropped table
"""
return position_words(self.text_positions())
def visualize(self, show_text=False, **kwargs):
"""
Visualize the cropped table.
"""
img = self.page.get_image()
confidences = [self.confidence_score]
labels = [self.label]
bboxes = [self.rect.bbox]
if show_text:
# text_positions = [(x0, y0, x1, y1) for x0, y0, x1, y1, _ in self.text_positions()]
text_positions = [w[:4] for w in self.page.get_positions_and_text()]
confidences += [0.9] * len(text_positions)
labels += [-1] * len(text_positions)
bboxes += text_positions
plot_results_unwr(img, confidence=confidences, labels=labels, boxes=bboxes, id2label=None, **kwargs)
def to_dict(self):
return {
"filename": self.page.get_filename(),
"page_no": self.page.page_number,
"bbox": self.rect.bbox,
"confidence_score": self.confidence_score,
"label": self.label
}
@staticmethod
def from_dict(d: dict, page: BasePage):
"""
Deserialize a CroppedTable object from dict.
Because file locations may change, require the user to provide the original page -
but as a helper method see PyPDFium2Utils.load_page_from_dict and PyPDFium2Utils.reload
:param d: dict
:param page: BasePage
:return: CroppedTable object
"""
if 'angle' in d:
return RotatedCroppedTable.from_dict(d, page)
return CroppedTable(page, d['bbox'], d['confidence_score'], d['label'])
@staticmethod
def from_image_only(img: PILImage) -> 'CroppedTable':
"""
Create a :class:`~gmft.CroppedTable` object from an image only.
:param img: PIL image
:return: CroppedTable object
"""
page = ImageOnlyPage(img)
# bbox is the entire image
bbox = (0, 0, img.width, img.height)
table = CroppedTable(page, bbox, confidence_score=1.0, label=0)
table._img = img
table._img_dpi = 72
return table
@property
def bbox(self):
return self.rect.bbox
class TableDetectorConfig:
"""
Configuration for the :class:`~gmft.TableDetector` class.
"""
image_processor_path: str = "microsoft/table-transformer-detection"
detector_path: str = "microsoft/table-transformer-detection"
no_timm: bool = True # huggingface revision
warn_uninitialized_weights: bool = False
torch_device: torch.device = torch.device("cpu")
detector_base_threshold: float = 0.9
"""Minimum confidence score required for a table"""
@property
def confidence_score_threshold(self):
raise DeprecationWarning("Use detector_base_threshold instead.")
@confidence_score_threshold.setter
def confidence_score_threshold(self, value):
raise DeprecationWarning("Use detector_base_threshold instead.")
def __init__(self, image_processor_path: str = None, detector_path: str = None, torch_device: torch.device = None):
if image_processor_path is not None:
self.image_processor_path = image_processor_path
if detector_path is not None:
self.detector_path = detector_path
if torch_device is not None:
self.torch_device = torch_device
class TableDetector:
"""
Detects tables in a pdf page. Default implementation uses TableTransformerForObjectDetection.
"""
def __init__(self, config: TableDetectorConfig=None, default_implementation=True):
"""
Initialize the TableDetector.
:param config: TableDetectorConfig
:param default_implementation: Should be True, unless you are writing a custom subclass for TableDetector.
"""
# future-proofing: allow subclasses for TableDetector to have different architectures
if not default_implementation:
return
if config is None:
config = TableDetectorConfig()
elif isinstance(config, dict):
config = TableDetectorConfig(**config)
if not config.warn_uninitialized_weights:
previous_verbosity = transformers.logging.get_verbosity()
transformers.logging.set_verbosity(transformers.logging.ERROR)
self.image_processor = AutoImageProcessor.from_pretrained(config.image_processor_path)
revision = "no_timm" if config.no_timm else None
self.detector = TableTransformerForObjectDetection.from_pretrained(config.detector_path, revision=revision).to(config.torch_device)
if not config.warn_uninitialized_weights:
transformers.logging.set_verbosity(previous_verbosity)
self.config = config
def extract(self, page: BasePage, config_overrides: TableDetectorConfig=None) -> list[CroppedTable]:
"""
Detect tables in a page.
:param page: BasePage
:param config_overrides: override the config for this call only
:return: list of CroppedTable objects
"""
if config_overrides is not None:
config = copy.deepcopy(self.config)
config.__dict__.update(config_overrides.__dict__)
else:
config = self.config
img = page.get_image(72) # use standard dpi = 72, which means we don't need any scaling
encoding = self.image_processor(img, return_tensors="pt").to(self.config.torch_device)
with torch.no_grad():
outputs = self.detector(**encoding)
# keep only predictions of queries with 0.9+ confidence (excluding no-object class)
target_sizes = torch.tensor([img.size[::-1]])
threshold = config.detector_base_threshold
results = self.image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[
0
]
tables = []
for i in range(len(results["boxes"])):
bbox = results["boxes"][i].tolist()
confidence_score = results["scores"][i].item()
label = results["labels"][i].item()
if label == 1:
tables.append(RotatedCroppedTable(page, bbox, confidence_score, 90, label))
else:
tables.append(CroppedTable(page, bbox, confidence_score, label))
return tables
class TATRTableDetector(TableDetector):
"""
Uses TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables.
Using :meth:`extract` produces a :class:`~gmft.FormattedTable`, which can be exported to csv, df, etc.
"""
pass
class RotatedCroppedTable(CroppedTable):
"""
Table that has been rotated.
Note: ``self.bbox`` and ``self.rect`` are in coordinates of the original pdf.
But text_positions() can possibly give transformed coordinates.
Currently, only 0, 90, 180, and 270 degree rotations are supported.
An angle of 90 would mean that a 90 degree cc rotation has been applied to a level image.
In practice, the majority of rotated tables are rotated by 90 degrees.
"""
def __init__(self, page: BasePage, bbox: tuple[int, int, int, int], confidence_score: float, angle: float, label=0):
"""
Currently, only 0, 90, 180, and 270 degree rotations are supported.
:param page: BasePage
:param angle: angle in degrees, counterclockwise.
That is, 90 would mean that a 90 degree cc rotation has been applied to a level image.
In practice, the majority of rotated tables are rotated by 90 degrees.
"""
super().__init__(page, bbox, confidence_score, label)
if angle not in [0, 90, 180, 270]:
raise ValueError("Only 0, 90, 180, 270 are supported.")
self.angle = angle
def image(self, dpi: int = None, padding: str | tuple[int, int, int, int]=None,
margin: str | tuple[int, int, int, int]=None, **kwargs) -> PILImage:
"""
Return the image of the cropped table.
"""
img = super().image(dpi=dpi, padding=padding, margin=margin, **kwargs)
# if self.angle == 90:
if self.angle != 0:
# rotate by negative angle to get back to original orientation
img = img.rotate(-self.angle, expand=True)
return img
def text_positions(self, remove_table_offset: bool = False, outside: bool = False) -> Generator[tuple[int, int, int, int, str], None, None]:
"""
Return the text positions of the cropped table.
If remove_table_offset is False, positions are relative to the top-left corner of the pdf (no adjustment for rotation).
If remove_table_offset is True, positions are relative to a hypothetical pdf where the text in the table is perfectly level, and
pdf's top-left corner is also the table's top-left corner (both at 0, 0).
:param remove_table_offset: if True, the positions are adjusted to be relative to the top-left corner of the table.
:param outside: if True, returns the **complement** of the table: all the text positions outside the table.
:return: list of text positions, which are tuples of (xmin, ymin, xmax, ymax, "string")
"""
if self.angle == 0 or remove_table_offset == False:
yield from super().text_positions(remove_table_offset=remove_table_offset, outside=outside)
elif self.angle == 90:
for w in super().text_positions(remove_table_offset=True, outside=outside):
x0, y0, x1, y1, text = w
x0, y0, x1, y1 = self.rect.height - y1, x0, self.rect.height - y0, x1
yield (x0, y0, x1, y1, text)
elif self.angle == 180:
for w in super().text_positions(remove_table_offset=True, outside=outside):
x0, y0, x1, y1, text = w
x0, y0, x1, y1 = self.rect.width - x1, self.rect.height - y1, self.rect.width - x0, self.rect.height - y0
yield (x0, y0, x1, y1, text)
elif self.angle == 270:
for w in super().text_positions(remove_table_offset=True, outside=outside):
x0, y0, x1, y1, text = w
x0, y0, x1, y1 = y0, self.rect.width - x1, y1, self.rect.width - x0
yield (x0, y0, x1, y1, text)
def to_dict(self):
d = super().to_dict()
d['angle'] = self.angle
return d
@staticmethod
def from_dict(d: dict, page: BasePage) -> Union[CroppedTable, 'RotatedCroppedTable']:
"""
Create a :class:`CroppedRotatedTable` object from dict.
"""
if 'angle' not in d:
return CroppedTable.from_dict(d, page)
return RotatedCroppedTable(page, d['bbox'], d['confidence_score'], d['angle'], d['label'])
# def visualize(self, **kwargs):
# """
# Visualize the cropped table.
# """
# img = self.page.get_image()
# plot_results_unwr(img, [self.confidence_score], [self.label], [self.bbox], self.angle, **kwargs)
"""
Module containing methods of formatting tables: structural analysis, data extraction, and converting them into pandas dataframes.
Whenever possible, classes (like :class:`AutoTableFormatter`) should be imported from the top-level module, not from this module,
as the exact paths may change in future versions.
Example:
>>> from gmft import AutoTableFormatter
"""
from abc import ABC, abstractmethod
import copy
from gmft.pdf_bindings.common import BasePage
import torch
from gmft.table_detection import CroppedTable, RotatedCroppedTable
import transformers
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import pandas as pd
from gmft.table_function_algorithm import extract_to_df
from gmft.table_visualization import plot_results_unwr
class FormattedTable(RotatedCroppedTable):
"""
This is a table that is "formatted", which is to say it is functionalized with header and data information through structural analysis.
Therefore, it can be converted into df, csv, etc.
"""
def __init__(self, cropped_table: CroppedTable, df: pd.DataFrame=None):
self._df = df
# create shallow copy
if 'angle' in cropped_table.__dict__:
angle = cropped_table.angle
else:
angle = 0
super().__init__(page=cropped_table.page,
bbox=cropped_table.rect.bbox,
confidence_score=cropped_table.confidence_score,
angle=angle,
label=cropped_table.label)
# self.page = cropped_table.page
# self.rect = cropped_table.rect
# self.bbox = cropped_table.bbox
# self.confidence_score = cropped_table.confidence_score
# self.label = cropped_table.label
self._img = cropped_table._img.copy() if cropped_table._img is not None else None
self._img_dpi = cropped_table._img_dpi
self._img_padding = cropped_table._img_padding
self._img_margin = cropped_table._img_margin
def df(self):
"""
Return the table as a pandas dataframe.
"""
return self._df
@abstractmethod
def visualize(self):
"""
Visualize the table.
"""
raise NotImplementedError
@abstractmethod
def to_dict(self):
"""
Serialize self into dict
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def from_dict(d: dict, page: BasePage):
"""
Deserialize from dict
"""
raise NotImplementedError
class TableFormatter(ABC):
"""
Abstract class for converting a :class:`~gmft.CroppedTable` to a :class:`~gmft.FormattedTable`.
Allows export to csv, df, etc.
"""
@abstractmethod
def extract(self, table: CroppedTable) -> FormattedTable:
"""
Extract the data from the table.
Produces a :class:`~gmft.FormattedTable` instance, from which data can be exported in csv, html, etc.
"""
raise NotImplementedError
class TATRFormatConfig:
"""
Configuration for :class:`~gmft.TATRTableFormatter`.
"""
# ---- model settings ----
warn_uninitialized_weights: bool = False
image_processor_path: str = "microsoft/table-transformer-detection"
formatter_path: str = "microsoft/table-transformer-structure-recognition"
no_timm: bool = True # use a model which uses AutoBackbone.
torch_device: torch.device = torch.device("cpu")
# https://huggingface.co/microsoft/table-transformer-structure-recognition/discussions/5
# "microsoft/table-transformer-structure-recognition-v1.1-all"
verbosity: int = 1
"""
0: errors only
1: print warnings
2: print warnings and info
3: print warnings, info, and debug
"""
formatter_base_threshold: float = 0.3
"""Base threshold for the confidence demanded of a table feature (row/column).
Note that a low threshold is actually better, because overzealous rows means that
generally, numbers are still aligned and there are just many empty rows
(having fewer rows than expected merges cells, which is bad).
"""
cell_required_confidence = {
0: 0.3, # table
1: 0.3, # column
2: 0.3, # row
3: 0.3, # column header
4: 0.5, # projected row header
5: 0.5, # spanning cell
6: 99 # no object
}
"""Confidences required (>=) for a row/column feature to be considered good. See TATRFormattedTable.id2label
But low confidences may be better than too high confidence (see formatter_base_threshold)
"""
# ---- df() settings ----
# ---- options ----
remove_null_rows = True
"""remove rows with no text"""
enable_multi_header = False
"""Enable multi-indices in the dataframe.
If false, then multiple headers will be merged column-wise."""
semantic_spanning_cells = False
"""
[Experimental] Enable semantic spanning cells, which often encode hierarchical multi-level indices.
"""
semantic_hierarchical_left_fill = 'algorithm'
"""
[Experimental] When semantic spanning cells is enabled, when a left header is detected which might
represent a group of rows, that same value is reduplicated for each row.
Possible values: 'algorithm', 'deep', None
"""
# ---- large table ----
large_table_if_n_rows_removed = 8
"""
If >= n rows are removed due to non-maxima suppression (NMS), then this table is classified as a large table.
"""
large_table_threshold = 10
"""with large tables, table transformer struggles with placing too many overlapping rows
luckily, with more rows, we have more info on the usual size of text, which we can use to make
a guess on the height such that no rows are merged or overlapping
large table assumption is only applied when (# of rows > large_table_threshold) AND (total overlap > large_table_row_overlap_threshold)
set 9999 to disable, set 0 to force large table assumption to run every time"""
large_table_row_overlap_threshold = 0.2
large_table_maximum_rows = 1000
"""If the table predicts a large number of rows, refuse to proceed. Therefore prevent memory issues for super small text."""
force_large_table_assumption=None
"""True: force large table assumption to be applied to all tables
False: force large table assumption to not be applied to any tables
None: heuristically apply large table assumption according to threshold and overlap"""
# ---- rejection and warnings ----
total_overlap_reject_threshold = 0.2
"""reject if total overlap is > 20% of table area"""
total_overlap_warn_threshold = 0.05
"""warn if total overlap is > 5% of table area"""
nms_warn_threshold = 5
"""warn if non maxima suppression removes > 5 rows"""
iob_reject_threshold = 0.05
"""reject if iob between textbox and cell is < 5%"""
iob_warn_threshold = 0.5
"""warn if iob between textbox and cell is < 50%"""
# ---- technical ----
_nms_overlap_threshold = 0.1
"""Non-maxima suppression: if two rows overlap by > threshold (default: 10%), then the one with the lower confidence is removed.
A subsequent technique is able to fill in gaps created by NMS."""
_large_table_merge_distance = 0.2
"""In the large_table method, if two means are within (20% * text_height) of each other, then they are merged.
This may be useful to adjust if text is being split due to subscripts/superscripts."""
_smallest_supported_text_height = 0.1
"""The smallest supported text height. Text smaller than this height will be ignored.
Helps prevent very small text from creating huge arrays under large table assumption."""
# ---- deprecated ----
# aggregate_spanning_cells = False
@property
def aggregate_spanning_cells(self):
raise DeprecationWarning("aggregate_spanning_cells has been removed.")
@aggregate_spanning_cells.setter
def aggregate_spanning_cells(self, value):
raise DeprecationWarning("aggregate_spanning_cells has been removed.")
# corner_clip_outlier_threshold = 0.1
# """"corner clip" is when the text is clipped by a corner, and not an edge"""
@property
def corner_clip_outlier_threshold(self):
raise DeprecationWarning("corner_clip_outlier_threshold has been removed.")
@corner_clip_outlier_threshold.setter
def corner_clip_outlier_threshold(self, value):
raise DeprecationWarning("corner_clip_outlier_threshold has been removed.")
# spanning_cell_minimum_width = 0.6
@property
def spanning_cell_minimum_width(self):
raise DeprecationWarning("spanning_cell_minimum_width has been removed.")
@spanning_cell_minimum_width.setter
def spanning_cell_minimum_width(self, value):
raise DeprecationWarning("spanning_cell_minimum_width has been removed.")
@property
def deduplication_iob_threshold(self):
raise DeprecationWarning("deduplication_iob_threshold is deprecated. See nms_overlap_threshold instead.")
@deduplication_iob_threshold.setter
def deduplication_iob_threshold(self, value):
raise DeprecationWarning("deduplication_iob_threshold is deprecated. See nms_overlap_threshold instead.")
def __init__(self, torch_device: torch.device = None):
if torch_device is not None:
self.torch_device = torch_device
class TATRFormattedTable(FormattedTable):
"""
FormattedTable, as seen by a Table Transformer (TATR).
See :class:`~gmft.TATRTableFormatter`.
"""
_POSSIBLE_ROWS = ['table row', 'table spanning cell', 'table projected row header'] # , 'table column header']
_POSSIBLE_PROJECTING_ROWS = ['table projected row header'] # , 'table spanning cell']
_POSSIBLE_COLUMN_HEADERS = ['table column header']
_POSSIBLE_COLUMNS = ['table column']
id2label = {
0: 'table',
1: 'table column',
2: 'table row',
3: 'table column header',
4: 'table projected row header',
5: 'table spanning cell',
6: 'no object',
}
label2id = {v: k for k, v in id2label.items()}
config: TATRFormatConfig
outliers: dict[str, bool]
effective_rows: list[tuple]
"Rows as seen by the image --> df algorithm, which may differ from what the table transformer sees."
effective_columns: list[tuple]
"Columns as seen by the image --> df algorithm, which may differ from what the table transformer sees."
effective_headers: list[tuple]
"Headers as seen by the image --> df algorithm."
effective_projecting: list[tuple]
"Projected rows as seen by the image --> df algorithm."
effective_spanning: list[tuple]
"Spanning cells as seen by the image --> df algorithm."
def __init__(self, cropped_table: CroppedTable, fctn_results: dict,
# fctn_scale_factor: float, fctn_padding: tuple[int, int, int, int],
config: TATRFormatConfig=None):
super(TATRFormattedTable, self).__init__(cropped_table)
self.fctn_results = fctn_results
# self.fctn_scale_factor = fctn_scale_factor
# self.fctn_padding = tuple(fctn_padding)
if config is None:
config = TATRFormatConfig()
self.config = config
self.outliers = None
def df(self, config_overrides: TATRFormatConfig=None):
"""
Return the table as a pandas dataframe.
:param config_overrides: override the config settings for this call only
"""
if self._df is not None: # cache
return self._df
if config_overrides is not None:
config = copy.deepcopy(self.config)
config.__dict__.update(config_overrides.__dict__)
else:
config = self.config
self._df = extract_to_df(self, config=config)
return self._df
def visualize(self, filter=None, dpi=None, padding=None, margin=(10,10,10,10), effective=False, **kwargs):
"""
Visualize the table.
:param filter: filter the labels to visualize. See TATRFormattedTable.id2label
:param dpi: Sets the dpi. If none, then the dpi of the cached image is used.
:param effective: if True, visualize the effective rows and columns, which may differ from the table transformer's output.
"""
if dpi is None: # dpi = needed_dpi
dpi = self._img_dpi
if dpi is None:
dpi = 72
if self._df is None:
self._df = self.df()
scale_by = (dpi / 72)
if effective:
vis = self.effective_rows + self.effective_columns + self.effective_headers + self.effective_projecting + self.effective_spanning
boxes = [x['bbox'] for x in vis]
boxes = [(x * scale_by for x in bbox) for bbox in boxes]
_to_visualize = {
"scores": [x['confidence'] for x in vis],
"labels": [self.label2id[x['label']] for x in vis],
"boxes": boxes
}
else:
# transform functionalized coordinates into image coordinates
# sf = self.fctn_scale_factor
# pdg = self.fctn_padding
# boxes = [_normalize_bbox(bbox, used_scale_factor=sf / scale_by, used_padding=pdg) for bbox in self.fctn_results["boxes"]]
boxes = [(x * scale_by for x in bbox) for bbox in self.fctn_results["boxes"]]
_to_visualize = {
"scores": self.fctn_results["scores"],
"labels": self.fctn_results["labels"],
"boxes": boxes
}
# get needed scale factor and dpi
img = self.image(dpi=dpi, padding=padding, margin=margin)
# if self._img is not None:
true_margin = [x * (dpi / 72) for x in self._img_margin]
return plot_results_unwr(img, _to_visualize['scores'], _to_visualize['labels'], _to_visualize['boxes'], TATRFormattedTable.id2label,
filter=filter, padding=padding, margin=true_margin, **kwargs)
def to_dict(self):
"""
Serialize self into dict
"""
if self.angle != 0:
parent = RotatedCroppedTable.to_dict(self)
else:
parent = CroppedTable.to_dict(self)
return {**parent, **{
# 'fctn_scale_factor': self.fctn_scale_factor,
# 'fctn_padding': list(self.fctn_padding),
'config': self.config.__dict__,
'outliers': self.outliers,
'fctn_results': self.fctn_results,
}}
@staticmethod
def from_dict(d: dict, page: BasePage):
"""
Deserialize from dict.
A page is required partly because of memory management, since having this open a page may cause memory issues.
"""
cropped_table = CroppedTable.from_dict(d, page)
if 'fctn_results' not in d:
raise ValueError("fctn_results not found in dict -- dict may be a CroppedTable but not a TATRFormattedTable.")
config = TATRFormatConfig()
for k, v in d['config'].items():
if v is not None and config.__dict__.get(k) != v:
setattr(config, k, v)
results = d['fctn_results']
if 'fctn_scale_factor' in d or 'scale_factor' in d or 'fctn_padding' in d or 'padding' in d:
# deprecated: this is for backwards compatibility
scale_factor = d.get('fctn_scale_factor', d.get('scale_factor', 1))
padding = d.get('fctn_padding', d.get('padding', (0, 0)))
padding = tuple(padding)
# normalize results here
for i, bbox in enumerate(results["boxes"]):
results["boxes"][i] = _normalize_bbox(bbox, used_scale_factor=scale_factor, used_padding=padding)
table = TATRFormattedTable(cropped_table, results, # scale_factor, tuple(padding),
config=config)
table.outliers = d.get('outliers', None)
return table
class TATRTableFormatter(TableFormatter):
"""
Uses a TableTransformerForObjectDetection for small/medium tables, and a custom algorithm for large tables.
Using :meth:`extract`, a :class:`~gmft.FormattedTable` is produced, which can be exported to csv, df, etc.
"""
def __init__(self, config: TATRFormatConfig=None):
if config is None:
config = TATRFormatConfig()
if not config.warn_uninitialized_weights:
previous_verbosity = transformers.logging.get_verbosity()
transformers.logging.set_verbosity(transformers.logging.ERROR)
if not config.warn_uninitialized_weights:
previous_verbosity = transformers.logging.get_verbosity()
transformers.logging.set_verbosity(transformers.logging.ERROR)
self.image_processor = AutoImageProcessor.from_pretrained(config.image_processor_path)
revision = "no_timm" if config.no_timm else None
self.structor = TableTransformerForObjectDetection.from_pretrained(config.formatter_path, revision=revision).to(config.torch_device)
self.config = config
if not config.warn_uninitialized_weights:
transformers.logging.set_verbosity(previous_verbosity)
def extract(self, table: CroppedTable, dpi=144, padding='auto', margin=None, config_overrides=None) -> FormattedTable:
"""
Extract the data from the table.
"""
if config_overrides is not None:
config = copy.deepcopy(self.config)
config.__dict__.update(config_overrides.__dict__)
else:
config = self.config
image = table.image(dpi=dpi, padding=padding, margin=margin) # (20, 20, 20, 20)
padding = table._img_padding
margin = table._img_margin
scale_factor = dpi / 72
encoding = self.image_processor(image, return_tensors="pt").to(self.config.torch_device)
with torch.no_grad():
outputs = self.structor(**encoding)
target_sizes = [image.size[::-1]]
# threshold = 0.3
# note that a LOW threshold is good because the model is overzealous in
# but since we find the highest-intersecting row, same-row elements still tend to stay together
# this is better than having a high threshold, because if we have fewer rows than expected, we merge cells
# losing information
results = self.image_processor.post_process_object_detection(outputs, threshold=config.formatter_base_threshold, target_sizes=target_sizes)[0]
# create a new FormattedTable instance with the cropped table and the dataframe
# formatted_table = FormattedTable(table, df)
# return formatted_table
results = {k: v.tolist() for k, v in results.items()}
# normalize results w.r.t. padding and scale factor
for i, bbox in enumerate(results["boxes"]):
results["boxes"][i] = _normalize_bbox(bbox, used_scale_factor=scale_factor, used_padding=padding, used_margin=margin)
formatted_table = TATRFormattedTable(table, results, # scale_factor, padding,
config=config)
return formatted_table
def _normalize_bbox(bbox: tuple[float, float, float, float], used_scale_factor: float,
used_padding: tuple[float, float], used_margin: tuple[float, float] =None):
"""
Normalize bbox such that:
1. padding is removed (so (0, 0) is the top-left of the cropped table)
2. scale factor is normalized (dpi=72)
3. margin is removed (so (0, 0) is the start of the original detected bbox)
"""
# print("Margin: ", used_margin)
if used_margin is None:
used_margin = (0, 0)
bbox = [bbox[0] - used_padding[0], bbox[1] - used_padding[1], bbox[2] - used_padding[0], bbox[3] - used_padding[1]]
bbox = [bbox[0] / used_scale_factor, bbox[1] / used_scale_factor, bbox[2] / used_scale_factor, bbox[3] / used_scale_factor]
bbox = [bbox[0] - used_margin[0], bbox[1] - used_margin[1], bbox[2] - used_margin[0], bbox[3] - used_margin[1]]
return bbox
import pytest
import torch
from gmft.pdf_bindings import PyPDFium2Document
from gmft.table_detection import TableDetector, TableDetectorConfig
from gmft.table_function import TATRFormatConfig, TATRTableFormatter
@pytest.fixture
def doc_tiny():
doc = PyPDFium2Document("test/samples/tiny.pdf")
yield doc
doc.close()
def test_cuda(doc_tiny):
if not torch.cuda.is_available():
raise RuntimeError("cannot test device settings without cuda")
page = doc_tiny[0]
detector = TableDetector(TableDetectorConfig(torch_device=torch.device("cuda")))
formatter = TATRTableFormatter(TATRFormatConfig(torch_device=torch.device("cuda")))
table = detector.extract(page)[0]
ft = formatter.extract(table)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment