Skip to content

Instantly share code, notes, and snippets.

@georgepar
Last active June 3, 2019 07:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save georgepar/09ea1ca5e9933fd52840c663ae41245b to your computer and use it in GitHub Desktop.
Save georgepar/09ea1ca5e9933fd52840c663ae41245b to your computer and use it in GitHub Desktop.
import os
import requests
from torch.utils.data import Dataset
from silx.io.dictdump import h5todict
def download_file(url, fname):
resp = requests.get(url, stream=True)
with open(fname, 'wb') as fd:
for datum in resp.iter_content():
fd.write(datum)
return fname
class CMUMosi(Dataset):
"""DataLoader for CMU Mosi dataset. Dataloader for CMU MOSEI should be very similar
Need to also download raw wavs and map the file ids
No alignment is performed
Args:
wav_paths (str): path to wavs
"""
def __init__(self, wavs_path, download_path='./cmu_mosei', task='sentiment'):
self.wavs_path = wavs_path
self.task = task
label_urls = {
"sentiment": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsSentiment.csd",
"emotion": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsEmotions.csd"
}
# Replace for cmu mosei
# label_urls = {
# "sentiment": "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsSentiment.csd",
# "emotion"]: "http://immortal.multicomp.cs.cmu.edu/CMU-MOSEI/labels/CMU_MOSEI_LabelsEmotions.csd"
# }
self.label_file = os.path.join(download_path,
label_urls[task].split('/')[-1])
if not os.path.isfile(self.label_file):
self.label_file = download_file(label_urls[task], self.label_file)
label_dict = h5todict(self.label_file)['Opinion Segment Labels']['data']
# Labels are dicts in the form {'features': numpy.array, 'intervals': numpy array}
# features contains sentiment annotations while intervals
# contains the start and end time of the respective label in the video
self.file_ids, self.labels = list(map(list, zip(*label_dict.items())))
self.data = [] # read respective file_ids and extract features
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.label[idx]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment