Skip to content

Instantly share code, notes, and snippets.

@cinjon
Created October 21, 2015 14:51
Show Gist options
  • Save cinjon/83dd10545b444f516972 to your computer and use it in GitHub Desktop.
Save cinjon/83dd10545b444f516972 to your computer and use it in GitHub Desktop.
"""
Tables makes the tables for a given png. It does this with TableMaker and can be
run through 'png_to_tables(png_loc)'.
"""
import os
import subprocess
from collections import defaultdict
def png_to_tables(png):
return TableMaker(png).make()
class CV(object):
def __init__(self):
import cv2
@staticmethod
def img(loc):
return cv2.imread(loc)
@staticmethod
def resize(img, ny, nx):
return cv2.resize(img, (nx, ny))
@staticmethod
def write(img, out):
cv2.imwrite(out, img)
@staticmethod
def nonwhite_pixels(img):
# make more efficient
ret = []
for rnum, row in enumerate(img):
for pnum, pixel in enumerate(row):
if not all([x == 255 for x in pixel.tolist()]):
ret.append((rnum, pnum))
return ret
class TableMaker(object):
def __init__(self, png_loc):
self.nwp = self._get_non_white_positions(png_loc)
@staticmethod
def _get_non_white_positions(png):
cv = CV()
img = cv.img(png)
resized = cv.resize(img, 3300, 2550)
return cv.nonwhite_pixels(resized)
def make(self):
rlines = self._get_table_row_lines()
clines = self._get_table_column_lines()
rpoints, cpoints = self._get_reduced_row_and_col_points(rlines, clines)
intersections = self._get_intersections(rpoints, cpoints)
return self._get_tables(intersections)
def _get_table_row_lines(self, threshold=None):
threshold = threshold or 850 # size of 2550/3300: 850 threshold
drow = defaultdict(list)
for row, col in self.nwp:
drow[row].append(col)
rows_with_col_lists = sorted([(k, v) for k,v in drow.iteritems()],
key=lambda e:len(e[1]), reverse=True)
return [e for e in rows_with_col_lists if len(e[1]) > threshold]
def _get_table_column_lines(self, threshold=None):
dcol = defaultdict(list)
for row, col in self.nwp:
dcol[col].append(row)
cols_with_row_lists = sorted([(k, v) for k,v in dcol.iteritems()],
key=lambda e:len(e[1]), reverse=True)
threshold = threshold or 60 # size of 2550/3300: 60 threshold
col_line_segments = []
for col, row_list in cols_with_row_lists:
segments = []
segment = []
for row in row_list:
if not segment or segment[-1] == row - 1:
segment.append(row)
elif len(segment) < threshold:
segment = [row]
else:
segments.append(segment)
segment = []
if len(segment) >= threshold:
segments.append(segment)
col_line_segments.extend([(col, s) for s in segments])
return col_line_segments
@staticmethod
def _get_reduced_row_and_col_points(row_lines, col_lines):
"""
Given row_lines and col_lines, reduces them to core points based off the
start points in each line. Returns as a list of dicts:
{'s':(x,y), 'e':(x,y)} for 'Start' and 'End'"""
rstarts = [{'s':(r[0], r[1][0]), 'e':(r[0], r[1][-1])}
for r in sorted(row_lines)]
cstarts = sorted(
[{'s':(r[1][0], r[0]), 'e':(r[1][-1], r[0])} for r in col_lines],
key=lambda l:(l['s'][0], l['s'][1]))
def get_reduced(starts, is_row=True):
# The 30 threshold might not be correct for other documents!!!
tmp = []
for num, line in enumerate(starts):
if not tmp:
tmp.append(line)
continue
row, col = line['s']
prevrow, prevcol = starts[num-1]['s']
is_next_row = row - 30 < prevrow and prevcol == col
is_next_col = row == prevrow and prevcol > col - 30
if (is_row and is_next_row) or (is_next_col and not is_row):
continue
tmp.append(line)
return tmp
return sorted(get_reduced(rstarts, True), key=lambda l:l['s'][0]), \
sorted(get_reduced(cstarts, False), key=lambda l:l['s'][0])
@staticmethod
def _get_intersections(rs, cs):
"""Given reduced lists of points 'rs' and 'cs' (get_reduced_row_and_col...),
gets the intersecting lines."""
def intersection(line1, line2):
"""
Return the coordinates of a point of intersection given two lines.
line1 and line2: lines given by 2 points (a 2-tuple of (x,y)-coords).
"""
def near(a, b, rtol=1e-5, atol=1e-8):
return abs(a - b) < (atol + rtol * abs(b))
(x1,y1), (x2,y2) = line1
(u1,v1), (u2,v2) = line2
(a,b), (c,d) = (x2-x1, u1-u2), (y2-y1, v1-v2)
e, f = u1-x1, v1-y1
denom = float(a*d - b*c)
if near(denom, 0):
return False
else:
t = (e*d - b*f)/denom
s = (a*f - e*c)/denom
if 0 <= t <= 1 and 0 <= s <= 1:
px = x1 + t*(x2-x1)
py = y1 + t*(y2-y1)
return (int(px), int(py))
return False
def extend_points(pts, is_columns):
ret = []
extension = 5*int(is_columns)
for num, p in enumerate(pts):
e = (p['e'][0] + extension, p['e'][1] + 5 - extension)
s = (p['s'][0] - extension, p['s'][1] + extension - 5)
ret.append({'e':e, 's':s, 'index':num})
return ret
extended_cs = extend_points(cs, True)
extended_rs = extend_points(rs, False)
intersections = []
for row in extended_rs:
row_start = row['s']
row_end = row['e']
for col in extended_cs:
col_start = col['s']
col_end = col['e']
intersect = intersection((row_start, row_end),
(col_start, col_end))
if not intersect:
continue
intersections.append({'row':(row_start, row_end),
'col':(col_start, col_end),
'intersect':intersect})
return intersections
@staticmethod
def _get_tables(intersections):
"""
1. Get the top left intersection.
2. Skate to the end of its row for the top right.
3. Skate to the bottom of its column for the bottom left.
4. Skate to the end of the bottom left's row for the bottom right.
5. Include all the intersections in between those 4 points in the table.
6. Repeat to get the next table.
"""
def get_corners(subset):
position = min([e['intersect'][0] for e in subset])
top_left = sorted(
[k for k in subset if k['intersect'][0] == position],
key = lambda e:e['intersect'][1]
)[0]
bottom_left = sorted(
[k for k in subset if k['col'][1] == top_left['col'][1]],
key = lambda e:e['intersect'][0]
)[-1]
top_right = sorted(
[k for k in subset if k['row'][1] == top_left['row'][1]],
key = lambda e:e['intersect'][1]
)[-1]
bottom_right = sorted(
[k for k in subset if k['row'][1] == bottom_left['row'][1]],
key = lambda e:e['intersect'][1]
)[-1]
return (p['intersect'] for p in
[top_left, top_right, bottom_right, bottom_left])
def in_corners(row, col, tl, tr, br, bl):
if row < tl[0] or col < tl[1]:
return False
if row > bl[0] or col > br[1]:
return False
return True
def get_table_points(subset):
tl, tr, br, bl = get_corners(subset)
remaining = []
ret = []
for i in subset:
row, col = i['intersect']
if in_corners(row, col, tl, tr, br, bl):
ret.append(i)
else:
remaining.append(i)
return remaining, ret
tables = []
remaining = []
for k in intersections:
remaining.append(k)
print 'Total Points: %d' % len(intersections)
while remaining:
remaining, table_points = get_table_points(remaining)
tables.append(Table([k['intersect'] for k in table_points]))
print 'Remaining Point Total: %d, Table length: %d.' % (
len(remaining), len(table_points))
return tables
class Table(object):
"""
Table class.
To reconstruct the table, we need to relate the MTed text position to the
CVed position of the cell that it's in. With that relation, in order to
reconstruct the table, we need to know the corner points of the cell,
what row it starts on, and what percentage of the row it takes up. We can
then reconstruct the table row by row and rely on the text size to balloon
as necessary within those confines.
"""
def __init__(self, points):
row_indices = self._get_row_indices(points)
col_indices = self._get_col_indices(points)
self.width = col_indices[-1] - col_indices[0]
self.rows = self.make_rows(points, row_indices)
@staticmethod
def _get_row_indices(points):
return sorted(list(set(k[0] for k in points)))
@staticmethod
def _get_col_indices(points):
return sorted(list(set(k[1] for k in points)))
def make_rows(self, points, row_indices):
"""
Incorrect in some cases where the rows misalign.
The row nums for cells may be wrong.
"""
ret = []
top_points = None
others = []
for row_num, indice in enumerate(row_indices):
row_points = sorted([p for p in points if p[0] == indice],
key = lambda k:k[1])
if top_points is None:
pass
elif len(row_points) > len(top_points):
# only use top_points that are aligned with row_points
# this is for when the points aren't aligned between two rows
matches = [k for k in row_points
if k[1] in set(e[1] for e in top_points)]
ret.append(self._make_row(top_points, matches, row_num))
elif len(row_points) < len(top_points):
# only use top_points that are aligned with row_points
# this is for when the points aren't aligned between two rows
matches = [k for k in top_points
if k[1] in set(e[1] for e in row_points)]
others = [k for k in top_points if k not in matches]
ret.append(self._make_row(matches, row_points, row_num))
else:
ret.append(self._make_row(top_points, row_points, row_num))
top_points = sorted(row_points + others, key=lambda k:k[1])
others = []
return ret
def _make_row(self, top_points, bottom_points, row_num):
row = []
for indice in xrange(len(top_points) - 1):
tl = top_points[indice]
tr = top_points[indice+1]
bl = bottom_points[indice]
br = bottom_points[indice+1]
width = 1.0 * (tr[1] - tl[1]) / self.width
row.append(Cell(tl, tr, bl, br, indice, width))
return row
class Cell(object):
"""
Cell class is a single cell in a table. It has four corner points (indices),
a row number in its parent table, and a percentage width of its row
"""
def __init__(self, tl, tr, bl, br, row_num, width):
self.tl = tl
self.tr = tr
self.bl = bl
self.br = br
self.row_num = row_num
self.width = width
def format_to_points(lines, is_row_first):
"""@lines looks like [({row/col}, [{cols/rows}])].
Return as [(row, col)]"""
ret = []
for elem, lst in lines:
for l in lst:
if is_row_first:
ret.append((elem, l))
else:
ret.append((l, elem))
return ret
def paint(img, points, color=None):
"""points is a list of pixels. this func is destructive to img."""
color = color or [255,0,0] # blue
for row, col in points:
img[row][col] = color
def paint_file_tables(png_loc, row_threshold=None, col_threshold=None):
cv = CV()
img = cv.img(png_loc)
try:
reimg = cv.resize(img, 3300, 2550)
except Exception, e:
print 'Failed to resize %s (%s). It has current size %d / %d.' % (
png_loc, e, len(img[0]), len(img))
return
non_white_positions = cv.nonwhite_pixels(reimg)
row_lines = get_table_row_lines(non_white_positions=non_white_positions,
threshold=row_threshold)
col_lines = get_table_column_lines(non_white_positions=non_white_positions,
threshold=col_threshold)
points = format_to_points(row_lines, True) + format_to_points(col_lines, False)
paint(reimg, points)
out_loc = png_loc[:-4] + '.painted.png'
print out_loc
cv.write(reimg, out_loc)
def paint_dir_tables(png_dir, row_threshold=None, col_threshold=None):
cv = CV()
for png in os.listdir(png_dir):
if png == '.DS_Store':
continue
full_path = os.path.join(png_dir, png)
paint_file_tables(full_path, row_threshold, col_threshold)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment