Skip to content

Instantly share code, notes, and snippets.

@glangford
Created February 7, 2024 21:26
Show Gist options
  • Save glangford/a2b24ffd92c832c60e1b1b49da1a8b27 to your computer and use it in GitHub Desktop.
Save glangford/a2b24ffd92c832c60e1b1b49da1a8b27 to your computer and use it in GitHub Desktop.
Experimental tool using NLP to create grammatically separated subtitles from whisper transcripts
"""
MIT License
Copyright (c) 2024 Glenn Langford
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import os
import argparse
import logging
import json
from more_itertools import chunked
from itertools import pairwise
from collections.abc import Iterator
import spacy
from spacy.language import Language
from spacy.tokens import Doc, Span, Token
from spacy.matcher import Matcher
from whisper.utils import format_timestamp
def get_time_span(span: Span, timing: dict):
start_token = span[0]
end_token = span[-1]
while start_token.is_punct or not timing.get(start_token.idx, None): # back up in tokens to get timing entry for a word
start_token = start_token.nbor(-1)
while end_token.is_punct or not timing.get(end_token.idx, None):
end_token = end_token.nbor(-1)
end_index = end_token.idx
start_index = start_token.idx
start, _ = timing[start_index]
#_, end = timing[end_index]
_, end = timing.get(end_index, (None, None))
if not end:
logging.debug("Timing aligment error: %s %d", span.text, end_token.idx)
return (start, end)
Token.set_extension("can_fragment_after", default=False)
Token.set_extension("fragment_reason", default="")
Span.set_extension("get_time_span", method=get_time_span) # type: ignore
punct_pattern = [{'IS_PUNCT': True, 'ORTH': {"IN": [",", ":", ";"]}}]
conj_pattern = [{"POS": {"IN": ["CCONJ", "SCONJ"]}}] # break prior to CCONJ or SCONJ
clause_pattern = [{"DEP": {"IN": ["advcl", "relcl", "acl", "acl:relcl"]}}] # break at either side of clause
ac_comp_pattern = [{"DEP": {"IN": ["acomp", "ccomp"]}}] # break at either side of adj or clausal complement
preposition_pattern = [{'POS': 'ADP'}] # but break *after* ADP if preceded by verb
dobj_pattern = [{'DEP': 'dobj'}, {'IS_PUNCT': False}]
# break after PART if preceded by verb (adjacent verb-particle pattern)
v_particle_pattern = [{'POS': 'VERB'}, {'POS': 'PART'}, {'POS': {"IN": ["VERB", "AUX"]}, 'OP': '!'}]
v_adj_pattern = [{'POS': "VERB"}, {"POS": "ADJ", "DEP": "amod"}] # amod = adjectival modifier (adjacent adjective pattern)
#log_pattern = [{'POS': "VERB"}, {"POS": "SCONJ"}, {'POS': "VERB"}] # pt "está a tornar"
# https://spacy.io/usage/processing-pipelines
# See "Example: Stateful component with settings"
def _fragment_at(token: Token, reason: str):
token._.can_fragment_after = True
token._.fragment_reason = reason
# Stateful spaCy pipeline component with settings
@Language.factory("fragmenter", default_config={"verbal_pauses": []})
def create_fragmenter_component(nlp: Language, name: str, verbal_pauses: list[int]):
return FragmenterComponent(nlp, verbal_pauses)
class FragmenterComponent:
def __init__(self, nlp: Language, verbal_pauses: list): # pauses must be serializable
self.pauses = set(verbal_pauses)
logging.info("Count of pauses: %d", len(self.pauses))
def __call__(self, doc: Doc) -> Doc: # process the document
return fragmenter(doc, self.pauses)
###@Language.component("fragmenter")
def fragmenter(doc: Doc, pauses: set) -> Doc:
matcher = Matcher(doc.vocab)
matcher.add("clause", [clause_pattern])
matcher.add("punct", [punct_pattern])
matcher.add("conj", [conj_pattern])
matcher.add("preposition", [preposition_pattern])
matcher.add("dobj", [dobj_pattern])
matcher.add("v_particle", [v_particle_pattern])
matcher.add("ac_comp", [ac_comp_pattern])
matcher.add("v_adj", [v_adj_pattern])
matches = matcher(doc)
conjunction_or_punct = frozenset(["CCONJ", "SCONJ", "PUNCT"])
# Iterate over the matches
for match_id, start, end in matches:
rule_id = doc.vocab.strings[match_id]
matched_span = doc[start:end]
##print(rule_id, start, end, doc[start].idx, matched_span.text)
token = doc[start]
#rule = re.sub("_pattern", "", rule_id)
if token.i < 2:
continue
match rule_id:
case "punct": # selected subset of punctuation characters
_fragment_at(token, reason=rule_id)
case "conj": # coordinating and subordinating conjunctions
prior = token.nbor(-1)
if prior.pos_ not in conjunction_or_punct: # eg. "so that"
_fragment_at(prior, reason=rule_id)
case "clause":
subtree = [t for t in token.subtree]
#print(" ", " ".join(t.text for t in subtree))
if len(subtree) < 2:
continue
clause_rule = f"{rule_id}:{token.text}"
left = subtree[0]
if left and left.i > 0 and not left.is_punct and left.text[0] != "'":
prior = left.nbor(-1)
if prior.pos_ not in conjunction_or_punct and not prior.nbor(-1).is_punct:
_fragment_at(prior, reason=clause_rule)
right = subtree[-1]
try:
if right.pos_ not in conjunction_or_punct and not right.nbor(1).is_punct:
_fragment_at(right, reason=clause_rule)
except IndexError:
continue # can safely ignore
case "preposition":
#print("", matched_span, token.text)
prior = token.nbor(-1)
if prior.pos_ in conjunction_or_punct or token.ent_iob_ == 'I': # prior is a conjunction, comma, or token is inside a named entity
continue
next = token.nbor(1)
# Keep preposition together with verb when appropriate
if (token.dep_ == 'prt' or prior.pos_ in ['AUX', 'VERB']) and not next.is_punct: # "going on, come up, end up"
_fragment_at(token, reason=f"{rule_id}-after")
else: # fragment prior to preposition
if token.i > 2 and not (prior.is_punct or prior.nbor(-1).is_punct):
_fragment_at(prior, reason=rule_id)
case "v_particle": # VERB + PART
#print(rule_id, matched_span)
particle = matched_span[1]
if particle.is_punct or particle.nbor(1).is_punct:
continue
_fragment_at(particle, reason=rule_id)
case "v_adj":
#print(rule_id, matched_span)
_fragment_at(token, reason=rule_id)
case "dobj":
if token.pos_ not in conjunction_or_punct:
_fragment_at(token, reason=rule_id)
case "ac_comp":
if token.is_punct: # first token could be punct with hyphenated words
continue
subtree = [t for t in token.subtree]
left = subtree[0]
if len(subtree) < 2:
continue
ac_rule = f"{rule_id}:{token.text}"
if left and left.i > 0 and left.text[0] != "'" and not (left.is_punct or left.nbor(-1).is_punct): # can get the leftmost token as an apostrophe or a comma
_fragment_at(left.nbor(-1), reason=ac_rule)
right = subtree[-1]
try:
if not (right.is_punct or right.nbor(1).is_punct):
_fragment_at(right, reason=ac_rule)
except IndexError:
logging.debug("ac_comp IndexError")
continue # can safely ignore
#print("acomp ", token.text, token.dep_, "tree:", " ".join(t.text for t in subtree))
_scan_entities(doc)
_scan_noun_phrases(doc)
_scan_pauses(doc, pauses)
return doc
def _scan_pauses(doc: Doc, pauses: set):
for token in doc:
if token.text[0] == '-':
continue
try:
if token.idx in pauses and not token.nbor(1).is_punct : # identify places where spaCy misses period end, eg. "The length was 2m."
logging.debug("Candidate pause: %d %s %s", token.i, token.text, token.nbor(1).text)
#_fragment_at(token, reason="verbal pause")
except IndexError:
continue # can be safely ignored
def _scan_entities(doc: Doc):
for entity in doc.ents:
# if any( (len(entity) < 2, entity.label_ in ignore_entities, len(entity.text) < 10) )
if len(entity) < 2 or entity.label_ in ['PERSON', 'ORDINAL', 'PERCENT', 'TIME', 'CARDINAL'] or len(entity.text) < 10:
continue
token = entity[0]
if token.i < 1:
continue
prior = token.nbor(-1)
if (not prior.is_punct and
prior.pos_ not in ['DET'] and
not prior._.can_fragment_after): # don't overwrite prior reason
_fragment_at(prior, reason="entity->")
after = entity[-1].nbor(1)
if (after.pos_ != "PART" and
not after.is_punct and # eg "UK's ..." where entity is possessive
not entity[-1]._.can_fragment_after):
_fragment_at(entity[-1], reason="entity")
def _scan_noun_phrases(doc: Doc):
for chunk in doc.noun_chunks:
if len(chunk) < 2:
continue
token = chunk[0]
if token.i > 0 and not token.is_punct:
prior = token.nbor(-1)
# Don't interfere with prior decision of ADP or S/C CONJ at the start of noun phrases
if (prior.pos_ not in ['ADP', 'SCONJ', 'CCONJ'] and
not prior._.can_fragment_after and # don't overwrite prior reason
not prior.is_punct): # eg. don't fragment at hyphen or period of previous sentence
_fragment_at(prior, reason="NP->")
try:
after = chunk[-1].nbor(1)
if (not after.is_punct and
not chunk[-1]._.can_fragment_after):
_fragment_at(chunk[-1], reason="NP")
except IndexError:
continue # can safely ignore
def leading_whitespace(s: str) -> int:
return(len(s) - len(s.lstrip()))
def load_whisper_json(file: str) -> tuple[str, dict]:
doc_timing = {}
doc_text = ""
js = open(file)
jsdata = json.load(js)
logging.debug("Language: %s", jsdata['language'])
for s in jsdata['segments']:
if 'words' not in s:
raise ValueError('JSON input file must contain word timestamps')
for word_timed in s['words']:
# print(word_timed['start'], word_timed['end'], word_timed['word'])
word = word_timed['word']
if len(doc_text) == 0:
word = word.lstrip() # remove any leading whitespace from first word in Whisper
start_index = 0
doc_text += word
start_index = len(doc_text) - len(word) + leading_whitespace(word) # align the timing index with the spaCy token index
doc_timing[start_index] = (word_timed['start'], word_timed['end'])
return doc_text, doc_timing
# Greedy scan from right to left for preferred grammar and occupancy
def preferred_division_for(span: Span, max_width: int) -> int:
# token.i: index of the token within the parent document
# token.idx: character offset of the token within the parent document
# Span.start_char: character offset for the start of the span
# Span.end_char: character offset for the end of the span
def is_grammatically_preferred(token: Token):
return token._.can_fragment_after and (
token._.fragment_reason in ['punct', 'conj', 'clause', 'entity', 'entity->', 'v_adj'])
preferreds = (t for t in reversed(span) if is_grammatically_preferred(t))
target_width = round(0.7 * max_width)
for tp in preferreds:
width = tp.idx + len(tp) - span.start_char
if width > max_width:
continue
remainder_width = span.end_char - tp.idx - len(tp)
if width <= remainder_width and width >= max_width/3 and remainder_width <= max_width: # prefer pyramid to finish the sentence
logging.debug("Primary complete %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text)
return tp.i
if width >= target_width and width <= remainder_width * 1.2: # no more than 20% bigger than remaining text in sentence
logging.debug("Primary selected %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text)
return tp.i
logging.debug("No primary for '%s'", span.text)
return 0
# Secondary scan from left to right for fragmenting words, or degrade to a word within line maximum
def secondary_division_for(span: Span, max_width: int) -> int:
token_divider = 0 # index of token in parent doc
start_index = span.start_char
for token in span:
if token.i == span[0].i:
continue
token_start = token.idx - start_index # character offset from start of span
if token_divider and token_start > max_width: # do we have a prior divider and is this token is over budget
break
token_end = token_start + len(token)
if token._.can_fragment_after and token_end <= max_width and token.i + 2 < span[-1].i : # select token to divide at
token_divider = token.i
if span.end_char - token.idx - len(token) <= max_width: # lookahead to see if Span remainder is less than max_width
break # prefer pyramid
if not token_divider and token_end > max_width: # if no grammatical option found, fall back to last token within budget
token_divider = token.i - 1 if token.pos_ != 'PUNCT' else token.i - 2 # don't orphan a period at end of sentence
logging.info("Forced division after word '%s' : '%s'", span.doc[token_divider].text, span.text)
return token_divider
def divide_span(span: Span, args) -> Iterator[Span]:
max_width = args.width
if span.end_char - span.start_char <= max_width:
yield span
return # end the generator if the length is ok
###
divider = preferred_division_for(span, max_width) or secondary_division_for(span, max_width)
after_divider = divider + 1
yield span.doc[span.start:after_divider] # up to divider inclusive
if after_divider < span.end:
yield from divide_span(span.doc[after_divider:span.end], args)
def iterate_document(doc: Doc, timing: dict, args):
max_lines = args.lines
#max_width = args.width
for sentence in doc.sents:
for chunk in chunked(divide_span(sentence, args), max_lines):
subtitle = '\n'.join(line.text for line in chunk)
sub_start, _ = chunk[0]._.get_time_span(timing)
_, sub_end = chunk[-1]._.get_time_span(timing)
yield sub_start, sub_end, subtitle
def write_srt(doc, timing, args):
comma: str = ','
for i, (start, end, text) in enumerate(iterate_document(doc, timing, args), start=1):
ts1 = format_timestamp(start, always_include_hours=True, decimal_marker=comma)
ts2 = format_timestamp(end, always_include_hours=True, decimal_marker=comma)
print(f"{i}\n{ts1} --> {ts2}\n{text}\n")
def configure_spaCy(model: str, entities: str, pauses: list = []):
nlp = spacy.load(model)
#print(nlp.pipe_names, file=sys.stderr)
if model.startswith('xx'):
raise NotImplementedError("spaCy multilanguage models are not currently supported")
#nlp.add_pipe("fragmenter", name="fragmenter", last=True)
nlp.add_pipe("fragmenter", config={"verbal_pauses": pauses}, last=True)
if len(entities) > 0:
new_ruler = nlp.add_pipe("entity_ruler", before='ner').from_disk(entities)
return nlp
def scan_for_pauses(doc_text: str, timing: dict) -> list[int]:
pauses = []
# requires Python 3.10
# Look at the silence between the end of one word and the start of the next
for (k1, (_, end)), (k2, (start, _)) in pairwise(sorted(timing.items())):
gap = start - end
if gap > 0.5:
pauses.append(k1) # the key in the timing dict before the pause
return pauses
def main():
parser = argparse.ArgumentParser(
prog='subwisp',
description='Convert a whisper .json transcript into .srt subtitles with sentences, grammatically separated where possible.',
epilog='')
parser.add_argument('input_file')
parser.add_argument('-m', '--model', help='specify spaCy model', default="en_core_web_lg")
parser.add_argument('-e', '--entities', help='optional custom entities for spaCy (.jsonl format)', default="")
parser.add_argument('-w', '--width', help='maximum line width', default=42, type=int)
parser.add_argument('-l', '--lines', help='maximum lines per subtitle', default=2, type=int, choices=range(1,4))
parser.add_argument('-d', '--debug', help='print debug information',
action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.WARNING)
parser.add_argument('--verbose', help='be verbose',
action="store_const", dest="loglevel", const=logging.INFO)
args = parser.parse_args()
logging.basicConfig(level=args.loglevel)
if not os.path.isfile(args.input_file):
logging.error("File not found: %s", args.input_file )
exit(-1)
if not args.model:
logging.error("No spacy model specified")
exit(-1)
if len(args.entities) > 0 and not os.path.isfile(args.entities):
logging.error("Entities file not found: %s", args.entities)
exit(-1)
wtext, word_timing = load_whisper_json(args.input_file)
verbal_pauses = scan_for_pauses(wtext, word_timing)
nlp = configure_spaCy(args.model, args.entities, verbal_pauses)
doc = nlp(wtext)
write_srt(doc, word_timing, args)
exit()
if __name__ == '__main__' :
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment