Skip to content

Instantly share code, notes, and snippets.

@menon92
Created September 23, 2019 19:46
Show Gist options
  • Save menon92/67e3830524998f14f3e50b356fa61c40 to your computer and use it in GitHub Desktop.
Save menon92/67e3830524998f14f3e50b356fa61c40 to your computer and use it in GitHub Desktop.
# coding:utf-8
import cv2
import numpy as np
def WaveletTransformAxisY(img):
row, col = img.shape[:2]
size = row / 2
img_even = img[1::2]
img_odd = img[0::2]
if len(img_even) != len(img_odd):
img_odd = img_odd[:-1]
# c: mean (low-frequency), d: diff (high-frequency)
c = (img_even + img_odd) / 2.
d = abs(img_odd - img_even)
return size, c, d
def WaveletTransformLLAxisY(img):
row, col = img.shape[:2]
img_even = img[1::2]
img_odd = img[0::2]
if len(img_even) != len(img_odd):
img_odd = img_odd[:-1]
# c: mean (low-frequency), d: diff (high-frequency)
c = (img_even + img_odd) / 2.
return c
def WaveletTransformAxisX(img):
tmp = np.fliplr(img.T)
size, dst_L, dst_H = WaveletTransformAxisY(tmp)
dst_L = np.flipud(dst_L.T)
dst_H = np.flipud(dst_H.T)
return size, dst_L, dst_H
def WaveletTransformLLAxisX(img):
tmp = np.fliplr(img.T)
dst_L = WaveletTransformLLAxisY(tmp)
dst_L = np.flipud(dst_L.T)
return dst_L
def WaveletTransform(img, n=1):
row, col = img.shape[:2]
roi = img[0:row,0:col]
wavelets = {}
for i in range(0, n):
print('i', i)
y_size, wavelet_L, wavelet_H = WaveletTransformAxisY(roi)
x_size, wavelet_LL, wavelet_LH = WaveletTransformAxisX(wavelet_L)
wavelets["LL_"+str(i+1)] = wavelet_LL
wavelets["LH_"+str(i+1)] = wavelet_LH
x_size, wavelet_HL, wavelet_HH = WaveletTransformAxisX(wavelet_H)
wavelets["HL_"+str(i+1)] = wavelet_HL
wavelets["HH_"+str(i+1)] = wavelet_HH
roi = wavelet_LL # k_l
return wavelets
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment