Skip to content

Instantly share code, notes, and snippets.

@JoeZiminski
Created September 21, 2023 19:20
Show Gist options
  • Save JoeZiminski/aa6ad688910555331b828c13df288cbc to your computer and use it in GitHub Desktop.
Save JoeZiminski/aa6ad688910555331b828c13df288cbc to your computer and use it in GitHub Desktop.
## GENERATE
from spikeinterface import extractors as se
from spikeinterface.preprocessing import bandpass_filter, common_reference, phase_shift, scale, whiten
import spikeinterface.sorters as ss
import numpy as np
import sys
from pathlib import Path
import spikeinterface as si
from spikeinterface.sorters import Kilosort2_5Sorter
from spikeinterface import load_extractor
# Load and Preprocess
data_path = Path(r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/test_data/mid")
output_path = Path(r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/sorter_output")
recording = se.read_spikeglx(data_path)
recording = phase_shift(recording)
recording = bandpass_filter(
recording, freq_min=300, freq_max=6000
)
recording = common_reference(
recording, operator="median", reference="global"
)
recording = whiten(recording, dtype=np.int16, mode="local", int_scale=200)
Kilosort2_5Sorter.set_kilosort2_5_path("/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/code/Kilosort-2.5")
sorting = ss.run_sorter("kilosort2_5",
recording,
output_folder=output_path,
delete_tmp_files=False,
delete_recording_dat=False,
skip_kilosort_preprocessing=True,
do_correction=True,
scaleproc=200,
)
drift_corrected_recording = si.read_binary(output_path / "sorter_output" / "recording.dat", sampling_frequency=recording.get_sampling_frequency(), dtype=np.int16, num_channels=recording.get_num_channels())
waveforms_orig = si.extract_waveforms(
recording,
sorting,
folder=output_path / "waveforms_orig",
use_relative_path=True,
overwrite=False,
)
drift_corrected_recording.set_channel_locations(recording.get_channel_locations())
waveforms_drift = si.extract_waveforms(
drift_corrected_recording,
sorting,
folder=output_path / "waveforms_drift",
use_relative_path=True,
overwrite=True,
allow_unfiltered=True,
)
## VIEW
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
drift_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output\waveforms_drift")
orig_path = Path(r"X:\neuroinformatics\scratch\jziminski\ephys\code\sorter_output\waveforms_orig")
drift_templates = np.load(drift_path / "templates_average.npy")
orig_templates = np.load(orig_path / "templates_average.npy")
assert drift_templates.shape == orig_templates.shape
with PdfPages("dirft_vs_nodrift.pdf") as pdf:
for i in range(drift_templates.shape[0]):
plt.figure()
plt.plot(np.mean(drift_templates[i, :, :], axis=1), color='tab:orange')
plt.plot(np.mean(orig_templates[i, :, :], axis=1), color='tab:blue')
plt.legend(["drift corrected", "no drift correction"])
pdf.savefig()
plt.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment