Skip to content

Instantly share code, notes, and snippets.

@sai-prasanna
Created October 18, 2019 03:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sai-prasanna/f18908d2054257045a251306e7a01806 to your computer and use it in GitHub Desktop.
Save sai-prasanna/f18908d2054257045a251306e7a01806 to your computer and use it in GitHub Desktop.
import streamlit as st
import numpy as np
from lmproof.scorer import TransformerLMScorer
from lmproof.candidate_generators import (MatchedGenerator,
EnglishInflectedGenerator,
SpellCorrectGenerator)
@st.cache(ignore_hash=True)
def model():
return TransformerLMScorer.load('en', 'cuda:0')
def data_point(gen_name, text, score):
return {'gen': gen_name, 'text': text, 'score': score}
st.title('Language Model Proofreader')
scorer = model()
scorer.batch_size = 5
match_gen = MatchedGenerator.load('en')
inflect_gen = EnglishInflectedGenerator()
spell_correct_gen = SpellCorrectGenerator.load('en')
candidate_generators = {
'match': match_gen,
'inflect': inflect_gen,
'spell': spell_correct_gen
}
sentence = st.text_area('Enter a sentece', value='Test sentence')
correction_gen_name = 'Correct:'
correction = sentence
previous_candidates = set([sentence])
threshold = 0.1
while True:
table = []
gen_names, candidates = list(zip(*[
(gen_name, candidate)
for gen_name, g in candidate_generators.items()
for candidate in g.candidates(correction)
if candidate not in previous_candidates]))
gen_names, candidates = list(gen_names), list(candidates)
# Do Scoring in one shot to use batching internally.
source_score, *candidate_scores = scorer.score([correction] + candidates)
# Add the threshold to bias towards source sentence.
biased_source_score = source_score + threshold
thresholded_scores = np.array(candidate_scores)
best_idx = np.argmax(thresholded_scores)
st.subheader('Sentences')
st.table(sorted([data_point(correction_gen_name, correction, biased_source_score)] + [data_point(g, t, s) for g, t, s in zip(gen_names, candidates, candidate_scores)], reverse=True, key=lambda x: x['score']))
if candidate_scores[best_idx] > biased_source_score:
best_candidate = candidates[best_idx]
correction_gen_name += ':' + gen_names[best_idx]
else:
best_candidate = None
if not best_candidate:
break
else:
correction = best_candidate
previous_candidates.union(set(candidates))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment