Skip to content

Instantly share code, notes, and snippets.

@emadeldeen24
Last active May 7, 2024 02:24
Show Gist options
  • Save emadeldeen24/53706e06313efa5618ee614cf1b8c899 to your computer and use it in GitHub Desktop.
Save emadeldeen24/53706e06313efa5618ee614cf1b8c899 to your computer and use it in GitHub Desktop.
SHHS1_preprocessing.py
import os
import torch
import numpy as np
from mne.io import read_raw_edf
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import logging
logging.getLogger('mne').setLevel(logging.WARNING)
import xml.etree.ElementTree as ET
EPOCH_SEC_SIZE = 30
ann2label = {
"Wake|0": 0,
"Stage 1 sleep|1": 1,
"Stage 2 sleep|2": 2,
"Stage 3 sleep|3": 3,
"Stage 4 sleep|4": 3,
"REM sleep|5": 4
}
data_dir = "/mnt/data/emad/shhs/polysomnography/edfs/shhs1/"
ann_dir = "/mnt/data/emad/shhs/polysomnography/annotations-events-nsrr/shhs1/"
save_dir = "/mnt/data/emad/shhs/shhs1_pt/"
filenames = next(os.walk(data_dir))[2]
annotation = next(os.walk(ann_dir))[2]
edf_fnames = list()
ann_fnames = list()
for f in filenames:
filename, file_extension = os.path.splitext(f)
if (file_extension == '.edf'):
edf_fnames.append(os.path.join(data_dir, filename + file_extension))
for f in annotation:
filename, file_extension = os.path.splitext(f)
if (file_extension == '.xml'):
ann_fnames.append(os.path.join(ann_dir, filename + file_extension))
# Check already preprocessed files:
done_subjects = next(os.walk(save_dir))[2]
ids = []
for f in done_subjects:
filename, file_extension = os.path.splitext(f)
ids.append(filename.split("_")[-1])
ids.sort()
ann_fnames = [i for i in ann_fnames if "shhs" in i]
edf_fnames.sort()
ann_fnames.sort()
edf_fnames = np.asarray(edf_fnames)
ann_fnames = np.asarray(ann_fnames)
# Initialize a dictionary to store epoch-label mappings
epoch_label_map = {}
# Iterate through annotation XML files
def get_labels(ann_fname):
# Parse the XML file
tree = ET.parse(ann_fname)
root = tree.getroot()
# Extract relevant information from the XML
for child in root.iter('ScoredEvent'):
t1 = child[0].text
t2 = child[1].text
t3 = float(child[2].text) # Convert start time to float
t4 = float(child[3].text) # Convert duration to float
if t2 not in ann2label:
continue
# Calculate the start and end timestamps for the event
start_timestamp = t3
end_timestamp = t3 + t4
# Extract the stage label (you can modify this as needed)
stage_label = ann2label[t2]
# Store the stage label for the corresponding epoch
epoch_label_map[(start_timestamp, end_timestamp)] = stage_label
return epoch_label_map
for file_id in range(len(edf_fnames)):
epoch_label_map = {}
# print(edf_fnames[file_id])
subject_id = os.path.basename(edf_fnames[file_id]).split("-")[-1].split(".")[0]
subject_ann = os.path.basename(ann_fnames[file_id]).split("-")[-2]
assert subject_id == subject_ann
if subject_id in ids:
continue
try:
print(f"Preprocessing subject: {subject_id}")
raw = read_raw_edf(edf_fnames[file_id], preload=True, stim_channel=None, verbose=None)
sampling_rate = raw.info['sfreq']
# channels = raw.ch_names
select_ch = ['EEG(sec)', 'ECG', 'EMG', 'EOG(L)', 'EOG(R)', 'EEG']
raw_ch_df = raw.to_data_frame(scaling_time=sampling_rate)[select_ch]
raw_ch_df.set_index(np.arange(len(raw_ch_df)))
raw_ch = raw_ch_df.values
# print(raw_ch.shape)
epoch_label_map = get_labels(ann_fnames[file_id])
eeg_epochs = []
labels = []
# Iterate through the epochs and extract corresponding EEG data
for epoch_start, epoch_end in epoch_label_map:
# Convert epoch timestamps to sample indices
start_sample = int(epoch_start * sampling_rate)
end_sample = int(epoch_end * sampling_rate)
# Extract the EEG epoch data
eeg_epoch = raw_ch[start_sample:end_sample]
n_epochs = len(eeg_epoch) // (EPOCH_SEC_SIZE * sampling_rate)
# Get epochs and their corresponding labels
x = np.asarray(np.split(eeg_epoch, n_epochs)) #.astype(np.float32)
y = [epoch_label_map[(epoch_start, epoch_end)]] * int(n_epochs)
eeg_epochs.extend(x)
labels.extend(y)
# Get epochs and their corresponding labels
x = np.array(eeg_epochs).transpose(0, 2, 1)
y = np.array(labels)
# print(x.shape)
# print(y.shape)
assert len(x) == len(y)
data_save = dict()
data_save["samples"] = torch.from_numpy(x).float()
data_save["labels"] = torch.from_numpy(y)
torch.save(data_save, os.path.join(save_dir, f"shhs1_{subject_id}.pt"))
print(f" ---------- Done with Subject {subject_id} ---------")
raw.close()
except:
print(f"####### ISSUE WITH SUBJECT {subject_id} #########")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment