Skip to content

Instantly share code, notes, and snippets.

@wosiu
Last active May 22, 2019 00:14
Show Gist options
  • Save wosiu/9fa50de9e47615b5fa08b23637e1f947 to your computer and use it in GitHub Desktop.
Save wosiu/9fa50de9e47615b5fa08b23637e1f947 to your computer and use it in GitHub Desktop.
Calamari OCR wrapper
'''
Calamari OCR Wrapper for in-code usage.
It assumes that images are already loaded into memory as np.array.
Extended with predictions postprocessing API.
For now character whitelisting is implemented.
Calamari OCR: https://github.com/Calamari-OCR/calamari/
License: Apache 2.0
'''
import abc
import os
import logging
import numpy as np
from typing import Dict, Tuple
from calamari_ocr.ocr.datasets import DataSetType, create_dataset, DataSetMode, InputDataset, RawDataSet
from calamari_ocr.ocr import Predictor
from calamari_ocr.proto import Predictions
logger = logging.getLogger(__name__)
class PredictionPostprocessor:
def __init__(self):
pass
'''
prediction is protocol buffer with fields:
.sentence
.positions
.chars
.char
.probability
'''
@abc.abstractmethod
def process_prediction(self, fid: str, prediction, **kwargs):
pass
def __call__(self, fid: str, prediction, **kwargs):
self.process_prediction(fid, prediction, **kwargs)
class CalamariWrapper:
def __init__(self, model_ckpt_path: str, lazy_load=False):
self.model_path = model_ckpt_path
self.predictor = None
if not lazy_load:
self.load_model()
def load_model(self):
if not self.predictor:
self.predictor = Predictor(checkpoint=self.model_path,
batch_size=1,
auto_update_checkpoints=False,
processes=os.cpu_count())
logger.info("Model %s loaded.", self.model_path)
def _predict_dataset(self, dataset):
""" Predict a complete dataset. Based on calamari_ocr.predictor.Predicator.predict_dataset(...)
Parameters
----------
dataset : Dataset
Dataset to predict
Yields
-------
PredictionResult
Single PredictionResult
"""
self.load_model()
input_dataset = InputDataset(dataset, self.predictor.data_preproc, None)
return self.predictor.predict_input_dataset(input_dataset, progress_bar=False)
def ocr_batch(self, batch: Dict[str, np.array], prediction_postprocessor: PredictionPostprocessor = None,
**kwargs) -> Tuple[Dict[str, str], Dict[str, float]]:
"""
Predicts batch of images already loaded into memory.
Based on calamari_ocr.scripts.predict.run(...)
:param batch: dictionary with mapping: text line id -> image with text line already loaded as np.array
:param prediction_postprocessor: optional implementation of PredictionPostprocessor
:param kwargs: arguments for prediction postprocessor
:return: two dictionaries: 1) mapping: text line id -> recognized text, 2) mapping: text line id -> confidence
"""
images = batch.values()
dataset = RawDataSet(DataSetMode.PREDICT, images, None)
logger.info("Found %s images in the dataset", len(dataset))
if len(dataset) == 0:
raise Exception("Empty dataset provided.")
predictions = self._predict_dataset(dataset)
txt_results = {}
conf = {}
for fid, result in zip(batch.keys(), predictions):
prediction = result.prediction
if prediction_postprocessor:
prediction_postprocessor(fid, prediction, **kwargs)
txt_results[fid] = prediction.sentence
conf[fid] = prediction.avg_char_probability * 100
return txt_results, conf
class WhitelistPostProc(PredictionPostprocessor):
"""
Postprocess calamari prediction by choosing characters presented in the whitelist and with the highest confidence.
"""
def __init__(self, whitelist: str):
super().__init__()
self.whitelist = set(whitelist)
self.whitelist.add('')
def choose_char(self, chrs):
chosen_prob = 0.0
chosen = ''
for ch in chrs:
if ch.char not in self.whitelist:
continue
if ch.probability > chosen_prob:
chosen_prob = ch.probability
chosen = ch.char
return chosen, chosen_prob
def process_prediction(self, fid: str, prediction, **kwargs):
if not prediction.positions:
return
chars, prob = zip(*[self.choose_char(pos.chars) for pos in prediction.positions])
prediction.sentence = "".join(chars)
prediction.avg_char_probability = float(np.mean(prob))
class BatchPredictionPostProcMapper(PredictionPostprocessor):
"""
Aggregates other prediction postprocessors, so that each text lines in one batch might have different postprocessors.
"""
def __init__(self, postprocessors: Dict[str, PredictionPostprocessor] = {}):
super().__init__()
self.postprocessors = postprocessors
def set_postprocessor(self, fid: str, postprocessor: PredictionPostprocessor):
self.postprocessors[fid] = postprocessor
def process_prediction(self, fid: str, prediction, **kwargs):
p = self.postprocessors.get(fid, None)
if p:
p.process_prediction(fid, prediction, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment