Skip to content

Instantly share code, notes, and snippets.

@kingjr
Last active May 11, 2023 10:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/e519015c7bdb34eba3c1e15eb59583a1 to your computer and use it in GitHub Desktop.
Save kingjr/e519015c7bdb34eba3c1e15eb59583a1 to your computer and use it in GitHub Desktop.
to put in bm/studies/
import itertools
import json
import os
import re
import typing as tp
from collections import defaultdict
from pathlib import Path
import datalad
import datalad.api
import mne
import nibabel
import numpy as np
import pandas as pd
import tqdm
from . import api, utils
EXCLUDE_TASKS = ("notthefallshortscram", "notthefalllongscram", "schema")
class StudyPaths(utils.StudyPaths):
def __init__(self) -> None:
super().__init__(Narrative2020Recording.study_name())
narratives = self.folder # Path("/datasets01/hasson_narratives/")
stimuli = narratives / "stimuli"
self.stimuli_path = stimuli
self.narratives_path = narratives
self.scan_exclude_path = narratives / "code" / "scan_exclude.json"
self.gentle_path = narratives / "stimuli" / "gentle"
self.afni_dir = narratives / "derivatives" / "afni-nosmooth"
self.test_download = "draft_narrative"
class Narrative2020Recording(api.Recording):
data_url = "http://datasets.datalad.org/?dir=/labs/hasson/narratives/"
paper_url = "https://www.nature.com/articles/s41597-021-01033-3"
doi = "https://doi.org/10.6084/m9.figshare.14818587"
licence = 'Creative Commons CC0 license which allows for free reuse without restriction.'
modality = "audio"
language = "en"
device = "fmri"
description = "345 subjects; 891 functional scans; 27 stories; each subject listened to a different set of stories"
tr = 1.5
@classmethod
def download(cls, n_subjects=None) -> None:
paths = StudyPaths()
path_dataset = "///labs/hasson/narratives"
path_target = str(paths.test_download)
Path(path_target).mkdir(exist_ok=True, parents=True)
path_targets = [
path_target,
str(Path(path_target) / "derivatives" / "afni-nosmooth"),
str(Path(path_target) / "derivatives" / "freesurfer"),
str(Path(path_target) / "stimuli")]
path_sources = [c.replace(path_target, path_dataset)
for c in path_targets]
recurrent = [False, False, False, True]
for psource, ptarget, recursive in zip(
path_sources, path_targets, recurrent):
ds = datalad.api.Dataset(ptarget)
if not ds.is_installed():
datalad.api.install(
path=ptarget, source=psource, get_data=False,
description="Narratives dataset", recursive=recursive,
jobs="auto", branch=None)
else:
print(psource, " allready installed at ", ptarget)
print("Finished downloading of metadata files")
ds = datalad.api.Dataset(path_target+"/derivatives/freesurfer")
ds.get("fsaverage6")
ds = datalad.api.Dataset(path_target)
ds.get("code")
print("Changing afni-nosmooth number of jobs")
ds = datalad.api.Dataset(path_target+"/derivatives/afni-nosmooth")
configManage = datalad.config.ConfigManager(dataset=ds)
configManage.set(var="runtime.max-annex-jobs", value="20")
print("Downloading afni-nosmooth")
files = list(
filter(
lambda l: not l.startswith("."),
os.listdir(ds.path)))
files_sub = list(filter(lambda e: e.startswith("sub-"), files))
files_tot = list(np.setdiff1d(files, files_sub)
) + files_sub[:n_subjects]
for d in tqdm.tqdm(files_tot):
print("Downloading ", d)
ds.get(d)
print("Ended download of afni-nosmooth")
# “auto” corresponds to the number defined by
# ‘datalad.runtime.max-annex-jobs’ configuration item.
# pytest: disable=arguments-differ
@ classmethod
def iter(cls) -> tp.Iterator["Narratives2020Recording"]: # type: ignore
"""Returns a generator of all recordings"""
# download, extract, organize
# cls.download()
# List all recordings: depends on study structure
cls.df = get_task_df()
for _, row in cls.df.iterrows():
recording = cls(subject_uid=row.subject,
story=row.task,
session=row.run,
condition=row.condition,
)
yield recording
def __init__(self, subject_uid: str,
session: str,
story: str,
condition: str,
) -> None:
recording_uid = "_".join(
[subject_uid, str(session),
story, str(condition)])
super().__init__(subject_uid=subject_uid,
recording_uid=recording_uid)
self.subject_uid = subject_uid
self.story = story
self.session = session
row = self.df.query(
"subject==@subject_uid and task==@story and run==@session and condition==@condition")
assert len(row) == 1
row = row.iloc[0]
self.info = row
assert Path(self.info.wav_file).exists()
assert Path(self.info.gii_fpath_left).exists()
assert Path(self.info.gii_fpath_right).exists()
def _load_raw(self) -> mne.io.RawArray:
bidata = []
for hemi in ["left", "right"]:
gii = nibabel.load(self.info[f"gii_fpath_{hemi}"])
data = np.vstack([da.data[np.newaxis, :] for da in gii.darrays])
bidata.append(data)
bidata = np.concatenate(bidata, axis=1) # [T, 2 * V]
return mne.io.RawArray(
bidata.T, info=mne.create_info(
bidata.shape[1],
sfreq=1. / self.tr, ch_types="mag"))
def _get_stimulus(self, task):
paths = StudyPaths()
stim_fname = paths.gentle_path / task / "align.csv"
text_fname = paths.gentle_path / task / "transcript.txt"
stim = pd.read_csv(
stim_fname, names=["word", "word_low", "onset", "offset"]
)
_preproc_stim(stim, text_fname, lower=False)
_fix_stimulus(stim, task)
stim["word_pp"] = stim["word_raw"]
stim["sentence_id"] = stim["sequ_index"]
stim["sequence_id"] = stim["sequ_index"]
# some onset / offset are missing => interpolate
stim[["onset", "offset"]] = stim[["onset", "offset"]].interpolate()
stim["start"] = stim["onset"]
stim["duration"] = stim["offset"] - stim["onset"]
stim["condition"] = "sentence"
return stim
def _load_events(self, add_phones=False) -> pd.DataFrame:
"""
task: audio task
stim_start_tr: number of TR before audio starts
"""
events = self._get_stimulus(self.story)
events["kind"] = "word"
events["word"] = events["word_pp"]
sound_row = pd.DataFrame(
[{"kind": "sound",
"start": 0.,
"filepath": str(self.info.wav_file),
"offset": np.nanmax(events["offset"].values)}])
events = events = pd.concat(
[sound_row, events], ignore_index=True)
if add_phones:
phone_df = _get_phones(self.story)
phone_df["kind"] = "phoneme"
events = pd.concat([events, phone_df],
axis=0,
ignore_index=True)
# Realign with the beginning of stimulus
for time_col in ["onset", "offset"]:
events[time_col] += (self.tr * self.info.stim_start_tr)
events["duration"] = 0.01
events = utils._formatevents(
events,
self.language,
self.modality,
block_by='sentence')
return events
def _load_subjects_info():
paths = StudyPaths()
# Load participants information
raw_subject_df = pd.read_csv(
paths.narratives_path / "participants.tsv", sep="\t"
)
raw_subject_df = raw_subject_df.astype("str")
subjects_df = []
for i, row in raw_subject_df.iterrows():
for task, condition, comprehension in zip(
row.task.split(","),
row.condition.split(","),
row.comprehension.split(","),
):
if comprehension != "n/a":
comprehension = float(comprehension)
if "shapes" in task:
comprehension /= 10
else:
comprehension = np.nan
if task.startswith("notthefall"):
condition = task.split("notthefall")[1]
audio_task = task
elif task != "milkyway": # milkyway and nothefall have different conditions
audio_task = task
else:
audio_task = task + condition
subjects_df.append(
{
"subject": row.participant_id,
"task": audio_task,
"bold_task": task,
"condition": condition,
"comprehension": comprehension,
}
)
subjects_df = pd.DataFrame(subjects_df)
return subjects_df
def _get_wav_file(tasks):
paths = StudyPaths()
wavs = [paths.stimuli_path / f"{task}_audio.wav" for task in tasks]
for wav in wavs:
assert Path(wav).is_file()
return wavs
def get_task_df(exclude=True, one_run_only=False):
# Partitipants info (subject, task, condition etc.)
subjects_df = _load_subjects_info()
bi_df = []
for hemi in ["left", "right"]:
# Get gii_files (+mark excluded files)
files_df = _get_gii_files_info(
subjects_df, space="fsaverage6", hemi=hemi
)
# Merge
df = pd.merge(
subjects_df, files_df, on=["subject", "bold_task"], how="left"
)
# Remove non existing sessions
df = df.dropna(subset=["gii_fname"])
df = df.astype({"run": int})
# Remove excluded task
if exclude:
df["exclude"] = df["exclude"].astype(bool)
df = df.query("not exclude")
df = df.query("task not in @EXCLUDE_TASKS")
# Remove second run (only one scan per subject, task, hemi, space)
if one_run_only:
df = (
df.sort_values(["subject", "task", "run"])
.groupby(["subject", "task", "condition"])
.agg("first")
.reset_index()
)
# Add wavefile
df["wav_file"] = _get_wav_file(df["task"].values)
# Add the start TR of the stimulus
task_onsets = _get_task_onsets()
df["stim_start_tr"] = [task_onsets[task] for task in df["task"]]
bi_df.append(df)
df_out = pd.merge(
bi_df[0],
bi_df[1],
on=['subject', 'task', 'condition', 'run', 'wav_file', 'stim_start_tr',
'bold_task'], suffixes=("_left", "_right"))
# condition: whether soud is scarmbled, intact etc.
# bold_task = fmri file name
# task = wave file name
# run = session, there should be only one if one_run_only=True
# bold_task redundant to gii
assert df_out.subject.nunique() == 321
assert df_out.task.nunique() == 18
assert df_out.bold_task.nunique() == 16
return df_out
def _get_gii_files_info(subjects_df, space="fsaverage6", hemi="left"):
paths = StudyPaths()
# Get corresponding bold files
files_df = defaultdict(list)
for i, row in subjects_df.iterrows():
gii_fname = f"{row.subject}_task-{row.bold_task}_*"
gii_fname += f"space-{space}_hemi-{hemi[0].upper()}_desc-clean.func.gii"
gii_files = list(
(paths.afni_dir / row.subject / "func").glob(gii_fname)
)
for file in gii_files:
fname = file.name
pattern = "run-(\w*)_" # noqa
run = re.findall(pattern, fname)
run = int(run[0]) if len(run) else 1
files_df["subject"].append(row.subject)
files_df["bold_task"].append(row.bold_task)
files_df["gii_fname"].append(fname)
files_df["gii_fpath"].append(file)
files_df["run"].append(run)
files_df = pd.DataFrame(files_df)
files_df["run"] = files_df["run"].astype(int)
# Check for excluded sessions
exclude_dic = json.load(open(paths.scan_exclude_path, "r"))
exclude = []
for _, row in files_df.iterrows():
row_exclude = False
if row.bold_task in exclude_dic:
if row.subject in exclude_dic[row.bold_task]:
for pattern in exclude_dic[row.bold_task][row.subject]:
if pattern in row.gii_fname:
row_exclude = True
exclude.append(row_exclude)
files_df["exclude"] = exclude
return files_df
def _get_phones(task):
paths = StudyPaths()
json_name = paths.gentle_path / task / "align.json"
dico = json.load(open(json_name, "r"))
phones = []
for v in dico["words"]:
if ("phones" in v):
current = v["start"]
for i, phone in enumerate(v["phones"]):
phones.append({
"phone": phone["phone"],
"onset": current,
"offset": current + phone["duration"],
"phone_id": i,
"word": i,
})
current += phone["duration"]
return pd.DataFrame(phones)
def _preproc_stim(df, text_fname, lower=False):
text = open(text_fname).read()
text = format_text(text, lower=lower)
transcript_tokens = space_tokenizer(text)
gentle_tokens = gentle_tokenizer(text)
assert len(gentle_tokens) == len(df)
spans = match_transcript_tokens(transcript_tokens, gentle_tokens)
assert len(spans) == len(gentle_tokens)
tokens = [w[0] for w in spans]
tokens = format_tokens(tokens, lower=lower)
# word raw
df["word_raw"] = tokens
# is_final_word
begin_of_sentences_marks = [".", "!", "?"]
df["is_eos"] = [
np.any([k in i for k in begin_of_sentences_marks]) for i in tokens
]
# is_bos
df["is_bos"] = np.roll(df["is_eos"], 1)
# seq_id
df["sequ_index"] = df["is_bos"].cumsum() - 1
# wordpos_in_seq
df["wordpos_in_seq"] = df.groupby("sequ_index").cumcount()
# wordpos_in_stim
df["wordpos_in_stim"] = np.arange(len(tokens))
# seq_len
df["seq_len"] = df.groupby("sequ_index")["word_raw"].transform(len)
# end of file
df["is_eof"] = [False] * (len(df) - 1) + [True]
df["is_bof"] = [True] + [False] * (len(df) - 1)
df["word_raw"] = df["word_raw"].fillna("")
df["word"] = df["word"].fillna("")
def _get_task_onsets():
start_tr = {}
# Set onsets for some tasks
for key in [
"21styear",
"milkywayoriginal",
"milkywaysynonyms",
"milkywayvodka",
"prettymouth",
"pieman",
"schema",
]:
start_tr[key] = 0
for key in ["piemanpni", "bronx", "black", "forgot"]:
start_tr[key] = 8
for key in [
"slumlordreach",
"shapessocial",
"shapesphysical",
"sherlock",
"merlin",
"notthefallintact",
"notthefallshortscram",
"notthefalllongscram",
]:
start_tr[key] = 3
for key in ["lucy"]:
start_tr[key] = 2 # 1 in events.tsv, 2 in paper
for key in ["tunnel"]:
start_tr[key] = 2
return start_tr
def _fix_stimulus(
stimulus,
task,
tasks_with_issues=("notthefallintact", "prettymouth", "merlin"),
new_starts=([25.8], [21], [29, 29.15]),
):
if task in tasks_with_issues:
new_vals = new_starts[tasks_with_issues.index(task)]
for i, val in enumerate(new_vals):
stimulus.loc[stimulus.index[i], "onset"] = val
stimulus.loc[stimulus.index[i], "offset"] = val + 0.1
def format_text(text, lower=True):
text = text.replace("\n", " ")
text = text.replace(" -- ", ". ")
text = text.replace(" – ", ", ")
text = text.replace("–", "-")
text = text.replace(' "', ". ")
text = text.replace(' "', ". ")
text = text.replace('" ', ". ")
text = text.replace('". ', ". ")
text = text.replace('." ', ". ")
text = text.replace("?. ", "? ")
text = text.replace(",. ", ", ")
text = text.replace("...", ". ")
text = text.replace(".. ", ". ")
text = text.replace(":", ". ")
text = text.replace("…", ". ")
text = text.replace("-", " ")
text = text.replace(" ", " ")
if lower:
text = text.lower()
return text
def match_transcript_tokens(transcript_tokens, gentle_tokens):
transcript_line = np.array(
[i[1] for i in transcript_tokens]
) # begin of each word
raw_words = []
for word, start, end in gentle_tokens:
middle = (start + end) / 2
diff = (middle - transcript_line).copy()
diff[diff < 0] = np.Inf
matching_idx = np.argmin(diff).astype(int)
raw_words.append(transcript_tokens[matching_idx])
return raw_words
def replace_special_character_chains(text):
text = text.replace("-", " ")
text = text.replace('laughs:"You', 'laughs: "You')
return text
def gentle_tokenizer(raw_sentence):
seq = []
for m in re.finditer(r"(\w|\’\w|\'\w)+", raw_sentence, re.UNICODE):
start, end = m.span()
word = m.group()
seq.append((word, start, end))
return seq
def split_with_index(s, c=" "):
p = 0
for k, g in itertools.groupby(s, lambda x: x == c):
q = p + sum(1 for i in g)
if not k:
yield p, q # or p, q-1 if you are really sure you want that
p = q
def format_tokens(x, lower=False):
x = np.array(x)
fx = [
format_text(" " + xi + " ", lower=lower).strip() for xi in x.reshape(-1)
]
fx = np.array(fx).reshape(x.shape)
return fx
def space_tokenizer(text):
return [(text[i:j], i, j) for i, j in split_with_index(text, c=" ")]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment