Skip to content

Instantly share code, notes, and snippets.

@trongan93
Last active October 18, 2020 03:10
Show Gist options
  • Save trongan93/9dcf243299ec630435e0476b753ef294 to your computer and use it in GitHub Desktop.
Save trongan93/9dcf243299ec630435e0476b753ef294 to your computer and use it in GitHub Desktop.
# Reference from **Bidimensional Empirical Mode Decomposition** code by Dawid Laszuk (laszukdawid@gmail.com).
# This version is modified by H-BEMD for sin and cos value and optimize Extrema detection and Normolization value
# By Trong-An Bui (trongan93@gmail.com - http://buitrongan.com)
class BEMD:
def __init__(self):
# ProtoIMF related
self.mse_thr = 0.01
self.mean_thr = 0.01
self.FIXE = 1 # Single iteration by default, otherwise results are terrible
self.FIXE_H = 0
self.MAX_ITERATION = 5
def __call__(self, image, max_imf=-1):
return self.bemd(image, max_imf=max_imf)
def extract_max_min_spline(self, image, min_peaks_pos, max_peaks_pos):
xi, yi = np.meshgrid(np.arange(image.shape[0]), np.arange(image.shape[1]))
min_val = np.array([image[x,y] for x, y in zip(*min_peaks_pos)])
max_val = np.array([image[x,y] for x, y in zip(*max_peaks_pos)])
min_env = self.spline_points(min_peaks_pos[0], min_peaks_pos[1], min_val, xi, yi)
max_env = self.spline_points(max_peaks_pos[0], max_peaks_pos[1], max_val, xi, yi)
return min_env, max_env
@classmethod
def spline_points(cls, X, Y, Z, xi, yi):
"""Creates a spline for given set of points.
Uses Radial-basis function to extrapolate surfaces. It's not the best but gives something.
Griddata algorithm didn't work.
"""
spline = Rbf(X, Y, Z, function='cubic')
return spline(xi, yi)
@classmethod
def find_extrema_positions(cls, image):
max_peaks_pos = BEMD.extract_maxima_positions(image)
min_peaks_pos = BEMD.extract_minima_positions(image)
return min_peaks_pos, max_peaks_pos
@classmethod
def extract_minima_positions(cls, image):
return BEMD.extract_maxima_positions(-image)
@classmethod
def extract_maxima_positions(cls, image):
seed_min = image - 0.000001
dilated = reconstruction(seed_min, image, method='dilation')
cleaned_image = image - dilated
maxima_positions = np.where(cleaned_image>0)[::-1]
return maxima_positions
@classmethod
def end_condition(cls, image, IMFs):
rec = np.sum(IMFs, axis=0)
if np.allclose(image, rec):
return True
return False
def check_proto_imf(self, proto_imf, proto_imf_prev, mean_env):
if np.all(np.abs(mean_env-mean_env.mean())<self.mean_thr):
return True
if np.allclose(proto_imf, proto_imf_prev, rtol=0.01):
return True
if np.mean(np.abs(proto_imf)) < self.mean_thr:
return True
mse_proto_imf = np.mean(proto_imf*proto_imf)
if mse_proto_imf > self.mse_thr:
return False
return False
def bemd(self, image, max_imf=-1):
image_s = image.copy()
imf = np.zeros(image.shape)
imf_old = imf.copy()
imfNo = 0
IMF = np.empty((imfNo,)+image.shape)
notFinished = True
while(notFinished):
res = image_s - np.sum(IMF[:imfNo], axis=0)
saveLogFile('residue_' + str(imfNo) + '.csv',res)
imf = res.copy()
mean_env = np.zeros(image.shape)
stop_sifting = False
n = 0
n_h = 0
while(not stop_sifting and n<self.MAX_ITERATION):
n += 1
min_peaks_pos, max_peaks_pos = self.find_extrema_positions(imf)
if len(min_peaks_pos[0])>1 and len(max_peaks_pos[0])>1:
min_env, max_env = self.extract_max_min_spline(imf, min_peaks_pos, max_peaks_pos)
mean_env = 0.5*(min_env+max_env)
imf_old = imf.copy()
imf = imf - mean_env
if self.FIXE:
if n>=self.FIXE+1:
stop_sifting = True
elif self.FIXE_H:
if n == 1: continue
if self.check_proto_imf(imf, imf_old, mean_env):
n_h += 1
else:
n_h = 0
# STOP if enough n_h
if n_h >= self.FIXE_H:
stop_sifting = True
else:
if self.check_proto_imf(imf, imf_old, mean_env):
stop_sifting = True
else:
stop_sifting = True
IMF = np.vstack((IMF, imf.copy()[None,:]))
imfNo += 1
if self.end_condition(image, IMF) or (max_imf>0 and imfNo>=max_imf):
notFinished = False
break
res = image_s - np.sum(IMF[:imfNo], axis=0)
if not np.allclose(res, 0):
IMF = np.vstack((IMF, res[None,:]))
imfNo += 1
return IMF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment