Skip to content

Instantly share code, notes, and snippets.

@isofew
Last active December 7, 2022 10:46
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save isofew/89f7a1eeda0b30805452bd16203821c6 to your computer and use it in GitHub Desktop.
Save isofew/89f7a1eeda0b30805452bd16203821c6 to your computer and use it in GitHub Desktop.
Align two Japanese subtitles based on reading matches score
#!/usr/bin/env python
# coding: utf-8
from matplotlib import pyplot as plt
from scipy.signal import correlate
from copy import deepcopy
import numpy as np
import argparse
import pysubs2
import re
import spacy
nlp = spacy.load('ja_ginza')
import string
alnum = string.ascii_lowercase + string.ascii_uppercase + string.digits
katakana = 'ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワン'
kid = {k:i for i,k in enumerate(katakana)}
maybe_kid = lambda k: [kid[k]] if k in kid else []
def get_kids(s, bra_half = re.compile('\([^\)]*\)'), bra_full = re.compile('([^)]*)')):
s = bra_half.sub('', s)
s = bra_full.sub('', s)
kids = []
for t in nlp(s):
if t.pos_ in ['SYM', 'PUNCT']:
continue
if '記号' in t.tag_:
continue
if all(c in alnum for c in t.orth_):
continue
# get the first reading, if it exists
for rr in t.morph.get('Reading'):
for r in rr:
kids += maybe_kid(r)
break
return kids
def kvec(kids, n_kids=len(kid), vad=1):
v = np.zeros(n_kids + 1)
v[-1] = vad
for k in kids:
v[k] = 1
return v
def expand_vec(sub, mat, wnd):
x = np.zeros(( int(sub[-1].end / wnd), len(kid)+1 ))
for s, v in zip(sub, mat):
si = int(s.start / wnd)
se = int(s.end / wnd)
x[si : se, :] = v[None]
return x
# a : steepness
# l : target len bw 0-0.5
def weighted(s, l, a=10, start=None, end=None):
if start is None:
start = 0
if end is None:
end = len(s)
L = len(s)
idx = np.arange(L)
x = idx - start
p = 1 / ( 1 + np.exp( a/L*(x - l) ) )
return s * (p * (idx >= start) * (idx < end))[:, None]
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--srt-source", type=str, help="source srt file by speech recognition programs", required=True)
parser.add_argument("--srt-target", type=str, help="target srt file for timestamp adjustments", required=True)
parser.add_argument("--srt-output", type=str, help="output srt file", required=True)
parser.add_argument("--wnd", default=10, type=int, help="(ms) window size to speed up correlation calculation")
parser.add_argument("--chunk-gap", default=15 * 1000, type=int, help="(ms) minimum gap between subtitle chunks")
parser.add_argument("--final-len", default=60 * 1000, type=int, help="(ms) minimum length for the final subtitle chunk")
parser.add_argument("--target-len", default=20 * 60 * 1000, type=int, help="(ms) target length for subtitle chunks")
parser.add_argument("--merge-wnd", default=2 * 1000, type=int, help="(ms) maximum difference in number of windows for delta merging")
parser.add_argument("--max-offset", default=60 * 1000, type=int, help="(ms) maximum time-shifting offset relative to previous chunk")
return parser.parse_args()
def main():
args = parse_args()
srt_source = args.srt_source #'../sample.whisper.srt'
srt_target = args.srt_target #'../sample.ground.srt'
srt_output = args.srt_output #'../sample.srt'
wnd = args.wnd #10 # ms window size to speed up correlation calculation
chunk_gap = args.chunk_gap #15 * 1000 #ms
final_len = args.final_len #60 * 1000 #ms
target_len = args.target_len
merge_wnd = int(args.merge_wnd / wnd) #2 * 1000 / wnd # ms->wnd
max_offset = int(args.max_offset / wnd) # ms->wnd
sub_source = pysubs2.load(srt_source)
sub_target = pysubs2.load(srt_target)
mat_source = np.stack([ kvec(get_kids(s.text)) for s in sub_source ])
mat_target = np.stack([ kvec(get_kids(s.text)) for s in sub_target ])
emat_source = expand_vec(sub_source, mat_source, wnd)
emat_target = expand_vec(sub_target, mat_target, wnd)
gaps = [
i for i in range(1, len(sub_target))
if sub_target[i].start - sub_target[i-1].end > chunk_gap \
and sub_target[-1].end - sub_target[i].start > final_len
]
deltas = []
for s in [0] + gaps:
if s > 0 and len(deltas) > 0:
# on top of previous delta, can shift leftmost(-) to previous sub
min_offset = int( deltas[-1] - (sub_target[s].start - sub_target[s-1].end) / wnd )
else:
min_offset = None
c = correlate(
emat_source,
weighted(emat_target, l=target_len / wnd, start=sub_target[s].start / wnd)
).sum(1)
x0 = len(emat_target) - 1 # this corresponds to zero delta
if min_offset is not None:
# anything before/below min is 0
if x0 + min_offset > 0:
c[:x0 + min_offset] = 0
# anything above/after max is also 0
if x0 + min_offset + max_offset > 0:
c[x0 + min_offset + max_offset:] = 0
else:
# for first delta, select between (-max, max)
if x0 - max_offset > 0:
c[:x0 - max_offset] = 0
# assume positive max
c[x0 + max_offset:] = 0
d = c.argmax() - x0
print('s', s, 'd', d, 'min_offset', min_offset)
if len(deltas) == 0 or abs(deltas[-1] - d) > merge_wnd:
merging_pos = len(deltas)
deltas.append(d)
else:
md = deltas[merging_pos:] + [d]
mean = md[0] # int(sum(md) / len(md))
print('>> merge with previous deltas', md, '=>', mean)
for i in range(merging_pos, len(deltas)):
deltas[i] = mean
deltas.append(mean)
sub_out = deepcopy(sub_target)
endpoints = [0] + gaps + [len(sub_target)]
for s, e, d in list(zip(endpoints[:-1], endpoints[1:], deltas)):
for i in range(s, e):
sub_out[i].shift(ms = d*wnd)
sub_out.to_file(open(srt_output, 'w'), 'srt')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment