Skip to content

Instantly share code, notes, and snippets.

@Redchards
Created September 12, 2019 12:56
Show Gist options
  • Save Redchards/a9a2c491b4bbe7fa3cb81632915b6fa0 to your computer and use it in GitHub Desktop.
Save Redchards/a9a2c491b4bbe7fa3cb81632915b6fa0 to your computer and use it in GitHub Desktop.
class EMDEHGDataset(Dataset):
def __init__(self, log_eps=1e-6, filt=None, channels=(2,), truncate=True, average_time=False):
super(EMDEHGDataset, self).__init__()
assert sum([int(x > 2) for x in channels]) == 0, 'Invalid channel number, can only be in {0, 1, 2}'
assert filt is None or filt in ['f1', 'f2', 'f3'], 'Invalid filter option, must be either None or of f1, f2, f3'
if filt is None:
offset = 0
elif filt == 'f1':
offset = 1
elif filt == 'f2':
offset = 2
elif filt == 'f3':
offset = 3
self.max_scale = max_scale
self.wavelet_per_octave = wavelet_per_octave
self.log_eps = log_eps
folder = os.path.join('data', 'tpehgdb')
data_records = set([os.path.splitext(s)[0] for s in os.listdir(folder)])
y = []
channel_selector = [c*3 + offset for c in channels]
raw_records = [wfdb.rdrecord(os.path.join(folder, record_name)) for record_name in data_records]
records_timesteps = [r.p_signal.shape[0] for r in raw_records]
min_timesteps = min(records_timesteps)
max_timesteps = max(records_timesteps)
self.timesteps = min_timesteps if truncate else max_timesteps
raw_data = []
for record in raw_records:
if truncate:
x = record.p_signal[:self.timesteps, channel_selector].transpose()
x /= np.max(np.abs(x))
raw_data.append(torch.FloatTensor(x))
else:
sig_length = record.p_signal.shape[0]
if sig_length < self.timesteps:
v = torch.zeros(self.timesteps, len(channels))
start = (self.timesteps - sig_length) // 2
x = record.p_signal[:, channel_selector]
x /= np.max(np.abs(x))
v[start:start + sig_length] = torch.from_numpy(x)
raw_data.append(v)
y.append(1 if float(record.comments[2].split(' ')[1]) < 37 else 0)
self.X = np.stack(raw_data, axis=0)
self.y = torch.LongTensor(y)
print(self.X.shape)
# TODO : There's a lot of noise at the beginning and the end of most EMDs, handle that more gracefully
l = []
for x in tqdm(self.X.squeeze(axis=1)):
emd = EMD(x/np.max(np.abs(x)))
imf = emd.decompose()[:10]
hilbert = scipy.signal.hilbert(imf)
y = np.imag(hilbert)
amp = np.sqrt(imf*imf + y*y)
e_amp = [scipy.stats.entropy(np.histogram(a/np.max(np.abs(a)), bins=20, density=True)[0]) for a in amp]
freq = [inst_freq(im)[0] for im in hilbert]
e_freq = [scipy.stats.entropy(np.histogram(f/np.max(np.abs(f)), bins=100, density=True)[0]) for f in freq]
l.append(np.concatenate((e_amp, e_freq)))
print(l)
print([s.shape for s in l])
self.X = torch.FloatTensor(l)
self.X = torch.log(torch.abs(self.X) + log_eps)
self.X = (self.X - self.X.mean(dim=0)) / self.X.std(dim=0)
if average_time:
self.X = self.X.mean(dim=-1)
self.shape = self.X.shape
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
def __len__(self):
return len(self.y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment