Skip to content

Instantly share code, notes, and snippets.

@skoudoro
Last active October 17, 2019 19:39
Show Gist options
  • Save skoudoro/675b39533878e918bdff88c706446959 to your computer and use it in GitHub Desktop.
Save skoudoro/675b39533878e918bdff88c706446959 to your computer and use it in GitHub Desktop.
test mppca with dipy
from dipy.reconst.shore import ShoreModel
from dipy.reconst.forecast import ForecastModel
from dipy.data import (fetch_cenir_multib, read_cenir_multib, get_sphere,
fetch_isbi2013_2shell, read_isbi2013_2shell,
fetch_cfin_multib, read_cfin_dwi)
import time
import multiprocessing
import numpy as np
import math
def test_mppca():
import matplotlib.pyplot as plt
from dipy.core.gradients import gradient_table
# load main pca function using Marcenko-Pastur distribution
from dipy.denoise.localpca import mppca
fetch_cfin_multib()
img, gtab = read_cfin_dwi()
data = img.get_data()
affine = img.affine
bvals = gtab.bvals
bvecs = gtab.bvecs
sel_b = np.logical_or(np.logical_or(bvals == 0, bvals == 1000), bvals == 2000)
data = data[..., sel_b]
gtab = gradient_table(bvals[sel_b], bvecs[sel_b])
print(data.shape)
t = time.time()
denoised_arr = mppca(data, patch_radius=2)
print("Time taken for local MP-PCA ", -t + time.time())
t = time.time()
denoised_arr_fast = mppca(data, patch_radius=2, use_fast=True)
print("Time taken for fast local MP-PCA ", -t + time.time())
sli = data.shape[2] // 2
gra = data.shape[3] - 1
orig = data[:, :, sli, gra]
den = denoised_arr[:, :, sli, gra]
den_fast = denoised_arr_fast[0][:, :, sli, gra]
rms_diff = np.sqrt((orig - den) ** 2)
rms_diff_fast = np.sqrt((orig - den_fast) ** 2)
fig1, ax = plt.subplots(2, 3, figsize=(12, 6),
subplot_kw={'xticks': [], 'yticks': []})
fig1.subplots_adjust(hspace=0.3, wspace=0.05)
ax.flat[0].imshow(orig.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[0].set_title('Original')
ax.flat[1].imshow(den.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[1].set_title('Denoised Output')
ax.flat[2].imshow(rms_diff.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[2].set_title('Residuals')
ax.flat[3].imshow(orig.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[3].set_title('Original')
ax.flat[4].imshow(den_fast.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[4].set_title('Fast Denoised Output')
ax.flat[5].imshow(rms_diff_fast.T, cmap='gray', interpolation='none',
origin='lower')
ax.flat[5].set_title('Fast Residuals')
plt.show()
fig1.savefig('denoised_mppca.png')
if __name__ == '__main__':
test_mppca()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment