Created
October 21, 2015 14:51
-
-
Save cinjon/83dd10545b444f516972 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tables makes the tables for a given png. It does this with TableMaker and can be | |
run through 'png_to_tables(png_loc)'. | |
""" | |
import os | |
import subprocess | |
from collections import defaultdict | |
def png_to_tables(png): | |
return TableMaker(png).make() | |
class CV(object): | |
def __init__(self): | |
import cv2 | |
@staticmethod | |
def img(loc): | |
return cv2.imread(loc) | |
@staticmethod | |
def resize(img, ny, nx): | |
return cv2.resize(img, (nx, ny)) | |
@staticmethod | |
def write(img, out): | |
cv2.imwrite(out, img) | |
@staticmethod | |
def nonwhite_pixels(img): | |
# make more efficient | |
ret = [] | |
for rnum, row in enumerate(img): | |
for pnum, pixel in enumerate(row): | |
if not all([x == 255 for x in pixel.tolist()]): | |
ret.append((rnum, pnum)) | |
return ret | |
class TableMaker(object): | |
def __init__(self, png_loc): | |
self.nwp = self._get_non_white_positions(png_loc) | |
@staticmethod | |
def _get_non_white_positions(png): | |
cv = CV() | |
img = cv.img(png) | |
resized = cv.resize(img, 3300, 2550) | |
return cv.nonwhite_pixels(resized) | |
def make(self): | |
rlines = self._get_table_row_lines() | |
clines = self._get_table_column_lines() | |
rpoints, cpoints = self._get_reduced_row_and_col_points(rlines, clines) | |
intersections = self._get_intersections(rpoints, cpoints) | |
return self._get_tables(intersections) | |
def _get_table_row_lines(self, threshold=None): | |
threshold = threshold or 850 # size of 2550/3300: 850 threshold | |
drow = defaultdict(list) | |
for row, col in self.nwp: | |
drow[row].append(col) | |
rows_with_col_lists = sorted([(k, v) for k,v in drow.iteritems()], | |
key=lambda e:len(e[1]), reverse=True) | |
return [e for e in rows_with_col_lists if len(e[1]) > threshold] | |
def _get_table_column_lines(self, threshold=None): | |
dcol = defaultdict(list) | |
for row, col in self.nwp: | |
dcol[col].append(row) | |
cols_with_row_lists = sorted([(k, v) for k,v in dcol.iteritems()], | |
key=lambda e:len(e[1]), reverse=True) | |
threshold = threshold or 60 # size of 2550/3300: 60 threshold | |
col_line_segments = [] | |
for col, row_list in cols_with_row_lists: | |
segments = [] | |
segment = [] | |
for row in row_list: | |
if not segment or segment[-1] == row - 1: | |
segment.append(row) | |
elif len(segment) < threshold: | |
segment = [row] | |
else: | |
segments.append(segment) | |
segment = [] | |
if len(segment) >= threshold: | |
segments.append(segment) | |
col_line_segments.extend([(col, s) for s in segments]) | |
return col_line_segments | |
@staticmethod | |
def _get_reduced_row_and_col_points(row_lines, col_lines): | |
""" | |
Given row_lines and col_lines, reduces them to core points based off the | |
start points in each line. Returns as a list of dicts: | |
{'s':(x,y), 'e':(x,y)} for 'Start' and 'End'""" | |
rstarts = [{'s':(r[0], r[1][0]), 'e':(r[0], r[1][-1])} | |
for r in sorted(row_lines)] | |
cstarts = sorted( | |
[{'s':(r[1][0], r[0]), 'e':(r[1][-1], r[0])} for r in col_lines], | |
key=lambda l:(l['s'][0], l['s'][1])) | |
def get_reduced(starts, is_row=True): | |
# The 30 threshold might not be correct for other documents!!! | |
tmp = [] | |
for num, line in enumerate(starts): | |
if not tmp: | |
tmp.append(line) | |
continue | |
row, col = line['s'] | |
prevrow, prevcol = starts[num-1]['s'] | |
is_next_row = row - 30 < prevrow and prevcol == col | |
is_next_col = row == prevrow and prevcol > col - 30 | |
if (is_row and is_next_row) or (is_next_col and not is_row): | |
continue | |
tmp.append(line) | |
return tmp | |
return sorted(get_reduced(rstarts, True), key=lambda l:l['s'][0]), \ | |
sorted(get_reduced(cstarts, False), key=lambda l:l['s'][0]) | |
@staticmethod | |
def _get_intersections(rs, cs): | |
"""Given reduced lists of points 'rs' and 'cs' (get_reduced_row_and_col...), | |
gets the intersecting lines.""" | |
def intersection(line1, line2): | |
""" | |
Return the coordinates of a point of intersection given two lines. | |
line1 and line2: lines given by 2 points (a 2-tuple of (x,y)-coords). | |
""" | |
def near(a, b, rtol=1e-5, atol=1e-8): | |
return abs(a - b) < (atol + rtol * abs(b)) | |
(x1,y1), (x2,y2) = line1 | |
(u1,v1), (u2,v2) = line2 | |
(a,b), (c,d) = (x2-x1, u1-u2), (y2-y1, v1-v2) | |
e, f = u1-x1, v1-y1 | |
denom = float(a*d - b*c) | |
if near(denom, 0): | |
return False | |
else: | |
t = (e*d - b*f)/denom | |
s = (a*f - e*c)/denom | |
if 0 <= t <= 1 and 0 <= s <= 1: | |
px = x1 + t*(x2-x1) | |
py = y1 + t*(y2-y1) | |
return (int(px), int(py)) | |
return False | |
def extend_points(pts, is_columns): | |
ret = [] | |
extension = 5*int(is_columns) | |
for num, p in enumerate(pts): | |
e = (p['e'][0] + extension, p['e'][1] + 5 - extension) | |
s = (p['s'][0] - extension, p['s'][1] + extension - 5) | |
ret.append({'e':e, 's':s, 'index':num}) | |
return ret | |
extended_cs = extend_points(cs, True) | |
extended_rs = extend_points(rs, False) | |
intersections = [] | |
for row in extended_rs: | |
row_start = row['s'] | |
row_end = row['e'] | |
for col in extended_cs: | |
col_start = col['s'] | |
col_end = col['e'] | |
intersect = intersection((row_start, row_end), | |
(col_start, col_end)) | |
if not intersect: | |
continue | |
intersections.append({'row':(row_start, row_end), | |
'col':(col_start, col_end), | |
'intersect':intersect}) | |
return intersections | |
@staticmethod | |
def _get_tables(intersections): | |
""" | |
1. Get the top left intersection. | |
2. Skate to the end of its row for the top right. | |
3. Skate to the bottom of its column for the bottom left. | |
4. Skate to the end of the bottom left's row for the bottom right. | |
5. Include all the intersections in between those 4 points in the table. | |
6. Repeat to get the next table. | |
""" | |
def get_corners(subset): | |
position = min([e['intersect'][0] for e in subset]) | |
top_left = sorted( | |
[k for k in subset if k['intersect'][0] == position], | |
key = lambda e:e['intersect'][1] | |
)[0] | |
bottom_left = sorted( | |
[k for k in subset if k['col'][1] == top_left['col'][1]], | |
key = lambda e:e['intersect'][0] | |
)[-1] | |
top_right = sorted( | |
[k for k in subset if k['row'][1] == top_left['row'][1]], | |
key = lambda e:e['intersect'][1] | |
)[-1] | |
bottom_right = sorted( | |
[k for k in subset if k['row'][1] == bottom_left['row'][1]], | |
key = lambda e:e['intersect'][1] | |
)[-1] | |
return (p['intersect'] for p in | |
[top_left, top_right, bottom_right, bottom_left]) | |
def in_corners(row, col, tl, tr, br, bl): | |
if row < tl[0] or col < tl[1]: | |
return False | |
if row > bl[0] or col > br[1]: | |
return False | |
return True | |
def get_table_points(subset): | |
tl, tr, br, bl = get_corners(subset) | |
remaining = [] | |
ret = [] | |
for i in subset: | |
row, col = i['intersect'] | |
if in_corners(row, col, tl, tr, br, bl): | |
ret.append(i) | |
else: | |
remaining.append(i) | |
return remaining, ret | |
tables = [] | |
remaining = [] | |
for k in intersections: | |
remaining.append(k) | |
print 'Total Points: %d' % len(intersections) | |
while remaining: | |
remaining, table_points = get_table_points(remaining) | |
tables.append(Table([k['intersect'] for k in table_points])) | |
print 'Remaining Point Total: %d, Table length: %d.' % ( | |
len(remaining), len(table_points)) | |
return tables | |
class Table(object): | |
""" | |
Table class. | |
To reconstruct the table, we need to relate the MTed text position to the | |
CVed position of the cell that it's in. With that relation, in order to | |
reconstruct the table, we need to know the corner points of the cell, | |
what row it starts on, and what percentage of the row it takes up. We can | |
then reconstruct the table row by row and rely on the text size to balloon | |
as necessary within those confines. | |
""" | |
def __init__(self, points): | |
row_indices = self._get_row_indices(points) | |
col_indices = self._get_col_indices(points) | |
self.width = col_indices[-1] - col_indices[0] | |
self.rows = self.make_rows(points, row_indices) | |
@staticmethod | |
def _get_row_indices(points): | |
return sorted(list(set(k[0] for k in points))) | |
@staticmethod | |
def _get_col_indices(points): | |
return sorted(list(set(k[1] for k in points))) | |
def make_rows(self, points, row_indices): | |
""" | |
Incorrect in some cases where the rows misalign. | |
The row nums for cells may be wrong. | |
""" | |
ret = [] | |
top_points = None | |
others = [] | |
for row_num, indice in enumerate(row_indices): | |
row_points = sorted([p for p in points if p[0] == indice], | |
key = lambda k:k[1]) | |
if top_points is None: | |
pass | |
elif len(row_points) > len(top_points): | |
# only use top_points that are aligned with row_points | |
# this is for when the points aren't aligned between two rows | |
matches = [k for k in row_points | |
if k[1] in set(e[1] for e in top_points)] | |
ret.append(self._make_row(top_points, matches, row_num)) | |
elif len(row_points) < len(top_points): | |
# only use top_points that are aligned with row_points | |
# this is for when the points aren't aligned between two rows | |
matches = [k for k in top_points | |
if k[1] in set(e[1] for e in row_points)] | |
others = [k for k in top_points if k not in matches] | |
ret.append(self._make_row(matches, row_points, row_num)) | |
else: | |
ret.append(self._make_row(top_points, row_points, row_num)) | |
top_points = sorted(row_points + others, key=lambda k:k[1]) | |
others = [] | |
return ret | |
def _make_row(self, top_points, bottom_points, row_num): | |
row = [] | |
for indice in xrange(len(top_points) - 1): | |
tl = top_points[indice] | |
tr = top_points[indice+1] | |
bl = bottom_points[indice] | |
br = bottom_points[indice+1] | |
width = 1.0 * (tr[1] - tl[1]) / self.width | |
row.append(Cell(tl, tr, bl, br, indice, width)) | |
return row | |
class Cell(object): | |
""" | |
Cell class is a single cell in a table. It has four corner points (indices), | |
a row number in its parent table, and a percentage width of its row | |
""" | |
def __init__(self, tl, tr, bl, br, row_num, width): | |
self.tl = tl | |
self.tr = tr | |
self.bl = bl | |
self.br = br | |
self.row_num = row_num | |
self.width = width | |
def format_to_points(lines, is_row_first): | |
"""@lines looks like [({row/col}, [{cols/rows}])]. | |
Return as [(row, col)]""" | |
ret = [] | |
for elem, lst in lines: | |
for l in lst: | |
if is_row_first: | |
ret.append((elem, l)) | |
else: | |
ret.append((l, elem)) | |
return ret | |
def paint(img, points, color=None): | |
"""points is a list of pixels. this func is destructive to img.""" | |
color = color or [255,0,0] # blue | |
for row, col in points: | |
img[row][col] = color | |
def paint_file_tables(png_loc, row_threshold=None, col_threshold=None): | |
cv = CV() | |
img = cv.img(png_loc) | |
try: | |
reimg = cv.resize(img, 3300, 2550) | |
except Exception, e: | |
print 'Failed to resize %s (%s). It has current size %d / %d.' % ( | |
png_loc, e, len(img[0]), len(img)) | |
return | |
non_white_positions = cv.nonwhite_pixels(reimg) | |
row_lines = get_table_row_lines(non_white_positions=non_white_positions, | |
threshold=row_threshold) | |
col_lines = get_table_column_lines(non_white_positions=non_white_positions, | |
threshold=col_threshold) | |
points = format_to_points(row_lines, True) + format_to_points(col_lines, False) | |
paint(reimg, points) | |
out_loc = png_loc[:-4] + '.painted.png' | |
print out_loc | |
cv.write(reimg, out_loc) | |
def paint_dir_tables(png_dir, row_threshold=None, col_threshold=None): | |
cv = CV() | |
for png in os.listdir(png_dir): | |
if png == '.DS_Store': | |
continue | |
full_path = os.path.join(png_dir, png) | |
paint_file_tables(full_path, row_threshold, col_threshold) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment