Created
September 21, 2023 19:20
-
-
Save JoeZiminski/aa6ad688910555331b828c13df288cbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## GENERATE | |
from spikeinterface import extractors as se | |
from spikeinterface.preprocessing import bandpass_filter, common_reference, phase_shift, scale, whiten | |
import spikeinterface.sorters as ss | |
import numpy as np | |
import sys | |
from pathlib import Path | |
import spikeinterface as si | |
from spikeinterface.sorters import Kilosort2_5Sorter | |
from spikeinterface import load_extractor | |
# Load and Preprocess | |
data_path = Path(r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/test_data/mid") | |
output_path = Path(r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/sorter_output") | |
recording = se.read_spikeglx(data_path) | |
recording = phase_shift(recording) | |
recording = bandpass_filter( | |
recording, freq_min=300, freq_max=6000 | |
) | |
recording = common_reference( | |
recording, operator="median", reference="global" | |
) | |
recording = whiten(recording, dtype=np.int16, mode="local", int_scale=200) | |
Kilosort2_5Sorter.set_kilosort2_5_path("/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/Kilosort-2.5") | |
sorting = ss.run_sorter("kilosort2_5", | |
recording, | |
output_folder=output_path, | |
delete_tmp_files=False, | |
delete_recording_dat=False, | |
skip_kilosort_preprocessing=True, | |
do_correction=True, | |
scaleproc=200, | |
) | |
drift_corrected_recording = si.read_binary(output_path / "sorter_output" / "recording.dat", sampling_frequency=recording.get_sampling_frequency(), dtype=np.int16, num_channels=recording.get_num_channels()) | |
waveforms_orig = si.extract_waveforms( | |
recording, | |
sorting, | |
folder=output_path / "waveforms_orig", | |
use_relative_path=True, | |
overwrite=False, | |
) | |
drift_corrected_recording.set_channel_locations(recording.get_channel_locations()) | |
waveforms_drift = si.extract_waveforms( | |
drift_corrected_recording, | |
sorting, | |
folder=output_path / "waveforms_drift", | |
use_relative_path=True, | |
overwrite=True, | |
allow_unfiltered=True, | |
) | |
## VIEW | |
import numpy as np | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_pdf import PdfPages | |
drift_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output\waveforms_drift") | |
orig_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output\waveforms_orig") | |
drift_templates = np.load(drift_path / "templates_average.npy") | |
orig_templates = np.load(orig_path / "templates_average.npy") | |
assert drift_templates.shape == orig_templates.shape | |
with PdfPages("dirft_vs_nodrift.pdf") as pdf: | |
for i in range(drift_templates.shape[0]): | |
plt.figure() | |
plt.plot(np.mean(drift_templates[i, :, :], axis=1), color='tab:orange') | |
plt.plot(np.mean(orig_templates[i, :, :], axis=1), color='tab:blue') | |
plt.legend(["drift corrected", "no drift correction"]) | |
pdf.savefig() | |
plt.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment