-
-
Save glangford/a2b24ffd92c832c60e1b1b49da1a8b27 to your computer and use it in GitHub Desktop.
Experimental tool using NLP to create grammatically separated subtitles from whisper transcripts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
MIT License | |
Copyright (c) 2024 Glenn Langford | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import os | |
import argparse | |
import logging | |
import json | |
from more_itertools import chunked | |
from itertools import pairwise | |
from collections.abc import Iterator | |
import spacy | |
from spacy.language import Language | |
from spacy.tokens import Doc, Span, Token | |
from spacy.matcher import Matcher | |
from whisper.utils import format_timestamp | |
def get_time_span(span: Span, timing: dict): | |
start_token = span[0] | |
end_token = span[-1] | |
while start_token.is_punct or not timing.get(start_token.idx, None): # back up in tokens to get timing entry for a word | |
start_token = start_token.nbor(-1) | |
while end_token.is_punct or not timing.get(end_token.idx, None): | |
end_token = end_token.nbor(-1) | |
end_index = end_token.idx | |
start_index = start_token.idx | |
start, _ = timing[start_index] | |
#_, end = timing[end_index] | |
_, end = timing.get(end_index, (None, None)) | |
if not end: | |
logging.debug("Timing aligment error: %s %d", span.text, end_token.idx) | |
return (start, end) | |
Token.set_extension("can_fragment_after", default=False) | |
Token.set_extension("fragment_reason", default="") | |
Span.set_extension("get_time_span", method=get_time_span) # type: ignore | |
punct_pattern = [{'IS_PUNCT': True, 'ORTH': {"IN": [",", ":", ";"]}}] | |
conj_pattern = [{"POS": {"IN": ["CCONJ", "SCONJ"]}}] # break prior to CCONJ or SCONJ | |
clause_pattern = [{"DEP": {"IN": ["advcl", "relcl", "acl", "acl:relcl"]}}] # break at either side of clause | |
ac_comp_pattern = [{"DEP": {"IN": ["acomp", "ccomp"]}}] # break at either side of adj or clausal complement | |
preposition_pattern = [{'POS': 'ADP'}] # but break *after* ADP if preceded by verb | |
dobj_pattern = [{'DEP': 'dobj'}, {'IS_PUNCT': False}] | |
# break after PART if preceded by verb (adjacent verb-particle pattern) | |
v_particle_pattern = [{'POS': 'VERB'}, {'POS': 'PART'}, {'POS': {"IN": ["VERB", "AUX"]}, 'OP': '!'}] | |
v_adj_pattern = [{'POS': "VERB"}, {"POS": "ADJ", "DEP": "amod"}] # amod = adjectival modifier (adjacent adjective pattern) | |
#log_pattern = [{'POS': "VERB"}, {"POS": "SCONJ"}, {'POS': "VERB"}] # pt "está a tornar" | |
# https://spacy.io/usage/processing-pipelines | |
# See "Example: Stateful component with settings" | |
def _fragment_at(token: Token, reason: str): | |
token._.can_fragment_after = True | |
token._.fragment_reason = reason | |
# Stateful spaCy pipeline component with settings | |
@Language.factory("fragmenter", default_config={"verbal_pauses": []}) | |
def create_fragmenter_component(nlp: Language, name: str, verbal_pauses: list[int]): | |
return FragmenterComponent(nlp, verbal_pauses) | |
class FragmenterComponent: | |
def __init__(self, nlp: Language, verbal_pauses: list): # pauses must be serializable | |
self.pauses = set(verbal_pauses) | |
logging.info("Count of pauses: %d", len(self.pauses)) | |
def __call__(self, doc: Doc) -> Doc: # process the document | |
return fragmenter(doc, self.pauses) | |
###@Language.component("fragmenter") | |
def fragmenter(doc: Doc, pauses: set) -> Doc: | |
matcher = Matcher(doc.vocab) | |
matcher.add("clause", [clause_pattern]) | |
matcher.add("punct", [punct_pattern]) | |
matcher.add("conj", [conj_pattern]) | |
matcher.add("preposition", [preposition_pattern]) | |
matcher.add("dobj", [dobj_pattern]) | |
matcher.add("v_particle", [v_particle_pattern]) | |
matcher.add("ac_comp", [ac_comp_pattern]) | |
matcher.add("v_adj", [v_adj_pattern]) | |
matches = matcher(doc) | |
conjunction_or_punct = frozenset(["CCONJ", "SCONJ", "PUNCT"]) | |
# Iterate over the matches | |
for match_id, start, end in matches: | |
rule_id = doc.vocab.strings[match_id] | |
matched_span = doc[start:end] | |
##print(rule_id, start, end, doc[start].idx, matched_span.text) | |
token = doc[start] | |
#rule = re.sub("_pattern", "", rule_id) | |
if token.i < 2: | |
continue | |
match rule_id: | |
case "punct": # selected subset of punctuation characters | |
_fragment_at(token, reason=rule_id) | |
case "conj": # coordinating and subordinating conjunctions | |
prior = token.nbor(-1) | |
if prior.pos_ not in conjunction_or_punct: # eg. "so that" | |
_fragment_at(prior, reason=rule_id) | |
case "clause": | |
subtree = [t for t in token.subtree] | |
#print(" ", " ".join(t.text for t in subtree)) | |
if len(subtree) < 2: | |
continue | |
clause_rule = f"{rule_id}:{token.text}" | |
left = subtree[0] | |
if left and left.i > 0 and not left.is_punct and left.text[0] != "'": | |
prior = left.nbor(-1) | |
if prior.pos_ not in conjunction_or_punct and not prior.nbor(-1).is_punct: | |
_fragment_at(prior, reason=clause_rule) | |
right = subtree[-1] | |
try: | |
if right.pos_ not in conjunction_or_punct and not right.nbor(1).is_punct: | |
_fragment_at(right, reason=clause_rule) | |
except IndexError: | |
continue # can safely ignore | |
case "preposition": | |
#print("", matched_span, token.text) | |
prior = token.nbor(-1) | |
if prior.pos_ in conjunction_or_punct or token.ent_iob_ == 'I': # prior is a conjunction, comma, or token is inside a named entity | |
continue | |
next = token.nbor(1) | |
# Keep preposition together with verb when appropriate | |
if (token.dep_ == 'prt' or prior.pos_ in ['AUX', 'VERB']) and not next.is_punct: # "going on, come up, end up" | |
_fragment_at(token, reason=f"{rule_id}-after") | |
else: # fragment prior to preposition | |
if token.i > 2 and not (prior.is_punct or prior.nbor(-1).is_punct): | |
_fragment_at(prior, reason=rule_id) | |
case "v_particle": # VERB + PART | |
#print(rule_id, matched_span) | |
particle = matched_span[1] | |
if particle.is_punct or particle.nbor(1).is_punct: | |
continue | |
_fragment_at(particle, reason=rule_id) | |
case "v_adj": | |
#print(rule_id, matched_span) | |
_fragment_at(token, reason=rule_id) | |
case "dobj": | |
if token.pos_ not in conjunction_or_punct: | |
_fragment_at(token, reason=rule_id) | |
case "ac_comp": | |
if token.is_punct: # first token could be punct with hyphenated words | |
continue | |
subtree = [t for t in token.subtree] | |
left = subtree[0] | |
if len(subtree) < 2: | |
continue | |
ac_rule = f"{rule_id}:{token.text}" | |
if left and left.i > 0 and left.text[0] != "'" and not (left.is_punct or left.nbor(-1).is_punct): # can get the leftmost token as an apostrophe or a comma | |
_fragment_at(left.nbor(-1), reason=ac_rule) | |
right = subtree[-1] | |
try: | |
if not (right.is_punct or right.nbor(1).is_punct): | |
_fragment_at(right, reason=ac_rule) | |
except IndexError: | |
logging.debug("ac_comp IndexError") | |
continue # can safely ignore | |
#print("acomp ", token.text, token.dep_, "tree:", " ".join(t.text for t in subtree)) | |
_scan_entities(doc) | |
_scan_noun_phrases(doc) | |
_scan_pauses(doc, pauses) | |
return doc | |
def _scan_pauses(doc: Doc, pauses: set): | |
for token in doc: | |
if token.text[0] == '-': | |
continue | |
try: | |
if token.idx in pauses and not token.nbor(1).is_punct : # identify places where spaCy misses period end, eg. "The length was 2m." | |
logging.debug("Candidate pause: %d %s %s", token.i, token.text, token.nbor(1).text) | |
#_fragment_at(token, reason="verbal pause") | |
except IndexError: | |
continue # can be safely ignored | |
def _scan_entities(doc: Doc): | |
for entity in doc.ents: | |
# if any( (len(entity) < 2, entity.label_ in ignore_entities, len(entity.text) < 10) ) | |
if len(entity) < 2 or entity.label_ in ['PERSON', 'ORDINAL', 'PERCENT', 'TIME', 'CARDINAL'] or len(entity.text) < 10: | |
continue | |
token = entity[0] | |
if token.i < 1: | |
continue | |
prior = token.nbor(-1) | |
if (not prior.is_punct and | |
prior.pos_ not in ['DET'] and | |
not prior._.can_fragment_after): # don't overwrite prior reason | |
_fragment_at(prior, reason="entity->") | |
after = entity[-1].nbor(1) | |
if (after.pos_ != "PART" and | |
not after.is_punct and # eg "UK's ..." where entity is possessive | |
not entity[-1]._.can_fragment_after): | |
_fragment_at(entity[-1], reason="entity") | |
def _scan_noun_phrases(doc: Doc): | |
for chunk in doc.noun_chunks: | |
if len(chunk) < 2: | |
continue | |
token = chunk[0] | |
if token.i > 0 and not token.is_punct: | |
prior = token.nbor(-1) | |
# Don't interfere with prior decision of ADP or S/C CONJ at the start of noun phrases | |
if (prior.pos_ not in ['ADP', 'SCONJ', 'CCONJ'] and | |
not prior._.can_fragment_after and # don't overwrite prior reason | |
not prior.is_punct): # eg. don't fragment at hyphen or period of previous sentence | |
_fragment_at(prior, reason="NP->") | |
try: | |
after = chunk[-1].nbor(1) | |
if (not after.is_punct and | |
not chunk[-1]._.can_fragment_after): | |
_fragment_at(chunk[-1], reason="NP") | |
except IndexError: | |
continue # can safely ignore | |
def leading_whitespace(s: str) -> int: | |
return(len(s) - len(s.lstrip())) | |
def load_whisper_json(file: str) -> tuple[str, dict]: | |
doc_timing = {} | |
doc_text = "" | |
js = open(file) | |
jsdata = json.load(js) | |
logging.debug("Language: %s", jsdata['language']) | |
for s in jsdata['segments']: | |
if 'words' not in s: | |
raise ValueError('JSON input file must contain word timestamps') | |
for word_timed in s['words']: | |
# print(word_timed['start'], word_timed['end'], word_timed['word']) | |
word = word_timed['word'] | |
if len(doc_text) == 0: | |
word = word.lstrip() # remove any leading whitespace from first word in Whisper | |
start_index = 0 | |
doc_text += word | |
start_index = len(doc_text) - len(word) + leading_whitespace(word) # align the timing index with the spaCy token index | |
doc_timing[start_index] = (word_timed['start'], word_timed['end']) | |
return doc_text, doc_timing | |
# Greedy scan from right to left for preferred grammar and occupancy | |
def preferred_division_for(span: Span, max_width: int) -> int: | |
# token.i: index of the token within the parent document | |
# token.idx: character offset of the token within the parent document | |
# Span.start_char: character offset for the start of the span | |
# Span.end_char: character offset for the end of the span | |
def is_grammatically_preferred(token: Token): | |
return token._.can_fragment_after and ( | |
token._.fragment_reason in ['punct', 'conj', 'clause', 'entity', 'entity->', 'v_adj']) | |
preferreds = (t for t in reversed(span) if is_grammatically_preferred(t)) | |
target_width = round(0.7 * max_width) | |
for tp in preferreds: | |
width = tp.idx + len(tp) - span.start_char | |
if width > max_width: | |
continue | |
remainder_width = span.end_char - tp.idx - len(tp) | |
if width <= remainder_width and width >= max_width/3 and remainder_width <= max_width: # prefer pyramid to finish the sentence | |
logging.debug("Primary complete %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text) | |
return tp.i | |
if width >= target_width and width <= remainder_width * 1.2: # no more than 20% bigger than remaining text in sentence | |
logging.debug("Primary selected %s %s at %d : '%s'", tp.text, tp._.fragment_reason, tp.idx - span.start_char, span.text) | |
return tp.i | |
logging.debug("No primary for '%s'", span.text) | |
return 0 | |
# Secondary scan from left to right for fragmenting words, or degrade to a word within line maximum | |
def secondary_division_for(span: Span, max_width: int) -> int: | |
token_divider = 0 # index of token in parent doc | |
start_index = span.start_char | |
for token in span: | |
if token.i == span[0].i: | |
continue | |
token_start = token.idx - start_index # character offset from start of span | |
if token_divider and token_start > max_width: # do we have a prior divider and is this token is over budget | |
break | |
token_end = token_start + len(token) | |
if token._.can_fragment_after and token_end <= max_width and token.i + 2 < span[-1].i : # select token to divide at | |
token_divider = token.i | |
if span.end_char - token.idx - len(token) <= max_width: # lookahead to see if Span remainder is less than max_width | |
break # prefer pyramid | |
if not token_divider and token_end > max_width: # if no grammatical option found, fall back to last token within budget | |
token_divider = token.i - 1 if token.pos_ != 'PUNCT' else token.i - 2 # don't orphan a period at end of sentence | |
logging.info("Forced division after word '%s' : '%s'", span.doc[token_divider].text, span.text) | |
return token_divider | |
def divide_span(span: Span, args) -> Iterator[Span]: | |
max_width = args.width | |
if span.end_char - span.start_char <= max_width: | |
yield span | |
return # end the generator if the length is ok | |
### | |
divider = preferred_division_for(span, max_width) or secondary_division_for(span, max_width) | |
after_divider = divider + 1 | |
yield span.doc[span.start:after_divider] # up to divider inclusive | |
if after_divider < span.end: | |
yield from divide_span(span.doc[after_divider:span.end], args) | |
def iterate_document(doc: Doc, timing: dict, args): | |
max_lines = args.lines | |
#max_width = args.width | |
for sentence in doc.sents: | |
for chunk in chunked(divide_span(sentence, args), max_lines): | |
subtitle = '\n'.join(line.text for line in chunk) | |
sub_start, _ = chunk[0]._.get_time_span(timing) | |
_, sub_end = chunk[-1]._.get_time_span(timing) | |
yield sub_start, sub_end, subtitle | |
def write_srt(doc, timing, args): | |
comma: str = ',' | |
for i, (start, end, text) in enumerate(iterate_document(doc, timing, args), start=1): | |
ts1 = format_timestamp(start, always_include_hours=True, decimal_marker=comma) | |
ts2 = format_timestamp(end, always_include_hours=True, decimal_marker=comma) | |
print(f"{i}\n{ts1} --> {ts2}\n{text}\n") | |
def configure_spaCy(model: str, entities: str, pauses: list = []): | |
nlp = spacy.load(model) | |
#print(nlp.pipe_names, file=sys.stderr) | |
if model.startswith('xx'): | |
raise NotImplementedError("spaCy multilanguage models are not currently supported") | |
#nlp.add_pipe("fragmenter", name="fragmenter", last=True) | |
nlp.add_pipe("fragmenter", config={"verbal_pauses": pauses}, last=True) | |
if len(entities) > 0: | |
new_ruler = nlp.add_pipe("entity_ruler", before='ner').from_disk(entities) | |
return nlp | |
def scan_for_pauses(doc_text: str, timing: dict) -> list[int]: | |
pauses = [] | |
# requires Python 3.10 | |
# Look at the silence between the end of one word and the start of the next | |
for (k1, (_, end)), (k2, (start, _)) in pairwise(sorted(timing.items())): | |
gap = start - end | |
if gap > 0.5: | |
pauses.append(k1) # the key in the timing dict before the pause | |
return pauses | |
def main(): | |
parser = argparse.ArgumentParser( | |
prog='subwisp', | |
description='Convert a whisper .json transcript into .srt subtitles with sentences, grammatically separated where possible.', | |
epilog='') | |
parser.add_argument('input_file') | |
parser.add_argument('-m', '--model', help='specify spaCy model', default="en_core_web_lg") | |
parser.add_argument('-e', '--entities', help='optional custom entities for spaCy (.jsonl format)', default="") | |
parser.add_argument('-w', '--width', help='maximum line width', default=42, type=int) | |
parser.add_argument('-l', '--lines', help='maximum lines per subtitle', default=2, type=int, choices=range(1,4)) | |
parser.add_argument('-d', '--debug', help='print debug information', | |
action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.WARNING) | |
parser.add_argument('--verbose', help='be verbose', | |
action="store_const", dest="loglevel", const=logging.INFO) | |
args = parser.parse_args() | |
logging.basicConfig(level=args.loglevel) | |
if not os.path.isfile(args.input_file): | |
logging.error("File not found: %s", args.input_file ) | |
exit(-1) | |
if not args.model: | |
logging.error("No spacy model specified") | |
exit(-1) | |
if len(args.entities) > 0 and not os.path.isfile(args.entities): | |
logging.error("Entities file not found: %s", args.entities) | |
exit(-1) | |
wtext, word_timing = load_whisper_json(args.input_file) | |
verbal_pauses = scan_for_pauses(wtext, word_timing) | |
nlp = configure_spaCy(args.model, args.entities, verbal_pauses) | |
doc = nlp(wtext) | |
write_srt(doc, word_timing, args) | |
exit() | |
if __name__ == '__main__' : | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment