Created
November 14, 2022 10:50
-
-
Save saattrupdan/bb6c9c52d9f4b35258db2b2456d31224 to your computer and use it in GitHub Desktop.
Create Danish WIT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Unpack the WIT dataset and extract the Danish samples.""" | |
from datasets.arrow_dataset import Example | |
from datasets.dataset_dict import DatasetDict | |
from datasets.load import load_dataset | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
import re | |
def main(): | |
"""Main function.""" | |
# Set up the path to the text file where we will be storing the indices of the | |
# Danish samples | |
path_to_indices = Path("danish_wit_samples.txt") | |
path_to_indices.touch(exist_ok=True) | |
# Load the indices from the text file | |
indices_already_found = list(map(int, path_to_indices.read_text().splitlines())) | |
# Stream the WIT dataset | |
wit = load_dataset("wikimedia/wit_base", split='train', streaming=False) | |
# Set up the starting index | |
if indices_already_found: | |
start_idx = max(indices_already_found) + 1 | |
else: | |
start_idx = 0 | |
# Set the total number of samples in the dataset | |
total = len(wit) | |
# Define the number of Danish samples already found | |
danish_samples: int = len(indices_already_found) | |
# Set up the progress bar | |
with tqdm(range(total), desc="Extracting Danish samples") as pbar: | |
# Update the progress bar with the number of Danish samples that we have found | |
if indices_already_found: | |
pbar.update(max(indices_already_found) + 1) | |
# Iterate over the samples | |
for idx in range(start_idx, total): | |
# Get the sample | |
try: | |
sample = wit[idx] | |
except: | |
continue | |
# If the sample is Danish, append the index to the text file | |
if 'da' in sample['wit_features']['language']: | |
with open(path_to_indices, 'a') as f: | |
f.write(f"{idx}\n") | |
danish_samples += 1 | |
# Update the progress bar | |
pbar.set_postfix_str( | |
f"Found {danish_samples} Danish samples " | |
f"({danish_samples / (1 + pbar.n):.2%})" | |
) | |
# Update the progress bar | |
pbar.update() | |
# Load the indices from the text file | |
indices = list(map(int, path_to_indices.read_text().splitlines())) | |
# Select the Danish samples | |
danish_wit = wit.select(indices) | |
# Extract the Danish content from the samples | |
danish_wit = danish_wit.map( | |
function=extract_danish, remove_columns=['wit_features'], | |
) | |
# Split the dataset into train, validation and test sets | |
train_valtest = danish_wit.train_test_split(test_size=256+1024) | |
val_test = train_valtest['test'].train_test_split(test_size=1024) | |
# Collect the splits in a dataset dictionary | |
danish_wit = DatasetDict(dict( | |
train=train_valtest['train'], | |
val=val_test['train'], | |
test=val_test['test'], | |
)) | |
# Push the Danish WIT to the Hugging Face Hub | |
danish_wit.push_to_hub(repo_id='alexandrainst/danish-wit') | |
def extract_danish(sample: Example) -> Example: | |
"""Extract the Danish texts from the samples. | |
Args: | |
sample (Example): | |
A sample from the WIT dataset. | |
Returns: | |
Example: | |
The new sample containing the Danish texts. | |
""" | |
# Get the index of the Danish texts | |
da_idx = sample['wit_features']['language'].index('da') | |
# Set Danish WIT features | |
for key, val in sample['wit_features'].items(): | |
if key != 'language': | |
sample[key] = val[da_idx] | |
# If the caption attribution description is None, set it to an empty string | |
desc = sample['caption_attribution_description'] | |
if desc is not None: | |
matches = re.search(pattern=r'(?<=Dansk: )(.+)', string=desc) | |
# If the Danish text was not found, set the description to be empty | |
if matches is None: | |
sample['caption_attribution_description'] = None | |
# Otherwise, set the description to be the Danish text | |
else: | |
danish_desc = re.sub( | |
pattern=r'[A-ZÆØÅ][a-zæøå]+\:.*', | |
repl='', | |
string=matches.group(1), | |
).strip() | |
sample['caption_attribution_description'] = danish_desc | |
return sample | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment