Skip to content

Instantly share code, notes, and snippets.

@nvbn
Created August 29, 2018 22:33
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save nvbn/f1365d2548f48fad449bb66d650ad95f to your computer and use it in GitHub Desktop.
Bob's Burgers to The Simpsons with TensorFlow
from pathlib import Path
from typing import NamedTuple
from collections import defaultdict
from datetime import timedelta
from subprocess import call
from pycaption.srt import SRTReader
import lxml.html
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
lang = 'en-US'
output_dir = ''
root = Path('')
class Caption(NamedTuple):
path: str
start: int
length: int
text: str
def to_text(raw_text):
if not raw_text:
return ''
raw_text = raw_text.replace('\n', ' ')
return lxml.html.document_fromstring(raw_text).text_content()
def _read_subtitles(path, offset=0):
with open(path, 'rb') as f:
data = f.read().decode()[offset:]
raw_captions = SRTReader().read(data, lang=lang).get_captions(lang)
for raw_caption, next_raw_caption in zip(raw_captions, raw_captions[1:] + [None]):
if next_raw_caption:
length = next_raw_caption.start - raw_caption.start
else:
length = raw_caption.end - raw_caption.start
yield Caption(
path=path,
start=raw_caption.start,
length=length,
text=to_text(raw_caption.get_text()),
)
def read_subtitles(path):
try:
return _read_subtitles(path, 0)
except:
return _read_subtitles(path, 1)
data_text2captions = defaultdict(lambda: [])
for season in root.glob('*'):
if season.is_dir():
for subtitles in season.glob('*.srt'):
print(subtitles)
try:
for caption in read_subtitles(subtitles.as_posix(), offset=1):
data_text2captions[caption.text].append(caption)
except:
print('pass', subtitles)
data_texts = [*data_text2captions]
print('got data texts')
# Tina-rannosaurus Wrecks
# https://www.opensubtitles.org/en/subtitles/5643476/bob-s-burgers-tina-rannosaurus-wrecks-en
# https://www.youtube.com/watch?v=hZ_EKHGgWJQ
play = [*read_subtitles('Bobs.Burgers.S03E07.HDTV.XviD-AFG.srt')][1:54]
play_text2captions = defaultdict(lambda: [])
for caption in play:
play_text2captions[caption.text].append(caption)
play_texts = [*play_text2captions]
print('got play texts')
module_url = "https://tfhub.dev/google/universal-sentence-encoder/2"
embed = hub.Module(module_url)
print('got module')
vec_a = tf.placeholder(tf.float32, shape=None)
vec_b = tf.placeholder(tf.float32, shape=None)
# For evaluation we use exactly normalized rather than
# approximately normalized.
normalized_a = tf.nn.l2_normalize(vec_a, axis=1)
normalized_b = tf.nn.l2_normalize(vec_b, axis=1)
sim_scores = -tf.acos(tf.reduce_sum(tf.multiply(normalized_a, normalized_b), axis=1))
def get_similarity_score(text_vec_a, text_vec_b):
emba, embb, scores = session.run(
[normalized_a, normalized_b, sim_scores],
feed_dict={
vec_a: text_vec_a,
vec_b: text_vec_b
})
return scores
def get_most_similar_text(vec_a, data_vectors):
scores = get_similarity_score([vec_a] * len(data_texts), data_vectors)
return data_texts[sorted(enumerate(scores), key=lambda score: -score[1])[3][0]]
with tf.Session() as session:
session.run([tf.global_variables_initializer(), tf.tables_initializer()])
data_vecs, play_vecs = session.run([embed(data_texts), embed(play_texts)])
data_vecs = np.array(data_vecs).tolist()
play_vecs = np.array(play_vecs).tolist()
print('got vecs')
similar_texts = {play_text: get_most_similar_text(play_vecs[n], data_vecs)
for n, play_text in enumerate(play_texts)}
print('got similarity')
class Part(NamedTuple):
video: str
start: str
end: str
output: str
def generate_parts():
for n, caption in enumerate(play):
similar = similar_texts[caption.text]
similar_caption = sorted(
data_text2captions[similar],
key=lambda maybe_similar: abs(caption.length - maybe_similar.length),
reverse=True)[0]
yield Part(
video=similar_caption.path.replace('.srt', '.mp4'),
start=str(timedelta(microseconds=similar_caption.start))[:-3],
end=str(timedelta(microseconds=similar_caption.length))[:-3],
output=Path(output_dir).joinpath(f'part_{n}.mp4').as_posix())
parts = [*generate_parts()]
for part in parts:
call(['ffmpeg', '-y', '-i', part.video,
'-ss', part.start, '-t', part.end,
'-c:v', 'libx264', '-c:a', 'aac', '-strict', 'experimental',
'-vf', 'fps=30',
'-b:a', '128k', part.output])
concat = '\n'.join(f"file '{part.output}'" for part in parts) + '\n'
with open('concat.txt', 'w') as f:
f.write(concat)
call(['ffmpeg', '-y', '-safe', '0', '-f', 'concat', '-i', 'concat.txt',
'-c:v', 'libx264', '-c:a', 'aac', '-strict', 'experimental',
'-vf', 'fps=30', 'output.mp4'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment