Skip to content

Instantly share code, notes, and snippets.

@Lauler
Created August 15, 2023 06:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lauler/ccfa87faf006144209b3d4eda6b042fe to your computer and use it in GitHub Desktop.
Save Lauler/ccfa87faf006144209b3d4eda6b042fe to your computer and use it in GitHub Desktop.
Preprocess srt files and bucket to ~30s chunks
import numpy as np
import os
import pandas as pd
import pysrt
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
default="kb_exempel_2",
help="Directory containing subdirectories with audio and srt files.",
)
parser.add_argument(
"--output_file",
type=str,
default="subs_preprocessed.parquet",
help="Name of output file.",
)
args = parser.parse_args()
subdirs = os.listdir(args.data_dir)
df = pd.DataFrame(subdirs, columns=["subdir"])
files = {}
for subdir in df["subdir"].tolist():
files[subdir] = os.listdir(os.path.join(args.data_dir, subdir))
df["files"] = df["subdir"].map(files)
df["audio"] = df["files"].map(lambda x: [file for file in x if file.endswith(".wav")][0])
df["srt"] = df["files"].map(lambda x: [file for file in x if file.endswith(".srt")][0])
df.drop("files", axis=1, inplace=True)
# Read every srt file in df and save each line and timestamp in a dataframe
df_subs = []
for subdir, srt, audio in zip(df["subdir"].tolist(), df["srt"].tolist(), df["audio"].tolist()):
sub = pysrt.open(os.path.join(args.data_dir, subdir, srt))
sub_block_data = []
for sub_block in sub:
sub_block_data.append(
{
"subdir": subdir,
"start": sub_block.start,
"end": sub_block.end,
"text": sub_block.text,
"srt": srt,
"audio": audio,
}
)
df_sub = pd.DataFrame(sub_block_data)
df_subs.append(df_sub)
df_subs = pd.concat(df_subs).reset_index(drop=True)
# Convert srt timestamps to milliseconds
df_subs["start_ms"] = df_subs["start"].map(
lambda x: x.hours * 3600000 + x.minutes * 60000 + x.seconds * 1000 + x.milliseconds
)
df_subs["end_ms"] = df_subs["end"].map(
lambda x: x.hours * 3600000 + x.minutes * 60000 + x.seconds * 1000 + x.milliseconds
)
df_subs["duration_s"] = (df_subs["end_ms"] - df_subs["start_ms"]) / 1000
# Divide the subtitle blocks into 30 second buckets
df_groups = []
for group, df_group in tqdm(
df_subs.groupby("audio"),
total=df_subs.groupby("audio").ngroups,
):
start = df_group["start_ms"].iloc[0]
bucket_nr = 0
bucket_cumsum = []
bucket_nrs = []
for i, end in enumerate(df_group["end_ms"]):
if ((end - start) / 1000) >= 30:
bucket_nr += 1
start = df_group["start_ms"].iloc[i]
f"Bucket {bucket_nr} has duration {prev_segment_length}."
prev_segment_length = (end - start) / 1000
bucket_cumsum.append(prev_segment_length)
bucket_nrs.append(bucket_nr)
df_group["observation_nr"] = bucket_nrs
df_group["bucket_cumsum"] = bucket_cumsum
df_groups.append(df_group)
df_groups = pd.concat(df_groups)
df_groups = df_groups.reset_index(drop=True)
# Maximum value of bucket_cumsum in each bucket (observation_nr group) is the duration of the observation
df_groups["bucket_duration_s"] = df_groups.groupby("observation_nr")["bucket_cumsum"].transform(max)
# Relative start and end times for each subtitle block within a bucket (observation_nr grouping)
df_groups["start_relative"] = df_groups["start_ms"] - df_groups.groupby("observation_nr")["start_ms"].transform(min)
df_groups["end_relative"] = df_groups["end_ms"] - df_groups.groupby("observation_nr")["start_ms"].transform(min)
# Round to nearest 20 ms (Whisper quantizes to nearest 20 ms for its timestamps)
df_groups["start_relative"] = (np.round(df_groups["start_relative"] / 20) * 20) / 1000
df_groups["end_relative"] = (np.round(df_groups["end_relative"] / 20)) * 20 / 1000
# start_bucket is the start_ms of the bucket in an observation_nr group
df_groups["start_bucket"] = df_groups.groupby("observation_nr")["start_ms"].transform(min)
# end_bucket is the end_ms of the bucket in an observation_nr group
df_groups["end_bucket"] = df_groups.groupby("observation_nr")["end_ms"].transform(max)
def format_timestamp(timestamp):
timestamp = "<|" + f"{timestamp:.2f}" + "|>"
return timestamp
df_groups["start_timestamp"] = df_groups["start_relative"].map(format_timestamp)
df_groups["end_timestamp"] = df_groups["end_relative"].map(format_timestamp)
df_groups["text_timestamps"] = df_groups["start_timestamp"] + df_groups["text"] + df_groups["end_timestamp"]
# Create a new column that joins the text_timestamps for each observation_nr group
df_groups["text_timestamps_bucket"] = df_groups.groupby("observation_nr")["text_timestamps"].transform(
lambda x: " ".join(x)
)
df_groups[
[
"subdir",
"audio",
"observation_nr",
"start_ms",
"end_ms",
"start_bucket",
"end_bucket",
"text",
"text_timestamps_bucket",
"start_relative",
"end_relative",
"bucket_duration_s",
]
].to_parquet(args.output_file, index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment